diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 000000000..4425e0dda --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,14 @@ +[alias] +lint = "clippy --workspace --tests --examples --bins -- -Dclippy::todo" +lint-all = "clippy --workspace --all-features --tests --examples --bins -- -Dclippy::todo" + +# lib checking +ci-check-min = "hack --workspace check --no-default-features" +ci-check-default = "hack --workspace check" +ci-check-default-tests = "check --workspace --tests" +ci-check-all-feature-powerset="hack --workspace --feature-powerset --skip=__compress,io-uring check" +ci-check-all-feature-powerset-linux="hack --workspace --feature-powerset --skip=__compress check" + +# testing +ci-doctest-default = "test --workspace --doc --no-fail-fast -- --nocapture" +ci-doctest = "test --workspace --all-features --doc --no-fail-fast -- --nocapture" diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..6164c657c --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +# These are supported funding model platforms + +github: [robjtede] diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 2df863ae8..fa06a137a 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -33,5 +33,5 @@ Please search on the [Actix Web issue tracker](https://github.com/actix/actix-we ## Your Environment -* Rust Version (I.e, output of `rustc -V`): -* Actix Web Version: +- Rust Version (I.e, output of `rustc -V`): +- Actix Web Version: diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index d8c6d66ca..bfa124ffd 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,15 +1,8 @@ blank_issues_enabled: true contact_links: - - name: GitHub Discussions - url: https://github.com/actix/actix-web/discussions - about: Actix Web Q&A - - name: Gitter chat (actix-web) - url: https://gitter.im/actix/actix-web - about: Actix Web Q&A - - name: Gitter chat (actix) - url: https://gitter.im/actix/actix - about: Actix (actor framework) Q&A - name: Actix Discord url: https://discord.gg/NWpN5mmg3x about: Actix developer discussion and community chat - + - name: GitHub Discussions + url: https://github.com/actix/actix-web/discussions + about: Actix Web Q&A diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 42deadf5a..d617cf708 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,7 +8,7 @@ PR_TYPE ## PR Checklist - - [ ] Tests for the changes have been added / updated. diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml index 828d62561..a4b54ca7a 100644 --- a/.github/workflows/bench.yml +++ b/.github/workflows/bench.yml @@ -1,8 +1,6 @@ name: Benchmark on: - pull_request: - types: [opened, synchronize, reopened] push: branches: - master diff --git a/.github/workflows/ci-master.yml b/.github/workflows/ci-master.yml new file mode 100644 index 000000000..b78617dc5 --- /dev/null +++ b/.github/workflows/ci-master.yml @@ -0,0 +1,153 @@ +name: CI (master only) + +on: + push: + branches: [master] + +jobs: + build_and_test_nightly: + strategy: + fail-fast: false + matrix: + target: + - { name: Linux, os: ubuntu-latest, triple: x86_64-unknown-linux-gnu } + - { name: macOS, os: macos-latest, triple: x86_64-apple-darwin } + - { name: Windows, os: windows-2022, triple: x86_64-pc-windows-msvc } + version: + - nightly + + name: ${{ matrix.target.name }} / ${{ matrix.version }} + runs-on: ${{ matrix.target.os }} + + env: + CI: 1 + CARGO_INCREMENTAL: 0 + VCPKGRS_DYNAMIC: 1 + + steps: + - uses: actions/checkout@v2 + + # install OpenSSL on Windows + # TODO: GitHub actions docs state that OpenSSL is + # already installed on these Windows machines somewhere + - name: Set vcpkg root + if: matrix.target.triple == 'x86_64-pc-windows-msvc' + run: echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append + - name: Install OpenSSL + if: matrix.target.triple == 'x86_64-pc-windows-msvc' + run: vcpkg install openssl:x64-windows + + - name: Install ${{ matrix.version }} + uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ matrix.version }}-${{ matrix.target.triple }} + profile: minimal + override: true + + - name: Generate Cargo.lock + uses: actions-rs/cargo@v1 + with: { command: generate-lockfile } + - name: Cache Dependencies + uses: Swatinem/rust-cache@v1.2.0 + + - name: Install cargo-hack + uses: actions-rs/cargo@v1 + with: + command: install + args: cargo-hack + + - name: check minimal + uses: actions-rs/cargo@v1 + with: { command: ci-check-min } + + - name: check default + uses: actions-rs/cargo@v1 + with: { command: ci-check-default } + + - name: tests + timeout-minutes: 60 + run: | + cargo test --lib --tests -p=actix-router --all-features + cargo test --lib --tests -p=actix-http --all-features + cargo test --lib --tests -p=actix-web --features=rustls,openssl -- --skip=test_reading_deflate_encoding_large_random_rustls + cargo test --lib --tests -p=actix-web-codegen --all-features + cargo test --lib --tests -p=awc --all-features + cargo test --lib --tests -p=actix-http-test --all-features + cargo test --lib --tests -p=actix-test --all-features + cargo test --lib --tests -p=actix-files + cargo test --lib --tests -p=actix-multipart --all-features + cargo test --lib --tests -p=actix-web-actors --all-features + + - name: tests (io-uring) + if: matrix.target.os == 'ubuntu-latest' + timeout-minutes: 60 + run: > + sudo bash -c "ulimit -Sl 512 + && ulimit -Hl 512 + && PATH=$PATH:/usr/share/rust/.cargo/bin + && RUSTUP_TOOLCHAIN=${{ matrix.version }} cargo test --lib --tests -p=actix-files --all-features" + + - name: Clear the cargo caches + run: | + cargo install cargo-cache --version 0.6.3 --no-default-features --features ci-autoclean + cargo-cache + + ci_feature_powerset_check: + name: Verify Feature Combinations + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install stable + uses: actions-rs/toolchain@v1 + with: + toolchain: stable-x86_64-unknown-linux-gnu + profile: minimal + override: true + + - name: Generate Cargo.lock + uses: actions-rs/cargo@v1 + with: { command: generate-lockfile } + - name: Cache Dependencies + uses: Swatinem/rust-cache@v1.2.0 + + - name: Install cargo-hack + uses: actions-rs/cargo@v1 + with: + command: install + args: cargo-hack + + - name: check feature combinations + uses: actions-rs/cargo@v1 + with: { command: ci-check-all-feature-powerset } + + - name: check feature combinations + uses: actions-rs/cargo@v1 + with: { command: ci-check-all-feature-powerset-linux } + + coverage: + name: coverage + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install stable + uses: actions-rs/toolchain@v1 + with: + toolchain: stable-x86_64-unknown-linux-gnu + profile: minimal + override: true + + - name: Generate Cargo.lock + uses: actions-rs/cargo@v1 + with: { command: generate-lockfile } + - name: Cache Dependencies + uses: Swatinem/rust-cache@v1.2.0 + + - name: Generate coverage file + run: | + cargo install cargo-tarpaulin --vers "^0.13" + cargo tarpaulin --workspace --features=rustls,openssl --out Xml --verbose + - name: Upload to Codecov + uses: codecov/codecov-action@v1 + with: { file: cobertura.xml } diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0126238ef..f41aa972f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,18 +14,32 @@ jobs: target: - { name: Linux, os: ubuntu-latest, triple: x86_64-unknown-linux-gnu } - { name: macOS, os: macos-latest, triple: x86_64-apple-darwin } - - { name: Windows, os: windows-latest, triple: x86_64-pc-windows-msvc } + - { name: Windows, os: windows-2022, triple: x86_64-pc-windows-msvc } version: - - 1.46.0 # MSRV + - 1.54.0 # MSRV - stable - - nightly name: ${{ matrix.target.name }} / ${{ matrix.version }} runs-on: ${{ matrix.target.os }} + env: + CI: 1 + CARGO_INCREMENTAL: 0 + VCPKGRS_DYNAMIC: 1 + steps: - uses: actions/checkout@v2 + # install OpenSSL on Windows + # TODO: GitHub actions docs state that OpenSSL is + # already installed on these Windows machines somewhere + - name: Set vcpkg root + if: matrix.target.triple == 'x86_64-pc-windows-msvc' + run: echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append + - name: Install OpenSSL + if: matrix.target.triple == 'x86_64-pc-windows-msvc' + run: vcpkg install openssl:x64-windows + - name: Install ${{ matrix.version }} uses: actions-rs/toolchain@v1 with: @@ -35,10 +49,9 @@ jobs: - name: Generate Cargo.lock uses: actions-rs/cargo@v1 - with: - command: generate-lockfile + with: { command: generate-lockfile } - name: Cache Dependencies - uses: Swatinem/rust-cache@v1.0.1 + uses: Swatinem/rust-cache@v1.2.0 - name: Install cargo-hack uses: actions-rs/cargo@v1 @@ -48,56 +61,60 @@ jobs: - name: check minimal uses: actions-rs/cargo@v1 - with: - command: hack - args: --clean-per-run check --workspace --no-default-features --tests + with: { command: ci-check-min } - - name: check full + - name: check default uses: actions-rs/cargo@v1 - with: - command: check - args: --workspace --bins --examples --tests + with: { command: ci-check-default } - name: tests - uses: actions-rs/cargo@v1 - with: - command: test - args: -v --workspace --all-features --no-fail-fast -- --nocapture - --skip=test_h2_content_length - --skip=test_reading_deflate_encoding_large_random_rustls - - - name: tests (actix-http) - uses: actions-rs/cargo@v1 - timeout-minutes: 40 - with: - command: test - args: --package=actix-http --no-default-features --features=rustls -- --nocapture - - - name: tests (awc) - uses: actions-rs/cargo@v1 - timeout-minutes: 40 - with: - command: test - args: --package=awc --no-default-features --features=rustls -- --nocapture - - - name: Generate coverage file - if: > - matrix.target.os == 'ubuntu-latest' - && matrix.version == 'stable' - && github.ref == 'refs/heads/master' + timeout-minutes: 60 run: | - cargo install cargo-tarpaulin --vers "^0.13" - cargo tarpaulin --out Xml - - name: Upload to Codecov - if: > - matrix.target.os == 'ubuntu-latest' - && matrix.version == 'stable' - && github.ref == 'refs/heads/master' - uses: codecov/codecov-action@v1 - with: - file: cobertura.xml + cargo test --lib --tests -p=actix-router --all-features + cargo test --lib --tests -p=actix-http --all-features + cargo test --lib --tests -p=actix-web --features=rustls,openssl -- --skip=test_reading_deflate_encoding_large_random_rustls + cargo test --lib --tests -p=actix-web-codegen --all-features + cargo test --lib --tests -p=awc --all-features + cargo test --lib --tests -p=actix-http-test --all-features + cargo test --lib --tests -p=actix-test --all-features + cargo test --lib --tests -p=actix-files + cargo test --lib --tests -p=actix-multipart --all-features + cargo test --lib --tests -p=actix-web-actors --all-features + + - name: tests (io-uring) + if: matrix.target.os == 'ubuntu-latest' + timeout-minutes: 60 + run: > + sudo bash -c "ulimit -Sl 512 + && ulimit -Hl 512 + && PATH=$PATH:/usr/share/rust/.cargo/bin + && RUSTUP_TOOLCHAIN=${{ matrix.version }} cargo test --lib --tests -p=actix-files --all-features" - name: Clear the cargo caches run: | - cargo install cargo-cache --no-default-features --features ci-autoclean + cargo install cargo-cache --version 0.6.3 --no-default-features --features ci-autoclean cargo-cache + + rustdoc: + name: doc tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install Rust (nightly) + uses: actions-rs/toolchain@v1 + with: + toolchain: nightly-x86_64-unknown-linux-gnu + profile: minimal + override: true + + - name: Generate Cargo.lock + uses: actions-rs/cargo@v1 + with: { command: generate-lockfile } + - name: Cache Dependencies + uses: Swatinem/rust-cache@v1.3.0 + + - name: doc tests + uses: actions-rs/cargo@v1 + timeout-minutes: 60 + with: { command: ci-doctest } diff --git a/.github/workflows/clippy-fmt.yml b/.github/workflows/clippy-fmt.yml index e966fa4ab..9fcb0a561 100644 --- a/.github/workflows/clippy-fmt.yml +++ b/.github/workflows/clippy-fmt.yml @@ -14,6 +14,7 @@ jobs: uses: actions-rs/toolchain@v1 with: toolchain: stable + profile: minimal components: rustfmt - name: Check with rustfmt uses: actions-rs/cargo@v1 @@ -30,10 +31,18 @@ jobs: uses: actions-rs/toolchain@v1 with: toolchain: stable + profile: minimal components: clippy override: true + + - name: Generate Cargo.lock + uses: actions-rs/cargo@v1 + with: { command: generate-lockfile } + - name: Cache Dependencies + uses: Swatinem/rust-cache@v1.2.0 + - name: Check with Clippy uses: actions-rs/clippy-check@v1 with: token: ${{ secrets.GITHUB_TOKEN }} - args: --workspace --tests --all-features + args: --workspace --tests --examples --all-features diff --git a/.github/workflows/upload-doc.yml b/.github/workflows/upload-doc.yml index c080dd8c3..94a2ddfbe 100644 --- a/.github/workflows/upload-doc.yml +++ b/.github/workflows/upload-doc.yml @@ -1,14 +1,12 @@ -name: Upload documentation +name: Upload Documentation on: push: - branches: - - master + branches: [master] jobs: build: runs-on: ubuntu-latest - if: github.repository == 'actix/actix-web' steps: - uses: actions/checkout@v2 @@ -20,14 +18,14 @@ jobs: profile: minimal override: true - - name: check build + - name: Build Docs uses: actions-rs/cargo@v1 with: command: doc args: --workspace --all-features --no-deps - name: Tweak HTML - run: echo "" > target/doc/index.html + run: echo '' > target/doc/index.html - name: Deploy to GitHub Pages uses: JamesIves/github-pages-deploy-action@3.7.1 diff --git a/.gitignore b/.gitignore index 638a4397a..543403267 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,6 @@ guide/build/ # Configuration directory generated by CLion .idea + +# Configuration directory generated by VSCode +.vscode diff --git a/CHANGES.md b/CHANGES.md index 39de18371..c379c1545 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,44 +2,350 @@ ## Unreleased - 2021-xx-xx ### Added -* Added `ServiceConfig::configure` to allow easy nesting of configuration. [#1988] +- `GuardContext::header` [#2569] +- `ServiceConfig::configure` to allow easy nesting of configuration functions. [#1988] + ### Changed -* Feature `cookies` is now optional and enabled by default. [#1981] +- `HttpResponse` can now be used as a `Responder` with any body type. [#2567] + +[#1988]: https://github.com/actix/actix-web/pull/1988 +[#2567]: https://github.com/actix/actix-web/pull/2567 +[#2569]: https://github.com/actix/actix-web/pull/2569 + + +## 4.0.0-beta.19 - 2022-01-04 +### Added +- `impl Hash` for `http::header::Encoding`. [#2501] +- `AcceptEncoding::negotiate()`. [#2501] + +### Changed +- `AcceptEncoding::preference` now returns `Option>`. [#2501] +- Rename methods `BodyEncoding::{encoding => encode_with, get_encoding => preferred_encoding}`. [#2501] +- `http::header::Encoding` now only represents `Content-Encoding` types. [#2501] + +### Fixed +- Auto-negotiation of content encoding is more fault-tolerant when using the `Compress` middleware. [#2501] + +### Removed +- `Compress::new`; restricting compression algorithm is done through feature flags. [#2501] +- `BodyEncoding` trait; signalling content encoding is now only done via the `Content-Encoding` header. [#2565] + +[#2501]: https://github.com/actix/actix-web/pull/2501 +[#2565]: https://github.com/actix/actix-web/pull/2565 + + +## 4.0.0-beta.18 - 2021-12-29 +### Changed +- Update `cookie` dependency (re-exported) to `0.16`. [#2555] +- Minimum supported Rust version (MSRV) is now 1.54. + +### Security +- `cookie` upgrade addresses [`RUSTSEC-2020-0071`]. + +[#2555]: https://github.com/actix/actix-web/pull/2555 +[`RUSTSEC-2020-0071`]: https://rustsec.org/advisories/RUSTSEC-2020-0071.html + + +## 4.0.0-beta.17 - 2021-12-29 +### Added +- `guard::GuardContext` for use with the `Guard` trait. [#2552] +- `ServiceRequest::guard_ctx` for obtaining a guard context. [#2552] + +### Changed +- `Guard` trait now receives a `&GuardContext`. [#2552] +- `guard::fn_guard` functions now receives a `&GuardContext`. [#2552] +- Some guards now return `impl Guard` and their concrete types are made private: `guard::Header` and all the method guards. [#2552] +- The `Not` guard is now generic over the type of guard it wraps. [#2552] + +### Fixed +- Rename `ConnectionInfo::{remote_addr => peer_addr}`, deprecating the old name. [#2554] +- `ConnectionInfo::peer_addr` will not return the port number. [#2554] +- `ConnectionInfo::realip_remote_addr` will not return the port number if sourcing the IP from the peer's socket address. [#2554] + +[#2552]: https://github.com/actix/actix-web/pull/2552 +[#2554]: https://github.com/actix/actix-web/pull/2554 + + +## 4.0.0-beta.16 - 2021-12-27 +### Changed +- No longer require `Scope` service body type to be boxed. [#2523] +- No longer require `Resource` service body type to be boxed. [#2526] + +[#2523]: https://github.com/actix/actix-web/pull/2523 +[#2526]: https://github.com/actix/actix-web/pull/2526 + + +## 4.0.0-beta.15 - 2021-12-17 +### Added +- Method on `Responder` trait (`customize`) for customizing responders and `CustomizeResponder` struct. [#2510] +- Implement `Debug` for `DefaultHeaders`. [#2510] + +### Changed +- Align `DefaultHeader` method terminology, deprecating previous methods. [#2510] +- Response service types in `ErrorHandlers` middleware now use `ServiceResponse>` to allow changing the body type. [#2515] +- Both variants in `ErrorHandlerResponse` now use `ServiceResponse>`. [#2515] +- Rename `test::{default_service => simple_service}`. Old name is deprecated. [#2518] +- Rename `test::{read_response_json => call_and_read_body_json}`. Old name is deprecated. [#2518] +- Rename `test::{read_response => call_and_read_body}`. Old name is deprecated. [#2518] +- Relax body type and error bounds on test utilities. [#2518] + +### Removed +- Top-level `EitherExtractError` export. [#2510] +- Conversion implementations for `either` crate. [#2516] +- `test::load_stream` and `test::load_body`; replace usage with `body::to_bytes`. [#2518] + +[#2510]: https://github.com/actix/actix-web/pull/2510 +[#2515]: https://github.com/actix/actix-web/pull/2515 +[#2516]: https://github.com/actix/actix-web/pull/2516 +[#2518]: https://github.com/actix/actix-web/pull/2518 + + +## 4.0.0-beta.14 - 2021-12-11 +### Added +- Methods on `AcceptLanguage`: `ranked` and `preference`. [#2480] +- `AcceptEncoding` typed header. [#2482] +- `Range` typed header. [#2485] +- `HttpResponse::map_into_{left,right}_body` and `HttpResponse::map_into_boxed_body`. [#2468] +- `ServiceResponse::map_into_{left,right}_body` and `HttpResponse::map_into_boxed_body`. [#2468] +- Connection data set through the `HttpServer::on_connect` callback is now accessible only from the new `HttpRequest::conn_data()` and `ServiceRequest::conn_data()` methods. [#2491] +- `HttpRequest::{req_data,req_data_mut}`. [#2487] +- `ServiceResponse::into_parts`. [#2499] + +### Changed +- Rename `Accept::{mime_precedence => ranked}`. [#2480] +- Rename `Accept::{mime_preference => preference}`. [#2480] +- Un-deprecate `App::data_factory`. [#2484] +- `HttpRequest::url_for` no longer constructs URLs with query or fragment components. [#2430] +- Remove `B` (body) type parameter on `App`. [#2493] +- Add `B` (body) type parameter on `Scope`. [#2492] +- Request-local data container is no longer part of a `RequestHead`. Instead it is a distinct part of a `Request`. [#2487] + +### Fixed +- Accept wildcard `*` items in `AcceptLanguage`. [#2480] +- Re-exports `dev::{BodySize, MessageBody, SizedStream}`. They are exposed through the `body` module. [#2468] +- Typed headers containing lists that require one or more items now enforce this minimum. [#2482] + +### Removed +- `ConnectionInfo::get`. [#2487] + +[#2430]: https://github.com/actix/actix-web/pull/2430 +[#2468]: https://github.com/actix/actix-web/pull/2468 +[#2480]: https://github.com/actix/actix-web/pull/2480 +[#2482]: https://github.com/actix/actix-web/pull/2482 +[#2484]: https://github.com/actix/actix-web/pull/2484 +[#2485]: https://github.com/actix/actix-web/pull/2485 +[#2487]: https://github.com/actix/actix-web/pull/2487 +[#2491]: https://github.com/actix/actix-web/pull/2491 +[#2492]: https://github.com/actix/actix-web/pull/2492 +[#2493]: https://github.com/actix/actix-web/pull/2493 +[#2499]: https://github.com/actix/actix-web/pull/2499 + + +## 4.0.0-beta.13 - 2021-11-30 +### Changed +- Update `actix-tls` to `3.0.0-rc.1`. [#2474] + +[#2474]: https://github.com/actix/actix-web/pull/2474 + + +## 4.0.0-beta.12 - 2021-11-22 +### Changed +- Compress middleware's response type is now `AnyBody>`. [#2448] + +### Fixed +- Relax `Unpin` bound on `S` (stream) parameter of `HttpResponseBuilder::streaming`. [#2448] + +### Removed +- `dev::ResponseBody` re-export; is function is replaced by the new `dev::AnyBody` enum. [#2446] + +[#2446]: https://github.com/actix/actix-web/pull/2446 +[#2448]: https://github.com/actix/actix-web/pull/2448 + + +## 4.0.0-beta.11 - 2021-11-15 +### Added +- Re-export `dev::ServerHandle` from `actix-server`. [#2442] + +### Changed +- `ContentType::html` now produces `text/html; charset=utf-8` instead of `text/html`. [#2423] +- Update `actix-server` to `2.0.0-beta.9`. [#2442] + +[#2423]: https://github.com/actix/actix-web/pull/2423 +[#2442]: https://github.com/actix/actix-web/pull/2442 + + +## 4.0.0-beta.10 - 2021-10-20 +### Added +- Option to allow `Json` extractor to work without a `Content-Type` header present. [#2362] +- `#[actix_web::test]` macro for setting up tests with a runtime. [#2409] + +### Changed +- Associated type `FromRequest::Config` was removed. [#2233] +- Inner field made private on `web::Payload`. [#2384] +- `Data::into_inner` and `Data::get_ref` no longer requires `T: Sized`. [#2403] +- Updated rustls to v0.20. [#2414] +- Minimum supported Rust version (MSRV) is now 1.52. + +### Removed +- Useless `ServiceResponse::checked_expr` method. [#2401] + +[#2233]: https://github.com/actix/actix-web/pull/2233 +[#2362]: https://github.com/actix/actix-web/pull/2362 +[#2384]: https://github.com/actix/actix-web/pull/2384 +[#2401]: https://github.com/actix/actix-web/pull/2401 +[#2403]: https://github.com/actix/actix-web/pull/2403 +[#2409]: https://github.com/actix/actix-web/pull/2409 +[#2414]: https://github.com/actix/actix-web/pull/2414 + + +## 4.0.0-beta.9 - 2021-09-09 +### Added +- Re-export actix-service `ServiceFactory` in `dev` module. [#2325] + +### Changed +- Compress middleware will return 406 Not Acceptable when no content encoding is acceptable to the client. [#2344] +- Move `BaseHttpResponse` to `dev::Response`. [#2379] +- Enable `TestRequest::param` to accept more than just static strings. [#2172] +- Minimum supported Rust version (MSRV) is now 1.51. + +### Fixed +- Fix quality parse error in Accept-Encoding header. [#2344] +- Re-export correct type at `web::HttpResponse`. [#2379] + +[#2172]: https://github.com/actix/actix-web/pull/2172 +[#2325]: https://github.com/actix/actix-web/pull/2325 +[#2344]: https://github.com/actix/actix-web/pull/2344 +[#2379]: https://github.com/actix/actix-web/pull/2379 + + +## 4.0.0-beta.8 - 2021-06-26 +### Added +- Add `ServiceRequest::parts_mut`. [#2177] +- Add extractors for `Uri` and `Method`. [#2263] +- Add extractors for `ConnectionInfo` and `PeerAddr`. [#2263] +- Add `Route::service` for using hand-written services as handlers. [#2262] + +### Changed +- Change compression algorithm features flags. [#2250] +- Deprecate `App::data` and `App::data_factory`. [#2271] +- Smarter extraction of `ConnectionInfo` parts. [#2282] + +### Fixed +- Scope and Resource middleware can access data items set on their own layer. [#2288] + +[#2177]: https://github.com/actix/actix-web/pull/2177 +[#2250]: https://github.com/actix/actix-web/pull/2250 +[#2271]: https://github.com/actix/actix-web/pull/2271 +[#2262]: https://github.com/actix/actix-web/pull/2262 +[#2263]: https://github.com/actix/actix-web/pull/2263 +[#2282]: https://github.com/actix/actix-web/pull/2282 +[#2288]: https://github.com/actix/actix-web/pull/2288 + + +## 4.0.0-beta.7 - 2021-06-17 +### Added +- `HttpServer::worker_max_blocking_threads` for setting block thread pool. [#2200] + +### Changed +- Adjusted default JSON payload limit to 2MB (from 32kb) and included size and limits in the `JsonPayloadError::Overflow` error variant. [#2162] +[#2162]: (https://github.com/actix/actix-web/pull/2162) +- `ServiceResponse::error_response` now uses body type of `Body`. [#2201] +- `ServiceResponse::checked_expr` now returns a `Result`. [#2201] +- Update `language-tags` to `0.3`. +- `ServiceResponse::take_body`. [#2201] +- `ServiceResponse::map_body` closure receives and returns `B` instead of `ResponseBody` types. [#2201] +- All error trait bounds in server service builders have changed from `Into` to `Into>`. [#2253] +- All error trait bounds in message body and stream impls changed from `Into` to `Into>`. [#2253] +- `HttpServer::{listen_rustls(), bind_rustls()}` now honor the ALPN protocols in the configuation parameter. [#2226] +- `middleware::normalize` now will not try to normalize URIs with no valid path [#2246] + +### Removed +- `HttpResponse::take_body` and old `HttpResponse::into_body` method that casted body type. [#2201] + +[#2200]: https://github.com/actix/actix-web/pull/2200 +[#2201]: https://github.com/actix/actix-web/pull/2201 +[#2253]: https://github.com/actix/actix-web/pull/2253 +[#2246]: https://github.com/actix/actix-web/pull/2246 + + +## 4.0.0-beta.6 - 2021-04-17 +### Added +- `HttpResponse` and `HttpResponseBuilder` structs. [#2065] + +### Changed +- Most error types are now marked `#[non_exhaustive]`. [#2148] +- Methods on `ContentDisposition` that took `T: AsRef` now take `impl AsRef`. + +[#2065]: https://github.com/actix/actix-web/pull/2065 +[#2148]: https://github.com/actix/actix-web/pull/2148 + + +## 4.0.0-beta.5 - 2021-04-02 +### Added +- `Header` extractor for extracting common HTTP headers in handlers. [#2094] +- Added `TestServer::client_headers` method. [#2097] + +### Fixed +- Double ampersand in Logger format is escaped correctly. [#2067] + +### Changed +- `CustomResponder` would return error as `HttpResponse` when `CustomResponder::with_header` failed + instead of skipping. (Only the first error is kept when multiple error occur) [#2093] + +### Removed +- The `client` mod was removed. Clients should now use `awc` directly. + [871ca5e4](https://github.com/actix/actix-web/commit/871ca5e4ae2bdc22d1ea02701c2992fa8d04aed7) +- Integration testing was moved to new `actix-test` crate. Namely these items from the `test` + module: `TestServer`, `TestServerConfig`, `start`, `start_with`, and `unused_addr`. [#2112] + +[#2067]: https://github.com/actix/actix-web/pull/2067 +[#2093]: https://github.com/actix/actix-web/pull/2093 +[#2094]: https://github.com/actix/actix-web/pull/2094 +[#2097]: https://github.com/actix/actix-web/pull/2097 +[#2112]: https://github.com/actix/actix-web/pull/2112 + + +## 4.0.0-beta.4 - 2021-03-09 +### Changed +- Feature `cookies` is now optional and enabled by default. [#1981] +- `JsonBody::new` returns a default limit of 32kB to be consistent with `JsonConfig` and the default + behaviour of the `web::Json` extractor. [#2010] [#1981]: https://github.com/actix/actix-web/pull/1981 -[#1988]: https://github.com/actix/actix-web/pull/1988 +[#2010]: https://github.com/actix/actix-web/pull/2010 + ## 4.0.0-beta.3 - 2021-02-10 -* Update `actix-web-codegen` to `0.5.0-beta.1`. +- Update `actix-web-codegen` to `0.5.0-beta.1`. ## 4.0.0-beta.2 - 2021-02-10 ### Added -* The method `Either, web::Form>::into_inner()` which returns the inner type for +- The method `Either, web::Form>::into_inner()` which returns the inner type for whichever variant was created. Also works for `Either, web::Json>`. [#1894] -* Add `services!` macro for helping register multiple services to `App`. [#1933] -* Enable registering a vec of services of the same type to `App` [#1933] +- Add `services!` macro for helping register multiple services to `App`. [#1933] +- Enable registering a vec of services of the same type to `App` [#1933] ### Changed -* Rework `Responder` trait to be sync and returns `Response`/`HttpResponse` directly. +- Rework `Responder` trait to be sync and returns `Response`/`HttpResponse` directly. Making it simpler and more performant. [#1891] -* `ServiceRequest::into_parts` and `ServiceRequest::from_parts` can no longer fail. [#1893] -* `ServiceRequest::from_request` can no longer fail. [#1893] -* Our `Either` type now uses `Left`/`Right` variants (instead of `A`/`B`) [#1894] -* `test::{call_service, read_response, read_response_json, send_request}` take `&Service` +- `ServiceRequest::into_parts` and `ServiceRequest::from_parts` can no longer fail. [#1893] +- `ServiceRequest::from_request` can no longer fail. [#1893] +- Our `Either` type now uses `Left`/`Right` variants (instead of `A`/`B`) [#1894] +- `test::{call_service, read_response, read_response_json, send_request}` take `&Service` in argument [#1905] -* `App::wrap_fn`, `Resource::wrap_fn` and `Scope::wrap_fn` provide `&Service` in closure +- `App::wrap_fn`, `Resource::wrap_fn` and `Scope::wrap_fn` provide `&Service` in closure argument. [#1905] -* `web::block` no longer requires the output is a Result. [#1957] +- `web::block` no longer requires the output is a Result. [#1957] ### Fixed -* Multiple calls to `App::data` with the same type now keeps the latest call's data. [#1906] +- Multiple calls to `App::data` with the same type now keeps the latest call's data. [#1906] ### Removed -* Public field of `web::Path` has been made private. [#1894] -* Public field of `web::Query` has been made private. [#1894] -* `TestRequest::with_header`; use `TestRequest::default().insert_header()`. [#1869] -* `AppService::set_service_data`; for custom HTTP service factories adding application data, use the +- Public field of `web::Path` has been made private. [#1894] +- Public field of `web::Query` has been made private. [#1894] +- `TestRequest::with_header`; use `TestRequest::default().insert_header()`. [#1869] +- `AppService::set_service_data`; for custom HTTP service factories adding application data, use the layered data model by calling `ServiceRequest::add_data_container` when handling requests instead. [#1906] @@ -55,26 +361,26 @@ ## 4.0.0-beta.1 - 2021-01-07 ### Added -* `Compat` middleware enabling generic response body/error type of middlewares like `Logger` and +- `Compat` middleware enabling generic response body/error type of middlewares like `Logger` and `Compress` to be used in `middleware::Condition` and `Resource`, `Scope` services. [#1865] ### Changed -* Update `actix-*` dependencies to tokio `1.0` based versions. [#1813] -* Bumped `rand` to `0.8`. -* Update `rust-tls` to `0.19`. [#1813] -* Rename `Handler` to `HandlerService` and rename `Factory` to `Handler`. [#1852] -* The default `TrailingSlash` is now `Trim`, in line with existing documentation. See migration +- Update `actix-*` dependencies to tokio `1.0` based versions. [#1813] +- Bumped `rand` to `0.8`. +- Update `rust-tls` to `0.19`. [#1813] +- Rename `Handler` to `HandlerService` and rename `Factory` to `Handler`. [#1852] +- The default `TrailingSlash` is now `Trim`, in line with existing documentation. See migration guide for implications. [#1875] -* Rename `DefaultHeaders::{content_type => add_content_type}`. [#1875] -* MSRV is now 1.46.0. +- Rename `DefaultHeaders::{content_type => add_content_type}`. [#1875] +- MSRV is now 1.46.0. ### Fixed -* Added the underlying parse error to `test::read_body_json`'s panic message. [#1812] +- Added the underlying parse error to `test::read_body_json`'s panic message. [#1812] ### Removed -* Public modules `middleware::{normalize, err_handlers}`. All necessary middleware structs are now +- Public modules `middleware::{normalize, err_handlers}`. All necessary middleware structs are now exposed directly by the `middleware` module. -* Remove `actix-threadpool` as dependency. `actix_threadpool::BlockingError` error type can be imported +- Remove `actix-threadpool` as dependency. `actix_threadpool::BlockingError` error type can be imported from `actix_web::error` module. [#1878] [#1812]: https://github.com/actix/actix-web/pull/1812 @@ -84,11 +390,19 @@ [#1875]: https://github.com/actix/actix-web/pull/1875 [#1878]: https://github.com/actix/actix-web/pull/1878 + +## 3.3.3 - 2021-12-18 +### Changed +- Soft-deprecate `NormalizePath::default()`, noting upcoming behavior change in v4. [#2529] + +[#2529]: https://github.com/actix/actix-web/pull/2529 + + ## 3.3.2 - 2020-12-01 ### Fixed -* Removed an occasional `unwrap` on `None` panic in `NormalizePathNormalization`. [#1762] -* Fix `match_pattern()` returning `None` for scope with empty path resource. [#1798] -* Increase minimum `socket2` version. [#1803] +- Removed an occasional `unwrap` on `None` panic in `NormalizePathNormalization`. [#1762] +- Fix `match_pattern()` returning `None` for scope with empty path resource. [#1798] +- Increase minimum `socket2` version. [#1803] [#1762]: https://github.com/actix/actix-web/pull/1762 [#1798]: https://github.com/actix/actix-web/pull/1798 @@ -96,15 +410,15 @@ ## 3.3.1 - 2020-11-29 -* Ensure `actix-http` dependency uses same `serde_urlencoded`. +- Ensure `actix-http` dependency uses same `serde_urlencoded`. ## 3.3.0 - 2020-11-25 ### Added -* Add `Either` extractor helper. [#1788] +- Add `Either` extractor helper. [#1788] ### Changed -* Upgrade `serde_urlencoded` to `0.7`. [#1773] +- Upgrade `serde_urlencoded` to `0.7`. [#1773] [#1773]: https://github.com/actix/actix-web/pull/1773 [#1788]: https://github.com/actix/actix-web/pull/1788 @@ -112,17 +426,17 @@ ## 3.2.0 - 2020-10-30 ### Added -* Implement `exclude_regex` for Logger middleware. [#1723] -* Add request-local data extractor `web::ReqData`. [#1748] -* Add ability to register closure for request middleware logging. [#1749] -* Add `app_data` to `ServiceConfig`. [#1757] -* Expose `on_connect` for access to the connection stream before request is handled. [#1754] +- Implement `exclude_regex` for Logger middleware. [#1723] +- Add request-local data extractor `web::ReqData`. [#1748] +- Add ability to register closure for request middleware logging. [#1749] +- Add `app_data` to `ServiceConfig`. [#1757] +- Expose `on_connect` for access to the connection stream before request is handled. [#1754] ### Changed -* Updated actix-web-codegen dependency for access to new `#[route(...)]` multi-method macro. -* Print non-configured `Data` type when attempting extraction. [#1743] -* Re-export bytes::Buf{Mut} in web module. [#1750] -* Upgrade `pin-project` to `1.0`. +- Updated actix-web-codegen dependency for access to new `#[route(...)]` multi-method macro. +- Print non-configured `Data` type when attempting extraction. [#1743] +- Re-export bytes::Buf{Mut} in web module. [#1750] +- Upgrade `pin-project` to `1.0`. [#1723]: https://github.com/actix/actix-web/pull/1723 [#1743]: https://github.com/actix/actix-web/pull/1743 @@ -134,13 +448,13 @@ ## 3.1.0 - 2020-09-29 ### Changed -* Add `TrailingSlash::MergeOnly` behaviour to `NormalizePath`, which allows `NormalizePath` +- Add `TrailingSlash::MergeOnly` behaviour to `NormalizePath`, which allows `NormalizePath` to retain any trailing slashes. [#1695] -* Remove bound `std::marker::Sized` from `web::Data` to support storing `Arc` +- Remove bound `std::marker::Sized` from `web::Data` to support storing `Arc` via `web::Data::from` [#1710] ### Fixed -* `ResourceMap` debug printing is no longer infinitely recursive. [#1708] +- `ResourceMap` debug printing is no longer infinitely recursive. [#1708] [#1695]: https://github.com/actix/actix-web/pull/1695 [#1708]: https://github.com/actix/actix-web/pull/1708 @@ -149,33 +463,33 @@ ## 3.0.2 - 2020-09-15 ### Fixed -* `NormalizePath` when used with `TrailingSlash::Trim` no longer trims the root path "/". [#1678] +- `NormalizePath` when used with `TrailingSlash::Trim` no longer trims the root path "/". [#1678] [#1678]: https://github.com/actix/actix-web/pull/1678 ## 3.0.1 - 2020-09-13 ### Changed -* `middleware::normalize::TrailingSlash` enum is now accessible. [#1673] +- `middleware::normalize::TrailingSlash` enum is now accessible. [#1673] [#1673]: https://github.com/actix/actix-web/pull/1673 ## 3.0.0 - 2020-09-11 -* No significant changes from `3.0.0-beta.4`. +- No significant changes from `3.0.0-beta.4`. ## 3.0.0-beta.4 - 2020-09-09 ### Added -* `middleware::NormalizePath` now has configurable behaviour for either always having a trailing +- `middleware::NormalizePath` now has configurable behavior for either always having a trailing slash, or as the new addition, always trimming trailing slashes. [#1639] ### Changed -* Update actix-codec and actix-utils dependencies. [#1634] -* `FormConfig` and `JsonConfig` configurations are now also considered when set +- Update actix-codec and actix-utils dependencies. [#1634] +- `FormConfig` and `JsonConfig` configurations are now also considered when set using `App::data`. [#1641] -* `HttpServer::maxconn` is renamed to the more expressive `HttpServer::max_connections`. [#1655] -* `HttpServer::maxconnrate` is renamed to the more expressive +- `HttpServer::maxconn` is renamed to the more expressive `HttpServer::max_connections`. [#1655] +- `HttpServer::maxconnrate` is renamed to the more expressive `HttpServer::max_connection_rate`. [#1655] [#1639]: https://github.com/actix/actix-web/pull/1639 @@ -185,22 +499,22 @@ ## 3.0.0-beta.3 - 2020-08-17 ### Changed -* Update `rustls` to 0.18 +- Update `rustls` to 0.18 ## 3.0.0-beta.2 - 2020-08-17 ### Changed -* `PayloadConfig` is now also considered in `Bytes` and `String` extractors when set +- `PayloadConfig` is now also considered in `Bytes` and `String` extractors when set using `App::data`. [#1610] -* `web::Path` now has a public representation: `web::Path(pub T)` that enables +- `web::Path` now has a public representation: `web::Path(pub T)` that enables destructuring. [#1594] -* `ServiceRequest::app_data` allows retrieval of non-Data data without splitting into parts to +- `ServiceRequest::app_data` allows retrieval of non-Data data without splitting into parts to access `HttpRequest` which already allows this. [#1618] -* Re-export all error types from `awc`. [#1621] -* MSRV is now 1.42.0. +- Re-export all error types from `awc`. [#1621] +- MSRV is now 1.42.0. ### Fixed -* Memory leak of app data in pooled requests. [#1609] +- Memory leak of app data in pooled requests. [#1609] [#1594]: https://github.com/actix/actix-web/pull/1594 [#1609]: https://github.com/actix/actix-web/pull/1609 @@ -211,29 +525,29 @@ ## 3.0.0-beta.1 - 2020-07-13 ### Added -* Re-export `actix_rt::main` as `actix_web::main`. -* `HttpRequest::match_pattern` and `ServiceRequest::match_pattern` for extracting the matched +- Re-export `actix_rt::main` as `actix_web::main`. +- `HttpRequest::match_pattern` and `ServiceRequest::match_pattern` for extracting the matched resource pattern. -* `HttpRequest::match_name` and `ServiceRequest::match_name` for extracting matched resource name. +- `HttpRequest::match_name` and `ServiceRequest::match_name` for extracting matched resource name. ### Changed -* Fix actix_http::h1::dispatcher so it returns when HW_BUFFER_SIZE is reached. Should reduce peak memory consumption during large uploads. [#1550] -* Migrate cookie handling to `cookie` crate. Actix-web no longer requires `ring` dependency. -* MSRV is now 1.41.1 +- Fix actix_http::h1::dispatcher so it returns when HW_BUFFER_SIZE is reached. Should reduce peak memory consumption during large uploads. [#1550] +- Migrate cookie handling to `cookie` crate. Actix-web no longer requires `ring` dependency. +- MSRV is now 1.41.1 ### Fixed -* `NormalizePath` improved consistency when path needs slashes added _and_ removed. +- `NormalizePath` improved consistency when path needs slashes added _and_ removed. ## 3.0.0-alpha.3 - 2020-05-21 ### Added -* Add option to create `Data` from `Arc` [#1509] +- Add option to create `Data` from `Arc` [#1509] ### Changed -* Resources and Scopes can now access non-overridden data types set on App (or containing scopes) when setting their own data. [#1486] -* Fix audit issue logging by default peer address [#1485] -* Bump minimum supported Rust version to 1.40 -* Replace deprecated `net2` crate with `socket2` +- Resources and Scopes can now access non-overridden data types set on App (or containing scopes) when setting their own data. [#1486] +- Fix audit issue logging by default peer address [#1485] +- Bump minimum supported Rust version to 1.40 +- Replace deprecated `net2` crate with `socket2` [#1485]: https://github.com/actix/actix-web/pull/1485 [#1509]: https://github.com/actix/actix-web/pull/1509 @@ -242,10 +556,10 @@ ### Changed -* `{Resource,Scope}::default_service(f)` handlers now support app data extraction. [#1452] -* Implement `std::error::Error` for our custom errors [#1422] -* NormalizePath middleware now appends trailing / so that routes of form /example/ respond to /example requests. [#1433] -* Remove the `failure` feature and support. +- `{Resource,Scope}::default_service(f)` handlers now support app data extraction. [#1452] +- Implement `std::error::Error` for our custom errors [#1422] +- NormalizePath middleware now appends trailing / so that routes of form /example/ respond to /example requests. [#1433] +- Remove the `failure` feature and support. [#1422]: https://github.com/actix/actix-web/pull/1422 [#1433]: https://github.com/actix/actix-web/pull/1433 @@ -257,16 +571,16 @@ ### Added -* Add helper function for creating routes with `TRACE` method guard `web::trace()` -* Add convenience functions `test::read_body_json()` and `test::TestRequest::send_request()` for testing. +- Add helper function for creating routes with `TRACE` method guard `web::trace()` +- Add convenience functions `test::read_body_json()` and `test::TestRequest::send_request()` for testing. ### Changed -* Use `sha-1` crate instead of unmaintained `sha1` crate -* Skip empty chunks when returning response from a `Stream` [#1308] -* Update the `time` dependency to 0.2.7 -* Update `actix-tls` dependency to 2.0.0-alpha.1 -* Update `rustls` dependency to 0.17 +- Use `sha-1` crate instead of unmaintained `sha1` crate +- Skip empty chunks when returning response from a `Stream` [#1308] +- Update the `time` dependency to 0.2.7 +- Update `actix-tls` dependency to 2.0.0-alpha.1 +- Update `rustls` dependency to 0.17 [#1308]: https://github.com/actix/actix-web/pull/1308 @@ -274,408 +588,408 @@ ### Changed -* Rename `HttpServer::start()` to `HttpServer::run()` +- Rename `HttpServer::start()` to `HttpServer::run()` -* Allow to gracefully stop test server via `TestServer::stop()` +- Allow to gracefully stop test server via `TestServer::stop()` -* Allow to specify multi-patterns for resources +- Allow to specify multi-patterns for resources ## [2.0.0-rc] - 2019-12-20 ### Changed -* Move `BodyEncoding` to `dev` module #1220 +- Move `BodyEncoding` to `dev` module #1220 -* Allow to set `peer_addr` for TestRequest #1074 +- Allow to set `peer_addr` for TestRequest #1074 -* Make web::Data deref to Arc #1214 +- Make web::Data deref to Arc #1214 -* Rename `App::register_data()` to `App::app_data()` +- Rename `App::register_data()` to `App::app_data()` -* `HttpRequest::app_data()` returns `Option<&T>` instead of `Option<&Data>` +- `HttpRequest::app_data()` returns `Option<&T>` instead of `Option<&Data>` ### Fixed -* Fix `AppConfig::secure()` is always false. #1202 +- Fix `AppConfig::secure()` is always false. #1202 ## [2.0.0-alpha.6] - 2019-12-15 ### Fixed -* Fixed compilation with default features off +- Fixed compilation with default features off ## [2.0.0-alpha.5] - 2019-12-13 ### Added -* Add test server, `test::start()` and `test::start_with()` +- Add test server, `test::start()` and `test::start_with()` ## [2.0.0-alpha.4] - 2019-12-08 ### Deleted -* Delete HttpServer::run(), it is not useful with async/await +- Delete HttpServer::run(), it is not useful with async/await ## [2.0.0-alpha.3] - 2019-12-07 ### Changed -* Migrate to tokio 0.2 +- Migrate to tokio 0.2 ## [2.0.0-alpha.1] - 2019-11-22 ### Changed -* Migrated to `std::future` +- Migrated to `std::future` -* Remove implementation of `Responder` for `()`. (#1167) +- Remove implementation of `Responder` for `()`. (#1167) ## [1.0.9] - 2019-11-14 ### Added -* Add `Payload::into_inner` method and make stored `def::Payload` public. (#1110) +- Add `Payload::into_inner` method and make stored `def::Payload` public. (#1110) ### Changed -* Support `Host` guards when the `Host` header is unset (e.g. HTTP/2 requests) (#1129) +- Support `Host` guards when the `Host` header is unset (e.g. HTTP/2 requests) (#1129) ## [1.0.8] - 2019-09-25 ### Added -* Add `Scope::register_data` and `Resource::register_data` methods, parallel to +- Add `Scope::register_data` and `Resource::register_data` methods, parallel to `App::register_data`. -* Add `middleware::Condition` that conditionally enables another middleware +- Add `middleware::Condition` that conditionally enables another middleware -* Allow to re-construct `ServiceRequest` from `HttpRequest` and `Payload` +- Allow to re-construct `ServiceRequest` from `HttpRequest` and `Payload` -* Add `HttpServer::listen_uds` for ability to listen on UDS FD rather than path, +- Add `HttpServer::listen_uds` for ability to listen on UDS FD rather than path, which is useful for example with systemd. ### Changed -* Make UrlEncodedError::Overflow more informative +- Make UrlEncodedError::Overflow more informative -* Use actix-testing for testing utils +- Use actix-testing for testing utils ## [1.0.7] - 2019-08-29 ### Fixed -* Request Extensions leak #1062 +- Request Extensions leak #1062 ## [1.0.6] - 2019-08-28 ### Added -* Re-implement Host predicate (#989) +- Re-implement Host predicate (#989) -* Form implements Responder, returning a `application/x-www-form-urlencoded` response +- Form implements Responder, returning a `application/x-www-form-urlencoded` response -* Add `into_inner` to `Data` +- Add `into_inner` to `Data` -* Add `test::TestRequest::set_form()` convenience method to automatically serialize data and set +- Add `test::TestRequest::set_form()` convenience method to automatically serialize data and set the header in test requests. ### Changed -* `Query` payload made `pub`. Allows user to pattern-match the payload. +- `Query` payload made `pub`. Allows user to pattern-match the payload. -* Enable `rust-tls` feature for client #1045 +- Enable `rust-tls` feature for client #1045 -* Update serde_urlencoded to 0.6.1 +- Update serde_urlencoded to 0.6.1 -* Update url to 2.1 +- Update url to 2.1 ## [1.0.5] - 2019-07-18 ### Added -* Unix domain sockets (HttpServer::bind_uds) #92 +- Unix domain sockets (HttpServer::bind_uds) #92 -* Actix now logs errors resulting in "internal server error" responses always, with the `error` +- Actix now logs errors resulting in "internal server error" responses always, with the `error` logging level ### Fixed -* Restored logging of errors through the `Logger` middleware +- Restored logging of errors through the `Logger` middleware ## [1.0.4] - 2019-07-17 ### Added -* Add `Responder` impl for `(T, StatusCode) where T: Responder` +- Add `Responder` impl for `(T, StatusCode) where T: Responder` -* Allow to access app's resource map via +- Allow to access app's resource map via `ServiceRequest::resource_map()` and `HttpRequest::resource_map()` methods. ### Changed -* Upgrade `rand` dependency version to 0.7 +- Upgrade `rand` dependency version to 0.7 ## [1.0.3] - 2019-06-28 ### Added -* Support asynchronous data factories #850 +- Support asynchronous data factories #850 ### Changed -* Use `encoding_rs` crate instead of unmaintained `encoding` crate +- Use `encoding_rs` crate instead of unmaintained `encoding` crate ## [1.0.2] - 2019-06-17 ### Changed -* Move cors middleware to `actix-cors` crate. +- Move cors middleware to `actix-cors` crate. -* Move identity middleware to `actix-identity` crate. +- Move identity middleware to `actix-identity` crate. ## [1.0.1] - 2019-06-17 ### Added -* Add support for PathConfig #903 +- Add support for PathConfig #903 -* Add `middleware::identity::RequestIdentity` trait to `get_identity` from `HttpMessage`. +- Add `middleware::identity::RequestIdentity` trait to `get_identity` from `HttpMessage`. ### Changed -* Move cors middleware to `actix-cors` crate. +- Move cors middleware to `actix-cors` crate. -* Move identity middleware to `actix-identity` crate. +- Move identity middleware to `actix-identity` crate. -* Disable default feature `secure-cookies`. +- Disable default feature `secure-cookies`. -* Allow to test an app that uses async actors #897 +- Allow to test an app that uses async actors #897 -* Re-apply patch from #637 #894 +- Re-apply patch from #637 #894 ### Fixed -* HttpRequest::url_for is broken with nested scopes #915 +- HttpRequest::url_for is broken with nested scopes #915 ## [1.0.0] - 2019-06-05 ### Added -* Add `Scope::configure()` method. +- Add `Scope::configure()` method. -* Add `ServiceRequest::set_payload()` method. +- Add `ServiceRequest::set_payload()` method. -* Add `test::TestRequest::set_json()` convenience method to automatically +- Add `test::TestRequest::set_json()` convenience method to automatically serialize data and set header in test requests. -* Add macros for head, options, trace, connect and patch http methods +- Add macros for head, options, trace, connect and patch http methods ### Changed -* Drop an unnecessary `Option<_>` indirection around `ServerBuilder` from `HttpServer`. #863 +- Drop an unnecessary `Option<_>` indirection around `ServerBuilder` from `HttpServer`. #863 ### Fixed -* Fix Logger request time format, and use rfc3339. #867 +- Fix Logger request time format, and use rfc3339. #867 -* Clear http requests pool on app service drop #860 +- Clear http requests pool on app service drop #860 ## [1.0.0-rc] - 2019-05-18 -### Add +### Added -* Add `Query::from_query()` to extract parameters from a query string. #846 -* `QueryConfig`, similar to `JsonConfig` for customizing error handling of query extractors. +- Add `Query::from_query()` to extract parameters from a query string. #846 +- `QueryConfig`, similar to `JsonConfig` for customizing error handling of query extractors. ### Changed -* `JsonConfig` is now `Send + Sync`, this implies that `error_handler` must be `Send + Sync` too. +- `JsonConfig` is now `Send + Sync`, this implies that `error_handler` must be `Send + Sync` too. ### Fixed -* Codegen with parameters in the path only resolves the first registered endpoint #841 +- Codegen with parameters in the path only resolves the first registered endpoint #841 ## [1.0.0-beta.4] - 2019-05-12 -### Add +### Added -* Allow to set/override app data on scope level +- Allow to set/override app data on scope level ### Changed -* `App::configure` take an `FnOnce` instead of `Fn` -* Upgrade actix-net crates +- `App::configure` take an `FnOnce` instead of `Fn` +- Upgrade actix-net crates ## [1.0.0-beta.3] - 2019-05-04 ### Added -* Add helper function for executing futures `test::block_fn()` +- Add helper function for executing futures `test::block_fn()` ### Changed -* Extractor configuration could be registered with `App::data()` +- Extractor configuration could be registered with `App::data()` or with `Resource::data()` #775 -* Route data is unified with app data, `Route::data()` moved to resource +- Route data is unified with app data, `Route::data()` moved to resource level to `Resource::data()` -* CORS handling without headers #702 +- CORS handling without headers #702 -* Allow to construct `Data` instances to avoid double `Arc` for `Send + Sync` types. +- Allow constructing `Data` instances to avoid double `Arc` for `Send + Sync` types. ### Fixed -* Fix `NormalizePath` middleware impl #806 +- Fix `NormalizePath` middleware impl #806 ### Deleted -* `App::data_factory()` is deleted. +- `App::data_factory()` is deleted. ## [1.0.0-beta.2] - 2019-04-24 ### Added -* Add raw services support via `web::service()` +- Add raw services support via `web::service()` -* Add helper functions for reading response body `test::read_body()` +- Add helper functions for reading response body `test::read_body()` -* Add support for `remainder match` (i.e "/path/{tail}*") +- Add support for `remainder match` (i.e "/path/{tail}*") -* Extend `Responder` trait, allow to override status code and headers. +- Extend `Responder` trait, allow to override status code and headers. -* Store visit and login timestamp in the identity cookie #502 +- Store visit and login timestamp in the identity cookie #502 ### Changed -* `.to_async()` handler can return `Responder` type #792 +- `.to_async()` handler can return `Responder` type #792 ### Fixed -* Fix async web::Data factory handling +- Fix async web::Data factory handling ## [1.0.0-beta.1] - 2019-04-20 ### Added -* Add helper functions for reading test response body, +- Add helper functions for reading test response body, `test::read_response()` and test::read_response_json()` -* Add `.peer_addr()` #744 +- Add `.peer_addr()` #744 -* Add `NormalizePath` middleware +- Add `NormalizePath` middleware ### Changed -* Rename `RouterConfig` to `ServiceConfig` +- Rename `RouterConfig` to `ServiceConfig` -* Rename `test::call_success` to `test::call_service` +- Rename `test::call_success` to `test::call_service` -* Removed `ServiceRequest::from_parts()` as it is unsafe to create from parts. +- Removed `ServiceRequest::from_parts()` as it is unsafe to create from parts. -* `CookieIdentityPolicy::max_age()` accepts value in seconds +- `CookieIdentityPolicy::max_age()` accepts value in seconds ### Fixed -* Fixed `TestRequest::app_data()` +- Fixed `TestRequest::app_data()` ## [1.0.0-alpha.6] - 2019-04-14 ### Changed -* Allow to use any service as default service. +- Allow using any service as default service. -* Remove generic type for request payload, always use default. +- Remove generic type for request payload, always use default. -* Removed `Decompress` middleware. Bytes, String, Json, Form extractors +- Removed `Decompress` middleware. Bytes, String, Json, Form extractors automatically decompress payload. -* Make extractor config type explicit. Add `FromRequest::Config` associated type. +- Make extractor config type explicit. Add `FromRequest::Config` associated type. ## [1.0.0-alpha.5] - 2019-04-12 ### Added -* Added async io `TestBuffer` for testing. +- Added async io `TestBuffer` for testing. ### Deleted -* Removed native-tls support +- Removed native-tls support ## [1.0.0-alpha.4] - 2019-04-08 ### Added -* `App::configure()` allow to offload app configuration to different methods +- `App::configure()` allow to offload app configuration to different methods -* Added `URLPath` option for logger +- Added `URLPath` option for logger -* Added `ServiceRequest::app_data()`, returns `Data` +- Added `ServiceRequest::app_data()`, returns `Data` -* Added `ServiceFromRequest::app_data()`, returns `Data` +- Added `ServiceFromRequest::app_data()`, returns `Data` ### Changed -* `FromRequest` trait refactoring +- `FromRequest` trait refactoring -* Move multipart support to actix-multipart crate +- Move multipart support to actix-multipart crate ### Fixed -* Fix body propagation in Response::from_error. #760 +- Fix body propagation in Response::from_error. #760 ## [1.0.0-alpha.3] - 2019-04-02 ### Changed -* Renamed `TestRequest::to_service()` to `TestRequest::to_srv_request()` +- Renamed `TestRequest::to_service()` to `TestRequest::to_srv_request()` -* Renamed `TestRequest::to_response()` to `TestRequest::to_srv_response()` +- Renamed `TestRequest::to_response()` to `TestRequest::to_srv_response()` -* Removed `Deref` impls +- Removed `Deref` impls ### Removed -* Removed unused `actix_web::web::md()` +- Removed unused `actix_web::web::md()` ## [1.0.0-alpha.2] - 2019-03-29 ### Added -* rustls support +- Rustls support ### Changed -* use forked cookie +- Use forked cookie -* multipart::Field renamed to MultipartField +- Multipart::Field renamed to MultipartField ## [1.0.0-alpha.1] - 2019-03-28 ### Changed -* Complete architecture re-design. +- Complete architecture re-design. -* Return 405 response if no matching route found within resource #538 +- Return 405 response if no matching route found within resource #538 diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index ae97b3240..dbd092095 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -8,19 +8,19 @@ In the interest of fostering an open and welcoming environment, we as contributo Examples of behavior that contributes to creating a positive environment include: -* Using welcoming and inclusive language -* Being respectful of differing viewpoints and experiences -* Gracefully accepting constructive criticism -* Focusing on what is best for the community -* Showing empathy towards other community members +- Using welcoming and inclusive language +- Being respectful of differing viewpoints and experiences +- Gracefully accepting constructive criticism +- Focusing on what is best for the community +- Showing empathy towards other community members Examples of unacceptable behavior by participants include: -* The use of sexualized language or imagery and unwelcome sexual attention or advances -* Trolling, insulting/derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or electronic address, without explicit permission -* Other conduct which could reasonably be considered inappropriate in a professional setting +- The use of sexualized language or imagery and unwelcome sexual attention or advances +- Trolling, insulting/derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or electronic address, without explicit permission +- Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities diff --git a/Cargo.toml b/Cargo.toml index 1a1b8645c..a2b8fb885 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,130 +1,131 @@ [package] name = "actix-web" -version = "4.0.0-beta.3" +version = "4.0.0-beta.19" authors = ["Nikolay Kim "] description = "Actix Web is a powerful, pragmatic, and extremely fast web framework for Rust" -readme = "README.md" keywords = ["actix", "http", "web", "framework", "async"] +categories = [ + "network-programming", + "asynchronous", + "web-programming::http-server", + "web-programming::websocket" +] homepage = "https://actix.rs" repository = "https://github.com/actix/actix-web.git" -documentation = "https://docs.rs/actix-web/" -categories = ["network-programming", "asynchronous", - "web-programming::http-server", - "web-programming::websocket"] license = "MIT OR Apache-2.0" edition = "2018" [package.metadata.docs.rs] # features that docs.rs will build with -features = ["openssl", "rustls", "compress", "secure-cookies"] - -[badges] -travis-ci = { repository = "actix/actix-web", branch = "master" } -codecov = { repository = "actix/actix-web", branch = "master", service = "github" } +features = ["openssl", "rustls", "compress-brotli", "compress-gzip", "compress-zstd", "cookies", "secure-cookies"] +rustdoc-args = ["--cfg", "docsrs"] [lib] name = "actix_web" path = "src/lib.rs" [workspace] +resolver = "2" members = [ - ".", - "awc", - "actix-http", - "actix-files", - "actix-multipart", - "actix-web-actors", - "actix-web-codegen", - "actix-http-test", + ".", + "actix-files", + "actix-http-test", + "actix-http", + "actix-multipart", + "actix-router", + "actix-test", + "actix-web-actors", + "actix-web-codegen", + "awc", ] [features] -default = ["compress", "cookies"] +default = ["compress-brotli", "compress-gzip", "compress-zstd", "cookies"] -# content-encoding support -compress = ["actix-http/compress", "awc/compress"] +# Brotli algorithm content-encoding support +compress-brotli = ["actix-http/compress-brotli", "__compress"] +# Gzip and deflate algorithms content-encoding support +compress-gzip = ["actix-http/compress-gzip", "__compress"] +# Zstd algorithm content-encoding support +compress-zstd = ["actix-http/compress-zstd", "__compress"] # support for cookies -cookies = ["actix-http/cookies", "awc/cookies"] +cookies = ["cookie"] # secure cookies feature -secure-cookies = ["actix-http/secure-cookies"] +secure-cookies = ["cookie/secure"] # openssl -openssl = ["tls-openssl", "actix-tls/accept", "actix-tls/openssl", "awc/openssl"] +openssl = ["actix-http/openssl", "actix-tls/accept", "actix-tls/openssl"] # rustls -rustls = ["tls-rustls", "actix-tls/accept", "actix-tls/rustls", "awc/rustls"] +rustls = ["actix-http/rustls", "actix-tls/accept", "actix-tls/rustls"] -[[example]] -name = "basic" -required-features = ["compress"] +# Internal (PRIVATE!) features used to aid testing and checking feature status. +# Don't rely on these whatsoever. They may disappear at anytime. +__compress = [] -[[example]] -name = "uds" -required-features = ["compress"] - -[[test]] -name = "test_server" -required-features = ["compress"] - -[[example]] -name = "on_connect" -required-features = [] - -[[example]] -name = "client" -required-features = ["rustls"] +# io-uring feature only avaiable for Linux OSes. +experimental-io-uring = ["actix-server/io-uring"] [dependencies] -actix-codec = "0.4.0-beta.1" -actix-macros = "0.2.0" -actix-router = "0.2.7" -actix-rt = "2" -actix-server = "2.0.0-beta.3" -actix-service = "2.0.0-beta.4" -actix-utils = "3.0.0-beta.2" -actix-tls = { version = "3.0.0-beta.3", default-features = false, optional = true } +actix-codec = "0.4.1" +actix-macros = "0.2.3" +actix-rt = "2.3" +actix-server = "2.0.0-rc.2" +actix-service = "2.0.0" +actix-utils = "3.0.0" +actix-tls = { version = "3.0.0", default-features = false, optional = true } -actix-web-codegen = "0.5.0-beta.1" -actix-http = "3.0.0-beta.3" -awc = { version = "3.0.0-beta.2", default-features = false } +actix-http = "3.0.0-beta.18" +actix-router = "0.5.0-beta.4" +actix-web-codegen = "0.5.0-rc.1" ahash = "0.7" bytes = "1" +cfg-if = "1" +cookie = { version = "0.16", features = ["percent-encode"], optional = true } derive_more = "0.99.5" -either = "1.5.3" encoding_rs = "0.8" futures-core = { version = "0.3.7", default-features = false } futures-util = { version = "0.3.7", default-features = false } +itoa = "1" +language-tags = "0.3" +once_cell = "1.5" log = "0.4" mime = "0.3" -pin-project = "1.0.0" +pin-project-lite = "0.2.7" regex = "1.4" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_urlencoded = "0.7" -smallvec = "1.6" -socket2 = "0.3.16" -time = { version = "0.2.23", default-features = false, features = ["std"] } -tls-openssl = { package = "openssl", version = "0.10.9", optional = true } -tls-rustls = { package = "rustls", version = "0.19.0", optional = true } +smallvec = "1.6.1" +socket2 = "0.4.0" +time = { version = "0.3", default-features = false, features = ["formatting"] } url = "2.1" -[target.'cfg(windows)'.dependencies.tls-openssl] -version = "0.10.9" -package = "openssl" -features = ["vendored"] -optional = true - [dev-dependencies] +actix-files = "0.6.0-beta.13" +actix-test = { version = "0.1.0-beta.11", features = ["openssl", "rustls"] } +awc = { version = "3.0.0-beta.18", features = ["openssl"] } + brotli2 = "0.3.2" -criterion = "0.3" -env_logger = "0.8" +const-str = "0.3" +criterion = { version = "0.3", features = ["html_reports"] } +env_logger = "0.9" flate2 = "1.0.13" +futures-util = { version = "0.3.7", default-features = false, features = ["std"] } rand = "0.8" rcgen = "0.8" -serde_derive = "1.0" +rustls-pemfile = "0.2" +static_assertions = "1" +tls-openssl = { package = "openssl", version = "0.10.9" } +tls-rustls = { package = "rustls", version = "0.20.0" } +zstd = "0.9" + +[profile.dev] +# Disabling debug info speeds up builds a bunch and we don't rely on it for debugging that much. +debug = 0 [profile.release] lto = true @@ -132,15 +133,42 @@ opt-level = 3 codegen-units = 1 [patch.crates-io] -actix-web = { path = "." } +actix-files = { path = "actix-files" } actix-http = { path = "actix-http" } actix-http-test = { path = "actix-http-test" } +actix-multipart = { path = "actix-multipart" } +actix-router = { path = "actix-router" } +actix-test = { path = "actix-test" } +actix-web = { path = "." } actix-web-actors = { path = "actix-web-actors" } actix-web-codegen = { path = "actix-web-codegen" } -actix-multipart = { path = "actix-multipart" } -actix-files = { path = "actix-files" } awc = { path = "awc" } +# uncomment for quick testing against local actix-net repo +# actix-service = { path = "../actix-net/actix-service" } +# actix-macros = { path = "../actix-net/actix-macros" } +# actix-rt = { path = "../actix-net/actix-rt" } +# actix-codec = { path = "../actix-net/actix-codec" } +# actix-utils = { path = "../actix-net/actix-utils" } +# actix-tls = { path = "../actix-net/actix-tls" } +# actix-server = { path = "../actix-net/actix-server" } + +[[test]] +name = "test_server" +required-features = ["compress-brotli", "compress-gzip", "compress-zstd", "cookies"] + +[[example]] +name = "basic" +required-features = ["compress-gzip"] + +[[example]] +name = "uds" +required-features = ["compress-gzip"] + +[[example]] +name = "on-connect" +required-features = [] + [[bench]] name = "server" harness = false diff --git a/LICENSE-MIT b/LICENSE-MIT index 95938ef15..d559b1cd1 100644 --- a/LICENSE-MIT +++ b/LICENSE-MIT @@ -1,4 +1,4 @@ -Copyright (c) 2017 Actix Team +Copyright (c) 2017-NOW Actix Team Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated diff --git a/MIGRATION.md b/MIGRATION.md index e01702868..338a04389 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -1,42 +1,57 @@ ## Unreleased -* The default `NormalizePath` behavior now strips trailing slashes by default. This was +- The default `NormalizePath` behavior now strips trailing slashes by default. This was previously documented to be the case in v3 but the behavior now matches. The effect is that routes defined with trailing slashes will become inaccessible when - using `NormalizePath::default()`. + using `NormalizePath::default()`. As such, calling `NormalizePath::default()` will log a warning. + It is advised that the `new` method be used instead. - Before: `#[get("/test/")` - After: `#[get("/test")` + Before: `#[get("/test/")]` + After: `#[get("/test")]` Alternatively, explicitly require trailing slashes: `NormalizePath::new(TrailingSlash::Always)`. +- The `type Config` of `FromRequest` was removed. + +- Feature flag `compress` has been split into its supported algorithm (brotli, gzip, zstd). + By default all compression algorithms are enabled. + To select algorithm you want to include with `middleware::Compress` use following flags: + - `compress-brotli` + - `compress-gzip` + - `compress-zstd` + If you have set in your `Cargo.toml` dedicated `actix-web` features and you still want + to have compression enabled. Please change features selection like bellow: + + Before: `"compress"` + After: `"compress-brotli", "compress-gzip", "compress-zstd"` + ## 3.0.0 -* The return type for `ServiceRequest::app_data::()` was changed from returning a `Data` to +- The return type for `ServiceRequest::app_data::()` was changed from returning a `Data` to simply a `T`. To access a `Data` use `ServiceRequest::app_data::>()`. -* Cookie handling has been offloaded to the `cookie` crate: +- Cookie handling has been offloaded to the `cookie` crate: * `USERINFO_ENCODE_SET` is no longer exposed. Percent-encoding is still supported; check docs. * Some types now require lifetime parameters. -* The time crate was updated to `v0.2`, a major breaking change to the time crate, which affects +- The time crate was updated to `v0.2`, a major breaking change to the time crate, which affects any `actix-web` method previously expecting a time v0.1 input. -* Setting a cookie's SameSite property, explicitly, to `SameSite::None` will now +- Setting a cookie's SameSite property, explicitly, to `SameSite::None` will now result in `SameSite=None` being sent with the response Set-Cookie header. To create a cookie without a SameSite attribute, remove any calls setting same_site. -* actix-http support for Actors messages was moved to actix-http crate and is enabled +- actix-http support for Actors messages was moved to actix-http crate and is enabled with feature `actors` -* content_length function is removed from actix-http. +- content_length function is removed from actix-http. You can set Content-Length by normally setting the response body or calling no_chunking function. -* `BodySize::Sized64` variant has been removed. `BodySize::Sized` now receives a +- `BodySize::Sized64` variant has been removed. `BodySize::Sized` now receives a `u64` instead of a `usize`. -* Code that was using `path.` to access a `web::Path<(A, B, C)>`s elements now needs to use +- Code that was using `path.` to access a `web::Path<(A, B, C)>`s elements now needs to use destructuring or `.into_inner()`. For example: ```rust @@ -56,35 +71,35 @@ } ``` -* `middleware::NormalizePath` can now also be configured to trim trailing slashes instead of always keeping one. +- `middleware::NormalizePath` can now also be configured to trim trailing slashes instead of always keeping one. It will need `middleware::normalize::TrailingSlash` when being constructed with `NormalizePath::new(...)`, or for an easier migration you can replace `wrap(middleware::NormalizePath)` with `wrap(middleware::NormalizePath::new(TrailingSlash::MergeOnly))`. -* `HttpServer::maxconn` is renamed to the more expressive `HttpServer::max_connections`. +- `HttpServer::maxconn` is renamed to the more expressive `HttpServer::max_connections`. -* `HttpServer::maxconnrate` is renamed to the more expressive `HttpServer::max_connection_rate`. +- `HttpServer::maxconnrate` is renamed to the more expressive `HttpServer::max_connection_rate`. ## 2.0.0 -* `HttpServer::start()` renamed to `HttpServer::run()`. It also possible to +- `HttpServer::start()` renamed to `HttpServer::run()`. It also possible to `.await` on `run` method result, in that case it awaits server exit. -* `App::register_data()` renamed to `App::app_data()` and accepts any type `T: 'static`. +- `App::register_data()` renamed to `App::app_data()` and accepts any type `T: 'static`. Stored data is available via `HttpRequest::app_data()` method at runtime. -* Extractor configuration must be registered with `App::app_data()` instead of `App::data()` +- Extractor configuration must be registered with `App::app_data()` instead of `App::data()` -* Sync handlers has been removed. `.to_async()` method has been renamed to `.to()` +- Sync handlers has been removed. `.to_async()` method has been renamed to `.to()` replace `fn` with `async fn` to convert sync handler to async -* `actix_http_test::TestServer` moved to `actix_web::test` module. To start +- `actix_http_test::TestServer` moved to `actix_web::test` module. To start test server use `test::start()` or `test_start_with_config()` methods -* `ResponseError` trait has been reafctored. `ResponseError::error_response()` renders +- `ResponseError` trait has been reafctored. `ResponseError::error_response()` renders http response. -* Feature `rust-tls` renamed to `rustls` +- Feature `rust-tls` renamed to `rustls` instead of @@ -98,7 +113,7 @@ actix-web = { version = "2.0.0", features = ["rustls"] } ``` -* Feature `ssl` renamed to `openssl` +- Feature `ssl` renamed to `openssl` instead of @@ -111,11 +126,11 @@ ```rust actix-web = { version = "2.0.0", features = ["openssl"] } ``` -* `Cors` builder now requires that you call `.finish()` to construct the middleware +- `Cors` builder now requires that you call `.finish()` to construct the middleware ## 1.0.1 -* Cors middleware has been moved to `actix-cors` crate +- Cors middleware has been moved to `actix-cors` crate instead of @@ -129,7 +144,7 @@ use actix_cors::Cors; ``` -* Identity middleware has been moved to `actix-identity` crate +- Identity middleware has been moved to `actix-identity` crate instead of @@ -146,7 +161,7 @@ ## 1.0.0 -* Extractor configuration. In version 1.0 this is handled with the new `Data` mechanism for both setting and retrieving the configuration +- Extractor configuration. In version 1.0 this is handled with the new `Data` mechanism for both setting and retrieving the configuration instead of @@ -204,7 +219,7 @@ ) ``` -* Resource registration. 1.0 version uses generalized resource +- Resource registration. 1.0 version uses generalized resource registration via `.service()` method. instead of @@ -224,7 +239,7 @@ .route(web::post().to(post_handler)) ``` -* Scope registration. +- Scope registration. instead of @@ -248,7 +263,7 @@ ); ``` -* `.with()`, `.with_async()` registration methods have been renamed to `.to()` and `.to_async()`. +- `.with()`, `.with_async()` registration methods have been renamed to `.to()` and `.to_async()`. instead of @@ -262,7 +277,7 @@ App.new().service(web::resource("/welcome").to(welcome)) ``` -* Passing arguments to handler with extractors, multiple arguments are allowed +- Passing arguments to handler with extractors, multiple arguments are allowed instead of @@ -280,7 +295,7 @@ } ``` -* `.f()`, `.a()` and `.h()` handler registration methods have been removed. +- `.f()`, `.a()` and `.h()` handler registration methods have been removed. Use `.to()` for handlers and `.to_async()` for async handlers. Handler function must use extractors. @@ -296,7 +311,7 @@ App.new().service(web::resource("/welcome").to(welcome)) ``` -* `HttpRequest` does not provide access to request's payload stream. +- `HttpRequest` does not provide access to request's payload stream. instead of @@ -326,7 +341,7 @@ } ``` -* `State` is now `Data`. You register Data during the App initialization process +- `State` is now `Data`. You register Data during the App initialization process and then access it from handlers either using a Data extractor or using HttpRequest's api. @@ -362,7 +377,7 @@ ``` -* AsyncResponder is removed, use `.to_async()` registration method and `impl Future<>` as result type. +- AsyncResponder is removed, use `.to_async()` registration method and `impl Future<>` as result type. instead of @@ -378,7 +393,7 @@ .. simply omit AsyncResponder and the corresponding responder() finish method -* Middleware +- Middleware instead of @@ -395,7 +410,7 @@ .route("/index.html", web::get().to(index)); ``` -* `HttpRequest::body()`, `HttpRequest::urlencoded()`, `HttpRequest::json()`, `HttpRequest::multipart()` +- `HttpRequest::body()`, `HttpRequest::urlencoded()`, `HttpRequest::json()`, `HttpRequest::multipart()` method have been removed. Use `Bytes`, `String`, `Form`, `Json`, `Multipart` extractors instead. instead of @@ -417,9 +432,9 @@ } ``` -* `actix_web::server` module has been removed. To start http server use `actix_web::HttpServer` type +- `actix_web::server` module has been removed. To start http server use `actix_web::HttpServer` type -* StaticFiles and NamedFile have been moved to a separate crate. +- StaticFiles and NamedFile have been moved to a separate crate. instead of `use actix_web::fs::StaticFile` @@ -429,20 +444,20 @@ use `use actix_files::NamedFile` -* Multipart has been moved to a separate crate. +- Multipart has been moved to a separate crate. instead of `use actix_web::multipart::Multipart` use `use actix_multipart::Multipart` -* Response compression is not enabled by default. +- Response compression is not enabled by default. To enable, use `Compress` middleware, `App::new().wrap(Compress::default())`. -* Session middleware moved to actix-session crate +- Session middleware moved to actix-session crate -* Actors support have been moved to `actix-web-actors` crate +- Actors support have been moved to `actix-web-actors` crate -* Custom Error +- Custom Error Instead of error_response method alone, ResponseError now provides two methods: error_response and render_response respectively. Where, error_response creates the error response and render_response returns the error response to the caller. @@ -456,7 +471,7 @@ ## 0.7.15 -* The `' '` character is not percent decoded anymore before matching routes. If you need to use it in +- The `' '` character is not percent decoded anymore before matching routes. If you need to use it in your routes, you should use `%20`. instead of @@ -481,18 +496,18 @@ } ``` -* If you used `AsyncResult::async` you need to replace it with `AsyncResult::future` +- If you used `AsyncResult::async` you need to replace it with `AsyncResult::future` ## 0.7.4 -* `Route::with_config()`/`Route::with_async_config()` always passes configuration objects as tuple +- `Route::with_config()`/`Route::with_async_config()` always passes configuration objects as tuple even for handler with one parameter. ## 0.7 -* `HttpRequest` does not implement `Stream` anymore. If you need to read request payload +- `HttpRequest` does not implement `Stream` anymore. If you need to read request payload use `HttpMessage::payload()` method. instead of @@ -518,10 +533,10 @@ } ``` -* [Middleware](https://actix.rs/actix-web/actix_web/middleware/trait.Middleware.html) +- [Middleware](https://actix.rs/actix-web/actix_web/middleware/trait.Middleware.html) trait uses `&HttpRequest` instead of `&mut HttpRequest`. -* Removed `Route::with2()` and `Route::with3()` use tuple of extractors instead. +- Removed `Route::with2()` and `Route::with3()` use tuple of extractors instead. instead of @@ -535,17 +550,17 @@ fn index((query, json): (Query<..>, Json impl Responder {} ``` -* `Handler::handle()` uses `&self` instead of `&mut self` +- `Handler::handle()` uses `&self` instead of `&mut self` -* `Handler::handle()` accepts reference to `HttpRequest<_>` instead of value +- `Handler::handle()` accepts reference to `HttpRequest<_>` instead of value -* Removed deprecated `HttpServer::threads()`, use +- Removed deprecated `HttpServer::threads()`, use [HttpServer::workers()](https://actix.rs/actix-web/actix_web/server/struct.HttpServer.html#method.workers) instead. -* Renamed `client::ClientConnectorError::Connector` to +- Renamed `client::ClientConnectorError::Connector` to `client::ClientConnectorError::Resolver` -* `Route::with()` does not return `ExtractorConfig`, to configure +- `Route::with()` does not return `ExtractorConfig`, to configure extractor use `Route::with_config()` instead of @@ -574,26 +589,26 @@ } ``` -* `Route::with_async()` does not return `ExtractorConfig`, to configure +- `Route::with_async()` does not return `ExtractorConfig`, to configure extractor use `Route::with_async_config()` ## 0.6 -* `Path` extractor return `ErrorNotFound` on failure instead of `ErrorBadRequest` +- `Path` extractor return `ErrorNotFound` on failure instead of `ErrorBadRequest` -* `ws::Message::Close` now includes optional close reason. +- `ws::Message::Close` now includes optional close reason. `ws::CloseCode::Status` and `ws::CloseCode::Empty` have been removed. -* `HttpServer::threads()` renamed to `HttpServer::workers()`. +- `HttpServer::threads()` renamed to `HttpServer::workers()`. -* `HttpServer::start_ssl()` and `HttpServer::start_tls()` deprecated. +- `HttpServer::start_ssl()` and `HttpServer::start_tls()` deprecated. Use `HttpServer::bind_ssl()` and `HttpServer::bind_tls()` instead. -* `HttpRequest::extensions()` returns read only reference to the request's Extension +- `HttpRequest::extensions()` returns read only reference to the request's Extension `HttpRequest::extensions_mut()` returns mutable reference. -* Instead of +- Instead of `use actix_web::middleware::{ CookieSessionBackend, CookieSessionError, RequestSession, @@ -604,15 +619,15 @@ `use actix_web::middleware::session{CookieSessionBackend, CookieSessionError, RequestSession, Session, SessionBackend, SessionImpl, SessionStorage};` -* `FromRequest::from_request()` accepts mutable reference to a request +- `FromRequest::from_request()` accepts mutable reference to a request -* `FromRequest::Result` has to implement `Into>` +- `FromRequest::Result` has to implement `Into>` -* [`Responder::respond_to()`]( +- [`Responder::respond_to()`]( https://actix.rs/actix-web/actix_web/trait.Responder.html#tymethod.respond_to) is generic over `S` -* Use `Query` extractor instead of HttpRequest::query()`. +- Use `Query` extractor instead of HttpRequest::query()`. ```rust fn index(q: Query>) -> Result<..> { @@ -626,37 +641,37 @@ let q = Query::>::extract(req); ``` -* Websocket operations are implemented as `WsWriter` trait. +- Websocket operations are implemented as `WsWriter` trait. you need to use `use actix_web::ws::WsWriter` ## 0.5 -* `HttpResponseBuilder::body()`, `.finish()`, `.json()` +- `HttpResponseBuilder::body()`, `.finish()`, `.json()` methods return `HttpResponse` instead of `Result` -* `actix_web::Method`, `actix_web::StatusCode`, `actix_web::Version` +- `actix_web::Method`, `actix_web::StatusCode`, `actix_web::Version` moved to `actix_web::http` module -* `actix_web::header` moved to `actix_web::http::header` +- `actix_web::header` moved to `actix_web::http::header` -* `NormalizePath` moved to `actix_web::http` module +- `NormalizePath` moved to `actix_web::http` module -* `HttpServer` moved to `actix_web::server`, added new `actix_web::server::new()` function, +- `HttpServer` moved to `actix_web::server`, added new `actix_web::server::new()` function, shortcut for `actix_web::server::HttpServer::new()` -* `DefaultHeaders` middleware does not use separate builder, all builder methods moved to type itself +- `DefaultHeaders` middleware does not use separate builder, all builder methods moved to type itself -* `StaticFiles::new()`'s show_index parameter removed, use `show_files_listing()` method instead. +- `StaticFiles::new()`'s show_index parameter removed, use `show_files_listing()` method instead. -* `CookieSessionBackendBuilder` removed, all methods moved to `CookieSessionBackend` type +- `CookieSessionBackendBuilder` removed, all methods moved to `CookieSessionBackend` type -* `actix_web::httpcodes` module is deprecated, `HttpResponse::Ok()`, `HttpResponse::Found()` and other `HttpResponse::XXX()` +- `actix_web::httpcodes` module is deprecated, `HttpResponse::Ok()`, `HttpResponse::Found()` and other `HttpResponse::XXX()` functions should be used instead -* `ClientRequestBuilder::body()` returns `Result<_, actix_web::Error>` +- `ClientRequestBuilder::body()` returns `Result<_, actix_web::Error>` instead of `Result<_, http::Error>` -* `Application` renamed to a `App` +- `Application` renamed to a `App` -* `actix_web::Reply`, `actix_web::Resource` moved to `actix_web::dev` +- `actix_web::Reply`, `actix_web::Resource` moved to `actix_web::dev` diff --git a/README.md b/README.md index cc7c4cd52..5c5a55743 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,13 @@

[![crates.io](https://img.shields.io/crates/v/actix-web?label=latest)](https://crates.io/crates/actix-web) -[![Documentation](https://docs.rs/actix-web/badge.svg?version=4.0.0-beta.2)](https://docs.rs/actix-web/4.0.0-beta.2) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-web/badge.svg?version=4.0.0-beta.19)](https://docs.rs/actix-web/4.0.0-beta.19) +![MSRV](https://img.shields.io/badge/rustc-1.54+-ab6000.svg) ![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-web.svg) -[![Dependency Status](https://deps.rs/crate/actix-web/4.0.0-beta.2/status.svg)](https://deps.rs/crate/actix-web/4.0.0-beta.2) +[![Dependency Status](https://deps.rs/crate/actix-web/4.0.0-beta.19/status.svg)](https://deps.rs/crate/actix-web/4.0.0-beta.19)
-[![build status](https://github.com/actix/actix-web/workflows/CI%20%28Linux%29/badge.svg?branch=master&event=push)](https://github.com/actix/actix-web/actions) -[![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) +[![CI](https://github.com/actix/actix-web/actions/workflows/ci.yml/badge.svg)](https://github.com/actix/actix-web/actions/workflows/ci.yml) +[![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) ![downloads](https://img.shields.io/crates/d/actix-web.svg) [![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) @@ -21,25 +21,25 @@ ## Features -* Supports *HTTP/1.x* and *HTTP/2* -* Streaming and pipelining -* Keep-alive and slow requests handling -* Client/server [WebSockets](https://actix.rs/docs/websockets/) support -* Transparent content compression/decompression (br, gzip, deflate) -* Powerful [request routing](https://actix.rs/docs/url-dispatch/) -* Multipart streams -* Static assets -* SSL support using OpenSSL or Rustls -* Middlewares ([Logger, Session, CORS, etc](https://actix.rs/docs/middleware/)) -* Includes an async [HTTP client](https://actix.rs/actix-web/actix_web/client/index.html) -* Runs on stable Rust 1.46+ +- Supports *HTTP/1.x* and *HTTP/2* +- Streaming and pipelining +- Keep-alive and slow requests handling +- Client/server [WebSockets](https://actix.rs/docs/websockets/) support +- Transparent content compression/decompression (br, gzip, deflate, zstd) +- Powerful [request routing](https://actix.rs/docs/url-dispatch/) +- Multipart streams +- Static assets +- SSL support using OpenSSL or Rustls +- Middlewares ([Logger, Session, CORS, etc](https://actix.rs/docs/middleware/)) +- Includes an async [HTTP client](https://docs.rs/awc/) +- Runs on stable Rust 1.54+ ## Documentation -* [Website & User Guide](https://actix.rs) -* [Examples Repository](https://github.com/actix/examples) -* [API Documentation](https://docs.rs/actix-web) -* [API Documentation (master branch)](https://actix.rs/actix-web/actix_web) +- [Website & User Guide](https://actix.rs) +- [Examples Repository](https://github.com/actix/examples) +- [API Documentation](https://docs.rs/actix-web) +- [API Documentation (master branch)](https://actix.rs/actix-web/actix_web) ## Example @@ -71,18 +71,18 @@ async fn main() -> std::io::Result<()> { ### More examples -* [Basic Setup](https://github.com/actix/examples/tree/master/basics/) -* [Application State](https://github.com/actix/examples/tree/master/state/) -* [JSON Handling](https://github.com/actix/examples/tree/master/json/) -* [Multipart Streams](https://github.com/actix/examples/tree/master/multipart/) -* [Diesel Integration](https://github.com/actix/examples/tree/master/diesel/) -* [r2d2 Integration](https://github.com/actix/examples/tree/master/r2d2/) -* [Simple WebSocket](https://github.com/actix/examples/tree/master/websocket/) -* [Tera Templates](https://github.com/actix/examples/tree/master/template_tera/) -* [Askama Templates](https://github.com/actix/examples/tree/master/template_askama/) -* [HTTPS using Rustls](https://github.com/actix/examples/tree/master/rustls/) -* [HTTPS using OpenSSL](https://github.com/actix/examples/tree/master/openssl/) -* [WebSocket Chat](https://github.com/actix/examples/tree/master/websocket-chat/) +- [Basic Setup](https://github.com/actix/examples/tree/master/basics/basics/) +- [Application State](https://github.com/actix/examples/tree/master/basics/state/) +- [JSON Handling](https://github.com/actix/examples/tree/master/json/json/) +- [Multipart Streams](https://github.com/actix/examples/tree/master/forms/multipart/) +- [Diesel Integration](https://github.com/actix/examples/tree/master/database_interactions/diesel/) +- [r2d2 Integration](https://github.com/actix/examples/tree/master/database_interactions/r2d2/) +- [Simple WebSocket](https://github.com/actix/examples/tree/master/websockets/websocket/) +- [Tera Templates](https://github.com/actix/examples/tree/master/template_engines/tera/) +- [Askama Templates](https://github.com/actix/examples/tree/master/template_engines/askama/) +- [HTTPS using Rustls](https://github.com/actix/examples/tree/master/security/rustls/) +- [HTTPS using OpenSSL](https://github.com/actix/examples/tree/master/security/openssl/) +- [WebSocket Chat](https://github.com/actix/examples/tree/master/websockets/chat/) You may consider checking out [this directory](https://github.com/actix/examples/tree/master/) for more examples. @@ -90,15 +90,15 @@ You may consider checking out ## Benchmarks One of the fastest web frameworks available according to the -[TechEmpower Framework Benchmark](https://www.techempower.com/benchmarks/#section=data-r19). +[TechEmpower Framework Benchmark](https://www.techempower.com/benchmarks/#section=data-r20&test=composite). ## License This project is licensed under either of -* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or +- Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or [http://www.apache.org/licenses/LICENSE-2.0]) -* MIT license ([LICENSE-MIT](LICENSE-MIT) or +- MIT license ([LICENSE-MIT](LICENSE-MIT) or [http://opensource.org/licenses/MIT]) at your option. diff --git a/actix-files/CHANGES.md b/actix-files/CHANGES.md index 6e2a241ac..e8a07d884 100644 --- a/actix-files/CHANGES.md +++ b/actix-files/CHANGES.md @@ -3,121 +3,191 @@ ## Unreleased - 2021-xx-xx +## 0.6.0-beta.13 - 2022-01-04 +- The `Files` service now rejects requests with URL paths that include `%2F` (decoded: `/`). [#2398] +- The `Files` service now correctly decodes `%25` in the URL path to `%` for the file path. [#2398] +- Minimum supported Rust version (MSRV) is now 1.54. + +[#2398]: https://github.com/actix/actix-web/pull/2398 + + +## 0.6.0-beta.12 - 2021-12-29 +- No significant changes since `0.6.0-beta.11`. + + +## 0.6.0-beta.11 - 2021-12-27 +- No significant changes since `0.6.0-beta.10`. + + +## 0.6.0-beta.10 - 2021-12-11 +- No significant changes since `0.6.0-beta.9`. + + +## 0.6.0-beta.9 - 2021-11-22 +- Add crate feature `experimental-io-uring`, enabling async file I/O to be utilized. This feature is only available on Linux OSes with recent kernel versions. This feature is semver-exempt. [#2408] +- Add `NamedFile::open_async`. [#2408] +- Fix 304 Not Modified responses to omit the Content-Length header, as per the spec. [#2453] +- The `Responder` impl for `NamedFile` now has a boxed future associated type. [#2408] +- The `Service` impl for `NamedFileService` now has a boxed future associated type. [#2408] +- Add `impl Clone` for `FilesService`. [#2408] + +[#2408]: https://github.com/actix/actix-web/pull/2408 +[#2453]: https://github.com/actix/actix-web/pull/2453 + + +## 0.6.0-beta.8 - 2021-10-20 +- Minimum supported Rust version (MSRV) is now 1.52. + + +## 0.6.0-beta.7 - 2021-09-09 +- Minimum supported Rust version (MSRV) is now 1.51. + + +## 0.6.0-beta.6 - 2021-06-26 +- Added `Files::path_filter()`. [#2274] +- `Files::show_files_listing()` can now be used with `Files::index_file()` to show files listing as a fallback when the index file is not found. [#2228] + +[#2274]: https://github.com/actix/actix-web/pull/2274 +[#2228]: https://github.com/actix/actix-web/pull/2228 + + +## 0.6.0-beta.5 - 2021-06-17 +- `NamedFile` now implements `ServiceFactory` and `HttpServiceFactory` making it much more useful in routing. For example, it can be used directly as a default service. [#2135] +- For symbolic links, `Content-Disposition` header no longer shows the filename of the original file. [#2156] +- `Files::redirect_to_slash_directory()` now works as expected when used with `Files::show_files_listing()`. [#2225] +- `application/{javascript, json, wasm}` mime type now have `inline` disposition by default. [#2257] + +[#2135]: https://github.com/actix/actix-web/pull/2135 +[#2156]: https://github.com/actix/actix-web/pull/2156 +[#2225]: https://github.com/actix/actix-web/pull/2225 +[#2257]: https://github.com/actix/actix-web/pull/2257 + + +## 0.6.0-beta.4 - 2021-04-02 +- Add support for `.guard` in `Files` to selectively filter `Files` services. [#2046] + +[#2046]: https://github.com/actix/actix-web/pull/2046 + + +## 0.6.0-beta.3 - 2021-03-09 +- No notable changes. + + ## 0.6.0-beta.2 - 2021-02-10 -* Fix If-Modified-Since and If-Unmodified-Since to not compare using sub-second timestamps. [#1887] -* Replace `v_htmlescape` with `askama_escape`. [#1953] +- Fix If-Modified-Since and If-Unmodified-Since to not compare using sub-second timestamps. [#1887] +- Replace `v_htmlescape` with `askama_escape`. [#1953] [#1887]: https://github.com/actix/actix-web/pull/1887 [#1953]: https://github.com/actix/actix-web/pull/1953 ## 0.6.0-beta.1 - 2021-01-07 -* `HttpRange::parse` now has its own error type. -* Update `bytes` to `1.0`. [#1813] +- `HttpRange::parse` now has its own error type. +- Update `bytes` to `1.0`. [#1813] [#1813]: https://github.com/actix/actix-web/pull/1813 ## 0.5.0 - 2020-12-26 -* Optionally support hidden files/directories. [#1811] +- Optionally support hidden files/directories. [#1811] [#1811]: https://github.com/actix/actix-web/pull/1811 ## 0.4.1 - 2020-11-24 -* Clarify order of parameters in `Files::new` and improve docs. +- Clarify order of parameters in `Files::new` and improve docs. ## 0.4.0 - 2020-10-06 -* Add `Files::prefer_utf8` option that adds UTF-8 charset on certain response types. [#1714] +- Add `Files::prefer_utf8` option that adds UTF-8 charset on certain response types. [#1714] [#1714]: https://github.com/actix/actix-web/pull/1714 ## 0.3.0 - 2020-09-11 -* No significant changes from 0.3.0-beta.1. +- No significant changes from 0.3.0-beta.1. ## 0.3.0-beta.1 - 2020-07-15 -* Update `v_htmlescape` to 0.10 -* Update `actix-web` and `actix-http` dependencies to beta.1 +- Update `v_htmlescape` to 0.10 +- Update `actix-web` and `actix-http` dependencies to beta.1 ## 0.3.0-alpha.1 - 2020-05-23 -* Update `actix-web` and `actix-http` dependencies to alpha -* Fix some typos in the docs -* Bump minimum supported Rust version to 1.40 -* Support sending Content-Length when Content-Range is specified [#1384] +- Update `actix-web` and `actix-http` dependencies to alpha +- Fix some typos in the docs +- Bump minimum supported Rust version to 1.40 +- Support sending Content-Length when Content-Range is specified [#1384] [#1384]: https://github.com/actix/actix-web/pull/1384 ## 0.2.1 - 2019-12-22 -* Use the same format for file URLs regardless of platforms +- Use the same format for file URLs regardless of platforms ## 0.2.0 - 2019-12-20 -* Fix BodyEncoding trait import #1220 +- Fix BodyEncoding trait import #1220 ## 0.2.0-alpha.1 - 2019-12-07 -* Migrate to `std::future` +- Migrate to `std::future` ## 0.1.7 - 2019-11-06 -* Add an additional `filename*` param in the `Content-Disposition` header of +- Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151) ## 0.1.6 - 2019-10-14 -* Add option to redirect to a slash-ended path `Files` #1132 +- Add option to redirect to a slash-ended path `Files` #1132 ## 0.1.5 - 2019-10-08 -* Bump up `mime_guess` crate version to 2.0.1 -* Bump up `percent-encoding` crate version to 2.1 -* Allow user defined request guards for `Files` #1113 +- Bump up `mime_guess` crate version to 2.0.1 +- Bump up `percent-encoding` crate version to 2.1 +- Allow user defined request guards for `Files` #1113 ## 0.1.4 - 2019-07-20 -* Allow to disable `Content-Disposition` header #686 +- Allow to disable `Content-Disposition` header #686 ## 0.1.3 - 2019-06-28 -* Do not set `Content-Length` header, let actix-http set it #930 +- Do not set `Content-Length` header, let actix-http set it #930 ## 0.1.2 - 2019-06-13 -* Content-Length is 0 for NamedFile HEAD request #914 -* Fix ring dependency from actix-web default features for #741 +- Content-Length is 0 for NamedFile HEAD request #914 +- Fix ring dependency from actix-web default features for #741 ## 0.1.1 - 2019-06-01 -* Static files are incorrectly served as both chunked and with length #812 +- Static files are incorrectly served as both chunked and with length #812 ## 0.1.0 - 2019-05-25 -* NamedFile last-modified check always fails due to nano-seconds in file modified date #820 +- NamedFile last-modified check always fails due to nano-seconds in file modified date #820 ## 0.1.0-beta.4 - 2019-05-12 -* Update actix-web to beta.4 +- Update actix-web to beta.4 ## 0.1.0-beta.1 - 2019-04-20 -* Update actix-web to beta.1 +- Update actix-web to beta.1 ## 0.1.0-alpha.6 - 2019-04-14 -* Update actix-web to alpha6 +- Update actix-web to alpha6 ## 0.1.0-alpha.4 - 2019-04-08 -* Update actix-web to alpha4 +- Update actix-web to alpha4 ## 0.1.0-alpha.2 - 2019-04-02 -* Add default handler support +- Add default handler support ## 0.1.0-alpha.1 - 2019-03-28 -* Initial impl +- Initial impl diff --git a/actix-files/Cargo.toml b/actix-files/Cargo.toml index 45fa18a69..fe780f5ab 100644 --- a/actix-files/Cargo.toml +++ b/actix-files/Cargo.toml @@ -1,13 +1,15 @@ [package] name = "actix-files" -version = "0.6.0-beta.2" -authors = ["Nikolay Kim "] +version = "0.6.0-beta.13" +authors = [ + "Nikolay Kim ", + "fakeshadow <24548779@qq.com>", + "Rob Ede ", +] description = "Static file serving for Actix Web" -readme = "README.md" keywords = ["actix", "http", "async", "futures"] homepage = "https://actix.rs" -repository = "https://github.com/actix/actix-web.git" -documentation = "https://docs.rs/actix-files/" +repository = "https://github.com/actix/actix-web" categories = ["asynchronous", "web-programming::http-server"] license = "MIT OR Apache-2.0" edition = "2018" @@ -16,21 +18,31 @@ edition = "2018" name = "actix_files" path = "src/lib.rs" +[features] +experimental-io-uring = ["actix-web/experimental-io-uring", "tokio-uring"] + [dependencies] -actix-web = { version = "4.0.0-beta.3", default-features = false } -actix-service = "2.0.0-beta.4" +actix-http = "3.0.0-beta.18" +actix-service = "2" +actix-utils = "3" +actix-web = { version = "4.0.0-beta.19", default-features = false } askama_escape = "0.10" bitflags = "1" bytes = "1" -futures-core = { version = "0.3.7", default-features = false } -futures-util = { version = "0.3.7", default-features = false } derive_more = "0.99.5" +futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } +http-range = "0.1.4" log = "0.4" mime = "0.3" mime_guess = "2.0.1" percent-encoding = "2.1" +pin-project-lite = "0.2.7" + +tokio-uring = { version = "0.1", optional = true } [dev-dependencies] -actix-rt = "2" -actix-web = "4.0.0-beta.3" +actix-rt = "2.2" +actix-test = "0.1.0-beta.11" +actix-web = "4.0.0-beta.19" +tempfile = "3.2" diff --git a/actix-files/README.md b/actix-files/README.md index 463f20224..be878d958 100644 --- a/actix-files/README.md +++ b/actix-files/README.md @@ -3,17 +3,16 @@ > Static file serving for Actix Web [![crates.io](https://img.shields.io/crates/v/actix-files?label=latest)](https://crates.io/crates/actix-files) -[![Documentation](https://docs.rs/actix-files/badge.svg?version=0.5.0)](https://docs.rs/actix-files/0.5.0) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-files/badge.svg?version=0.6.0-beta.13)](https://docs.rs/actix-files/0.6.0-beta.13) +[![Version](https://img.shields.io/badge/rustc-1.54+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.54.0.html) ![License](https://img.shields.io/crates/l/actix-files.svg)
-[![dependency status](https://deps.rs/crate/actix-files/0.5.0/status.svg)](https://deps.rs/crate/actix-files/0.5.0) +[![dependency status](https://deps.rs/crate/actix-files/0.6.0-beta.13/status.svg)](https://deps.rs/crate/actix-files/0.6.0-beta.13) [![Download](https://img.shields.io/crates/d/actix-files.svg)](https://crates.io/crates/actix-files) -[![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-files/) -- [Example Project](https://github.com/actix/examples/tree/master/static_index) -- [Chat on Gitter](https://gitter.im/actix/actix-web) -- Minimum supported Rust version: 1.46 or later +- [Example Project](https://github.com/actix/examples/tree/master/basics/static_index) +- Minimum Supported Rust Version (MSRV): 1.54 diff --git a/actix-files/src/chunked.rs b/actix-files/src/chunked.rs index f639848c9..68221ccc3 100644 --- a/actix-files/src/chunked.rs +++ b/actix-files/src/chunked.rs @@ -1,98 +1,277 @@ use std::{ cmp, fmt, - fs::File, future::Future, - io::{self, Read, Seek}, + io, pin::Pin, task::{Context, Poll}, }; -use actix_web::{ - error::{BlockingError, Error}, - rt::task::{spawn_blocking, JoinHandle}, -}; -use bytes::Bytes; +use actix_web::{error::Error, web::Bytes}; use futures_core::{ready, Stream}; +use pin_project_lite::pin_project; -#[doc(hidden)] -/// A helper created from a `std::fs::File` which reads the file -/// chunk-by-chunk on a `ThreadPool`. -pub struct ChunkedReadFile { - size: u64, - offset: u64, - state: ChunkedReadFileState, - counter: u64, -} +use super::named::File; -enum ChunkedReadFileState { - File(Option), - Future(JoinHandle>), -} - -impl ChunkedReadFile { - pub(crate) fn new(size: u64, offset: u64, file: File) -> Self { - Self { - size, - offset, - state: ChunkedReadFileState::File(Some(file)), - counter: 0, - } +pin_project! { + /// Adapter to read a `std::file::File` in chunks. + #[doc(hidden)] + pub struct ChunkedReadFile { + size: u64, + offset: u64, + #[pin] + state: ChunkedReadFileState, + counter: u64, + callback: F, } } -impl fmt::Debug for ChunkedReadFile { +#[cfg(not(feature = "experimental-io-uring"))] +pin_project! { + #[project = ChunkedReadFileStateProj] + #[project_replace = ChunkedReadFileStateProjReplace] + enum ChunkedReadFileState { + File { file: Option, }, + Future { #[pin] fut: Fut }, + } +} + +#[cfg(feature = "experimental-io-uring")] +pin_project! { + #[project = ChunkedReadFileStateProj] + #[project_replace = ChunkedReadFileStateProjReplace] + enum ChunkedReadFileState { + File { file: Option<(File, BytesMut)> }, + Future { #[pin] fut: Fut }, + } +} + +impl fmt::Debug for ChunkedReadFile { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("ChunkedReadFile") } } -impl Stream for ChunkedReadFile { +pub(crate) fn new_chunked_read( + size: u64, + offset: u64, + file: File, +) -> impl Stream> { + ChunkedReadFile { + size, + offset, + #[cfg(not(feature = "experimental-io-uring"))] + state: ChunkedReadFileState::File { file: Some(file) }, + #[cfg(feature = "experimental-io-uring")] + state: ChunkedReadFileState::File { + file: Some((file, BytesMut::new())), + }, + counter: 0, + callback: chunked_read_file_callback, + } +} + +#[cfg(not(feature = "experimental-io-uring"))] +async fn chunked_read_file_callback( + mut file: File, + offset: u64, + max_bytes: usize, +) -> Result<(File, Bytes), Error> { + use io::{Read as _, Seek as _}; + + let res = actix_web::rt::task::spawn_blocking(move || { + let mut buf = Vec::with_capacity(max_bytes); + + file.seek(io::SeekFrom::Start(offset))?; + + let n_bytes = file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; + + if n_bytes == 0 { + Err(io::Error::from(io::ErrorKind::UnexpectedEof)) + } else { + Ok((file, Bytes::from(buf))) + } + }) + .await + .map_err(|_| actix_web::error::BlockingError)??; + + Ok(res) +} + +#[cfg(feature = "experimental-io-uring")] +async fn chunked_read_file_callback( + file: File, + offset: u64, + max_bytes: usize, + mut bytes_mut: BytesMut, +) -> io::Result<(File, Bytes, BytesMut)> { + bytes_mut.reserve(max_bytes); + + let (res, mut bytes_mut) = file.read_at(bytes_mut, offset).await; + let n_bytes = res?; + + if n_bytes == 0 { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + + let bytes = bytes_mut.split_to(n_bytes).freeze(); + + Ok((file, bytes, bytes_mut)) +} + +#[cfg(feature = "experimental-io-uring")] +impl Stream for ChunkedReadFile +where + F: Fn(File, u64, usize, BytesMut) -> Fut, + Fut: Future>, +{ type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.as_mut().get_mut(); - match this.state { - ChunkedReadFileState::File(ref mut file) => { - let size = this.size; - let offset = this.offset; - let counter = this.counter; + let mut this = self.as_mut().project(); + match this.state.as_mut().project() { + ChunkedReadFileStateProj::File { file } => { + let size = *this.size; + let offset = *this.offset; + let counter = *this.counter; if size == counter { Poll::Ready(None) } else { - let mut file = file + let max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; + + let (file, bytes_mut) = file .take() .expect("ChunkedReadFile polled after completion"); - let fut = spawn_blocking(move || { - let max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; + let fut = (this.callback)(file, offset, max_bytes, bytes_mut); - let mut buf = Vec::with_capacity(max_bytes); - file.seek(io::SeekFrom::Start(offset))?; + this.state + .project_replace(ChunkedReadFileState::Future { fut }); - let n_bytes = - file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; - - if n_bytes == 0 { - return Err(io::ErrorKind::UnexpectedEof.into()); - } - - Ok((file, Bytes::from(buf))) - }); - this.state = ChunkedReadFileState::Future(fut); self.poll_next(cx) } } - ChunkedReadFileState::Future(ref mut fut) => { - let (file, bytes) = - ready!(Pin::new(fut).poll(cx)).map_err(|_| BlockingError)??; - this.state = ChunkedReadFileState::File(Some(file)); + ChunkedReadFileStateProj::Future { fut } => { + let (file, bytes, bytes_mut) = ready!(fut.poll(cx))?; - this.offset += bytes.len() as u64; - this.counter += bytes.len() as u64; + this.state.project_replace(ChunkedReadFileState::File { + file: Some((file, bytes_mut)), + }); + + *this.offset += bytes.len() as u64; + *this.counter += bytes.len() as u64; Poll::Ready(Some(Ok(bytes))) } } } } + +#[cfg(not(feature = "experimental-io-uring"))] +impl Stream for ChunkedReadFile +where + F: Fn(File, u64, usize) -> Fut, + Fut: Future>, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.as_mut().project(); + match this.state.as_mut().project() { + ChunkedReadFileStateProj::File { file } => { + let size = *this.size; + let offset = *this.offset; + let counter = *this.counter; + + if size == counter { + Poll::Ready(None) + } else { + let max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; + + let file = file + .take() + .expect("ChunkedReadFile polled after completion"); + + let fut = (this.callback)(file, offset, max_bytes); + + this.state + .project_replace(ChunkedReadFileState::Future { fut }); + + self.poll_next(cx) + } + } + ChunkedReadFileStateProj::Future { fut } => { + let (file, bytes) = ready!(fut.poll(cx))?; + + this.state + .project_replace(ChunkedReadFileState::File { file: Some(file) }); + + *this.offset += bytes.len() as u64; + *this.counter += bytes.len() as u64; + + Poll::Ready(Some(Ok(bytes))) + } + } + } +} + +#[cfg(feature = "experimental-io-uring")] +use bytes_mut::BytesMut; + +// TODO: remove new type and use bytes::BytesMut directly +#[doc(hidden)] +#[cfg(feature = "experimental-io-uring")] +mod bytes_mut { + use std::ops::{Deref, DerefMut}; + + use tokio_uring::buf::{IoBuf, IoBufMut}; + + #[derive(Debug)] + pub struct BytesMut(bytes::BytesMut); + + impl BytesMut { + pub(super) fn new() -> Self { + Self(bytes::BytesMut::new()) + } + } + + impl Deref for BytesMut { + type Target = bytes::BytesMut; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl DerefMut for BytesMut { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + unsafe impl IoBuf for BytesMut { + fn stable_ptr(&self) -> *const u8 { + self.0.as_ptr() + } + + fn bytes_init(&self) -> usize { + self.0.len() + } + + fn bytes_total(&self) -> usize { + self.0.capacity() + } + } + + unsafe impl IoBufMut for BytesMut { + fn stable_mut_ptr(&mut self) -> *mut u8 { + self.0.as_mut_ptr() + } + + unsafe fn set_init(&mut self, init_len: usize) { + if self.len() < init_len { + self.0.set_len(init_len); + } + } + } +} diff --git a/actix-files/src/directory.rs b/actix-files/src/directory.rs index 80e0c98d0..26225ea5c 100644 --- a/actix-files/src/directory.rs +++ b/actix-files/src/directory.rs @@ -40,14 +40,23 @@ impl Directory { pub(crate) type DirectoryRenderer = dyn Fn(&Directory, &HttpRequest) -> Result; -// show file url as relative to static path +/// Returns percent encoded file URL path. macro_rules! encode_file_url { ($path:ident) => { utf8_percent_encode(&$path, CONTROLS) }; } -// " -- " & -- & ' -- ' < -- < > -- > / -- / +/// Returns HTML entity encoded formatter. +/// +/// ```plain +/// " => " +/// & => & +/// ' => ' +/// < => < +/// > => > +/// / => / +/// ``` macro_rules! encode_file_name { ($entry:ident) => { escape_html_entity(&$entry.file_name().to_string_lossy(), Html) diff --git a/actix-files/src/error.rs b/actix-files/src/error.rs index 9b30cbaa2..6682529f8 100644 --- a/actix-files/src/error.rs +++ b/actix-files/src/error.rs @@ -1,4 +1,4 @@ -use actix_web::{http::StatusCode, HttpResponse, ResponseError}; +use actix_web::{http::StatusCode, ResponseError}; use derive_more::Display; /// Errors which can occur when serving static files. @@ -16,22 +16,30 @@ pub enum FilesError { /// Return `NotFound` for `FilesError` impl ResponseError for FilesError { - fn error_response(&self) -> HttpResponse { - HttpResponse::new(StatusCode::NOT_FOUND) + fn status_code(&self) -> StatusCode { + StatusCode::NOT_FOUND } } +#[allow(clippy::enum_variant_names)] #[derive(Display, Debug, PartialEq)] +#[non_exhaustive] pub enum UriSegmentError { /// The segment started with the wrapped invalid character. #[display(fmt = "The segment started with the wrapped invalid character")] BadStart(char), + /// The segment contained the wrapped invalid character. #[display(fmt = "The segment contained the wrapped invalid character")] BadChar(char), + /// The segment ended with the wrapped invalid character. #[display(fmt = "The segment ended with the wrapped invalid character")] BadEnd(char), + + /// The path is not a valid UTF-8 string after doing percent decoding. + #[display(fmt = "The path is not a valid UTF-8 string after percent-decoding")] + NotValidUtf8, } /// Return `BadRequest` for `UriSegmentError` diff --git a/actix-files/src/files.rs b/actix-files/src/files.rs index 6f8b28bbf..adfb93232 100644 --- a/actix-files/src/files.rs +++ b/actix-files/src/files.rs @@ -1,25 +1,35 @@ -use std::{cell::RefCell, fmt, io, path::PathBuf, rc::Rc}; +use std::{ + cell::RefCell, + fmt, io, + path::{Path, PathBuf}, + rc::Rc, +}; use actix_service::{boxed, IntoServiceFactory, ServiceFactory, ServiceFactoryExt}; use actix_web::{ - dev::{AppService, HttpServiceFactory, ResourceDef, ServiceRequest, ServiceResponse}, + dev::{ + AppService, HttpServiceFactory, RequestHead, ResourceDef, ServiceRequest, + ServiceResponse, + }, error::Error, guard::Guard, http::header::DispositionType, HttpRequest, }; -use futures_util::future::{ok, FutureExt, LocalBoxFuture}; +use futures_core::future::LocalBoxFuture; use crate::{ - directory_listing, named, Directory, DirectoryRenderer, FilesService, HttpNewService, - MimeOverride, + directory_listing, named, + service::{FilesService, FilesServiceInner}, + Directory, DirectoryRenderer, HttpNewService, MimeOverride, PathFilter, }; /// Static files handling service. /// /// `Files` service must be registered with `App::service()` method. /// -/// ```rust +/// # Examples +/// ``` /// use actix_web::App; /// use actix_files::Files; /// @@ -35,8 +45,10 @@ pub struct Files { default: Rc>>>, renderer: Rc, mime_override: Option>, + path_filter: Option>, file_flags: named::Flags, - guards: Option>, + use_guards: Option>, + guards: Vec>, hidden_files: bool, } @@ -58,6 +70,8 @@ impl Clone for Files { file_flags: self.file_flags, path: self.path.clone(), mime_override: self.mime_override.clone(), + path_filter: self.path_filter.clone(), + use_guards: self.use_guards.clone(), guards: self.guards.clone(), hidden_files: self.hidden_files, } @@ -79,10 +93,9 @@ impl Files { /// If the mount path is set as the root path `/`, services registered after this one will /// be inaccessible. Register more specific handlers and services first. /// - /// `Files` uses a threadpool for blocking filesystem operations. By default, the pool uses a - /// max number of threads equal to `512 * HttpServer::worker`. Real time thread count are - /// adjusted with work load. More threads would spawn when need and threads goes idle for a - /// period of time would be de-spawned. + /// `Files` utilizes the existing Tokio thread-pool for blocking filesystem operations. + /// The number of running threads is adjusted over time as needed, up to a maximum of 512 times + /// the number of server [workers](actix_web::HttpServer::workers), by default. pub fn new>(mount_path: &str, serve_from: T) -> Files { let orig_dir = serve_from.into(); let dir = match orig_dir.canonicalize() { @@ -94,7 +107,7 @@ impl Files { }; Files { - path: mount_path.to_owned(), + path: mount_path.trim_end_matches('/').to_owned(), directory: dir, index: None, show_index: false, @@ -102,8 +115,10 @@ impl Files { default: Rc::new(RefCell::new(None)), renderer: Rc::new(directory_listing), mime_override: None, + path_filter: None, file_flags: named::Flags::default(), - guards: None, + use_guards: None, + guards: Vec::new(), hidden_files: false, } } @@ -111,6 +126,9 @@ impl Files { /// Show files listing for directories. /// /// By default show files listing is disabled. + /// + /// When used with [`Files::index_file()`], files listing is shown as a fallback + /// when the index file is not found. pub fn show_files_listing(mut self) -> Self { self.show_index = true; self @@ -143,10 +161,45 @@ impl Files { self } + /// Sets path filtering closure. + /// + /// The path provided to the closure is relative to `serve_from` path. + /// You can safely join this path with the `serve_from` path to get the real path. + /// However, the real path may not exist since the filter is called before checking path existence. + /// + /// When a path doesn't pass the filter, [`Files::default_handler`] is called if set, otherwise, + /// `404 Not Found` is returned. + /// + /// # Examples + /// ``` + /// use std::path::Path; + /// use actix_files::Files; + /// + /// // prevent searching subdirectories and following symlinks + /// let files_service = Files::new("/", "./static").path_filter(|path, _| { + /// path.components().count() == 1 + /// && Path::new("./static") + /// .join(path) + /// .symlink_metadata() + /// .map(|m| !m.file_type().is_symlink()) + /// .unwrap_or(false) + /// }); + /// ``` + pub fn path_filter(mut self, f: F) -> Self + where + F: Fn(&Path, &RequestHead) -> bool + 'static, + { + self.path_filter = Some(Rc::new(f)); + self + } + /// Set index file /// - /// Shows specific index file for directory "/" instead of + /// Shows specific index file for directories instead of /// showing files listing. + /// + /// If the index file is not found, files listing is shown as a fallback if + /// [`Files::show_files_listing()`] is set. pub fn index_file>(mut self, index: T) -> Self { self.index = Some(index.into()); self @@ -155,7 +208,6 @@ impl Files { /// Specifies whether to use ETag or not. /// /// Default is true. - #[inline] pub fn use_etag(mut self, value: bool) -> Self { self.file_flags.set(named::Flags::ETAG, value); self @@ -164,7 +216,6 @@ impl Files { /// Specifies whether to use Last-Modified or not. /// /// Default is true. - #[inline] pub fn use_last_modified(mut self, value: bool) -> Self { self.file_flags.set(named::Flags::LAST_MD, value); self @@ -173,31 +224,80 @@ impl Files { /// Specifies whether text responses should signal a UTF-8 encoding. /// /// Default is false (but will default to true in a future version). - #[inline] pub fn prefer_utf8(mut self, value: bool) -> Self { self.file_flags.set(named::Flags::PREFER_UTF8, value); self } - /// Specifies custom guards to use for directory listings and files. + /// Adds a routing guard. /// - /// Default behaviour allows GET and HEAD. - #[inline] - pub fn use_guards(mut self, guards: G) -> Self { - self.guards = Some(Rc::new(guards)); + /// Use this to allow multiple chained file services that respond to strictly different + /// properties of a request. Due to the way routing works, if a guard check returns true and the + /// request starts being handled by the file service, it will not be able to back-out and try + /// the next service, you will simply get a 404 (or 405) error response. + /// + /// To allow `POST` requests to retrieve files, see [`Files::use_guards`]. + /// + /// # Examples + /// ``` + /// use actix_web::{guard::Header, App}; + /// use actix_files::Files; + /// + /// App::new().service( + /// Files::new("/","/my/site/files") + /// .guard(Header("Host", "example.com")) + /// ); + /// ``` + pub fn guard(mut self, guard: G) -> Self { + self.guards.push(Rc::new(guard)); self } + /// Specifies guard to check before fetching directory listings or files. + /// + /// Note that this guard has no effect on routing; it's main use is to guard on the request's + /// method just before serving the file, only allowing `GET` and `HEAD` requests by default. + /// See [`Files::guard`] for routing guards. + pub fn method_guard(mut self, guard: G) -> Self { + self.use_guards = Some(Rc::new(guard)); + self + } + + /// See [`Files::method_guard`]. + #[doc(hidden)] + #[deprecated(since = "0.6.0", note = "Renamed to `method_guard`.")] + pub fn use_guards(self, guard: G) -> Self { + self.method_guard(guard) + } + /// Disable `Content-Disposition` header. /// /// By default Content-Disposition` header is enabled. - #[inline] pub fn disable_content_disposition(mut self) -> Self { self.file_flags.remove(named::Flags::CONTENT_DISPOSITION); self } /// Sets default handler which is used when no matched file could be found. + /// + /// # Examples + /// Setting a fallback static file handler: + /// ``` + /// use actix_files::{Files, NamedFile}; + /// use actix_web::dev::{ServiceRequest, ServiceResponse, fn_service}; + /// + /// # fn run() -> Result<(), actix_web::Error> { + /// let files = Files::new("/", "./static") + /// .index_file("index.html") + /// .default_handler(fn_service(|req: ServiceRequest| async { + /// let (req, _) = req.into_parts(); + /// let file = NamedFile::open_async("./static/404.html").await?; + /// let res = file.into_response(&req); + /// Ok(ServiceResponse::new(req, res)) + /// })); + /// # Ok(()) + /// # } + /// ``` pub fn default_handler(mut self, f: F) -> Self where F: IntoServiceFactory, @@ -217,7 +317,6 @@ impl Files { } /// Enables serving hidden files and directories, allowing a leading dots in url fragments. - #[inline] pub fn use_hidden_files(mut self) -> Self { self.hidden_files = true; self @@ -225,7 +324,19 @@ impl Files { } impl HttpServiceFactory for Files { - fn register(self, config: &mut AppService) { + fn register(mut self, config: &mut AppService) { + let guards = if self.guards.is_empty() { + None + } else { + let guards = std::mem::take(&mut self.guards); + Some( + guards + .into_iter() + .map(|guard| -> Box { Box::new(guard) }) + .collect::>(), + ) + }; + if self.default.borrow().is_none() { *self.default.borrow_mut() = Some(config.default_service()); } @@ -236,7 +347,7 @@ impl HttpServiceFactory for Files { ResourceDef::prefix(&self.path) }; - config.register_service(rdef, None, self, None) + config.register_service(rdef, guards, self, None) } } @@ -249,7 +360,7 @@ impl ServiceFactory for Files { type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - let mut srv = FilesService { + let mut inner = FilesServiceInner { directory: self.directory.clone(), index: self.index.clone(), show_index: self.show_index, @@ -257,24 +368,25 @@ impl ServiceFactory for Files { default: None, renderer: self.renderer.clone(), mime_override: self.mime_override.clone(), + path_filter: self.path_filter.clone(), file_flags: self.file_flags, - guards: self.guards.clone(), + guards: self.use_guards.clone(), hidden_files: self.hidden_files, }; if let Some(ref default) = *self.default.borrow() { - default - .new_service(()) - .map(move |result| match result { + let fut = default.new_service(()); + Box::pin(async { + match fut.await { Ok(default) => { - srv.default = Some(default); - Ok(srv) + inner.default = Some(default); + Ok(FilesService(Rc::new(inner))) } Err(_) => Err(()), - }) - .boxed_local() + } + }) } else { - ok(srv).boxed_local() + Box::pin(async move { Ok(FilesService(Rc::new(inner))) }) } } } diff --git a/actix-files/src/lib.rs b/actix-files/src/lib.rs index 04dd9f07f..a11aa32c7 100644 --- a/actix-files/src/lib.rs +++ b/actix-files/src/lib.rs @@ -3,7 +3,7 @@ //! Provides a non-blocking service for serving static files from disk. //! //! # Example -//! ```rust +//! ``` //! use actix_web::App; //! use actix_files::Files; //! @@ -11,16 +11,17 @@ //! .service(Files::new("/static", ".").prefer_utf8(true)); //! ``` -#![deny(rust_2018_idioms)] -#![warn(missing_docs, missing_debug_implementations)] +#![deny(rust_2018_idioms, nonstandard_style)] +#![warn(future_incompatible, missing_docs, missing_debug_implementations)] use actix_service::boxed::{BoxService, BoxServiceFactory}; use actix_web::{ - dev::{ServiceRequest, ServiceResponse}, + dev::{RequestHead, ServiceRequest, ServiceResponse}, error::Error, http::header::DispositionType, }; use mime_guess::from_ext; +use std::path::Path; mod chunked; mod directory; @@ -32,12 +33,12 @@ mod path_buf; mod range; mod service; -pub use crate::chunked::ChunkedReadFile; -pub use crate::directory::Directory; -pub use crate::files::Files; -pub use crate::named::NamedFile; -pub use crate::range::HttpRange; -pub use crate::service::FilesService; +pub use self::chunked::ChunkedReadFile; +pub use self::directory::Directory; +pub use self::files::Files; +pub use self::named::NamedFile; +pub use self::range::HttpRange; +pub use self::service::FilesService; use self::directory::{directory_listing, DirectoryRenderer}; use self::error::FilesError; @@ -56,10 +57,12 @@ pub fn file_extension_to_mime(ext: &str) -> mime::Mime { type MimeOverride = dyn Fn(&mime::Name<'_>) -> DispositionType; +type PathFilter = dyn Fn(&Path, &RequestHead) -> bool; + #[cfg(test)] mod tests { use std::{ - fs::{self, File}, + fs::{self}, ops::Add, time::{Duration, SystemTime}, }; @@ -76,11 +79,11 @@ mod tests { web::{self, Bytes}, App, HttpResponse, Responder, }; - use futures_util::future::ok; use super::*; + use crate::named::File; - #[actix_rt::test] + #[actix_web::test] async fn test_file_extension_to_mime() { let m = file_extension_to_mime(""); assert_eq!(m, mime::APPLICATION_OCTET_STREAM); @@ -97,7 +100,7 @@ mod tests { #[actix_rt::test] async fn test_if_modified_since_without_if_none_match() { - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let since = header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); let req = TestRequest::default() @@ -109,7 +112,7 @@ mod tests { #[actix_rt::test] async fn test_if_modified_since_without_if_none_match_same() { - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let since = file.last_modified().unwrap(); let req = TestRequest::default() @@ -121,7 +124,7 @@ mod tests { #[actix_rt::test] async fn test_if_modified_since_with_if_none_match() { - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let since = header::HttpDate::from(SystemTime::now().add(Duration::from_secs(60))); let req = TestRequest::default() @@ -134,7 +137,7 @@ mod tests { #[actix_rt::test] async fn test_if_unmodified_since() { - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let since = file.last_modified().unwrap(); let req = TestRequest::default() @@ -146,7 +149,7 @@ mod tests { #[actix_rt::test] async fn test_if_unmodified_since_failed() { - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let since = header::HttpDate::from(SystemTime::UNIX_EPOCH); let req = TestRequest::default() @@ -158,8 +161,8 @@ mod tests { #[actix_rt::test] async fn test_named_file_text() { - assert!(NamedFile::open("test--").is_err()); - let mut file = NamedFile::open("Cargo.toml").unwrap(); + assert!(NamedFile::open_async("test--").await.is_err()); + let mut file = NamedFile::open_async("Cargo.toml").await.unwrap(); { file.file(); let _f: &File = &file; @@ -182,8 +185,8 @@ mod tests { #[actix_rt::test] async fn test_named_file_content_disposition() { - assert!(NamedFile::open("test--").is_err()); - let mut file = NamedFile::open("Cargo.toml").unwrap(); + assert!(NamedFile::open_async("test--").await.is_err()); + let mut file = NamedFile::open_async("Cargo.toml").await.unwrap(); { file.file(); let _f: &File = &file; @@ -199,7 +202,8 @@ mod tests { "inline; filename=\"Cargo.toml\"" ); - let file = NamedFile::open("Cargo.toml") + let file = NamedFile::open_async("Cargo.toml") + .await .unwrap() .disable_content_disposition(); let req = TestRequest::default().to_http_request(); @@ -209,8 +213,19 @@ mod tests { #[actix_rt::test] async fn test_named_file_non_ascii_file_name() { - let mut file = - NamedFile::from_file(File::open("Cargo.toml").unwrap(), "貨物.toml").unwrap(); + let file = { + #[cfg(feature = "experimental-io-uring")] + { + crate::named::File::open("Cargo.toml").await.unwrap() + } + + #[cfg(not(feature = "experimental-io-uring"))] + { + crate::named::File::open("Cargo.toml").unwrap() + } + }; + + let mut file = NamedFile::from_file(file, "貨物.toml").unwrap(); { file.file(); let _f: &File = &file; @@ -233,7 +248,8 @@ mod tests { #[actix_rt::test] async fn test_named_file_set_content_type() { - let mut file = NamedFile::open("Cargo.toml") + let mut file = NamedFile::open_async("Cargo.toml") + .await .unwrap() .set_content_type(mime::TEXT_XML); { @@ -258,7 +274,7 @@ mod tests { #[actix_rt::test] async fn test_named_file_image() { - let mut file = NamedFile::open("tests/test.png").unwrap(); + let mut file = NamedFile::open_async("tests/test.png").await.unwrap(); { file.file(); let _f: &File = &file; @@ -279,13 +295,30 @@ mod tests { ); } + #[actix_rt::test] + async fn test_named_file_javascript() { + let file = NamedFile::open_async("tests/test.js").await.unwrap(); + + let req = TestRequest::default().to_http_request(); + let resp = file.respond_to(&req).await.unwrap(); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "application/javascript" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"test.js\"" + ); + } + #[actix_rt::test] async fn test_named_file_image_attachment() { let cd = ContentDisposition { disposition: DispositionType::Attachment, parameters: vec![DispositionParam::Filename(String::from("test.png"))], }; - let mut file = NamedFile::open("tests/test.png") + let mut file = NamedFile::open_async("tests/test.png") + .await .unwrap() .set_content_disposition(cd); { @@ -310,7 +343,7 @@ mod tests { #[actix_rt::test] async fn test_named_file_binary() { - let mut file = NamedFile::open("tests/test.binary").unwrap(); + let mut file = NamedFile::open_async("tests/test.binary").await.unwrap(); { file.file(); let _f: &File = &file; @@ -333,7 +366,8 @@ mod tests { #[actix_rt::test] async fn test_named_file_status_code_text() { - let mut file = NamedFile::open("Cargo.toml") + let mut file = NamedFile::open_async("Cargo.toml") + .await .unwrap() .set_status_code(StatusCode::NOT_FOUND); { @@ -413,7 +447,7 @@ mod tests { #[actix_rt::test] async fn test_named_file_content_range_headers() { - let srv = test::start(|| App::new().service(Files::new("/", "."))); + let srv = actix_test::start(|| App::new().service(Files::new("/", "."))); // Valid range header let response = srv @@ -438,7 +472,7 @@ mod tests { #[actix_rt::test] async fn test_named_file_content_length_headers() { - let srv = test::start(|| App::new().service(Files::new("/", "."))); + let srv = actix_test::start(|| App::new().service(Files::new("/", "."))); // Valid range header let response = srv @@ -477,7 +511,7 @@ mod tests { #[actix_rt::test] async fn test_head_content_length_headers() { - let srv = test::start(|| App::new().service(Files::new("/", "."))); + let srv = actix_test::start(|| App::new().service(Files::new("/", "."))); let response = srv.head("/tests/test.binary").send().await.unwrap(); @@ -532,7 +566,7 @@ mod tests { #[actix_rt::test] async fn test_files_guards() { let srv = test::init_service( - App::new().service(Files::new("/", ".").use_guards(guard::Post())), + App::new().service(Files::new("/", ".").method_guard(guard::Post())), ) .await; @@ -549,7 +583,8 @@ mod tests { async fn test_named_file_content_encoding() { let srv = test::init_service(App::new().wrap(Compress::default()).service( web::resource("/").to(|| async { - NamedFile::open("Cargo.toml") + NamedFile::open_async("Cargo.toml") + .await .unwrap() .set_content_encoding(header::ContentEncoding::Identity) }), @@ -562,14 +597,16 @@ mod tests { .to_request(); let res = test::call_service(&srv, request).await; assert_eq!(res.status(), StatusCode::OK); - assert!(!res.headers().contains_key(header::CONTENT_ENCODING)); + assert!(res.headers().contains_key(header::CONTENT_ENCODING)); + assert!(!test::read_body(res).await.is_empty()); } #[actix_rt::test] async fn test_named_file_content_encoding_gzip() { let srv = test::init_service(App::new().wrap(Compress::default()).service( web::resource("/").to(|| async { - NamedFile::open("Cargo.toml") + NamedFile::open_async("Cargo.toml") + .await .unwrap() .set_content_encoding(header::ContentEncoding::Gzip) }), @@ -595,7 +632,7 @@ mod tests { #[actix_rt::test] async fn test_named_file_allowed_method() { let req = TestRequest::default().method(Method::GET).to_http_request(); - let file = NamedFile::open("Cargo.toml").unwrap(); + let file = NamedFile::open_async("Cargo.toml").await.unwrap(); let resp = file.respond_to(&req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } @@ -632,7 +669,7 @@ mod tests { #[actix_rt::test] async fn test_redirect_to_slash_directory() { - // should not redirect if no index + // should not redirect if no index and files listing is disabled let srv = test::init_service( App::new().service(Files::new("/", ".").redirect_to_slash_directory()), ) @@ -654,6 +691,19 @@ mod tests { let resp = test::call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::FOUND); + // should redirect if files listing is enabled + let srv = test::init_service( + App::new().service( + Files::new("/", ".") + .show_files_listing() + .redirect_to_slash_directory(), + ), + ) + .await; + let req = TestRequest::with_uri("/tests").to_request(); + let resp = test::call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::FOUND); + // should not redirect if the path is wrong let req = TestRequest::with_uri("/not_existing").to_request(); let resp = test::call_service(&srv, req).await; @@ -662,89 +712,52 @@ mod tests { #[actix_rt::test] async fn test_static_files_bad_directory() { - let _st: Files = Files::new("/", "missing"); - let _st: Files = Files::new("/", "Cargo.toml"); + let service = Files::new("/", "./missing").new_service(()).await.unwrap(); + + let req = TestRequest::with_uri("/").to_srv_request(); + let resp = test::call_service(&service, req).await; + + assert_eq!(resp.status(), StatusCode::NOT_FOUND); } #[actix_rt::test] async fn test_default_handler_file_missing() { let st = Files::new("/", ".") - .default_handler(|req: ServiceRequest| { - ok(req.into_response(HttpResponse::Ok().body("default content"))) + .default_handler(|req: ServiceRequest| async { + Ok(req.into_response(HttpResponse::Ok().body("default content"))) }) .new_service(()) .await .unwrap(); let req = TestRequest::with_uri("/missing").to_srv_request(); - let resp = test::call_service(&st, req).await; + assert_eq!(resp.status(), StatusCode::OK); let bytes = test::read_body(resp).await; assert_eq!(bytes, web::Bytes::from_static(b"default content")); } - // #[actix_rt::test] - // async fn test_serve_index() { - // let st = Files::new(".").index_file("test.binary"); - // let req = TestRequest::default().uri("/tests").finish(); + #[actix_rt::test] + async fn test_serve_index_nested() { + let service = Files::new(".", ".") + .index_file("lib.rs") + .new_service(()) + .await + .unwrap(); - // let resp = st.handle(&req).respond_to(&req).unwrap(); - // let resp = resp.as_msg(); - // assert_eq!(resp.status(), StatusCode::OK); - // assert_eq!( - // resp.headers() - // .get(header::CONTENT_TYPE) - // .expect("content type"), - // "application/octet-stream" - // ); - // assert_eq!( - // resp.headers() - // .get(header::CONTENT_DISPOSITION) - // .expect("content disposition"), - // "attachment; filename=\"test.binary\"" - // ); + let req = TestRequest::default().uri("/src").to_srv_request(); + let resp = test::call_service(&service, req).await; - // let req = TestRequest::default().uri("/tests/").finish(); - // let resp = st.handle(&req).respond_to(&req).unwrap(); - // let resp = resp.as_msg(); - // assert_eq!(resp.status(), StatusCode::OK); - // assert_eq!( - // resp.headers().get(header::CONTENT_TYPE).unwrap(), - // "application/octet-stream" - // ); - // assert_eq!( - // resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), - // "attachment; filename=\"test.binary\"" - // ); - - // // nonexistent index file - // let req = TestRequest::default().uri("/tests/unknown").finish(); - // let resp = st.handle(&req).respond_to(&req).unwrap(); - // let resp = resp.as_msg(); - // assert_eq!(resp.status(), StatusCode::NOT_FOUND); - - // let req = TestRequest::default().uri("/tests/unknown/").finish(); - // let resp = st.handle(&req).respond_to(&req).unwrap(); - // let resp = resp.as_msg(); - // assert_eq!(resp.status(), StatusCode::NOT_FOUND); - // } - - // #[actix_rt::test] - // async fn test_serve_index_nested() { - // let st = Files::new(".").index_file("mod.rs"); - // let req = TestRequest::default().uri("/src/client").finish(); - // let resp = st.handle(&req).respond_to(&req).unwrap(); - // let resp = resp.as_msg(); - // assert_eq!(resp.status(), StatusCode::OK); - // assert_eq!( - // resp.headers().get(header::CONTENT_TYPE).unwrap(), - // "text/x-rust" - // ); - // assert_eq!( - // resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), - // "inline; filename=\"mod.rs\"" - // ); - // } + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/x-rust" + ); + assert_eq!( + resp.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"lib.rs\"" + ); + } #[actix_rt::test] async fn integration_serve_index() { @@ -790,5 +803,187 @@ mod tests { let req = TestRequest::get().uri("/test/%43argo.toml").to_request(); let res = test::call_service(&srv, req).await; assert_eq!(res.status(), StatusCode::OK); + + // `%2F` == `/` + let req = TestRequest::get().uri("/test%2Ftest.binary").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::get().uri("/test/Cargo.toml%00").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_percent_encoding_2() { + let tmpdir = tempfile::tempdir().unwrap(); + let filename = match cfg!(unix) { + true => "ض:?#[]{}<>()@!$&'`|*+,;= %20.test", + false => "ض#[]{}()@!$&'`+,;= %20.test", + }; + let filename_encoded = filename + .as_bytes() + .iter() + .map(|c| format!("%{:02X}", c)) + .collect::(); + std::fs::File::create(tmpdir.path().join(filename)).unwrap(); + + let srv = test::init_service(App::new().service(Files::new("", tmpdir.path()))).await; + + let req = TestRequest::get() + .uri(&format!("/{}", filename_encoded)) + .to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_serve_named_file() { + let factory = NamedFile::open_async("Cargo.toml").await.unwrap(); + let srv = test::init_service(App::new().service(factory)).await; + + let req = TestRequest::get().uri("/Cargo.toml").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); + + let bytes = test::read_body(res).await; + let data = Bytes::from(fs::read("Cargo.toml").unwrap()); + assert_eq!(bytes, data); + + let req = TestRequest::get().uri("/test/unknown").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_serve_named_file_prefix() { + let factory = NamedFile::open_async("Cargo.toml").await.unwrap(); + let srv = + test::init_service(App::new().service(web::scope("/test").service(factory))).await; + + let req = TestRequest::get().uri("/test/Cargo.toml").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); + + let bytes = test::read_body(res).await; + let data = Bytes::from(fs::read("Cargo.toml").unwrap()); + assert_eq!(bytes, data); + + let req = TestRequest::get().uri("/Cargo.toml").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_named_file_default_service() { + let factory = NamedFile::open_async("Cargo.toml").await.unwrap(); + let srv = test::init_service(App::new().default_service(factory)).await; + + for route in ["/foobar", "/baz", "/"].iter() { + let req = TestRequest::get().uri(route).to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); + + let bytes = test::read_body(res).await; + let data = Bytes::from(fs::read("Cargo.toml").unwrap()); + assert_eq!(bytes, data); + } + } + + #[actix_rt::test] + async fn test_default_handler_named_file() { + let factory = NamedFile::open_async("Cargo.toml").await.unwrap(); + let st = Files::new("/", ".") + .default_handler(factory) + .new_service(()) + .await + .unwrap(); + let req = TestRequest::with_uri("/missing").to_srv_request(); + let resp = test::call_service(&st, req).await; + + assert_eq!(resp.status(), StatusCode::OK); + let bytes = test::read_body(resp).await; + let data = Bytes::from(fs::read("Cargo.toml").unwrap()); + assert_eq!(bytes, data); + } + + #[actix_rt::test] + async fn test_symlinks() { + let srv = test::init_service(App::new().service(Files::new("test", "."))).await; + + let req = TestRequest::get() + .uri("/test/tests/symlink-test.png") + .to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(header::CONTENT_DISPOSITION).unwrap(), + "inline; filename=\"symlink-test.png\"" + ); + } + + #[actix_rt::test] + async fn test_index_with_show_files_listing() { + let service = Files::new(".", ".") + .index_file("lib.rs") + .show_files_listing() + .new_service(()) + .await + .unwrap(); + + // Serve the index if exists + let req = TestRequest::default().uri("/src").to_srv_request(); + let resp = test::call_service(&service, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/x-rust" + ); + + // Show files listing, otherwise. + let req = TestRequest::default().uri("/tests").to_srv_request(); + let resp = test::call_service(&service, req).await; + assert_eq!( + resp.headers().get(header::CONTENT_TYPE).unwrap(), + "text/html; charset=utf-8" + ); + let bytes = test::read_body(resp).await; + assert!(format!("{:?}", bytes).contains("/tests/test.png")); + } + + #[actix_rt::test] + async fn test_path_filter() { + // prevent searching subdirectories + let st = Files::new("/", ".") + .path_filter(|path, _| path.components().count() == 1) + .new_service(()) + .await + .unwrap(); + + let req = TestRequest::with_uri("/Cargo.toml").to_srv_request(); + let resp = test::call_service(&st, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/src/lib.rs").to_srv_request(); + let resp = test::call_service(&st, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } + + #[actix_rt::test] + async fn test_default_handler_filter() { + let st = Files::new("/", ".") + .default_handler(|req: ServiceRequest| async { + Ok(req.into_response(HttpResponse::Ok().body("default content"))) + }) + .path_filter(|path, _| path.extension() == Some("png".as_ref())) + .new_service(()) + .await + .unwrap(); + let req = TestRequest::with_uri("/Cargo.toml").to_srv_request(); + let resp = test::call_service(&st, req).await; + + assert_eq!(resp.status(), StatusCode::OK); + let bytes = test::read_body(resp).await; + assert_eq!(bytes, web::Bytes::from_static(b"default content")); } } diff --git a/actix-files/src/named.rs b/actix-files/src/named.rs index a688b2e6c..019730dc6 100644 --- a/actix-files/src/named.rs +++ b/actix-files/src/named.rs @@ -1,26 +1,29 @@ -use std::fs::{File, Metadata}; -use std::io; -use std::ops::{Deref, DerefMut}; -use std::path::{Path, PathBuf}; -use std::time::{SystemTime, UNIX_EPOCH}; - -#[cfg(unix)] -use std::os::unix::fs::MetadataExt; +use std::{ + fmt, + fs::Metadata, + io, + path::{Path, PathBuf}, + time::{SystemTime, UNIX_EPOCH}, +}; +use actix_service::{Service, ServiceFactory}; use actix_web::{ - dev::{BodyEncoding, SizedStream}, + body::{self, BoxBody, SizedStream}, + dev::{AppService, HttpServiceFactory, ResourceDef, ServiceRequest, ServiceResponse}, http::{ header::{ - self, Charset, ContentDisposition, DispositionParam, DispositionType, ExtendedValue, + self, Charset, ContentDisposition, ContentEncoding, DispositionParam, + DispositionType, ExtendedValue, HeaderValue, }, - ContentEncoding, StatusCode, + StatusCode, }, - HttpMessage, HttpRequest, HttpResponse, Responder, + Error, HttpMessage, HttpRequest, HttpResponse, Responder, }; use bitflags::bitflags; +use derive_more::{Deref, DerefMut}; +use futures_core::future::LocalBoxFuture; use mime_guess::from_path; -use crate::ChunkedReadFile; use crate::{encoding::equiv_utf8_text, range::HttpRange}; bitflags! { @@ -39,9 +42,34 @@ impl Default for Flags { } /// A file with an associated name. -#[derive(Debug)] +/// +/// `NamedFile` can be registered as services: +/// ``` +/// use actix_web::App; +/// use actix_files::NamedFile; +/// +/// # async fn run() -> Result<(), Box> { +/// let file = NamedFile::open_async("./static/index.html").await?; +/// let app = App::new().service(file); +/// # Ok(()) +/// # } +/// ``` +/// +/// They can also be returned from handlers: +/// ``` +/// use actix_web::{Responder, get}; +/// use actix_files::NamedFile; +/// +/// #[get("/")] +/// async fn index() -> impl Responder { +/// NamedFile::open_async("./static/index.html").await +/// } +/// ``` +#[derive(Deref, DerefMut)] pub struct NamedFile { path: PathBuf, + #[deref] + #[deref_mut] file: File, modified: Option, pub(crate) md: Metadata, @@ -52,6 +80,39 @@ pub struct NamedFile { pub(crate) encoding: Option, } +impl fmt::Debug for NamedFile { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NamedFile") + .field("path", &self.path) + .field( + "file", + #[cfg(feature = "experimental-io-uring")] + { + &"tokio_uring::File" + }, + #[cfg(not(feature = "experimental-io-uring"))] + { + &self.file + }, + ) + .field("modified", &self.modified) + .field("md", &self.md) + .field("flags", &self.flags) + .field("status_code", &self.status_code) + .field("content_type", &self.content_type) + .field("content_disposition", &self.content_disposition) + .field("encoding", &self.encoding) + .finish() + } +} + +#[cfg(not(feature = "experimental-io-uring"))] +pub(crate) use std::fs::File; +#[cfg(feature = "experimental-io-uring")] +pub(crate) use tokio_uring::fs::File; + +use super::chunked; + impl NamedFile { /// Creates an instance from a previously opened file. /// @@ -59,8 +120,7 @@ impl NamedFile { /// `ContentDisposition` headers. /// /// # Examples - /// - /// ```rust + /// ```ignore /// use actix_files::NamedFile; /// use std::io::{self, Write}; /// use std::env; @@ -94,6 +154,11 @@ impl NamedFile { let disposition = match ct.type_() { mime::IMAGE | mime::TEXT | mime::VIDEO => DispositionType::Inline, + mime::APPLICATION => match ct.subtype() { + mime::JAVASCRIPT | mime::JSON => DispositionType::Inline, + name if name == "wasm" => DispositionType::Inline, + _ => DispositionType::Attachment, + }, _ => DispositionType::Attachment, }; @@ -116,7 +181,30 @@ impl NamedFile { (ct, cd) }; - let md = file.metadata()?; + let md = { + #[cfg(not(feature = "experimental-io-uring"))] + { + file.metadata()? + } + + #[cfg(feature = "experimental-io-uring")] + { + use std::os::unix::prelude::{AsRawFd, FromRawFd}; + + let fd = file.as_raw_fd(); + + // SAFETY: fd is borrowed and lives longer than the unsafe block + unsafe { + let file = std::fs::File::from_raw_fd(fd); + let md = file.metadata(); + // SAFETY: forget the fd before exiting block in success or error case but don't + // run destructor (that would close file handle) + std::mem::forget(file); + md? + } + } + }; + let modified = md.modified().ok(); let encoding = None; @@ -136,14 +224,42 @@ impl NamedFile { /// Attempts to open a file in read-only mode. /// /// # Examples - /// - /// ```rust + /// ``` /// use actix_files::NamedFile; - /// /// let file = NamedFile::open("foo.txt"); /// ``` + #[cfg(not(feature = "experimental-io-uring"))] pub fn open>(path: P) -> io::Result { - Self::from_file(File::open(&path)?, path) + let file = File::open(&path)?; + Self::from_file(file, path) + } + + /// Attempts to open a file asynchronously in read-only mode. + /// + /// When the `experimental-io-uring` crate feature is enabled, this will be async. + /// Otherwise, it will be just like [`open`][Self::open]. + /// + /// # Examples + /// ``` + /// use actix_files::NamedFile; + /// # async fn open() { + /// let file = NamedFile::open_async("foo.txt").await.unwrap(); + /// # } + /// ``` + pub async fn open_async>(path: P) -> io::Result { + let file = { + #[cfg(not(feature = "experimental-io-uring"))] + { + File::open(&path)? + } + + #[cfg(feature = "experimental-io-uring")] + { + File::open(&path).await? + } + }; + + Self::from_file(file, path) } /// Returns reference to the underlying `File` object. @@ -155,13 +271,12 @@ impl NamedFile { /// Retrieve the path of this file. /// /// # Examples - /// - /// ```rust + /// ``` /// # use std::io; /// use actix_files::NamedFile; /// - /// # fn path() -> io::Result<()> { - /// let file = NamedFile::open("test.txt")?; + /// # async fn path() -> io::Result<()> { + /// let file = NamedFile::open_async("test.txt").await?; /// assert_eq!(file.path().as_os_str(), "foo.txt"); /// # Ok(()) /// # } @@ -177,21 +292,21 @@ impl NamedFile { self } - /// Set the MIME Content-Type for serving this file. By default - /// the Content-Type is inferred from the filename extension. + /// Set the MIME Content-Type for serving this file. By default the Content-Type is inferred + /// from the filename extension. #[inline] pub fn set_content_type(mut self, mime_type: mime::Mime) -> Self { self.content_type = mime_type; self } - /// Set the Content-Disposition for serving this file. This allows - /// changing the inline/attachment disposition as well as the filename - /// sent to the peer. By default the disposition is `inline` for text, - /// image, and video content types, and `attachment` otherwise, and - /// the filename is taken from the path provided in the `open` method - /// after converting it to UTF-8 using. - /// [`std::ffi::OsStr::to_string_lossy`] + /// Set the Content-Disposition for serving this file. This allows changing the + /// `inline/attachment` disposition as well as the filename sent to the peer. + /// + /// By default the disposition is `inline` for `text/*`, `image/*`, `video/*` and + /// `application/{javascript, json, wasm}` mime types, and `attachment` otherwise, and the + /// filename is taken from the path provided in the `open` method after converting it to UTF-8 + /// (using `to_string_lossy`). #[inline] pub fn set_content_disposition(mut self, cd: header::ContentDisposition) -> Self { self.content_disposition = cd; @@ -209,13 +324,15 @@ impl NamedFile { } /// Set content encoding for serving this file + /// + /// Must be used with [`actix_web::middleware::Compress`] to take effect. #[inline] pub fn set_content_encoding(mut self, enc: ContentEncoding) -> Self { self.encoding = Some(enc); self } - /// Specifies whether to use ETag or not. + /// Specifies whether to return `ETag` header in response. /// /// Default is true. #[inline] @@ -224,7 +341,7 @@ impl NamedFile { self } - /// Specifies whether to use Last-Modified or not. + /// Specifies whether to return `Last-Modified` header in response. /// /// Default is true. #[inline] @@ -242,14 +359,18 @@ impl NamedFile { self } + /// Creates an `ETag` in a format is similar to Apache's. pub(crate) fn etag(&self) -> Option { - // This etag format is similar to Apache's. self.modified.as_ref().map(|mtime| { let ino = { #[cfg(unix)] { + #[cfg(unix)] + use std::os::unix::fs::MetadataExt as _; + self.md.ino() } + #[cfg(not(unix))] { 0 @@ -260,7 +381,7 @@ impl NamedFile { .duration_since(UNIX_EPOCH) .expect("modification time must be after epoch"); - header::EntityTag::strong(format!( + header::EntityTag::new_strong(format!( "{:x}:{:x}:{:x}:{:x}", ino, self.md.len(), @@ -275,16 +396,17 @@ impl NamedFile { } /// Creates an `HttpResponse` with file as a streaming body. - pub fn into_response(self, req: &HttpRequest) -> HttpResponse { + pub fn into_response(self, req: &HttpRequest) -> HttpResponse { if self.status_code != StatusCode::OK { let mut res = HttpResponse::build(self.status_code); - if self.flags.contains(Flags::PREFER_UTF8) { - let ct = equiv_utf8_text(self.content_type.clone()); - res.insert_header((header::CONTENT_TYPE, ct.to_string())); + let ct = if self.flags.contains(Flags::PREFER_UTF8) { + equiv_utf8_text(self.content_type.clone()) } else { - res.insert_header((header::CONTENT_TYPE, self.content_type.to_string())); - } + self.content_type + }; + + res.insert_header((header::CONTENT_TYPE, ct.to_string())); if self.flags.contains(Flags::CONTENT_DISPOSITION) { res.insert_header(( @@ -294,10 +416,10 @@ impl NamedFile { } if let Some(current_encoding) = self.encoding { - res.encoding(current_encoding); + res.insert_header((header::CONTENT_ENCODING, current_encoding.as_str())); } - let reader = ChunkedReadFile::new(self.md.len(), 0, self.file); + let reader = chunked::new_chunked_read(self.md.len(), 0, self.file); return res.streaming(reader); } @@ -320,8 +442,8 @@ impl NamedFile { } else if let (Some(ref m), Some(header::IfUnmodifiedSince(ref since))) = (last_modified, req.get_header()) { - let t1: SystemTime = m.clone().into(); - let t2: SystemTime = since.clone().into(); + let t1: SystemTime = (*m).into(); + let t2: SystemTime = (*since).into(); match (t1.duration_since(UNIX_EPOCH), t2.duration_since(UNIX_EPOCH)) { (Ok(t1), Ok(t2)) => t1.as_secs() > t2.as_secs(), @@ -339,8 +461,8 @@ impl NamedFile { } else if let (Some(ref m), Some(header::IfModifiedSince(ref since))) = (last_modified, req.get_header()) { - let t1: SystemTime = m.clone().into(); - let t2: SystemTime = since.clone().into(); + let t1: SystemTime = (*m).into(); + let t2: SystemTime = (*since).into(); match (t1.duration_since(UNIX_EPOCH), t2.duration_since(UNIX_EPOCH)) { (Ok(t1), Ok(t2)) => t1.as_secs() <= t2.as_secs(), @@ -350,36 +472,36 @@ impl NamedFile { false }; - let mut resp = HttpResponse::build(self.status_code); + let mut res = HttpResponse::build(self.status_code); - if self.flags.contains(Flags::PREFER_UTF8) { - let ct = equiv_utf8_text(self.content_type.clone()); - resp.insert_header((header::CONTENT_TYPE, ct.to_string())); + let ct = if self.flags.contains(Flags::PREFER_UTF8) { + equiv_utf8_text(self.content_type.clone()) } else { - resp.insert_header((header::CONTENT_TYPE, self.content_type.to_string())); - } + self.content_type + }; + + res.insert_header((header::CONTENT_TYPE, ct.to_string())); if self.flags.contains(Flags::CONTENT_DISPOSITION) { - resp.insert_header(( + res.insert_header(( header::CONTENT_DISPOSITION, self.content_disposition.to_string(), )); } - // default compressing if let Some(current_encoding) = self.encoding { - resp.encoding(current_encoding); + res.insert_header((header::CONTENT_ENCODING, current_encoding.as_str())); } if let Some(lm) = last_modified { - resp.insert_header((header::LAST_MODIFIED, lm.to_string())); + res.insert_header((header::LAST_MODIFIED, lm.to_string())); } if let Some(etag) = etag { - resp.insert_header((header::ETAG, etag.to_string())); + res.insert_header((header::ETAG, etag.to_string())); } - resp.insert_header((header::ACCEPT_RANGES, "bytes")); + res.insert_header((header::ACCEPT_RANGES, "bytes")); let mut length = self.md.len(); let mut offset = 0; @@ -391,47 +513,41 @@ impl NamedFile { length = ranges[0].length; offset = ranges[0].start; - resp.encoding(ContentEncoding::Identity); - resp.insert_header(( + // don't allow compression middleware to modify partial content + res.insert_header(( + header::CONTENT_ENCODING, + HeaderValue::from_static("identity"), + )); + + res.insert_header(( header::CONTENT_RANGE, format!("bytes {}-{}/{}", offset, offset + length - 1, self.md.len()), )); } else { - resp.insert_header((header::CONTENT_RANGE, format!("bytes */{}", length))); - return resp.status(StatusCode::RANGE_NOT_SATISFIABLE).finish(); + res.insert_header((header::CONTENT_RANGE, format!("bytes */{}", length))); + return res.status(StatusCode::RANGE_NOT_SATISFIABLE).finish(); }; } else { - return resp.status(StatusCode::BAD_REQUEST).finish(); + return res.status(StatusCode::BAD_REQUEST).finish(); }; }; if precondition_failed { - return resp.status(StatusCode::PRECONDITION_FAILED).finish(); + return res.status(StatusCode::PRECONDITION_FAILED).finish(); } else if not_modified { - return resp.status(StatusCode::NOT_MODIFIED).finish(); + return res + .status(StatusCode::NOT_MODIFIED) + .body(body::None::new()) + .map_into_boxed_body(); } - let reader = ChunkedReadFile::new(length, offset, self.file); + let reader = chunked::new_chunked_read(length, offset, self.file); if offset != 0 || length != self.md.len() { - resp.status(StatusCode::PARTIAL_CONTENT); + res.status(StatusCode::PARTIAL_CONTENT); } - resp.body(SizedStream::new(length, reader)) - } -} - -impl Deref for NamedFile { - type Target = File; - - fn deref(&self) -> &File { - &self.file - } -} - -impl DerefMut for NamedFile { - fn deref_mut(&mut self) -> &mut File { - &mut self.file + res.body(SizedStream::new(length, reader)) } } @@ -476,7 +592,62 @@ fn none_match(etag: Option<&header::EntityTag>, req: &HttpRequest) -> bool { } impl Responder for NamedFile { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { + type Body = BoxBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { self.into_response(req) } } + +impl ServiceFactory for NamedFile { + type Response = ServiceResponse; + type Error = Error; + type Config = (); + type Service = NamedFileService; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; + + fn new_service(&self, _: ()) -> Self::Future { + let service = NamedFileService { + path: self.path.clone(), + }; + + Box::pin(async move { Ok(service) }) + } +} + +#[doc(hidden)] +#[derive(Debug)] +pub struct NamedFileService { + path: PathBuf, +} + +impl Service for NamedFileService { + type Response = ServiceResponse; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + actix_service::always_ready!(); + + fn call(&self, req: ServiceRequest) -> Self::Future { + let (req, _) = req.into_parts(); + + let path = self.path.clone(); + Box::pin(async move { + let file = NamedFile::open_async(path).await?; + let res = file.into_response(&req); + Ok(ServiceResponse::new(req, res)) + }) + } +} + +impl HttpServiceFactory for NamedFile { + fn register(self, config: &mut AppService) { + config.register_service( + ResourceDef::root_prefix(self.path.to_string_lossy().as_ref()), + None, + self, + None, + ) + } +} diff --git a/actix-files/src/path_buf.rs b/actix-files/src/path_buf.rs index dd8e5b503..03b2cd766 100644 --- a/actix-files/src/path_buf.rs +++ b/actix-files/src/path_buf.rs @@ -1,14 +1,14 @@ use std::{ - path::{Path, PathBuf}, + path::{Component, Path, PathBuf}, str::FromStr, }; +use actix_utils::future::{ready, Ready}; use actix_web::{dev::Payload, FromRequest, HttpRequest}; -use futures_util::future::{ready, Ready}; use crate::error::UriSegmentError; -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub(crate) struct PathBufWrap(PathBuf); impl FromStr for PathBufWrap { @@ -21,11 +21,28 @@ impl FromStr for PathBufWrap { impl PathBufWrap { /// Parse a path, giving the choice of allowing hidden files to be considered valid segments. + /// + /// Path traversal is guarded by this method. pub fn parse_path(path: &str, hidden_files: bool) -> Result { let mut buf = PathBuf::new(); + // equivalent to `path.split('/').count()` + let mut segment_count = path.matches('/').count() + 1; + + // we can decode the whole path here (instead of per-segment decoding) + // because we will reject `%2F` in paths using `segement_count`. + let path = percent_encoding::percent_decode_str(path) + .decode_utf8() + .map_err(|_| UriSegmentError::NotValidUtf8)?; + + // disallow decoding `%2F` into `/` + if segment_count != path.matches('/').count() + 1 { + return Err(UriSegmentError::BadChar('/')); + } + for segment in path.split('/') { if segment == ".." { + segment_count -= 1; buf.pop(); } else if !hidden_files && segment.starts_with('.') { return Err(UriSegmentError::BadStart('.')); @@ -38,6 +55,7 @@ impl PathBufWrap { } else if segment.ends_with('<') { return Err(UriSegmentError::BadEnd('<')); } else if segment.is_empty() { + segment_count -= 1; continue; } else if cfg!(windows) && segment.contains('\\') { return Err(UriSegmentError::BadChar('\\')); @@ -46,6 +64,12 @@ impl PathBufWrap { } } + // make sure we agree with stdlib parser + for (i, component) in buf.components().enumerate() { + assert!(matches!(component, Component::Normal(_))); + assert!(i < segment_count); + } + Ok(PathBufWrap(buf)) } } @@ -59,7 +83,6 @@ impl AsRef for PathBufWrap { impl FromRequest for PathBufWrap { type Error = UriSegmentError; type Future = Ready>; - type Config = (); fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { ready(req.match_info().path().parse()) @@ -116,4 +139,24 @@ mod tests { PathBuf::from_iter(vec!["test", ".tt"]) ); } + + #[test] + fn path_traversal() { + assert_eq!( + PathBufWrap::parse_path("/../README.md", false).unwrap().0, + PathBuf::from_iter(vec!["README.md"]) + ); + + assert_eq!( + PathBufWrap::parse_path("/../README.md", true).unwrap().0, + PathBuf::from_iter(vec!["README.md"]) + ); + + assert_eq!( + PathBufWrap::parse_path("/../../../../../../../../../../etc/passwd", false) + .unwrap() + .0, + PathBuf::from_iter(vec!["etc/passwd"]) + ); + } } diff --git a/actix-files/src/range.rs b/actix-files/src/range.rs index 6718980cb..8d9fe9445 100644 --- a/actix-files/src/range.rs +++ b/actix-files/src/range.rs @@ -10,9 +10,6 @@ pub struct HttpRange { pub length: u64, } -const PREFIX: &str = "bytes="; -const PREFIX_LEN: usize = 6; - #[derive(Debug, Clone, Display, Error)] #[display(fmt = "Parse HTTP Range failed")] pub struct ParseRangeErr(#[error(not(source))] ()); @@ -23,82 +20,16 @@ impl HttpRange { /// `header` is HTTP Range header (e.g. `bytes=bytes=0-9`). /// `size` is full size of response (file). pub fn parse(header: &str, size: u64) -> Result, ParseRangeErr> { - if header.is_empty() { - return Ok(Vec::new()); + match http_range::HttpRange::parse(header, size) { + Ok(ranges) => Ok(ranges + .iter() + .map(|range| HttpRange { + start: range.start, + length: range.length, + }) + .collect()), + Err(_) => Err(ParseRangeErr(())), } - if !header.starts_with(PREFIX) { - return Err(ParseRangeErr(())); - } - - let size_sig = size as i64; - let mut no_overlap = false; - - let all_ranges: Vec> = header[PREFIX_LEN..] - .split(',') - .map(|x| x.trim()) - .filter(|x| !x.is_empty()) - .map(|ra| { - let mut start_end_iter = ra.split('-'); - - let start_str = start_end_iter.next().ok_or(ParseRangeErr(()))?.trim(); - let end_str = start_end_iter.next().ok_or(ParseRangeErr(()))?.trim(); - - if start_str.is_empty() { - // If no start is specified, end specifies the - // range start relative to the end of the file. - let mut length: i64 = end_str.parse().map_err(|_| ParseRangeErr(()))?; - - if length > size_sig { - length = size_sig; - } - - Ok(Some(HttpRange { - start: (size_sig - length) as u64, - length: length as u64, - })) - } else { - let start: i64 = start_str.parse().map_err(|_| ParseRangeErr(()))?; - - if start < 0 { - return Err(ParseRangeErr(())); - } - if start >= size_sig { - no_overlap = true; - return Ok(None); - } - - let length = if end_str.is_empty() { - // If no end is specified, range extends to end of the file. - size_sig - start - } else { - let mut end: i64 = end_str.parse().map_err(|_| ParseRangeErr(()))?; - - if start > end { - return Err(ParseRangeErr(())); - } - - if end >= size_sig { - end = size_sig - 1; - } - - end - start + 1 - }; - - Ok(Some(HttpRange { - start: start as u64, - length: length as u64, - })) - } - }) - .collect::>()?; - - let ranges: Vec = all_ranges.into_iter().filter_map(|x| x).collect(); - - if no_overlap && ranges.is_empty() { - return Err(ParseRangeErr(())); - } - - Ok(ranges) } } diff --git a/actix-files/src/service.rs b/actix-files/src/service.rs index 14eea6ebc..152e1855e 100644 --- a/actix-files/src/service.rs +++ b/actix-files/src/service.rs @@ -1,22 +1,33 @@ -use std::{fmt, io, path::PathBuf, rc::Rc, task::Poll}; +use std::{fmt, io, ops::Deref, path::PathBuf, rc::Rc}; -use actix_service::Service; use actix_web::{ - dev::{ServiceRequest, ServiceResponse}, + body::BoxBody, + dev::{Service, ServiceRequest, ServiceResponse}, error::Error, guard::Guard, http::{header, Method}, HttpResponse, }; -use futures_util::future::{ok, Either, LocalBoxFuture, Ready}; +use futures_core::future::LocalBoxFuture; use crate::{ named, Directory, DirectoryRenderer, FilesError, HttpService, MimeOverride, NamedFile, - PathBufWrap, + PathBufWrap, PathFilter, }; /// Assembled file serving service. -pub struct FilesService { +#[derive(Clone)] +pub struct FilesService(pub(crate) Rc); + +impl Deref for FilesService { + type Target = FilesServiceInner; + + fn deref(&self) -> &Self::Target { + &*self.0 + } +} + +pub struct FilesServiceInner { pub(crate) directory: PathBuf, pub(crate) index: Option, pub(crate) show_index: bool, @@ -24,26 +35,56 @@ pub struct FilesService { pub(crate) default: Option, pub(crate) renderer: Rc, pub(crate) mime_override: Option>, + pub(crate) path_filter: Option>, pub(crate) file_flags: named::Flags, pub(crate) guards: Option>, pub(crate) hidden_files: bool, } -type FilesServiceFuture = Either< - Ready>, - LocalBoxFuture<'static, Result>, ->; +impl fmt::Debug for FilesServiceInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("FilesServiceInner") + } +} impl FilesService { - fn handle_err(&self, e: io::Error, req: ServiceRequest) -> FilesServiceFuture { - log::debug!("Failed to handle {}: {}", req.path(), e); + async fn handle_err( + &self, + err: io::Error, + req: ServiceRequest, + ) -> Result { + log::debug!("error handling {}: {}", req.path(), err); if let Some(ref default) = self.default { - Either::Right(default.call(req)) + default.call(req).await } else { - Either::Left(ok(req.error_response(e))) + Ok(req.error_response(err)) } } + + fn serve_named_file( + &self, + req: ServiceRequest, + mut named_file: NamedFile, + ) -> ServiceResponse { + if let Some(ref mime_override) = self.mime_override { + let new_disposition = mime_override(&named_file.content_type.type_()); + named_file.content_disposition.disposition = new_disposition; + } + named_file.flags = self.file_flags; + + let (req, _) = req.into_parts(); + let res = named_file.into_response(&req); + ServiceResponse::new(req, res) + } + + fn show_index(&self, req: ServiceRequest, path: PathBuf) -> ServiceResponse { + let dir = Directory::new(self.directory.clone(), path); + + let (req, _) = req.into_parts(); + + (self.renderer)(&dir, &req).unwrap_or_else(|e| ServiceResponse::from_err(e, req)) + } } impl fmt::Debug for FilesService { @@ -53,102 +94,100 @@ impl fmt::Debug for FilesService { } impl Service for FilesService { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; - type Future = FilesServiceFuture; + type Future = LocalBoxFuture<'static, Result>; actix_service::always_ready!(); fn call(&self, req: ServiceRequest) -> Self::Future { let is_method_valid = if let Some(guard) = &self.guards { // execute user defined guards - (**guard).check(req.head()) + (**guard).check(&req.guard_ctx()) } else { // default behavior matches!(*req.method(), Method::HEAD | Method::GET) }; - if !is_method_valid { - return Either::Left(ok(req.into_response( - actix_web::HttpResponse::MethodNotAllowed() - .insert_header(header::ContentType(mime::TEXT_PLAIN_UTF_8)) - .body("Request did not meet this resource's requirements."), - ))); - } + let this = self.clone(); - let real_path = - match PathBufWrap::parse_path(req.match_info().path(), self.hidden_files) { - Ok(item) => item, - Err(e) => return Either::Left(ok(req.error_response(e))), - }; + Box::pin(async move { + if !is_method_valid { + return Ok(req.into_response( + HttpResponse::MethodNotAllowed() + .insert_header(header::ContentType(mime::TEXT_PLAIN_UTF_8)) + .body("Request did not meet this resource's requirements."), + )); + } - // full file path - let path = match self.directory.join(&real_path).canonicalize() { - Ok(path) => path, - Err(e) => return self.handle_err(e, req), - }; + let real_path = + match PathBufWrap::parse_path(req.match_info().path(), this.hidden_files) { + Ok(item) => item, + Err(err) => return Ok(req.error_response(err)), + }; - if path.is_dir() { - if let Some(ref redir_index) = self.index { - if self.redirect_to_slash && !req.path().ends_with('/') { + if let Some(filter) = &this.path_filter { + if !filter(real_path.as_ref(), req.head()) { + if let Some(ref default) = this.default { + return default.call(req).await; + } else { + return Ok(req.into_response(HttpResponse::NotFound().finish())); + } + } + } + + // full file path + let path = this.directory.join(&real_path); + if let Err(err) = path.canonicalize() { + return this.handle_err(err, req).await; + } + + if path.is_dir() { + if this.redirect_to_slash + && !req.path().ends_with('/') + && (this.index.is_some() || this.show_index) + { let redirect_to = format!("{}/", req.path()); - return Either::Left(ok(req.into_response( + return Ok(req.into_response( HttpResponse::Found() .insert_header((header::LOCATION, redirect_to)) - .body("") - .into_body(), - ))); + .finish(), + )); } - let path = path.join(redir_index); - - match NamedFile::open(path) { + match this.index { + Some(ref index) => { + let named_path = path.join(index); + match NamedFile::open_async(named_path).await { + Ok(named_file) => Ok(this.serve_named_file(req, named_file)), + Err(_) if this.show_index => Ok(this.show_index(req, path)), + Err(err) => this.handle_err(err, req).await, + } + } + None if this.show_index => Ok(this.show_index(req, path)), + _ => Ok(ServiceResponse::from_err( + FilesError::IsDirectory, + req.into_parts().0, + )), + } + } else { + match NamedFile::open_async(&path).await { Ok(mut named_file) => { - if let Some(ref mime_override) = self.mime_override { + if let Some(ref mime_override) = this.mime_override { let new_disposition = mime_override(&named_file.content_type.type_()); named_file.content_disposition.disposition = new_disposition; } - named_file.flags = self.file_flags; + named_file.flags = this.file_flags; let (req, _) = req.into_parts(); let res = named_file.into_response(&req); - Either::Left(ok(ServiceResponse::new(req, res))) + Ok(ServiceResponse::new(req, res)) } - Err(e) => self.handle_err(e, req), + Err(err) => this.handle_err(err, req).await, } - } else if self.show_index { - let dir = Directory::new(self.directory.clone(), path); - - let (req, _) = req.into_parts(); - let x = (self.renderer)(&dir, &req); - - match x { - Ok(resp) => Either::Left(ok(resp)), - Err(e) => Either::Left(ok(ServiceResponse::from_err(e, req))), - } - } else { - Either::Left(ok(ServiceResponse::from_err( - FilesError::IsDirectory, - req.into_parts().0, - ))) } - } else { - match NamedFile::open(path) { - Ok(mut named_file) => { - if let Some(ref mime_override) = self.mime_override { - let new_disposition = mime_override(&named_file.content_type.type_()); - named_file.content_disposition.disposition = new_disposition; - } - named_file.flags = self.file_flags; - - let (req, _) = req.into_parts(); - let res = named_file.into_response(&req); - Either::Left(ok(ServiceResponse::new(req, res))) - } - Err(e) => self.handle_err(e, req), - } - } + }) } } diff --git a/actix-files/tests/encoding.rs b/actix-files/tests/encoding.rs index d21d4f8fd..652a7c12b 100644 --- a/actix-files/tests/encoding.rs +++ b/actix-files/tests/encoding.rs @@ -8,7 +8,7 @@ use actix_web::{ App, }; -#[actix_rt::test] +#[actix_web::test] async fn test_utf8_file_contents() { // use default ISO-8859-1 encoding let srv = test::init_service(App::new().service(Files::new("/", "./tests"))).await; diff --git a/actix-files/tests/fixtures/guards/first/index.txt b/actix-files/tests/fixtures/guards/first/index.txt new file mode 100644 index 000000000..fe4f02ad0 --- /dev/null +++ b/actix-files/tests/fixtures/guards/first/index.txt @@ -0,0 +1 @@ +first \ No newline at end of file diff --git a/actix-files/tests/fixtures/guards/second/index.txt b/actix-files/tests/fixtures/guards/second/index.txt new file mode 100644 index 000000000..2147e4188 --- /dev/null +++ b/actix-files/tests/fixtures/guards/second/index.txt @@ -0,0 +1 @@ +second \ No newline at end of file diff --git a/actix-files/tests/guard.rs b/actix-files/tests/guard.rs new file mode 100644 index 000000000..d053f3fdc --- /dev/null +++ b/actix-files/tests/guard.rs @@ -0,0 +1,36 @@ +use actix_files::Files; +use actix_web::{ + guard::Host, + http::StatusCode, + test::{self, TestRequest}, + App, +}; +use bytes::Bytes; + +#[actix_web::test] +async fn test_guard_filter() { + let srv = test::init_service( + App::new() + .service(Files::new("/", "./tests/fixtures/guards/first").guard(Host("first.com"))) + .service( + Files::new("/", "./tests/fixtures/guards/second").guard(Host("second.com")), + ), + ) + .await; + + let req = TestRequest::with_uri("/index.txt") + .append_header(("Host", "first.com")) + .to_request(); + let res = test::call_service(&srv, req).await; + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(test::read_body(res).await, Bytes::from("first")); + + let req = TestRequest::with_uri("/index.txt") + .append_header(("Host", "second.com")) + .to_request(); + let res = test::call_service(&srv, req).await; + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(test::read_body(res).await, Bytes::from("second")); +} diff --git a/actix-files/tests/symlink-test.png b/actix-files/tests/symlink-test.png new file mode 120000 index 000000000..65c0dcfd6 --- /dev/null +++ b/actix-files/tests/symlink-test.png @@ -0,0 +1 @@ +test.png \ No newline at end of file diff --git a/actix-files/tests/test.js b/actix-files/tests/test.js new file mode 100644 index 000000000..2ee135561 --- /dev/null +++ b/actix-files/tests/test.js @@ -0,0 +1 @@ +// this file is empty. diff --git a/actix-files/tests/traversal.rs b/actix-files/tests/traversal.rs new file mode 100644 index 000000000..c890b3fe4 --- /dev/null +++ b/actix-files/tests/traversal.rs @@ -0,0 +1,27 @@ +use actix_files::Files; +use actix_web::{ + http::StatusCode, + test::{self, TestRequest}, + App, +}; + +#[actix_rt::test] +async fn test_directory_traversal_prevention() { + let srv = test::init_service(App::new().service(Files::new("/", "./tests"))).await; + + let req = + TestRequest::with_uri("/../../../../../../../../../../../etc/passwd").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri( + "/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/%2e%2e/etc/passwd", + ) + .to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/%00/etc/passwd%00").to_request(); + let res = test::call_service(&srv, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} diff --git a/actix-http-test/CHANGES.md b/actix-http-test/CHANGES.md index 2f47d700d..b62281798 100644 --- a/actix-http-test/CHANGES.md +++ b/actix-http-test/CHANGES.md @@ -3,88 +3,136 @@ ## Unreleased - 2021-xx-xx +## 3.0.0-beta.11 - 2022-01-04 +- Minimum supported Rust version (MSRV) is now 1.54. + + +## 3.0.0-beta.10 - 2021-12-27 +- Update `actix-server` to `2.0.0-rc.2`. [#2550] + +[#2550]: https://github.com/actix/actix-web/pull/2550 + + +## 3.0.0-beta.9 - 2021-12-11 +- No significant changes since `3.0.0-beta.8`. + + +## 3.0.0-beta.8 - 2021-11-30 +- Update `actix-tls` to `3.0.0-rc.1`. [#2474] + +[#2474]: https://github.com/actix/actix-web/pull/2474 + + +## 3.0.0-beta.7 - 2021-11-22 +- Fix compatibility with experimental `io-uring` feature of `actix-rt`. [#2408] + +[#2408]: https://github.com/actix/actix-web/pull/2408 + + +## 3.0.0-beta.6 - 2021-11-15 +- `TestServer::stop` is now async and will wait for the server and system to shutdown. [#2442] +- Update `actix-server` to `2.0.0-beta.9`. [#2442] +- Minimum supported Rust version (MSRV) is now 1.52. + +[#2442]: https://github.com/actix/actix-web/pull/2442 + + +## 3.0.0-beta.5 - 2021-09-09 +- Minimum supported Rust version (MSRV) is now 1.51. + + +## 3.0.0-beta.4 - 2021-04-02 +- Added `TestServer::client_headers` method. [#2097] + +[#2097]: https://github.com/actix/actix-web/pull/2097 + + +## 3.0.0-beta.3 - 2021-03-09 +- No notable changes. + + ## 3.0.0-beta.2 - 2021-02-10 -* No notable changes. +- No notable changes. ## 3.0.0-beta.1 - 2021-01-07 -* Update `bytes` to `1.0`. [#1813] +- Update `bytes` to `1.0`. [#1813] [#1813]: https://github.com/actix/actix-web/pull/1813 ## 2.1.0 - 2020-11-25 -* Add ability to set address for `TestServer`. [#1645] -* Upgrade `base64` to `0.13`. -* Upgrade `serde_urlencoded` to `0.7`. [#1773] +- Add ability to set address for `TestServer`. [#1645] +- Upgrade `base64` to `0.13`. +- Upgrade `serde_urlencoded` to `0.7`. [#1773] [#1773]: https://github.com/actix/actix-web/pull/1773 [#1645]: https://github.com/actix/actix-web/pull/1645 ## 2.0.0 - 2020-09-11 -* Update actix-codec and actix-utils dependencies. +- Update actix-codec and actix-utils dependencies. ## 2.0.0-alpha.1 - 2020-05-23 -* Update the `time` dependency to 0.2.7 -* Update `actix-connect` dependency to 2.0.0-alpha.2 -* Make `test_server` `async` fn. -* Bump minimum supported Rust version to 1.40 -* Replace deprecated `net2` crate with `socket2` -* Update `base64` dependency to 0.12 -* Update `env_logger` dependency to 0.7 +- Update the `time` dependency to 0.2.7 +- Update `actix-connect` dependency to 2.0.0-alpha.2 +- Make `test_server` `async` fn. +- Bump minimum supported Rust version to 1.40 +- Replace deprecated `net2` crate with `socket2` +- Update `base64` dependency to 0.12 +- Update `env_logger` dependency to 0.7 ## 1.0.0 - 2019-12-13 -* Replaced `TestServer::start()` with `test_server()` +- Replaced `TestServer::start()` with `test_server()` ## 1.0.0-alpha.3 - 2019-12-07 -* Migrate to `std::future` +- Migrate to `std::future` ## 0.2.5 - 2019-09-17 -* Update serde_urlencoded to "0.6.1" -* Increase TestServerRuntime timeouts from 500ms to 3000ms -* Do not override current `System` +- Update serde_urlencoded to "0.6.1" +- Increase TestServerRuntime timeouts from 500ms to 3000ms +- Do not override current `System` ## 0.2.4 - 2019-07-18 -* Update actix-server to 0.6 +- Update actix-server to 0.6 ## 0.2.3 - 2019-07-16 -* Add `delete`, `options`, `patch` methods to `TestServerRunner` +- Add `delete`, `options`, `patch` methods to `TestServerRunner` ## 0.2.2 - 2019-06-16 -* Add .put() and .sput() methods +- Add .put() and .sput() methods ## 0.2.1 - 2019-06-05 -* Add license files +- Add license files ## 0.2.0 - 2019-05-12 -* Update awc and actix-http deps +- Update awc and actix-http deps ## 0.1.1 - 2019-04-24 -* Always make new connection for http client +- Always make new connection for http client ## 0.1.0 - 2019-04-16 -* No changes +- No changes ## 0.1.0-alpha.3 - 2019-04-02 -* Request functions accept path #743 +- Request functions accept path #743 ## 0.1.0-alpha.2 - 2019-03-29 -* Added TestServerRuntime::load_body() method -* Update actix-http and awc libraries +- Added TestServerRuntime::load_body() method +- Update actix-http and awc libraries ## 0.1.0-alpha.1 - 2019-03-28 -* Initial impl +- Initial impl diff --git a/actix-http-test/Cargo.toml b/actix-http-test/Cargo.toml index 6dcf73637..db92f1983 100644 --- a/actix-http-test/Cargo.toml +++ b/actix-http-test/Cargo.toml @@ -1,18 +1,18 @@ [package] name = "actix-http-test" -version = "3.0.0-beta.2" +version = "3.0.0-beta.11" authors = ["Nikolay Kim "] description = "Various helpers for Actix applications to use during testing" -readme = "README.md" keywords = ["http", "web", "framework", "async", "futures"] homepage = "https://actix.rs" repository = "https://github.com/actix/actix-web.git" -documentation = "https://docs.rs/actix-http-test/" -categories = ["network-programming", "asynchronous", - "web-programming::http-server", - "web-programming::websocket"] +categories = [ + "network-programming", + "asynchronous", + "web-programming::http-server", + "web-programming::websocket", +] license = "MIT OR Apache-2.0" -exclude = [".gitignore", ".cargo/config"] edition = "2018" [package.metadata.docs.rs] @@ -29,33 +29,27 @@ default = [] openssl = ["tls-openssl", "awc/openssl"] [dependencies] -actix-service = "2.0.0-beta.4" -actix-codec = "0.4.0-beta.1" -actix-tls = "3.0.0-beta.3" -actix-utils = "3.0.0-beta.2" -actix-rt = "2" -actix-server = "2.0.0-beta.3" -awc = "3.0.0-beta.2" +actix-service = "2.0.0" +actix-codec = "0.4.1" +actix-tls = "3.0.0" +actix-utils = "3.0.0" +actix-rt = "2.2" +actix-server = "2.0.0-rc.2" +awc = { version = "3.0.0-beta.18", default-features = false } base64 = "0.13" bytes = "1" futures-core = { version = "0.3.7", default-features = false } -http = "0.2.2" +http = "0.2.5" log = "0.4" -socket2 = "0.3" +socket2 = "0.4" serde = "1.0" serde_json = "1.0" slab = "0.4" serde_urlencoded = "0.7" -time = { version = "0.2.23", default-features = false, features = ["std"] } tls-openssl = { version = "0.10.9", package = "openssl", optional = true } - -[target.'cfg(windows)'.dependencies.tls-openssl] -version = "0.10.9" -package = "openssl" -features = ["vendored"] -optional = true +tokio = { version = "1.8.4", features = ["sync"] } [dev-dependencies] -actix-web = "4.0.0-beta.3" -actix-http = "3.0.0-beta.3" +actix-web = { version = "4.0.0-beta.19", default-features = false, features = ["cookies"] } +actix-http = "3.0.0-beta.18" diff --git a/actix-http-test/README.md b/actix-http-test/README.md index 66f15979d..10c04b368 100644 --- a/actix-http-test/README.md +++ b/actix-http-test/README.md @@ -3,13 +3,15 @@ > Various helpers for Actix applications to use during testing. [![crates.io](https://img.shields.io/crates/v/actix-http-test?label=latest)](https://crates.io/crates/actix-http-test) -[![Documentation](https://docs.rs/actix-http-test/badge.svg?version=2.1.0)](https://docs.rs/actix-http-test/2.1.0) +[![Documentation](https://docs.rs/actix-http-test/badge.svg?version=3.0.0-beta.11)](https://docs.rs/actix-http-test/3.0.0-beta.11) +[![Version](https://img.shields.io/badge/rustc-1.54+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.54.0.html) ![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-http-test) -[![Dependency Status](https://deps.rs/crate/actix-http-test/2.1.0/status.svg)](https://deps.rs/crate/actix-http-test/2.1.0) -[![Join the chat at https://gitter.im/actix/actix-web](https://badges.gitter.im/actix/actix-web.svg)](https://gitter.im/actix/actix-web?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +
+[![Dependency Status](https://deps.rs/crate/actix-http-test/3.0.0-beta.11/status.svg)](https://deps.rs/crate/actix-http-test/3.0.0-beta.11) +[![Download](https://img.shields.io/crates/d/actix-http-test.svg)](https://crates.io/crates/actix-http-test) +[![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-http-test) -- [Chat on Gitter](https://gitter.im/actix/actix-web) -- Minimum Supported Rust Version (MSRV): 1.46.0 +- Minimum Supported Rust Version (MSRV): 1.54 diff --git a/actix-http-test/src/lib.rs b/actix-http-test/src/lib.rs index df5774998..8636ef9c4 100644 --- a/actix-http-test/src/lib.rs +++ b/actix-http-test/src/lib.rs @@ -1,46 +1,48 @@ //! Various helpers for Actix applications to use during testing. -#![deny(rust_2018_idioms)] +#![deny(rust_2018_idioms, nonstandard_style)] +#![warn(future_incompatible)] #![doc(html_logo_url = "https://actix.rs/img/logo.png")] #![doc(html_favicon_url = "https://actix.rs/favicon.ico")] #[cfg(feature = "openssl")] extern crate tls_openssl as openssl; -use std::sync::mpsc; -use std::{net, thread, time}; +use std::{net, thread, time::Duration}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_rt::{net::TcpStream, System}; -use actix_server::{Server, ServiceFactory}; -use awc::{error::PayloadError, ws, Client, ClientRequest, ClientResponse, Connector}; +use actix_server::{Server, ServerServiceFactory}; +use awc::{ + error::PayloadError, http::header::HeaderMap, ws, Client, ClientRequest, ClientResponse, + Connector, +}; use bytes::Bytes; use futures_core::stream::Stream; use http::Method; use socket2::{Domain, Protocol, Socket, Type}; +use tokio::sync::mpsc; -/// Start test server +/// Start test server. /// -/// `TestServer` is very simple test server that simplify process of writing -/// integration tests cases for actix web applications. +/// `TestServer` is very simple test server that simplify process of writing integration tests cases +/// for HTTP applications. /// /// # Examples -/// -/// ```rust +/// ```no_run /// use actix_http::HttpService; -/// use actix_http_test::TestServer; +/// use actix_http_test::test_server; /// use actix_web::{web, App, HttpResponse, Error}; /// /// async fn my_handler() -> Result { /// Ok(HttpResponse::Ok().into()) /// } /// -/// #[actix_rt::test] +/// #[actix_web::test] /// async fn test_example() { -/// let mut srv = TestServer::start( -/// || HttpService::new( -/// App::new().service( -/// web::resource("/").to(my_handler)) +/// let mut srv = TestServer::start(|| +/// HttpService::new( +/// App::new().service(web::resource("/").to(my_handler)) /// ) /// ); /// @@ -49,89 +51,91 @@ use socket2::{Domain, Protocol, Socket, Type}; /// assert!(response.status().is_success()); /// } /// ``` -pub async fn test_server>(factory: F) -> TestServer { +pub async fn test_server>(factory: F) -> TestServer { let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); test_server_with_addr(tcp, factory).await } -/// Start [`test server`](test_server()) on a concrete Address -pub async fn test_server_with_addr>( +/// Start [`test server`](test_server()) on an existing address binding. +pub async fn test_server_with_addr>( tcp: net::TcpListener, factory: F, ) -> TestServer { - let (tx, rx) = mpsc::channel(); + let (started_tx, started_rx) = std::sync::mpsc::channel(); + let (thread_stop_tx, thread_stop_rx) = mpsc::channel(1); // run server in separate thread thread::spawn(move || { - let sys = System::new(); - let local_addr = tcp.local_addr().unwrap(); + System::new().block_on(async move { + let local_addr = tcp.local_addr().unwrap(); - let srv = Server::build() - .listen("test", tcp, factory)? - .workers(1) - .disable_signals(); + let srv = Server::build() + .workers(1) + .disable_signals() + .system_exit() + .listen("test", tcp, factory) + .expect("test server could not be created"); - sys.block_on(async { - srv.run(); - tx.send((System::current(), local_addr)).unwrap(); + let srv = srv.run(); + started_tx + .send((System::current(), srv.handle(), local_addr)) + .unwrap(); + + // drive server loop + srv.await.unwrap(); }); - sys.run() + // notify TestServer that server and system have shut down + // all thread managed resources should be dropped at this point + let _ = thread_stop_tx.send(()); }); - let (system, addr) = rx.recv().unwrap(); + let (system, server, addr) = started_rx.recv().unwrap(); let client = { + #[cfg(feature = "openssl")] let connector = { - #[cfg(feature = "openssl")] - { - use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_verify(SslVerifyMode::NONE); - let _ = builder - .set_alpn_protos(b"\x02h2\x08http/1.1") - .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); - Connector::new() - .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(30000)) - .ssl(builder.build()) - .finish() - } - #[cfg(not(feature = "openssl"))] - { - Connector::new() - .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(30000)) - .finish() - } + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); + + Connector::new() + .conn_lifetime(Duration::from_secs(0)) + .timeout(Duration::from_millis(30000)) + .openssl(builder.build()) + }; + + #[cfg(not(feature = "openssl"))] + let connector = { + Connector::new() + .conn_lifetime(Duration::from_secs(0)) + .timeout(Duration::from_millis(30000)) }; Client::builder().connector(connector).finish() }; TestServer { - addr, + server, client, system, + addr, + thread_stop_rx, } } -/// Get first available unused address -pub fn unused_addr() -> net::SocketAddr { - let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); - let socket = Socket::new(Domain::ipv4(), Type::stream(), Some(Protocol::tcp())).unwrap(); - socket.bind(&addr.into()).unwrap(); - socket.set_reuse_address(true).unwrap(); - let tcp = socket.into_tcp_listener(); - tcp.local_addr().unwrap() -} - /// Test server controller pub struct TestServer { + server: actix_server::ServerHandle, + client: awc::Client, + system: actix_rt::System, addr: net::SocketAddr, - client: Client, - system: System, + thread_stop_rx: mpsc::Receiver<()>, } impl TestServer { @@ -260,14 +264,49 @@ impl TestServer { self.ws_at("/").await } - /// Stop HTTP server - fn stop(&mut self) { + /// Get default HeaderMap of Client. + /// + /// Returns Some(&mut HeaderMap) when Client object is unique + /// (No other clone of client exists at the same time). + pub fn client_headers(&mut self) -> Option<&mut HeaderMap> { + self.client.headers() + } + + /// Stop HTTP server. + /// + /// Waits for spawned `Server` and `System` to (force) shutdown. + pub async fn stop(&mut self) { + // signal server to stop + self.server.stop(false).await; + + // also signal system to stop + // though this is handled by `ServerBuilder::exit_system` too self.system.stop(); + + // wait for thread to be stopped but don't care about result + let _ = self.thread_stop_rx.recv().await; } } impl Drop for TestServer { fn drop(&mut self) { - self.stop() + // calls in this Drop impl should be enough to shut down the server, system, and thread + // without needing to await anything + + // signal server to stop + let _ = self.server.stop(true); + + // signal system to stop + self.system.stop(); } } + +/// Get a localhost socket address with random, unused port. +pub fn unused_addr() -> net::SocketAddr { + let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).unwrap(); + socket.bind(&addr.into()).unwrap(); + socket.set_reuse_address(true).unwrap(); + let tcp = net::TcpListener::from(socket); + tcp.local_addr().unwrap() +} diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 54f7357f1..935e35561 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,55 +1,363 @@ # Changes ## Unreleased - 2021-xx-xx + + +## 3.0.0-beta.18 - 2022-01-04 +### Added +- `impl Eq` for `header::ContentEncoding`. [#2501] +- `impl Copy` for `QualityItem` where `T: Copy`. [#2501] +- `Quality::ZERO` equivalent to `q=0`. [#2501] +- `QualityItem::zero` that uses `Quality::ZERO`. [#2501] +- `ContentEncoding::to_header_value()`. [#2501] + ### Changed -* Feature `cookies` is now optional and disabled by default. [#1981] +- `Quality::MIN` is now the smallest non-zero value. [#2501] +- `QualityItem::min` semantics changed with `QualityItem::MIN`. [#2501] +- Rename `ContentEncoding::{Br => Brotli}`. [#2501] +- Minimum supported Rust version (MSRV) is now 1.54. +- Rename `header::EntityTag::{weak => new_weak, strong => new_strong}`. [#2565] + +### Fixed +- `ContentEncoding::Identity` can now be parsed from a string. [#2501] +- A `Vary` header is now correctly sent along with compressed content. [#2501] + +### Removed +- `ContentEncoding::Auto` variant. [#2501] +- `ContentEncoding::is_compression()`. [#2501] + +[#2501]: https://github.com/actix/actix-web/pull/2501 +[#2565]: https://github.com/actix/actix-web/pull/2565 + + +## 3.0.0-beta.17 - 2021-12-27 +### Changes +- `HeaderMap::get_all` now returns a `std::slice::Iter`. [#2527] +- `Payload` inner fields are now named. [#2545] +- `impl Stream` for `Payload` no longer requires the `Stream` variant be `Unpin`. [#2545] +- `impl Future` for `h1::SendResponse` no longer requires the body type be `Unpin`. [#2545] +- `impl Stream` for `encoding::Decoder` no longer requires the stream type be `Unpin`. [#2545] +- Rename `PayloadStream` to `BoxedPayloadStream`. [#2545] + +### Removed +- `h1::Payload::readany`. [#2545] + +[#2527]: https://github.com/actix/actix-web/pull/2527 +[#2545]: https://github.com/actix/actix-web/pull/2545 + + +## 3.0.0-beta.16 - 2021-12-17 +### Added +- New method on `MessageBody` trait, `try_into_bytes`, with default implementation, for optimizations on body types that complete in exactly one poll. Replaces `is_complete_body` and `take_complete_body`. [#2522] + +### Changed +- Rename trait `IntoHeaderPair => TryIntoHeaderPair`. [#2510] +- Rename `TryIntoHeaderPair::{try_into_header_pair => try_into_pair}`. [#2510] +- Rename trait `IntoHeaderValue => TryIntoHeaderValue`. [#2510] + +### Removed +- `MessageBody::{is_complete_body,take_complete_body}`. [#2522] + +[#2510]: https://github.com/actix/actix-web/pull/2510 +[#2522]: https://github.com/actix/actix-web/pull/2522 + + +## 3.0.0-beta.15 - 2021-12-11 +### Added +- Add timeout for canceling HTTP/2 server side connection handshake. Default to 5 seconds. [#2483] +- HTTP/2 handshake timeout can be configured with `ServiceConfig::client_timeout`. [#2483] +- `Response::map_into_boxed_body`. [#2468] +- `body::EitherBody` enum. [#2468] +- `body::None` struct. [#2468] +- Impl `MessageBody` for `bytestring::ByteString`. [#2468] +- `impl Clone for ws::HandshakeError`. [#2468] +- `#[must_use]` for `ws::Codec` to prevent subtle bugs. [#1920] +- `impl Default ` for `ws::Codec`. [#1920] +- `header::QualityItem::{max, min}`. [#2486] +- `header::Quality::{MAX, MIN}`. [#2486] +- `impl Display` for `header::Quality`. [#2486] +- Connection data set through the `on_connect_ext` callbacks is now accessible only from the new `Request::conn_data()` method. [#2491] +- `Request::take_conn_data()`. [#2491] +- `Request::take_req_data()`. [#2487] +- `impl Clone` for `RequestHead`. [#2487] +- New methods on `MessageBody` trait, `is_complete_body` and `take_complete_body`, both with default implementations, for optimizations on body types that are done in exactly one poll/chunk. [#2497] +- New `boxed` method on `MessageBody` trait for wrapping body type. [#2520] + +### Changed +- Rename `body::BoxBody::{from_body => new}`. [#2468] +- Body type for `Responses` returned from `Response::{new, ok, etc...}` is now `BoxBody`. [#2468] +- The `Error` associated type on `MessageBody` type now requires `impl Error` (or similar). [#2468] +- Error types using in service builders now require `Into>`. [#2468] +- `From` implementations on error types now return a `Response`. [#2468] +- `ResponseBuilder::body(B)` now returns `Response>`. [#2468] +- `ResponseBuilder::finish()` now returns `Response>`. [#2468] + +### Removed +- `ResponseBuilder::streaming`. [#2468] +- `impl Future` for `ResponseBuilder`. [#2468] +- Remove unnecessary `MessageBody` bound on types passed to `body::AnyBody::new`. [#2468] +- Move `body::AnyBody` to `awc`. Replaced with `EitherBody` and `BoxBody`. [#2468] +- `impl Copy` for `ws::Codec`. [#1920] +- `header::qitem` helper. Replaced with `header::QualityItem::max`. [#2486] +- `impl TryFrom` for `header::Quality`. [#2486] +- `http` module. Most everything it contained is exported at the crate root. [#2488] + +[#2483]: https://github.com/actix/actix-web/pull/2483 +[#2468]: https://github.com/actix/actix-web/pull/2468 +[#1920]: https://github.com/actix/actix-web/pull/1920 +[#2486]: https://github.com/actix/actix-web/pull/2486 +[#2487]: https://github.com/actix/actix-web/pull/2487 +[#2488]: https://github.com/actix/actix-web/pull/2488 +[#2491]: https://github.com/actix/actix-web/pull/2491 +[#2497]: https://github.com/actix/actix-web/pull/2497 +[#2520]: https://github.com/actix/actix-web/pull/2520 + + +## 3.0.0-beta.14 - 2021-11-30 +### Changed +- Guarantee ordering of `header::GetAll` iterator to be same as insertion order. [#2467] +- Expose `header::map` module. [#2467] +- Implement `ExactSizeIterator` and `FusedIterator` for all `HeaderMap` iterators. [#2470] +- Update `actix-tls` to `3.0.0-rc.1`. [#2474] + +[#2467]: https://github.com/actix/actix-web/pull/2467 +[#2470]: https://github.com/actix/actix-web/pull/2470 +[#2474]: https://github.com/actix/actix-web/pull/2474 + + +## 3.0.0-beta.13 - 2021-11-22 +### Added +- `body::AnyBody::empty` for quickly creating an empty body. [#2446] +- `body::AnyBody::none` for quickly creating a "none" body. [#2456] +- `impl Clone` for `body::AnyBody where S: Clone`. [#2448] +- `body::AnyBody::into_boxed` for quickly converting to a type-erased, boxed body type. [#2448] + +### Changed +- Rename `body::AnyBody::{Message => Body}`. [#2446] +- Rename `body::AnyBody::{from_message => new_boxed}`. [#2448] +- Rename `body::AnyBody::{from_slice => copy_from_slice}`. [#2448] +- Rename `body::{BoxAnyBody => BoxBody}`. [#2448] +- Change representation of `AnyBody` to include a type parameter in `Body` variant. Defaults to `BoxBody`. [#2448] +- `Encoder::response` now returns `AnyBody>`. [#2448] + +### Removed +- `body::AnyBody::Empty`; an empty body can now only be represented as a zero-length `Bytes` variant. [#2446] +- `body::BodySize::Empty`; an empty body can now only be represented as a `Sized(0)` variant. [#2446] +- `EncoderError::Boxed`; it is no longer required. [#2446] +- `body::ResponseBody`; is function is replaced by the new `body::AnyBody` enum. [#2446] + +[#2446]: https://github.com/actix/actix-web/pull/2446 +[#2448]: https://github.com/actix/actix-web/pull/2448 +[#2456]: https://github.com/actix/actix-web/pull/2456 + + +## 3.0.0-beta.12 - 2021-11-15 +### Changed +- Update `actix-server` to `2.0.0-beta.9`. [#2442] + +### Removed +- `client` module. [#2425] +- `trust-dns` feature. [#2425] + +[#2425]: https://github.com/actix/actix-web/pull/2425 +[#2442]: https://github.com/actix/actix-web/pull/2442 + + +## 3.0.0-beta.11 - 2021-10-20 +### Changed +- Updated rustls to v0.20. [#2414] +- Minimum supported Rust version (MSRV) is now 1.52. + +[#2414]: https://github.com/actix/actix-web/pull/2414 + + +## 3.0.0-beta.10 - 2021-09-09 +### Changed +- `ContentEncoding` is now marked `#[non_exhaustive]`. [#2377] +- Minimum supported Rust version (MSRV) is now 1.51. + +### Fixed +- Remove slice creation pointing to potential uninitialized data on h1 encoder. [#2364] +- Remove `Into` bound on `Encoder` body types. [#2375] +- Fix quality parse error in Accept-Encoding header. [#2344] + +[#2364]: https://github.com/actix/actix-web/pull/2364 +[#2375]: https://github.com/actix/actix-web/pull/2375 +[#2344]: https://github.com/actix/actix-web/pull/2344 +[#2377]: https://github.com/actix/actix-web/pull/2377 + + +## 3.0.0-beta.9 - 2021-08-09 +### Fixed +- Potential HTTP request smuggling vulnerabilities. [RUSTSEC-2021-0081](https://github.com/rustsec/advisory-db/pull/977) + + +## 3.0.0-beta.8 - 2021-06-26 +### Changed +- Change compression algorithm features flags. [#2250] + +### Removed +- `downcast` and `downcast_get_type_id` macros. [#2291] + +[#2291]: https://github.com/actix/actix-web/pull/2291 +[#2250]: https://github.com/actix/actix-web/pull/2250 + + +## 3.0.0-beta.7 - 2021-06-17 +### Added +- Alias `body::Body` as `body::AnyBody`. [#2215] +- `BoxAnyBody`: a boxed message body with boxed errors. [#2183] +- Re-export `http` crate's `Error` type as `error::HttpError`. [#2171] +- Re-export `StatusCode`, `Method`, `Version` and `Uri` at the crate root. [#2171] +- Re-export `ContentEncoding` and `ConnectionType` at the crate root. [#2171] +- `Response::into_body` that consumes response and returns body type. [#2201] +- `impl Default` for `Response`. [#2201] +- Add zstd support for `ContentEncoding`. [#2244] + +### Changed +- The `MessageBody` trait now has an associated `Error` type. [#2183] +- All error trait bounds in server service builders have changed from `Into` to `Into>`. [#2253] +- All error trait bounds in message body and stream impls changed from `Into` to `Into>`. [#2253] +- Places in `Response` where `ResponseBody` was received or returned now simply use `B`. [#2201] +- `header` mod is now public. [#2171] +- `uri` mod is now public. [#2171] +- Update `language-tags` to `0.3`. +- Reduce the level from `error` to `debug` for the log line that is emitted when a `500 Internal Server Error` is built using `HttpResponse::from_error`. [#2201] +- `ResponseBuilder::message_body` now returns a `Result`. [#2201] +- Remove `Unpin` bound on `ResponseBuilder::streaming`. [#2253] +- `HttpServer::{listen_rustls(), bind_rustls()}` now honor the ALPN protocols in the configuation parameter. [#2226] + +### Removed +- Stop re-exporting `http` crate's `HeaderMap` types in addition to ours. [#2171] +- Down-casting for `MessageBody` types. [#2183] +- `error::Result` alias. [#2201] +- Error field from `Response` and `Response::error`. [#2205] +- `impl Future` for `Response`. [#2201] +- `Response::take_body` and old `Response::into_body` method that casted body type. [#2201] +- `InternalError` and all the error types it constructed. [#2215] +- Conversion (`impl Into`) of `Response` and `ResponseBuilder` to `Error`. [#2215] + +[#2171]: https://github.com/actix/actix-web/pull/2171 +[#2183]: https://github.com/actix/actix-web/pull/2183 +[#2196]: https://github.com/actix/actix-web/pull/2196 +[#2201]: https://github.com/actix/actix-web/pull/2201 +[#2205]: https://github.com/actix/actix-web/pull/2205 +[#2215]: https://github.com/actix/actix-web/pull/2215 +[#2253]: https://github.com/actix/actix-web/pull/2253 +[#2244]: https://github.com/actix/actix-web/pull/2244 + + + +## 3.0.0-beta.6 - 2021-04-17 +### Added +- `impl MessageBody for Pin>`. [#2152] +- `Response::{ok, bad_request, not_found, internal_server_error}`. [#2159] +- Helper `body::to_bytes` for async collecting message body into Bytes. [#2158] + +### Changes +- The type parameter of `Response` no longer has a default. [#2152] +- The `Message` variant of `body::Body` is now `Pin>`. [#2152] +- `BodyStream` and `SizedStream` are no longer restricted to Unpin types. [#2152] +- Error enum types are marked `#[non_exhaustive]`. [#2161] + +### Removed +- `cookies` feature flag. [#2065] +- Top-level `cookies` mod (re-export). [#2065] +- `HttpMessage` trait loses the `cookies` and `cookie` methods. [#2065] +- `impl ResponseError for CookieParseError`. [#2065] +- Deprecated methods on `ResponseBuilder`: `if_true`, `if_some`. [#2148] +- `ResponseBuilder::json`. [#2148] +- `ResponseBuilder::{set_header, header}`. [#2148] +- `impl From for Body`. [#2148] +- `Response::build_from`. [#2159] +- Most of the status code builders on `Response`. [#2159] + +[#2065]: https://github.com/actix/actix-web/pull/2065 +[#2148]: https://github.com/actix/actix-web/pull/2148 +[#2152]: https://github.com/actix/actix-web/pull/2152 +[#2159]: https://github.com/actix/actix-web/pull/2159 +[#2158]: https://github.com/actix/actix-web/pull/2158 +[#2161]: https://github.com/actix/actix-web/pull/2161 + + +## 3.0.0-beta.5 - 2021-04-02 +### Added +- `client::Connector::handshake_timeout` method for customizing TLS connection handshake timeout. [#2081] +- `client::ConnectorService` as `client::Connector::finish` method's return type [#2081] +- `client::ConnectionIo` trait alias [#2081] + +### Changed +- `client::Connector` type now only have one generic type for `actix_service::Service`. [#2063] + +### Removed +- Common typed HTTP headers were moved to actix-web. [2094] +- `ResponseError` impl for `actix_utils::timeout::TimeoutError`. [#2127] + +[#2063]: https://github.com/actix/actix-web/pull/2063 +[#2081]: https://github.com/actix/actix-web/pull/2081 +[#2094]: https://github.com/actix/actix-web/pull/2094 +[#2127]: https://github.com/actix/actix-web/pull/2127 + + +## 3.0.0-beta.4 - 2021-03-08 +### Changed +- Feature `cookies` is now optional and disabled by default. [#1981] +- `ws::hash_key` now returns array. [#2035] +- `ResponseBuilder::json` now takes `impl Serialize`. [#2052] + +### Removed +- Re-export of `futures_channel::oneshot::Canceled` is removed from `error` mod. [#1994] +- `ResponseError` impl for `futures_channel::oneshot::Canceled` is removed. [#1994] [#1981]: https://github.com/actix/actix-web/pull/1981 +[#1994]: https://github.com/actix/actix-web/pull/1994 +[#2035]: https://github.com/actix/actix-web/pull/2035 +[#2052]: https://github.com/actix/actix-web/pull/2052 ## 3.0.0-beta.3 - 2021-02-10 -* No notable changes. +- No notable changes. ## 3.0.0-beta.2 - 2021-02-10 ### Added -* `IntoHeaderPair` trait that allows using typed and untyped headers in the same methods. [#1869] -* `ResponseBuilder::insert_header` method which allows using typed headers. [#1869] -* `ResponseBuilder::append_header` method which allows using typed headers. [#1869] -* `TestRequest::insert_header` method which allows using typed headers. [#1869] -* `ContentEncoding` implements all necessary header traits. [#1912] -* `HeaderMap::len_keys` has the behavior of the old `len` method. [#1964] -* `HeaderMap::drain` as an efficient draining iterator. [#1964] -* Implement `IntoIterator` for owned `HeaderMap`. [#1964] -* `trust-dns` optional feature to enable `trust-dns-resolver` as client dns resolver. [#1969] +- `TryIntoHeaderPair` trait that allows using typed and untyped headers in the same methods. [#1869] +- `ResponseBuilder::insert_header` method which allows using typed headers. [#1869] +- `ResponseBuilder::append_header` method which allows using typed headers. [#1869] +- `TestRequest::insert_header` method which allows using typed headers. [#1869] +- `ContentEncoding` implements all necessary header traits. [#1912] +- `HeaderMap::len_keys` has the behavior of the old `len` method. [#1964] +- `HeaderMap::drain` as an efficient draining iterator. [#1964] +- Implement `IntoIterator` for owned `HeaderMap`. [#1964] +- `trust-dns` optional feature to enable `trust-dns-resolver` as client dns resolver. [#1969] ### Changed -* `ResponseBuilder::content_type` now takes an `impl IntoHeaderValue` to support using typed +- `ResponseBuilder::content_type` now takes an `impl TryIntoHeaderValue` to support using typed `mime` types. [#1894] -* Renamed `IntoHeaderValue::{try_into => try_into_value}` to avoid ambiguity with std +- Renamed `TryIntoHeaderValue::{try_into => try_into_value}` to avoid ambiguity with std `TryInto` trait. [#1894] -* `Extensions::insert` returns Option of replaced item. [#1904] -* Remove `HttpResponseBuilder::json2()`. [#1903] -* Enable `HttpResponseBuilder::json()` to receive data by value and reference. [#1903] -* `client::error::ConnectError` Resolver variant contains `Box` type. [#1905] -* `client::ConnectorConfig` default timeout changed to 5 seconds. [#1905] -* Simplify `BlockingError` type to a unit struct. It's now only triggered when blocking thread pool +- `Extensions::insert` returns Option of replaced item. [#1904] +- Remove `HttpResponseBuilder::json2()`. [#1903] +- Enable `HttpResponseBuilder::json()` to receive data by value and reference. [#1903] +- `client::error::ConnectError` Resolver variant contains `Box` type. [#1905] +- `client::ConnectorConfig` default timeout changed to 5 seconds. [#1905] +- Simplify `BlockingError` type to a unit struct. It's now only triggered when blocking thread pool is dead. [#1957] -* `HeaderMap::len` now returns number of values instead of number of keys. [#1964] -* `HeaderMap::insert` now returns iterator of removed values. [#1964] -* `HeaderMap::remove` now returns iterator of removed values. [#1964] +- `HeaderMap::len` now returns number of values instead of number of keys. [#1964] +- `HeaderMap::insert` now returns iterator of removed values. [#1964] +- `HeaderMap::remove` now returns iterator of removed values. [#1964] ### Removed -* `ResponseBuilder::set`; use `ResponseBuilder::insert_header`. [#1869] -* `ResponseBuilder::set_header`; use `ResponseBuilder::insert_header`. [#1869] -* `ResponseBuilder::header`; use `ResponseBuilder::append_header`. [#1869] -* `TestRequest::with_hdr`; use `TestRequest::default().insert_header()`. [#1869] -* `TestRequest::with_header`; use `TestRequest::default().insert_header()`. [#1869] -* `actors` optional feature. [#1969] -* `ResponseError` impl for `actix::MailboxError`. [#1969] +- `ResponseBuilder::set`; use `ResponseBuilder::insert_header`. [#1869] +- `ResponseBuilder::set_header`; use `ResponseBuilder::insert_header`. [#1869] +- `ResponseBuilder::header`; use `ResponseBuilder::append_header`. [#1869] +- `TestRequest::with_hdr`; use `TestRequest::default().insert_header()`. [#1869] +- `TestRequest::with_header`; use `TestRequest::default().insert_header()`. [#1869] +- `actors` optional feature. [#1969] +- `ResponseError` impl for `actix::MailboxError`. [#1969] ### Documentation -* Vastly improve docs and add examples for `HeaderMap`. [#1964] +- Vastly improve docs and add examples for `HeaderMap`. [#1964] [#1869]: https://github.com/actix/actix-web/pull/1869 [#1894]: https://github.com/actix/actix-web/pull/1894 @@ -64,24 +372,24 @@ ## 3.0.0-beta.1 - 2021-01-07 ### Added -* Add `Http3` to `Protocol` enum for future compatibility and also mark `#[non_exhaustive]`. +- Add `Http3` to `Protocol` enum for future compatibility and also mark `#[non_exhaustive]`. ### Changed -* Update `actix-*` dependencies to tokio `1.0` based versions. [#1813] -* Bumped `rand` to `0.8`. -* Update `bytes` to `1.0`. [#1813] -* Update `h2` to `0.3`. [#1813] -* The `ws::Message::Text` enum variant now contains a `bytestring::ByteString`. [#1864] +- Update `actix-*` dependencies to tokio `1.0` based versions. [#1813] +- Bumped `rand` to `0.8`. +- Update `bytes` to `1.0`. [#1813] +- Update `h2` to `0.3`. [#1813] +- The `ws::Message::Text` enum variant now contains a `bytestring::ByteString`. [#1864] ### Removed -* Deprecated `on_connect` methods have been removed. Prefer the new +- Deprecated `on_connect` methods have been removed. Prefer the new `on_connect_ext` technique. [#1857] -* Remove `ResponseError` impl for `actix::actors::resolver::ResolverError` +- Remove `ResponseError` impl for `actix::actors::resolver::ResolverError` due to deprecate of resolver actor. [#1813] -* Remove `ConnectError::SslHandshakeError` and re-export of `HandshakeError`. +- Remove `ConnectError::SslHandshakeError` and re-export of `HandshakeError`. due to the removal of this type from `tokio-openssl` crate. openssl handshake error would return as `ConnectError::SslError`. [#1813] -* Remove `actix-threadpool` dependency. Use `actix_rt::task::spawn_blocking`. +- Remove `actix-threadpool` dependency. Use `actix_rt::task::spawn_blocking`. Due to this change `actix_threadpool::BlockingError` type is moved into `actix_http::error` module. [#1878] @@ -91,17 +399,22 @@ [#1878]: https://github.com/actix/actix-web/pull/1878 +## 2.2.1 - 2021-08-09 +### Fixed +- Potential HTTP request smuggling vulnerabilities. [RUSTSEC-2021-0081](https://github.com/rustsec/advisory-db/pull/977) + + ## 2.2.0 - 2020-11-25 ### Added -* HttpResponse builders for 1xx status codes. [#1768] -* `Accept::mime_precedence` and `Accept::mime_preference`. [#1793] -* `TryFrom` and `TryFrom` for `http::header::Quality`. [#1797] +- HttpResponse builders for 1xx status codes. [#1768] +- `Accept::mime_precedence` and `Accept::mime_preference`. [#1793] +- `TryFrom` and `TryFrom` for `http::header::Quality`. [#1797] ### Fixed -* Started dropping `transfer-encoding: chunked` and `Content-Length` for 1XX and 204 responses. [#1767] +- Started dropping `transfer-encoding: chunked` and `Content-Length` for 1XX and 204 responses. [#1767] ### Changed -* Upgrade `serde_urlencoded` to `0.7`. [#1773] +- Upgrade `serde_urlencoded` to `0.7`. [#1773] [#1773]: https://github.com/actix/actix-web/pull/1773 [#1767]: https://github.com/actix/actix-web/pull/1767 @@ -112,12 +425,12 @@ ## 2.1.0 - 2020-10-30 ### Added -* Added more flexible `on_connect_ext` methods for on-connect handling. [#1754] +- Added more flexible `on_connect_ext` methods for on-connect handling. [#1754] ### Changed -* Upgrade `base64` to `0.13`. [#1744] -* Upgrade `pin-project` to `1.0`. [#1733] -* Deprecate `ResponseBuilder::{if_some, if_true}`. [#1760] +- Upgrade `base64` to `0.13`. [#1744] +- Upgrade `pin-project` to `1.0`. [#1733] +- Deprecate `ResponseBuilder::{if_some, if_true}`. [#1760] [#1760]: https://github.com/actix/actix-web/pull/1760 [#1754]: https://github.com/actix/actix-web/pull/1754 @@ -126,28 +439,28 @@ ## 2.0.0 - 2020-09-11 -* No significant changes from `2.0.0-beta.4`. +- No significant changes from `2.0.0-beta.4`. ## 2.0.0-beta.4 - 2020-09-09 ### Changed -* Update actix-codec and actix-utils dependencies. -* Update actix-connect and actix-tls dependencies. +- Update actix-codec and actix-utils dependencies. +- Update actix-connect and actix-tls dependencies. ## 2.0.0-beta.3 - 2020-08-14 ### Fixed -* Memory leak of `client::pool::ConnectorPoolSupport`. [#1626] +- Memory leak of `client::pool::ConnectorPoolSupport`. [#1626] [#1626]: https://github.com/actix/actix-web/pull/1626 ## 2.0.0-beta.2 - 2020-07-21 ### Fixed -* Potential UB in h1 decoder using uninitialized memory. [#1614] +- Potential UB in h1 decoder using uninitialized memory. [#1614] ### Changed -* Fix illegal chunked encoding. [#1615] +- Fix illegal chunked encoding. [#1615] [#1614]: https://github.com/actix/actix-web/pull/1614 [#1615]: https://github.com/actix/actix-web/pull/1615 @@ -155,10 +468,10 @@ ## 2.0.0-beta.1 - 2020-07-11 ### Changed -* Migrate cookie handling to `cookie` crate. [#1558] -* Update `sha-1` to 0.9. [#1586] -* Fix leak in client pool. [#1580] -* MSRV is now 1.41.1. +- Migrate cookie handling to `cookie` crate. [#1558] +- Update `sha-1` to 0.9. [#1586] +- Fix leak in client pool. [#1580] +- MSRV is now 1.41.1. [#1558]: https://github.com/actix/actix-web/pull/1558 [#1586]: https://github.com/actix/actix-web/pull/1586 @@ -167,15 +480,15 @@ ## 2.0.0-alpha.4 - 2020-05-21 ### Changed -* Bump minimum supported Rust version to 1.40 -* content_length function is removed, and you can set Content-Length by calling +- Bump minimum supported Rust version to 1.40 +- content_length function is removed, and you can set Content-Length by calling no_chunking function [#1439] -* `BodySize::Sized64` variant has been removed. `BodySize::Sized` now receives a +- `BodySize::Sized64` variant has been removed. `BodySize::Sized` now receives a `u64` instead of a `usize`. -* Update `base64` dependency to 0.12 +- Update `base64` dependency to 0.12 ### Fixed -* Support parsing of `SameSite=None` [#1503] +- Support parsing of `SameSite=None` [#1503] [#1439]: https://github.com/actix/actix-web/pull/1439 [#1503]: https://github.com/actix/actix-web/pull/1503 @@ -183,13 +496,13 @@ ## 2.0.0-alpha.3 - 2020-05-08 ### Fixed -* Correct spelling of ConnectError::Unresolved [#1487] -* Fix a mistake in the encoding of websocket continuation messages wherein +- Correct spelling of ConnectError::Unresolved [#1487] +- Fix a mistake in the encoding of websocket continuation messages wherein Item::FirstText and Item::FirstBinary are each encoded as the other. ### Changed -* Implement `std::error::Error` for our custom errors [#1422] -* Remove `failure` support for `ResponseError` since that crate +- Implement `std::error::Error` for our custom errors [#1422] +- Remove `failure` support for `ResponseError` since that crate will be deprecated in the near future. [#1422]: https://github.com/actix/actix-web/pull/1422 @@ -198,12 +511,12 @@ ## 2.0.0-alpha.2 - 2020-03-07 ### Changed -* Update `actix-connect` and `actix-tls` dependency to 2.0.0-alpha.1. [#1395] -* Change default initial window size and connection window size for HTTP2 to 2MB and 1MB +- Update `actix-connect` and `actix-tls` dependency to 2.0.0-alpha.1. [#1395] +- Change default initial window size and connection window size for HTTP2 to 2MB and 1MB respectively to improve download speed for awc when downloading large objects. [#1394] -* client::Connector accepts initial_window_size and initial_connection_window_size +- client::Connector accepts initial_window_size and initial_connection_window_size HTTP2 configuration. [#1394] -* client::Connector allowing to set max_http_version to limit HTTP version to be used. [#1394] +- client::Connector allowing to set max_http_version to limit HTTP version to be used. [#1394] [#1394]: https://github.com/actix/actix-web/pull/1394 [#1395]: https://github.com/actix/actix-web/pull/1395 @@ -211,61 +524,61 @@ ## 2.0.0-alpha.1 - 2020-02-27 ### Changed -* Update the `time` dependency to 0.2.7. -* Moved actors messages support from actix crate, enabled with feature `actors`. -* Breaking change: trait MessageBody requires Unpin and accepting `Pin<&mut Self>` instead of +- Update the `time` dependency to 0.2.7. +- Moved actors messages support from actix crate, enabled with feature `actors`. +- Breaking change: trait MessageBody requires Unpin and accepting `Pin<&mut Self>` instead of `&mut self` in the poll_next(). -* MessageBody is not implemented for &'static [u8] anymore. +- MessageBody is not implemented for &'static [u8] anymore. ### Fixed -* Allow `SameSite=None` cookies to be sent in a response. +- Allow `SameSite=None` cookies to be sent in a response. ## 1.0.1 - 2019-12-20 ### Fixed -* Poll upgrade service's readiness from HTTP service handlers -* Replace brotli with brotli2 #1224 +- Poll upgrade service's readiness from HTTP service handlers +- Replace brotli with brotli2 #1224 ## 1.0.0 - 2019-12-13 ### Added -* Add websockets continuation frame support +- Add websockets continuation frame support ### Changed -* Replace `flate2-xxx` features with `compress` +- Replace `flate2-xxx` features with `compress` ## 1.0.0-alpha.5 - 2019-12-09 ### Fixed -* Check `Upgrade` service readiness before calling it -* Fix buffer remaining capacity calculation +- Check `Upgrade` service readiness before calling it +- Fix buffer remaining capacity calculation ### Changed -* Websockets: Ping and Pong should have binary data #1049 +- Websockets: Ping and Pong should have binary data #1049 ## 1.0.0-alpha.4 - 2019-12-08 ### Added -* Add impl ResponseBuilder for Error +- Add impl ResponseBuilder for Error ### Changed -* Use rust based brotli compression library +- Use rust based brotli compression library ## 1.0.0-alpha.3 - 2019-12-07 ### Changed -* Migrate to tokio 0.2 -* Migrate to `std::future` +- Migrate to tokio 0.2 +- Migrate to `std::future` ## 0.2.11 - 2019-11-06 ### Added -* Add support for serde_json::Value to be passed as argument to ResponseBuilder.body() -* Add an additional `filename*` param in the `Content-Disposition` header of +- Add support for serde_json::Value to be passed as argument to ResponseBuilder.body() +- Add an additional `filename*` param in the `Content-Disposition` header of `actix_files::NamedFile` to be more compatible. (#1151) -* Allow to use `std::convert::Infallible` as `actix_http::error::Error` +- Allow to use `std::convert::Infallible` as `actix_http::error::Error` ### Fixed -* To be compatible with non-English error responses, `ResponseError` rendered with `text/plain; +- To be compatible with non-English error responses, `ResponseError` rendered with `text/plain; charset=utf-8` header [#1118] [#1878]: https://github.com/actix/actix-web/pull/1878 @@ -273,169 +586,169 @@ ## 0.2.10 - 2019-09-11 ### Added -* Add support for sending HTTP requests with `Rc` in addition to sending HTTP requests +- Add support for sending HTTP requests with `Rc` in addition to sending HTTP requests with `RequestHead` ### Fixed -* h2 will use error response #1080 -* on_connect result isn't added to request extensions for http2 requests #1009 +- h2 will use error response #1080 +- on_connect result isn't added to request extensions for http2 requests #1009 ## 0.2.9 - 2019-08-13 ### Changed -* Dropped the `byteorder`-dependency in favor of `stdlib`-implementation -* Update percent-encoding to 2.1 -* Update serde_urlencoded to 0.6.1 +- Dropped the `byteorder`-dependency in favor of `stdlib`-implementation +- Update percent-encoding to 2.1 +- Update serde_urlencoded to 0.6.1 ### Fixed -* Fixed a panic in the HTTP2 handshake in client HTTP requests (#1031) +- Fixed a panic in the HTTP2 handshake in client HTTP requests (#1031) ## 0.2.8 - 2019-08-01 ### Added -* Add `rustls` support -* Add `Clone` impl for `HeaderMap` +- Add `rustls` support +- Add `Clone` impl for `HeaderMap` ### Fixed -* awc client panic #1016 -* Invalid response with compression middleware enabled, but compression-related features +- awc client panic #1016 +- Invalid response with compression middleware enabled, but compression-related features disabled #997 ## 0.2.7 - 2019-07-18 ### Added -* Add support for downcasting response errors #986 +- Add support for downcasting response errors #986 ## 0.2.6 - 2019-07-17 ### Changed -* Replace `ClonableService` with local copy -* Upgrade `rand` dependency version to 0.7 +- Replace `ClonableService` with local copy +- Upgrade `rand` dependency version to 0.7 ## 0.2.5 - 2019-06-28 ### Added -* Add `on-connect` callback, `HttpServiceBuilder::on_connect()` #946 +- Add `on-connect` callback, `HttpServiceBuilder::on_connect()` #946 ### Changed -* Use `encoding_rs` crate instead of unmaintained `encoding` crate -* Add `Copy` and `Clone` impls for `ws::Codec` +- Use `encoding_rs` crate instead of unmaintained `encoding` crate +- Add `Copy` and `Clone` impls for `ws::Codec` ## 0.2.4 - 2019-06-16 ### Fixed -* Do not compress NoContent (204) responses #918 +- Do not compress NoContent (204) responses #918 ## 0.2.3 - 2019-06-02 ### Added -* Debug impl for ResponseBuilder -* From SizedStream and BodyStream for Body +- Debug impl for ResponseBuilder +- From SizedStream and BodyStream for Body ### Changed -* SizedStream uses u64 +- SizedStream uses u64 ## 0.2.2 - 2019-05-29 ### Fixed -* Parse incoming stream before closing stream on disconnect #868 +- Parse incoming stream before closing stream on disconnect #868 ## 0.2.1 - 2019-05-25 ### Fixed -* Handle socket read disconnect +- Handle socket read disconnect ## 0.2.0 - 2019-05-12 ### Changed -* Update actix-service to 0.4 -* Expect and upgrade services accept `ServerConfig` config. +- Update actix-service to 0.4 +- Expect and upgrade services accept `ServerConfig` config. ### Deleted -* `OneRequest` service +- `OneRequest` service ## 0.1.5 - 2019-05-04 ### Fixed -* Clean up response extensions in response pool #817 +- Clean up response extensions in response pool #817 ## 0.1.4 - 2019-04-24 ### Added -* Allow to render h1 request headers in `Camel-Case` +- Allow to render h1 request headers in `Camel-Case` ### Fixed -* Read until eof for http/1.0 responses #771 +- Read until eof for http/1.0 responses #771 ## 0.1.3 - 2019-04-23 ### Fixed -* Fix http client pool management -* Fix http client wait queue management #794 +- Fix http client pool management +- Fix http client wait queue management #794 ## 0.1.2 - 2019-04-23 ### Fixed -* Fix BorrowMutError panic in client connector #793 +- Fix BorrowMutError panic in client connector #793 ## 0.1.1 - 2019-04-19 ### Changed -* Cookie::max_age() accepts value in seconds -* Cookie::max_age_time() accepts value in time::Duration -* Allow to specify server address for client connector +- Cookie::max_age() accepts value in seconds +- Cookie::max_age_time() accepts value in time::Duration +- Allow to specify server address for client connector ## 0.1.0 - 2019-04-16 ### Added -* Expose peer addr via `Request::peer_addr()` and `RequestHead::peer_addr` +- Expose peer addr via `Request::peer_addr()` and `RequestHead::peer_addr` ### Changed -* `actix_http::encoding` always available -* use trust-dns-resolver 0.11.0 +- `actix_http::encoding` always available +- use trust-dns-resolver 0.11.0 ## 0.1.0-alpha.5 - 2019-04-12 ### Added -* Allow to use custom service for upgrade requests -* Added `h1::SendResponse` future. +- Allow to use custom service for upgrade requests +- Added `h1::SendResponse` future. ### Changed -* MessageBody::length() renamed to MessageBody::size() for consistency -* ws handshake verification functions take RequestHead instead of Request +- MessageBody::length() renamed to MessageBody::size() for consistency +- ws handshake verification functions take RequestHead instead of Request ## 0.1.0-alpha.4 - 2019-04-08 ### Added -* Allow to use custom `Expect` handler -* Add minimal `std::error::Error` impl for `Error` +- Allow to use custom `Expect` handler +- Add minimal `std::error::Error` impl for `Error` ### Changed -* Export IntoHeaderValue -* Render error and return as response body -* Use thread pool for response body compression +- Export IntoHeaderValue +- Render error and return as response body +- Use thread pool for response body compression ### Deleted -* Removed PayloadBuffer +- Removed PayloadBuffer ## 0.1.0-alpha.3 - 2019-04-02 ### Added -* Warn when an unsealed private cookie isn't valid UTF-8 +- Warn when an unsealed private cookie isn't valid UTF-8 ### Fixed -* Rust 1.31.0 compatibility -* Preallocate read buffer for h1 codec -* Detect socket disconnection during protocol selection +- Rust 1.31.0 compatibility +- Preallocate read buffer for h1 codec +- Detect socket disconnection during protocol selection ## 0.1.0-alpha.2 - 2019-03-29 ### Added -* Added ws::Message::Nop, no-op websockets message +- Added ws::Message::Nop, no-op websockets message ### Changed -* Do not use thread pool for decompression if chunk size is smaller than 2048. +- Do not use thread pool for decompression if chunk size is smaller than 2048. ## 0.1.0-alpha.1 - 2019-03-28 -* Initial impl +- Initial impl diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 69a344d4a..b0aa199b9 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -1,22 +1,23 @@ [package] name = "actix-http" -version = "3.0.0-beta.3" +version = "3.0.0-beta.18" authors = ["Nikolay Kim "] description = "HTTP primitives for the Actix ecosystem" -readme = "README.md" keywords = ["actix", "http", "framework", "async", "futures"] homepage = "https://actix.rs" repository = "https://github.com/actix/actix-web.git" -documentation = "https://docs.rs/actix-http/" -categories = ["network-programming", "asynchronous", - "web-programming::http-server", - "web-programming::websocket"] +categories = [ + "network-programming", + "asynchronous", + "web-programming::http-server", + "web-programming::websocket", +] license = "MIT OR Apache-2.0" edition = "2018" [package.metadata.docs.rs] # features that docs.rs will build with -features = ["openssl", "rustls", "compress", "cookies", "secure-cookies"] +features = ["openssl", "rustls", "compress-brotli", "compress-gzip", "compress-zstd"] [lib] name = "actix_http" @@ -26,84 +27,80 @@ path = "src/lib.rs" default = [] # openssl -openssl = ["actix-tls/openssl"] +openssl = ["actix-tls/accept", "actix-tls/openssl"] # rustls support -rustls = ["actix-tls/rustls"] +rustls = ["actix-tls/accept", "actix-tls/rustls"] # enable compression support -compress = ["flate2", "brotli2"] +compress-brotli = ["brotli2", "__compress"] +compress-gzip = ["flate2", "__compress"] +compress-zstd = ["zstd", "__compress"] -# support for cookies -cookies = ["cookie"] - -# support for secure cookies -secure-cookies = ["cookies", "cookie/secure"] - -# trust-dns as client dns resolver -trust-dns = ["trust-dns-resolver"] +# Internal (PRIVATE!) features used to aid testing and cheking feature status. +# Don't rely on these whatsoever. They may disappear at anytime. +__compress = [] [dependencies] -actix-service = "2.0.0-beta.4" -actix-codec = "0.4.0-beta.1" -actix-utils = "3.0.0-beta.2" -actix-rt = "2" -actix-tls = "3.0.0-beta.2" +actix-service = "2.0.0" +actix-codec = "0.4.1" +actix-utils = "3.0.0" +actix-rt = { version = "2.2", default-features = false } ahash = "0.7" base64 = "0.13" bitflags = "1.2" bytes = "1" bytestring = "1" -cfg-if = "1" -cookie = { version = "0.14.1", features = ["percent-encode"], optional = true } derive_more = "0.99.5" encoding_rs = "0.8" -futures-channel = { version = "0.3.7", default-features = false, features = ["alloc"] } futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } -futures-util = { version = "0.3.7", default-features = false, features = ["alloc", "sink"] } -h2 = "0.3.0" -http = "0.2.2" -httparse = "1.3" -indexmap = "1.3" -itoa = "0.4" -language-tags = "0.2" -lazy_static = "1.4" +h2 = "0.3.9" +http = "0.2.5" +httparse = "1.5.1" +httpdate = "1.0.1" +itoa = "1" +language-tags = "0.3" +local-channel = "0.1" log = "0.4" mime = "0.3" percent-encoding = "2.1" -pin-project = "1.0.0" +pin-project-lite = "0.2" rand = "0.8" -regex = "1.3" -serde = "1.0" -serde_json = "1.0" -serde_urlencoded = "0.7" -sha-1 = "0.9" -slab = "0.4" -smallvec = "1.6" -time = { version = "0.2.23", default-features = false, features = ["std"] } +sha-1 = "0.10" +smallvec = "1.6.1" + +# tls +actix-tls = { version = "3.0.0", default-features = false, optional = true } # compression brotli2 = { version="0.3.2", optional = true } flate2 = { version = "1.0.13", optional = true } - -trust-dns-resolver = { version = "0.20.0", optional = true } +zstd = { version = "0.9", optional = true } [dev-dependencies] -actix-server = "2.0.0-beta.3" -actix-http-test = { version = "3.0.0-beta.2", features = ["openssl"] } -actix-tls = { version = "3.0.0-beta.2", features = ["openssl"] } -criterion = "0.3" -env_logger = "0.8" -rcgen = "0.8" -serde_derive = "1.0" -tls-openssl = { version = "0.10", package = "openssl" } -tls-rustls = { version = "0.19", package = "rustls" } +actix-http-test = { version = "3.0.0-beta.11", features = ["openssl"] } +actix-server = "2.0.0-rc.2" +actix-tls = { version = "3.0.0", features = ["openssl"] } +actix-web = "4.0.0-beta.19" -[target.'cfg(windows)'.dev-dependencies.tls-openssl] -version = "0.10.9" -package = "openssl" -features = ["vendored"] +async-stream = "0.3" +criterion = { version = "0.3", features = ["html_reports"] } +env_logger = "0.9" +futures-util = { version = "0.3.7", default-features = false, features = ["alloc"] } +rcgen = "0.8" +regex = "1.3" +rustls-pemfile = "0.2" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +static_assertions = "1" +tls-openssl = { package = "openssl", version = "0.10.9" } +tls-rustls = { package = "rustls", version = "0.20.0" } +tokio = { version = "1.8.4", features = ["net", "rt", "macros"] } + +[[example]] +name = "ws" +required-features = ["rustls"] [[bench]] name = "write-camel-case" @@ -116,3 +113,7 @@ harness = false [[bench]] name = "uninit-headers" harness = false + +[[bench]] +name = "quality-value" +harness = false diff --git a/actix-http/README.md b/actix-http/README.md index 881fbc8c5..9883cc3f0 100644 --- a/actix-http/README.md +++ b/actix-http/README.md @@ -3,19 +3,18 @@ > HTTP primitives for the Actix ecosystem. [![crates.io](https://img.shields.io/crates/v/actix-http?label=latest)](https://crates.io/crates/actix-http) -[![Documentation](https://docs.rs/actix-http/badge.svg?version=3.0.0-beta.3)](https://docs.rs/actix-http/3.0.0-beta.3) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-http/badge.svg?version=3.0.0-beta.18)](https://docs.rs/actix-http/3.0.0-beta.18) +[![Version](https://img.shields.io/badge/rustc-1.54+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.54.0.html) ![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-http.svg)
-[![dependency status](https://deps.rs/crate/actix-http/3.0.0-beta.3/status.svg)](https://deps.rs/crate/actix-http/3.0.0-beta.3) +[![dependency status](https://deps.rs/crate/actix-http/3.0.0-beta.18/status.svg)](https://deps.rs/crate/actix-http/3.0.0-beta.18) [![Download](https://img.shields.io/crates/d/actix-http.svg)](https://crates.io/crates/actix-http) -[![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-http) -- [Chat on Gitter](https://gitter.im/actix/actix-web) -- Minimum Supported Rust Version (MSRV): 1.46.0 +- Minimum Supported Rust Version (MSRV): 1.54 ## Example @@ -55,8 +54,8 @@ async fn main() -> io::Result<()> { This project is licensed under either of -* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)) -* MIT license ([LICENSE-MIT](LICENSE-MIT) or [http://opensource.org/licenses/MIT](http://opensource.org/licenses/MIT)) +- Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)) +- MIT license ([LICENSE-MIT](LICENSE-MIT) or [http://opensource.org/licenses/MIT](http://opensource.org/licenses/MIT)) at your option. diff --git a/actix-http/benches/quality-value.rs b/actix-http/benches/quality-value.rs new file mode 100644 index 000000000..31b67f999 --- /dev/null +++ b/actix-http/benches/quality-value.rs @@ -0,0 +1,90 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; + +const CODES: &[u16] = &[0, 1000, 201, 800, 550]; + +fn bench_quality_display_impls(c: &mut Criterion) { + let mut group = c.benchmark_group("quality value display impls"); + + for i in CODES.iter() { + group.bench_with_input(BenchmarkId::new("New (fast?)", i), i, |b, &i| { + b.iter(|| _new::Quality(i).to_string()) + }); + + group.bench_with_input(BenchmarkId::new("Naive", i), i, |b, &i| { + b.iter(|| _naive::Quality(i).to_string()) + }); + } + + group.finish(); +} + +criterion_group!(benches, bench_quality_display_impls); +criterion_main!(benches); + +mod _new { + use std::fmt; + + pub struct Quality(pub(crate) u16); + + impl fmt::Display for Quality { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + 0 => f.write_str("0"), + 1000 => f.write_str("1"), + + // some number in the range 1–999 + x => { + f.write_str("0.")?; + + // this implementation avoids string allocation otherwise required + // for `.trim_end_matches('0')` + + if x < 10 { + f.write_str("00")?; + // 0 is handled so it's not possible to have a trailing 0, we can just return + itoa::fmt(f, x) + } else if x < 100 { + f.write_str("0")?; + if x % 10 == 0 { + // trailing 0, divide by 10 and write + itoa::fmt(f, x / 10) + } else { + itoa::fmt(f, x) + } + } else { + // x is in range 101–999 + + if x % 100 == 0 { + // two trailing 0s, divide by 100 and write + itoa::fmt(f, x / 100) + } else if x % 10 == 0 { + // one trailing 0, divide by 10 and write + itoa::fmt(f, x / 10) + } else { + itoa::fmt(f, x) + } + } + } + } + } + } +} + +mod _naive { + use std::fmt; + + pub struct Quality(pub(crate) u16); + + impl fmt::Display for Quality { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + 0 => f.write_str("0"), + 1000 => f.write_str("1"), + + x => { + write!(f, "{}", format!("{:03}", x).trim_end_matches('0')) + } + } + } + } +} diff --git a/actix-http/benches/status-line.rs b/actix-http/benches/status-line.rs index f62d18ed8..9fe099478 100644 --- a/actix-http/benches/status-line.rs +++ b/actix-http/benches/status-line.rs @@ -189,11 +189,7 @@ mod _original { n /= 100; curr -= 2; unsafe { - ptr::copy_nonoverlapping( - lut_ptr.offset(d1 as isize), - buf_ptr.offset(curr), - 2, - ); + ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), 2); } // decode last 1 or 2 chars @@ -206,11 +202,7 @@ mod _original { let d1 = n << 1; curr -= 2; unsafe { - ptr::copy_nonoverlapping( - lut_ptr.offset(d1 as isize), - buf_ptr.offset(curr), - 2, - ); + ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), 2); } } diff --git a/actix-http/benches/uninit-headers.rs b/actix-http/benches/uninit-headers.rs index 83e74171c..5dfd3bc11 100644 --- a/actix-http/benches/uninit-headers.rs +++ b/actix-http/benches/uninit-headers.rs @@ -54,15 +54,10 @@ const EMPTY_HEADER_INDEX: HeaderIndex = HeaderIndex { value: (0, 0), }; -const EMPTY_HEADER_INDEX_ARRAY: [HeaderIndex; MAX_HEADERS] = - [EMPTY_HEADER_INDEX; MAX_HEADERS]; +const EMPTY_HEADER_INDEX_ARRAY: [HeaderIndex; MAX_HEADERS] = [EMPTY_HEADER_INDEX; MAX_HEADERS]; impl HeaderIndex { - fn record( - bytes: &[u8], - headers: &[httparse::Header<'_>], - indices: &mut [HeaderIndex], - ) { + fn record(bytes: &[u8], headers: &[httparse::Header<'_>], indices: &mut [HeaderIndex]) { let bytes_ptr = bytes.as_ptr() as usize; for (header, indices) in headers.iter().zip(indices.iter_mut()) { let name_start = header.name.as_ptr() as usize - bytes_ptr; @@ -78,12 +73,12 @@ impl HeaderIndex { // test cases taken from: // https://github.com/seanmonstar/httparse/blob/master/benches/parse.rs -const REQ_SHORT: &'static [u8] = b"\ +const REQ_SHORT: &[u8] = b"\ GET / HTTP/1.0\r\n\ Host: example.com\r\n\ Cookie: session=60; user_id=1\r\n\r\n"; -const REQ: &'static [u8] = b"\ +const REQ: &[u8] = b"\ GET /wp-content/uploads/2010/03/hello-kitty-darth-vader-pink.jpg HTTP/1.1\r\n\ Host: www.kittyhell.com\r\n\ User-Agent: Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.6; ja-JP-mac; rv:1.9.2.3) Gecko/20100401 Firefox/3.6.3 Pathtraq/0.9\r\n\ @@ -119,6 +114,8 @@ mod _original { use std::mem::MaybeUninit; pub fn parse_headers(src: &mut BytesMut) -> usize { + #![allow(clippy::uninit_assumed_init)] + let mut headers: [HeaderIndex; MAX_HEADERS] = unsafe { MaybeUninit::uninit().assume_init() }; diff --git a/actix-http/benches/write-camel-case.rs b/actix-http/benches/write-camel-case.rs index fa4930eb9..ccf09b37e 100644 --- a/actix-http/benches/write-camel-case.rs +++ b/actix-http/benches/write-camel-case.rs @@ -18,7 +18,8 @@ fn bench_write_camel_case(c: &mut Criterion) { group.bench_with_input(BenchmarkId::new("New", i), bts, |b, bts| { b.iter(|| { let mut buf = black_box([0; 24]); - _new::write_camel_case(black_box(bts), &mut buf) + let len = black_box(bts.len()); + _new::write_camel_case(black_box(bts), buf.as_mut_ptr(), len) }); }); } @@ -30,9 +31,12 @@ criterion_group!(benches, bench_write_camel_case); criterion_main!(benches); mod _new { - pub fn write_camel_case(value: &[u8], buffer: &mut [u8]) { + pub fn write_camel_case(value: &[u8], buf: *mut u8, len: usize) { // first copy entire (potentially wrong) slice to output - buffer[..value.len()].copy_from_slice(value); + let buffer = unsafe { + std::ptr::copy_nonoverlapping(value.as_ptr(), buf, len); + std::slice::from_raw_parts_mut(buf, len) + }; let mut iter = value.iter(); diff --git a/actix-http/examples/actix-web.rs b/actix-http/examples/actix-web.rs new file mode 100644 index 000000000..f8226507f --- /dev/null +++ b/actix-http/examples/actix-web.rs @@ -0,0 +1,26 @@ +use actix_http::HttpService; +use actix_server::Server; +use actix_service::map_config; +use actix_web::{dev::AppConfig, get, App}; + +#[get("/")] +async fn index() -> &'static str { + "Hello, world. From Actix Web!" +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> std::io::Result<()> { + Server::build() + .bind("hello-world", "127.0.0.1:8080", || { + // construct actix-web app + let app = App::new().service(index); + + HttpService::build() + // pass the app to service builder + // map_config is used to map App's configuration to ServiceBuilder + .finish(map_config(app, |_| AppConfig::default())) + .tcp() + })? + .run() + .await +} diff --git a/actix-http/examples/echo.rs b/actix-http/examples/echo.rs index 90d768cbe..22f553f38 100644 --- a/actix-http/examples/echo.rs +++ b/actix-http/examples/echo.rs @@ -1,19 +1,17 @@ -use std::{env, io}; +use std::io; -use actix_http::{Error, HttpService, Request, Response}; +use actix_http::{Error, HttpService, Request, Response, StatusCode}; use actix_server::Server; use bytes::BytesMut; -use futures_util::StreamExt; +use futures_util::StreamExt as _; use http::header::HeaderValue; -use log::info; #[actix_rt::main] async fn main() -> io::Result<()> { - env::set_var("RUST_LOG", "echo=info"); - env_logger::init(); + env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); Server::build() - .bind("echo", "127.0.0.1:8080", || { + .bind("echo", ("127.0.0.1", 8080), || { HttpService::build() .client_timeout(1000) .client_disconnect(1000) @@ -23,13 +21,11 @@ async fn main() -> io::Result<()> { body.extend_from_slice(&item?); } - info!("request body: {:?}", body); + log::info!("request body: {:?}", body); + Ok::<_, Error>( - Response::Ok() - .insert_header(( - "x-head", - HeaderValue::from_static("dummy value!"), - )) + Response::build(StatusCode::OK) + .insert_header(("x-head", HeaderValue::from_static("dummy value!"))) .body(body), ) }) diff --git a/actix-http/examples/echo2.rs b/actix-http/examples/echo2.rs index bc932ce8f..e3b915e05 100644 --- a/actix-http/examples/echo2.rs +++ b/actix-http/examples/echo2.rs @@ -1,31 +1,31 @@ -use std::{env, io}; +use std::io; -use actix_http::http::HeaderValue; -use actix_http::{Error, HttpService, Request, Response}; +use actix_http::{ + body::MessageBody, header::HeaderValue, Error, HttpService, Request, Response, StatusCode, +}; use actix_server::Server; use bytes::BytesMut; -use futures_util::StreamExt; -use log::info; +use futures_util::StreamExt as _; -async fn handle_request(mut req: Request) -> Result { +async fn handle_request(mut req: Request) -> Result, Error> { let mut body = BytesMut::new(); while let Some(item) = req.payload().next().await { body.extend_from_slice(&item?) } - info!("request body: {:?}", body); - Ok(Response::Ok() + log::info!("request body: {:?}", body); + + Ok(Response::build(StatusCode::OK) .insert_header(("x-head", HeaderValue::from_static("dummy value!"))) .body(body)) } #[actix_rt::main] async fn main() -> io::Result<()> { - env::set_var("RUST_LOG", "echo=info"); - env_logger::init(); + env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); Server::build() - .bind("echo", "127.0.0.1:8080", || { + .bind("echo", ("127.0.0.1", 8080), || { HttpService::build().finish(handle_request).tcp() })? .run() diff --git a/actix-http/examples/hello-world.rs b/actix-http/examples/hello-world.rs index a84e9aac6..a29903cc4 100644 --- a/actix-http/examples/hello-world.rs +++ b/actix-http/examples/hello-world.rs @@ -1,29 +1,35 @@ -use std::{env, io}; +use std::{convert::Infallible, io}; -use actix_http::{HttpService, Response}; +use actix_http::{ + header::HeaderValue, HttpMessage, HttpService, Request, Response, StatusCode, +}; use actix_server::Server; -use futures_util::future; -use http::header::HeaderValue; -use log::info; #[actix_rt::main] async fn main() -> io::Result<()> { - env::set_var("RUST_LOG", "hello_world=info"); - env_logger::init(); + env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); Server::build() - .bind("hello-world", "127.0.0.1:8080", || { + .bind("hello-world", ("127.0.0.1", 8080), || { HttpService::build() .client_timeout(1000) .client_disconnect(1000) - .finish(|_req| { - info!("{:?}", _req); - let mut res = Response::Ok(); + .on_connect_ext(|_, ext| { + ext.insert(42u32); + }) + .finish(|req: Request| async move { + log::info!("{:?}", req); + + let mut res = Response::build(StatusCode::OK); + res.insert_header(("x-head", HeaderValue::from_static("dummy value!"))); + + let forty_two = req.extensions().get::().unwrap().to_string(); res.insert_header(( - "x-head", - HeaderValue::from_static("dummy value!"), + "x-forty-two", + HeaderValue::from_str(&forty_two).unwrap(), )); - future::ok::<_, ()>(res.body("Hello world!")) + + Ok::<_, Infallible>(res.body("Hello world!")) }) .tcp() })? diff --git a/actix-http/examples/streaming-error.rs b/actix-http/examples/streaming-error.rs new file mode 100644 index 000000000..3988cbac2 --- /dev/null +++ b/actix-http/examples/streaming-error.rs @@ -0,0 +1,40 @@ +//! Example showing response body (chunked) stream erroring. +//! +//! Test using `nc` or `curl`. +//! ```sh +//! $ curl -vN 127.0.0.1:8080 +//! $ echo 'GET / HTTP/1.1\n\n' | nc 127.0.0.1 8080 +//! ``` + +use std::{convert::Infallible, io, time::Duration}; + +use actix_http::{body::BodyStream, HttpService, Response}; +use actix_server::Server; +use async_stream::stream; +use bytes::Bytes; + +#[actix_rt::main] +async fn main() -> io::Result<()> { + env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); + + Server::build() + .bind("streaming-error", ("127.0.0.1", 8080), || { + HttpService::build() + .finish(|req| async move { + log::info!("{:?}", req); + let res = Response::ok(); + + Ok::<_, Infallible>(res.set_body(BodyStream::new(stream! { + yield Ok(Bytes::from("123")); + yield Ok(Bytes::from("456")); + + actix_rt::time::sleep(Duration::from_millis(1000)).await; + + yield Err(io::Error::new(io::ErrorKind::Other, "")); + }))) + }) + .tcp() + })? + .run() + .await +} diff --git a/actix-http/examples/ws.rs b/actix-http/examples/ws.rs new file mode 100644 index 000000000..d70e43314 --- /dev/null +++ b/actix-http/examples/ws.rs @@ -0,0 +1,112 @@ +//! Sets up a WebSocket server over TCP and TLS. +//! Sends a heartbeat message every 4 seconds but does not respond to any incoming frames. + +extern crate tls_rustls as rustls; + +use std::{ + io, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use actix_codec::Encoder; +use actix_http::{body::BodyStream, error::Error, ws, HttpService, Request, Response}; +use actix_rt::time::{interval, Interval}; +use actix_server::Server; +use bytes::{Bytes, BytesMut}; +use bytestring::ByteString; +use futures_core::{ready, Stream}; + +#[actix_rt::main] +async fn main() -> io::Result<()> { + env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); + + Server::build() + .bind("tcp", ("127.0.0.1", 8080), || { + HttpService::build().h1(handler).tcp() + })? + .bind("tls", ("127.0.0.1", 8443), || { + HttpService::build().finish(handler).rustls(tls_config()) + })? + .run() + .await +} + +async fn handler(req: Request) -> Result>, Error> { + log::info!("handshaking"); + let mut res = ws::handshake(req.head())?; + + // handshake will always fail under HTTP/2 + + log::info!("responding"); + Ok(res.message_body(BodyStream::new(Heartbeat::new(ws::Codec::new())))?) +} + +struct Heartbeat { + codec: ws::Codec, + interval: Interval, +} + +impl Heartbeat { + fn new(codec: ws::Codec) -> Self { + Self { + codec, + interval: interval(Duration::from_secs(4)), + } + } +} + +impl Stream for Heartbeat { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + log::trace!("poll"); + + ready!(self.as_mut().interval.poll_tick(cx)); + + let mut buffer = BytesMut::new(); + + self.as_mut() + .codec + .encode( + ws::Message::Text(ByteString::from_static("hello world")), + &mut buffer, + ) + .unwrap(); + + Poll::Ready(Some(Ok(buffer.freeze()))) + } +} + +fn tls_config() -> rustls::ServerConfig { + use std::io::BufReader; + + use rustls::{Certificate, PrivateKey}; + use rustls_pemfile::{certs, pkcs8_private_keys}; + + let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); + let cert_file = cert.serialize_pem().unwrap(); + let key_file = cert.serialize_private_key_pem(); + + let cert_file = &mut BufReader::new(cert_file.as_bytes()); + let key_file = &mut BufReader::new(key_file.as_bytes()); + + let cert_chain = certs(cert_file) + .unwrap() + .into_iter() + .map(Certificate) + .collect(); + let mut keys = pkcs8_private_keys(key_file).unwrap(); + + let mut config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert_chain, PrivateKey(keys.remove(0))) + .unwrap(); + + config.alpn_protocols.push(b"http/1.1".to_vec()); + config.alpn_protocols.push(b"h2".to_vec()); + + config +} diff --git a/actix-http/rustfmt.toml b/actix-http/rustfmt.toml deleted file mode 100644 index 5fcaaca0f..000000000 --- a/actix-http/rustfmt.toml +++ /dev/null @@ -1,5 +0,0 @@ -max_width = 89 -reorder_imports = true -#wrap_comments = true -#fn_args_density = "Compressed" -#use_small_heuristics = false diff --git a/actix-http/src/body.rs b/actix-http/src/body.rs deleted file mode 100644 index 0dbe93a4a..000000000 --- a/actix-http/src/body.rs +++ /dev/null @@ -1,714 +0,0 @@ -//! Traits and structures to aid consuming and writing HTTP payloads. - -use std::{ - fmt, mem, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::{Bytes, BytesMut}; -use futures_core::{ready, Stream}; -use pin_project::pin_project; - -use crate::error::Error; - -/// Body size hint. -#[derive(Debug, PartialEq, Copy, Clone)] -pub enum BodySize { - None, - Empty, - Sized(u64), - Stream, -} - -impl BodySize { - pub fn is_eof(&self) -> bool { - matches!(self, BodySize::None | BodySize::Empty | BodySize::Sized(0)) - } -} - -/// Type that implement this trait can be streamed to a peer. -pub trait MessageBody { - fn size(&self) -> BodySize; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>>; - - downcast_get_type_id!(); -} - -downcast!(MessageBody); - -impl MessageBody for () { - fn size(&self) -> BodySize { - BodySize::Empty - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - Poll::Ready(None) - } -} - -impl MessageBody for Box { - fn size(&self) -> BodySize { - self.as_ref().size() - } - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - Pin::new(self.get_mut().as_mut()).poll_next(cx) - } -} - -#[pin_project(project = ResponseBodyProj)] -pub enum ResponseBody { - Body(#[pin] B), - Other(Body), -} - -impl ResponseBody { - pub fn into_body(self) -> ResponseBody { - match self { - ResponseBody::Body(b) => ResponseBody::Other(b), - ResponseBody::Other(b) => ResponseBody::Other(b), - } - } -} - -impl ResponseBody { - pub fn take_body(&mut self) -> ResponseBody { - mem::replace(self, ResponseBody::Other(Body::None)) - } -} - -impl ResponseBody { - pub fn as_ref(&self) -> Option<&B> { - if let ResponseBody::Body(ref b) = self { - Some(b) - } else { - None - } - } -} - -impl MessageBody for ResponseBody { - fn size(&self) -> BodySize { - match self { - ResponseBody::Body(ref body) => body.size(), - ResponseBody::Other(ref body) => body.size(), - } - } - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - match self.project() { - ResponseBodyProj::Body(body) => body.poll_next(cx), - ResponseBodyProj::Other(body) => Pin::new(body).poll_next(cx), - } - } -} - -impl Stream for ResponseBody { - type Item = Result; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.project() { - ResponseBodyProj::Body(body) => body.poll_next(cx), - ResponseBodyProj::Other(body) => Pin::new(body).poll_next(cx), - } - } -} - -/// Represents various types of HTTP message body. -pub enum Body { - /// Empty response. `Content-Length` header is not set. - None, - /// Zero sized response body. `Content-Length` header is set to `0`. - Empty, - /// Specific response body. - Bytes(Bytes), - /// Generic message body. - Message(Box), -} - -impl Body { - /// Create body from slice (copy) - pub fn from_slice(s: &[u8]) -> Body { - Body::Bytes(Bytes::copy_from_slice(s)) - } - - /// Create body from generic message body. - pub fn from_message(body: B) -> Body { - Body::Message(Box::new(body)) - } -} - -impl MessageBody for Body { - fn size(&self) -> BodySize { - match self { - Body::None => BodySize::None, - Body::Empty => BodySize::Empty, - Body::Bytes(ref bin) => BodySize::Sized(bin.len() as u64), - Body::Message(ref body) => body.size(), - } - } - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - match self.get_mut() { - Body::None => Poll::Ready(None), - Body::Empty => Poll::Ready(None), - Body::Bytes(ref mut bin) => { - let len = bin.len(); - if len == 0 { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(mem::take(bin)))) - } - } - Body::Message(body) => Pin::new(&mut **body).poll_next(cx), - } - } -} - -impl PartialEq for Body { - fn eq(&self, other: &Body) -> bool { - match *self { - Body::None => matches!(*other, Body::None), - Body::Empty => matches!(*other, Body::Empty), - Body::Bytes(ref b) => match *other { - Body::Bytes(ref b2) => b == b2, - _ => false, - }, - Body::Message(_) => false, - } - } -} - -impl fmt::Debug for Body { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - Body::None => write!(f, "Body::None"), - Body::Empty => write!(f, "Body::Empty"), - Body::Bytes(ref b) => write!(f, "Body::Bytes({:?})", b), - Body::Message(_) => write!(f, "Body::Message(_)"), - } - } -} - -impl From<&'static str> for Body { - fn from(s: &'static str) -> Body { - Body::Bytes(Bytes::from_static(s.as_ref())) - } -} - -impl From<&'static [u8]> for Body { - fn from(s: &'static [u8]) -> Body { - Body::Bytes(Bytes::from_static(s)) - } -} - -impl From> for Body { - fn from(vec: Vec) -> Body { - Body::Bytes(Bytes::from(vec)) - } -} - -impl From for Body { - fn from(s: String) -> Body { - s.into_bytes().into() - } -} - -impl<'a> From<&'a String> for Body { - fn from(s: &'a String) -> Body { - Body::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(&s))) - } -} - -impl From for Body { - fn from(s: Bytes) -> Body { - Body::Bytes(s) - } -} - -impl From for Body { - fn from(s: BytesMut) -> Body { - Body::Bytes(s.freeze()) - } -} - -impl From for Body { - fn from(v: serde_json::Value) -> Body { - Body::Bytes(v.to_string().into()) - } -} - -impl From> for Body -where - S: Stream> + Unpin + 'static, -{ - fn from(s: SizedStream) -> Body { - Body::from_message(s) - } -} - -impl From> for Body -where - S: Stream> + Unpin + 'static, - E: Into + 'static, -{ - fn from(s: BodyStream) -> Body { - Body::from_message(s) - } -} - -impl MessageBody for Bytes { - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(mem::take(self.get_mut())))) - } - } -} - -impl MessageBody for BytesMut { - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(mem::take(self.get_mut()).freeze()))) - } - } -} - -impl MessageBody for &'static str { - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(Bytes::from_static( - mem::take(self.get_mut()).as_ref(), - )))) - } - } -} - -impl MessageBody for Vec { - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(Bytes::from(mem::take(self.get_mut()))))) - } - } -} - -impl MessageBody for String { - fn size(&self) -> BodySize { - BodySize::Sized(self.len() as u64) - } - - fn poll_next( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { - if self.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(Bytes::from( - mem::take(self.get_mut()).into_bytes(), - )))) - } - } -} - -/// Type represent streaming body. -/// Response does not contain `content-length` header and appropriate transfer encoding is used. -pub struct BodyStream { - stream: S, -} - -impl BodyStream -where - S: Stream> + Unpin, - E: Into, -{ - pub fn new(stream: S) -> Self { - BodyStream { stream } - } -} - -impl MessageBody for BodyStream -where - S: Stream> + Unpin, - E: Into, -{ - fn size(&self) -> BodySize { - BodySize::Stream - } - - /// Attempts to pull out the next value of the underlying [`Stream`]. - /// - /// Empty values are skipped to prevent [`BodyStream`]'s transmission being - /// ended on a zero-length chunk, but rather proceed until the underlying - /// [`Stream`] ends. - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - loop { - let stream = &mut self.as_mut().stream; - return Poll::Ready(match ready!(Pin::new(stream).poll_next(cx)) { - Some(Ok(ref bytes)) if bytes.is_empty() => continue, - opt => opt.map(|res| res.map_err(Into::into)), - }); - } - } -} - -/// Type represent streaming body. This body implementation should be used -/// if total size of stream is known. Data get sent as is without using transfer encoding. -pub struct SizedStream { - size: u64, - stream: S, -} - -impl SizedStream -where - S: Stream> + Unpin, -{ - pub fn new(size: u64, stream: S) -> Self { - SizedStream { size, stream } - } -} - -impl MessageBody for SizedStream -where - S: Stream> + Unpin, -{ - fn size(&self) -> BodySize { - BodySize::Sized(self.size as u64) - } - - /// Attempts to pull out the next value of the underlying [`Stream`]. - /// - /// Empty values are skipped to prevent [`SizedStream`]'s transmission being - /// ended on a zero-length chunk, but rather proceed until the underlying - /// [`Stream`] ends. - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - loop { - let stream = &mut self.as_mut().stream; - return Poll::Ready(match ready!(Pin::new(stream).poll_next(cx)) { - Some(Ok(ref bytes)) if bytes.is_empty() => continue, - val => val, - }); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use futures_util::future::poll_fn; - use futures_util::pin_mut; - use futures_util::stream; - - impl Body { - pub(crate) fn get_ref(&self) -> &[u8] { - match *self { - Body::Bytes(ref bin) => &bin, - _ => panic!(), - } - } - } - - impl ResponseBody { - pub(crate) fn get_ref(&self) -> &[u8] { - match *self { - ResponseBody::Body(ref b) => b.get_ref(), - ResponseBody::Other(ref b) => b.get_ref(), - } - } - } - - #[actix_rt::test] - async fn test_static_str() { - assert_eq!(Body::from("").size(), BodySize::Sized(0)); - assert_eq!(Body::from("test").size(), BodySize::Sized(4)); - assert_eq!(Body::from("test").get_ref(), b"test"); - - assert_eq!("test".size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| Pin::new(&mut "test").poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_static_bytes() { - assert_eq!(Body::from(b"test".as_ref()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test"); - assert_eq!( - Body::from_slice(b"test".as_ref()).size(), - BodySize::Sized(4) - ); - assert_eq!(Body::from_slice(b"test".as_ref()).get_ref(), b"test"); - let sb = Bytes::from(&b"test"[..]); - pin_mut!(sb); - - assert_eq!(sb.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| sb.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_vec() { - assert_eq!(Body::from(Vec::from("test")).size(), BodySize::Sized(4)); - assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test"); - let test_vec = Vec::from("test"); - pin_mut!(test_vec); - - assert_eq!(test_vec.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| test_vec.as_mut().poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_bytes() { - let b = Bytes::from("test"); - assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b.clone()).get_ref(), b"test"); - pin_mut!(b); - - assert_eq!(b.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_bytes_mut() { - let b = BytesMut::from("test"); - assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b.clone()).get_ref(), b"test"); - pin_mut!(b); - - assert_eq!(b.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_string() { - let b = "test".to_owned(); - assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); - assert_eq!(Body::from(b.clone()).get_ref(), b"test"); - assert_eq!(Body::from(&b).size(), BodySize::Sized(4)); - assert_eq!(Body::from(&b).get_ref(), b"test"); - pin_mut!(b); - - assert_eq!(b.size(), BodySize::Sized(4)); - assert_eq!( - poll_fn(|cx| b.as_mut().poll_next(cx)).await.unwrap().ok(), - Some(Bytes::from("test")) - ); - } - - #[actix_rt::test] - async fn test_unit() { - assert_eq!(().size(), BodySize::Empty); - assert!(poll_fn(|cx| Pin::new(&mut ()).poll_next(cx)) - .await - .is_none()); - } - - #[actix_rt::test] - async fn test_box() { - let val = Box::new(()); - pin_mut!(val); - assert_eq!(val.size(), BodySize::Empty); - assert!(poll_fn(|cx| val.as_mut().poll_next(cx)).await.is_none()); - } - - #[actix_rt::test] - async fn test_body_eq() { - assert!( - Body::Bytes(Bytes::from_static(b"1")) - == Body::Bytes(Bytes::from_static(b"1")) - ); - assert!(Body::Bytes(Bytes::from_static(b"1")) != Body::None); - } - - #[actix_rt::test] - async fn test_body_debug() { - assert!(format!("{:?}", Body::None).contains("Body::None")); - assert!(format!("{:?}", Body::Empty).contains("Body::Empty")); - assert!(format!("{:?}", Body::Bytes(Bytes::from_static(b"1"))).contains('1')); - } - - #[actix_rt::test] - async fn test_serde_json() { - use serde_json::json; - assert_eq!( - Body::from(serde_json::Value::String("test".into())).size(), - BodySize::Sized(6) - ); - assert_eq!( - Body::from(json!({"test-key":"test-value"})).size(), - BodySize::Sized(25) - ); - } - - mod body_stream { - use super::*; - //use futures::task::noop_waker; - //use futures::stream::once; - - #[actix_rt::test] - async fn skips_empty_chunks() { - let body = BodyStream::new(stream::iter( - ["1", "", "2"] - .iter() - .map(|&v| Ok(Bytes::from(v)) as Result), - )); - pin_mut!(body); - - assert_eq!( - poll_fn(|cx| body.as_mut().poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("1")), - ); - assert_eq!( - poll_fn(|cx| body.as_mut().poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("2")), - ); - } - - /* Now it does not compile as it should - #[actix_rt::test] - async fn move_pinned_pointer() { - let (sender, receiver) = futures::channel::oneshot::channel(); - let mut body_stream = Ok(BodyStream::new(once(async { - let x = Box::new(0i32); - let y = &x; - receiver.await.unwrap(); - let _z = **y; - Ok::<_, ()>(Bytes::new()) - }))); - - let waker = noop_waker(); - let mut context = Context::from_waker(&waker); - pin_mut!(body_stream); - - let _ = body_stream.as_mut().unwrap().poll_next(&mut context); - sender.send(()).unwrap(); - let _ = std::mem::replace(&mut body_stream, Err([0; 32])).unwrap().poll_next(&mut context); - }*/ - } - - mod sized_stream { - use super::*; - - #[actix_rt::test] - async fn skips_empty_chunks() { - let body = SizedStream::new( - 2, - stream::iter(["1", "", "2"].iter().map(|&v| Ok(Bytes::from(v)))), - ); - pin_mut!(body); - assert_eq!( - poll_fn(|cx| body.as_mut().poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("1")), - ); - assert_eq!( - poll_fn(|cx| body.as_mut().poll_next(cx)) - .await - .unwrap() - .ok(), - Some(Bytes::from("2")), - ); - } - } - - #[actix_rt::test] - async fn test_body_casting() { - let mut body = String::from("hello cast"); - let resp_body: &mut dyn MessageBody = &mut body; - let body = resp_body.downcast_ref::().unwrap(); - assert_eq!(body, "hello cast"); - let body = &mut resp_body.downcast_mut::().unwrap(); - body.push('!'); - let body = resp_body.downcast_ref::().unwrap(); - assert_eq!(body, "hello cast!"); - let not_body = resp_body.downcast_ref::<()>(); - assert!(not_body.is_none()); - } -} diff --git a/actix-http/src/body/body_stream.rs b/actix-http/src/body/body_stream.rs new file mode 100644 index 000000000..cf4f488b2 --- /dev/null +++ b/actix-http/src/body/body_stream.rs @@ -0,0 +1,215 @@ +use std::{ + error::Error as StdError, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures_core::{ready, Stream}; +use pin_project_lite::pin_project; + +use super::{BodySize, MessageBody}; + +pin_project! { + /// Streaming response wrapper. + /// + /// Response does not contain `Content-Length` header and appropriate transfer encoding is used. + pub struct BodyStream { + #[pin] + stream: S, + } +} + +// TODO: from_infallible method + +impl BodyStream +where + S: Stream>, + E: Into> + 'static, +{ + #[inline] + pub fn new(stream: S) -> Self { + BodyStream { stream } + } +} + +impl MessageBody for BodyStream +where + S: Stream>, + E: Into> + 'static, +{ + type Error = E; + + #[inline] + fn size(&self) -> BodySize { + BodySize::Stream + } + + /// Attempts to pull out the next value of the underlying [`Stream`]. + /// + /// Empty values are skipped to prevent [`BodyStream`]'s transmission being + /// ended on a zero-length chunk, but rather proceed until the underlying + /// [`Stream`] ends. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + let stream = self.as_mut().project().stream; + + let chunk = match ready!(stream.poll_next(cx)) { + Some(Ok(ref bytes)) if bytes.is_empty() => continue, + opt => opt, + }; + + return Poll::Ready(chunk); + } + } +} + +#[cfg(test)] +mod tests { + use std::{convert::Infallible, time::Duration}; + + use actix_rt::{ + pin, + time::{sleep, Sleep}, + }; + use actix_utils::future::poll_fn; + use derive_more::{Display, Error}; + use futures_core::ready; + use futures_util::{stream, FutureExt as _}; + use pin_project_lite::pin_project; + use static_assertions::{assert_impl_all, assert_not_impl_all}; + + use super::*; + use crate::body::to_bytes; + + assert_impl_all!(BodyStream>>: MessageBody); + assert_impl_all!(BodyStream>>: MessageBody); + assert_impl_all!(BodyStream>>: MessageBody); + assert_impl_all!(BodyStream>>: MessageBody); + assert_impl_all!(BodyStream>>: MessageBody); + + assert_not_impl_all!(BodyStream>: MessageBody); + assert_not_impl_all!(BodyStream>: MessageBody); + // crate::Error is not Clone + assert_not_impl_all!(BodyStream>>: MessageBody); + + #[actix_rt::test] + async fn skips_empty_chunks() { + let body = BodyStream::new(stream::iter( + ["1", "", "2"] + .iter() + .map(|&v| Ok::<_, Infallible>(Bytes::from(v))), + )); + pin!(body); + + assert_eq!( + poll_fn(|cx| body.as_mut().poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("1")), + ); + assert_eq!( + poll_fn(|cx| body.as_mut().poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("2")), + ); + } + + #[actix_rt::test] + async fn read_to_bytes() { + let body = BodyStream::new(stream::iter( + ["1", "", "2"] + .iter() + .map(|&v| Ok::<_, Infallible>(Bytes::from(v))), + )); + + assert_eq!(to_bytes(body).await.ok(), Some(Bytes::from("12"))); + } + #[derive(Debug, Display, Error)] + #[display(fmt = "stream error")] + struct StreamErr; + + #[actix_rt::test] + async fn stream_immediate_error() { + let body = BodyStream::new(stream::once(async { Err(StreamErr) })); + assert!(matches!(to_bytes(body).await, Err(StreamErr))); + } + + #[actix_rt::test] + async fn stream_string_error() { + // `&'static str` does not impl `Error` + // but it does impl `Into>` + + let body = BodyStream::new(stream::once(async { Err("stringy error") })); + assert!(matches!(to_bytes(body).await, Err("stringy error"))); + } + + #[actix_rt::test] + async fn stream_boxed_error() { + // `Box` does not impl `Error` + // but it does impl `Into>` + + let body = BodyStream::new(stream::once(async { + Err(Box::::from("stringy error")) + })); + + assert_eq!( + to_bytes(body).await.unwrap_err().to_string(), + "stringy error" + ); + } + + #[actix_rt::test] + async fn stream_delayed_error() { + let body = BodyStream::new(stream::iter(vec![Ok(Bytes::from("1")), Err(StreamErr)])); + assert!(matches!(to_bytes(body).await, Err(StreamErr))); + + pin_project! { + #[derive(Debug)] + #[project = TimeDelayStreamProj] + enum TimeDelayStream { + Start, + Sleep { delay: Pin> }, + Done, + } + } + + impl Stream for TimeDelayStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.as_mut().get_mut() { + TimeDelayStream::Start => { + let sleep = sleep(Duration::from_millis(1)); + self.as_mut().set(TimeDelayStream::Sleep { + delay: Box::pin(sleep), + }); + cx.waker().wake_by_ref(); + Poll::Pending + } + + TimeDelayStream::Sleep { ref mut delay } => { + ready!(delay.poll_unpin(cx)); + self.set(TimeDelayStream::Done); + cx.waker().wake_by_ref(); + Poll::Pending + } + + TimeDelayStream::Done => Poll::Ready(Some(Err(StreamErr))), + } + } + } + + let body = BodyStream::new(TimeDelayStream::Start); + assert!(matches!(to_bytes(body).await, Err(StreamErr))); + } +} diff --git a/actix-http/src/body/boxed.rs b/actix-http/src/body/boxed.rs new file mode 100644 index 000000000..d109a6a74 --- /dev/null +++ b/actix-http/src/body/boxed.rs @@ -0,0 +1,127 @@ +use std::{ + error::Error as StdError, + fmt, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; + +use super::{BodySize, MessageBody, MessageBodyMapErr}; +use crate::body; + +/// A boxed message body with boxed errors. +#[derive(Debug)] +pub struct BoxBody(BoxBodyInner); + +enum BoxBodyInner { + None(body::None), + Bytes(Bytes), + Stream(Pin>>>), +} + +impl fmt::Debug for BoxBodyInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::None(arg0) => f.debug_tuple("None").field(arg0).finish(), + Self::Bytes(arg0) => f.debug_tuple("Bytes").field(arg0).finish(), + Self::Stream(_) => f.debug_tuple("Stream").field(&"dyn MessageBody").finish(), + } + } +} + +impl BoxBody { + /// Same as `MessageBody::boxed`. + /// + /// If the body type to wrap is unknown or generic it is better to use [`MessageBody::boxed`] to + /// avoid double boxing. + #[inline] + pub fn new(body: B) -> Self + where + B: MessageBody + 'static, + { + match body.size() { + BodySize::None => Self(BoxBodyInner::None(body::None)), + _ => match body.try_into_bytes() { + Ok(bytes) => Self(BoxBodyInner::Bytes(bytes)), + Err(body) => { + let body = MessageBodyMapErr::new(body, Into::into); + Self(BoxBodyInner::Stream(Box::pin(body))) + } + }, + } + } + + /// Returns a mutable pinned reference to the inner message body type. + #[inline] + pub fn as_pin_mut(&mut self) -> Pin<&mut Self> { + Pin::new(self) + } +} + +impl MessageBody for BoxBody { + type Error = Box; + + #[inline] + fn size(&self) -> BodySize { + match &self.0 { + BoxBodyInner::None(none) => none.size(), + BoxBodyInner::Bytes(bytes) => bytes.size(), + BoxBodyInner::Stream(stream) => stream.size(), + } + } + + #[inline] + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match &mut self.0 { + BoxBodyInner::None(body) => { + Pin::new(body).poll_next(cx).map_err(|err| match err {}) + } + BoxBodyInner::Bytes(body) => { + Pin::new(body).poll_next(cx).map_err(|err| match err {}) + } + BoxBodyInner::Stream(body) => Pin::new(body).poll_next(cx), + } + } + + #[inline] + fn try_into_bytes(self) -> Result { + match self.0 { + BoxBodyInner::None(body) => Ok(body.try_into_bytes().unwrap()), + BoxBodyInner::Bytes(body) => Ok(body.try_into_bytes().unwrap()), + _ => Err(self), + } + } + + #[inline] + fn boxed(self) -> BoxBody { + self + } +} + +#[cfg(test)] +mod tests { + + use static_assertions::{assert_impl_all, assert_not_impl_all}; + + use super::*; + use crate::body::to_bytes; + + assert_impl_all!(BoxBody: MessageBody, fmt::Debug, Unpin); + + assert_not_impl_all!(BoxBody: Send, Sync, Unpin); + + #[actix_rt::test] + async fn nested_boxed_body() { + let body = Bytes::from_static(&[1, 2, 3]); + let boxed_body = BoxBody::new(BoxBody::new(body)); + + assert_eq!( + to_bytes(boxed_body).await.unwrap(), + Bytes::from(vec![1, 2, 3]), + ); + } +} diff --git a/actix-http/src/body/either.rs b/actix-http/src/body/either.rs new file mode 100644 index 000000000..add1eab7c --- /dev/null +++ b/actix-http/src/body/either.rs @@ -0,0 +1,108 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use pin_project_lite::pin_project; + +use super::{BodySize, BoxBody, MessageBody}; +use crate::Error; + +pin_project! { + #[project = EitherBodyProj] + #[derive(Debug, Clone)] + pub enum EitherBody { + /// A body of type `L`. + Left { #[pin] body: L }, + + /// A body of type `R`. + Right { #[pin] body: R }, + } +} + +impl EitherBody { + /// Creates new `EitherBody` using left variant and boxed right variant. + #[inline] + pub fn new(body: L) -> Self { + Self::Left { body } + } +} + +impl EitherBody { + /// Creates new `EitherBody` using left variant. + #[inline] + pub fn left(body: L) -> Self { + Self::Left { body } + } + + /// Creates new `EitherBody` using right variant. + #[inline] + pub fn right(body: R) -> Self { + Self::Right { body } + } +} + +impl MessageBody for EitherBody +where + L: MessageBody + 'static, + R: MessageBody + 'static, +{ + type Error = Error; + + #[inline] + fn size(&self) -> BodySize { + match self { + EitherBody::Left { body } => body.size(), + EitherBody::Right { body } => body.size(), + } + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.project() { + EitherBodyProj::Left { body } => body + .poll_next(cx) + .map_err(|err| Error::new_body().with_cause(err)), + EitherBodyProj::Right { body } => body + .poll_next(cx) + .map_err(|err| Error::new_body().with_cause(err)), + } + } + + #[inline] + fn try_into_bytes(self) -> Result { + match self { + EitherBody::Left { body } => body + .try_into_bytes() + .map_err(|body| EitherBody::Left { body }), + EitherBody::Right { body } => body + .try_into_bytes() + .map_err(|body| EitherBody::Right { body }), + } + } + + #[inline] + fn boxed(self) -> BoxBody { + match self { + EitherBody::Left { body } => body.boxed(), + EitherBody::Right { body } => body.boxed(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn type_parameter_inference() { + let _body: EitherBody<(), _> = EitherBody::new(()); + + let _body: EitherBody<_, ()> = EitherBody::left(()); + let _body: EitherBody<(), _> = EitherBody::right(()); + } +} diff --git a/actix-http/src/body/message_body.rs b/actix-http/src/body/message_body.rs new file mode 100644 index 000000000..0a605a69a --- /dev/null +++ b/actix-http/src/body/message_body.rs @@ -0,0 +1,554 @@ +//! [`MessageBody`] trait and foreign implementations. + +use std::{ + convert::Infallible, + error::Error as StdError, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::{Bytes, BytesMut}; +use futures_core::ready; +use pin_project_lite::pin_project; + +use super::{BodySize, BoxBody}; + +/// An interface types that can converted to bytes and used as response bodies. +// TODO: examples +pub trait MessageBody { + /// The type of error that will be returned if streaming body fails. + /// + /// Since it is not appropriate to generate a response mid-stream, it only requires `Error` for + /// internal use and logging. + type Error: Into>; + + /// Body size hint. + /// + /// If [`BodySize::None`] is returned, optimizations that skip reading the body are allowed. + fn size(&self) -> BodySize; + + /// Attempt to pull out the next chunk of body bytes. + // TODO: expand documentation + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>>; + + /// Try to convert into the complete chunk of body bytes. + /// + /// Implement this method if the entire body can be trivially extracted. This is useful for + /// optimizations where `poll_next` calls can be avoided. + /// + /// Body types with [`BodySize::None`] are allowed to return empty `Bytes`. Although, if calling + /// this method, it is recommended to check `size` first and return early. + /// + /// # Errors + /// The default implementation will error and return the original type back to the caller for + /// further use. + #[inline] + fn try_into_bytes(self) -> Result + where + Self: Sized, + { + Err(self) + } + + /// Converts this body into `BoxBody`. + #[inline] + fn boxed(self) -> BoxBody + where + Self: Sized + 'static, + { + BoxBody::new(self) + } +} + +mod foreign_impls { + use super::*; + + impl MessageBody for Infallible { + type Error = Infallible; + + fn size(&self) -> BodySize { + match *self {} + } + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + match *self {} + } + } + + impl MessageBody for () { + type Error = Infallible; + + #[inline] + fn size(&self) -> BodySize { + BodySize::Sized(0) + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + Poll::Ready(None) + } + + #[inline] + fn try_into_bytes(self) -> Result { + Ok(Bytes::new()) + } + } + + impl MessageBody for Box + where + B: MessageBody + Unpin + ?Sized, + { + type Error = B::Error; + + #[inline] + fn size(&self) -> BodySize { + self.as_ref().size() + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + Pin::new(self.get_mut().as_mut()).poll_next(cx) + } + } + + impl MessageBody for Pin> + where + B: MessageBody + ?Sized, + { + type Error = B::Error; + + #[inline] + fn size(&self) -> BodySize { + self.as_ref().size() + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.get_mut().as_mut().poll_next(cx) + } + } + + impl MessageBody for &'static [u8] { + type Error = Infallible; + + #[inline] + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(Bytes::from_static(mem::take(self.get_mut()))))) + } + } + + #[inline] + fn try_into_bytes(self) -> Result { + Ok(Bytes::from_static(self)) + } + } + + impl MessageBody for Bytes { + type Error = Infallible; + + #[inline] + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(mem::take(self.get_mut())))) + } + } + + #[inline] + fn try_into_bytes(self) -> Result { + Ok(self) + } + } + + impl MessageBody for BytesMut { + type Error = Infallible; + + #[inline] + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(mem::take(self.get_mut()).freeze()))) + } + } + + #[inline] + fn try_into_bytes(self) -> Result { + Ok(self.freeze()) + } + } + + impl MessageBody for Vec { + type Error = Infallible; + + #[inline] + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(mem::take(self.get_mut()).into()))) + } + } + + #[inline] + fn try_into_bytes(self) -> Result { + Ok(Bytes::from(self)) + } + } + + impl MessageBody for &'static str { + type Error = Infallible; + + #[inline] + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + let string = mem::take(self.get_mut()); + let bytes = Bytes::from_static(string.as_bytes()); + Poll::Ready(Some(Ok(bytes))) + } + } + + #[inline] + fn try_into_bytes(self) -> Result { + Ok(Bytes::from_static(self.as_bytes())) + } + } + + impl MessageBody for String { + type Error = Infallible; + + #[inline] + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.is_empty() { + Poll::Ready(None) + } else { + let string = mem::take(self.get_mut()); + Poll::Ready(Some(Ok(Bytes::from(string)))) + } + } + + #[inline] + fn try_into_bytes(self) -> Result { + Ok(Bytes::from(self)) + } + } + + impl MessageBody for bytestring::ByteString { + type Error = Infallible; + + #[inline] + fn size(&self) -> BodySize { + BodySize::Sized(self.len() as u64) + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + let string = mem::take(self.get_mut()); + Poll::Ready(Some(Ok(string.into_bytes()))) + } + + #[inline] + fn try_into_bytes(self) -> Result { + Ok(self.into_bytes()) + } + } +} + +pin_project! { + pub(crate) struct MessageBodyMapErr { + #[pin] + body: B, + mapper: Option, + } +} + +impl MessageBodyMapErr +where + B: MessageBody, + F: FnOnce(B::Error) -> E, +{ + pub(crate) fn new(body: B, mapper: F) -> Self { + Self { + body, + mapper: Some(mapper), + } + } +} + +impl MessageBody for MessageBodyMapErr +where + B: MessageBody, + F: FnOnce(B::Error) -> E, + E: Into>, +{ + type Error = E; + + #[inline] + fn size(&self) -> BodySize { + self.body.size() + } + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let this = self.as_mut().project(); + + match ready!(this.body.poll_next(cx)) { + Some(Err(err)) => { + let f = self.as_mut().project().mapper.take().unwrap(); + let mapped_err = (f)(err); + Poll::Ready(Some(Err(mapped_err))) + } + Some(Ok(val)) => Poll::Ready(Some(Ok(val))), + None => Poll::Ready(None), + } + } + + #[inline] + fn try_into_bytes(self) -> Result { + let Self { body, mapper } = self; + body.try_into_bytes().map_err(|body| Self { body, mapper }) + } +} + +#[cfg(test)] +mod tests { + use actix_rt::pin; + use actix_utils::future::poll_fn; + use bytes::{Bytes, BytesMut}; + + use super::*; + use crate::body::{self, EitherBody}; + + macro_rules! assert_poll_next { + ($pin:expr, $exp:expr) => { + assert_eq!( + poll_fn(|cx| $pin.as_mut().poll_next(cx)) + .await + .unwrap() // unwrap option + .unwrap(), // unwrap result + $exp + ); + }; + } + + macro_rules! assert_poll_next_none { + ($pin:expr) => { + assert!(poll_fn(|cx| $pin.as_mut().poll_next(cx)).await.is_none()); + }; + } + + #[actix_rt::test] + async fn boxing_equivalence() { + assert_eq!(().size(), BodySize::Sized(0)); + assert_eq!(().size(), Box::new(()).size()); + assert_eq!(().size(), Box::pin(()).size()); + + let pl = Box::new(()); + pin!(pl); + assert_poll_next_none!(pl); + + let mut pl = Box::pin(()); + assert_poll_next_none!(pl); + } + + #[actix_rt::test] + async fn test_unit() { + let pl = (); + assert_eq!(pl.size(), BodySize::Sized(0)); + pin!(pl); + assert_poll_next_none!(pl); + } + + #[actix_rt::test] + async fn test_static_str() { + assert_eq!("".size(), BodySize::Sized(0)); + assert_eq!("test".size(), BodySize::Sized(4)); + + let pl = "test"; + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + #[actix_rt::test] + async fn test_static_bytes() { + assert_eq!(b"".as_ref().size(), BodySize::Sized(0)); + assert_eq!(b"test".as_ref().size(), BodySize::Sized(4)); + + let pl = b"test".as_ref(); + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + #[actix_rt::test] + async fn test_vec() { + assert_eq!(vec![0; 0].size(), BodySize::Sized(0)); + assert_eq!(Vec::from("test").size(), BodySize::Sized(4)); + + let pl = Vec::from("test"); + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + #[actix_rt::test] + async fn test_bytes() { + assert_eq!(Bytes::new().size(), BodySize::Sized(0)); + assert_eq!(Bytes::from_static(b"test").size(), BodySize::Sized(4)); + + let pl = Bytes::from_static(b"test"); + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + #[actix_rt::test] + async fn test_bytes_mut() { + assert_eq!(BytesMut::new().size(), BodySize::Sized(0)); + assert_eq!(BytesMut::from(b"test".as_ref()).size(), BodySize::Sized(4)); + + let pl = BytesMut::from("test"); + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + #[actix_rt::test] + async fn test_string() { + assert_eq!(String::new().size(), BodySize::Sized(0)); + assert_eq!("test".to_owned().size(), BodySize::Sized(4)); + + let pl = "test".to_owned(); + pin!(pl); + assert_poll_next!(pl, Bytes::from("test")); + } + + #[actix_rt::test] + async fn complete_body_combinators() { + let body = Bytes::from_static(b"test"); + let body = BoxBody::new(body); + let body = EitherBody::<_, ()>::left(body); + let body = EitherBody::<(), _>::right(body); + // Do not support try_into_bytes: + // let body = Box::new(body); + // let body = Box::pin(body); + + assert_eq!(body.try_into_bytes().unwrap(), Bytes::from("test")); + } + + #[actix_rt::test] + async fn complete_body_combinators_poll() { + let body = Bytes::from_static(b"test"); + let body = BoxBody::new(body); + let body = EitherBody::<_, ()>::left(body); + let body = EitherBody::<(), _>::right(body); + let mut body = body; + + assert_eq!(body.size(), BodySize::Sized(4)); + assert_poll_next!(Pin::new(&mut body), Bytes::from("test")); + assert_poll_next_none!(Pin::new(&mut body)); + } + + #[actix_rt::test] + async fn none_body_combinators() { + fn none_body() -> BoxBody { + let body = body::None; + let body = BoxBody::new(body); + let body = EitherBody::<_, ()>::left(body); + let body = EitherBody::<(), _>::right(body); + body.boxed() + } + + assert_eq!(none_body().size(), BodySize::None); + assert_eq!(none_body().try_into_bytes().unwrap(), Bytes::new()); + assert_poll_next_none!(Pin::new(&mut none_body())); + } + + // down-casting used to be done with a method on MessageBody trait + // test is kept to demonstrate equivalence of Any trait + #[actix_rt::test] + async fn test_body_casting() { + let mut body = String::from("hello cast"); + // let mut resp_body: &mut dyn MessageBody = &mut body; + let resp_body: &mut dyn std::any::Any = &mut body; + let body = resp_body.downcast_ref::().unwrap(); + assert_eq!(body, "hello cast"); + let body = &mut resp_body.downcast_mut::().unwrap(); + body.push('!'); + let body = resp_body.downcast_ref::().unwrap(); + assert_eq!(body, "hello cast!"); + let not_body = resp_body.downcast_ref::<()>(); + assert!(not_body.is_none()); + } +} diff --git a/actix-http/src/body/mod.rs b/actix-http/src/body/mod.rs new file mode 100644 index 000000000..af7c4626f --- /dev/null +++ b/actix-http/src/body/mod.rs @@ -0,0 +1,20 @@ +//! Traits and structures to aid consuming and writing HTTP payloads. + +mod body_stream; +mod boxed; +mod either; +mod message_body; +mod none; +mod size; +mod sized_stream; +mod utils; + +pub use self::body_stream::BodyStream; +pub use self::boxed::BoxBody; +pub use self::either::EitherBody; +pub use self::message_body::MessageBody; +pub(crate) use self::message_body::MessageBodyMapErr; +pub use self::none::None; +pub use self::size::BodySize; +pub use self::sized_stream::SizedStream; +pub use self::utils::to_bytes; diff --git a/actix-http/src/body/none.rs b/actix-http/src/body/none.rs new file mode 100644 index 000000000..0e7bbe5a9 --- /dev/null +++ b/actix-http/src/body/none.rs @@ -0,0 +1,48 @@ +use std::{ + convert::Infallible, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; + +use super::{BodySize, MessageBody}; + +/// Body type for responses that forbid payloads. +/// +/// Distinct from an empty response which would contain a Content-Length header. +/// +/// For an "empty" body, use `()` or `Bytes::new()`. +#[derive(Debug, Clone, Copy, Default)] +#[non_exhaustive] +pub struct None; + +impl None { + /// Constructs new "none" body. + #[inline] + pub fn new() -> Self { + None + } +} + +impl MessageBody for None { + type Error = Infallible; + + #[inline] + fn size(&self) -> BodySize { + BodySize::None + } + + #[inline] + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + Poll::Ready(Option::None) + } + + #[inline] + fn try_into_bytes(self) -> Result { + Ok(Bytes::new()) + } +} diff --git a/actix-http/src/body/size.rs b/actix-http/src/body/size.rs new file mode 100644 index 000000000..ec7873ca5 --- /dev/null +++ b/actix-http/src/body/size.rs @@ -0,0 +1,41 @@ +/// Body size hint. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BodySize { + /// Implicitly empty body. + /// + /// Will omit the Content-Length header. Used for responses to certain methods (e.g., `HEAD`) or + /// with particular status codes (e.g., 204 No Content). Consumers that read this as a body size + /// hint are allowed to make optimizations that skip reading or writing the payload. + None, + + /// Known size body. + /// + /// Will write `Content-Length: N` header. + Sized(u64), + + /// Unknown size body. + /// + /// Will not write Content-Length header. Can be used with chunked Transfer-Encoding. + Stream, +} + +impl BodySize { + /// Equivalent to `BodySize::Sized(0)`; + pub const ZERO: Self = Self::Sized(0); + + /// Returns true if size hint indicates omitted or empty body. + /// + /// Streams will return false because it cannot be known without reading the stream. + /// + /// ``` + /// # use actix_http::body::BodySize; + /// assert!(BodySize::None.is_eof()); + /// assert!(BodySize::Sized(0).is_eof()); + /// + /// assert!(!BodySize::Sized(64).is_eof()); + /// assert!(!BodySize::Stream.is_eof()); + /// ``` + pub fn is_eof(&self) -> bool { + matches!(self, BodySize::None | BodySize::Sized(0)) + } +} diff --git a/actix-http/src/body/sized_stream.rs b/actix-http/src/body/sized_stream.rs new file mode 100644 index 000000000..9c1727246 --- /dev/null +++ b/actix-http/src/body/sized_stream.rs @@ -0,0 +1,171 @@ +use std::{ + error::Error as StdError, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures_core::{ready, Stream}; +use pin_project_lite::pin_project; + +use super::{BodySize, MessageBody}; + +pin_project! { + /// Known sized streaming response wrapper. + /// + /// This body implementation should be used if total size of stream is known. Data is sent as-is + /// without using chunked transfer encoding. + pub struct SizedStream { + size: u64, + #[pin] + stream: S, + } +} + +impl SizedStream +where + S: Stream>, + E: Into> + 'static, +{ + #[inline] + pub fn new(size: u64, stream: S) -> Self { + SizedStream { size, stream } + } +} + +// TODO: from_infallible method + +impl MessageBody for SizedStream +where + S: Stream>, + E: Into> + 'static, +{ + type Error = E; + + #[inline] + fn size(&self) -> BodySize { + BodySize::Sized(self.size as u64) + } + + /// Attempts to pull out the next value of the underlying [`Stream`]. + /// + /// Empty values are skipped to prevent [`SizedStream`]'s transmission being + /// ended on a zero-length chunk, but rather proceed until the underlying + /// [`Stream`] ends. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + let stream = self.as_mut().project().stream; + + let chunk = match ready!(stream.poll_next(cx)) { + Some(Ok(ref bytes)) if bytes.is_empty() => continue, + val => val, + }; + + return Poll::Ready(chunk); + } + } +} + +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use actix_rt::pin; + use actix_utils::future::poll_fn; + use futures_util::stream; + use static_assertions::{assert_impl_all, assert_not_impl_all}; + + use super::*; + use crate::body::to_bytes; + + assert_impl_all!(SizedStream>>: MessageBody); + assert_impl_all!(SizedStream>>: MessageBody); + assert_impl_all!(SizedStream>>: MessageBody); + assert_impl_all!(SizedStream>>: MessageBody); + assert_impl_all!(SizedStream>>: MessageBody); + + assert_not_impl_all!(SizedStream>: MessageBody); + assert_not_impl_all!(SizedStream>: MessageBody); + // crate::Error is not Clone + assert_not_impl_all!(SizedStream>>: MessageBody); + + #[actix_rt::test] + async fn skips_empty_chunks() { + let body = SizedStream::new( + 2, + stream::iter( + ["1", "", "2"] + .iter() + .map(|&v| Ok::<_, Infallible>(Bytes::from(v))), + ), + ); + + pin!(body); + + assert_eq!( + poll_fn(|cx| body.as_mut().poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("1")), + ); + + assert_eq!( + poll_fn(|cx| body.as_mut().poll_next(cx)) + .await + .unwrap() + .ok(), + Some(Bytes::from("2")), + ); + } + + #[actix_rt::test] + async fn read_to_bytes() { + let body = SizedStream::new( + 2, + stream::iter( + ["1", "", "2"] + .iter() + .map(|&v| Ok::<_, Infallible>(Bytes::from(v))), + ), + ); + + assert_eq!(to_bytes(body).await.ok(), Some(Bytes::from("12"))); + } + + #[actix_rt::test] + async fn stream_string_error() { + // `&'static str` does not impl `Error` + // but it does impl `Into>` + + let body = SizedStream::new(0, stream::once(async { Err("stringy error") })); + assert_eq!(to_bytes(body).await, Ok(Bytes::new())); + + let body = SizedStream::new(1, stream::once(async { Err("stringy error") })); + assert!(matches!(to_bytes(body).await, Err("stringy error"))); + } + + #[actix_rt::test] + async fn stream_boxed_error() { + // `Box` does not impl `Error` + // but it does impl `Into>` + + let body = SizedStream::new( + 0, + stream::once(async { Err(Box::::from("stringy error")) }), + ); + assert_eq!(to_bytes(body).await.unwrap(), Bytes::new()); + + let body = SizedStream::new( + 1, + stream::once(async { Err(Box::::from("stringy error")) }), + ); + assert_eq!( + to_bytes(body).await.unwrap_err().to_string(), + "stringy error" + ); + } +} diff --git a/actix-http/src/body/utils.rs b/actix-http/src/body/utils.rs new file mode 100644 index 000000000..194af47f8 --- /dev/null +++ b/actix-http/src/body/utils.rs @@ -0,0 +1,77 @@ +use std::task::Poll; + +use actix_rt::pin; +use actix_utils::future::poll_fn; +use bytes::{Bytes, BytesMut}; +use futures_core::ready; + +use super::{BodySize, MessageBody}; + +/// Collects the body produced by a `MessageBody` implementation into `Bytes`. +/// +/// Any errors produced by the body stream are returned immediately. +/// +/// # Examples +/// ``` +/// use actix_http::body::{self, to_bytes}; +/// use bytes::Bytes; +/// +/// # async fn test_to_bytes() { +/// let body = body::None::new(); +/// let bytes = to_bytes(body).await.unwrap(); +/// assert!(bytes.is_empty()); +/// +/// let body = Bytes::from_static(b"123"); +/// let bytes = to_bytes(body).await.unwrap(); +/// assert_eq!(bytes, b"123"[..]); +/// # } +/// ``` +pub async fn to_bytes(body: B) -> Result { + let cap = match body.size() { + BodySize::None | BodySize::Sized(0) => return Ok(Bytes::new()), + BodySize::Sized(size) => size as usize, + // good enough first guess for chunk size + BodySize::Stream => 32_768, + }; + + let mut buf = BytesMut::with_capacity(cap); + + pin!(body); + + poll_fn(|cx| loop { + let body = body.as_mut(); + + match ready!(body.poll_next(cx)) { + Some(Ok(bytes)) => buf.extend_from_slice(&*bytes), + None => return Poll::Ready(Ok(())), + Some(Err(err)) => return Poll::Ready(Err(err)), + } + }) + .await?; + + Ok(buf.freeze()) +} + +#[cfg(test)] +mod test { + use futures_util::{stream, StreamExt as _}; + + use super::*; + use crate::{body::BodyStream, Error}; + + #[actix_rt::test] + async fn test_to_bytes() { + let bytes = to_bytes(()).await.unwrap(); + assert!(bytes.is_empty()); + + let body = Bytes::from_static(b"123"); + let bytes = to_bytes(body).await.unwrap(); + assert_eq!(bytes, b"123"[..]); + + let stream = stream::iter(vec![Bytes::from_static(b"123"), Bytes::from_static(b"abc")]) + .map(Ok::<_, Error>); + let body = BodyStream::new(stream); + let bytes = to_bytes(body).await.unwrap(); + assert_eq!(bytes, b"123abc"[..]); + } +} diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index fa430c4fe..408ee7924 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -1,19 +1,16 @@ -use std::marker::PhantomData; -use std::rc::Rc; -use std::{fmt, net}; +use std::{fmt, marker::PhantomData, net, rc::Rc}; use actix_codec::Framed; use actix_service::{IntoServiceFactory, Service, ServiceFactory}; -use crate::body::MessageBody; -use crate::config::{KeepAlive, ServiceConfig}; -use crate::error::Error; -use crate::h1::{Codec, ExpectHandler, H1Service, UpgradeHandler}; -use crate::h2::H2Service; -use crate::request::Request; -use crate::response::Response; -use crate::service::HttpService; -use crate::{ConnectCallback, Extensions}; +use crate::{ + body::{BoxBody, MessageBody}, + config::{KeepAlive, ServiceConfig}, + h1::{self, ExpectHandler, H1Service, UpgradeHandler}, + h2::H2Service, + service::HttpService, + ConnectCallback, Extensions, Request, Response, +}; /// A HTTP service builder /// @@ -34,11 +31,12 @@ pub struct HttpServiceBuilder { impl HttpServiceBuilder where S: ServiceFactory, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, >::Future: 'static, { /// Create instance of `ServiceConfigBuilder` + #[allow(clippy::new_without_default)] pub fn new() -> Self { HttpServiceBuilder { keep_alive: KeepAlive::Timeout(5), @@ -57,17 +55,15 @@ where impl HttpServiceBuilder where S: ServiceFactory, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, >::Future: 'static, X: ServiceFactory, - X::Error: Into, + X::Error: Into>, X::InitError: fmt::Debug, - >::Future: 'static, - U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, + U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U::Error: fmt::Display, U::InitError: fmt::Debug, - )>>::Future: 'static, { /// Set server keep-alive setting. /// @@ -125,9 +121,8 @@ where where F: IntoServiceFactory, X1: ServiceFactory, - X1::Error: Into, + X1::Error: Into>, X1::InitError: fmt::Debug, - >::Future: 'static, { HttpServiceBuilder { keep_alive: self.keep_alive, @@ -148,11 +143,10 @@ where /// and this service get called with original request and framed object. pub fn upgrade(self, upgrade: F) -> HttpServiceBuilder where - F: IntoServiceFactory)>, - U1: ServiceFactory<(Request, Framed), Config = (), Response = ()>, + F: IntoServiceFactory)>, + U1: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U1::Error: fmt::Display, U1::InitError: fmt::Debug, - )>>::Future: 'static, { HttpServiceBuilder { keep_alive: self.keep_alive, @@ -185,7 +179,7 @@ where where B: MessageBody, F: IntoServiceFactory, - S::Error: Into, + S::Error: Into>, S::InitError: fmt::Debug, S::Response: Into>, { @@ -206,12 +200,12 @@ where /// Finish service configuration and create a HTTP service for HTTP/2 protocol. pub fn h2(self, service: F) -> H2Service where - B: MessageBody + 'static, F: IntoServiceFactory, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - >::Future: 'static, + + B: MessageBody + 'static, { let cfg = ServiceConfig::new( self.keep_alive, @@ -221,19 +215,18 @@ where self.local_addr, ); - H2Service::with_config(cfg, service.into_factory()) - .on_connect_ext(self.on_connect_ext) + H2Service::with_config(cfg, service.into_factory()).on_connect_ext(self.on_connect_ext) } /// Finish service configuration and create `HttpService` instance. pub fn finish(self, service: F) -> HttpService where - B: MessageBody + 'static, F: IntoServiceFactory, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, - >::Future: 'static, + + B: MessageBody + 'static, { let cfg = ServiceConfig::new( self.keep_alive, diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs deleted file mode 100644 index 4c6a6dcb8..000000000 --- a/actix-http/src/client/connection.rs +++ /dev/null @@ -1,348 +0,0 @@ -use std::future::Future; -use std::ops::{Deref, DerefMut}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{fmt, io, time}; - -use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; -use actix_rt::task::JoinHandle; -use bytes::Bytes; -use futures_core::future::LocalBoxFuture; -use futures_util::future::{err, Either, FutureExt, Ready}; -use h2::client::SendRequest; -use pin_project::pin_project; - -use crate::body::MessageBody; -use crate::h1::ClientCodec; -use crate::message::{RequestHeadType, ResponseHead}; -use crate::payload::Payload; - -use super::error::SendRequestError; -use super::pool::{Acquired, Protocol}; -use super::{h1proto, h2proto}; - -pub(crate) enum ConnectionType { - H1(Io), - H2(H2Connection), -} - -// h2 connection has two parts: SendRequest and Connection. -// Connection is spawned as async task on runtime and H2Connection would hold a handle for -// this task. So it can wake up and quit the task when SendRequest is dropped. -pub(crate) struct H2Connection { - handle: JoinHandle<()>, - sender: SendRequest, -} - -impl H2Connection { - pub(crate) fn new( - sender: SendRequest, - connection: h2::client::Connection, - ) -> Self - where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - { - let handle = actix_rt::spawn(async move { - let _ = connection.await; - }); - - Self { handle, sender } - } -} - -// wake up waker when drop -impl Drop for H2Connection { - fn drop(&mut self) { - self.handle.abort(); - } -} - -// only expose sender type to public. -impl Deref for H2Connection { - type Target = SendRequest; - - fn deref(&self) -> &Self::Target { - &self.sender - } -} - -impl DerefMut for H2Connection { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.sender - } -} - -pub trait Connection { - type Io: AsyncRead + AsyncWrite + Unpin; - type Future: Future>; - - fn protocol(&self) -> Protocol; - - /// Send request and body - fn send_request>( - self, - head: H, - body: B, - ) -> Self::Future; - - type TunnelFuture: Future< - Output = Result<(ResponseHead, Framed), SendRequestError>, - >; - - /// Send request, returns Response and Framed - fn open_tunnel>(self, head: H) -> Self::TunnelFuture; -} - -pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { - /// Close connection - fn close(self: Pin<&mut Self>); - - /// Release connection to the connection pool - fn release(self: Pin<&mut Self>); -} - -#[doc(hidden)] -/// HTTP client connection -pub struct IoConnection { - io: Option>, - created: time::Instant, - pool: Option>, -} - -impl fmt::Debug for IoConnection -where - T: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.io { - Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io), - Some(ConnectionType::H2(_)) => write!(f, "H2Connection"), - None => write!(f, "Connection(Empty)"), - } - } -} - -impl IoConnection { - pub(crate) fn new( - io: ConnectionType, - created: time::Instant, - pool: Option>, - ) -> Self { - IoConnection { - pool, - created, - io: Some(io), - } - } - - pub(crate) fn into_inner(self) -> (ConnectionType, time::Instant) { - (self.io.unwrap(), self.created) - } -} - -impl Connection for IoConnection -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ - type Io = T; - type Future = - LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; - - fn protocol(&self) -> Protocol { - match self.io { - Some(ConnectionType::H1(_)) => Protocol::Http1, - Some(ConnectionType::H2(_)) => Protocol::Http2, - None => Protocol::Http1, - } - } - - fn send_request>( - mut self, - head: H, - body: B, - ) -> Self::Future { - match self.io.take().unwrap() { - ConnectionType::H1(io) => { - h1proto::send_request(io, head.into(), body, self.created, self.pool) - .boxed_local() - } - ConnectionType::H2(io) => { - h2proto::send_request(io, head.into(), body, self.created, self.pool) - .boxed_local() - } - } - } - - type TunnelFuture = Either< - LocalBoxFuture< - 'static, - Result<(ResponseHead, Framed), SendRequestError>, - >, - Ready), SendRequestError>>, - >; - - /// Send request, returns Response and Framed - fn open_tunnel>(mut self, head: H) -> Self::TunnelFuture { - match self.io.take().unwrap() { - ConnectionType::H1(io) => { - Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local()) - } - ConnectionType::H2(io) => { - if let Some(mut pool) = self.pool.take() { - pool.release(IoConnection::new( - ConnectionType::H2(io), - self.created, - None, - )); - } - Either::Right(err(SendRequestError::TunnelNotSupported)) - } - } - } -} - -#[allow(dead_code)] -pub(crate) enum EitherConnection { - A(IoConnection), - B(IoConnection), -} - -impl Connection for EitherConnection -where - A: AsyncRead + AsyncWrite + Unpin + 'static, - B: AsyncRead + AsyncWrite + Unpin + 'static, -{ - type Io = EitherIo; - type Future = - LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; - - fn protocol(&self) -> Protocol { - match self { - EitherConnection::A(con) => con.protocol(), - EitherConnection::B(con) => con.protocol(), - } - } - - fn send_request>( - self, - head: H, - body: RB, - ) -> Self::Future { - match self { - EitherConnection::A(con) => con.send_request(head, body), - EitherConnection::B(con) => con.send_request(head, body), - } - } - - type TunnelFuture = LocalBoxFuture< - 'static, - Result<(ResponseHead, Framed), SendRequestError>, - >; - - /// Send request, returns Response and Framed - fn open_tunnel>(self, head: H) -> Self::TunnelFuture { - match self { - EitherConnection::A(con) => con - .open_tunnel(head) - .map(|res| { - res.map(|(head, framed)| (head, framed.into_map_io(EitherIo::A))) - }) - .boxed_local(), - EitherConnection::B(con) => con - .open_tunnel(head) - .map(|res| { - res.map(|(head, framed)| (head, framed.into_map_io(EitherIo::B))) - }) - .boxed_local(), - } - } -} - -#[pin_project(project = EitherIoProj)] -pub enum EitherIo { - A(#[pin] A), - B(#[pin] B), -} - -impl AsyncRead for EitherIo -where - A: AsyncRead, - B: AsyncRead, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match self.project() { - EitherIoProj::A(val) => val.poll_read(cx, buf), - EitherIoProj::B(val) => val.poll_read(cx, buf), - } - } -} - -impl AsyncWrite for EitherIo -where - A: AsyncWrite, - B: AsyncWrite, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.project() { - EitherIoProj::A(val) => val.poll_write(cx, buf), - EitherIoProj::B(val) => val.poll_write(cx, buf), - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { - EitherIoProj::A(val) => val.poll_flush(cx), - EitherIoProj::B(val) => val.poll_flush(cx), - } - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.project() { - EitherIoProj::A(val) => val.poll_shutdown(cx), - EitherIoProj::B(val) => val.poll_shutdown(cx), - } - } -} - -#[cfg(test)] -mod test { - use std::net; - - use actix_rt::net::TcpStream; - - use super::*; - - #[actix_rt::test] - async fn test_h2_connection_drop() { - let addr = "127.0.0.1:0".parse::().unwrap(); - let listener = net::TcpListener::bind(addr).unwrap(); - let local = listener.local_addr().unwrap(); - - std::thread::spawn(move || while listener.accept().is_ok() {}); - - let tcp = TcpStream::connect(local).await.unwrap(); - let (sender, connection) = h2::client::handshake(tcp).await.unwrap(); - let conn = H2Connection::new(sender.clone(), connection); - - assert!(sender.clone().ready().await.is_ok()); - assert!(h2::client::SendRequest::clone(&*conn).ready().await.is_ok()); - - drop(conn); - - match sender.ready().await { - Ok(_) => panic!("connection should be gone and can not be ready"), - Err(e) => assert!(e.is_io()), - }; - } -} diff --git a/actix-http/src/client/connector.rs b/actix-http/src/client/connector.rs deleted file mode 100644 index 3bf424d49..000000000 --- a/actix-http/src/client/connector.rs +++ /dev/null @@ -1,614 +0,0 @@ -use std::fmt; -use std::marker::PhantomData; -use std::time::Duration; - -use actix_codec::{AsyncRead, AsyncWrite}; -use actix_rt::net::TcpStream; -use actix_service::{apply_fn, Service, ServiceExt}; -use actix_tls::connect::{ - new_connector, Connect as TcpConnect, Connection as TcpConnection, Resolver, -}; -use actix_utils::timeout::{TimeoutError, TimeoutService}; -use http::Uri; - -use super::config::ConnectorConfig; -use super::connection::Connection; -use super::error::ConnectError; -use super::pool::{ConnectionPool, Protocol}; -use super::Connect; - -#[cfg(feature = "openssl")] -use actix_tls::connect::ssl::openssl::SslConnector as OpensslConnector; -#[cfg(feature = "rustls")] -use actix_tls::connect::ssl::rustls::ClientConfig; -#[cfg(feature = "rustls")] -use std::sync::Arc; - -#[cfg(any(feature = "openssl", feature = "rustls"))] -enum SslConnector { - #[cfg(feature = "openssl")] - Openssl(OpensslConnector), - #[cfg(feature = "rustls")] - Rustls(Arc), -} -#[cfg(not(any(feature = "openssl", feature = "rustls")))] -type SslConnector = (); - -/// Manages HTTP client network connectivity. -/// -/// The `Connector` type uses a builder-like combinator pattern for service -/// construction that finishes by calling the `.finish()` method. -/// -/// ```rust,ignore -/// use std::time::Duration; -/// use actix_http::client::Connector; -/// -/// let connector = Connector::new() -/// .timeout(Duration::from_secs(5)) -/// .finish(); -/// ``` -pub struct Connector { - connector: T, - config: ConnectorConfig, - #[allow(dead_code)] - ssl: SslConnector, - _phantom: PhantomData, -} - -trait Io: AsyncRead + AsyncWrite + Unpin {} -impl Io for T {} - -impl Connector<(), ()> { - #[allow(clippy::new_ret_no_self, clippy::let_unit_value)] - pub fn new() -> Connector< - impl Service< - TcpConnect, - Response = TcpConnection, - Error = actix_tls::connect::ConnectError, - > + Clone, - TcpStream, - > { - Connector { - ssl: Self::build_ssl(vec![b"h2".to_vec(), b"http/1.1".to_vec()]), - connector: new_connector(resolver::resolver()), - config: ConnectorConfig::default(), - _phantom: PhantomData, - } - } - - // Build Ssl connector with openssl, based on supplied alpn protocols - #[cfg(feature = "openssl")] - fn build_ssl(protocols: Vec>) -> SslConnector { - use actix_tls::connect::ssl::openssl::SslMethod; - use bytes::{BufMut, BytesMut}; - - let mut alpn = BytesMut::with_capacity(20); - for proto in protocols.iter() { - alpn.put_u8(proto.len() as u8); - alpn.put(proto.as_slice()); - } - - let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap(); - let _ = ssl - .set_alpn_protos(&alpn) - .map_err(|e| error!("Can not set alpn protocol: {:?}", e)); - SslConnector::Openssl(ssl.build()) - } - - // Build Ssl connector with rustls, based on supplied alpn protocols - #[cfg(all(not(feature = "openssl"), feature = "rustls"))] - fn build_ssl(protocols: Vec>) -> SslConnector { - let mut config = ClientConfig::new(); - config.set_protocols(&protocols); - config.root_store.add_server_trust_anchors( - &actix_tls::connect::ssl::rustls::TLS_SERVER_ROOTS, - ); - SslConnector::Rustls(Arc::new(config)) - } - - // ssl turned off, provides empty ssl connector - #[cfg(not(any(feature = "openssl", feature = "rustls")))] - fn build_ssl(_: Vec>) -> SslConnector {} -} - -impl Connector { - /// Use custom connector. - pub fn connector(self, connector: T1) -> Connector - where - U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug, - T1: Service< - TcpConnect, - Response = TcpConnection, - Error = actix_tls::connect::ConnectError, - > + Clone, - { - Connector { - connector, - config: self.config, - ssl: self.ssl, - _phantom: PhantomData, - } - } -} - -impl Connector -where - U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, - T: Service< - TcpConnect, - Response = TcpConnection, - Error = actix_tls::connect::ConnectError, - > + Clone - + 'static, -{ - /// Connection timeout, i.e. max time to connect to remote host including dns name resolution. - /// Set to 1 second by default. - pub fn timeout(mut self, timeout: Duration) -> Self { - self.config.timeout = timeout; - self - } - - #[cfg(feature = "openssl")] - /// Use custom `SslConnector` instance. - pub fn ssl(mut self, connector: OpensslConnector) -> Self { - self.ssl = SslConnector::Openssl(connector); - self - } - - #[cfg(feature = "rustls")] - pub fn rustls(mut self, connector: Arc) -> Self { - self.ssl = SslConnector::Rustls(connector); - self - } - - /// Maximum supported HTTP major version. - /// - /// Supported versions are HTTP/1.1 and HTTP/2. - pub fn max_http_version(mut self, val: http::Version) -> Self { - let versions = match val { - http::Version::HTTP_11 => vec![b"http/1.1".to_vec()], - http::Version::HTTP_2 => vec![b"h2".to_vec(), b"http/1.1".to_vec()], - _ => { - unimplemented!("actix-http:client: supported versions http/1.1, http/2") - } - }; - self.ssl = Connector::build_ssl(versions); - self - } - - /// Indicates the initial window size (in octets) for - /// HTTP2 stream-level flow control for received data. - /// - /// The default value is 65,535 and is good for APIs, but not for big objects. - pub fn initial_window_size(mut self, size: u32) -> Self { - self.config.stream_window_size = size; - self - } - - /// Indicates the initial window size (in octets) for - /// HTTP2 connection-level flow control for received data. - /// - /// The default value is 65,535 and is good for APIs, but not for big objects. - pub fn initial_connection_window_size(mut self, size: u32) -> Self { - self.config.conn_window_size = size; - self - } - - /// Set total number of simultaneous connections per type of scheme. - /// - /// If limit is 0, the connector has no limit. - /// The default limit size is 100. - pub fn limit(mut self, limit: usize) -> Self { - self.config.limit = limit; - self - } - - /// Set keep-alive period for opened connection. - /// - /// Keep-alive period is the period between connection usage. If - /// the delay between repeated usages of the same connection - /// exceeds this period, the connection is closed. - /// Default keep-alive period is 15 seconds. - pub fn conn_keep_alive(mut self, dur: Duration) -> Self { - self.config.conn_keep_alive = dur; - self - } - - /// Set max lifetime period for connection. - /// - /// Connection lifetime is max lifetime of any opened connection - /// until it is closed regardless of keep-alive period. - /// Default lifetime period is 75 seconds. - pub fn conn_lifetime(mut self, dur: Duration) -> Self { - self.config.conn_lifetime = dur; - self - } - - /// Set server connection disconnect timeout in milliseconds. - /// - /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete - /// within this time, the socket get dropped. This timeout affects only secure connections. - /// - /// To disable timeout set value to 0. - /// - /// By default disconnect timeout is set to 3000 milliseconds. - pub fn disconnect_timeout(mut self, dur: Duration) -> Self { - self.config.disconnect_timeout = Some(dur); - self - } - - /// Finish configuration process and create connector service. - /// The Connector builder always concludes by calling `finish()` last in - /// its combinator chain. - pub fn finish( - self, - ) -> impl Service + Clone - { - #[cfg(not(any(feature = "openssl", feature = "rustls")))] - { - let connector = TimeoutService::new( - self.config.timeout, - apply_fn(self.connector, |msg: Connect, srv| { - srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) - }) - .map_err(ConnectError::from) - .map(|stream| (stream.into_parts().0, Protocol::Http1)), - ) - .map_err(|e| match e { - TimeoutError::Service(e) => e, - TimeoutError::Timeout => ConnectError::Timeout, - }); - - connect_impl::InnerConnector { - tcp_pool: ConnectionPool::new( - connector, - self.config.no_disconnect_timeout(), - ), - } - } - #[cfg(any(feature = "openssl", feature = "rustls"))] - { - const H2: &[u8] = b"h2"; - use actix_service::{boxed::service, pipeline}; - #[cfg(feature = "openssl")] - use actix_tls::connect::ssl::openssl::OpensslConnector; - #[cfg(feature = "rustls")] - use actix_tls::connect::ssl::rustls::{RustlsConnector, Session}; - - let ssl_service = TimeoutService::new( - self.config.timeout, - pipeline( - apply_fn(self.connector.clone(), |msg: Connect, srv| { - srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) - }) - .map_err(ConnectError::from), - ) - .and_then(match self.ssl { - #[cfg(feature = "openssl")] - SslConnector::Openssl(ssl) => service( - OpensslConnector::service(ssl) - .map(|stream| { - let sock = stream.into_parts().0; - let h2 = sock - .ssl() - .selected_alpn_protocol() - .map(|protos| protos.windows(2).any(|w| w == H2)) - .unwrap_or(false); - if h2 { - (Box::new(sock) as Box, Protocol::Http2) - } else { - (Box::new(sock) as Box, Protocol::Http1) - } - }) - .map_err(ConnectError::from), - ), - #[cfg(feature = "rustls")] - SslConnector::Rustls(ssl) => service( - RustlsConnector::service(ssl) - .map_err(ConnectError::from) - .map(|stream| { - let sock = stream.into_parts().0; - let h2 = sock - .get_ref() - .1 - .get_alpn_protocol() - .map(|protos| protos.windows(2).any(|w| w == H2)) - .unwrap_or(false); - if h2 { - (Box::new(sock) as Box, Protocol::Http2) - } else { - (Box::new(sock) as Box, Protocol::Http1) - } - }), - ), - }), - ) - .map_err(|e| match e { - TimeoutError::Service(e) => e, - TimeoutError::Timeout => ConnectError::Timeout, - }); - - let tcp_service = TimeoutService::new( - self.config.timeout, - apply_fn(self.connector, |msg: Connect, srv| { - srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) - }) - .map_err(ConnectError::from) - .map(|stream| (stream.into_parts().0, Protocol::Http1)), - ) - .map_err(|e| match e { - TimeoutError::Service(e) => e, - TimeoutError::Timeout => ConnectError::Timeout, - }); - - connect_impl::InnerConnector { - tcp_pool: ConnectionPool::new( - tcp_service, - self.config.no_disconnect_timeout(), - ), - ssl_pool: ConnectionPool::new(ssl_service, self.config), - } - } - } -} - -#[cfg(not(any(feature = "openssl", feature = "rustls")))] -mod connect_impl { - use std::task::{Context, Poll}; - - use futures_util::future::{err, Either, Ready}; - - use super::*; - use crate::client::connection::IoConnection; - - pub(crate) struct InnerConnector - where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service + 'static, - { - pub(crate) tcp_pool: ConnectionPool, - } - - impl Clone for InnerConnector - where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service + 'static, - { - fn clone(&self) -> Self { - InnerConnector { - tcp_pool: self.tcp_pool.clone(), - } - } - } - - impl Service for InnerConnector - where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service + 'static, - { - type Response = IoConnection; - type Error = ConnectError; - type Future = Either< - as Service>::Future, - Ready, ConnectError>>, - >; - - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.tcp_pool.poll_ready(cx) - } - - fn call(&self, req: Connect) -> Self::Future { - match req.uri.scheme_str() { - Some("https") | Some("wss") => { - Either::Right(err(ConnectError::SslIsNotSupported)) - } - _ => Either::Left(self.tcp_pool.call(req)), - } - } - } -} - -#[cfg(any(feature = "openssl", feature = "rustls"))] -mod connect_impl { - use std::future::Future; - use std::marker::PhantomData; - use std::pin::Pin; - use std::task::{Context, Poll}; - - use futures_core::ready; - use futures_util::future::Either; - - use super::*; - use crate::client::connection::EitherConnection; - - pub(crate) struct InnerConnector - where - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - T1: Service, - T2: Service, - { - pub(crate) tcp_pool: ConnectionPool, - pub(crate) ssl_pool: ConnectionPool, - } - - impl Clone for InnerConnector - where - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - T1: Service + 'static, - T2: Service + 'static, - { - fn clone(&self) -> Self { - InnerConnector { - tcp_pool: self.tcp_pool.clone(), - ssl_pool: self.ssl_pool.clone(), - } - } - } - - impl Service for InnerConnector - where - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - T1: Service + 'static, - T2: Service + 'static, - { - type Response = EitherConnection; - type Error = ConnectError; - type Future = Either< - InnerConnectorResponseA, - InnerConnectorResponseB, - >; - - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.tcp_pool.poll_ready(cx) - } - - fn call(&self, req: Connect) -> Self::Future { - match req.uri.scheme_str() { - Some("https") | Some("wss") => Either::Right(InnerConnectorResponseB { - fut: self.ssl_pool.call(req), - _phantom: PhantomData, - }), - _ => Either::Left(InnerConnectorResponseA { - fut: self.tcp_pool.call(req), - _phantom: PhantomData, - }), - } - } - } - - #[pin_project::pin_project] - pub(crate) struct InnerConnectorResponseA - where - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service + 'static, - { - #[pin] - fut: as Service>::Future, - _phantom: PhantomData, - } - - impl Future for InnerConnectorResponseA - where - T: Service + 'static, - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - { - type Output = Result, ConnectError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Poll::Ready( - ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) - .map(EitherConnection::A), - ) - } - } - - #[pin_project::pin_project] - pub(crate) struct InnerConnectorResponseB - where - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service + 'static, - { - #[pin] - fut: as Service>::Future, - _phantom: PhantomData, - } - - impl Future for InnerConnectorResponseB - where - T: Service + 'static, - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - { - type Output = Result, ConnectError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Poll::Ready( - ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) - .map(EitherConnection::B), - ) - } - } -} - -#[cfg(not(feature = "trust-dns"))] -mod resolver { - use super::*; - - pub(super) fn resolver() -> Resolver { - Resolver::Default - } -} - -#[cfg(feature = "trust-dns")] -mod resolver { - use std::{cell::RefCell, net::SocketAddr}; - - use actix_tls::connect::Resolve; - use futures_core::future::LocalBoxFuture; - use trust_dns_resolver::{ - config::{ResolverConfig, ResolverOpts}, - system_conf::read_system_conf, - TokioAsyncResolver, - }; - - use super::*; - - pub(super) fn resolver() -> Resolver { - // new type for impl Resolve trait for TokioAsyncResolver. - struct TrustDnsResolver(TokioAsyncResolver); - - impl Resolve for TrustDnsResolver { - fn lookup<'a>( - &'a self, - host: &'a str, - port: u16, - ) -> LocalBoxFuture<'a, Result, Box>> - { - Box::pin(async move { - let res = self - .0 - .lookup_ip(host) - .await? - .iter() - .map(|ip| SocketAddr::new(ip, port)) - .collect(); - Ok(res) - }) - } - } - - // dns struct is cached in thread local. - // so new client constructor can reuse the existing dns resolver. - thread_local! { - static TRUST_DNS_RESOLVER: RefCell> = RefCell::new(None); - } - - // get from thread local or construct a new trust-dns resolver. - TRUST_DNS_RESOLVER.with(|local| { - let resolver = local.borrow().as_ref().map(Clone::clone); - match resolver { - Some(resolver) => resolver, - None => { - let (cfg, opts) = match read_system_conf() { - Ok((cfg, opts)) => (cfg, opts), - Err(e) => { - log::error!("TRust-DNS can not load system config: {}", e); - (ResolverConfig::default(), ResolverOpts::default()) - } - }; - - let resolver = TokioAsyncResolver::tokio(cfg, opts).unwrap(); - - // box trust dns resolver and put it in thread local. - let resolver = Resolver::new_custom(TrustDnsResolver(resolver)); - *local.borrow_mut() = Some(resolver.clone()); - resolver - } - } - }) - } -} diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs deleted file mode 100644 index 92c3c0e1b..000000000 --- a/actix-http/src/client/h1proto.rs +++ /dev/null @@ -1,291 +0,0 @@ -use std::io::Write; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{io, time}; - -use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; -use bytes::buf::BufMut; -use bytes::{Bytes, BytesMut}; -use futures_core::Stream; -use futures_util::future::poll_fn; -use futures_util::{pin_mut, SinkExt, StreamExt}; - -use crate::error::PayloadError; -use crate::h1; -use crate::header::HeaderMap; -use crate::http::header::{IntoHeaderValue, HOST}; -use crate::message::{RequestHeadType, ResponseHead}; -use crate::payload::{Payload, PayloadStream}; - -use super::connection::{ConnectionLifetime, ConnectionType, IoConnection}; -use super::error::{ConnectError, SendRequestError}; -use super::pool::Acquired; -use crate::body::{BodySize, MessageBody}; - -pub(crate) async fn send_request( - io: T, - mut head: RequestHeadType, - body: B, - created: time::Instant, - pool: Option>, -) -> Result<(ResponseHead, Payload), SendRequestError> -where - T: AsyncRead + AsyncWrite + Unpin + 'static, - B: MessageBody, -{ - // set request host header - if !head.as_ref().headers.contains_key(HOST) - && !head.extra_headers().iter().any(|h| h.contains_key(HOST)) - { - if let Some(host) = head.as_ref().uri.host() { - let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); - - let _ = match head.as_ref().uri.port_u16() { - None | Some(80) | Some(443) => write!(wrt, "{}", host), - Some(port) => write!(wrt, "{}:{}", host, port), - }; - - match wrt.get_mut().split().freeze().try_into_value() { - Ok(value) => match head { - RequestHeadType::Owned(ref mut head) => { - head.headers.insert(HOST, value); - } - RequestHeadType::Rc(_, ref mut extra_headers) => { - let headers = extra_headers.get_or_insert(HeaderMap::new()); - headers.insert(HOST, value); - } - }, - Err(e) => log::error!("Can not set HOST header {}", e), - } - } - } - - let io = H1Connection { - created, - pool, - io: Some(io), - }; - - // create Framed and send request - let mut framed_inner = Framed::new(io, h1::ClientCodec::default()); - framed_inner.send((head, body.size()).into()).await?; - - // send request body - match body.size() { - BodySize::None | BodySize::Empty | BodySize::Sized(0) => {} - _ => send_body(body, Pin::new(&mut framed_inner)).await?, - }; - - // read response and init read body - let res = Pin::new(&mut framed_inner).into_future().await; - let (head, framed) = if let (Some(result), framed) = res { - let item = result.map_err(SendRequestError::from)?; - (item, framed) - } else { - return Err(SendRequestError::from(ConnectError::Disconnected)); - }; - - match framed.codec_ref().message_type() { - h1::MessageType::None => { - let force_close = !framed.codec_ref().keepalive(); - release_connection(framed, force_close); - Ok((head, Payload::None)) - } - _ => { - let pl: PayloadStream = PlStream::new(framed_inner).boxed_local(); - Ok((head, pl.into())) - } - } -} - -pub(crate) async fn open_tunnel( - io: T, - head: RequestHeadType, -) -> Result<(ResponseHead, Framed), SendRequestError> -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ - // create Framed and send request - let mut framed = Framed::new(io, h1::ClientCodec::default()); - framed.send((head, BodySize::None).into()).await?; - - // read response - if let (Some(result), framed) = framed.into_future().await { - let head = result.map_err(SendRequestError::from)?; - Ok((head, framed)) - } else { - Err(SendRequestError::from(ConnectError::Disconnected)) - } -} - -/// send request body to the peer -pub(crate) async fn send_body( - body: B, - mut framed: Pin<&mut Framed>, -) -> Result<(), SendRequestError> -where - T: ConnectionLifetime + Unpin, - B: MessageBody, -{ - pin_mut!(body); - - let mut eof = false; - while !eof { - while !eof && !framed.as_ref().is_write_buf_full() { - match poll_fn(|cx| body.as_mut().poll_next(cx)).await { - Some(result) => { - framed.as_mut().write(h1::Message::Chunk(Some(result?)))?; - } - None => { - eof = true; - framed.as_mut().write(h1::Message::Chunk(None))?; - } - } - } - - if !framed.as_ref().is_write_buf_empty() { - poll_fn(|cx| match framed.as_mut().flush(cx) { - Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), - Poll::Pending => { - if !framed.as_ref().is_write_buf_full() { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - }) - .await?; - } - } - - SinkExt::flush(Pin::into_inner(framed)).await?; - Ok(()) -} - -#[doc(hidden)] -/// HTTP client connection -pub struct H1Connection { - /// T should be `Unpin` - io: Option, - created: time::Instant, - pool: Option>, -} - -impl ConnectionLifetime for H1Connection -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ - /// Close connection - fn close(mut self: Pin<&mut Self>) { - if let Some(mut pool) = self.pool.take() { - if let Some(io) = self.io.take() { - pool.close(IoConnection::new( - ConnectionType::H1(io), - self.created, - None, - )); - } - } - } - - /// Release this connection to the connection pool - fn release(mut self: Pin<&mut Self>) { - if let Some(mut pool) = self.pool.take() { - if let Some(io) = self.io.take() { - pool.release(IoConnection::new( - ConnectionType::H1(io), - self.created, - None, - )); - } - } - } -} - -impl AsyncRead for H1Connection { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf) - } -} - -impl AsyncWrite for H1Connection { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf) - } - - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(self.io.as_mut().unwrap()).poll_flush(cx) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx) - } -} - -#[pin_project::pin_project] -pub(crate) struct PlStream { - #[pin] - framed: Option>, -} - -impl PlStream { - fn new(framed: Framed) -> Self { - let framed = framed.into_map_codec(|codec| codec.into_payload_codec()); - - PlStream { - framed: Some(framed), - } - } -} - -impl Stream for PlStream { - type Item = Result; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let mut this = self.project(); - - match this.framed.as_mut().as_pin_mut().unwrap().next_item(cx)? { - Poll::Pending => Poll::Pending, - Poll::Ready(Some(chunk)) => { - if let Some(chunk) = chunk { - Poll::Ready(Some(Ok(chunk))) - } else { - let framed = this.framed.as_mut().as_pin_mut().unwrap(); - let force_close = !framed.codec_ref().keepalive(); - release_connection(framed, force_close); - Poll::Ready(None) - } - } - Poll::Ready(None) => Poll::Ready(None), - } - } -} - -fn release_connection(framed: Pin<&mut Framed>, force_close: bool) -where - T: ConnectionLifetime, -{ - if !force_close && framed.is_read_buf_empty() && framed.is_write_buf_empty() { - framed.io_pin().release() - } else { - framed.io_pin().close() - } -} diff --git a/actix-http/src/client/mod.rs b/actix-http/src/client/mod.rs deleted file mode 100644 index 9c7f632ea..000000000 --- a/actix-http/src/client/mod.rs +++ /dev/null @@ -1,22 +0,0 @@ -//! HTTP client. - -use http::Uri; - -mod config; -mod connection; -mod connector; -mod error; -mod h1proto; -mod h2proto; -mod pool; - -pub use self::connection::Connection; -pub use self::connector::Connector; -pub use self::error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}; -pub use self::pool::Protocol; - -#[derive(Clone)] -pub struct Connect { - pub uri: Uri, - pub addr: Option, -} diff --git a/actix-http/src/client/pool.rs b/actix-http/src/client/pool.rs deleted file mode 100644 index 867ba5c0c..000000000 --- a/actix-http/src/client/pool.rs +++ /dev/null @@ -1,645 +0,0 @@ -use std::cell::RefCell; -use std::collections::VecDeque; -use std::future::Future; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; -use std::time::{Duration, Instant}; - -use actix_codec::{AsyncRead, AsyncWrite, ReadBuf}; -use actix_rt::time::{sleep, Sleep}; -use actix_service::Service; -use actix_utils::task::LocalWaker; -use ahash::AHashMap; -use bytes::Bytes; -use futures_channel::oneshot; -use futures_core::future::LocalBoxFuture; -use futures_util::future::{poll_fn, FutureExt}; -use h2::client::{Connection, SendRequest}; -use http::uri::Authority; -use indexmap::IndexSet; -use pin_project::pin_project; -use slab::Slab; - -use super::config::ConnectorConfig; -use super::connection::{ConnectionType, IoConnection}; -use super::error::ConnectError; -use super::h2proto::handshake; -use super::Connect; -use crate::client::connection::H2Connection; - -#[derive(Clone, Copy, PartialEq)] -/// Protocol version -pub enum Protocol { - Http1, - Http2, -} - -#[derive(Hash, Eq, PartialEq, Clone, Debug)] -pub(crate) struct Key { - authority: Authority, -} - -impl From for Key { - fn from(authority: Authority) -> Key { - Key { authority } - } -} - -/// Connections pool -pub(crate) struct ConnectionPool(Rc, Rc>>); - -impl ConnectionPool -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service + 'static, -{ - pub(crate) fn new(connector: T, config: ConnectorConfig) -> Self { - let connector_rc = Rc::new(connector); - let inner_rc = Rc::new(RefCell::new(Inner { - config, - acquired: 0, - waiters: Slab::new(), - waiters_queue: IndexSet::new(), - available: AHashMap::default(), - waker: LocalWaker::new(), - })); - - // start support future - actix_rt::spawn(ConnectorPoolSupport { - connector: Rc::clone(&connector_rc), - inner: Rc::clone(&inner_rc), - }); - - ConnectionPool(connector_rc, inner_rc) - } -} - -impl Clone for ConnectionPool -where - Io: 'static, -{ - fn clone(&self) -> Self { - ConnectionPool(self.0.clone(), self.1.clone()) - } -} - -impl Drop for ConnectionPool { - fn drop(&mut self) { - // wake up the ConnectorPoolSupport when dropping so it can exit properly. - self.1.borrow().waker.wake(); - } -} - -impl Service for ConnectionPool -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service + 'static, -{ - type Response = IoConnection; - type Error = ConnectError; - type Future = LocalBoxFuture<'static, Result, ConnectError>>; - - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.0.poll_ready(cx) - } - - fn call(&self, req: Connect) -> Self::Future { - let connector = self.0.clone(); - let inner = self.1.clone(); - - let fut = async move { - let key = if let Some(authority) = req.uri.authority() { - authority.clone().into() - } else { - return Err(ConnectError::Unresolved); - }; - - // acquire connection - match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await { - Acquire::Acquired(io, created) => { - // use existing connection - Ok(IoConnection::new( - io, - created, - Some(Acquired(key, Some(inner))), - )) - } - Acquire::Available => { - // open tcp connection - let (io, proto) = connector.call(req).await?; - - let config = inner.borrow().config.clone(); - - let guard = OpenGuard::new(key, inner); - - if proto == Protocol::Http1 { - Ok(IoConnection::new( - ConnectionType::H1(io), - Instant::now(), - Some(guard.consume()), - )) - } else { - let (sender, connection) = handshake(io, &config).await?; - Ok(IoConnection::new( - ConnectionType::H2(H2Connection::new(sender, connection)), - Instant::now(), - Some(guard.consume()), - )) - } - } - _ => { - // connection is not available, wait - let (rx, token) = inner.borrow_mut().wait_for(req); - - let guard = WaiterGuard::new(key, token, inner); - let res = match rx.await { - Err(_) => Err(ConnectError::Disconnected), - Ok(res) => res, - }; - guard.consume(); - res - } - } - }; - - fut.boxed_local() - } -} - -struct WaiterGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - key: Key, - token: usize, - inner: Option>>>, -} - -impl WaiterGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn new(key: Key, token: usize, inner: Rc>>) -> Self { - Self { - key, - token, - inner: Some(inner), - } - } - - fn consume(mut self) { - let _ = self.inner.take(); - } -} - -impl Drop for WaiterGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn drop(&mut self) { - if let Some(i) = self.inner.take() { - let mut inner = i.as_ref().borrow_mut(); - inner.release_waiter(&self.key, self.token); - inner.check_availability(); - } - } -} - -struct OpenGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - key: Key, - inner: Option>>>, -} - -impl OpenGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn new(key: Key, inner: Rc>>) -> Self { - Self { - key, - inner: Some(inner), - } - } - - fn consume(mut self) -> Acquired { - Acquired(self.key.clone(), self.inner.take()) - } -} - -impl Drop for OpenGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn drop(&mut self) { - if let Some(i) = self.inner.take() { - let mut inner = i.as_ref().borrow_mut(); - inner.release(); - inner.check_availability(); - } - } -} - -enum Acquire { - Acquired(ConnectionType, Instant), - Available, - NotAvailable, -} - -struct AvailableConnection { - io: ConnectionType, - used: Instant, - created: Instant, -} - -pub(crate) struct Inner { - config: ConnectorConfig, - acquired: usize, - available: AHashMap>>, - waiters: Slab< - Option<( - Connect, - oneshot::Sender, ConnectError>>, - )>, - >, - waiters_queue: IndexSet<(Key, usize)>, - waker: LocalWaker, -} - -impl Inner { - fn reserve(&mut self) { - self.acquired += 1; - } - - fn release(&mut self) { - self.acquired -= 1; - } - - fn release_waiter(&mut self, key: &Key, token: usize) { - self.waiters.remove(token); - let _ = self.waiters_queue.shift_remove(&(key.clone(), token)); - } -} - -impl Inner -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - /// connection is not available, wait - fn wait_for( - &mut self, - connect: Connect, - ) -> ( - oneshot::Receiver, ConnectError>>, - usize, - ) { - let (tx, rx) = oneshot::channel(); - - let key: Key = connect.uri.authority().unwrap().clone().into(); - let entry = self.waiters.vacant_entry(); - let token = entry.key(); - entry.insert(Some((connect, tx))); - assert!(self.waiters_queue.insert((key, token))); - - (rx, token) - } - - fn acquire(&mut self, key: &Key, cx: &mut Context<'_>) -> Acquire { - // check limits - if self.config.limit > 0 && self.acquired >= self.config.limit { - return Acquire::NotAvailable; - } - - self.reserve(); - - // check if open connection is available - // cleanup stale connections at the same time - if let Some(ref mut connections) = self.available.get_mut(key) { - let now = Instant::now(); - while let Some(conn) = connections.pop_back() { - // check if it still usable - if (now - conn.used) > self.config.conn_keep_alive - || (now - conn.created) > self.config.conn_lifetime - { - if let Some(timeout) = self.config.disconnect_timeout { - if let ConnectionType::H1(io) = conn.io { - actix_rt::spawn(CloseConnection::new(io, timeout)); - } - } - } else { - let mut io = conn.io; - let mut buf = [0; 2]; - let mut read_buf = ReadBuf::new(&mut buf); - if let ConnectionType::H1(ref mut s) = io { - match Pin::new(s).poll_read(cx, &mut read_buf) { - Poll::Pending => {} - Poll::Ready(Ok(())) if !read_buf.filled().is_empty() => { - if let Some(timeout) = self.config.disconnect_timeout { - if let ConnectionType::H1(io) = io { - actix_rt::spawn(CloseConnection::new( - io, timeout, - )); - } - } - continue; - } - _ => continue, - } - } - return Acquire::Acquired(io, conn.created); - } - } - } - Acquire::Available - } - - fn release_conn(&mut self, key: &Key, io: ConnectionType, created: Instant) { - self.acquired -= 1; - self.available - .entry(key.clone()) - .or_insert_with(VecDeque::new) - .push_back(AvailableConnection { - io, - created, - used: Instant::now(), - }); - self.check_availability(); - } - - fn release_close(&mut self, io: ConnectionType) { - self.acquired -= 1; - if let Some(timeout) = self.config.disconnect_timeout { - if let ConnectionType::H1(io) = io { - actix_rt::spawn(CloseConnection::new(io, timeout)); - } - } - self.check_availability(); - } - - fn check_availability(&self) { - if !self.waiters_queue.is_empty() && self.acquired < self.config.limit { - self.waker.wake(); - } - } -} - -#[pin_project::pin_project] -struct CloseConnection { - io: T, - #[pin] - timeout: Sleep, -} - -impl CloseConnection -where - T: AsyncWrite + Unpin, -{ - fn new(io: T, timeout: Duration) -> Self { - CloseConnection { - io, - timeout: sleep(timeout), - } - } -} - -impl Future for CloseConnection -where - T: AsyncWrite + Unpin, -{ - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - let this = self.project(); - - match this.timeout.poll(cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => match Pin::new(this.io).poll_shutdown(cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => Poll::Pending, - }, - } - } -} - -#[pin_project] -struct ConnectorPoolSupport -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - connector: Rc, - inner: Rc>>, -} - -impl Future for ConnectorPoolSupport -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service, - T::Future: 'static, -{ - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - if Rc::strong_count(this.inner) == 1 { - // If we are last copy of Inner it means the ConnectionPool is already gone - // and we are safe to exit. - return Poll::Ready(()); - } - - let mut inner = this.inner.borrow_mut(); - inner.waker.register(cx.waker()); - - // check waiters - loop { - let (key, token) = { - if let Some((key, token)) = inner.waiters_queue.get_index(0) { - (key.clone(), *token) - } else { - break; - } - }; - if inner.waiters.get(token).unwrap().is_none() { - continue; - } - - match inner.acquire(&key, cx) { - Acquire::NotAvailable => break, - Acquire::Acquired(io, created) => { - let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1; - if let Err(conn) = tx.send(Ok(IoConnection::new( - io, - created, - Some(Acquired(key.clone(), Some(this.inner.clone()))), - ))) { - let (io, created) = conn.unwrap().into_inner(); - inner.release_conn(&key, io, created); - } - } - Acquire::Available => { - let (connect, tx) = - inner.waiters.get_mut(token).unwrap().take().unwrap(); - OpenWaitingConnection::spawn( - key.clone(), - tx, - this.inner.clone(), - this.connector.call(connect), - inner.config.clone(), - ); - } - } - let _ = inner.waiters_queue.swap_remove_index(0); - } - - Poll::Pending - } -} - -#[pin_project::pin_project(PinnedDrop)] -struct OpenWaitingConnection -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - #[pin] - fut: F, - key: Key, - h2: Option< - LocalBoxFuture< - 'static, - Result<(SendRequest, Connection), h2::Error>, - >, - >, - rx: Option, ConnectError>>>, - inner: Option>>>, - config: ConnectorConfig, -} - -impl OpenWaitingConnection -where - F: Future> + 'static, - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn spawn( - key: Key, - rx: oneshot::Sender, ConnectError>>, - inner: Rc>>, - fut: F, - config: ConnectorConfig, - ) { - actix_rt::spawn(OpenWaitingConnection { - key, - fut, - h2: None, - rx: Some(rx), - inner: Some(inner), - config, - }); - } -} - -#[pin_project::pinned_drop] -impl PinnedDrop for OpenWaitingConnection -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn drop(self: Pin<&mut Self>) { - if let Some(inner) = self.project().inner.take() { - let mut inner = inner.as_ref().borrow_mut(); - inner.release(); - inner.check_availability(); - } - } -} - -impl Future for OpenWaitingConnection -where - F: Future>, - Io: AsyncRead + AsyncWrite + Unpin, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().project(); - - if let Some(ref mut h2) = this.h2 { - return match Pin::new(h2).poll(cx) { - Poll::Ready(Ok((sender, connection))) => { - let rx = this.rx.take().unwrap(); - let _ = rx.send(Ok(IoConnection::new( - ConnectionType::H2(H2Connection::new(sender, connection)), - Instant::now(), - Some(Acquired(this.key.clone(), this.inner.take())), - ))); - Poll::Ready(()) - } - Poll::Pending => Poll::Pending, - Poll::Ready(Err(err)) => { - let _ = this.inner.take(); - if let Some(rx) = this.rx.take() { - let _ = rx.send(Err(ConnectError::H2(err))); - } - Poll::Ready(()) - } - }; - } - - match this.fut.poll(cx) { - Poll::Ready(Err(err)) => { - let _ = this.inner.take(); - if let Some(rx) = this.rx.take() { - let _ = rx.send(Err(err)); - } - Poll::Ready(()) - } - Poll::Ready(Ok((io, proto))) => { - if proto == Protocol::Http1 { - let rx = this.rx.take().unwrap(); - let _ = rx.send(Ok(IoConnection::new( - ConnectionType::H1(io), - Instant::now(), - Some(Acquired(this.key.clone(), this.inner.take())), - ))); - Poll::Ready(()) - } else { - *this.h2 = Some(handshake(io, this.config).boxed_local()); - self.poll(cx) - } - } - Poll::Pending => Poll::Pending, - } - } -} - -pub(crate) struct Acquired(Key, Option>>>); - -impl Acquired -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ - pub(crate) fn close(&mut self, conn: IoConnection) { - if let Some(inner) = self.1.take() { - let (io, _) = conn.into_inner(); - inner.as_ref().borrow_mut().release_close(io); - } - } - pub(crate) fn release(&mut self, conn: IoConnection) { - if let Some(inner) = self.1.take() { - let (io, created) = conn.into_inner(); - inner - .as_ref() - .borrow_mut() - .release_conn(&self.0, io, created); - } - } -} - -impl Drop for Acquired { - fn drop(&mut self) { - if let Some(inner) = self.1.take() { - inner.as_ref().borrow_mut().release(); - } - } -} diff --git a/actix-http/src/config.rs b/actix-http/src/config.rs index 9f84b8694..5d020edfc 100644 --- a/actix-http/src/config.rs +++ b/actix-http/src/config.rs @@ -1,26 +1,29 @@ -use std::cell::Cell; -use std::fmt::Write; -use std::rc::Rc; -use std::time::Duration; -use std::{fmt, net}; +use std::{ + cell::Cell, + fmt::{self, Write}, + net, + rc::Rc, + time::{Duration, SystemTime}, +}; use actix_rt::{ task::JoinHandle, time::{interval, sleep_until, Instant, Sleep}, }; use bytes::BytesMut; -use time::OffsetDateTime; /// "Sun, 06 Nov 1994 08:49:37 GMT".len() -const DATE_VALUE_LENGTH: usize = 29; +pub(crate) const DATE_VALUE_LENGTH: usize = 29; #[derive(Debug, PartialEq, Clone, Copy)] /// Server keep-alive setting pub enum KeepAlive { /// Keep alive in seconds Timeout(usize), + /// Rely on OS to shutdown tcp connection Os, + /// Disabled Disabled, } @@ -104,6 +107,8 @@ impl ServiceConfig { } /// Returns the local address that this server is bound to. + /// + /// Returns `None` for connections via UDS (Unix Domain Socket). #[inline] pub fn local_addr(&self) -> Option { self.0.local_addr @@ -126,9 +131,7 @@ impl ServiceConfig { pub fn client_timer(&self) -> Option { let delay_time = self.0.client_timeout; if delay_time != 0 { - Some(sleep_until( - self.0.date_service.now() + Duration::from_millis(delay_time), - )) + Some(sleep_until(self.now() + Duration::from_millis(delay_time))) } else { None } @@ -138,7 +141,7 @@ impl ServiceConfig { pub fn client_timer_expire(&self) -> Option { let delay = self.0.client_timeout; if delay != 0 { - Some(self.0.date_service.now() + Duration::from_millis(delay)) + Some(self.now() + Duration::from_millis(delay)) } else { None } @@ -148,29 +151,21 @@ impl ServiceConfig { pub fn client_disconnect_timer(&self) -> Option { let delay = self.0.client_disconnect; if delay != 0 { - Some(self.0.date_service.now() + Duration::from_millis(delay)) + Some(self.now() + Duration::from_millis(delay)) } else { None } } - #[inline] /// Return keep-alive timer delay is configured. + #[inline] pub fn keep_alive_timer(&self) -> Option { - if let Some(ka) = self.0.keep_alive { - Some(sleep_until(self.0.date_service.now() + ka)) - } else { - None - } + self.keep_alive().map(|ka| sleep_until(self.now() + ka)) } /// Keep-alive expire time pub fn keep_alive_expire(&self) -> Option { - if let Some(ka) = self.0.keep_alive { - Some(self.0.date_service.now() + ka) - } else { - None - } + self.keep_alive().map(|ka| self.now() + ka) } #[inline] @@ -214,12 +209,7 @@ impl Date { fn update(&mut self) { self.pos = 0; - write!( - self, - "{}", - OffsetDateTime::now_utc().format("%a, %d %b %Y %H:%M:%S GMT") - ) - .unwrap(); + write!(self, "{}", httpdate::fmt_http_date(SystemTime::now())).unwrap(); } } @@ -277,11 +267,11 @@ impl DateService { } // TODO: move to a util module for testing all spawn handle drop style tasks. -#[cfg(test)] /// Test Module for checking the drop state of certain async tasks that are spawned /// with `actix_rt::spawn` /// /// The target task must explicitly generate `NotifyOnDrop` when spawn the task +#[cfg(test)] mod notify_on_drop { use std::cell::RefCell; @@ -291,9 +281,8 @@ mod notify_on_drop { /// Check if the spawned task is dropped. /// - /// # Panic: - /// - /// When there was no `NotifyOnDrop` instance on current thread + /// # Panics + /// Panics when there was no `NotifyOnDrop` instance on current thread. pub(crate) fn is_dropped() -> bool { NOTIFY_DROPPED.with(|bool| { bool.borrow() @@ -336,7 +325,7 @@ mod notify_on_drop { mod tests { use super::*; - use actix_rt::task::yield_now; + use actix_rt::{task::yield_now, time::sleep}; #[actix_rt::test] async fn test_date_service_update() { @@ -360,7 +349,14 @@ mod tests { assert_ne!(buf1, buf2); drop(settings); - assert!(notify_on_drop::is_dropped()); + + // Ensure the task will drop eventually + let mut times = 0; + while !notify_on_drop::is_dropped() { + sleep(Duration::from_millis(100)).await; + times += 1; + assert!(times < 10, "Timeout waiting for task drop"); + } } #[actix_rt::test] @@ -375,14 +371,21 @@ mod tests { let clone3 = service.clone(); drop(clone1); - assert_eq!(false, notify_on_drop::is_dropped()); + assert!(!notify_on_drop::is_dropped()); drop(clone2); - assert_eq!(false, notify_on_drop::is_dropped()); + assert!(!notify_on_drop::is_dropped()); drop(clone3); - assert_eq!(false, notify_on_drop::is_dropped()); + assert!(!notify_on_drop::is_dropped()); drop(service); - assert!(notify_on_drop::is_dropped()); + + // Ensure the task will drop eventually + let mut times = 0; + while !notify_on_drop::is_dropped() { + sleep(Duration::from_millis(100)).await; + times += 1; + assert!(times < 10, "Timeout waiting for task drop"); + } } #[test] diff --git a/actix-http/src/encoding/decoder.rs b/actix-http/src/encoding/decoder.rs index f0abae865..da4b56c6a 100644 --- a/actix-http/src/encoding/decoder.rs +++ b/actix-http/src/encoding/decoder.rs @@ -8,24 +8,34 @@ use std::{ }; use actix_rt::task::{spawn_blocking, JoinHandle}; -use brotli2::write::BrotliDecoder; use bytes::Bytes; -use flate2::write::{GzDecoder, ZlibDecoder}; use futures_core::{ready, Stream}; +#[cfg(feature = "compress-brotli")] +use brotli2::write::BrotliDecoder; + +#[cfg(feature = "compress-gzip")] +use flate2::write::{GzDecoder, ZlibDecoder}; + +#[cfg(feature = "compress-zstd")] +use zstd::stream::write::Decoder as ZstdDecoder; + use crate::{ encoding::Writer, error::{BlockingError, PayloadError}, - http::header::{ContentEncoding, HeaderMap, CONTENT_ENCODING}, + header::{ContentEncoding, HeaderMap, CONTENT_ENCODING}, }; const MAX_CHUNK_SIZE_DECODE_IN_PLACE: usize = 2049; -pub struct Decoder { - decoder: Option, - stream: S, - eof: bool, - fut: Option, ContentDecoder), io::Error>>>, +pin_project_lite::pin_project! { + pub struct Decoder { + decoder: Option, + #[pin] + stream: S, + eof: bool, + fut: Option, ContentDecoder), io::Error>>>, + } } impl Decoder @@ -36,14 +46,24 @@ where #[inline] pub fn new(stream: S, encoding: ContentEncoding) -> Decoder { let decoder = match encoding { - ContentEncoding::Br => Some(ContentDecoder::Br(Box::new( + #[cfg(feature = "compress-brotli")] + ContentEncoding::Brotli => Some(ContentDecoder::Brotli(Box::new( BrotliDecoder::new(Writer::new()), ))), + #[cfg(feature = "compress-gzip")] ContentEncoding::Deflate => Some(ContentDecoder::Deflate(Box::new( ZlibDecoder::new(Writer::new()), ))), - ContentEncoding::Gzip => Some(ContentDecoder::Gzip(Box::new( - GzDecoder::new(Writer::new()), + #[cfg(feature = "compress-gzip")] + ContentEncoding::Gzip => Some(ContentDecoder::Gzip(Box::new(GzDecoder::new( + Writer::new(), + )))), + #[cfg(feature = "compress-zstd")] + ContentEncoding::Zstd => Some(ContentDecoder::Zstd(Box::new( + ZstdDecoder::new(Writer::new()).expect( + "Failed to create zstd decoder. This is a bug. \ + Please report it to the actix-web repository.", + ), ))), _ => None, }; @@ -63,7 +83,7 @@ where let encoding = headers .get(&CONTENT_ENCODING) .and_then(|val| val.to_str().ok()) - .map(ContentEncoding::from) + .and_then(|x| x.parse().ok()) .unwrap_or(ContentEncoding::Identity); Self::new(stream, encoding) @@ -72,45 +92,44 @@ where impl Stream for Decoder where - S: Stream> + Unpin, + S: Stream>, { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + loop { - if let Some(ref mut fut) = self.fut { + if let Some(ref mut fut) = this.fut { let (chunk, decoder) = ready!(Pin::new(fut).poll(cx)).map_err(|_| BlockingError)??; - self.decoder = Some(decoder); - self.fut.take(); + *this.decoder = Some(decoder); + this.fut.take(); if let Some(chunk) = chunk { return Poll::Ready(Some(Ok(chunk))); } } - if self.eof { + if *this.eof { return Poll::Ready(None); } - match ready!(Pin::new(&mut self.stream).poll_next(cx)) { + match ready!(this.stream.as_mut().poll_next(cx)) { Some(Err(err)) => return Poll::Ready(Some(Err(err))), Some(Ok(chunk)) => { - if let Some(mut decoder) = self.decoder.take() { + if let Some(mut decoder) = this.decoder.take() { if chunk.len() < MAX_CHUNK_SIZE_DECODE_IN_PLACE { let chunk = decoder.feed_data(chunk)?; - self.decoder = Some(decoder); + *this.decoder = Some(decoder); if let Some(chunk) = chunk { return Poll::Ready(Some(Ok(chunk))); } } else { - self.fut = Some(spawn_blocking(move || { + *this.fut = Some(spawn_blocking(move || { let chunk = decoder.feed_data(chunk)?; Ok((chunk, decoder)) })); @@ -123,9 +142,9 @@ where } None => { - self.eof = true; + *this.eof = true; - return if let Some(mut decoder) = self.decoder.take() { + return if let Some(mut decoder) = this.decoder.take() { match decoder.feed_eof() { Ok(Some(res)) => Poll::Ready(Some(Ok(res))), Ok(None) => Poll::Ready(None), @@ -141,15 +160,23 @@ where } enum ContentDecoder { + #[cfg(feature = "compress-gzip")] Deflate(Box>), + #[cfg(feature = "compress-gzip")] Gzip(Box>), - Br(Box>), + #[cfg(feature = "compress-brotli")] + Brotli(Box>), + // We need explicit 'static lifetime here because ZstdDecoder need lifetime + // argument, and we use `spawn_blocking` in `Decoder::poll_next` that require `FnOnce() -> R + Send + 'static` + #[cfg(feature = "compress-zstd")] + Zstd(Box>), } impl ContentDecoder { fn feed_eof(&mut self) -> io::Result> { match self { - ContentDecoder::Br(ref mut decoder) => match decoder.flush() { + #[cfg(feature = "compress-brotli")] + ContentDecoder::Brotli(ref mut decoder) => match decoder.flush() { Ok(()) => { let b = decoder.get_mut().take(); @@ -162,6 +189,7 @@ impl ContentDecoder { Err(e) => Err(e), }, + #[cfg(feature = "compress-gzip")] ContentDecoder::Gzip(ref mut decoder) => match decoder.try_finish() { Ok(_) => { let b = decoder.get_mut().take(); @@ -175,6 +203,7 @@ impl ContentDecoder { Err(e) => Err(e), }, + #[cfg(feature = "compress-gzip")] ContentDecoder::Deflate(ref mut decoder) => match decoder.try_finish() { Ok(_) => { let b = decoder.get_mut().take(); @@ -186,12 +215,26 @@ impl ContentDecoder { } Err(e) => Err(e), }, + + #[cfg(feature = "compress-zstd")] + ContentDecoder::Zstd(ref mut decoder) => match decoder.flush() { + Ok(_) => { + let b = decoder.get_mut().take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) + } + } + Err(e) => Err(e), + }, } } fn feed_data(&mut self, data: Bytes) -> io::Result> { match self { - ContentDecoder::Br(ref mut decoder) => match decoder.write_all(&data) { + #[cfg(feature = "compress-brotli")] + ContentDecoder::Brotli(ref mut decoder) => match decoder.write_all(&data) { Ok(_) => { decoder.flush()?; let b = decoder.get_mut().take(); @@ -205,6 +248,7 @@ impl ContentDecoder { Err(e) => Err(e), }, + #[cfg(feature = "compress-gzip")] ContentDecoder::Gzip(ref mut decoder) => match decoder.write_all(&data) { Ok(_) => { decoder.flush()?; @@ -219,6 +263,7 @@ impl ContentDecoder { Err(e) => Err(e), }, + #[cfg(feature = "compress-gzip")] ContentDecoder::Deflate(ref mut decoder) => match decoder.write_all(&data) { Ok(_) => { decoder.flush()?; @@ -232,6 +277,21 @@ impl ContentDecoder { } Err(e) => Err(e), }, + + #[cfg(feature = "compress-zstd")] + ContentDecoder::Zstd(ref mut decoder) => match decoder.write_all(&data) { + Ok(_) => { + decoder.flush()?; + + let b = decoder.get_mut().take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) + } + } + Err(e) => Err(e), + }, } } } diff --git a/actix-http/src/encoding/encoder.rs b/actix-http/src/encoding/encoder.rs index 366ecb8c4..2848ad697 100644 --- a/actix-http/src/encoding/encoder.rs +++ b/actix-http/src/encoding/encoder.rs @@ -1,6 +1,7 @@ //! Stream encoders. use std::{ + error::Error as StdError, future::Future, io::{self, Write as _}, pin::Pin, @@ -8,143 +9,175 @@ use std::{ }; use actix_rt::task::{spawn_blocking, JoinHandle}; -use brotli2::write::BrotliEncoder; use bytes::Bytes; -use flate2::write::{GzEncoder, ZlibEncoder}; +use derive_more::Display; use futures_core::ready; -use pin_project::pin_project; +use pin_project_lite::pin_project; -use crate::{ - body::{Body, BodySize, MessageBody, ResponseBody}, - http::{ - header::{ContentEncoding, CONTENT_ENCODING}, - HeaderValue, StatusCode, - }, - Error, ResponseHead, -}; +#[cfg(feature = "compress-brotli")] +use brotli2::write::BrotliEncoder; + +#[cfg(feature = "compress-gzip")] +use flate2::write::{GzEncoder, ZlibEncoder}; + +#[cfg(feature = "compress-zstd")] +use zstd::stream::write::Encoder as ZstdEncoder; use super::Writer; -use crate::error::BlockingError; +use crate::{ + body::{self, BodySize, MessageBody}, + error::BlockingError, + header::{self, ContentEncoding, HeaderValue, CONTENT_ENCODING}, + ResponseHead, StatusCode, +}; const MAX_CHUNK_SIZE_ENCODE_IN_PLACE: usize = 1024; -#[pin_project] -pub struct Encoder { - eof: bool, - #[pin] - body: EncoderBody, - encoder: Option, - fut: Option>>, +pin_project! { + pub struct Encoder { + #[pin] + body: EncoderBody, + encoder: Option, + fut: Option>>, + eof: bool, + } } impl Encoder { - pub fn response( - encoding: ContentEncoding, - head: &mut ResponseHead, - body: ResponseBody, - ) -> ResponseBody> { - let can_encode = !(head.headers().contains_key(&CONTENT_ENCODING) + fn none() -> Self { + Encoder { + body: EncoderBody::None { + body: body::None::new(), + }, + encoder: None, + fut: None, + eof: true, + } + } + + pub fn response(encoding: ContentEncoding, head: &mut ResponseHead, body: B) -> Self { + // no need to compress an empty body + if matches!(body.size(), BodySize::None) { + return Self::none(); + } + + let should_encode = !(head.headers().contains_key(&CONTENT_ENCODING) || head.status == StatusCode::SWITCHING_PROTOCOLS || head.status == StatusCode::NO_CONTENT - || encoding == ContentEncoding::Identity - || encoding == ContentEncoding::Auto); + || encoding == ContentEncoding::Identity); - let body = match body { - ResponseBody::Other(b) => match b { - Body::None => return ResponseBody::Other(Body::None), - Body::Empty => return ResponseBody::Other(Body::Empty), - Body::Bytes(buf) => { - if can_encode { - EncoderBody::Bytes(buf) - } else { - return ResponseBody::Other(Body::Bytes(buf)); - } - } - Body::Message(stream) => EncoderBody::BoxedStream(stream), - }, - ResponseBody::Body(stream) => EncoderBody::Stream(stream), + let body = match body.try_into_bytes() { + Ok(body) => EncoderBody::Full { body }, + Err(body) => EncoderBody::Stream { body }, }; - if can_encode { - // Modify response body only if encoder is not None - if let Some(enc) = ContentEncoder::encoder(encoding) { + if should_encode { + // wrap body only if encoder is feature-enabled + if let Some(enc) = ContentEncoder::select(encoding) { update_head(encoding, head); - head.no_chunking(false); - return ResponseBody::Body(Encoder { + + return Encoder { body, - eof: false, - fut: None, encoder: Some(enc), - }); + fut: None, + eof: false, + }; } } - ResponseBody::Body(Encoder { + + Encoder { body, - eof: false, - fut: None, encoder: None, - }) + fut: None, + eof: false, + } } } -#[pin_project(project = EncoderBodyProj)] -enum EncoderBody { - Bytes(Bytes), - Stream(#[pin] B), - BoxedStream(Box), +pin_project! { + #[project = EncoderBodyProj] + enum EncoderBody { + None { body: body::None }, + Full { body: Bytes }, + Stream { #[pin] body: B }, + } } -impl MessageBody for EncoderBody { +impl MessageBody for EncoderBody +where + B: MessageBody, +{ + type Error = EncoderError; + + #[inline] fn size(&self) -> BodySize { match self { - EncoderBody::Bytes(ref b) => b.size(), - EncoderBody::Stream(ref b) => b.size(), - EncoderBody::BoxedStream(ref b) => b.size(), + EncoderBody::None { body } => body.size(), + EncoderBody::Full { body } => body.size(), + EncoderBody::Stream { body } => body.size(), } } fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { match self.project() { - EncoderBodyProj::Bytes(b) => { - if b.is_empty() { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(std::mem::take(b)))) - } + EncoderBodyProj::None { body } => { + Pin::new(body).poll_next(cx).map_err(|err| match err {}) } - EncoderBodyProj::Stream(b) => b.poll_next(cx), - EncoderBodyProj::BoxedStream(ref mut b) => { - Pin::new(b.as_mut()).poll_next(cx) + EncoderBodyProj::Full { body } => { + Pin::new(body).poll_next(cx).map_err(|err| match err {}) } + EncoderBodyProj::Stream { body } => body + .poll_next(cx) + .map_err(|err| EncoderError::Body(err.into())), + } + } + + #[inline] + fn try_into_bytes(self) -> Result + where + Self: Sized, + { + match self { + EncoderBody::None { body } => Ok(body.try_into_bytes().unwrap()), + EncoderBody::Full { body } => Ok(body.try_into_bytes().unwrap()), + _ => Err(self), } } } -impl MessageBody for Encoder { +impl MessageBody for Encoder +where + B: MessageBody, +{ + type Error = EncoderError; + + #[inline] fn size(&self) -> BodySize { - if self.encoder.is_none() { - self.body.size() - } else { + if self.encoder.is_some() { BodySize::Stream + } else { + self.body.size() } } fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { let mut this = self.project(); + loop { if *this.eof { return Poll::Ready(None); } if let Some(ref mut fut) = this.fut { - let mut encoder = - ready!(Pin::new(fut).poll(cx)).map_err(|_| BlockingError)??; + let mut encoder = ready!(Pin::new(fut).poll(cx)) + .map_err(|_| EncoderError::Blocking(BlockingError))? + .map_err(EncoderError::Io)?; let chunk = encoder.take(); *this.encoder = Some(encoder); @@ -163,7 +196,7 @@ impl MessageBody for Encoder { Some(Ok(chunk)) => { if let Some(mut encoder) = this.encoder.take() { if chunk.len() < MAX_CHUNK_SIZE_ENCODE_IN_PLACE { - encoder.write(&chunk)?; + encoder.write(&chunk).map_err(EncoderError::Io)?; let chunk = encoder.take(); *this.encoder = Some(encoder); @@ -183,7 +216,8 @@ impl MessageBody for Encoder { None => { if let Some(encoder) = this.encoder.take() { - let chunk = encoder.finish()?; + let chunk = encoder.finish().map_err(EncoderError::Io)?; + if chunk.is_empty() { return Poll::Ready(None); } else { @@ -197,35 +231,77 @@ impl MessageBody for Encoder { } } } + + #[inline] + fn try_into_bytes(mut self) -> Result + where + Self: Sized, + { + if self.encoder.is_some() { + Err(self) + } else { + match self.body.try_into_bytes() { + Ok(body) => Ok(body), + Err(body) => { + self.body = body; + Err(self) + } + } + } + } } fn update_head(encoding: ContentEncoding, head: &mut ResponseHead) { - head.headers_mut().insert( - CONTENT_ENCODING, - HeaderValue::from_static(encoding.as_str()), - ); + head.headers_mut() + .insert(header::CONTENT_ENCODING, encoding.to_header_value()); + head.headers_mut() + .insert(header::VARY, HeaderValue::from_static("accept-encoding")); + + head.no_chunking(false); } enum ContentEncoder { + #[cfg(feature = "compress-gzip")] Deflate(ZlibEncoder), + + #[cfg(feature = "compress-gzip")] Gzip(GzEncoder), - Br(BrotliEncoder), + + #[cfg(feature = "compress-brotli")] + Brotli(BrotliEncoder), + + // Wwe need explicit 'static lifetime here because ZstdEncoder needs a lifetime argument and we + // use `spawn_blocking` in `Encoder::poll_next` that requires `FnOnce() -> R + Send + 'static`. + #[cfg(feature = "compress-zstd")] + Zstd(ZstdEncoder<'static, Writer>), } impl ContentEncoder { - fn encoder(encoding: ContentEncoding) -> Option { + fn select(encoding: ContentEncoding) -> Option { match encoding { + #[cfg(feature = "compress-gzip")] ContentEncoding::Deflate => Some(ContentEncoder::Deflate(ZlibEncoder::new( Writer::new(), flate2::Compression::fast(), ))), + + #[cfg(feature = "compress-gzip")] ContentEncoding::Gzip => Some(ContentEncoder::Gzip(GzEncoder::new( Writer::new(), flate2::Compression::fast(), ))), - ContentEncoding::Br => { - Some(ContentEncoder::Br(BrotliEncoder::new(Writer::new(), 3))) + + #[cfg(feature = "compress-brotli")] + ContentEncoding::Brotli => { + Some(ContentEncoder::Brotli(BrotliEncoder::new(Writer::new(), 3))) } + + #[cfg(feature = "compress-zstd")] + ContentEncoding::Zstd => { + let encoder = ZstdEncoder::new(Writer::new(), 3).ok()?; + Some(ContentEncoder::Zstd(encoder)) + } + _ => None, } } @@ -233,38 +309,60 @@ impl ContentEncoder { #[inline] pub(crate) fn take(&mut self) -> Bytes { match *self { - ContentEncoder::Br(ref mut encoder) => encoder.get_mut().take(), + #[cfg(feature = "compress-brotli")] + ContentEncoder::Brotli(ref mut encoder) => encoder.get_mut().take(), + + #[cfg(feature = "compress-gzip")] ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().take(), + + #[cfg(feature = "compress-gzip")] ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().take(), + + #[cfg(feature = "compress-zstd")] + ContentEncoder::Zstd(ref mut encoder) => encoder.get_mut().take(), } } fn finish(self) -> Result { match self { - ContentEncoder::Br(encoder) => match encoder.finish() { + #[cfg(feature = "compress-brotli")] + ContentEncoder::Brotli(encoder) => match encoder.finish() { Ok(writer) => Ok(writer.buf.freeze()), Err(err) => Err(err), }, + + #[cfg(feature = "compress-gzip")] ContentEncoder::Gzip(encoder) => match encoder.finish() { Ok(writer) => Ok(writer.buf.freeze()), Err(err) => Err(err), }, + + #[cfg(feature = "compress-gzip")] ContentEncoder::Deflate(encoder) => match encoder.finish() { Ok(writer) => Ok(writer.buf.freeze()), Err(err) => Err(err), }, + + #[cfg(feature = "compress-zstd")] + ContentEncoder::Zstd(encoder) => match encoder.finish() { + Ok(writer) => Ok(writer.buf.freeze()), + Err(err) => Err(err), + }, } } fn write(&mut self, data: &[u8]) -> Result<(), io::Error> { match *self { - ContentEncoder::Br(ref mut encoder) => match encoder.write_all(data) { + #[cfg(feature = "compress-brotli")] + ContentEncoder::Brotli(ref mut encoder) => match encoder.write_all(data) { Ok(_) => Ok(()), Err(err) => { trace!("Error decoding br encoding: {}", err); Err(err) } }, + + #[cfg(feature = "compress-gzip")] ContentEncoder::Gzip(ref mut encoder) => match encoder.write_all(data) { Ok(_) => Ok(()), Err(err) => { @@ -272,6 +370,8 @@ impl ContentEncoder { Err(err) } }, + + #[cfg(feature = "compress-gzip")] ContentEncoder::Deflate(ref mut encoder) => match encoder.write_all(data) { Ok(_) => Ok(()), Err(err) => { @@ -279,6 +379,44 @@ impl ContentEncoder { Err(err) } }, + + #[cfg(feature = "compress-zstd")] + ContentEncoder::Zstd(ref mut encoder) => match encoder.write_all(data) { + Ok(_) => Ok(()), + Err(err) => { + trace!("Error decoding ztsd encoding: {}", err); + Err(err) + } + }, } } } + +#[derive(Debug, Display)] +#[non_exhaustive] +pub enum EncoderError { + #[display(fmt = "body")] + Body(Box), + + #[display(fmt = "blocking")] + Blocking(BlockingError), + + #[display(fmt = "io")] + Io(io::Error), +} + +impl StdError for EncoderError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + EncoderError::Body(err) => Some(&**err), + EncoderError::Blocking(err) => Some(err), + EncoderError::Io(err) => Some(err), + } + } +} + +impl From for crate::Error { + fn from(err: EncoderError) -> Self { + crate::Error::new_encoder().with_cause(err) + } +} diff --git a/actix-http/src/encoding/mod.rs b/actix-http/src/encoding/mod.rs index cb271c638..d51dd66c0 100644 --- a/actix-http/src/encoding/mod.rs +++ b/actix-http/src/encoding/mod.rs @@ -10,6 +10,9 @@ mod encoder; pub use self::decoder::Decoder; pub use self::encoder::Encoder; +/// Special-purpose writer for streaming (de-)compression. +/// +/// Pre-allocates 8KiB of capacity. pub(self) struct Writer { buf: BytesMut, } diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs index 6a9de2d3e..cdf495c45 100644 --- a/actix-http/src/error.rs +++ b/actix-http/src/error.rs @@ -1,270 +1,202 @@ //! Error and Result module -use std::cell::RefCell; -use std::io::Write; -use std::str::Utf8Error; -use std::string::FromUtf8Error; -use std::{fmt, io, result}; +use std::{error::Error as StdError, fmt, io, str::Utf8Error, string::FromUtf8Error}; -use actix_codec::{Decoder, Encoder}; -use actix_utils::dispatcher::DispatcherError as FramedDispatcherError; -use actix_utils::timeout::TimeoutError; -use bytes::BytesMut; -use derive_more::{Display, From}; -pub use futures_channel::oneshot::Canceled; -use http::uri::InvalidUri; -use http::{header, Error as HttpError, StatusCode}; -use serde::de::value::Error as DeError; -use serde_json::error::Error as JsonError; -use serde_urlencoded::ser::Error as FormError; +use derive_more::{Display, Error, From}; +use http::{uri::InvalidUri, StatusCode}; -use crate::body::Body; -use crate::helpers::Writer; -use crate::response::{Response, ResponseBuilder}; +use crate::{body::BoxBody, ws, Response}; -#[cfg(feature = "cookies")] -pub use crate::cookie::ParseError as CookieParseError; +pub use http::Error as HttpError; -/// A specialized [`std::result::Result`] -/// for actix web operations -/// -/// This typedef is generally used to avoid writing out -/// `actix_http::error::Error` directly and is otherwise a direct mapping to -/// `Result`. -pub type Result = result::Result; - -/// General purpose actix web error. -/// -/// An actix web error is used to carry errors from `std::error` -/// through actix in a convenient way. It can be created through -/// converting errors with `into()`. -/// -/// Whenever it is created from an external object a response error is created -/// for it that can be used to create an HTTP response from it this means that -/// if you have access to an actix `Error` you can always get a -/// `ResponseError` reference from it. pub struct Error { - cause: Box, + inner: Box, +} + +pub(crate) struct ErrorInner { + #[allow(dead_code)] + kind: Kind, + cause: Option>, } impl Error { - /// Returns the reference to the underlying `ResponseError`. - pub fn as_response_error(&self) -> &dyn ResponseError { - self.cause.as_ref() + fn new(kind: Kind) -> Self { + Self { + inner: Box::new(ErrorInner { kind, cause: None }), + } } - /// Similar to `as_response_error` but downcasts. - pub fn as_error(&self) -> Option<&T> { - ResponseError::downcast_ref(self.cause.as_ref()) + pub(crate) fn with_cause(mut self, cause: impl Into>) -> Self { + self.inner.cause = Some(cause.into()); + self + } + + pub(crate) fn new_http() -> Self { + Self::new(Kind::Http) + } + + pub(crate) fn new_parse() -> Self { + Self::new(Kind::Parse) + } + + pub(crate) fn new_payload() -> Self { + Self::new(Kind::Payload) + } + + pub(crate) fn new_body() -> Self { + Self::new(Kind::Body) + } + + pub(crate) fn new_send_response() -> Self { + Self::new(Kind::SendResponse) + } + + #[allow(unused)] // reserved for future use (TODO: remove allow when being used) + pub(crate) fn new_io() -> Self { + Self::new(Kind::Io) + } + + #[allow(unused)] // used in encoder behind feature flag so ignore unused warning + pub(crate) fn new_encoder() -> Self { + Self::new(Kind::Encoder) + } + + pub(crate) fn new_ws() -> Self { + Self::new(Kind::Ws) } } -/// Error that can be converted to `Response` -pub trait ResponseError: fmt::Debug + fmt::Display { - /// Response's status code - /// - /// Internal server error is generated by default. - fn status_code(&self) -> StatusCode { - StatusCode::INTERNAL_SERVER_ERROR - } +impl From for Response { + fn from(err: Error) -> Self { + // TODO: more appropriate error status codes, usage assessment needed + let status_code = match err.inner.kind { + Kind::Parse => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; - /// Create response for error - /// - /// Internal server error is generated by default. - fn error_response(&self) -> Response { - let mut resp = Response::new(self.status_code()); - let mut buf = BytesMut::new(); - let _ = write!(Writer(&mut buf), "{}", self); - resp.headers_mut().insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain; charset=utf-8"), - ); - resp.set_body(Body::from(buf)) + Response::new(status_code).set_body(BoxBody::new(err.to_string())) } - - downcast_get_type_id!(); } -downcast!(ResponseError); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] +pub(crate) enum Kind { + #[display(fmt = "error processing HTTP")] + Http, -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.cause, f) - } + #[display(fmt = "error parsing HTTP message")] + Parse, + + #[display(fmt = "request payload read error")] + Payload, + + #[display(fmt = "response body write error")] + Body, + + #[display(fmt = "send response error")] + SendResponse, + + #[display(fmt = "error in WebSocket process")] + Ws, + + #[display(fmt = "connection error")] + Io, + + #[display(fmt = "encoder error")] + Encoder, } impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{:?}", &self.cause) + // TODO: more detail + f.write_str("actix_http::Error") } } -impl std::error::Error for Error { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - None +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.inner.cause.as_ref() { + Some(err) => write!(f, "{}: {}", &self.inner.kind, err), + None => write!(f, "{}", &self.inner.kind), + } } } -impl From<()> for Error { - fn from(_: ()) -> Self { - Error::from(UnitError) +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.inner.cause.as_ref().map(Box::as_ref) } } impl From for Error { - fn from(_: std::convert::Infallible) -> Self { - // `std::convert::Infallible` indicates an error - // that will never happen - unreachable!() + fn from(err: std::convert::Infallible) -> Self { + match err {} } } -/// Convert `Error` to a `Response` instance -impl From for Response { - fn from(err: Error) -> Self { - Response::from_error(err) +impl From for Error { + fn from(err: HttpError) -> Self { + Self::new_http().with_cause(err) } } -/// `Error` for any error that implements `ResponseError` -impl From for Error { - fn from(err: T) -> Error { - Error { - cause: Box::new(err), - } +impl From for Error { + fn from(err: ws::HandshakeError) -> Self { + Self::new_ws().with_cause(err) } } -/// Convert Response to a Error -impl From for Error { - fn from(res: Response) -> Error { - InternalError::from_response("", res).into() +impl From for Error { + fn from(err: ws::ProtocolError) -> Self { + Self::new_ws().with_cause(err) } } -/// Convert ResponseBuilder to a Error -impl From for Error { - fn from(mut res: ResponseBuilder) -> Error { - InternalError::from_response("", res.finish()).into() - } -} - -/// Inspects the underlying enum and returns an appropriate status code. -/// -/// If the variant is [`TimeoutError::Service`], the error code of the service is returned. -/// Otherwise, [`StatusCode::GATEWAY_TIMEOUT`] is returned. -impl ResponseError for TimeoutError { - fn status_code(&self) -> StatusCode { - match self { - TimeoutError::Service(e) => e.status_code(), - TimeoutError::Timeout => StatusCode::GATEWAY_TIMEOUT, - } - } -} - -#[derive(Debug, Display)] -#[display(fmt = "UnknownError")] -struct UnitError; - -/// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`UnitError`]. -impl ResponseError for UnitError {} - -/// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`JsonError`]. -impl ResponseError for JsonError {} - -/// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`FormError`]. -impl ResponseError for FormError {} - -#[cfg(feature = "openssl")] -/// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`actix_tls::accept::openssl::SslError`]. -impl ResponseError for actix_tls::accept::openssl::SslError {} - -/// Returns [`StatusCode::BAD_REQUEST`] for [`DeError`]. -impl ResponseError for DeError { - fn status_code(&self) -> StatusCode { - StatusCode::BAD_REQUEST - } -} - -/// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`Canceled`]. -impl ResponseError for Canceled {} - -/// Returns [`StatusCode::BAD_REQUEST`] for [`Utf8Error`]. -impl ResponseError for Utf8Error { - fn status_code(&self) -> StatusCode { - StatusCode::BAD_REQUEST - } -} - -/// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`HttpError`]. -impl ResponseError for HttpError {} - -/// Inspects the underlying [`io::ErrorKind`] and returns an appropriate status code. -/// -/// If the error is [`io::ErrorKind::NotFound`], [`StatusCode::NOT_FOUND`] is returned. If the -/// error is [`io::ErrorKind::PermissionDenied`], [`StatusCode::FORBIDDEN`] is returned. Otherwise, -/// [`StatusCode::INTERNAL_SERVER_ERROR`] is returned. -impl ResponseError for io::Error { - fn status_code(&self) -> StatusCode { - match self.kind() { - io::ErrorKind::NotFound => StatusCode::NOT_FOUND, - io::ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, - _ => StatusCode::INTERNAL_SERVER_ERROR, - } - } -} - -/// Returns [`StatusCode::BAD_REQUEST`] for [`header::InvalidHeaderValue`]. -impl ResponseError for header::InvalidHeaderValue { - fn status_code(&self) -> StatusCode { - StatusCode::BAD_REQUEST - } -} - -/// A set of errors that can occur during parsing HTTP streams -#[derive(Debug, Display)] +/// A set of errors that can occur during parsing HTTP streams. +#[derive(Debug, Display, Error)] +#[non_exhaustive] pub enum ParseError { /// An invalid `Method`, such as `GE.T`. #[display(fmt = "Invalid Method specified")] Method, + /// An invalid `Uri`, such as `exam ple.domain`. #[display(fmt = "Uri error: {}", _0)] Uri(InvalidUri), + /// An invalid `HttpVersion`, such as `HTP/1.1` #[display(fmt = "Invalid HTTP version specified")] Version, + /// An invalid `Header`. #[display(fmt = "Invalid Header provided")] Header, + /// A message head is too large to be reasonable. #[display(fmt = "Message head is too large")] TooLarge, + /// A message reached EOF, but is not complete. #[display(fmt = "Message is incomplete")] Incomplete, + /// An invalid `Status`, such as `1337 ELITE`. #[display(fmt = "Invalid Status provided")] Status, + /// A timeout occurred waiting for an IO event. #[allow(dead_code)] #[display(fmt = "Timeout")] Timeout, - /// An `io::Error` that occurred while trying to read or write to a network - /// stream. + + /// An `io::Error` that occurred while trying to read or write to a network stream. #[display(fmt = "IO error: {}", _0)] Io(io::Error), - /// Parsing a field as string failed + + /// Parsing a field as string failed. #[display(fmt = "UTF8 error: {}", _0)] Utf8(Utf8Error), } -/// Return `BadRequest` for `ParseError` -impl ResponseError for ParseError { - fn status_code(&self) -> StatusCode { - StatusCode::BAD_REQUEST - } -} - impl From for ParseError { fn from(err: io::Error) -> ParseError { ParseError::Io(err) @@ -303,18 +235,27 @@ impl From for ParseError { } } +impl From for Error { + fn from(err: ParseError) -> Self { + Self::new_parse().with_cause(err) + } +} + +impl From for Response { + fn from(err: ParseError) -> Self { + Error::from(err).into() + } +} + /// A set of errors that can occur running blocking tasks in thread pool. -#[derive(Debug, Display)] +#[derive(Debug, Display, Error)] #[display(fmt = "Blocking thread pool is gone")] +// TODO: non-exhaustive pub struct BlockingError; -impl std::error::Error for BlockingError {} - -/// `InternalServerError` for `BlockingError` -impl ResponseError for BlockingError {} - -#[derive(Display, Debug)] -/// A set of errors that can occur during payload parsing +/// A set of errors that can occur during payload parsing. +#[derive(Debug, Display)] +#[non_exhaustive] pub enum PayloadError { /// A payload reached EOF, but is not complete. #[display( @@ -385,46 +326,35 @@ impl From for PayloadError { } } -/// `PayloadError` returns two possible results: -/// -/// - `Overflow` returns `PayloadTooLarge` -/// - Other errors returns `BadRequest` -impl ResponseError for PayloadError { - fn status_code(&self) -> StatusCode { - match *self { - PayloadError::Overflow => StatusCode::PAYLOAD_TOO_LARGE, - _ => StatusCode::BAD_REQUEST, - } - } -} - -/// Return `BadRequest` for `cookie::ParseError` -#[cfg(feature = "cookies")] -impl ResponseError for crate::cookie::ParseError { - fn status_code(&self) -> StatusCode { - StatusCode::BAD_REQUEST +impl From for Error { + fn from(err: PayloadError) -> Self { + Self::new_payload().with_cause(err) } } +/// A set of errors that can occur during dispatching HTTP requests. #[derive(Debug, Display, From)] -/// A set of errors that can occur during dispatching HTTP requests pub enum DispatchError { - /// Service error - Service(Error), + /// Service error. + #[display(fmt = "Service Error")] + Service(Response), - /// Upgrade service error + /// Body streaming error. + #[display(fmt = "Body error: {}", _0)] + Body(Box), + + /// Upgrade service error. Upgrade, - /// An `io::Error` that occurred while trying to read or write to a network - /// stream. + /// An `io::Error` that occurred while trying to read or write to a network stream. #[display(fmt = "IO error: {}", _0)] Io(io::Error), - /// Http request parse error. - #[display(fmt = "Parse error: {}", _0)] + /// Request parse error. + #[display(fmt = "Request parse error: {}", _0)] Parse(ParseError), - /// Http/2 error + /// HTTP/2 error. #[display(fmt = "{}", _0)] H2(h2::Error), @@ -436,539 +366,52 @@ pub enum DispatchError { #[display(fmt = "Connection shutdown timeout")] DisconnectTimeout, - /// Payload is not consumed - #[display(fmt = "Task is completed but request's payload is not consumed")] - PayloadIsNotConsumed, - - /// Malformed request - #[display(fmt = "Malformed request")] - MalformedRequest, - - /// Internal error + /// Internal error. #[display(fmt = "Internal error")] InternalError, - - /// Unknown error - #[display(fmt = "Unknown error")] - Unknown, } -/// A set of error that can occur during parsing content type -#[derive(PartialEq, Debug, Display)] +impl StdError for DispatchError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + // TODO: error source extraction? + DispatchError::Service(_res) => None, + DispatchError::Body(err) => Some(&**err), + DispatchError::Io(err) => Some(err), + DispatchError::Parse(err) => Some(err), + DispatchError::H2(err) => Some(err), + _ => None, + } + } +} + +/// A set of error that can occur during parsing content type. +#[derive(Debug, Display, Error)] +#[non_exhaustive] pub enum ContentTypeError { /// Can not parse content type #[display(fmt = "Can not parse content type")] ParseError, + /// Unknown content encoding #[display(fmt = "Unknown content encoding")] UnknownEncoding, } -impl std::error::Error for ContentTypeError {} +#[cfg(test)] +mod content_type_test_impls { + use super::*; -/// Return `BadRequest` for `ContentTypeError` -impl ResponseError for ContentTypeError { - fn status_code(&self) -> StatusCode { - StatusCode::BAD_REQUEST - } -} - -impl + Decoder, I> ResponseError for FramedDispatcherError -where - E: fmt::Debug + fmt::Display, - >::Error: fmt::Debug, - ::Error: fmt::Debug, -{ -} - -/// Helper type that can wrap any error and generate custom response. -/// -/// In following example any `io::Error` will be converted into "BAD REQUEST" -/// response as opposite to *INTERNAL SERVER ERROR* which is defined by -/// default. -/// -/// ```rust -/// # use std::io; -/// # use actix_http::*; -/// -/// fn index(req: Request) -> Result<&'static str> { -/// Err(error::ErrorBadRequest(io::Error::new(io::ErrorKind::Other, "error"))) -/// } -/// ``` -pub struct InternalError { - cause: T, - status: InternalErrorType, -} - -enum InternalErrorType { - Status(StatusCode), - Response(RefCell>), -} - -impl InternalError { - /// Create `InternalError` instance - pub fn new(cause: T, status: StatusCode) -> Self { - InternalError { - cause, - status: InternalErrorType::Status(status), - } - } - - /// Create `InternalError` with predefined `Response`. - pub fn from_response(cause: T, response: Response) -> Self { - InternalError { - cause, - status: InternalErrorType::Response(RefCell::new(Some(response))), - } - } -} - -impl fmt::Debug for InternalError -where - T: fmt::Debug + 'static, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Debug::fmt(&self.cause, f) - } -} - -impl fmt::Display for InternalError -where - T: fmt::Display + 'static, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.cause, f) - } -} - -impl ResponseError for InternalError -where - T: fmt::Debug + fmt::Display + 'static, -{ - fn status_code(&self) -> StatusCode { - match self.status { - InternalErrorType::Status(st) => st, - InternalErrorType::Response(ref resp) => { - if let Some(resp) = resp.borrow().as_ref() { - resp.head().status - } else { - StatusCode::INTERNAL_SERVER_ERROR + impl std::cmp::PartialEq for ContentTypeError { + fn eq(&self, other: &Self) -> bool { + match self { + Self::ParseError => matches!(other, ContentTypeError::ParseError), + Self::UnknownEncoding => { + matches!(other, ContentTypeError::UnknownEncoding) } } } } - - fn error_response(&self) -> Response { - match self.status { - InternalErrorType::Status(st) => { - let mut res = Response::new(st); - let mut buf = BytesMut::new(); - let _ = write!(Writer(&mut buf), "{}", self); - res.headers_mut().insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("text/plain; charset=utf-8"), - ); - res.set_body(Body::from(buf)) - } - InternalErrorType::Response(ref resp) => { - if let Some(resp) = resp.borrow_mut().take() { - resp - } else { - Response::new(StatusCode::INTERNAL_SERVER_ERROR) - } - } - } - } -} - -/// Helper function that creates wrapper of any error and generate *BAD -/// REQUEST* response. -#[allow(non_snake_case)] -pub fn ErrorBadRequest(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::BAD_REQUEST).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *UNAUTHORIZED* response. -#[allow(non_snake_case)] -pub fn ErrorUnauthorized(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::UNAUTHORIZED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *PAYMENT_REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorPaymentRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::PAYMENT_REQUIRED).into() -} - -/// Helper function that creates wrapper of any error and generate *FORBIDDEN* -/// response. -#[allow(non_snake_case)] -pub fn ErrorForbidden(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::FORBIDDEN).into() -} - -/// Helper function that creates wrapper of any error and generate *NOT FOUND* -/// response. -#[allow(non_snake_case)] -pub fn ErrorNotFound(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::NOT_FOUND).into() -} - -/// Helper function that creates wrapper of any error and generate *METHOD NOT -/// ALLOWED* response. -#[allow(non_snake_case)] -pub fn ErrorMethodNotAllowed(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::METHOD_NOT_ALLOWED).into() -} - -/// Helper function that creates wrapper of any error and generate *NOT -/// ACCEPTABLE* response. -#[allow(non_snake_case)] -pub fn ErrorNotAcceptable(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::NOT_ACCEPTABLE).into() -} - -/// Helper function that creates wrapper of any error and generate *PROXY -/// AUTHENTICATION REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorProxyAuthenticationRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::PROXY_AUTHENTICATION_REQUIRED).into() -} - -/// Helper function that creates wrapper of any error and generate *REQUEST -/// TIMEOUT* response. -#[allow(non_snake_case)] -pub fn ErrorRequestTimeout(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::REQUEST_TIMEOUT).into() -} - -/// Helper function that creates wrapper of any error and generate *CONFLICT* -/// response. -#[allow(non_snake_case)] -pub fn ErrorConflict(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::CONFLICT).into() -} - -/// Helper function that creates wrapper of any error and generate *GONE* -/// response. -#[allow(non_snake_case)] -pub fn ErrorGone(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::GONE).into() -} - -/// Helper function that creates wrapper of any error and generate *LENGTH -/// REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorLengthRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::LENGTH_REQUIRED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *PAYLOAD TOO LARGE* response. -#[allow(non_snake_case)] -pub fn ErrorPayloadTooLarge(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::PAYLOAD_TOO_LARGE).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *URI TOO LONG* response. -#[allow(non_snake_case)] -pub fn ErrorUriTooLong(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::URI_TOO_LONG).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *UNSUPPORTED MEDIA TYPE* response. -#[allow(non_snake_case)] -pub fn ErrorUnsupportedMediaType(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::UNSUPPORTED_MEDIA_TYPE).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *RANGE NOT SATISFIABLE* response. -#[allow(non_snake_case)] -pub fn ErrorRangeNotSatisfiable(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::RANGE_NOT_SATISFIABLE).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *IM A TEAPOT* response. -#[allow(non_snake_case)] -pub fn ErrorImATeapot(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::IM_A_TEAPOT).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *MISDIRECTED REQUEST* response. -#[allow(non_snake_case)] -pub fn ErrorMisdirectedRequest(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::MISDIRECTED_REQUEST).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *UNPROCESSABLE ENTITY* response. -#[allow(non_snake_case)] -pub fn ErrorUnprocessableEntity(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::UNPROCESSABLE_ENTITY).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *LOCKED* response. -#[allow(non_snake_case)] -pub fn ErrorLocked(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::LOCKED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *FAILED DEPENDENCY* response. -#[allow(non_snake_case)] -pub fn ErrorFailedDependency(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::FAILED_DEPENDENCY).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *UPGRADE REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorUpgradeRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::UPGRADE_REQUIRED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *PRECONDITION FAILED* response. -#[allow(non_snake_case)] -pub fn ErrorPreconditionFailed(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::PRECONDITION_FAILED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *PRECONDITION REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorPreconditionRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::PRECONDITION_REQUIRED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *TOO MANY REQUESTS* response. -#[allow(non_snake_case)] -pub fn ErrorTooManyRequests(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::TOO_MANY_REQUESTS).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *REQUEST HEADER FIELDS TOO LARGE* response. -#[allow(non_snake_case)] -pub fn ErrorRequestHeaderFieldsTooLarge(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *UNAVAILABLE FOR LEGAL REASONS* response. -#[allow(non_snake_case)] -pub fn ErrorUnavailableForLegalReasons(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *EXPECTATION FAILED* response. -#[allow(non_snake_case)] -pub fn ErrorExpectationFailed(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::EXPECTATION_FAILED).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *INTERNAL SERVER ERROR* response. -#[allow(non_snake_case)] -pub fn ErrorInternalServerError(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::INTERNAL_SERVER_ERROR).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *NOT IMPLEMENTED* response. -#[allow(non_snake_case)] -pub fn ErrorNotImplemented(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::NOT_IMPLEMENTED).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *BAD GATEWAY* response. -#[allow(non_snake_case)] -pub fn ErrorBadGateway(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::BAD_GATEWAY).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *SERVICE UNAVAILABLE* response. -#[allow(non_snake_case)] -pub fn ErrorServiceUnavailable(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::SERVICE_UNAVAILABLE).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *GATEWAY TIMEOUT* response. -#[allow(non_snake_case)] -pub fn ErrorGatewayTimeout(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::GATEWAY_TIMEOUT).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *HTTP VERSION NOT SUPPORTED* response. -#[allow(non_snake_case)] -pub fn ErrorHttpVersionNotSupported(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::HTTP_VERSION_NOT_SUPPORTED).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *VARIANT ALSO NEGOTIATES* response. -#[allow(non_snake_case)] -pub fn ErrorVariantAlsoNegotiates(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::VARIANT_ALSO_NEGOTIATES).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *INSUFFICIENT STORAGE* response. -#[allow(non_snake_case)] -pub fn ErrorInsufficientStorage(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::INSUFFICIENT_STORAGE).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *LOOP DETECTED* response. -#[allow(non_snake_case)] -pub fn ErrorLoopDetected(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::LOOP_DETECTED).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *NOT EXTENDED* response. -#[allow(non_snake_case)] -pub fn ErrorNotExtended(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::NOT_EXTENDED).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *NETWORK AUTHENTICATION REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorNetworkAuthenticationRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::NETWORK_AUTHENTICATION_REQUIRED).into() } #[cfg(test)] @@ -979,55 +422,42 @@ mod tests { #[test] fn test_into_response() { - let resp: Response = ParseError::Incomplete.error_response(); + let resp: Response = ParseError::Incomplete.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let err: HttpError = StatusCode::from_u16(10000).err().unwrap().into(); - let resp: Response = err.error_response(); + let resp: Response = Error::new_http().with_cause(err).into(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } - #[test] - fn test_cookie_parse() { - let resp: Response = CookieParseError::EmptyName.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - } - #[test] fn test_as_response() { let orig = io::Error::new(io::ErrorKind::Other, "other"); - let e: Error = ParseError::Io(orig).into(); - assert_eq!(format!("{}", e.as_response_error()), "IO error: other"); - } - - #[test] - fn test_error_cause() { - let orig = io::Error::new(io::ErrorKind::Other, "other"); - let desc = orig.to_string(); - let e = Error::from(orig); - assert_eq!(format!("{}", e.as_response_error()), desc); + let err: Error = ParseError::Io(orig).into(); + assert_eq!( + format!("{}", err), + "error parsing HTTP message: IO error: other" + ); } #[test] fn test_error_display() { let orig = io::Error::new(io::ErrorKind::Other, "other"); - let desc = orig.to_string(); - let e = Error::from(orig); - assert_eq!(format!("{}", e), desc); + let err = Error::new_io().with_cause(orig); + assert_eq!("connection error: other", err.to_string()); } #[test] fn test_error_http_response() { let orig = io::Error::new(io::ErrorKind::Other, "other"); - let e = Error::from(orig); - let resp: Response = e.into(); + let err = Error::new_io().with_cause(orig); + let resp: Response = err.into(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } #[test] fn test_payload_error() { - let err: PayloadError = - io::Error::new(io::ErrorKind::Other, "ParseError").into(); + let err: PayloadError = io::Error::new(io::ErrorKind::Other, "ParseError").into(); assert!(err.to_string().contains("ParseError")); let err = PayloadError::Incomplete(None); @@ -1072,142 +502,4 @@ mod tests { from!(httparse::Error::TooManyHeaders => ParseError::TooLarge); from!(httparse::Error::Version => ParseError::Version); } - - #[test] - fn test_internal_error() { - let err = - InternalError::from_response(ParseError::Method, Response::Ok().into()); - let resp: Response = err.error_response(); - assert_eq!(resp.status(), StatusCode::OK); - } - - #[test] - fn test_error_casting() { - let err = PayloadError::Overflow; - let resp_err: &dyn ResponseError = &err; - let err = resp_err.downcast_ref::().unwrap(); - assert_eq!(err.to_string(), "Payload reached size limit."); - let not_err = resp_err.downcast_ref::(); - assert!(not_err.is_none()); - } - - #[test] - fn test_error_helpers() { - let r: Response = ErrorBadRequest("err").into(); - assert_eq!(r.status(), StatusCode::BAD_REQUEST); - - let r: Response = ErrorUnauthorized("err").into(); - assert_eq!(r.status(), StatusCode::UNAUTHORIZED); - - let r: Response = ErrorPaymentRequired("err").into(); - assert_eq!(r.status(), StatusCode::PAYMENT_REQUIRED); - - let r: Response = ErrorForbidden("err").into(); - assert_eq!(r.status(), StatusCode::FORBIDDEN); - - let r: Response = ErrorNotFound("err").into(); - assert_eq!(r.status(), StatusCode::NOT_FOUND); - - let r: Response = ErrorMethodNotAllowed("err").into(); - assert_eq!(r.status(), StatusCode::METHOD_NOT_ALLOWED); - - let r: Response = ErrorNotAcceptable("err").into(); - assert_eq!(r.status(), StatusCode::NOT_ACCEPTABLE); - - let r: Response = ErrorProxyAuthenticationRequired("err").into(); - assert_eq!(r.status(), StatusCode::PROXY_AUTHENTICATION_REQUIRED); - - let r: Response = ErrorRequestTimeout("err").into(); - assert_eq!(r.status(), StatusCode::REQUEST_TIMEOUT); - - let r: Response = ErrorConflict("err").into(); - assert_eq!(r.status(), StatusCode::CONFLICT); - - let r: Response = ErrorGone("err").into(); - assert_eq!(r.status(), StatusCode::GONE); - - let r: Response = ErrorLengthRequired("err").into(); - assert_eq!(r.status(), StatusCode::LENGTH_REQUIRED); - - let r: Response = ErrorPreconditionFailed("err").into(); - assert_eq!(r.status(), StatusCode::PRECONDITION_FAILED); - - let r: Response = ErrorPayloadTooLarge("err").into(); - assert_eq!(r.status(), StatusCode::PAYLOAD_TOO_LARGE); - - let r: Response = ErrorUriTooLong("err").into(); - assert_eq!(r.status(), StatusCode::URI_TOO_LONG); - - let r: Response = ErrorUnsupportedMediaType("err").into(); - assert_eq!(r.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE); - - let r: Response = ErrorRangeNotSatisfiable("err").into(); - assert_eq!(r.status(), StatusCode::RANGE_NOT_SATISFIABLE); - - let r: Response = ErrorExpectationFailed("err").into(); - assert_eq!(r.status(), StatusCode::EXPECTATION_FAILED); - - let r: Response = ErrorImATeapot("err").into(); - assert_eq!(r.status(), StatusCode::IM_A_TEAPOT); - - let r: Response = ErrorMisdirectedRequest("err").into(); - assert_eq!(r.status(), StatusCode::MISDIRECTED_REQUEST); - - let r: Response = ErrorUnprocessableEntity("err").into(); - assert_eq!(r.status(), StatusCode::UNPROCESSABLE_ENTITY); - - let r: Response = ErrorLocked("err").into(); - assert_eq!(r.status(), StatusCode::LOCKED); - - let r: Response = ErrorFailedDependency("err").into(); - assert_eq!(r.status(), StatusCode::FAILED_DEPENDENCY); - - let r: Response = ErrorUpgradeRequired("err").into(); - assert_eq!(r.status(), StatusCode::UPGRADE_REQUIRED); - - let r: Response = ErrorPreconditionRequired("err").into(); - assert_eq!(r.status(), StatusCode::PRECONDITION_REQUIRED); - - let r: Response = ErrorTooManyRequests("err").into(); - assert_eq!(r.status(), StatusCode::TOO_MANY_REQUESTS); - - let r: Response = ErrorRequestHeaderFieldsTooLarge("err").into(); - assert_eq!(r.status(), StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE); - - let r: Response = ErrorUnavailableForLegalReasons("err").into(); - assert_eq!(r.status(), StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS); - - let r: Response = ErrorInternalServerError("err").into(); - assert_eq!(r.status(), StatusCode::INTERNAL_SERVER_ERROR); - - let r: Response = ErrorNotImplemented("err").into(); - assert_eq!(r.status(), StatusCode::NOT_IMPLEMENTED); - - let r: Response = ErrorBadGateway("err").into(); - assert_eq!(r.status(), StatusCode::BAD_GATEWAY); - - let r: Response = ErrorServiceUnavailable("err").into(); - assert_eq!(r.status(), StatusCode::SERVICE_UNAVAILABLE); - - let r: Response = ErrorGatewayTimeout("err").into(); - assert_eq!(r.status(), StatusCode::GATEWAY_TIMEOUT); - - let r: Response = ErrorHttpVersionNotSupported("err").into(); - assert_eq!(r.status(), StatusCode::HTTP_VERSION_NOT_SUPPORTED); - - let r: Response = ErrorVariantAlsoNegotiates("err").into(); - assert_eq!(r.status(), StatusCode::VARIANT_ALSO_NEGOTIATES); - - let r: Response = ErrorInsufficientStorage("err").into(); - assert_eq!(r.status(), StatusCode::INSUFFICIENT_STORAGE); - - let r: Response = ErrorLoopDetected("err").into(); - assert_eq!(r.status(), StatusCode::LOOP_DETECTED); - - let r: Response = ErrorNotExtended("err").into(); - assert_eq!(r.status(), StatusCode::NOT_EXTENDED); - - let r: Response = ErrorNetworkAuthenticationRequired("err").into(); - assert_eq!(r.status(), StatusCode::NETWORK_AUTHENTICATION_REQUIRED); - } } diff --git a/actix-http/src/extensions.rs b/actix-http/src/extensions.rs index 5fdcefd6d..60b769d13 100644 --- a/actix-http/src/extensions.rs +++ b/actix-http/src/extensions.rs @@ -1,6 +1,6 @@ use std::{ any::{Any, TypeId}, - fmt, mem, + fmt, }; use ahash::AHashMap; @@ -10,8 +10,7 @@ use ahash::AHashMap; /// All entries into this map must be owned types (or static references). #[derive(Default)] pub struct Extensions { - /// Use FxHasher with a std HashMap with for faster - /// lookups on the small `TypeId` (u64 equivalent) keys. + /// Use AHasher with a std HashMap with for faster lookups on the small `TypeId` keys. map: AHashMap>, } @@ -20,7 +19,7 @@ impl Extensions { #[inline] pub fn new() -> Extensions { Extensions { - map: AHashMap::default(), + map: AHashMap::new(), } } @@ -123,11 +122,6 @@ impl Extensions { pub fn extend(&mut self, other: Extensions) { self.map.extend(other.map); } - - /// Sets (or overrides) items from `other` into this map. - pub(crate) fn drain_from(&mut self, other: &mut Self) { - self.map.extend(mem::take(&mut other.map)); - } } impl fmt::Debug for Extensions { @@ -179,6 +173,8 @@ mod tests { #[test] fn test_integers() { + static A: u32 = 8; + let mut map = Extensions::new(); map.insert::(8); @@ -191,6 +187,7 @@ mod tests { map.insert::(32); map.insert::(64); map.insert::(128); + map.insert::<&'static u32>(&A); assert!(map.get::().is_some()); assert!(map.get::().is_some()); assert!(map.get::().is_some()); @@ -201,6 +198,7 @@ mod tests { assert!(map.get::().is_some()); assert!(map.get::().is_some()); assert!(map.get::().is_some()); + assert!(map.get::<&'static u32>().is_some()); } #[test] @@ -279,27 +277,4 @@ mod tests { assert_eq!(extensions.get(), Some(&20u8)); assert_eq!(extensions.get_mut(), Some(&mut 20u8)); } - - #[test] - fn test_drain_from() { - let mut ext = Extensions::new(); - ext.insert(2isize); - - let mut more_ext = Extensions::new(); - - more_ext.insert(5isize); - more_ext.insert(5usize); - - assert_eq!(ext.get::(), Some(&2isize)); - assert_eq!(ext.get::(), None); - assert_eq!(more_ext.get::(), Some(&5isize)); - assert_eq!(more_ext.get::(), Some(&5usize)); - - ext.drain_from(&mut more_ext); - - assert_eq!(ext.get::(), Some(&5isize)); - assert_eq!(ext.get::(), Some(&5usize)); - assert_eq!(more_ext.get::(), None); - assert_eq!(more_ext.get::(), None); - } } diff --git a/actix-http/src/h1/chunked.rs b/actix-http/src/h1/chunked.rs new file mode 100644 index 000000000..7d0532fcd --- /dev/null +++ b/actix-http/src/h1/chunked.rs @@ -0,0 +1,426 @@ +use std::{io, task::Poll}; + +use bytes::{Buf as _, Bytes, BytesMut}; + +macro_rules! byte ( + ($rdr:ident) => ({ + if $rdr.len() > 0 { + let b = $rdr[0]; + $rdr.advance(1); + b + } else { + return Poll::Pending + } + }) +); + +#[derive(Debug, PartialEq, Clone)] +pub(super) enum ChunkedState { + Size, + SizeLws, + Extension, + SizeLf, + Body, + BodyCr, + BodyLf, + EndCr, + EndLf, + End, +} + +impl ChunkedState { + pub(super) fn step( + &self, + body: &mut BytesMut, + size: &mut u64, + buf: &mut Option, + ) -> Poll> { + use self::ChunkedState::*; + match *self { + Size => ChunkedState::read_size(body, size), + SizeLws => ChunkedState::read_size_lws(body), + Extension => ChunkedState::read_extension(body), + SizeLf => ChunkedState::read_size_lf(body, *size), + Body => ChunkedState::read_body(body, size, buf), + BodyCr => ChunkedState::read_body_cr(body), + BodyLf => ChunkedState::read_body_lf(body), + EndCr => ChunkedState::read_end_cr(body), + EndLf => ChunkedState::read_end_lf(body), + End => Poll::Ready(Ok(ChunkedState::End)), + } + } + + fn read_size(rdr: &mut BytesMut, size: &mut u64) -> Poll> { + let radix = 16; + + let rem = match byte!(rdr) { + b @ b'0'..=b'9' => b - b'0', + b @ b'a'..=b'f' => b + 10 - b'a', + b @ b'A'..=b'F' => b + 10 - b'A', + b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => return Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size line: Invalid Size", + ))); + } + }; + + match size.checked_mul(radix) { + Some(n) => { + *size = n as u64; + *size += rem as u64; + + Poll::Ready(Ok(ChunkedState::Size)) + } + None => { + log::debug!("chunk size would overflow u64"); + Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size line: Size is too big", + ))) + } + } + } + + fn read_size_lws(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + // LWS can follow the chunk size, but no more digits can come + b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), + b';' => Poll::Ready(Ok(ChunkedState::Extension)), + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size linear white space", + ))), + } + } + fn read_extension(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), + // strictly 0x20 (space) should be disallowed but we don't parse quoted strings here + 0x00..=0x08 | 0x0a..=0x1f | 0x7f => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid character in chunk extension", + ))), + _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions + } + } + fn read_size_lf(rdr: &mut BytesMut, size: u64) -> Poll> { + match byte!(rdr) { + b'\n' if size > 0 => Poll::Ready(Ok(ChunkedState::Body)), + b'\n' if size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size LF", + ))), + } + } + + fn read_body( + rdr: &mut BytesMut, + rem: &mut u64, + buf: &mut Option, + ) -> Poll> { + log::trace!("Chunked read, remaining={:?}", rem); + + let len = rdr.len() as u64; + if len == 0 { + Poll::Ready(Ok(ChunkedState::Body)) + } else { + let slice; + if *rem > len { + slice = rdr.split().freeze(); + *rem -= len; + } else { + slice = rdr.split_to(*rem as usize).freeze(); + *rem = 0; + } + *buf = Some(slice); + if *rem > 0 { + Poll::Ready(Ok(ChunkedState::Body)) + } else { + Poll::Ready(Ok(ChunkedState::BodyCr)) + } + } + } + + fn read_body_cr(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body CR", + ))), + } + } + fn read_body_lf(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\n' => Poll::Ready(Ok(ChunkedState::Size)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body LF", + ))), + } + } + fn read_end_cr(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end CR", + ))), + } + } + fn read_end_lf(rdr: &mut BytesMut) -> Poll> { + match byte!(rdr) { + b'\n' => Poll::Ready(Ok(ChunkedState::End)), + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end LF", + ))), + } + } +} + +#[cfg(test)] +mod tests { + use actix_codec::Decoder as _; + use bytes::{Bytes, BytesMut}; + use http::Method; + + use crate::{ + error::ParseError, + h1::decoder::{MessageDecoder, PayloadItem}, + HttpMessage as _, Request, + }; + + macro_rules! parse_ready { + ($e:expr) => {{ + match MessageDecoder::::default().decode($e) { + Ok(Some((msg, _))) => msg, + Ok(_) => unreachable!("Eof during parsing http request"), + Err(err) => unreachable!("Error during parsing http request: {:?}", err), + } + }}; + } + + macro_rules! expect_parse_err { + ($e:expr) => {{ + match MessageDecoder::::default().decode($e) { + Err(err) => match err { + ParseError::Io(_) => unreachable!("Parse error expected"), + _ => {} + }, + _ => unreachable!("Error expected"), + } + }}; + } + + #[test] + fn test_parse_chunked_payload_chunk_extension() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\ + \r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(msg.chunked().unwrap()); + + buf.extend(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n") + let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); + assert_eq!(chunk, Bytes::from_static(b"data")); + let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); + assert_eq!(chunk, Bytes::from_static(b"line")); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert!(msg.eof()); + } + + #[test] + fn test_request_chunked() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + if let Ok(val) = req.chunked() { + assert!(val); + } else { + unreachable!("Error"); + } + + // intentional typo in "chunked" + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chnked\r\n\r\n", + ); + expect_parse_err!(&mut buf); + } + + #[test] + fn test_http_request_chunked_payload() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"data" + ); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"line" + ); + assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); + } + + #[test] + fn test_http_request_chunked_payload_and_next_message() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend( + b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\ + POST /test2 HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n" + .iter(), + ); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"data"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"line"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert!(msg.eof()); + + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert!(req.chunked().unwrap()); + assert_eq!(*req.method(), Method::POST); + assert!(req.chunked().unwrap()); + } + + #[test] + fn test_http_request_chunked_payload_chunks() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend(b"4\r\n1111\r\n"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"1111"); + + buf.extend(b"4\r\ndata\r"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"data"); + + buf.extend(b"\n4"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"\r"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + buf.extend(b"\n"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"li"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"li"); + + //trailers + //buf.feed_data("test: test\r\n"); + //not_ready!(reader.parse(&mut buf, &mut readbuf)); + + buf.extend(b"ne\r\n0\r\n"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"ne"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"\r\n"); + assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); + } + + #[test] + fn chunk_extension_quoted() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + Host: localhost:8080\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 2;hello=b;one=\"1 2 3\"\r\n\ + xx", + ); + + let mut reader = MessageDecoder::::default(); + let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + + let chunk = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"xx"))); + } + + #[test] + fn hrs_chunk_extension_invalid() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: localhost:8080\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 2;x\nx\r\n\ + 4c\r\n\ + 0\r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + + let err = pl.decode(&mut buf).unwrap_err(); + assert!(err + .to_string() + .contains("Invalid character in chunk extension")); + } + + #[test] + fn hrs_chunk_size_overflow() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + f0000000000000003\r\n\ + abc\r\n\ + 0\r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + + let err = pl.decode(&mut buf).unwrap_err(); + assert!(err + .to_string() + .contains("Invalid chunk size line: Size is too big")); + } +} diff --git a/actix-http/src/h1/client.rs b/actix-http/src/h1/client.rs index 4a6104688..9bd896ae0 100644 --- a/actix-http/src/h1/client.rs +++ b/actix-http/src/h1/client.rs @@ -5,13 +5,15 @@ use bitflags::bitflags; use bytes::{Bytes, BytesMut}; use http::{Method, Version}; -use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; -use super::{decoder, encoder, reserve_readbuf}; -use super::{Message, MessageType}; -use crate::body::BodySize; -use crate::config::ServiceConfig; -use crate::error::{ParseError, PayloadError}; -use crate::message::{ConnectionType, RequestHeadType, ResponseHead}; +use super::{ + decoder::{self, PayloadDecoder, PayloadItem, PayloadType}, + encoder, reserve_readbuf, Message, MessageType, +}; +use crate::{ + body::BodySize, + error::{ParseError, PayloadError}, + ConnectionType, RequestHeadType, ResponseHead, ServiceConfig, +}; bitflags! { struct Flags: u8 { @@ -120,7 +122,7 @@ impl Decoder for ClientCodec { debug_assert!(!self.inner.payload.is_some(), "Payload decoder is set"); if let Some((req, payload)) = self.inner.decoder.decode(src)? { - if let Some(ctype) = req.ctype() { + if let Some(ctype) = req.conn_type() { // do not use peer's keep-alive self.inner.ctype = if ctype == ConnectionType::KeepAlive { self.inner.ctype diff --git a/actix-http/src/h1/codec.rs b/actix-http/src/h1/codec.rs index 634ca25e8..9a8907579 100644 --- a/actix-http/src/h1/codec.rs +++ b/actix-http/src/h1/codec.rs @@ -5,15 +5,13 @@ use bitflags::bitflags; use bytes::BytesMut; use http::{Method, Version}; -use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; -use super::{decoder, encoder}; -use super::{Message, MessageType}; -use crate::body::BodySize; -use crate::config::ServiceConfig; -use crate::error::ParseError; -use crate::message::ConnectionType; -use crate::request::Request; -use crate::response::Response; +use super::{ + decoder::{self, PayloadDecoder, PayloadItem, PayloadType}, + encoder, Message, MessageType, +}; +use crate::{ + body::BodySize, error::ParseError, ConnectionType, Request, Response, ServiceConfig, +}; bitflags! { struct Flags: u8 { @@ -29,7 +27,7 @@ pub struct Codec { decoder: decoder::MessageDecoder, payload: Option, version: Version, - ctype: ConnectionType, + conn_type: ConnectionType, // encoder part flags: Flags, @@ -65,7 +63,7 @@ impl Codec { decoder: decoder::MessageDecoder::default(), payload: None, version: Version::HTTP_11, - ctype: ConnectionType::Close, + conn_type: ConnectionType::Close, encoder: encoder::MessageEncoder::default(), } } @@ -73,13 +71,13 @@ impl Codec { /// Check if request is upgrade. #[inline] pub fn upgrade(&self) -> bool { - self.ctype == ConnectionType::Upgrade + self.conn_type == ConnectionType::Upgrade } /// Check if last response is keep-alive. #[inline] pub fn keepalive(&self) -> bool { - self.ctype == ConnectionType::KeepAlive + self.conn_type == ConnectionType::KeepAlive } /// Check if keep-alive enabled on server level. @@ -124,11 +122,11 @@ impl Decoder for Codec { let head = req.head(); self.flags.set(Flags::HEAD, head.method == Method::HEAD); self.version = head.version; - self.ctype = head.connection_type(); - if self.ctype == ConnectionType::KeepAlive + self.conn_type = head.connection_type(); + if self.conn_type == ConnectionType::KeepAlive && !self.flags.contains(Flags::KEEPALIVE_ENABLED) { - self.ctype = ConnectionType::Close + self.conn_type = ConnectionType::Close } match payload { PayloadType::None => self.payload = None, @@ -159,14 +157,14 @@ impl Encoder, BodySize)>> for Codec { res.head_mut().version = self.version; // connection status - self.ctype = if let Some(ct) = res.head().ctype() { + self.conn_type = if let Some(ct) = res.head().conn_type() { if ct == ConnectionType::KeepAlive { - self.ctype + self.conn_type } else { ct } } else { - self.ctype + self.conn_type }; // encode message @@ -177,10 +175,9 @@ impl Encoder, BodySize)>> for Codec { self.flags.contains(Flags::STREAM), self.version, length, - self.ctype, + self.conn_type, &self.config, )?; - // self.headers_size = (dst.len() - len) as u32; } Message::Chunk(Some(bytes)) => { self.encoder.encode_chunk(bytes.as_ref(), dst)?; @@ -189,6 +186,7 @@ impl Encoder, BodySize)>> for Codec { self.encoder.encode_eof(dst)?; } } + Ok(()) } } @@ -199,7 +197,7 @@ mod tests { use http::Method; use super::*; - use crate::HttpMessage; + use crate::HttpMessage as _; #[actix_rt::test] async fn test_http_request_chunked_payload_and_next_message() { diff --git a/actix-http/src/h1/decoder.rs b/actix-http/src/h1/decoder.rs index 93a4b13d2..3d3a3ceac 100644 --- a/actix-http/src/h1/decoder.rs +++ b/actix-http/src/h1/decoder.rs @@ -1,18 +1,15 @@ -use std::convert::TryFrom; -use std::io; -use std::marker::PhantomData; -use std::task::Poll; +use std::{convert::TryFrom, io, marker::PhantomData, mem::MaybeUninit, task::Poll}; use actix_codec::Decoder; -use bytes::{Buf, Bytes, BytesMut}; -use http::header::{HeaderName, HeaderValue}; -use http::{header, Method, StatusCode, Uri, Version}; +use bytes::{Bytes, BytesMut}; +use http::{ + header::{self, HeaderName, HeaderValue}, + Method, StatusCode, Uri, Version, +}; use log::{debug, error, trace}; -use crate::error::ParseError; -use crate::header::HeaderMap; -use crate::message::{ConnectionType, ResponseHead}; -use crate::request::Request; +use super::chunked::ChunkedState; +use crate::{error::ParseError, header::HeaderMap, ConnectionType, Request, ResponseHead}; pub(crate) const MAX_BUFFER_SIZE: usize = 131_072; const MAX_HEADERS: usize = 96; @@ -50,7 +47,7 @@ pub(crate) enum PayloadLength { } pub(crate) trait MessageType: Sized { - fn set_connection_type(&mut self, ctype: Option); + fn set_connection_type(&mut self, conn_type: Option); fn set_expect(&mut self); @@ -67,14 +64,14 @@ pub(crate) trait MessageType: Sized { let mut has_upgrade_websocket = false; let mut expect = false; let mut chunked = false; + let mut seen_te = false; let mut content_length = None; { let headers = self.headers_mut(); for idx in raw_headers.iter() { - let name = - HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]).unwrap(); + let name = HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]).unwrap(); // SAFETY: httparse already checks header value is only visible ASCII bytes // from_maybe_shared_unchecked contains debug assertions so they are omitted here @@ -85,8 +82,17 @@ pub(crate) trait MessageType: Sized { }; match name { - header::CONTENT_LENGTH => { - if let Ok(s) = value.to_str() { + header::CONTENT_LENGTH if content_length.is_some() => { + debug!("multiple Content-Length"); + return Err(ParseError::Header); + } + + header::CONTENT_LENGTH => match value.to_str() { + Ok(s) if s.trim().starts_with('+') => { + debug!("illegal Content-Length: {:?}", s); + return Err(ParseError::Header); + } + Ok(s) => { if let Ok(len) = s.parse::() { if len != 0 { content_length = Some(len); @@ -95,22 +101,38 @@ pub(crate) trait MessageType: Sized { debug!("illegal Content-Length: {:?}", s); return Err(ParseError::Header); } - } else { + } + Err(_) => { debug!("illegal Content-Length: {:?}", value); return Err(ParseError::Header); } - } + }, + // transfer-encoding + header::TRANSFER_ENCODING if seen_te => { + debug!("multiple Transfer-Encoding not allowed"); + return Err(ParseError::Header); + } + header::TRANSFER_ENCODING => { - if let Ok(s) = value.to_str().map(|s| s.trim()) { - chunked = s.eq_ignore_ascii_case("chunked"); + seen_te = true; + + if let Ok(s) = value.to_str().map(str::trim) { + if s.eq_ignore_ascii_case("chunked") { + chunked = true; + } else if s.eq_ignore_ascii_case("identity") { + // allow silently since multiple TE headers are already checked + } else { + debug!("illegal Transfer-Encoding: {:?}", s); + return Err(ParseError::Header); + } } else { return Err(ParseError::Header); } } // connection keep-alive state header::CONNECTION => { - ka = if let Ok(conn) = value.to_str().map(|conn| conn.trim()) { + ka = if let Ok(conn) = value.to_str().map(str::trim) { if conn.eq_ignore_ascii_case("keep-alive") { Some(ConnectionType::KeepAlive) } else if conn.eq_ignore_ascii_case("close") { @@ -125,7 +147,7 @@ pub(crate) trait MessageType: Sized { }; } header::UPGRADE => { - if let Ok(val) = value.to_str().map(|val| val.trim()) { + if let Ok(val) = value.to_str().map(str::trim) { if val.eq_ignore_ascii_case("websocket") { has_upgrade_websocket = true; } @@ -148,7 +170,7 @@ pub(crate) trait MessageType: Sized { self.set_expect() } - // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 if chunked { // Chunked encoding Ok(PayloadLength::Payload(PayloadType::Payload( @@ -168,8 +190,8 @@ pub(crate) trait MessageType: Sized { } impl MessageType for Request { - fn set_connection_type(&mut self, ctype: Option) { - if let Some(ctype) = ctype { + fn set_connection_type(&mut self, conn_type: Option) { + if let Some(ctype) = conn_type { self.head_mut().set_connection_type(ctype); } } @@ -186,10 +208,17 @@ impl MessageType for Request { let mut headers: [HeaderIndex; MAX_HEADERS] = EMPTY_HEADER_INDEX_ARRAY; let (len, method, uri, ver, h_len) = { - let mut parsed: [httparse::Header<'_>; MAX_HEADERS] = EMPTY_HEADER_ARRAY; + // SAFETY: + // Create an uninitialized array of `MaybeUninit`. The `assume_init` is + // safe because the type we are claiming to have initialized here is a + // bunch of `MaybeUninit`s, which do not require initialization. + let mut parsed = unsafe { + MaybeUninit::<[MaybeUninit>; MAX_HEADERS]>::uninit() + .assume_init() + }; - let mut req = httparse::Request::new(&mut parsed); - match req.parse(src)? { + let mut req = httparse::Request::new(&mut []); + match req.parse_with_uninit_headers(src, &mut parsed)? { httparse::Status::Complete(len) => { let method = Method::from_bytes(req.method.unwrap().as_bytes()) .map_err(|_| ParseError::Method)?; @@ -246,8 +275,8 @@ impl MessageType for Request { } impl MessageType for ResponseHead { - fn set_connection_type(&mut self, ctype: Option) { - if let Some(ctype) = ctype { + fn set_connection_type(&mut self, conn_type: Option) { + if let Some(ctype) = conn_type { ResponseHead::set_connection_type(self, ctype); } } @@ -408,20 +437,6 @@ enum Kind { Eof, } -#[derive(Debug, PartialEq, Clone)] -enum ChunkedState { - Size, - SizeLws, - Extension, - SizeLf, - Body, - BodyCr, - BodyLf, - EndCr, - EndLf, - End, -} - impl Decoder for PayloadDecoder { type Item = PayloadItem; type Error = io::Error; @@ -451,19 +466,23 @@ impl Decoder for PayloadDecoder { Kind::Chunked(ref mut state, ref mut size) => { loop { let mut buf = None; + // advances the chunked state *state = match state.step(src, size, &mut buf) { Poll::Pending => return Ok(None), Poll::Ready(Ok(state)) => state, Poll::Ready(Err(e)) => return Err(e), }; + if *state == ChunkedState::End { trace!("End of chunked stream"); return Ok(Some(PayloadItem::Eof)); } + if let Some(buf) = buf { return Ok(Some(PayloadItem::Chunk(buf))); } + if src.is_empty() { return Ok(None); } @@ -480,201 +499,40 @@ impl Decoder for PayloadDecoder { } } -macro_rules! byte ( - ($rdr:ident) => ({ - if $rdr.len() > 0 { - let b = $rdr[0]; - $rdr.advance(1); - b - } else { - return Poll::Pending - } - }) -); - -impl ChunkedState { - fn step( - &self, - body: &mut BytesMut, - size: &mut u64, - buf: &mut Option, - ) -> Poll> { - use self::ChunkedState::*; - match *self { - Size => ChunkedState::read_size(body, size), - SizeLws => ChunkedState::read_size_lws(body), - Extension => ChunkedState::read_extension(body), - SizeLf => ChunkedState::read_size_lf(body, size), - Body => ChunkedState::read_body(body, size, buf), - BodyCr => ChunkedState::read_body_cr(body), - BodyLf => ChunkedState::read_body_lf(body), - EndCr => ChunkedState::read_end_cr(body), - EndLf => ChunkedState::read_end_lf(body), - End => Poll::Ready(Ok(ChunkedState::End)), - } - } - - fn read_size( - rdr: &mut BytesMut, - size: &mut u64, - ) -> Poll> { - let radix = 16; - match byte!(rdr) { - b @ b'0'..=b'9' => { - *size *= radix; - *size += u64::from(b - b'0'); - } - b @ b'a'..=b'f' => { - *size *= radix; - *size += u64::from(b + 10 - b'a'); - } - b @ b'A'..=b'F' => { - *size *= radix; - *size += u64::from(b + 10 - b'A'); - } - b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)), - b';' => return Poll::Ready(Ok(ChunkedState::Extension)), - b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), - _ => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk size line: Invalid Size", - ))); - } - } - Poll::Ready(Ok(ChunkedState::Size)) - } - - fn read_size_lws(rdr: &mut BytesMut) -> Poll> { - trace!("read_size_lws"); - match byte!(rdr) { - // LWS can follow the chunk size, but no more digits can come - b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), - b';' => Poll::Ready(Ok(ChunkedState::Extension)), - b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk size linear white space", - ))), - } - } - fn read_extension(rdr: &mut BytesMut) -> Poll> { - match byte!(rdr) { - b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), - _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions - } - } - fn read_size_lf( - rdr: &mut BytesMut, - size: &mut u64, - ) -> Poll> { - match byte!(rdr) { - b'\n' if *size > 0 => Poll::Ready(Ok(ChunkedState::Body)), - b'\n' if *size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk size LF", - ))), - } - } - - fn read_body( - rdr: &mut BytesMut, - rem: &mut u64, - buf: &mut Option, - ) -> Poll> { - trace!("Chunked read, remaining={:?}", rem); - - let len = rdr.len() as u64; - if len == 0 { - Poll::Ready(Ok(ChunkedState::Body)) - } else { - let slice; - if *rem > len { - slice = rdr.split().freeze(); - *rem -= len; - } else { - slice = rdr.split_to(*rem as usize).freeze(); - *rem = 0; - } - *buf = Some(slice); - if *rem > 0 { - Poll::Ready(Ok(ChunkedState::Body)) - } else { - Poll::Ready(Ok(ChunkedState::BodyCr)) - } - } - } - - fn read_body_cr(rdr: &mut BytesMut) -> Poll> { - match byte!(rdr) { - b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk body CR", - ))), - } - } - fn read_body_lf(rdr: &mut BytesMut) -> Poll> { - match byte!(rdr) { - b'\n' => Poll::Ready(Ok(ChunkedState::Size)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk body LF", - ))), - } - } - fn read_end_cr(rdr: &mut BytesMut) -> Poll> { - match byte!(rdr) { - b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk end CR", - ))), - } - } - fn read_end_lf(rdr: &mut BytesMut) -> Poll> { - match byte!(rdr) { - b'\n' => Poll::Ready(Ok(ChunkedState::End)), - _ => Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid chunk end LF", - ))), - } - } -} - #[cfg(test)] mod tests { use bytes::{Bytes, BytesMut}; use http::{Method, Version}; use super::*; - use crate::error::ParseError; - use crate::http::header::{HeaderName, SET_COOKIE}; - use crate::HttpMessage; + use crate::{ + error::ParseError, + header::{HeaderName, SET_COOKIE}, + HttpMessage as _, + }; impl PayloadType { - fn unwrap(self) -> PayloadDecoder { + pub(crate) fn unwrap(self) -> PayloadDecoder { match self { PayloadType::Payload(pl) => pl, _ => panic!(), } } - fn is_unhandled(&self) -> bool { + pub(crate) fn is_unhandled(&self) -> bool { matches!(self, PayloadType::Stream(_)) } } impl PayloadItem { - fn chunk(self) -> Bytes { + pub(crate) fn chunk(self) -> Bytes { match self { PayloadItem::Chunk(chunk) => chunk, _ => panic!("error"), } } - fn eof(&self) -> bool { + + pub(crate) fn eof(&self) -> bool { matches!(*self, PayloadItem::Eof) } } @@ -743,8 +601,7 @@ mod tests { #[test] fn test_parse_body() { - let mut buf = - BytesMut::from("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); + let mut buf = BytesMut::from("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); let mut reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); @@ -760,8 +617,7 @@ mod tests { #[test] fn test_parse_body_crlf() { - let mut buf = - BytesMut::from("\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); + let mut buf = BytesMut::from("\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); let mut reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); @@ -967,34 +823,6 @@ mod tests { assert!(req.upgrade()); } - #[test] - fn test_request_chunked() { - let mut buf = BytesMut::from( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n", - ); - let req = parse_ready!(&mut buf); - - if let Ok(val) = req.chunked() { - assert!(val); - } else { - unreachable!("Error"); - } - - // intentional typo in "chunked" - let mut buf = BytesMut::from( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chnked\r\n\r\n", - ); - let req = parse_ready!(&mut buf); - - if let Ok(val) = req.chunked() { - assert!(!val); - } else { - unreachable!("Error"); - } - } - #[test] fn test_headers_content_length_err_1() { let mut buf = BytesMut::from( @@ -1112,128 +940,9 @@ mod tests { expect_parse_err!(&mut buf); } - #[test] - fn test_http_request_chunked_payload() { - let mut buf = BytesMut::from( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n", - ); - let mut reader = MessageDecoder::::default(); - let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); - assert!(req.chunked().unwrap()); - - buf.extend(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); - assert_eq!( - pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), - b"data" - ); - assert_eq!( - pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), - b"line" - ); - assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); - } - - #[test] - fn test_http_request_chunked_payload_and_next_message() { - let mut buf = BytesMut::from( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n", - ); - let mut reader = MessageDecoder::::default(); - let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); - assert!(req.chunked().unwrap()); - - buf.extend( - b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\ - POST /test2 HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n" - .iter(), - ); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"data"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"line"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert!(msg.eof()); - - let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); - assert!(req.chunked().unwrap()); - assert_eq!(*req.method(), Method::POST); - assert!(req.chunked().unwrap()); - } - - #[test] - fn test_http_request_chunked_payload_chunks() { - let mut buf = BytesMut::from( - "GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n", - ); - - let mut reader = MessageDecoder::::default(); - let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); - assert!(req.chunked().unwrap()); - - buf.extend(b"4\r\n1111\r\n"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"1111"); - - buf.extend(b"4\r\ndata\r"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"data"); - - buf.extend(b"\n4"); - assert!(pl.decode(&mut buf).unwrap().is_none()); - - buf.extend(b"\r"); - assert!(pl.decode(&mut buf).unwrap().is_none()); - buf.extend(b"\n"); - assert!(pl.decode(&mut buf).unwrap().is_none()); - - buf.extend(b"li"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"li"); - - //trailers - //buf.feed_data("test: test\r\n"); - //not_ready!(reader.parse(&mut buf, &mut readbuf)); - - buf.extend(b"ne\r\n0\r\n"); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"ne"); - assert!(pl.decode(&mut buf).unwrap().is_none()); - - buf.extend(b"\r\n"); - assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); - } - - #[test] - fn test_parse_chunked_payload_chunk_extension() { - let mut buf = BytesMut::from( - &"GET /test HTTP/1.1\r\n\ - transfer-encoding: chunked\r\n\r\n"[..], - ); - - let mut reader = MessageDecoder::::default(); - let (msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); - assert!(msg.chunked().unwrap()); - - buf.extend(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n") - let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); - assert_eq!(chunk, Bytes::from_static(b"data")); - let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); - assert_eq!(chunk, Bytes::from_static(b"line")); - let msg = pl.decode(&mut buf).unwrap().unwrap(); - assert!(msg.eof()); - } - #[test] fn test_response_http10_read_until_eof() { - let mut buf = BytesMut::from(&"HTTP/1.0 200 Ok\r\n\r\ntest data"[..]); + let mut buf = BytesMut::from("HTTP/1.0 200 Ok\r\n\r\ntest data"); let mut reader = MessageDecoder::::default(); let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); @@ -1242,4 +951,84 @@ mod tests { let chunk = pl.decode(&mut buf).unwrap().unwrap(); assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"test data"))); } + + #[test] + fn hrs_multiple_content_length() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Content-Length: 4\r\n\ + Content-Length: 2\r\n\ + \r\n\ + abcd", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn hrs_content_length_plus() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Content-Length: +3\r\n\ + \r\n\ + 000", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn hrs_unknown_transfer_encoding() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Transfer-Encoding: JUNK\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 5\r\n\ + hello\r\n\ + 0", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn hrs_multiple_transfer_encoding() { + let mut buf = BytesMut::from( + "GET / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Content-Length: 51\r\n\ + Transfer-Encoding: identity\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 0\r\n\ + \r\n\ + GET /forbidden HTTP/1.1\r\n\ + Host: example.com\r\n\r\n", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn transfer_encoding_agrees() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + Host: example.com\r\n\ + Content-Length: 3\r\n\ + Transfer-Encoding: identity\r\n\ + \r\n\ + 0\r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + + let chunk = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"0\r\n"))); + } } diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 2e729b78d..13055f08a 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -13,21 +13,24 @@ use actix_rt::time::{sleep_until, Instant, Sleep}; use actix_service::Service; use bitflags::bitflags; use bytes::{Buf, BytesMut}; +use futures_core::ready; use log::{error, trace}; -use pin_project::pin_project; +use pin_project_lite::pin_project; -use crate::body::{Body, BodySize, MessageBody, ResponseBody}; -use crate::config::ServiceConfig; -use crate::error::{DispatchError, Error}; -use crate::error::{ParseError, PayloadError}; -use crate::request::Request; -use crate::response::Response; -use crate::service::HttpFlow; -use crate::OnConnectData; +use crate::{ + body::{BodySize, BoxBody, MessageBody}, + config::ServiceConfig, + error::{DispatchError, ParseError, PayloadError}, + service::HttpFlow, + Error, Extensions, OnConnectData, Request, Response, StatusCode, +}; -use super::codec::Codec; -use super::payload::{Payload, PayloadSender, PayloadStatus}; -use super::{Message, MessageType}; +use super::{ + codec::Codec, + decoder::MAX_BUFFER_SIZE, + payload::{Payload, PayloadSender, PayloadStatus}, + Message, MessageType, +}; const LW_BUFFER_SIZE: usize = 1024; const HW_BUFFER_SIZE: usize = 1024 * 8; @@ -40,74 +43,114 @@ bitflags! { const SHUTDOWN = 0b0000_0100; const READ_DISCONNECT = 0b0000_1000; const WRITE_DISCONNECT = 0b0001_0000; - const UPGRADE = 0b0010_0000; } } -#[pin_project] -/// Dispatcher for HTTP/1.1 protocol -pub struct Dispatcher -where - S: Service, - S::Error: Into, - B: MessageBody, - X: Service, - X::Error: Into, - U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display, -{ - #[pin] - inner: DispatcherState, +// there's 2 versions of Dispatcher state because of: +// https://github.com/taiki-e/pin-project-lite/issues/3 +// +// tl;dr: pin-project-lite doesn't play well with other attribute macros - #[cfg(test)] - poll_count: u64, +#[cfg(not(test))] +pin_project! { + /// Dispatcher for HTTP/1.1 protocol + pub struct Dispatcher + where + S: Service, + S::Error: Into>, + + B: MessageBody, + + X: Service, + X::Error: Into>, + + U: Service<(Request, Framed), Response = ()>, + U::Error: fmt::Display, + { + #[pin] + inner: DispatcherState, + } } -#[pin_project(project = DispatcherStateProj)] -enum DispatcherState -where - S: Service, - S::Error: Into, - B: MessageBody, - X: Service, - X::Error: Into, - U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display, -{ - Normal(#[pin] InnerDispatcher), - Upgrade(#[pin] U::Future), +#[cfg(test)] +pin_project! { + /// Dispatcher for HTTP/1.1 protocol + pub struct Dispatcher + where + S: Service, + S::Error: Into>, + + B: MessageBody, + + X: Service, + X::Error: Into>, + + U: Service<(Request, Framed), Response = ()>, + U::Error: fmt::Display, + { + #[pin] + inner: DispatcherState, + + // used in tests + poll_count: u64, + } } -#[pin_project(project = InnerDispatcherProj)] -struct InnerDispatcher -where - S: Service, - S::Error: Into, - B: MessageBody, - X: Service, - X::Error: Into, - U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display, -{ - flow: Rc>, - on_connect_data: OnConnectData, - flags: Flags, - peer_addr: Option, - error: Option, +pin_project! { + #[project = DispatcherStateProj] + enum DispatcherState + where + S: Service, + S::Error: Into>, - #[pin] - state: State, - payload: Option, - messages: VecDeque, + B: MessageBody, - ka_expire: Instant, - #[pin] - ka_timer: Option, + X: Service, + X::Error: Into>, - io: Option, - read_buf: BytesMut, - write_buf: BytesMut, - codec: Codec, + U: Service<(Request, Framed), Response = ()>, + U::Error: fmt::Display, + { + Normal { #[pin] inner: InnerDispatcher }, + Upgrade { #[pin] fut: U::Future }, + } +} + +pin_project! { + #[project = InnerDispatcherProj] + struct InnerDispatcher + where + S: Service, + S::Error: Into>, + + B: MessageBody, + + X: Service, + X::Error: Into>, + + U: Service<(Request, Framed), Response = ()>, + U::Error: fmt::Display, + { + flow: Rc>, + flags: Flags, + peer_addr: Option, + conn_data: Option>, + error: Option, + + #[pin] + state: State, + payload: Option, + messages: VecDeque, + + ka_expire: Instant, + #[pin] + ka_timer: Option, + + io: Option, + read_buf: BytesMut, + write_buf: BytesMut, + codec: Codec, + } } enum DispatcherMessage { @@ -116,23 +159,29 @@ enum DispatcherMessage { Error(Response<()>), } -#[pin_project(project = StateProj)] -enum State -where - S: Service, - X: Service, - B: MessageBody, -{ - None, - ExpectCall(#[pin] X::Future), - ServiceCall(#[pin] S::Future), - SendPayload(#[pin] ResponseBody), +pin_project! { + #[project = StateProj] + enum State + where + S: Service, + X: Service, + + B: MessageBody, + { + None, + ExpectCall { #[pin] fut: X::Future }, + ServiceCall { #[pin] fut: S::Future }, + SendPayload { #[pin] body: B }, + SendErrorPayload { #[pin] body: BoxBody }, + } } impl State where S: Service, + X: Service, + B: MessageBody, { fn is_empty(&self) -> bool { @@ -149,22 +198,26 @@ enum PollResponse { impl Dispatcher where T: AsyncRead + AsyncWrite + Unpin, + S: Service, - S::Error: Into, + S::Error: Into>, S::Response: Into>, + B: MessageBody, + X: Service, - X::Error: Into, + X::Error: Into>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { /// Create HTTP/1 dispatcher. pub(crate) fn new( io: T, - config: ServiceConfig, flow: Rc>, - on_connect_data: OnConnectData, + config: ServiceConfig, peer_addr: Option, + conn_data: OnConnectData, ) -> Self { let flags = if config.keep_alive_enabled() { Flags::KEEPALIVE @@ -179,22 +232,27 @@ where }; Dispatcher { - inner: DispatcherState::Normal(InnerDispatcher { - read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), - write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), - payload: None, - state: State::None, - error: None, - messages: VecDeque::new(), - io: Some(io), - codec: Codec::new(config), - flow, - on_connect_data, - flags, - peer_addr, - ka_expire, - ka_timer, - }), + inner: DispatcherState::Normal { + inner: InnerDispatcher { + flow, + flags, + peer_addr, + conn_data: conn_data.0.map(Rc::new), + error: None, + + state: State::None, + payload: None, + messages: VecDeque::new(), + + ka_expire, + ka_timer, + + io: Some(io), + read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), + write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), + codec: Codec::new(config), + }, + }, #[cfg(test)] poll_count: 0, @@ -205,20 +263,21 @@ where impl InnerDispatcher where T: AsyncRead + AsyncWrite + Unpin, + S: Service, - S::Error: Into, + S::Error: Into>, S::Response: Into>, + B: MessageBody, + X: Service, - X::Error: Into, + X::Error: Into>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { fn can_read(&self, cx: &mut Context<'_>) -> bool { - if self - .flags - .intersects(Flags::READ_DISCONNECT | Flags::UPGRADE) - { + if self.flags.contains(Flags::READ_DISCONNECT) { false } else if let Some(ref info) = self.payload { info.need_read(cx) == PayloadStatus::Read @@ -237,14 +296,7 @@ where } } - /// Flush stream - /// - /// true - got WouldBlock - /// false - didn't get WouldBlock - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Result { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let InnerDispatcherProj { io, write_buf, .. } = self.project(); let mut io = Pin::new(io.as_mut().unwrap()); @@ -252,19 +304,15 @@ where let mut written = 0; while written < len { - match io.as_mut().poll_write(cx, &write_buf[written..]) { - Poll::Ready(Ok(0)) => { - return Err(DispatchError::Io(io::Error::new( - io::ErrorKind::WriteZero, - "", - ))) + match io.as_mut().poll_write(cx, &write_buf[written..])? { + Poll::Ready(0) => { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::WriteZero, ""))) } - Poll::Ready(Ok(n)) => written += n, + Poll::Ready(n) => written += n, Poll::Pending => { write_buf.advance(written); - return Ok(true); + return Poll::Pending; } - Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), } } @@ -272,20 +320,18 @@ where write_buf.clear(); // flush the io and check if get blocked. - let blocked = io.poll_flush(cx)?.is_pending(); - - Ok(blocked) + io.poll_flush(cx) } - fn send_response( + fn send_response_inner( self: Pin<&mut Self>, message: Response<()>, - body: ResponseBody, - ) -> Result<(), DispatchError> { + body: &impl MessageBody, + ) -> Result { let size = body.size(); - let mut this = self.project(); + let this = self.project(); this.codec - .encode(Message::Item((message, size)), &mut this.write_buf) + .encode(Message::Item((message, size)), this.write_buf) .map_err(|err| { if let Some(mut payload) = this.payload.take() { payload.set_error(PayloadError::Incomplete(None)); @@ -294,10 +340,35 @@ where })?; this.flags.set(Flags::KEEPALIVE, this.codec.keepalive()); - match size { - BodySize::None | BodySize::Empty => this.state.set(State::None), - _ => this.state.set(State::SendPayload(body)), + + Ok(size) + } + + fn send_response( + mut self: Pin<&mut Self>, + message: Response<()>, + body: B, + ) -> Result<(), DispatchError> { + let size = self.as_mut().send_response_inner(message, &body)?; + let state = match size { + BodySize::None | BodySize::Sized(0) => State::None, + _ => State::SendPayload { body }, }; + self.project().state.set(state); + Ok(()) + } + + fn send_error_response( + mut self: Pin<&mut Self>, + message: Response<()>, + body: BoxBody, + ) -> Result<(), DispatchError> { + let size = self.as_mut().send_response_inner(message, &body)?; + let state = match size { + BodySize::None | BodySize::Sized(0) => State::None, + _ => State::SendErrorPayload { body }, + }; + self.project().state.set(state); Ok(()) } @@ -321,12 +392,12 @@ where // Handle `EXPECT: 100-Continue` header if req.head().expect() { // set InnerDispatcher state and continue loop to poll it. - let task = this.flow.expect.call(req); - this.state.set(State::ExpectCall(task)); + let fut = this.flow.expect.call(req); + this.state.set(State::ExpectCall { fut }); } else { // the same as expect call. - let task = this.flow.service.call(req); - this.state.set(State::ServiceCall(task)); + let fut = this.flow.service.call(req); + this.state.set(State::ServiceCall { fut }); }; } @@ -335,8 +406,7 @@ where // send_response would update InnerDispatcher state to SendPayload or // None(If response body is empty). // continue loop to poll it. - self.as_mut() - .send_response(res, ResponseBody::Other(Body::Empty))?; + self.as_mut().send_error_response(res, BoxBody::new(()))?; } // return with upgrade request and poll it exclusively. @@ -347,7 +417,7 @@ where // all messages are dealt with. None => return Ok(PollResponse::DoNothing), }, - StateProj::ServiceCall(fut) => match fut.poll(cx) { + StateProj::ServiceCall { fut } => match fut.poll(cx) { // service call resolved. send response. Poll::Ready(Ok(res)) => { let (res, body) = res.into().replace_body(()); @@ -356,9 +426,9 @@ where // send service call error as response Poll::Ready(Err(err)) => { - let res: Response = err.into().into(); + let res: Response = err.into(); let (res, body) = res.replace_body(()); - self.as_mut().send_response(res, body.into_body())?; + self.as_mut().send_error_response(res, body)?; } // service call pending and could be waiting for more chunk messages. @@ -373,21 +443,18 @@ where } }, - StateProj::SendPayload(mut stream) => { + StateProj::SendPayload { mut body } => { // keep populate writer buffer until buffer size limit hit, // get blocked or finished. while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE { - match stream.as_mut().poll_next(cx) { + match body.as_mut().poll_next(cx) { Poll::Ready(Some(Ok(item))) => { - this.codec.encode( - Message::Chunk(Some(item)), - &mut this.write_buf, - )?; + this.codec + .encode(Message::Chunk(Some(item)), this.write_buf)?; } Poll::Ready(None) => { - this.codec - .encode(Message::Chunk(None), &mut this.write_buf)?; + this.codec.encode(Message::Chunk(None), this.write_buf)?; // payload stream finished. // set state to None and handle next message this.state.set(State::None); @@ -395,7 +462,7 @@ where } Poll::Ready(Some(Err(err))) => { - return Err(DispatchError::Service(err)) + return Err(DispatchError::Body(err.into())) } Poll::Pending => return Ok(PollResponse::DoNothing), @@ -406,21 +473,57 @@ where return Ok(PollResponse::DrainWriteBuf); } - StateProj::ExpectCall(fut) => match fut.poll(cx) { + StateProj::SendErrorPayload { mut body } => { + // TODO: de-dupe impl with SendPayload + + // keep populate writer buffer until buffer size limit hit, + // get blocked or finished. + while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE { + match body.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(item))) => { + this.codec + .encode(Message::Chunk(Some(item)), this.write_buf)?; + } + + Poll::Ready(None) => { + this.codec.encode(Message::Chunk(None), this.write_buf)?; + // payload stream finished. + // set state to None and handle next message + this.state.set(State::None); + continue 'res; + } + + Poll::Ready(Some(Err(err))) => { + return Err(DispatchError::Body( + Error::new_body().with_cause(err).into(), + )) + } + + Poll::Pending => return Ok(PollResponse::DoNothing), + } + } + // buffer is beyond max size. + // return and try to write the whole buffer to io stream. + return Ok(PollResponse::DrainWriteBuf); + } + + StateProj::ExpectCall { fut } => match fut.poll(cx) { // expect resolved. write continue to buffer and set InnerDispatcher state // to service call. Poll::Ready(Ok(req)) => { this.write_buf .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); let fut = this.flow.service.call(req); - this.state.set(State::ServiceCall(fut)); + this.state.set(State::ServiceCall { fut }); } + // send expect error as response Poll::Ready(Err(err)) => { - let res: Response = err.into().into(); + let res: Response = err.into(); let (res, body) = res.replace_body(()); - self.as_mut().send_response(res, body.into_body())?; + self.as_mut().send_error_response(res, body)?; } + // expect must be solved before progress can be made. Poll::Pending => return Ok(PollResponse::DoNothing), }, @@ -434,29 +537,28 @@ where cx: &mut Context<'_>, ) -> Result<(), DispatchError> { // Handle `EXPECT: 100-Continue` header + let mut this = self.as_mut().project(); if req.head().expect() { // set dispatcher state so the future is pinned. - let mut this = self.as_mut().project(); - let task = this.flow.expect.call(req); - this.state.set(State::ExpectCall(task)); + let fut = this.flow.expect.call(req); + this.state.set(State::ExpectCall { fut }); } else { // the same as above. - let mut this = self.as_mut().project(); - let task = this.flow.service.call(req); - this.state.set(State::ServiceCall(task)); + let fut = this.flow.service.call(req); + this.state.set(State::ServiceCall { fut }); }; // eagerly poll the future for once(or twice if expect is resolved immediately). loop { match self.as_mut().project().state.project() { - StateProj::ExpectCall(fut) => { + StateProj::ExpectCall { fut } => { match fut.poll(cx) { // expect is resolved. continue loop and poll the service call branch. Poll::Ready(Ok(req)) => { self.as_mut().send_continue(); let mut this = self.as_mut().project(); - let task = this.flow.service.call(req); - this.state.set(State::ServiceCall(task)); + let fut = this.flow.service.call(req); + this.state.set(State::ServiceCall { fut }); continue; } // future is pending. return Ok(()) to notify that a new state is @@ -466,14 +568,13 @@ where // to notify the dispatcher a new state is set and the outer loop // should be continue. Poll::Ready(Err(err)) => { - let err = err.into(); - let res: Response = err.into(); + let res: Response = err.into(); let (res, body) = res.replace_body(()); - return self.send_response(res, body.into_body()); + return self.send_error_response(res, body); } } } - StateProj::ServiceCall(fut) => { + StateProj::ServiceCall { fut } => { // return no matter the service call future's result. return match fut.poll(cx) { // future is resolved. send response and return a result. On success @@ -487,21 +588,23 @@ where Poll::Pending => Ok(()), // see the comment on ExpectCall state branch's Ready(Err(err)). Poll::Ready(Err(err)) => { - let res: Response = err.into().into(); + let res: Response = err.into(); let (res, body) = res.replace_body(()); - self.send_response(res, body.into_body()) + self.send_error_response(res, body) } }; } - _ => unreachable!( - "State must be set to ServiceCall or ExceptCall in handle_request" - ), + _ => { + unreachable!( + "State must be set to ServiceCall or ExceptCall in handle_request" + ) + } } } } /// Process one incoming request. - pub(self) fn poll_request( + fn poll_request( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Result { @@ -513,32 +616,48 @@ where let mut updated = false; let mut this = self.as_mut().project(); loop { - match this.codec.decode(&mut this.read_buf) { + match this.codec.decode(this.read_buf) { Ok(Some(msg)) => { updated = true; this.flags.insert(Flags::STARTED); match msg { Message::Item(mut req) => { - let pl = this.codec.message_type(); req.head_mut().peer_addr = *this.peer_addr; - // merge on_connect_ext data into request extensions - this.on_connect_data.merge_into(&mut req); + req.conn_data = this.conn_data.as_ref().map(Rc::clone); - if pl == MessageType::Stream && this.flow.upgrade.is_some() { - this.messages.push_back(DispatcherMessage::Upgrade(req)); - break; - } - if pl == MessageType::Payload || pl == MessageType::Stream { - let (ps, pl) = Payload::create(false); - let (req1, _) = - req.replace_payload(crate::Payload::H1(pl)); - req = req1; - *this.payload = Some(ps); + match this.codec.message_type() { + // Request is upgradable. add upgrade message and break. + // everything remain in read buffer would be handed to + // upgraded Request. + MessageType::Stream if this.flow.upgrade.is_some() => { + this.messages.push_back(DispatcherMessage::Upgrade(req)); + break; + } + + // Request is not upgradable. + MessageType::Payload | MessageType::Stream => { + /* + PayloadSender and Payload are smart pointers share the + same state. + PayloadSender is attached to dispatcher and used to sink + new chunked request data to state. + Payload is attached to Request and passed to Service::call + where the state can be collected and consumed. + */ + let (sender, payload) = Payload::create(false); + let (req1, _) = + req.replace_payload(crate::Payload::H1 { payload }); + req = req1; + *this.payload = Some(sender); + } + + // Request has no payload. + MessageType::None => {} } - // handle request early + // handle request early when no future in InnerDispatcher state. if this.state.is_empty() { self.as_mut().handle_request(req, cx)?; this = self.as_mut().project(); @@ -550,12 +669,10 @@ where if let Some(ref mut payload) = this.payload { payload.feed_data(chunk); } else { - error!( - "Internal server error: unexpected payload chunk" - ); + error!("Internal server error: unexpected payload chunk"); this.flags.insert(Flags::READ_DISCONNECT); this.messages.push_back(DispatcherMessage::Error( - Response::InternalServerError().finish().drop_body(), + Response::internal_server_error().drop_body(), )); *this.error = Some(DispatchError::InternalError); break; @@ -568,7 +685,7 @@ where error!("Internal server error: unexpected eof"); this.flags.insert(Flags::READ_DISCONNECT); this.messages.push_back(DispatcherMessage::Error( - Response::InternalServerError().finish().drop_body(), + Response::internal_server_error().drop_body(), )); *this.error = Some(DispatchError::InternalError); break; @@ -590,9 +707,11 @@ where payload.set_error(PayloadError::Overflow); } // Requests overflow buffer size should be responded with 431 - this.messages.push_back(DispatcherMessage::Error( - Response::RequestHeaderFieldsTooLarge().finish().drop_body(), - )); + this.messages + .push_back(DispatcherMessage::Error(Response::with_body( + StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, + (), + ))); this.flags.insert(Flags::READ_DISCONNECT); *this.error = Some(ParseError::TooLarge.into()); break; @@ -604,7 +723,7 @@ where // Malformed requests should be responded with 400 this.messages.push_back(DispatcherMessage::Error( - Response::BadRequest().finish().drop_body(), + Response::bad_request().drop_body(), )); this.flags.insert(Flags::READ_DISCONNECT); *this.error = Some(err.into()); @@ -634,17 +753,11 @@ where None => { // conditionally go into shutdown timeout if this.flags.contains(Flags::SHUTDOWN) { - if let Some(deadline) = this.codec.config().client_disconnect_timer() - { + if let Some(deadline) = this.codec.config().client_disconnect_timer() { // write client disconnect time out and poll again to // go into Some> branch this.ka_timer.set(Some(sleep_until(deadline))); return self.poll_keepalive(cx); - } else { - this.flags.insert(Flags::READ_DISCONNECT); - if let Some(mut payload) = this.payload.take() { - payload.set_error(PayloadError::Incomplete(None)); - } } } } @@ -654,7 +767,7 @@ where // got timeout during shutdown, drop connection if this.flags.contains(Flags::SHUTDOWN) { return Err(DispatchError::DisconnectTimeout); - // exceed deadline. check for any outstanding tasks + // exceed deadline. check for any outstanding tasks } else if timer.deadline() >= *this.ka_expire { // have no task at hand. if this.state.is_empty() && this.write_buf.is_empty() { @@ -674,28 +787,21 @@ where } } else { // timeout on first request (slow request) return 408 - if !this.flags.contains(Flags::STARTED) { - trace!("Slow request timeout"); - let _ = self.as_mut().send_response( - Response::RequestTimeout().finish().drop_body(), - ResponseBody::Other(Body::Empty), - ); - this = self.project(); - } else { - trace!("Keep-alive connection timeout"); - } + trace!("Slow request timeout"); + let _ = self.as_mut().send_error_response( + Response::with_body(StatusCode::REQUEST_TIMEOUT, ()), + BoxBody::new(()), + ); + this = self.project(); this.flags.insert(Flags::STARTED | Flags::SHUTDOWN); - this.state.set(State::None); } - // still have unfinished task. try to reset and register keep-alive. - } else if let Some(deadline) = - this.codec.config().keep_alive_expire() - { + // still have unfinished task. try to reset and register keep-alive. + } else if let Some(deadline) = this.codec.config().keep_alive_expire() { timer.as_mut().reset(deadline); let _ = timer.poll(cx); } - // timer resolved but still have not met the keep-alive expire deadline. - // reset and register for later wakeup. + // timer resolved but still have not met the keep-alive expire deadline. + // reset and register for later wakeup. } else { timer.as_mut().reset(*this.ka_expire); let _ = timer.poll(cx); @@ -709,7 +815,6 @@ where /// Returns true when io stream can be disconnected after write to it. /// /// It covers these conditions: - /// /// - `std::io::ErrorKind::ConnectionReset` after partial read. /// - all data read done. #[inline(always)] @@ -728,6 +833,46 @@ where let mut read_some = false; loop { + // Return early when read buf exceed decoder's max buffer size. + if this.read_buf.len() >= MAX_BUFFER_SIZE { + // At this point it's not known IO stream is still scheduled to be waked up so + // force wake up dispatcher just in case. + // + // Reason: + // AsyncRead mostly would only have guarantee wake up when the poll_read + // return Poll::Pending. + // + // Case: + // When read_buf is beyond max buffer size the early return could be successfully + // be parsed as a new Request. This case would not generate ParseError::TooLarge and + // at this point IO stream is not fully read to Pending and would result in + // dispatcher stuck until timeout (KA) + // + // Note: + // This is a perf choice to reduce branch on ::decode. + // + // A Request head too large to parse is only checked on + // `httparse::Status::Partial` condition. + + if this.payload.is_none() { + // When dispatcher has a payload the responsibility of wake up it would be shift + // to h1::payload::Payload. + // + // Reason: + // Self wake up when there is payload would waste poll and/or result in + // over read. + // + // Case: + // When payload is (partial) dropped by user there is no need to do + // read anymore. At this case read_buf could always remain beyond + // MAX_BUFFER_SIZE and self wake up would be busy poll dispatcher and + // waste resources. + cx.waker().wake_by_ref(); + } + + return Ok(false); + } + // grow buffer if necessary. let remaining = this.read_buf.capacity() - this.read_buf.len(); if remaining < LW_BUFFER_SIZE { @@ -735,35 +880,18 @@ where } match actix_codec::poll_read_buf(io.as_mut(), cx, this.read_buf) { - Poll::Pending => return Ok(false), Poll::Ready(Ok(n)) => { if n == 0 { return Ok(true); - } else { - // Return early when read buf exceed decoder's max buffer size. - if this.read_buf.len() >= super::decoder::MAX_BUFFER_SIZE { - // at this point it's not known io is still scheduled to - // be waked up. so force wake up dispatcher just in case. - // TODO: figure out the overhead. - if this.payload.is_none() { - // When dispatcher has a payload. The responsibility of - // wake up stream would be shift to PayloadSender. - // Therefore no self wake up is needed. - cx.waker().wake_by_ref(); - } - return Ok(false); - } - - read_some = true; } + read_some = true; } + Poll::Pending => return Ok(false), Poll::Ready(Err(err)) => { - return if err.kind() == io::ErrorKind::WouldBlock { - Ok(false) - } else if err.kind() == io::ErrorKind::ConnectionReset && read_some { - Ok(true) - } else { - Err(DispatchError::Io(err)) + return match err.kind() { + io::ErrorKind::WouldBlock => Ok(false), + io::ErrorKind::ConnectionReset if read_some => Ok(true), + _ => Err(DispatchError::Io(err)), } } } @@ -787,12 +915,16 @@ where impl Future for Dispatcher where T: AsyncRead + AsyncWrite + Unpin, + S: Service, - S::Error: Into, + S::Error: Into>, S::Response: Into>, + B: MessageBody, + X: Service, - X::Error: Into, + X::Error: Into>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { @@ -808,21 +940,18 @@ where } match this.inner.project() { - DispatcherStateProj::Normal(mut inner) => { + DispatcherStateProj::Normal { mut inner } => { inner.as_mut().poll_keepalive(cx)?; if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::WRITE_DISCONNECT) { Poll::Ready(Ok(())) } else { - // flush buffer and wait on block. - if inner.as_mut().poll_flush(cx)? { - Poll::Pending - } else { - Pin::new(inner.project().io.as_mut().unwrap()) - .poll_shutdown(cx) - .map_err(DispatchError::from) - } + // flush buffer and wait on blocked. + ready!(inner.as_mut().poll_flush(cx))?; + Pin::new(inner.project().io.as_mut().unwrap()) + .poll_shutdown(cx) + .map_err(DispatchError::from) } } else { // read from io stream and fill read buffer. @@ -851,7 +980,7 @@ where self.as_mut() .project() .inner - .set(DispatcherState::Upgrade(upgrade)); + .set(DispatcherState::Upgrade { fut: upgrade }); return self.poll(cx); } }; @@ -862,7 +991,7 @@ where // // TODO: what? is WouldBlock good or bad? // want to find a reference for this macOS behavior - if inner.as_mut().poll_flush(cx)? || !drain { + if inner.as_mut().poll_flush(cx)?.is_pending() || !drain { break; } } @@ -903,8 +1032,8 @@ where } } } - DispatcherStateProj::Upgrade(fut) => fut.poll(cx).map_err(|e| { - error!("Upgrade handler error: {}", e); + DispatcherStateProj::Upgrade { fut: upgrade } => upgrade.poll(cx).map_err(|err| { + error!("Upgrade handler error: {}", err); DispatchError::Upgrade }), } @@ -916,14 +1045,16 @@ mod tests { use std::str; use actix_service::fn_service; - use futures_util::future::{lazy, ready}; + use actix_utils::future::{ready, Ready}; + use bytes::Bytes; + use futures_util::future::lazy; use super::*; - use crate::test::TestBuffer; - use crate::{error::Error, KeepAlive}; use crate::{ + error::Error, h1::{ExpectHandler, UpgradeHandler}, - test::TestSeqBuffer, + test::{TestBuffer, TestSeqBuffer}, + HttpMessage, KeepAlive, Method, }; fn find_slice(haystack: &[u8], needle: &[u8], from: usize) -> Option { @@ -935,25 +1066,29 @@ mod tests { fn stabilize_date_header(payload: &mut [u8]) { let mut from = 0; - while let Some(pos) = find_slice(&payload, b"date", from) { + while let Some(pos) = find_slice(payload, b"date", from) { payload[(from + pos)..(from + pos + 35)] .copy_from_slice(b"date: Thu, 01 Jan 1970 12:34:56 UTC"); from += 35; } } - fn ok_service() -> impl Service { - fn_service(|_req: Request| ready(Ok::<_, Error>(Response::Ok().finish()))) + fn ok_service( + ) -> impl Service, Error = Error> { + fn_service(|_req: Request| ready(Ok::<_, Error>(Response::ok()))) } - fn echo_path_service() -> impl Service { + fn echo_path_service( + ) -> impl Service, Error = Error> { fn_service(|req: Request| { let path = req.path().as_bytes(); - ready(Ok::<_, Error>(Response::Ok().body(Body::from_slice(path)))) + ready(Ok::<_, Error>( + Response::ok().set_body(Bytes::copy_from_slice(path)), + )) }) } - fn echo_payload_service() -> impl Service + fn echo_payload_service() -> impl Service, Error = Error> { fn_service(|mut req: Request| { Box::pin(async move { @@ -965,7 +1100,7 @@ mod tests { body.extend_from_slice(chunk.unwrap().chunk()) } - Ok::<_, Error>(Response::Ok().body(body)) + Ok::<_, Error>(Response::ok().set_body(body.freeze())) }) }) } @@ -979,20 +1114,20 @@ mod tests { let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, - ServiceConfig::default(), services, - OnConnectData::default(), + ServiceConfig::default(), None, + OnConnectData::default(), ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); match h1.as_mut().poll(cx) { Poll::Pending => panic!(), Poll::Ready(res) => assert!(res.is_err()), } - if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { + if let DispatcherStateProj::Normal { inner } = h1.project().inner.project() { assert!(inner.flags.contains(Flags::READ_DISCONNECT)); assert_eq!( &inner.project().io.take().unwrap().write_buf[..26], @@ -1019,15 +1154,15 @@ mod tests { let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, - cfg, services, - OnConnectData::default(), + cfg, None, + OnConnectData::default(), ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); - assert!(matches!(&h1.inner, DispatcherState::Normal(_))); + assert!(matches!(&h1.inner, DispatcherState::Normal { .. })); match h1.as_mut().poll(cx) { Poll::Pending => panic!("first poll should not be pending"), @@ -1037,7 +1172,7 @@ mod tests { // polls: initial => shutdown assert_eq!(h1.poll_count, 2); - if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { + if let DispatcherStateProj::Normal { inner } = h1.project().inner.project() { let res = &mut inner.project().io.take().unwrap().write_buf[..]; stabilize_date_header(res); @@ -1073,15 +1208,15 @@ mod tests { let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, - cfg, services, - OnConnectData::default(), + cfg, None, + OnConnectData::default(), ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); - assert!(matches!(&h1.inner, DispatcherState::Normal(_))); + assert!(matches!(&h1.inner, DispatcherState::Normal { .. })); match h1.as_mut().poll(cx) { Poll::Pending => panic!("first poll should not be pending"), @@ -1091,7 +1226,7 @@ mod tests { // polls: initial => shutdown assert_eq!(h1.poll_count, 1); - if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { + if let DispatcherStateProj::Normal { inner } = h1.project().inner.project() { let res = &mut inner.project().io.take().unwrap().write_buf[..]; stabilize_date_header(res); @@ -1123,10 +1258,10 @@ mod tests { let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf.clone(), - cfg, services, - OnConnectData::default(), + cfg, None, + OnConnectData::default(), ); buf.extend_read_buf( @@ -1138,16 +1273,16 @@ mod tests { ", ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); assert!(h1.as_mut().poll(cx).is_pending()); - assert!(matches!(&h1.inner, DispatcherState::Normal(_))); + assert!(matches!(&h1.inner, DispatcherState::Normal { .. })); // polls: manual assert_eq!(h1.poll_count, 1); eprintln!("poll count: {}", h1.poll_count); - if let DispatcherState::Normal(ref inner) = h1.inner { + if let DispatcherState::Normal { ref inner } = h1.inner { let io = inner.io.as_ref().unwrap(); let res = &io.write_buf()[..]; assert_eq!( @@ -1162,7 +1297,7 @@ mod tests { // polls: manual manual shutdown assert_eq!(h1.poll_count, 3); - if let DispatcherState::Normal(ref inner) = h1.inner { + if let DispatcherState::Normal { ref inner } = h1.inner { let io = inner.io.as_ref().unwrap(); let mut res = (&io.write_buf()[..]).to_owned(); stabilize_date_header(&mut res); @@ -1195,10 +1330,10 @@ mod tests { let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf.clone(), - cfg, services, - OnConnectData::default(), + cfg, None, + OnConnectData::default(), ); buf.extend_read_buf( @@ -1210,15 +1345,15 @@ mod tests { ", ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); assert!(h1.as_mut().poll(cx).is_ready()); - assert!(matches!(&h1.inner, DispatcherState::Normal(_))); + assert!(matches!(&h1.inner, DispatcherState::Normal { .. })); // polls: manual shutdown assert_eq!(h1.poll_count, 2); - if let DispatcherState::Normal(ref inner) = h1.inner { + if let DispatcherState::Normal { ref inner } = h1.inner { let io = inner.io.as_ref().unwrap(); let mut res = (&io.write_buf()[..]).to_owned(); stabilize_date_header(&mut res); @@ -1247,19 +1382,35 @@ mod tests { #[actix_rt::test] async fn test_upgrade() { + struct TestUpgrade; + + impl Service<(Request, Framed)> for TestUpgrade { + type Response = (); + type Error = Error; + type Future = Ready>; + + actix_service::always_ready!(); + + fn call(&self, (req, _framed): (Request, Framed)) -> Self::Future { + assert_eq!(req.method(), Method::GET); + assert!(req.upgrade()); + assert_eq!(req.headers().get("upgrade").unwrap(), "websocket"); + ready(Ok(())) + } + } + lazy(|cx| { let mut buf = TestSeqBuffer::empty(); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); - let services = - HttpFlow::new(ok_service(), ExpectHandler, Some(UpgradeHandler)); + let services = HttpFlow::new(ok_service(), ExpectHandler, Some(TestUpgrade)); - let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + let h1 = Dispatcher::<_, _, _, _, TestUpgrade>::new( buf.clone(), - cfg, services, - OnConnectData::default(), + cfg, None, + OnConnectData::default(), ); buf.extend_read_buf( @@ -1271,10 +1422,10 @@ mod tests { ", ); - futures_util::pin_mut!(h1); + actix_rt::pin!(h1); assert!(h1.as_mut().poll(cx).is_ready()); - assert!(matches!(&h1.inner, DispatcherState::Upgrade(_))); + assert!(matches!(&h1.inner, DispatcherState::Upgrade { .. })); // polls: manual shutdown assert_eq!(h1.poll_count, 2); diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs index 97916e7db..f2a862278 100644 --- a/actix-http/src/h1/encoder.rs +++ b/actix-http/src/h1/encoder.rs @@ -1,24 +1,26 @@ -use std::io::Write; -use std::marker::PhantomData; -use std::ptr::copy_nonoverlapping; -use std::slice::from_raw_parts_mut; -use std::{cmp, io}; +use std::{ + cmp, + io::{self, Write as _}, + marker::PhantomData, + ptr::copy_nonoverlapping, + slice::from_raw_parts_mut, +}; use bytes::{BufMut, BytesMut}; -use crate::body::BodySize; -use crate::config::ServiceConfig; -use crate::header::{map::Value, HeaderName}; -use crate::helpers; -use crate::http::header::{CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; -use crate::http::{HeaderMap, StatusCode, Version}; -use crate::message::{ConnectionType, RequestHeadType}; -use crate::response::Response; +use crate::{ + body::BodySize, + header::{ + map::Value, HeaderMap, HeaderName, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, + }, + helpers, ConnectionType, RequestHeadType, Response, ServiceConfig, StatusCode, Version, +}; const AVERAGE_HEADER_SIZE: usize = 30; #[derive(Debug)] pub(crate) struct MessageEncoder { + #[allow(dead_code)] pub length: BodySize, pub te: TransferEncoding, _phantom: PhantomData, @@ -54,7 +56,7 @@ pub(crate) trait MessageType: Sized { dst: &mut BytesMut, version: Version, mut length: BodySize, - ctype: ConnectionType, + conn_type: ConnectionType, config: &ServiceConfig, ) -> io::Result<()> { let chunked = self.chunked(); @@ -69,17 +71,28 @@ pub(crate) trait MessageType: Sized { | StatusCode::PROCESSING | StatusCode::NO_CONTENT => { // skip content-length and transfer-encoding headers - // See https://tools.ietf.org/html/rfc7230#section-3.3.1 - // and https://tools.ietf.org/html/rfc7230#section-3.3.2 + // see https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.1 + // and https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.2 skip_len = true; length = BodySize::None } + + StatusCode::NOT_MODIFIED => { + // 304 responses should never have a body but should retain a manually set + // content-length header + // see https://datatracker.ietf.org/doc/html/rfc7232#section-4.1 + skip_len = false; + length = BodySize::None; + } + _ => {} } } + match length { BodySize::Stream => { if chunked { + skip_len = true; if camel_case { dst.put_slice(b"\r\nTransfer-Encoding: chunked\r\n") } else { @@ -90,19 +103,14 @@ pub(crate) trait MessageType: Sized { dst.put_slice(b"\r\n"); } } - BodySize::Empty => { - if camel_case { - dst.put_slice(b"\r\nContent-Length: 0\r\n"); - } else { - dst.put_slice(b"\r\ncontent-length: 0\r\n"); - } - } + BodySize::Sized(0) if camel_case => dst.put_slice(b"\r\nContent-Length: 0\r\n"), + BodySize::Sized(0) => dst.put_slice(b"\r\ncontent-length: 0\r\n"), BodySize::Sized(len) => helpers::write_content_length(len, dst), BodySize::None => dst.put_slice(b"\r\n"), } // Connection - match ctype { + match conn_type { ConnectionType::Upgrade => dst.put_slice(b"connection: upgrade\r\n"), ConnectionType::KeepAlive if version < Version::HTTP_11 => { if camel_case { @@ -173,7 +181,7 @@ pub(crate) trait MessageType: Sized { unsafe { if camel_case { // use Camel-Case headers - write_camel_case(k, from_raw_parts_mut(buf, k_len)); + write_camel_case(k, buf, k_len); } else { write_data(k, buf, k_len); } @@ -287,7 +295,7 @@ impl MessageType for RequestHeadType { let head = self.as_ref(); dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE); write!( - helpers::Writer(dst), + helpers::MutWriter(dst), "{} {} {}", head.method, head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), @@ -297,11 +305,7 @@ impl MessageType for RequestHeadType { Version::HTTP_11 => "HTTP/1.1", Version::HTTP_2 => "HTTP/2.0", Version::HTTP_3 => "HTTP/3.0", - _ => - return Err(io::Error::new( - io::ErrorKind::Other, - "unsupported version" - )), + _ => return Err(io::Error::new(io::ErrorKind::Other, "unsupported version")), } ) .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) @@ -327,13 +331,13 @@ impl MessageEncoder { stream: bool, version: Version, length: BodySize, - ctype: ConnectionType, + conn_type: ConnectionType, config: &ServiceConfig, ) -> io::Result<()> { // transfer encoding if !head { self.te = match length { - BodySize::Empty => TransferEncoding::empty(), + BodySize::Sized(0) => TransferEncoding::empty(), BodySize::Sized(len) => TransferEncoding::length(len), BodySize::Stream => { if message.chunked() && !stream { @@ -349,7 +353,7 @@ impl MessageEncoder { } message.encode_status(dst)?; - message.encode_headers(dst, version, length, ctype, config) + message.encode_headers(dst, version, length, conn_type, config) } } @@ -363,10 +367,12 @@ pub(crate) struct TransferEncoding { enum TransferEncodingKind { /// An Encoder for when Transfer-Encoding includes `chunked`. Chunked(bool), + /// An Encoder for when Content-Length is set. /// /// Enforces that the body is not longer than the Content-Length header. Length(u64), + /// An Encoder for when Content-Length is not known. /// /// Application decides when to stop writing. @@ -420,7 +426,7 @@ impl TransferEncoding { *eof = true; buf.extend_from_slice(b"0\r\n\r\n"); } else { - writeln!(helpers::Writer(buf), "{:X}\r", msg.len()) + writeln!(helpers::MutWriter(buf), "{:X}\r", msg.len()) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; buf.reserve(msg.len() + 2); @@ -471,15 +477,22 @@ impl TransferEncoding { } /// # Safety -/// Callers must ensure that the given length matches given value length. +/// Callers must ensure that the given `len` matches the given `value` length and that `buf` is +/// valid for writes of at least `len` bytes. unsafe fn write_data(value: &[u8], buf: *mut u8, len: usize) { debug_assert_eq!(value.len(), len); copy_nonoverlapping(value.as_ptr(), buf, len); } -fn write_camel_case(value: &[u8], buffer: &mut [u8]) { +/// # Safety +/// Callers must ensure that the given `len` matches the given `value` length and that `buf` is +/// valid for writes of at least `len` bytes. +unsafe fn write_camel_case(value: &[u8], buf: *mut u8, len: usize) { // first copy entire (potentially wrong) slice to output - buffer[..value.len()].copy_from_slice(value); + write_data(value, buf, len); + + // SAFETY: We just initialized the buffer with `value` + let buffer = from_raw_parts_mut(buf, len); let mut iter = value.iter(); @@ -512,8 +525,10 @@ mod tests { use http::header::AUTHORIZATION; use super::*; - use crate::http::header::{HeaderValue, CONTENT_TYPE}; - use crate::RequestHead; + use crate::{ + header::{HeaderValue, CONTENT_TYPE}, + RequestHead, + }; #[test] fn test_chunked_te() { @@ -543,12 +558,11 @@ mod tests { let _ = head.encode_headers( &mut bytes, Version::HTTP_11, - BodySize::Empty, + BodySize::Sized(0), ConnectionType::Close, &ServiceConfig::default(), ); - let data = - String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); assert!(data.contains("Content-Length: 0\r\n")); assert!(data.contains("Connection: close\r\n")); @@ -562,8 +576,7 @@ mod tests { ConnectionType::KeepAlive, &ServiceConfig::default(), ); - let data = - String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); assert!(data.contains("Transfer-Encoding: chunked\r\n")); assert!(data.contains("Content-Type: plain/text\r\n")); assert!(data.contains("Date: date\r\n")); @@ -584,8 +597,7 @@ mod tests { ConnectionType::KeepAlive, &ServiceConfig::default(), ); - let data = - String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); assert!(data.contains("transfer-encoding: chunked\r\n")); assert!(data.contains("content-type: xml\r\n")); assert!(data.contains("content-type: plain/text\r\n")); @@ -614,12 +626,11 @@ mod tests { let _ = head.encode_headers( &mut bytes, Version::HTTP_11, - BodySize::Empty, + BodySize::Sized(0), ConnectionType::Close, &ServiceConfig::default(), ); - let data = - String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); assert!(data.contains("content-length: 0\r\n")); assert!(data.contains("connection: close\r\n")); assert!(data.contains("authorization: another authorization\r\n")); @@ -630,12 +641,10 @@ mod tests { async fn test_no_content_length() { let mut bytes = BytesMut::with_capacity(2048); - let mut res: Response<()> = - Response::new(StatusCode::SWITCHING_PROTOCOLS).into_body::<()>(); + let mut res = Response::with_body(StatusCode::SWITCHING_PROTOCOLS, ()); + res.headers_mut().insert(DATE, HeaderValue::from_static("")); res.headers_mut() - .insert(DATE, HeaderValue::from_static(&"")); - res.headers_mut() - .insert(CONTENT_LENGTH, HeaderValue::from_static(&"0")); + .insert(CONTENT_LENGTH, HeaderValue::from_static("0")); let _ = res.encode_headers( &mut bytes, @@ -644,8 +653,7 @@ mod tests { ConnectionType::Upgrade, &ServiceConfig::default(), ); - let data = - String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); assert!(!data.contains("content-length: 0\r\n")); assert!(!data.contains("transfer-encoding: chunked\r\n")); } diff --git a/actix-http/src/h1/expect.rs b/actix-http/src/h1/expect.rs index 65856edf6..bef281e59 100644 --- a/actix-http/src/h1/expect.rs +++ b/actix-http/src/h1/expect.rs @@ -1,10 +1,7 @@ -use std::task::Poll; - use actix_service::{Service, ServiceFactory}; -use futures_util::future::{ready, Ready}; +use actix_utils::future::{ready, Ready}; -use crate::error::Error; -use crate::request::Request; +use crate::{Error, Request}; pub struct ExpectHandler; diff --git a/actix-http/src/h1/mod.rs b/actix-http/src/h1/mod.rs index 7e6df6ceb..64586a2dc 100644 --- a/actix-http/src/h1/mod.rs +++ b/actix-http/src/h1/mod.rs @@ -1,6 +1,8 @@ //! HTTP/1 protocol implementation. + use bytes::{Bytes, BytesMut}; +mod chunked; mod client; mod codec; mod decoder; @@ -57,7 +59,7 @@ pub(crate) fn reserve_readbuf(src: &mut BytesMut) { #[cfg(test)] mod tests { use super::*; - use crate::request::Request; + use crate::Request; impl Message { pub fn message(self) -> Request { diff --git a/actix-http/src/h1/payload.rs b/actix-http/src/h1/payload.rs index d4cfee146..4d031c15a 100644 --- a/actix-http/src/h1/payload.rs +++ b/actix-http/src/h1/payload.rs @@ -1,11 +1,13 @@ //! Payload stream -use std::cell::RefCell; -use std::collections::VecDeque; -use std::pin::Pin; -use std::rc::{Rc, Weak}; -use std::task::{Context, Poll}; -use actix_utils::task::LocalWaker; +use std::{ + cell::RefCell, + collections::VecDeque, + pin::Pin, + rc::{Rc, Weak}, + task::{Context, Poll, Waker}, +}; + use bytes::Bytes; use futures_core::Stream; @@ -23,39 +25,32 @@ pub enum PayloadStatus { /// Buffered stream of bytes chunks /// -/// Payload stores chunks in a vector. First chunk can be received with -/// `.readany()` method. Payload stream is not thread safe. Payload does not -/// notify current task when new data is available. +/// Payload stores chunks in a vector. First chunk can be received with `poll_next`. Payload does +/// not notify current task when new data is available. /// -/// Payload stream can be used as `Response` body stream. +/// Payload can be used as `Response` body stream. #[derive(Debug)] pub struct Payload { inner: Rc>, } impl Payload { - /// Create payload stream. + /// Creates a payload stream. /// - /// This method construct two objects responsible for bytes stream - /// generation. - /// - /// * `PayloadSender` - *Sender* side of the stream - /// - /// * `Payload` - *Receiver* side of the stream + /// This method construct two objects responsible for bytes stream generation: + /// - `PayloadSender` - *Sender* side of the stream + /// - `Payload` - *Receiver* side of the stream pub fn create(eof: bool) -> (PayloadSender, Payload) { let shared = Rc::new(RefCell::new(Inner::new(eof))); ( - PayloadSender { - inner: Rc::downgrade(&shared), - }, + PayloadSender::new(Rc::downgrade(&shared)), Payload { inner: shared }, ) } - /// Create empty payload - #[doc(hidden)] - pub fn empty() -> Payload { + /// Creates an empty payload. + pub(crate) fn empty() -> Payload { Payload { inner: Rc::new(RefCell::new(Inner::new(true))), } @@ -78,14 +73,6 @@ impl Payload { pub fn unread_data(&mut self, data: Bytes) { self.inner.borrow_mut().unread_data(data); } - - #[inline] - pub fn readany( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>> { - self.inner.borrow_mut().readany(cx) - } } impl Stream for Payload { @@ -95,7 +82,7 @@ impl Stream for Payload { self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.inner.borrow_mut().readany(cx) + Pin::new(&mut *self.inner.borrow_mut()).poll_next(cx) } } @@ -105,6 +92,10 @@ pub struct PayloadSender { } impl PayloadSender { + fn new(inner: Weak>) -> Self { + Self { inner } + } + #[inline] pub fn set_error(&mut self, err: PayloadError) { if let Some(shared) = self.inner.upgrade() { @@ -134,7 +125,7 @@ impl PayloadSender { if shared.borrow().need_read { PayloadStatus::Read } else { - shared.borrow_mut().io_task.register(cx.waker()); + shared.borrow_mut().register_io(cx); PayloadStatus::Pause } } else { @@ -150,8 +141,8 @@ struct Inner { err: Option, need_read: bool, items: VecDeque, - task: LocalWaker, - io_task: LocalWaker, + task: Option, + io_task: Option, } impl Inner { @@ -162,8 +153,46 @@ impl Inner { err: None, items: VecDeque::new(), need_read: true, - task: LocalWaker::new(), - io_task: LocalWaker::new(), + task: None, + io_task: None, + } + } + + /// Wake up future waiting for payload data to be available. + fn wake(&mut self) { + if let Some(waker) = self.task.take() { + waker.wake(); + } + } + + /// Wake up future feeding data to Payload. + fn wake_io(&mut self) { + if let Some(waker) = self.io_task.take() { + waker.wake(); + } + } + + /// Register future waiting data from payload. + /// Waker would be used in `Inner::wake` + fn register(&mut self, cx: &mut Context<'_>) { + if self + .task + .as_ref() + .map_or(true, |w| !cx.waker().will_wake(w)) + { + self.task = Some(cx.waker().clone()); + } + } + + // Register future feeding data to payload. + /// Waker would be used in `Inner::wake_io` + fn register_io(&mut self, cx: &mut Context<'_>) { + if self + .io_task + .as_ref() + .map_or(true, |w| !cx.waker().will_wake(w)) + { + self.io_task = Some(cx.waker().clone()); } } @@ -182,7 +211,7 @@ impl Inner { self.len += data.len(); self.items.push_back(data); self.need_read = self.len < MAX_BUFFER_SIZE; - self.task.wake(); + self.wake(); } #[cfg(test)] @@ -190,8 +219,8 @@ impl Inner { self.len } - fn readany( - &mut self, + fn poll_next( + mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { if let Some(data) = self.items.pop_front() { @@ -199,9 +228,9 @@ impl Inner { self.need_read = self.len < MAX_BUFFER_SIZE; if self.need_read && !self.eof { - self.task.register(cx.waker()); + self.register(cx); } - self.io_task.wake(); + self.wake_io(); Poll::Ready(Some(Ok(data))) } else if let Some(err) = self.err.take() { Poll::Ready(Some(Err(err))) @@ -209,8 +238,8 @@ impl Inner { Poll::Ready(None) } else { self.need_read = true; - self.task.register(cx.waker()); - self.io_task.wake(); + self.register(cx); + self.wake_io(); Poll::Pending } } @@ -223,8 +252,18 @@ impl Inner { #[cfg(test)] mod tests { + use std::panic::{RefUnwindSafe, UnwindSafe}; + + use actix_utils::future::poll_fn; + use static_assertions::{assert_impl_all, assert_not_impl_any}; + use super::*; - use futures_util::future::poll_fn; + + assert_impl_all!(Payload: Unpin); + assert_not_impl_any!(Payload: Send, Sync, UnwindSafe, RefUnwindSafe); + + assert_impl_all!(Inner: Unpin, Send, Sync); + assert_not_impl_any!(Inner: UnwindSafe, RefUnwindSafe); #[actix_rt::test] async fn test_unread_data() { @@ -236,7 +275,10 @@ mod tests { assert_eq!( Bytes::from("data"), - poll_fn(|cx| payload.readany(cx)).await.unwrap().unwrap() + poll_fn(|cx| Pin::new(&mut payload).poll_next(cx)) + .await + .unwrap() + .unwrap() ); } } diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index b79453ebd..43b7919a7 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -1,27 +1,28 @@ -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; -use std::{fmt, net}; +use std::{ + fmt, + marker::PhantomData, + net, + rc::Rc, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_rt::net::TcpStream; -use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; -use futures_core::ready; -use futures_util::future::ready; +use actix_service::{ + fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _, +}; +use actix_utils::future::ready; +use futures_core::future::LocalBoxFuture; -use crate::body::MessageBody; -use crate::config::ServiceConfig; -use crate::error::{DispatchError, Error}; -use crate::request::Request; -use crate::response::Response; -use crate::service::HttpFlow; -use crate::{ConnectCallback, OnConnectData}; +use crate::{ + body::{BoxBody, MessageBody}, + config::ServiceConfig, + error::DispatchError, + service::HttpServiceHandler, + ConnectCallback, OnConnectData, Request, Response, +}; -use super::codec::Codec; -use super::dispatcher::Dispatcher; -use super::{ExpectHandler, UpgradeHandler}; +use super::{codec::Codec, dispatcher::Dispatcher, ExpectHandler, UpgradeHandler}; /// `ServiceFactory` implementation for HTTP1 transport pub struct H1Service { @@ -36,7 +37,7 @@ pub struct H1Service { impl H1Service where S: ServiceFactory, - S::Error: Into, + S::Error: Into>, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, @@ -60,15 +61,21 @@ where impl H1Service where S: ServiceFactory, - S::Error: Into, + S::Future: 'static, + S::Error: Into>, S::InitError: fmt::Debug, S::Response: Into>, + B: MessageBody, + X: ServiceFactory, - X::Error: Into, + X::Future: 'static, + X::Error: Into>, X::InitError: fmt::Debug, + U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, - U::Error: fmt::Display + Into, + U::Future: 'static, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { /// Create simple tcp stream service @@ -81,7 +88,7 @@ where Error = DispatchError, InitError = (), > { - pipeline_factory(|io: TcpStream| { + fn_service(|io: TcpStream| { let peer_addr = io.peer_addr().ok(); ready(Ok((io, peer_addr))) }) @@ -93,29 +100,39 @@ where mod openssl { use super::*; - use actix_service::ServiceFactoryExt; - use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; - use actix_tls::accept::TlsError; + use actix_tls::accept::{ + openssl::{ + reexports::{Error as SslError, SslAcceptor}, + Acceptor, TlsStream, + }, + TlsError, + }; - impl H1Service, S, B, X, U> + impl H1Service, S, B, X, U> where S: ServiceFactory, - S::Error: Into, + S::Future: 'static, + S::Error: Into>, S::InitError: fmt::Debug, S::Response: Into>, + B: MessageBody, + X: ServiceFactory, - X::Error: Into, + X::Future: 'static, + X::Error: Into>, X::InitError: fmt::Debug, + U: ServiceFactory< - (Request, Framed, Codec>), + (Request, Framed, Codec>), Config = (), Response = (), >, - U::Error: fmt::Display + Into, + U::Future: 'static, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { - /// Create openssl based service + /// Create OpenSSL based service. pub fn openssl( self, acceptor: SslAcceptor, @@ -126,47 +143,58 @@ mod openssl { Error = TlsError, InitError = (), > { - pipeline_factory( - Acceptor::new(acceptor) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()), - ) - .and_then(|io: SslStream| { - let peer_addr = io.get_ref().peer_addr().ok(); - ready(Ok((io, peer_addr))) - }) - .and_then(self.map_err(TlsError::Service)) + Acceptor::new(acceptor) + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) + .map(|io: TlsStream| { + let peer_addr = io.get_ref().peer_addr().ok(); + (io, peer_addr) + }) + .and_then(self.map_err(TlsError::Service)) } } } #[cfg(feature = "rustls")] mod rustls { + + use std::io; + + use actix_service::ServiceFactoryExt as _; + use actix_tls::accept::{ + rustls::{reexports::ServerConfig, Acceptor, TlsStream}, + TlsError, + }; + use super::*; - use actix_service::ServiceFactoryExt; - use actix_tls::accept::rustls::{Acceptor, ServerConfig, TlsStream}; - use actix_tls::accept::TlsError; - use std::{fmt, io}; impl H1Service, S, B, X, U> where S: ServiceFactory, - S::Error: Into, + S::Future: 'static, + S::Error: Into>, S::InitError: fmt::Debug, S::Response: Into>, + B: MessageBody, + X: ServiceFactory, - X::Error: Into, + X::Future: 'static, + X::Error: Into>, X::InitError: fmt::Debug, + U: ServiceFactory< (Request, Framed, Codec>), Config = (), Response = (), >, - U::Error: fmt::Display + Into, + U::Future: 'static, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { - /// Create rustls based service + /// Create Rustls based service. pub fn rustls( self, config: ServerConfig, @@ -177,16 +205,16 @@ mod rustls { Error = TlsError, InitError = (), > { - pipeline_factory( - Acceptor::new(config) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()), - ) - .and_then(|io: TlsStream| { - let peer_addr = io.get_ref().0.peer_addr().ok(); - ready(Ok((io, peer_addr))) - }) - .and_then(self.map_err(TlsError::Service)) + Acceptor::new(config) + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) + .map(|io: TlsStream| { + let peer_addr = io.get_ref().0.peer_addr().ok(); + (io, peer_addr) + }) + .and_then(self.map_err(TlsError::Service)) } } } @@ -194,7 +222,7 @@ mod rustls { impl H1Service where S: ServiceFactory, - S::Error: Into, + S::Error: Into>, S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, @@ -202,7 +230,7 @@ where pub fn expect(self, expect: X1) -> H1Service where X1: ServiceFactory, - X1::Error: Into, + X1::Error: Into>, X1::InitError: fmt::Debug, { H1Service { @@ -238,20 +266,26 @@ where } } -impl ServiceFactory<(T, Option)> - for H1Service +impl ServiceFactory<(T, Option)> for H1Service where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + 'static, + S: ServiceFactory, - S::Error: Into, + S::Future: 'static, + S::Error: Into>, S::Response: Into>, S::InitError: fmt::Debug, + B: MessageBody, + X: ServiceFactory, - X::Error: Into, + X::Future: 'static, + X::Error: Into>, X::InitError: fmt::Debug, + U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, - U::Error: fmt::Display + Into, + U::Future: 'static, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, { type Response = (); @@ -259,217 +293,77 @@ where type Config = (); type Service = H1ServiceHandler; type InitError = (); - type Future = H1ServiceResponse; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - H1ServiceResponse { - fut: self.srv.new_service(()), - fut_ex: Some(self.expect.new_service(())), - fut_upg: self.upgrade.as_ref().map(|f| f.new_service(())), - expect: None, - upgrade: None, - on_connect_ext: self.on_connect_ext.clone(), - cfg: Some(self.cfg.clone()), - _phantom: PhantomData, - } - } -} + let service = self.srv.new_service(()); + let expect = self.expect.new_service(()); + let upgrade = self.upgrade.as_ref().map(|s| s.new_service(())); + let on_connect_ext = self.on_connect_ext.clone(); + let cfg = self.cfg.clone(); -#[doc(hidden)] -#[pin_project::pin_project] -pub struct H1ServiceResponse -where - S: ServiceFactory, - S::Error: Into, - S::InitError: fmt::Debug, - X: ServiceFactory, - X::Error: Into, - X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Framed), Response = ()>, - U::Error: fmt::Display, - U::InitError: fmt::Debug, -{ - #[pin] - fut: S::Future, - #[pin] - fut_ex: Option, - #[pin] - fut_upg: Option, - expect: Option, - upgrade: Option, - on_connect_ext: Option>>, - cfg: Option, - _phantom: PhantomData, -} + Box::pin(async move { + let expect = expect + .await + .map_err(|e| log::error!("Init http expect service error: {:?}", e))?; -impl Future for H1ServiceResponse -where - T: AsyncRead + AsyncWrite + Unpin, - S: ServiceFactory, - S::Error: Into, - S::Response: Into>, - S::InitError: fmt::Debug, - B: MessageBody, - X: ServiceFactory, - X::Error: Into, - X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Framed), Response = ()>, - U::Error: fmt::Display, - U::InitError: fmt::Debug, -{ - type Output = Result, ()>; + let upgrade = match upgrade { + Some(upgrade) => { + let upgrade = upgrade + .await + .map_err(|e| log::error!("Init http upgrade service error: {:?}", e))?; + Some(upgrade) + } + None => None, + }; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut().project(); + let service = service + .await + .map_err(|e| log::error!("Init http service error: {:?}", e))?; - if let Some(fut) = this.fut_ex.as_pin_mut() { - let expect = ready!(fut - .poll(cx) - .map_err(|e| log::error!("Init http service error: {:?}", e)))?; - this = self.as_mut().project(); - *this.expect = Some(expect); - this.fut_ex.set(None); - } - - if let Some(fut) = this.fut_upg.as_pin_mut() { - let upgrade = ready!(fut - .poll(cx) - .map_err(|e| log::error!("Init http service error: {:?}", e)))?; - this = self.as_mut().project(); - *this.upgrade = Some(upgrade); - this.fut_upg.set(None); - } - - let result = ready!(this - .fut - .poll(cx) - .map_err(|e| log::error!("Init http service error: {:?}", e))); - - Poll::Ready(result.map(|service| { - let this = self.as_mut().project(); - - H1ServiceHandler::new( - this.cfg.take().unwrap(), + Ok(H1ServiceHandler::new( + cfg, service, - this.expect.take().unwrap(), - this.upgrade.take(), - this.on_connect_ext.clone(), - ) - })) + expect, + upgrade, + on_connect_ext, + )) + }) } } /// `Service` implementation for HTTP/1 transport -pub struct H1ServiceHandler -where - S: Service, - X: Service, - U: Service<(Request, Framed)>, -{ - flow: Rc>, - on_connect_ext: Option>>, - cfg: ServiceConfig, - _phantom: PhantomData, -} +pub type H1ServiceHandler = HttpServiceHandler; -impl H1ServiceHandler -where - S: Service, - S::Error: Into, - S::Response: Into>, - B: MessageBody, - X: Service, - X::Error: Into, - U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display, -{ - fn new( - cfg: ServiceConfig, - service: S, - expect: X, - upgrade: Option, - on_connect_ext: Option>>, - ) -> H1ServiceHandler { - H1ServiceHandler { - flow: HttpFlow::new(service, expect, upgrade), - cfg, - on_connect_ext, - _phantom: PhantomData, - } - } -} - -impl Service<(T, Option)> - for H1ServiceHandler +impl Service<(T, Option)> for HttpServiceHandler where T: AsyncRead + AsyncWrite + Unpin, + S: Service, - S::Error: Into, + S::Error: Into>, S::Response: Into>, + B: MessageBody, + X: Service, - X::Error: Into, + X::Error: Into>, + U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display + Into, + U::Error: fmt::Display + Into>, { type Response = (); type Error = DispatchError; type Future = Dispatcher; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - let ready = self - .flow - .expect - .poll_ready(cx) - .map_err(|e| { - let e = e.into(); - log::error!("Http service readiness error: {:?}", e); - DispatchError::Service(e) - })? - .is_ready(); - - let ready = self - .flow - .service - .poll_ready(cx) - .map_err(|e| { - let e = e.into(); - log::error!("Http service readiness error: {:?}", e); - DispatchError::Service(e) - })? - .is_ready() - && ready; - - let ready = if let Some(ref upg) = self.flow.upgrade { - upg.poll_ready(cx) - .map_err(|e| { - let e = e.into(); - log::error!("Http service readiness error: {:?}", e); - DispatchError::Service(e) - })? - .is_ready() - && ready - } else { - ready - }; - - if ready { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } + self._poll_ready(cx).map_err(|err| { + log::error!("HTTP/1 service readiness error: {:?}", err); + DispatchError::Service(err) + }) } fn call(&self, (io, addr): (T, Option)) -> Self::Future { - let on_connect_data = - OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); - - Dispatcher::new( - io, - self.cfg.clone(), - self.flow.clone(), - on_connect_data, - addr, - ) + let conn_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); + Dispatcher::new(io, self.flow.clone(), self.cfg.clone(), addr, conn_data) } } diff --git a/actix-http/src/h1/upgrade.rs b/actix-http/src/h1/upgrade.rs index 5e24d84e3..f25b0718b 100644 --- a/actix-http/src/h1/upgrade.rs +++ b/actix-http/src/h1/upgrade.rs @@ -1,12 +1,8 @@ -use std::task::Poll; - use actix_codec::Framed; use actix_service::{Service, ServiceFactory}; -use futures_util::future::{ready, Ready}; +use futures_core::future::LocalBoxFuture; -use crate::error::Error; -use crate::h1::Codec; -use crate::request::Request; +use crate::{h1::Codec, Error, Request}; pub struct UpgradeHandler; @@ -16,7 +12,7 @@ impl ServiceFactory<(Request, Framed)> for UpgradeHandler { type Config = (); type Service = UpgradeHandler; type InitError = Error; - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { unimplemented!() @@ -26,11 +22,11 @@ impl ServiceFactory<(Request, Framed)> for UpgradeHandler { impl Service<(Request, Framed)> for UpgradeHandler { type Response = (); type Error = Error; - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; actix_service::always_ready!(); fn call(&self, _: (Request, Framed)) -> Self::Future { - ready(Ok(())) + unimplemented!() } } diff --git a/actix-http/src/h1/utils.rs b/actix-http/src/h1/utils.rs index 9e9c57137..5c11b1dab 100644 --- a/actix-http/src/h1/utils.rs +++ b/actix-http/src/h1/utils.rs @@ -1,27 +1,35 @@ -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use pin_project_lite::pin_project; -use crate::body::{BodySize, MessageBody, ResponseBody}; -use crate::error::Error; -use crate::h1::{Codec, Message}; -use crate::response::Response; +use crate::{ + body::{BodySize, MessageBody}, + h1::{Codec, Message}, + Error, Response, +}; -/// Send HTTP/1 response -#[pin_project::pin_project] -pub struct SendResponse { - res: Option, BodySize)>>, - #[pin] - body: Option>, - #[pin] - framed: Option>, +pin_project! { + /// Send HTTP/1 response + pub struct SendResponse { + res: Option, BodySize)>>, + + #[pin] + body: Option, + + #[pin] + framed: Option>, + } } impl SendResponse where B: MessageBody, + B::Error: Into, { pub fn new(framed: Framed, response: Response) -> Self { let (res, body) = response.into_parts(); @@ -37,7 +45,8 @@ where impl Future for SendResponse where T: AsyncRead + AsyncWrite + Unpin, - B: MessageBody + Unpin, + B: MessageBody, + B::Error: Into, { type Output = Result, Error>; @@ -60,15 +69,24 @@ where .unwrap() .is_write_buf_full() { - match this.body.as_mut().as_pin_mut().unwrap().poll_next(cx)? { + let next = match this.body.as_mut().as_pin_mut().unwrap().poll_next(cx) { + Poll::Ready(Some(Ok(item))) => Poll::Ready(Some(item)), + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + }; + + match next { Poll::Ready(item) => { // body is done when item is None body_done = item.is_none(); if body_done { - let _ = this.body.take(); + this.body.set(None); } let framed = this.framed.as_mut().as_pin_mut().unwrap(); - framed.write(Message::Chunk(item))?; + framed + .write(Message::Chunk(item)) + .map_err(|err| Error::new_send_response().with_cause(err))?; } Poll::Pending => body_ready = false, } @@ -79,7 +97,10 @@ where // flush write buffer if !framed.is_write_buf_empty() { - match framed.flush(cx)? { + match framed + .flush(cx) + .map_err(|err| Error::new_send_response().with_cause(err))? + { Poll::Ready(_) => { if body_ready { continue; @@ -93,7 +114,9 @@ where // send response if let Some(res) = this.res.take() { - framed.write(res)?; + framed + .write(res) + .map_err(|err| Error::new_send_response().with_cause(err))?; continue; } diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs index 5ccd2a9d1..a90eb3466 100644 --- a/actix-http/src/h2/dispatcher.rs +++ b/actix-http/src/h2/dispatcher.rs @@ -1,370 +1,334 @@ -use std::future::Future; -use std::marker::PhantomData; -use std::net; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; -use std::{cmp, convert::TryFrom}; +use std::{ + cmp, + error::Error as StdError, + future::Future, + marker::PhantomData, + net, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite}; -use actix_rt::time::{Instant, Sleep}; +use actix_rt::time::{sleep, Sleep}; use actix_service::Service; +use actix_utils::future::poll_fn; use bytes::{Bytes, BytesMut}; use futures_core::ready; -use h2::server::{Connection, SendResponse}; -use h2::SendStream; -use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; +use h2::{ + server::{Connection, SendResponse}, + Ping, PingPong, +}; use log::{error, trace}; +use pin_project_lite::pin_project; -use crate::body::{BodySize, MessageBody, ResponseBody}; -use crate::config::ServiceConfig; -use crate::error::{DispatchError, Error}; -use crate::message::ResponseHead; -use crate::payload::Payload; -use crate::request::Request; -use crate::response::Response; -use crate::service::HttpFlow; -use crate::OnConnectData; +use crate::{ + body::{BodySize, BoxBody, MessageBody}, + config::ServiceConfig, + header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}, + service::HttpFlow, + Extensions, OnConnectData, Payload, Request, Response, ResponseHead, +}; const CHUNK_SIZE: usize = 16_384; -/// Dispatcher for HTTP/2 protocol. -#[pin_project::pin_project] -pub struct Dispatcher -where - T: AsyncRead + AsyncWrite + Unpin, - S: Service, - B: MessageBody, -{ - flow: Rc>, - connection: Connection, - on_connect_data: OnConnectData, - config: ServiceConfig, - peer_addr: Option, - ka_expire: Instant, - ka_timer: Option, - _phantom: PhantomData, +pin_project! { + /// Dispatcher for HTTP/2 protocol. + pub struct Dispatcher { + flow: Rc>, + connection: Connection, + conn_data: Option>, + config: ServiceConfig, + peer_addr: Option, + ping_pong: Option, + _phantom: PhantomData + } } impl Dispatcher where T: AsyncRead + AsyncWrite + Unpin, - S: Service, - S::Error: Into, - S::Response: Into>, - B: MessageBody, { pub(crate) fn new( + mut conn: Connection, flow: Rc>, - connection: Connection, - on_connect_data: OnConnectData, config: ServiceConfig, - timeout: Option, peer_addr: Option, + conn_data: OnConnectData, + timer: Option>>, ) -> Self { - // let keepalive = config.keep_alive_enabled(); - // let flags = if keepalive { - // Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED - // } else { - // Flags::empty() - // }; + let ping_pong = config.keep_alive().map(|dur| H2PingPong { + timer: timer + .map(|mut timer| { + // reset timer if it's received from new function. + timer.as_mut().reset(config.now() + dur); + timer + }) + .unwrap_or_else(|| Box::pin(sleep(dur))), + on_flight: false, + ping_pong: conn.ping_pong().unwrap(), + }); - // keep-alive timer - let (ka_expire, ka_timer) = if let Some(delay) = timeout { - (delay.deadline(), Some(delay)) - } else if let Some(delay) = config.keep_alive_timer() { - (delay.deadline(), Some(delay)) - } else { - (config.now(), None) - }; - - Dispatcher { + Self { flow, config, peer_addr, - connection, - on_connect_data, - ka_expire, - ka_timer, + connection: conn, + conn_data: conn_data.0.map(Rc::new), + ping_pong, _phantom: PhantomData, } } } +struct H2PingPong { + timer: Pin>, + on_flight: bool, + ping_pong: PingPong, +} + impl Future for Dispatcher where T: AsyncRead + AsyncWrite + Unpin, + S: Service, - S::Error: Into + 'static, + S::Error: Into>, S::Future: 'static, - S::Response: Into> + 'static, - B: MessageBody + 'static, + S::Response: Into>, + + B: MessageBody, { - type Output = Result<(), DispatchError>; + type Output = Result<(), crate::error::DispatchError>; #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); loop { - match ready!(Pin::new(&mut this.connection).poll_accept(cx)) { - None => return Poll::Ready(Ok(())), - - Some(Err(err)) => return Poll::Ready(Err(err.into())), - - Some(Ok((req, res))) => { - // update keep-alive expire - if this.ka_timer.is_some() { - if let Some(expire) = this.config.keep_alive_expire() { - this.ka_expire = expire; - } - } - + match Pin::new(&mut this.connection).poll_accept(cx)? { + Poll::Ready(Some((req, tx))) => { let (parts, body) = req.into_parts(); - let pl = crate::h2::Payload::new(body); - let pl = Payload::::H2(pl); + let payload = crate::h2::Payload::new(body); + let pl = Payload::H2 { payload }; let mut req = Request::with_payload(pl); - let head = &mut req.head_mut(); + let head = req.head_mut(); head.uri = parts.uri; head.method = parts.method; head.version = parts.version; head.headers = parts.headers.into(); head.peer_addr = this.peer_addr; - // merge on_connect_ext data into request extensions - this.on_connect_data.merge_into(&mut req); + req.conn_data = this.conn_data.as_ref().map(Rc::clone); - let svc = ServiceResponse:: { - state: ServiceResponseState::ServiceCall( - this.flow.service.call(req), - Some(res), - ), - config: this.config.clone(), - buffer: None, - _phantom: PhantomData, - }; + let fut = this.flow.service.call(req); + let config = this.config.clone(); - actix_rt::spawn(svc); + // multiplex request handling with spawn task + actix_rt::spawn(async move { + // resolve service call and send response. + let res = match fut.await { + Ok(res) => handle_response(res.into(), tx, config).await, + Err(err) => { + let res: Response = err.into(); + handle_response(res, tx, config).await + } + }; + + // log error. + if let Err(err) = res { + match err { + DispatchError::SendResponse(err) => { + trace!("Error sending HTTP/2 response: {:?}", err) + } + DispatchError::SendData(err) => warn!("{:?}", err), + DispatchError::ResponseBody(err) => { + error!("Response payload stream error: {:?}", err) + } + } + } + }); } + Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Pending => match this.ping_pong.as_mut() { + Some(ping_pong) => loop { + if ping_pong.on_flight { + // When have on flight ping pong. poll pong and and keep alive timer. + // on success pong received update keep alive timer to determine the next timing of + // ping pong. + match ping_pong.ping_pong.poll_pong(cx)? { + Poll::Ready(_) => { + ping_pong.on_flight = false; + + let dead_line = this.config.keep_alive_expire().unwrap(); + ping_pong.timer.as_mut().reset(dead_line); + } + Poll::Pending => { + return ping_pong.timer.as_mut().poll(cx).map(|_| Ok(())) + } + } + } else { + // When there is no on flight ping pong. keep alive timer is used to wait for next + // timing of ping pong. Therefore at this point it serves as an interval instead. + ready!(ping_pong.timer.as_mut().poll(cx)); + + ping_pong.ping_pong.send_ping(Ping::opaque())?; + + let dead_line = this.config.keep_alive_expire().unwrap(); + ping_pong.timer.as_mut().reset(dead_line); + + ping_pong.on_flight = true; + } + }, + None => return Poll::Pending, + }, } } } } -#[pin_project::pin_project] -struct ServiceResponse { - #[pin] - state: ServiceResponseState, +enum DispatchError { + SendResponse(h2::Error), + SendData(h2::Error), + ResponseBody(Box), +} + +async fn handle_response( + res: Response, + mut tx: SendResponse, config: ServiceConfig, - buffer: Option, - _phantom: PhantomData<(I, E)>, -} - -#[pin_project::pin_project(project = ServiceResponseStateProj)] -enum ServiceResponseState { - ServiceCall(#[pin] F, Option>), - SendPayload(SendStream, #[pin] ResponseBody), -} - -impl ServiceResponse +) -> Result<(), DispatchError> where - F: Future>, - E: Into, - I: Into>, B: MessageBody, { - fn prepare_response( - &self, - head: &ResponseHead, - size: &mut BodySize, - ) -> http::Response<()> { - let mut has_date = false; - let mut skip_len = size != &BodySize::Stream; + let (res, body) = res.replace_body(()); - let mut res = http::Response::new(()); - *res.status_mut() = head.status; - *res.version_mut() = http::Version::HTTP_2; + // prepare response. + let mut size = body.size(); + let res = prepare_response(config, res.head(), &mut size); + let eof = size.is_eof(); - // Content length - match head.status { - http::StatusCode::NO_CONTENT - | http::StatusCode::CONTINUE - | http::StatusCode::PROCESSING => *size = BodySize::None, - http::StatusCode::SWITCHING_PROTOCOLS => { - skip_len = true; - *size = BodySize::Stream; + // send response head and return on eof. + let mut stream = tx + .send_response(res, eof) + .map_err(DispatchError::SendResponse)?; + + if eof { + return Ok(()); + } + + // poll response body and send chunks to client + actix_rt::pin!(body); + + while let Some(res) = poll_fn(|cx| body.as_mut().poll_next(cx)).await { + let mut chunk = res.map_err(|err| DispatchError::ResponseBody(err.into()))?; + + 'send: loop { + let chunk_size = cmp::min(chunk.len(), CHUNK_SIZE); + + // reserve enough space and wait for stream ready. + stream.reserve_capacity(chunk_size); + + match poll_fn(|cx| stream.poll_capacity(cx)).await { + // No capacity left. drop body and return. + None => return Ok(()), + + Some(Err(err)) => return Err(DispatchError::SendData(err)), + + Some(Ok(cap)) => { + // split chunk to writeable size and send to client + let len = chunk.len(); + let bytes = chunk.split_to(cmp::min(len, cap)); + + stream + .send_data(bytes, false) + .map_err(DispatchError::SendData)?; + + // Current chuck completely sent. break send loop and poll next one. + if chunk.is_empty() { + break 'send; + } + } } + } + } + + // response body streaming finished. send end of stream and return. + stream + .send_data(Bytes::new(), true) + .map_err(DispatchError::SendData)?; + + Ok(()) +} + +fn prepare_response( + config: ServiceConfig, + head: &ResponseHead, + size: &mut BodySize, +) -> http::Response<()> { + let mut has_date = false; + let mut skip_len = size != &BodySize::Stream; + + let mut res = http::Response::new(()); + *res.status_mut() = head.status; + *res.version_mut() = http::Version::HTTP_2; + + // Content length + match head.status { + http::StatusCode::NO_CONTENT + | http::StatusCode::CONTINUE + | http::StatusCode::PROCESSING => *size = BodySize::None, + http::StatusCode::SWITCHING_PROTOCOLS => { + skip_len = true; + *size = BodySize::Stream; + } + _ => {} + } + + let _ = match size { + BodySize::None | BodySize::Stream => None, + + BodySize::Sized(0) => { + #[allow(clippy::declare_interior_mutable_const)] + const HV_ZERO: HeaderValue = HeaderValue::from_static("0"); + res.headers_mut().insert(CONTENT_LENGTH, HV_ZERO) + } + + BodySize::Sized(len) => { + let mut buf = itoa::Buffer::new(); + + res.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::from_str(buf.format(*len)).unwrap(), + ) + } + }; + + // copy headers + for (key, value) in head.headers.iter() { + match *key { + // TODO: consider skipping other headers according to: + // https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2.2 + // omit HTTP/1.x only headers + CONNECTION | TRANSFER_ENCODING => continue, + CONTENT_LENGTH if skip_len => continue, + DATE => has_date = true, _ => {} } - let _ = match size { - BodySize::None | BodySize::Stream => None, - BodySize::Empty => res - .headers_mut() - .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), - BodySize::Sized(len) => res.headers_mut().insert( - CONTENT_LENGTH, - HeaderValue::try_from(format!("{}", len)).unwrap(), - ), - }; - - // copy headers - for (key, value) in head.headers.iter() { - match *key { - // omit HTTP/1 only headers - CONNECTION | TRANSFER_ENCODING => continue, - CONTENT_LENGTH if skip_len => continue, - DATE => has_date = true, - _ => {} - } - - res.headers_mut().append(key, value.clone()); - } - - // set date header - if !has_date { - let mut bytes = BytesMut::with_capacity(29); - self.config.set_date_header(&mut bytes); - res.headers_mut().insert( - DATE, - // SAFETY: serialized date-times are known ASCII strings - unsafe { HeaderValue::from_maybe_shared_unchecked(bytes.freeze()) }, - ); - } - - res + res.headers_mut().append(key, value.clone()); } -} -impl Future for ServiceResponse -where - F: Future>, - E: Into, - I: Into>, - B: MessageBody, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut().project(); - - match this.state.project() { - ServiceResponseStateProj::ServiceCall(call, send) => { - match ready!(call.poll(cx)) { - Ok(res) => { - let (res, body) = res.into().replace_body(()); - - let mut send = send.take().unwrap(); - let mut size = body.size(); - let h2_res = - self.as_mut().prepare_response(res.head(), &mut size); - this = self.as_mut().project(); - - let stream = match send.send_response(h2_res, size.is_eof()) { - Err(e) => { - trace!("Error sending HTTP/2 response: {:?}", e); - return Poll::Ready(()); - } - Ok(stream) => stream, - }; - - if size.is_eof() { - Poll::Ready(()) - } else { - this.state - .set(ServiceResponseState::SendPayload(stream, body)); - self.poll(cx) - } - } - - Err(e) => { - let res: Response = e.into().into(); - let (res, body) = res.replace_body(()); - - let mut send = send.take().unwrap(); - let mut size = body.size(); - let h2_res = - self.as_mut().prepare_response(res.head(), &mut size); - this = self.as_mut().project(); - - let stream = match send.send_response(h2_res, size.is_eof()) { - Err(e) => { - trace!("Error sending HTTP/2 response: {:?}", e); - return Poll::Ready(()); - } - Ok(stream) => stream, - }; - - if size.is_eof() { - Poll::Ready(()) - } else { - this.state.set(ServiceResponseState::SendPayload( - stream, - body.into_body(), - )); - self.poll(cx) - } - } - } - } - - ServiceResponseStateProj::SendPayload(ref mut stream, ref mut body) => { - loop { - loop { - match this.buffer { - Some(ref mut buffer) => { - match ready!(stream.poll_capacity(cx)) { - None => return Poll::Ready(()), - - Some(Ok(cap)) => { - let len = buffer.len(); - let bytes = buffer.split_to(cmp::min(cap, len)); - - if let Err(e) = stream.send_data(bytes, false) { - warn!("{:?}", e); - return Poll::Ready(()); - } else if !buffer.is_empty() { - let cap = cmp::min(buffer.len(), CHUNK_SIZE); - stream.reserve_capacity(cap); - } else { - this.buffer.take(); - } - } - - Some(Err(e)) => { - warn!("{:?}", e); - return Poll::Ready(()); - } - } - } - - None => match ready!(body.as_mut().poll_next(cx)) { - None => { - if let Err(e) = stream.send_data(Bytes::new(), true) - { - warn!("{:?}", e); - } - return Poll::Ready(()); - } - - Some(Ok(chunk)) => { - stream.reserve_capacity(cmp::min( - chunk.len(), - CHUNK_SIZE, - )); - *this.buffer = Some(chunk); - } - - Some(Err(e)) => { - error!("Response payload stream error: {:?}", e); - return Poll::Ready(()); - } - }, - } - } - } - } - } + // set date header + if !has_date { + let mut bytes = BytesMut::with_capacity(29); + config.set_date_header(&mut bytes); + res.headers_mut().insert( + DATE, + // SAFETY: serialized date-times are known ASCII strings + unsafe { HeaderValue::from_maybe_shared_unchecked(bytes.freeze()) }, + ); } + + res } diff --git a/actix-http/src/h2/mod.rs b/actix-http/src/h2/mod.rs index 7eff44ac1..47d51b420 100644 --- a/actix-http/src/h2/mod.rs +++ b/actix-http/src/h2/mod.rs @@ -1,20 +1,30 @@ //! HTTP/2 protocol. use std::{ + future::Future, pin::Pin, task::{Context, Poll}, }; +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_rt::time::Sleep; use bytes::Bytes; use futures_core::{ready, Stream}; -use h2::RecvStream; +use h2::{ + server::{handshake, Connection, Handshake}, + RecvStream, +}; mod dispatcher; mod service; pub use self::dispatcher::Dispatcher; pub use self::service::H2Service; -use crate::error::PayloadError; + +use crate::{ + config::ServiceConfig, + error::{DispatchError, PayloadError}, +}; /// HTTP/2 peer stream. pub struct Payload { @@ -30,10 +40,7 @@ impl Payload { impl Stream for Payload { type Item = Result; - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); match ready!(Pin::new(&mut this.stream).poll_data(cx)) { @@ -50,3 +57,55 @@ impl Stream for Payload { } } } + +pub(crate) fn handshake_with_timeout( + io: T, + config: &ServiceConfig, +) -> HandshakeWithTimeout +where + T: AsyncRead + AsyncWrite + Unpin, +{ + HandshakeWithTimeout { + handshake: handshake(io), + timer: config.client_timer().map(Box::pin), + } +} + +pub(crate) struct HandshakeWithTimeout { + handshake: Handshake, + timer: Option>>, +} + +impl Future for HandshakeWithTimeout +where + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result<(Connection, Option>>), DispatchError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + match Pin::new(&mut this.handshake).poll(cx)? { + // return the timer on success handshake. It can be re-used for h2 ping-pong. + Poll::Ready(conn) => Poll::Ready(Ok((conn, this.timer.take()))), + Poll::Pending => match this.timer.as_mut() { + Some(timer) => { + ready!(timer.as_mut().poll(cx)); + Poll::Ready(Err(DispatchError::SlowRequestTimeout)) + } + None => Poll::Pending, + }, + } + } +} + +#[cfg(test)] +mod tests { + use std::panic::{RefUnwindSafe, UnwindSafe}; + + use static_assertions::assert_impl_all; + + use super::*; + + assert_impl_all!(Payload: Unpin, Send, Sync, UnwindSafe, RefUnwindSafe); +} diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index e00c8d968..469648054 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -1,30 +1,30 @@ -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{net, rc::Rc}; +use std::{ + future::Future, + marker::PhantomData, + mem, net, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_rt::net::TcpStream; use actix_service::{ - fn_factory, fn_service, pipeline_factory, IntoServiceFactory, Service, - ServiceFactory, + fn_factory, fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _, }; -use bytes::Bytes; -use futures_core::ready; -use futures_util::future::ok; -use h2::server::{self, Handshake}; +use actix_utils::future::ready; +use futures_core::{future::LocalBoxFuture, ready}; use log::error; -use crate::body::MessageBody; -use crate::config::ServiceConfig; -use crate::error::{DispatchError, Error}; -use crate::request::Request; -use crate::response::Response; -use crate::service::HttpFlow; -use crate::{ConnectCallback, OnConnectData}; +use crate::{ + body::{BoxBody, MessageBody}, + config::ServiceConfig, + error::DispatchError, + service::HttpFlow, + ConnectCallback, OnConnectData, Request, Response, +}; -use super::dispatcher::Dispatcher; +use super::{dispatcher::Dispatcher, handshake_with_timeout, HandshakeWithTimeout}; /// `ServiceFactory` implementation for HTTP/2 transport pub struct H2Service { @@ -37,9 +37,10 @@ pub struct H2Service { impl H2Service where S: ServiceFactory, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, { /// Create new `H2Service` instance with config. @@ -65,9 +66,11 @@ where impl H2Service where S: ServiceFactory, - S::Error: Into + 'static, + S::Future: 'static, + S::Error: Into> + 'static, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, { /// Create plain TCP based service @@ -80,33 +83,40 @@ where Error = DispatchError, InitError = S::InitError, > { - pipeline_factory(fn_factory(|| async { - Ok::<_, S::InitError>(fn_service(|io: TcpStream| { + fn_factory(|| { + ready(Ok::<_, S::InitError>(fn_service(|io: TcpStream| { let peer_addr = io.peer_addr().ok(); - ok::<_, DispatchError>((io, peer_addr)) - })) - })) + ready(Ok::<_, DispatchError>((io, peer_addr))) + }))) + }) .and_then(self) } } #[cfg(feature = "openssl")] mod openssl { - use actix_service::{fn_factory, fn_service, ServiceFactoryExt}; - use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; - use actix_tls::accept::TlsError; + use actix_service::ServiceFactoryExt as _; + use actix_tls::accept::{ + openssl::{ + reexports::{Error as SslError, SslAcceptor}, + Acceptor, TlsStream, + }, + TlsError, + }; use super::*; - impl H2Service, S, B> + impl H2Service, S, B> where S: ServiceFactory, - S::Error: Into + 'static, + S::Future: 'static, + S::Error: Into> + 'static, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, { - /// Create OpenSSL based service + /// Create OpenSSL based service. pub fn openssl( self, acceptor: SslAcceptor, @@ -117,39 +127,43 @@ mod openssl { Error = TlsError, InitError = S::InitError, > { - pipeline_factory( - Acceptor::new(acceptor) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()), - ) - .and_then(fn_factory(|| { - ok::<_, S::InitError>(fn_service(|io: SslStream| { + Acceptor::new(acceptor) + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) + .map(|io: TlsStream| { let peer_addr = io.get_ref().peer_addr().ok(); - ok((io, peer_addr)) - })) - })) - .and_then(self.map_err(TlsError::Service)) + (io, peer_addr) + }) + .and_then(self.map_err(TlsError::Service)) } } } #[cfg(feature = "rustls")] mod rustls { - use super::*; - use actix_service::ServiceFactoryExt; - use actix_tls::accept::rustls::{Acceptor, ServerConfig, TlsStream}; - use actix_tls::accept::TlsError; use std::io; + use actix_service::ServiceFactoryExt as _; + use actix_tls::accept::{ + rustls::{reexports::ServerConfig, Acceptor, TlsStream}, + TlsError, + }; + + use super::*; + impl H2Service, S, B> where S: ServiceFactory, - S::Error: Into + 'static, + S::Future: 'static, + S::Error: Into> + 'static, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, { - /// Create Rustls based service + /// Create Rustls based service. pub fn rustls( self, mut config: ServerConfig, @@ -160,32 +174,34 @@ mod rustls { Error = TlsError, InitError = S::InitError, > { - let protos = vec!["h2".to_string().into()]; - config.set_protocols(&protos); + let mut protos = vec![b"h2".to_vec()]; + protos.extend_from_slice(&config.alpn_protocols); + config.alpn_protocols = protos; - pipeline_factory( - Acceptor::new(config) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()), - ) - .and_then(fn_factory(|| { - ok::<_, S::InitError>(fn_service(|io: TlsStream| { + Acceptor::new(config) + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) + .map(|io: TlsStream| { let peer_addr = io.get_ref().0.peer_addr().ok(); - ok((io, peer_addr)) - })) - })) - .and_then(self.map_err(TlsError::Service)) + (io, peer_addr) + }) + .and_then(self.map_err(TlsError::Service)) } } } impl ServiceFactory<(T, Option)> for H2Service where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + 'static, + S: ServiceFactory, - S::Error: Into + 'static, + S::Future: 'static, + S::Error: Into> + 'static, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, { type Response = (); @@ -193,52 +209,16 @@ where type Config = (); type Service = H2ServiceHandler; type InitError = S::InitError; - type Future = H2ServiceResponse; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - H2ServiceResponse { - fut: self.srv.new_service(()), - cfg: Some(self.cfg.clone()), - on_connect_ext: self.on_connect_ext.clone(), - _phantom: PhantomData, - } - } -} + let service = self.srv.new_service(()); + let cfg = self.cfg.clone(); + let on_connect_ext = self.on_connect_ext.clone(); -#[doc(hidden)] -#[pin_project::pin_project] -pub struct H2ServiceResponse -where - S: ServiceFactory, -{ - #[pin] - fut: S::Future, - cfg: Option, - on_connect_ext: Option>>, - _phantom: PhantomData, -} - -impl Future for H2ServiceResponse -where - T: AsyncRead + AsyncWrite + Unpin, - S: ServiceFactory, - S::Error: Into + 'static, - S::Response: Into> + 'static, - >::Future: 'static, - B: MessageBody + 'static, -{ - type Output = Result, S::InitError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().project(); - - this.fut.poll(cx).map_ok(|service| { - let this = self.as_mut().project(); - H2ServiceHandler::new( - this.cfg.take().unwrap(), - this.on_connect_ext.clone(), - service, - ) + Box::pin(async move { + let service = service.await?; + Ok(H2ServiceHandler::new(cfg, on_connect_ext, service)) }) } } @@ -257,7 +237,7 @@ where impl H2ServiceHandler where S: Service, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, @@ -280,7 +260,7 @@ impl Service<(T, Option)> for H2ServiceHandler, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, @@ -290,16 +270,15 @@ where type Future = H2ServiceHandlerResponse; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.flow.service.poll_ready(cx).map_err(|e| { - let e = e.into(); - error!("Service readiness error: {:?}", e); - DispatchError::Service(e) + self.flow.service.poll_ready(cx).map_err(|err| { + let err = err.into(); + error!("Service readiness error: {:?}", err); + DispatchError::Service(err) }) } fn call(&self, (io, addr): (T, Option)) -> Self::Future { - let on_connect_data = - OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); + let on_connect_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); H2ServiceHandlerResponse { state: State::Handshake( @@ -307,7 +286,7 @@ where Some(self.cfg.clone()), addr, on_connect_data, - server::handshake(io), + handshake_with_timeout(io, &self.cfg), ), } } @@ -318,21 +297,21 @@ where T: AsyncRead + AsyncWrite + Unpin, S::Future: 'static, { - Incoming(Dispatcher), Handshake( Option>>, Option, Option, OnConnectData, - Handshake, + HandshakeWithTimeout, ), + Established(Dispatcher), } pub struct H2ServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, @@ -344,7 +323,7 @@ impl Future for H2ServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody, @@ -353,31 +332,35 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.state { - State::Incoming(ref mut disp) => Pin::new(disp).poll(cx), State::Handshake( ref mut srv, ref mut config, ref peer_addr, - ref mut on_connect_data, + ref mut conn_data, ref mut handshake, ) => match ready!(Pin::new(handshake).poll(cx)) { - Ok(conn) => { - let on_connect_data = std::mem::take(on_connect_data); - self.state = State::Incoming(Dispatcher::new( - srv.take().unwrap(), + Ok((conn, timer)) => { + let on_connect_data = mem::take(conn_data); + + self.state = State::Established(Dispatcher::new( conn, - on_connect_data, + srv.take().unwrap(), config.take().unwrap(), - None, *peer_addr, + on_connect_data, + timer, )); + self.poll(cx) } + Err(err) => { trace!("H2 handshake error: {}", err); - Poll::Ready(Err(err.into())) + Poll::Ready(Err(err)) } }, + + State::Established(ref mut disp) => Pin::new(disp).poll(cx), } } } diff --git a/actix-http/src/header/as_name.rs b/actix-http/src/header/as_name.rs index af81ff7f2..a895010b1 100644 --- a/actix-http/src/header/as_name.rs +++ b/actix-http/src/header/as_name.rs @@ -1,47 +1,55 @@ -//! Helper trait for types that can be effectively borrowed as a [HeaderValue]. -//! -//! [HeaderValue]: crate::http::HeaderValue +//! Sealed [`AsHeaderName`] trait and implementations. -use std::{borrow::Cow, str::FromStr}; +use std::{borrow::Cow, str::FromStr as _}; use http::header::{HeaderName, InvalidHeaderName}; +/// Sealed trait implemented for types that can be effectively borrowed as a [`HeaderValue`]. +/// +/// [`HeaderValue`]: super::HeaderValue pub trait AsHeaderName: Sealed {} +pub struct Seal; + pub trait Sealed { - fn try_as_name(&self) -> Result, InvalidHeaderName>; + fn try_as_name(&self, seal: Seal) -> Result, InvalidHeaderName>; } impl Sealed for HeaderName { - fn try_as_name(&self) -> Result, InvalidHeaderName> { + #[inline] + fn try_as_name(&self, _: Seal) -> Result, InvalidHeaderName> { Ok(Cow::Borrowed(self)) } } impl AsHeaderName for HeaderName {} impl Sealed for &HeaderName { - fn try_as_name(&self) -> Result, InvalidHeaderName> { + #[inline] + fn try_as_name(&self, _: Seal) -> Result, InvalidHeaderName> { Ok(Cow::Borrowed(*self)) } } impl AsHeaderName for &HeaderName {} impl Sealed for &str { - fn try_as_name(&self) -> Result, InvalidHeaderName> { + #[inline] + fn try_as_name(&self, _: Seal) -> Result, InvalidHeaderName> { HeaderName::from_str(self).map(Cow::Owned) } } impl AsHeaderName for &str {} impl Sealed for String { - fn try_as_name(&self) -> Result, InvalidHeaderName> { + #[inline] + fn try_as_name(&self, _: Seal) -> Result, InvalidHeaderName> { HeaderName::from_str(self).map(Cow::Owned) } } impl AsHeaderName for String {} impl Sealed for &String { - fn try_as_name(&self) -> Result, InvalidHeaderName> { + #[inline] + fn try_as_name(&self, _: Seal) -> Result, InvalidHeaderName> { HeaderName::from_str(self).map(Cow::Owned) } } diff --git a/actix-http/src/header/common/accept_charset.rs b/actix-http/src/header/common/accept_charset.rs deleted file mode 100644 index db530a8bc..000000000 --- a/actix-http/src/header/common/accept_charset.rs +++ /dev/null @@ -1,62 +0,0 @@ -use crate::header::{Charset, QualityItem, ACCEPT_CHARSET}; - -header! { - /// `Accept-Charset` header, defined in - /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.3) - /// - /// The `Accept-Charset` header field can be sent by a user agent to - /// indicate what charsets are acceptable in textual response content. - /// This field allows user agents capable of understanding more - /// comprehensive or special-purpose charsets to signal that capability - /// to an origin server that is capable of representing information in - /// those charsets. - /// - /// # ABNF - /// - /// ```text - /// Accept-Charset = 1#( ( charset / "*" ) [ weight ] ) - /// ``` - /// - /// # Example values - /// * `iso-8859-5, unicode-1-1;q=0.8` - /// - /// # Examples - /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::{AcceptCharset, Charset, qitem}; - /// - /// let mut builder = Response::Ok(); - /// builder.insert_header( - /// AcceptCharset(vec![qitem(Charset::Us_Ascii)]) - /// ); - /// ``` - /// - /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::{AcceptCharset, Charset, q, QualityItem}; - /// - /// let mut builder = Response::Ok(); - /// builder.insert_header( - /// AcceptCharset(vec![ - /// QualityItem::new(Charset::Us_Ascii, q(900)), - /// QualityItem::new(Charset::Iso_8859_10, q(200)), - /// ]) - /// ); - /// ``` - /// - /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::{AcceptCharset, Charset, qitem}; - /// - /// let mut builder = Response::Ok(); - /// builder.insert_header( - /// AcceptCharset(vec![qitem(Charset::Ext("utf-8".to_owned()))]) - /// ); - /// ``` - (AcceptCharset, ACCEPT_CHARSET) => (QualityItem)+ - - test_accept_charset { - // Test case from RFC - test_header!(test1, vec![b"iso-8859-5, unicode-1-1;q=0.8"]); - } -} diff --git a/actix-http/src/header/common/accept_encoding.rs b/actix-http/src/header/common/accept_encoding.rs deleted file mode 100644 index c90f529bc..000000000 --- a/actix-http/src/header/common/accept_encoding.rs +++ /dev/null @@ -1,72 +0,0 @@ -use header::{Encoding, QualityItem}; - -header! { - /// `Accept-Encoding` header, defined in - /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.4) - /// - /// The `Accept-Encoding` header field can be used by user agents to - /// indicate what response content-codings are - /// acceptable in the response. An `identity` token is used as a synonym - /// for "no encoding" in order to communicate when no encoding is - /// preferred. - /// - /// # ABNF - /// - /// ```text - /// Accept-Encoding = #( codings [ weight ] ) - /// codings = content-coding / "identity" / "*" - /// ``` - /// - /// # Example values - /// * `compress, gzip` - /// * `` - /// * `*` - /// * `compress;q=0.5, gzip;q=1` - /// * `gzip;q=1.0, identity; q=0.5, *;q=0` - /// - /// # Examples - /// ``` - /// use hyper::header::{Headers, AcceptEncoding, Encoding, qitem}; - /// - /// let mut headers = Headers::new(); - /// headers.set( - /// AcceptEncoding(vec![qitem(Encoding::Chunked)]) - /// ); - /// ``` - /// ``` - /// use hyper::header::{Headers, AcceptEncoding, Encoding, qitem}; - /// - /// let mut headers = Headers::new(); - /// headers.set( - /// AcceptEncoding(vec![ - /// qitem(Encoding::Chunked), - /// qitem(Encoding::Gzip), - /// qitem(Encoding::Deflate), - /// ]) - /// ); - /// ``` - /// ``` - /// use hyper::header::{Headers, AcceptEncoding, Encoding, QualityItem, q, qitem}; - /// - /// let mut headers = Headers::new(); - /// headers.set( - /// AcceptEncoding(vec![ - /// qitem(Encoding::Chunked), - /// QualityItem::new(Encoding::Gzip, q(600)), - /// QualityItem::new(Encoding::EncodingExt("*".to_owned()), q(0)), - /// ]) - /// ); - /// ``` - (AcceptEncoding, "Accept-Encoding") => (QualityItem)* - - test_accept_encoding { - // From the RFC - test_header!(test1, vec![b"compress, gzip"]); - test_header!(test2, vec![b""], Some(AcceptEncoding(vec![]))); - test_header!(test3, vec![b"*"]); - // Note: Removed quality 1 from gzip - test_header!(test4, vec![b"compress;q=0.5, gzip"]); - // Note: Removed quality 1 from gzip - test_header!(test5, vec![b"gzip, identity; q=0.5, *;q=0"]); - } -} diff --git a/actix-http/src/header/common/accept_language.rs b/actix-http/src/header/common/accept_language.rs deleted file mode 100644 index a7ad00863..000000000 --- a/actix-http/src/header/common/accept_language.rs +++ /dev/null @@ -1,69 +0,0 @@ -use crate::header::{QualityItem, ACCEPT_LANGUAGE}; -use language_tags::LanguageTag; - -header! { - /// `Accept-Language` header, defined in - /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.5) - /// - /// The `Accept-Language` header field can be used by user agents to - /// indicate the set of natural languages that are preferred in the - /// response. - /// - /// # ABNF - /// - /// ```text - /// Accept-Language = 1#( language-range [ weight ] ) - /// language-range = - /// ``` - /// - /// # Example values - /// * `da, en-gb;q=0.8, en;q=0.7` - /// * `en-us;q=1.0, en;q=0.5, fr` - /// - /// # Examples - /// - /// ``` - /// use language_tags::langtag; - /// use actix_http::Response; - /// use actix_http::http::header::{AcceptLanguage, LanguageTag, qitem}; - /// - /// let mut builder = Response::Ok(); - /// let mut langtag: LanguageTag = Default::default(); - /// langtag.language = Some("en".to_owned()); - /// langtag.region = Some("US".to_owned()); - /// builder.insert_header( - /// AcceptLanguage(vec![ - /// qitem(langtag), - /// ]) - /// ); - /// ``` - /// - /// ``` - /// use language_tags::langtag; - /// use actix_http::Response; - /// use actix_http::http::header::{AcceptLanguage, QualityItem, q, qitem}; - /// - /// let mut builder = Response::Ok(); - /// builder.insert_header( - /// AcceptLanguage(vec![ - /// qitem(langtag!(da)), - /// QualityItem::new(langtag!(en;;;GB), q(800)), - /// QualityItem::new(langtag!(en), q(700)), - /// ]) - /// ); - /// ``` - (AcceptLanguage, ACCEPT_LANGUAGE) => (QualityItem)+ - - test_accept_language { - // From the RFC - test_header!(test1, vec![b"da, en-gb;q=0.8, en;q=0.7"]); - // Own test - test_header!( - test2, vec![b"en-US, en; q=0.5, fr"], - Some(AcceptLanguage(vec![ - qitem("en-US".parse().unwrap()), - QualityItem::new("en".parse().unwrap(), q(500)), - qitem("fr".parse().unwrap()), - ]))); - } -} diff --git a/actix-http/src/header/common/cache_control.rs b/actix-http/src/header/common/cache_control.rs deleted file mode 100644 index 94ce9a750..000000000 --- a/actix-http/src/header/common/cache_control.rs +++ /dev/null @@ -1,262 +0,0 @@ -use std::fmt::{self, Write}; -use std::str::FromStr; - -use http::header; - -use crate::header::{ - fmt_comma_delimited, from_comma_delimited, Header, IntoHeaderValue, Writer, -}; - -/// `Cache-Control` header, defined in [RFC7234](https://tools.ietf.org/html/rfc7234#section-5.2) -/// -/// The `Cache-Control` header field is used to specify directives for -/// caches along the request/response chain. Such cache directives are -/// unidirectional in that the presence of a directive in a request does -/// not imply that the same directive is to be given in the response. -/// -/// # ABNF -/// -/// ```text -/// Cache-Control = 1#cache-directive -/// cache-directive = token [ "=" ( token / quoted-string ) ] -/// ``` -/// -/// # Example values -/// -/// * `no-cache` -/// * `private, community="UCI"` -/// * `max-age=30` -/// -/// # Examples -/// ``` -/// use actix_http::Response; -/// use actix_http::http::header::{CacheControl, CacheDirective}; -/// -/// let mut builder = Response::Ok(); -/// builder.insert_header(CacheControl(vec![CacheDirective::MaxAge(86400u32)])); -/// ``` -/// -/// ```rust -/// use actix_http::Response; -/// use actix_http::http::header::{CacheControl, CacheDirective}; -/// -/// let mut builder = Response::Ok(); -/// builder.insert_header(CacheControl(vec![ -/// CacheDirective::NoCache, -/// CacheDirective::Private, -/// CacheDirective::MaxAge(360u32), -/// CacheDirective::Extension("foo".to_owned(), Some("bar".to_owned())), -/// ])); -/// ``` -#[derive(PartialEq, Clone, Debug)] -pub struct CacheControl(pub Vec); - -__hyper__deref!(CacheControl => Vec); - -//TODO: this could just be the header! macro -impl Header for CacheControl { - fn name() -> header::HeaderName { - header::CACHE_CONTROL - } - - #[inline] - fn parse(msg: &T) -> Result - where - T: crate::HttpMessage, - { - let directives = from_comma_delimited(msg.headers().get_all(&Self::name()))?; - if !directives.is_empty() { - Ok(CacheControl(directives)) - } else { - Err(crate::error::ParseError::Header) - } - } -} - -impl fmt::Display for CacheControl { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt_comma_delimited(f, &self[..]) - } -} - -impl IntoHeaderValue for CacheControl { - type Error = header::InvalidHeaderValue; - - fn try_into_value(self) -> Result { - let mut writer = Writer::new(); - let _ = write!(&mut writer, "{}", self); - header::HeaderValue::from_maybe_shared(writer.take()) - } -} - -/// `CacheControl` contains a list of these directives. -#[derive(PartialEq, Clone, Debug)] -pub enum CacheDirective { - /// "no-cache" - NoCache, - /// "no-store" - NoStore, - /// "no-transform" - NoTransform, - /// "only-if-cached" - OnlyIfCached, - - // request directives - /// "max-age=delta" - MaxAge(u32), - /// "max-stale=delta" - MaxStale(u32), - /// "min-fresh=delta" - MinFresh(u32), - - // response directives - /// "must-revalidate" - MustRevalidate, - /// "public" - Public, - /// "private" - Private, - /// "proxy-revalidate" - ProxyRevalidate, - /// "s-maxage=delta" - SMaxAge(u32), - - /// Extension directives. Optionally include an argument. - Extension(String, Option), -} - -impl fmt::Display for CacheDirective { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use self::CacheDirective::*; - fmt::Display::fmt( - match *self { - NoCache => "no-cache", - NoStore => "no-store", - NoTransform => "no-transform", - OnlyIfCached => "only-if-cached", - - MaxAge(secs) => return write!(f, "max-age={}", secs), - MaxStale(secs) => return write!(f, "max-stale={}", secs), - MinFresh(secs) => return write!(f, "min-fresh={}", secs), - - MustRevalidate => "must-revalidate", - Public => "public", - Private => "private", - ProxyRevalidate => "proxy-revalidate", - SMaxAge(secs) => return write!(f, "s-maxage={}", secs), - - Extension(ref name, None) => &name[..], - Extension(ref name, Some(ref arg)) => { - return write!(f, "{}={}", name, arg); - } - }, - f, - ) - } -} - -impl FromStr for CacheDirective { - type Err = Option<::Err>; - fn from_str(s: &str) -> Result::Err>> { - use self::CacheDirective::*; - match s { - "no-cache" => Ok(NoCache), - "no-store" => Ok(NoStore), - "no-transform" => Ok(NoTransform), - "only-if-cached" => Ok(OnlyIfCached), - "must-revalidate" => Ok(MustRevalidate), - "public" => Ok(Public), - "private" => Ok(Private), - "proxy-revalidate" => Ok(ProxyRevalidate), - "" => Err(None), - _ => match s.find('=') { - Some(idx) if idx + 1 < s.len() => { - match (&s[..idx], (&s[idx + 1..]).trim_matches('"')) { - ("max-age", secs) => secs.parse().map(MaxAge).map_err(Some), - ("max-stale", secs) => secs.parse().map(MaxStale).map_err(Some), - ("min-fresh", secs) => secs.parse().map(MinFresh).map_err(Some), - ("s-maxage", secs) => secs.parse().map(SMaxAge).map_err(Some), - (left, right) => { - Ok(Extension(left.to_owned(), Some(right.to_owned()))) - } - } - } - Some(_) => Err(None), - None => Ok(Extension(s.to_owned(), None)), - }, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::header::Header; - use crate::test::TestRequest; - - #[test] - fn test_parse_multiple_headers() { - let req = TestRequest::default() - .insert_header((header::CACHE_CONTROL, "no-cache, private")) - .finish(); - let cache = Header::parse(&req); - assert_eq!( - cache.ok(), - Some(CacheControl(vec![ - CacheDirective::NoCache, - CacheDirective::Private, - ])) - ) - } - - #[test] - fn test_parse_argument() { - let req = TestRequest::default() - .insert_header((header::CACHE_CONTROL, "max-age=100, private")) - .finish(); - let cache = Header::parse(&req); - assert_eq!( - cache.ok(), - Some(CacheControl(vec![ - CacheDirective::MaxAge(100), - CacheDirective::Private, - ])) - ) - } - - #[test] - fn test_parse_quote_form() { - let req = TestRequest::default() - .insert_header((header::CACHE_CONTROL, "max-age=\"200\"")) - .finish(); - let cache = Header::parse(&req); - assert_eq!( - cache.ok(), - Some(CacheControl(vec![CacheDirective::MaxAge(200)])) - ) - } - - #[test] - fn test_parse_extension() { - let req = TestRequest::default() - .insert_header((header::CACHE_CONTROL, "foo, bar=baz")) - .finish(); - let cache = Header::parse(&req); - assert_eq!( - cache.ok(), - Some(CacheControl(vec![ - CacheDirective::Extension("foo".to_owned(), None), - CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned())), - ])) - ) - } - - #[test] - fn test_parse_bad_syntax() { - let req = TestRequest::default() - .insert_header((header::CACHE_CONTROL, "foo=")) - .finish(); - let cache: Result = Header::parse(&req); - assert_eq!(cache.ok(), None) - } -} diff --git a/actix-http/src/header/common/content_encoding.rs b/actix-http/src/header/common/content_encoding.rs deleted file mode 100644 index b93d66101..000000000 --- a/actix-http/src/header/common/content_encoding.rs +++ /dev/null @@ -1,106 +0,0 @@ -use std::{convert::Infallible, str::FromStr}; - -use http::header::InvalidHeaderValue; - -use crate::{ - error::ParseError, - header::{self, from_one_raw_str, Header, HeaderName, HeaderValue, IntoHeaderValue}, - HttpMessage, -}; - -/// Represents a supported content encoding. -#[derive(Copy, Clone, PartialEq, Debug)] -pub enum ContentEncoding { - /// Automatically select encoding based on encoding negotiation. - Auto, - - /// A format using the Brotli algorithm. - Br, - - /// A format using the zlib structure with deflate algorithm. - Deflate, - - /// Gzip algorithm. - Gzip, - - /// Indicates the identity function (i.e. no compression, nor modification). - Identity, -} - -impl ContentEncoding { - /// Is the content compressed? - #[inline] - pub fn is_compression(self) -> bool { - matches!(self, ContentEncoding::Identity | ContentEncoding::Auto) - } - - /// Convert content encoding to string - #[inline] - pub fn as_str(self) -> &'static str { - match self { - ContentEncoding::Br => "br", - ContentEncoding::Gzip => "gzip", - ContentEncoding::Deflate => "deflate", - ContentEncoding::Identity | ContentEncoding::Auto => "identity", - } - } - - /// Default Q-factor (quality) value. - #[inline] - pub fn quality(self) -> f64 { - match self { - ContentEncoding::Br => 1.1, - ContentEncoding::Gzip => 1.0, - ContentEncoding::Deflate => 0.9, - ContentEncoding::Identity | ContentEncoding::Auto => 0.1, - } - } -} - -impl Default for ContentEncoding { - fn default() -> Self { - Self::Identity - } -} - -impl FromStr for ContentEncoding { - type Err = Infallible; - - fn from_str(val: &str) -> Result { - Ok(Self::from(val)) - } -} - -impl From<&str> for ContentEncoding { - fn from(val: &str) -> ContentEncoding { - let val = val.trim(); - - if val.eq_ignore_ascii_case("br") { - ContentEncoding::Br - } else if val.eq_ignore_ascii_case("gzip") { - ContentEncoding::Gzip - } else if val.eq_ignore_ascii_case("deflate") { - ContentEncoding::Deflate - } else { - ContentEncoding::default() - } - } -} - -impl IntoHeaderValue for ContentEncoding { - type Error = InvalidHeaderValue; - - fn try_into_value(self) -> Result { - Ok(HeaderValue::from_static(self.as_str())) - } -} - -impl Header for ContentEncoding { - fn name() -> HeaderName { - header::CONTENT_ENCODING - } - - fn parse(msg: &T) -> Result { - from_one_raw_str(msg.headers().get(Self::name())) - } -} diff --git a/actix-http/src/header/common/content_language.rs b/actix-http/src/header/common/content_language.rs deleted file mode 100644 index e9be67a1b..000000000 --- a/actix-http/src/header/common/content_language.rs +++ /dev/null @@ -1,58 +0,0 @@ -use crate::header::{QualityItem, CONTENT_LANGUAGE}; -use language_tags::LanguageTag; - -header! { - /// `Content-Language` header, defined in - /// [RFC7231](https://tools.ietf.org/html/rfc7231#section-3.1.3.2) - /// - /// The `Content-Language` header field describes the natural language(s) - /// of the intended audience for the representation. Note that this - /// might not be equivalent to all the languages used within the - /// representation. - /// - /// # ABNF - /// - /// ```text - /// Content-Language = 1#language-tag - /// ``` - /// - /// # Example values - /// - /// * `da` - /// * `mi, en` - /// - /// # Examples - /// - /// ``` - /// use language_tags::langtag; - /// use actix_http::Response; - /// use actix_http::http::header::{ContentLanguage, qitem}; - /// - /// let mut builder = Response::Ok(); - /// builder.insert_header( - /// ContentLanguage(vec![ - /// qitem(langtag!(en)), - /// ]) - /// ); - /// ``` - /// - /// ``` - /// use language_tags::langtag; - /// use actix_http::Response; - /// use actix_http::http::header::{ContentLanguage, qitem}; - /// - /// let mut builder = Response::Ok(); - /// builder.insert_header( - /// ContentLanguage(vec![ - /// qitem(langtag!(da)), - /// qitem(langtag!(en;;;GB)), - /// ]) - /// ); - /// ``` - (ContentLanguage, CONTENT_LANGUAGE) => (QualityItem)+ - - test_content_language { - test_header!(test1, vec![b"da"]); - test_header!(test2, vec![b"mi, en"]); - } -} diff --git a/actix-http/src/header/common/etag.rs b/actix-http/src/header/common/etag.rs deleted file mode 100644 index 4c1e8d262..000000000 --- a/actix-http/src/header/common/etag.rs +++ /dev/null @@ -1,100 +0,0 @@ -use crate::header::{EntityTag, ETAG}; - -header! { - /// `ETag` header, defined in [RFC7232](http://tools.ietf.org/html/rfc7232#section-2.3) - /// - /// The `ETag` header field in a response provides the current entity-tag - /// for the selected representation, as determined at the conclusion of - /// handling the request. An entity-tag is an opaque validator for - /// differentiating between multiple representations of the same - /// resource, regardless of whether those multiple representations are - /// due to resource state changes over time, content negotiation - /// resulting in multiple representations being valid at the same time, - /// or both. An entity-tag consists of an opaque quoted string, possibly - /// prefixed by a weakness indicator. - /// - /// # ABNF - /// - /// ```text - /// ETag = entity-tag - /// ``` - /// - /// # Example values - /// - /// * `"xyzzy"` - /// * `W/"xyzzy"` - /// * `""` - /// - /// # Examples - /// - /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::{ETag, EntityTag}; - /// - /// let mut builder = Response::Ok(); - /// builder.insert_header( - /// ETag(EntityTag::new(false, "xyzzy".to_owned())) - /// ); - /// ``` - /// - /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::{ETag, EntityTag}; - /// - /// let mut builder = Response::Ok(); - /// builder.insert_header( - /// ETag(EntityTag::new(true, "xyzzy".to_owned())) - /// ); - /// ``` - (ETag, ETAG) => [EntityTag] - - test_etag { - // From the RFC - test_header!(test1, - vec![b"\"xyzzy\""], - Some(ETag(EntityTag::new(false, "xyzzy".to_owned())))); - test_header!(test2, - vec![b"W/\"xyzzy\""], - Some(ETag(EntityTag::new(true, "xyzzy".to_owned())))); - test_header!(test3, - vec![b"\"\""], - Some(ETag(EntityTag::new(false, "".to_owned())))); - // Own tests - test_header!(test4, - vec![b"\"foobar\""], - Some(ETag(EntityTag::new(false, "foobar".to_owned())))); - test_header!(test5, - vec![b"\"\""], - Some(ETag(EntityTag::new(false, "".to_owned())))); - test_header!(test6, - vec![b"W/\"weak-etag\""], - Some(ETag(EntityTag::new(true, "weak-etag".to_owned())))); - test_header!(test7, - vec![b"W/\"\x65\x62\""], - Some(ETag(EntityTag::new(true, "\u{0065}\u{0062}".to_owned())))); - test_header!(test8, - vec![b"W/\"\""], - Some(ETag(EntityTag::new(true, "".to_owned())))); - test_header!(test9, - vec![b"no-dquotes"], - None::); - test_header!(test10, - vec![b"w/\"the-first-w-is-case-sensitive\""], - None::); - test_header!(test11, - vec![b""], - None::); - test_header!(test12, - vec![b"\"unmatched-dquotes1"], - None::); - test_header!(test13, - vec![b"unmatched-dquotes2\""], - None::); - test_header!(test14, - vec![b"matched-\"dquotes\""], - None::); - test_header!(test15, - vec![b"\""], - None::); - } -} diff --git a/actix-http/src/header/common/mod.rs b/actix-http/src/header/common/mod.rs deleted file mode 100644 index 90e0a855e..000000000 --- a/actix-http/src/header/common/mod.rs +++ /dev/null @@ -1,355 +0,0 @@ -//! A Collection of Header implementations for common HTTP Headers. -//! -//! ## Mime -//! -//! Several header fields use MIME values for their contents. Keeping with the -//! strongly-typed theme, the [mime] crate -//! is used, such as `ContentType(pub Mime)`. -#![cfg_attr(rustfmt, rustfmt_skip)] - -pub use self::accept_charset::AcceptCharset; -//pub use self::accept_encoding::AcceptEncoding; -pub use self::accept::Accept; -pub use self::accept_language::AcceptLanguage; -pub use self::allow::Allow; -pub use self::cache_control::{CacheControl, CacheDirective}; -pub use self::content_disposition::{ - ContentDisposition, DispositionParam, DispositionType, -}; -pub use self::content_language::ContentLanguage; -pub use self::content_range::{ContentRange, ContentRangeSpec}; -pub use self::content_encoding::{ContentEncoding}; -pub use self::content_type::ContentType; -pub use self::date::Date; -pub use self::etag::ETag; -pub use self::expires::Expires; -pub use self::if_match::IfMatch; -pub use self::if_modified_since::IfModifiedSince; -pub use self::if_none_match::IfNoneMatch; -pub use self::if_range::IfRange; -pub use self::if_unmodified_since::IfUnmodifiedSince; -pub use self::last_modified::LastModified; -//pub use self::range::{Range, ByteRangeSpec}; - -#[doc(hidden)] -#[macro_export] -macro_rules! __hyper__deref { - ($from:ty => $to:ty) => { - impl ::std::ops::Deref for $from { - type Target = $to; - - #[inline] - fn deref(&self) -> &$to { - &self.0 - } - } - - impl ::std::ops::DerefMut for $from { - #[inline] - fn deref_mut(&mut self) -> &mut $to { - &mut self.0 - } - } - }; -} - -#[doc(hidden)] -#[macro_export] -macro_rules! __hyper__tm { - ($id:ident, $tm:ident{$($tf:item)*}) => { - #[allow(unused_imports)] - #[cfg(test)] - mod $tm{ - use std::str; - use http::Method; - use mime::*; - use $crate::header::*; - use super::$id as HeaderField; - $($tf)* - } - - } -} - -#[doc(hidden)] -#[macro_export] -macro_rules! test_header { - ($id:ident, $raw:expr) => { - #[test] - fn $id() { - use super::*; - use $crate::test; - - let raw = $raw; - let a: Vec> = raw.iter().map(|x| x.to_vec()).collect(); - let mut req = test::TestRequest::default(); - for item in a { - req = req.insert_header((HeaderField::name(), item)).take(); - } - let req = req.finish(); - let value = HeaderField::parse(&req); - let result = format!("{}", value.unwrap()); - let expected = String::from_utf8(raw[0].to_vec()).unwrap(); - let result_cmp: Vec = result - .to_ascii_lowercase() - .split(' ') - .map(|x| x.to_owned()) - .collect(); - let expected_cmp: Vec = expected - .to_ascii_lowercase() - .split(' ') - .map(|x| x.to_owned()) - .collect(); - assert_eq!(result_cmp.concat(), expected_cmp.concat()); - } - }; - ($id:ident, $raw:expr, $typed:expr) => { - #[test] - fn $id() { - use $crate::test; - - let a: Vec> = $raw.iter().map(|x| x.to_vec()).collect(); - let mut req = test::TestRequest::default(); - for item in a { - req.insert_header((HeaderField::name(), item)); - } - let req = req.finish(); - let val = HeaderField::parse(&req); - let typed: Option = $typed; - // Test parsing - assert_eq!(val.ok(), typed); - // Test formatting - if typed.is_some() { - let raw = &($raw)[..]; - let mut iter = raw.iter().map(|b| str::from_utf8(&b[..]).unwrap()); - let mut joined = String::new(); - joined.push_str(iter.next().unwrap()); - for s in iter { - joined.push_str(", "); - joined.push_str(s); - } - assert_eq!(format!("{}", typed.unwrap()), joined); - } - } - }; -} - -#[macro_export] -macro_rules! header { - // $a:meta: Attributes associated with the header item (usually docs) - // $id:ident: Identifier of the header - // $n:expr: Lowercase name of the header - // $nn:expr: Nice name of the header - - // List header, zero or more items - ($(#[$a:meta])*($id:ident, $name:expr) => ($item:ty)*) => { - $(#[$a])* - #[derive(Clone, Debug, PartialEq)] - pub struct $id(pub Vec<$item>); - __hyper__deref!($id => Vec<$item>); - impl $crate::http::header::Header for $id { - #[inline] - fn name() -> $crate::http::header::HeaderName { - $name - } - #[inline] - fn parse(msg: &T) -> Result - where T: $crate::HttpMessage - { - $crate::http::header::from_comma_delimited( - msg.headers().get_all(Self::name())).map($id) - } - } - impl std::fmt::Display for $id { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result { - $crate::http::header::fmt_comma_delimited(f, &self.0[..]) - } - } - impl $crate::http::header::IntoHeaderValue for $id { - type Error = $crate::http::header::InvalidHeaderValue; - - fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { - use std::fmt::Write; - let mut writer = $crate::http::header::Writer::new(); - let _ = write!(&mut writer, "{}", self); - $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) - } - } - }; - // List header, one or more items - ($(#[$a:meta])*($id:ident, $name:expr) => ($item:ty)+) => { - $(#[$a])* - #[derive(Clone, Debug, PartialEq)] - pub struct $id(pub Vec<$item>); - __hyper__deref!($id => Vec<$item>); - impl $crate::http::header::Header for $id { - #[inline] - fn name() -> $crate::http::header::HeaderName { - $name - } - #[inline] - fn parse(msg: &T) -> Result - where T: $crate::HttpMessage - { - $crate::http::header::from_comma_delimited( - msg.headers().get_all(Self::name())).map($id) - } - } - impl std::fmt::Display for $id { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - $crate::http::header::fmt_comma_delimited(f, &self.0[..]) - } - } - impl $crate::http::header::IntoHeaderValue for $id { - type Error = $crate::http::header::InvalidHeaderValue; - - fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { - use std::fmt::Write; - let mut writer = $crate::http::header::Writer::new(); - let _ = write!(&mut writer, "{}", self); - $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) - } - } - }; - // Single value header - ($(#[$a:meta])*($id:ident, $name:expr) => [$value:ty]) => { - $(#[$a])* - #[derive(Clone, Debug, PartialEq)] - pub struct $id(pub $value); - __hyper__deref!($id => $value); - impl $crate::http::header::Header for $id { - #[inline] - fn name() -> $crate::http::header::HeaderName { - $name - } - #[inline] - fn parse(msg: &T) -> Result - where T: $crate::HttpMessage - { - $crate::http::header::from_one_raw_str( - msg.headers().get(Self::name())).map($id) - } - } - impl std::fmt::Display for $id { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - std::fmt::Display::fmt(&self.0, f) - } - } - impl $crate::http::header::IntoHeaderValue for $id { - type Error = $crate::http::header::InvalidHeaderValue; - - fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { - self.0.try_into_value() - } - } - }; - // List header, one or more items with "*" option - ($(#[$a:meta])*($id:ident, $name:expr) => {Any / ($item:ty)+}) => { - $(#[$a])* - #[derive(Clone, Debug, PartialEq)] - pub enum $id { - /// Any value is a match - Any, - /// Only the listed items are a match - Items(Vec<$item>), - } - impl $crate::http::header::Header for $id { - #[inline] - fn name() -> $crate::http::header::HeaderName { - $name - } - #[inline] - fn parse(msg: &T) -> Result - where T: $crate::HttpMessage - { - let any = msg.headers().get(Self::name()).and_then(|hdr| { - hdr.to_str().ok().and_then(|hdr| Some(hdr.trim() == "*"))}); - - if let Some(true) = any { - Ok($id::Any) - } else { - Ok($id::Items( - $crate::http::header::from_comma_delimited( - msg.headers().get_all(Self::name()))?)) - } - } - } - impl std::fmt::Display for $id { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match *self { - $id::Any => f.write_str("*"), - $id::Items(ref fields) => $crate::http::header::fmt_comma_delimited( - f, &fields[..]) - } - } - } - impl $crate::http::header::IntoHeaderValue for $id { - type Error = $crate::http::header::InvalidHeaderValue; - - fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { - use std::fmt::Write; - let mut writer = $crate::http::header::Writer::new(); - let _ = write!(&mut writer, "{}", self); - $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) - } - } - }; - - // optional test module - ($(#[$a:meta])*($id:ident, $name:expr) => ($item:ty)* $tm:ident{$($tf:item)*}) => { - header! { - $(#[$a])* - ($id, $name) => ($item)* - } - - __hyper__tm! { $id, $tm { $($tf)* }} - }; - ($(#[$a:meta])*($id:ident, $n:expr) => ($item:ty)+ $tm:ident{$($tf:item)*}) => { - header! { - $(#[$a])* - ($id, $n) => ($item)+ - } - - __hyper__tm! { $id, $tm { $($tf)* }} - }; - ($(#[$a:meta])*($id:ident, $name:expr) => [$item:ty] $tm:ident{$($tf:item)*}) => { - header! { - $(#[$a])* ($id, $name) => [$item] - } - - __hyper__tm! { $id, $tm { $($tf)* }} - }; - ($(#[$a:meta])*($id:ident, $name:expr) => {Any / ($item:ty)+} $tm:ident{$($tf:item)*}) => { - header! { - $(#[$a])* - ($id, $name) => {Any / ($item)+} - } - - __hyper__tm! { $id, $tm { $($tf)* }} - }; -} - -mod accept_charset; -// mod accept_encoding; -mod accept; -mod accept_language; -mod allow; -mod cache_control; -mod content_disposition; -mod content_language; -mod content_encoding; -mod content_range; -mod content_type; -mod date; -mod etag; -mod expires; -mod if_match; -mod if_modified_since; -mod if_none_match; -mod if_range; -mod if_unmodified_since; -mod last_modified; diff --git a/actix-http/src/header/common/range.rs b/actix-http/src/header/common/range.rs deleted file mode 100644 index f9e203bb2..000000000 --- a/actix-http/src/header/common/range.rs +++ /dev/null @@ -1,410 +0,0 @@ -use std::fmt::{self, Display}; -use std::str::FromStr; - -use header::parsing::from_one_raw_str; -use header::{Header, Raw}; - -/// `Range` header, defined in [RFC7233](https://tools.ietf.org/html/rfc7233#section-3.1) -/// -/// The "Range" header field on a GET request modifies the method -/// semantics to request transfer of only one or more sub-ranges of the -/// selected representation data, rather than the entire selected -/// representation data. -/// -/// # ABNF -/// -/// ```text -/// Range = byte-ranges-specifier / other-ranges-specifier -/// other-ranges-specifier = other-range-unit "=" other-range-set -/// other-range-set = 1*VCHAR -/// -/// bytes-unit = "bytes" -/// -/// byte-ranges-specifier = bytes-unit "=" byte-range-set -/// byte-range-set = 1#(byte-range-spec / suffix-byte-range-spec) -/// byte-range-spec = first-byte-pos "-" [last-byte-pos] -/// first-byte-pos = 1*DIGIT -/// last-byte-pos = 1*DIGIT -/// ``` -/// -/// # Example values -/// -/// * `bytes=1000-` -/// * `bytes=-2000` -/// * `bytes=0-1,30-40` -/// * `bytes=0-10,20-90,-100` -/// * `custom_unit=0-123` -/// * `custom_unit=xxx-yyy` -/// -/// # Examples -/// -/// ``` -/// use hyper::header::{Headers, Range, ByteRangeSpec}; -/// -/// let mut headers = Headers::new(); -/// headers.set(Range::Bytes( -/// vec![ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::AllFrom(200)] -/// )); -/// -/// headers.clear(); -/// headers.set(Range::Unregistered("letters".to_owned(), "a-f".to_owned())); -/// ``` -/// -/// ``` -/// use hyper::header::{Headers, Range}; -/// -/// let mut headers = Headers::new(); -/// headers.set(Range::bytes(1, 100)); -/// -/// headers.clear(); -/// headers.set(Range::bytes_multi(vec![(1, 100), (200, 300)])); -/// ``` -#[derive(PartialEq, Clone, Debug)] -pub enum Range { - /// Byte range - Bytes(Vec), - /// Custom range, with unit not registered at IANA - /// (`other-range-unit`: String , `other-range-set`: String) - Unregistered(String, String), -} - -/// Each `Range::Bytes` header can contain one or more `ByteRangeSpecs`. -/// Each `ByteRangeSpec` defines a range of bytes to fetch -#[derive(PartialEq, Clone, Debug)] -pub enum ByteRangeSpec { - /// Get all bytes between x and y ("x-y") - FromTo(u64, u64), - /// Get all bytes starting from x ("x-") - AllFrom(u64), - /// Get last x bytes ("-x") - Last(u64), -} - -impl ByteRangeSpec { - /// Given the full length of the entity, attempt to normalize the byte range - /// into an satisfiable end-inclusive (from, to) range. - /// - /// The resulting range is guaranteed to be a satisfiable range within the - /// bounds of `0 <= from <= to < full_length`. - /// - /// If the byte range is deemed unsatisfiable, `None` is returned. - /// An unsatisfiable range is generally cause for a server to either reject - /// the client request with a `416 Range Not Satisfiable` status code, or to - /// simply ignore the range header and serve the full entity using a `200 - /// OK` status code. - /// - /// This function closely follows [RFC 7233][1] section 2.1. - /// As such, it considers ranges to be satisfiable if they meet the - /// following conditions: - /// - /// > If a valid byte-range-set includes at least one byte-range-spec with - /// a first-byte-pos that is less than the current length of the - /// representation, or at least one suffix-byte-range-spec with a - /// non-zero suffix-length, then the byte-range-set is satisfiable. - /// Otherwise, the byte-range-set is unsatisfiable. - /// - /// The function also computes remainder ranges based on the RFC: - /// - /// > If the last-byte-pos value is - /// absent, or if the value is greater than or equal to the current - /// length of the representation data, the byte range is interpreted as - /// the remainder of the representation (i.e., the server replaces the - /// value of last-byte-pos with a value that is one less than the current - /// length of the selected representation). - /// - /// [1]: https://tools.ietf.org/html/rfc7233 - pub fn to_satisfiable_range(&self, full_length: u64) -> Option<(u64, u64)> { - // If the full length is zero, there is no satisfiable end-inclusive range. - if full_length == 0 { - return None; - } - match self { - &ByteRangeSpec::FromTo(from, to) => { - if from < full_length && from <= to { - Some((from, ::std::cmp::min(to, full_length - 1))) - } else { - None - } - } - &ByteRangeSpec::AllFrom(from) => { - if from < full_length { - Some((from, full_length - 1)) - } else { - None - } - } - &ByteRangeSpec::Last(last) => { - if last > 0 { - // From the RFC: If the selected representation is shorter - // than the specified suffix-length, - // the entire representation is used. - if last > full_length { - Some((0, full_length - 1)) - } else { - Some((full_length - last, full_length - 1)) - } - } else { - None - } - } - } - } -} - -impl Range { - /// Get the most common byte range header ("bytes=from-to") - pub fn bytes(from: u64, to: u64) -> Range { - Range::Bytes(vec![ByteRangeSpec::FromTo(from, to)]) - } - - /// Get byte range header with multiple subranges - /// ("bytes=from1-to1,from2-to2,fromX-toX") - pub fn bytes_multi(ranges: Vec<(u64, u64)>) -> Range { - Range::Bytes( - ranges - .iter() - .map(|r| ByteRangeSpec::FromTo(r.0, r.1)) - .collect(), - ) - } -} - -impl fmt::Display for ByteRangeSpec { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - ByteRangeSpec::FromTo(from, to) => write!(f, "{}-{}", from, to), - ByteRangeSpec::Last(pos) => write!(f, "-{}", pos), - ByteRangeSpec::AllFrom(pos) => write!(f, "{}-", pos), - } - } -} - -impl fmt::Display for Range { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Range::Bytes(ref ranges) => { - write!(f, "bytes=")?; - - for (i, range) in ranges.iter().enumerate() { - if i != 0 { - f.write_str(",")?; - } - Display::fmt(range, f)?; - } - Ok(()) - } - Range::Unregistered(ref unit, ref range_str) => { - write!(f, "{}={}", unit, range_str) - } - } - } -} - -impl FromStr for Range { - type Err = ::Error; - - fn from_str(s: &str) -> ::Result { - let mut iter = s.splitn(2, '='); - - match (iter.next(), iter.next()) { - (Some("bytes"), Some(ranges)) => { - let ranges = from_comma_delimited(ranges); - if ranges.is_empty() { - return Err(::Error::Header); - } - Ok(Range::Bytes(ranges)) - } - (Some(unit), Some(range_str)) if unit != "" && range_str != "" => { - Ok(Range::Unregistered(unit.to_owned(), range_str.to_owned())) - } - _ => Err(::Error::Header), - } - } -} - -impl FromStr for ByteRangeSpec { - type Err = ::Error; - - fn from_str(s: &str) -> ::Result { - let mut parts = s.splitn(2, '-'); - - match (parts.next(), parts.next()) { - (Some(""), Some(end)) => end - .parse() - .or(Err(::Error::Header)) - .map(ByteRangeSpec::Last), - (Some(start), Some("")) => start - .parse() - .or(Err(::Error::Header)) - .map(ByteRangeSpec::AllFrom), - (Some(start), Some(end)) => match (start.parse(), end.parse()) { - (Ok(start), Ok(end)) if start <= end => { - Ok(ByteRangeSpec::FromTo(start, end)) - } - _ => Err(::Error::Header), - }, - _ => Err(::Error::Header), - } - } -} - -fn from_comma_delimited(s: &str) -> Vec { - s.split(',') - .filter_map(|x| match x.trim() { - "" => None, - y => Some(y), - }) - .filter_map(|x| x.parse().ok()) - .collect() -} - -impl Header for Range { - fn header_name() -> &'static str { - static NAME: &'static str = "Range"; - NAME - } - - fn parse_header(raw: &Raw) -> ::Result { - from_one_raw_str(raw) - } - - fn fmt_header(&self, f: &mut ::header::Formatter) -> fmt::Result { - f.fmt_line(self) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_bytes_range_valid() { - let r: Range = Header::parse_header(&"bytes=1-100".into()).unwrap(); - let r2: Range = Header::parse_header(&"bytes=1-100,-".into()).unwrap(); - let r3 = Range::bytes(1, 100); - assert_eq!(r, r2); - assert_eq!(r2, r3); - - let r: Range = Header::parse_header(&"bytes=1-100,200-".into()).unwrap(); - let r2: Range = - Header::parse_header(&"bytes= 1-100 , 101-xxx, 200- ".into()).unwrap(); - let r3 = Range::Bytes(vec![ - ByteRangeSpec::FromTo(1, 100), - ByteRangeSpec::AllFrom(200), - ]); - assert_eq!(r, r2); - assert_eq!(r2, r3); - - let r: Range = Header::parse_header(&"bytes=1-100,-100".into()).unwrap(); - let r2: Range = Header::parse_header(&"bytes=1-100, ,,-100".into()).unwrap(); - let r3 = Range::Bytes(vec![ - ByteRangeSpec::FromTo(1, 100), - ByteRangeSpec::Last(100), - ]); - assert_eq!(r, r2); - assert_eq!(r2, r3); - - let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap(); - let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); - assert_eq!(r, r2); - } - - #[test] - fn test_parse_unregistered_range_valid() { - let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap(); - let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); - assert_eq!(r, r2); - - let r: Range = Header::parse_header(&"custom=abcd".into()).unwrap(); - let r2 = Range::Unregistered("custom".to_owned(), "abcd".to_owned()); - assert_eq!(r, r2); - - let r: Range = Header::parse_header(&"custom=xxx-yyy".into()).unwrap(); - let r2 = Range::Unregistered("custom".to_owned(), "xxx-yyy".to_owned()); - assert_eq!(r, r2); - } - - #[test] - fn test_parse_invalid() { - let r: ::Result = Header::parse_header(&"bytes=1-a,-".into()); - assert_eq!(r.ok(), None); - - let r: ::Result = Header::parse_header(&"bytes=1-2-3".into()); - assert_eq!(r.ok(), None); - - let r: ::Result = Header::parse_header(&"abc".into()); - assert_eq!(r.ok(), None); - - let r: ::Result = Header::parse_header(&"bytes=1-100=".into()); - assert_eq!(r.ok(), None); - - let r: ::Result = Header::parse_header(&"bytes=".into()); - assert_eq!(r.ok(), None); - - let r: ::Result = Header::parse_header(&"custom=".into()); - assert_eq!(r.ok(), None); - - let r: ::Result = Header::parse_header(&"=1-100".into()); - assert_eq!(r.ok(), None); - } - - #[test] - fn test_fmt() { - use header::Headers; - - let mut headers = Headers::new(); - - headers.set(Range::Bytes(vec![ - ByteRangeSpec::FromTo(0, 1000), - ByteRangeSpec::AllFrom(2000), - ])); - assert_eq!(&headers.to_string(), "Range: bytes=0-1000,2000-\r\n"); - - headers.clear(); - headers.set(Range::Bytes(vec![])); - - assert_eq!(&headers.to_string(), "Range: bytes=\r\n"); - - headers.clear(); - headers.set(Range::Unregistered("custom".to_owned(), "1-xxx".to_owned())); - - assert_eq!(&headers.to_string(), "Range: custom=1-xxx\r\n"); - } - - #[test] - fn test_byte_range_spec_to_satisfiable_range() { - assert_eq!( - Some((0, 0)), - ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(3) - ); - assert_eq!( - Some((1, 2)), - ByteRangeSpec::FromTo(1, 2).to_satisfiable_range(3) - ); - assert_eq!( - Some((1, 2)), - ByteRangeSpec::FromTo(1, 5).to_satisfiable_range(3) - ); - assert_eq!(None, ByteRangeSpec::FromTo(3, 3).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::FromTo(2, 1).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(0)); - - assert_eq!( - Some((0, 2)), - ByteRangeSpec::AllFrom(0).to_satisfiable_range(3) - ); - assert_eq!( - Some((2, 2)), - ByteRangeSpec::AllFrom(2).to_satisfiable_range(3) - ); - assert_eq!(None, ByteRangeSpec::AllFrom(3).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::AllFrom(5).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::AllFrom(0).to_satisfiable_range(0)); - - assert_eq!(Some((1, 2)), ByteRangeSpec::Last(2).to_satisfiable_range(3)); - assert_eq!(Some((2, 2)), ByteRangeSpec::Last(1).to_satisfiable_range(3)); - assert_eq!(Some((0, 2)), ByteRangeSpec::Last(5).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::Last(0).to_satisfiable_range(3)); - assert_eq!(None, ByteRangeSpec::Last(2).to_satisfiable_range(0)); - } -} diff --git a/actix-http/src/header/into_pair.rs b/actix-http/src/header/into_pair.rs index d0d6e7324..91c3e6640 100644 --- a/actix-http/src/header/into_pair.rs +++ b/actix-http/src/header/into_pair.rs @@ -1,17 +1,20 @@ -use std::convert::TryFrom; +//! [`TryIntoHeaderPair`] trait and implementations. -use http::{ - header::{HeaderName, InvalidHeaderName, InvalidHeaderValue}, - Error as HttpError, HeaderValue, +use std::convert::TryFrom as _; + +use super::{ + Header, HeaderName, HeaderValue, InvalidHeaderName, InvalidHeaderValue, TryIntoHeaderValue, }; +use crate::error::HttpError; -use super::{Header, IntoHeaderValue}; - -/// Transforms structures into header K/V pairs for inserting into `HeaderMap`s. -pub trait IntoHeaderPair: Sized { +/// An interface for types that can be converted into a [`HeaderName`] + [`HeaderValue`] pair for +/// insertion into a [`HeaderMap`]. +/// +/// [`HeaderMap`]: super::HeaderMap +pub trait TryIntoHeaderPair: Sized { type Error: Into; - fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error>; + fn try_into_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error>; } #[derive(Debug)] @@ -29,14 +32,14 @@ impl From for HttpError { } } -impl IntoHeaderPair for (HeaderName, V) +impl TryIntoHeaderPair for (HeaderName, V) where - V: IntoHeaderValue, + V: TryIntoHeaderValue, V::Error: Into, { type Error = InvalidHeaderPart; - fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + fn try_into_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { let (name, value) = self; let value = value .try_into_value() @@ -45,14 +48,14 @@ where } } -impl IntoHeaderPair for (&HeaderName, V) +impl TryIntoHeaderPair for (&HeaderName, V) where - V: IntoHeaderValue, + V: TryIntoHeaderValue, V::Error: Into, { type Error = InvalidHeaderPart; - fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + fn try_into_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { let (name, value) = self; let value = value .try_into_value() @@ -61,14 +64,14 @@ where } } -impl IntoHeaderPair for (&[u8], V) +impl TryIntoHeaderPair for (&[u8], V) where - V: IntoHeaderValue, + V: TryIntoHeaderValue, V::Error: Into, { type Error = InvalidHeaderPart; - fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + fn try_into_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { let (name, value) = self; let name = HeaderName::try_from(name).map_err(InvalidHeaderPart::Name)?; let value = value @@ -78,14 +81,14 @@ where } } -impl IntoHeaderPair for (&str, V) +impl TryIntoHeaderPair for (&str, V) where - V: IntoHeaderValue, + V: TryIntoHeaderValue, V::Error: Into, { type Error = InvalidHeaderPart; - fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + fn try_into_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { let (name, value) = self; let name = HeaderName::try_from(name).map_err(InvalidHeaderPart::Name)?; let value = value @@ -95,23 +98,25 @@ where } } -impl IntoHeaderPair for (String, V) +impl TryIntoHeaderPair for (String, V) where - V: IntoHeaderValue, + V: TryIntoHeaderValue, V::Error: Into, { type Error = InvalidHeaderPart; - fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + #[inline] + fn try_into_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { let (name, value) = self; - (name.as_str(), value).try_into_header_pair() + (name.as_str(), value).try_into_pair() } } -impl IntoHeaderPair for T { - type Error = ::Error; +impl TryIntoHeaderPair for T { + type Error = ::Error; - fn try_into_header_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + #[inline] + fn try_into_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { Ok((T::name(), self.try_into_value()?)) } } diff --git a/actix-http/src/header/into_value.rs b/actix-http/src/header/into_value.rs index 4ba58e726..6d369ee65 100644 --- a/actix-http/src/header/into_value.rs +++ b/actix-http/src/header/into_value.rs @@ -1,11 +1,13 @@ -use std::convert::TryFrom; +//! [`TryIntoHeaderValue`] trait and implementations. + +use std::convert::TryFrom as _; use bytes::Bytes; use http::{header::InvalidHeaderValue, Error as HttpError, HeaderValue}; use mime::Mime; -/// A trait for any object that can be Converted to a `HeaderValue` -pub trait IntoHeaderValue: Sized { +/// An interface for types that can be converted into a [`HeaderValue`]. +pub trait TryIntoHeaderValue: Sized { /// The type returned in the event of a conversion error. type Error: Into; @@ -13,7 +15,7 @@ pub trait IntoHeaderValue: Sized { fn try_into_value(self) -> Result; } -impl IntoHeaderValue for HeaderValue { +impl TryIntoHeaderValue for HeaderValue { type Error = InvalidHeaderValue; #[inline] @@ -22,7 +24,7 @@ impl IntoHeaderValue for HeaderValue { } } -impl IntoHeaderValue for &HeaderValue { +impl TryIntoHeaderValue for &HeaderValue { type Error = InvalidHeaderValue; #[inline] @@ -31,7 +33,7 @@ impl IntoHeaderValue for &HeaderValue { } } -impl IntoHeaderValue for &str { +impl TryIntoHeaderValue for &str { type Error = InvalidHeaderValue; #[inline] @@ -40,7 +42,7 @@ impl IntoHeaderValue for &str { } } -impl IntoHeaderValue for &[u8] { +impl TryIntoHeaderValue for &[u8] { type Error = InvalidHeaderValue; #[inline] @@ -49,7 +51,7 @@ impl IntoHeaderValue for &[u8] { } } -impl IntoHeaderValue for Bytes { +impl TryIntoHeaderValue for Bytes { type Error = InvalidHeaderValue; #[inline] @@ -58,7 +60,7 @@ impl IntoHeaderValue for Bytes { } } -impl IntoHeaderValue for Vec { +impl TryIntoHeaderValue for Vec { type Error = InvalidHeaderValue; #[inline] @@ -67,7 +69,7 @@ impl IntoHeaderValue for Vec { } } -impl IntoHeaderValue for String { +impl TryIntoHeaderValue for String { type Error = InvalidHeaderValue; #[inline] @@ -76,7 +78,7 @@ impl IntoHeaderValue for String { } } -impl IntoHeaderValue for usize { +impl TryIntoHeaderValue for usize { type Error = InvalidHeaderValue; #[inline] @@ -85,7 +87,7 @@ impl IntoHeaderValue for usize { } } -impl IntoHeaderValue for i64 { +impl TryIntoHeaderValue for i64 { type Error = InvalidHeaderValue; #[inline] @@ -94,7 +96,7 @@ impl IntoHeaderValue for i64 { } } -impl IntoHeaderValue for u64 { +impl TryIntoHeaderValue for u64 { type Error = InvalidHeaderValue; #[inline] @@ -103,7 +105,7 @@ impl IntoHeaderValue for u64 { } } -impl IntoHeaderValue for i32 { +impl TryIntoHeaderValue for i32 { type Error = InvalidHeaderValue; #[inline] @@ -112,7 +114,7 @@ impl IntoHeaderValue for i32 { } } -impl IntoHeaderValue for u32 { +impl TryIntoHeaderValue for u32 { type Error = InvalidHeaderValue; #[inline] @@ -121,7 +123,7 @@ impl IntoHeaderValue for u32 { } } -impl IntoHeaderValue for Mime { +impl TryIntoHeaderValue for Mime { type Error = InvalidHeaderValue; #[inline] diff --git a/actix-http/src/header/map.rs b/actix-http/src/header/map.rs index 106e44edb..33fb262c4 100644 --- a/actix-http/src/header/map.rs +++ b/actix-http/src/header/map.rs @@ -1,12 +1,12 @@ //! A multi-value [`HeaderMap`] and its iterators. -use std::{borrow::Cow, collections::hash_map, ops}; +use std::{borrow::Cow, collections::hash_map, iter, ops}; use ahash::AHashMap; use http::header::{HeaderName, HeaderValue}; use smallvec::{smallvec, SmallVec}; -use crate::header::AsHeaderName; +use super::AsHeaderName; /// A multi-map of HTTP headers. /// @@ -14,7 +14,7 @@ use crate::header::AsHeaderName; /// /// # Examples /// ``` -/// use actix_http::http::{header, HeaderMap, HeaderValue}; +/// use actix_http::header::{self, HeaderMap, HeaderValue}; /// /// let mut map = HeaderMap::new(); /// @@ -75,7 +75,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::HeaderMap; + /// # use actix_http::header::HeaderMap; /// let map = HeaderMap::new(); /// /// assert!(map.is_empty()); @@ -92,7 +92,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::HeaderMap; + /// # use actix_http::header::HeaderMap; /// let map = HeaderMap::with_capacity(16); /// /// assert!(map.is_empty()); @@ -123,12 +123,11 @@ impl HeaderMap { let mut map = HeaderMap::with_capacity(capacity); map.append(first_name.clone(), first_value); - let (map, _) = - drain.fold((map, first_name), |(mut map, prev_name), (name, value)| { - let name = name.unwrap_or(prev_name); - map.append(name.clone(), value); - (map, name) - }); + let (map, _) = drain.fold((map, first_name), |(mut map, prev_name), (name, value)| { + let name = name.unwrap_or(prev_name); + map.append(name.clone(), value); + (map, name) + }); map } @@ -139,7 +138,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// assert_eq!(map.len(), 0); /// @@ -162,7 +161,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// assert_eq!(map.len_keys(), 0); /// @@ -181,7 +180,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// assert!(map.is_empty()); /// @@ -198,7 +197,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// /// map.insert(header::ACCEPT, HeaderValue::from_static("text/plain")); @@ -213,7 +212,7 @@ impl HeaderMap { } fn get_value(&self, key: impl AsHeaderName) -> Option<&Value> { - match key.try_as_name().ok()? { + match key.try_as_name(super::as_name::Seal).ok()? { Cow::Borrowed(name) => self.inner.get(name), Cow::Owned(name) => self.inner.get(&name), } @@ -231,7 +230,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// /// map.insert(header::SET_COOKIE, HeaderValue::from_static("one=1")); @@ -249,7 +248,7 @@ impl HeaderMap { /// assert!(map.get("INVALID HEADER NAME").is_none()); /// ``` pub fn get(&self, key: impl AsHeaderName) -> Option<&HeaderValue> { - self.get_value(key).map(|val| val.first()) + self.get_value(key).map(Value::first) } /// Returns a mutable reference to the _first_ value associated a header name. @@ -264,7 +263,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// /// map.insert(header::SET_COOKIE, HeaderValue::from_static("one=1")); @@ -279,21 +278,21 @@ impl HeaderMap { /// assert!(map.get("INVALID HEADER NAME").is_none()); /// ``` pub fn get_mut(&mut self, key: impl AsHeaderName) -> Option<&mut HeaderValue> { - match key.try_as_name().ok()? { - Cow::Borrowed(name) => self.inner.get_mut(name).map(|v| v.first_mut()), - Cow::Owned(name) => self.inner.get_mut(&name).map(|v| v.first_mut()), + match key.try_as_name(super::as_name::Seal).ok()? { + Cow::Borrowed(name) => self.inner.get_mut(name).map(Value::first_mut), + Cow::Owned(name) => self.inner.get_mut(&name).map(Value::first_mut), } } /// Returns an iterator over all values associated with a header name. /// /// The returned iterator does not incur any allocations and will yield no items if there are no - /// values associated with the key. Iteration order is **not** guaranteed to be the same as + /// values associated with the key. Iteration order is guaranteed to be the same as /// insertion order. /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// /// let mut none_iter = map.get_all(header::ORIGIN); @@ -307,8 +306,11 @@ impl HeaderMap { /// assert_eq!(set_cookies_iter.next().unwrap(), "two=2"); /// assert!(set_cookies_iter.next().is_none()); /// ``` - pub fn get_all(&self, key: impl AsHeaderName) -> GetAll<'_> { - GetAll::new(self.get_value(key)) + pub fn get_all(&self, key: impl AsHeaderName) -> std::slice::Iter<'_, HeaderValue> { + match self.get_value(key) { + Some(value) => value.iter(), + None => (&[]).iter(), + } } // TODO: get_all_mut ? @@ -319,7 +321,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// assert!(!map.contains_key(header::ACCEPT)); /// @@ -327,14 +329,14 @@ impl HeaderMap { /// assert!(map.contains_key(header::ACCEPT)); /// ``` pub fn contains_key(&self, key: impl AsHeaderName) -> bool { - match key.try_as_name() { + match key.try_as_name(super::as_name::Seal) { Ok(Cow::Borrowed(name)) => self.inner.contains_key(name), Ok(Cow::Owned(name)) => self.inner.contains_key(&name), Err(_) => false, } } - /// Inserts a name-value pair into the map. + /// Inserts (overrides) a name-value pair in the map. /// /// If the map already contained this key, the new value is associated with the key and all /// previous values are removed and returned as a `Removed` iterator. The key is not updated; @@ -342,7 +344,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// /// map.insert(header::ACCEPT, HeaderValue::from_static("text/plain")); @@ -355,12 +357,25 @@ impl HeaderMap { /// /// assert_eq!(map.len(), 1); /// ``` + /// + /// A convenience method is provided on the returned iterator to check if the insertion replaced + /// any values. + /// ``` + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// let removed = map.insert(header::ACCEPT, HeaderValue::from_static("text/plain")); + /// assert!(removed.is_empty()); + /// + /// let removed = map.insert(header::ACCEPT, HeaderValue::from_static("text/html")); + /// assert!(!removed.is_empty()); + /// ``` pub fn insert(&mut self, key: HeaderName, val: HeaderValue) -> Removed { let value = self.inner.insert(key, Value::one(val)); Removed::new(value) } - /// Inserts a name-value pair into the map. + /// Appends a name-value pair to the map. /// /// If the map already contained this key, the new value is added to the list of values /// currently associated with the key. The key is not updated; this matters for types that can @@ -368,7 +383,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// /// map.append(header::HOST, HeaderValue::from_static("example.com")); @@ -393,9 +408,12 @@ impl HeaderMap { /// Removes all headers for a particular header name from the map. /// + /// Providing an invalid header names (as a string argument) will have no effect and return + /// without error. + /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// /// map.append(header::SET_COOKIE, HeaderValue::from_static("one=1")); @@ -409,8 +427,23 @@ impl HeaderMap { /// assert!(removed.next().is_none()); /// /// assert!(map.is_empty()); + /// ``` + /// + /// A convenience method is provided on the returned iterator to check if the `remove` call + /// actually removed any values. + /// ``` + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; + /// let mut map = HeaderMap::new(); + /// + /// let removed = map.remove("accept"); + /// assert!(removed.is_empty()); + /// + /// map.insert(header::ACCEPT, HeaderValue::from_static("text/html")); + /// let removed = map.remove("accept"); + /// assert!(!removed.is_empty()); + /// ``` pub fn remove(&mut self, key: impl AsHeaderName) -> Removed { - let value = match key.try_as_name() { + let value = match key.try_as_name(super::as_name::Seal) { Ok(Cow::Borrowed(name)) => self.inner.remove(name), Ok(Cow::Owned(name)) => self.inner.remove(&name), Err(_) => None, @@ -428,7 +461,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::HeaderMap; + /// # use actix_http::header::HeaderMap; /// let map = HeaderMap::with_capacity(16); /// /// assert!(map.is_empty()); @@ -448,7 +481,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::HeaderMap; + /// # use actix_http::header::HeaderMap; /// let mut map = HeaderMap::with_capacity(2); /// assert!(map.capacity() >= 2); /// @@ -468,7 +501,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// /// let mut iter = map.iter(); @@ -500,7 +533,7 @@ impl HeaderMap { /// /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// /// let mut iter = map.keys(); @@ -528,7 +561,7 @@ impl HeaderMap { /// Keeps the allocated memory for reuse. /// # Examples /// ``` - /// # use actix_http::http::{header, HeaderMap, HeaderValue}; + /// # use actix_http::header::{self, HeaderMap, HeaderValue}; /// let mut map = HeaderMap::new(); /// /// let mut iter = map.drain(); @@ -550,7 +583,8 @@ impl HeaderMap { } } -/// Note that this implementation will clone a [HeaderName] for each value. +/// Note that this implementation will clone a [HeaderName] for each value. Consider using +/// [`drain`](Self::drain) to control header name cloning. impl IntoIterator for HeaderMap { type Item = (HeaderName, HeaderValue); type IntoIter = IntoIter; @@ -571,60 +605,39 @@ impl<'a> IntoIterator for &'a HeaderMap { } } -/// Iterator for all values with the same header name. +/// Convert `http::HeaderMap` to our `HeaderMap`. +impl From for HeaderMap { + fn from(mut map: http::HeaderMap) -> HeaderMap { + HeaderMap::from_drain(map.drain()) + } +} + +/// Iterator over removed, owned values with the same associated name. /// -/// See [`HeaderMap::get_all`]. -#[derive(Debug)] -pub struct GetAll<'a> { - idx: usize, - value: Option<&'a Value>, -} - -impl<'a> GetAll<'a> { - fn new(value: Option<&'a Value>) -> Self { - Self { idx: 0, value } - } -} - -impl<'a> Iterator for GetAll<'a> { - type Item = &'a HeaderValue; - - fn next(&mut self) -> Option { - let val = self.value?; - - match val.get(self.idx) { - Some(val) => { - self.idx += 1; - Some(val) - } - None => { - // current index is none; remove value to fast-path future next calls - self.value = None; - None - } - } - } - - fn size_hint(&self) -> (usize, Option) { - match self.value { - Some(val) => (val.len(), Some(val.len())), - None => (0, Some(0)), - } - } -} - -/// Iterator for owned [`HeaderValue`]s with the same associated [`HeaderName`] returned from methods -/// on [`HeaderMap`] that remove or replace items. +/// Returned from methods that remove or replace items. See [`HeaderMap::insert`] +/// and [`HeaderMap::remove`]. #[derive(Debug)] pub struct Removed { inner: Option>, } -impl<'a> Removed { +impl Removed { fn new(value: Option) -> Self { let inner = value.map(|value| value.inner.into_iter()); Self { inner } } + + /// Returns true if iterator contains no elements, without consuming it. + /// + /// If called immediately after [`HeaderMap::insert`] or [`HeaderMap::remove`], it will indicate + /// wether any items were actually replaced or removed, respectively. + pub fn is_empty(&self) -> bool { + match self.inner { + // size hint lower bound of smallvec is the correct length + Some(ref iter) => iter.size_hint().0 == 0, + None => true, + } + } } impl Iterator for Removed { @@ -644,7 +657,11 @@ impl Iterator for Removed { } } -/// Iterator over all [`HeaderName`]s in the map. +impl ExactSizeIterator for Removed {} + +impl iter::FusedIterator for Removed {} + +/// Iterator over all names in the map. #[derive(Debug)] pub struct Keys<'a>(hash_map::Keys<'a, HeaderName, Value>); @@ -662,6 +679,11 @@ impl<'a> Iterator for Keys<'a> { } } +impl ExactSizeIterator for Keys<'_> {} + +impl iter::FusedIterator for Keys<'_> {} + +/// Iterator over borrowed name-value pairs. #[derive(Debug)] pub struct Iter<'a> { inner: hash_map::Iter<'a, HeaderName, Value>, @@ -684,7 +706,7 @@ impl<'a> Iterator for Iter<'a> { fn next(&mut self) -> Option { // handle in-progress multi value lists first - if let Some((ref name, ref mut vals)) = self.multi_inner { + if let Some((name, ref mut vals)) = self.multi_inner { match vals.get(self.multi_idx) { Some(val) => { self.multi_idx += 1; @@ -713,6 +735,10 @@ impl<'a> Iterator for Iter<'a> { } } +impl ExactSizeIterator for Iter<'_> {} + +impl iter::FusedIterator for Iter<'_> {} + /// Iterator over drained name-value pairs. /// /// Iterator items are `(Option, HeaderValue)` to avoid cloning. @@ -764,6 +790,10 @@ impl<'a> Iterator for Drain<'a> { } } +impl ExactSizeIterator for Drain<'_> {} + +impl iter::FusedIterator for Drain<'_> {} + /// Iterator over owned name-value pairs. /// /// Implementation necessarily clones header names for each value. @@ -814,12 +844,27 @@ impl Iterator for IntoIter { } } +impl ExactSizeIterator for IntoIter {} + +impl iter::FusedIterator for IntoIter {} + #[cfg(test)] mod tests { + use std::iter::FusedIterator; + use http::header; + use static_assertions::assert_impl_all; use super::*; + assert_impl_all!(HeaderMap: IntoIterator); + assert_impl_all!(Keys<'_>: Iterator, ExactSizeIterator, FusedIterator); + assert_impl_all!(std::slice::Iter<'_, HeaderValue>: Iterator, ExactSizeIterator, FusedIterator); + assert_impl_all!(Removed: Iterator, ExactSizeIterator, FusedIterator); + assert_impl_all!(Iter<'_>: Iterator, ExactSizeIterator, FusedIterator); + assert_impl_all!(IntoIter: Iterator, ExactSizeIterator, FusedIterator); + assert_impl_all!(Drain<'_>: Iterator, ExactSizeIterator, FusedIterator); + #[test] fn create() { let map = HeaderMap::new(); @@ -945,6 +990,56 @@ mod tests { assert_eq!(vals.next(), removed.next().as_ref()); } + #[test] + fn get_all_iteration_order_matches_insertion_order() { + let mut map = HeaderMap::new(); + + let mut vals = map.get_all(header::COOKIE); + assert!(vals.next().is_none()); + + map.append(header::COOKIE, HeaderValue::from_static("1")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"1"); + assert!(vals.next().is_none()); + + map.append(header::COOKIE, HeaderValue::from_static("2")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"1"); + assert_eq!(vals.next().unwrap().as_bytes(), b"2"); + assert!(vals.next().is_none()); + + map.append(header::COOKIE, HeaderValue::from_static("3")); + map.append(header::COOKIE, HeaderValue::from_static("4")); + map.append(header::COOKIE, HeaderValue::from_static("5")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"1"); + assert_eq!(vals.next().unwrap().as_bytes(), b"2"); + assert_eq!(vals.next().unwrap().as_bytes(), b"3"); + assert_eq!(vals.next().unwrap().as_bytes(), b"4"); + assert_eq!(vals.next().unwrap().as_bytes(), b"5"); + assert!(vals.next().is_none()); + + let _ = map.insert(header::COOKIE, HeaderValue::from_static("6")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"6"); + assert!(vals.next().is_none()); + + let _ = map.insert(header::COOKIE, HeaderValue::from_static("7")); + let _ = map.insert(header::COOKIE, HeaderValue::from_static("8")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"8"); + assert!(vals.next().is_none()); + + map.append(header::COOKIE, HeaderValue::from_static("9")); + let mut vals = map.get_all(header::COOKIE); + assert_eq!(vals.next().unwrap().as_bytes(), b"8"); + assert_eq!(vals.next().unwrap().as_bytes(), b"9"); + assert!(vals.next().is_none()); + + // check for fused-ness + assert!(vals.next().is_none()); + } + fn owned_pair<'a>( (name, val): (&'a HeaderName, &'a HeaderValue), ) -> (HeaderName, HeaderValue) { diff --git a/actix-http/src/header/mod.rs b/actix-http/src/header/mod.rs index 1100a959d..5f352fc12 100644 --- a/actix-http/src/header/mod.rs +++ b/actix-http/src/header/mod.rs @@ -1,83 +1,64 @@ -//! Typed HTTP headers, pre-defined `HeaderName`s, traits for parsing and conversion, and other -//! header utility methods. +//! Pre-defined `HeaderName`s, traits for parsing and conversion, and other header utility methods. -use std::fmt; - -use bytes::{Bytes, BytesMut}; use percent_encoding::{AsciiSet, CONTROLS}; -pub use http::header::*; +// re-export from http except header map related items +pub use http::header::{ + HeaderName, HeaderValue, InvalidHeaderName, InvalidHeaderValue, ToStrError, +}; -use crate::error::ParseError; -use crate::HttpMessage; +// re-export const header names +pub use http::header::{ + ACCEPT, ACCEPT_CHARSET, ACCEPT_ENCODING, ACCEPT_LANGUAGE, ACCEPT_RANGES, + ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, + ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_EXPOSE_HEADERS, + ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, AGE, + ALLOW, ALT_SVC, AUTHORIZATION, CACHE_CONTROL, CONNECTION, CONTENT_DISPOSITION, + CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_LOCATION, CONTENT_RANGE, + CONTENT_SECURITY_POLICY, CONTENT_SECURITY_POLICY_REPORT_ONLY, CONTENT_TYPE, COOKIE, DATE, + DNT, ETAG, EXPECT, EXPIRES, FORWARDED, FROM, HOST, IF_MATCH, IF_MODIFIED_SINCE, + IF_NONE_MATCH, IF_RANGE, IF_UNMODIFIED_SINCE, LAST_MODIFIED, LINK, LOCATION, MAX_FORWARDS, + ORIGIN, PRAGMA, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, PUBLIC_KEY_PINS, + PUBLIC_KEY_PINS_REPORT_ONLY, RANGE, REFERER, REFERRER_POLICY, REFRESH, RETRY_AFTER, + SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, + SEC_WEBSOCKET_VERSION, SERVER, SET_COOKIE, STRICT_TRANSPORT_SECURITY, TE, TRAILER, + TRANSFER_ENCODING, UPGRADE, UPGRADE_INSECURE_REQUESTS, USER_AGENT, VARY, VIA, WARNING, + WWW_AUTHENTICATE, X_CONTENT_TYPE_OPTIONS, X_DNS_PREFETCH_CONTROL, X_FRAME_OPTIONS, + X_XSS_PROTECTION, +}; + +use crate::{error::ParseError, HttpMessage}; mod as_name; mod into_pair; mod into_value; +pub mod map; +mod shared; mod utils; -mod common; -pub(crate) mod map; -mod shared; - -pub use self::common::*; -#[doc(hidden)] -pub use self::shared::*; - pub use self::as_name::AsHeaderName; -pub use self::into_pair::IntoHeaderPair; -pub use self::into_value::IntoHeaderValue; -#[doc(hidden)] -pub use self::map::GetAll; +pub use self::into_pair::TryIntoHeaderPair; +pub use self::into_value::TryIntoHeaderValue; pub use self::map::HeaderMap; -pub use self::utils::*; +pub use self::shared::{ + parse_extended_value, q, Charset, ContentEncoding, ExtendedValue, HttpDate, LanguageTag, + Quality, QualityItem, +}; +pub use self::utils::{ + fmt_comma_delimited, from_comma_delimited, from_one_raw_str, http_percent_encode, +}; -/// A trait for any object that already represents a valid header field and value. -pub trait Header: IntoHeaderValue { - /// Returns the name of the header field +/// An interface for types that already represent a valid header. +pub trait Header: TryIntoHeaderValue { + /// Returns the name of the header field. fn name() -> HeaderName; - /// Parse a header - fn parse(msg: &T) -> Result; -} - -#[derive(Debug, Default)] -pub(crate) struct Writer { - buf: BytesMut, -} - -impl Writer { - fn new() -> Writer { - Writer::default() - } - - fn take(&mut self) -> Bytes { - self.buf.split().freeze() - } -} - -impl fmt::Write for Writer { - #[inline] - fn write_str(&mut self, s: &str) -> fmt::Result { - self.buf.extend_from_slice(s.as_bytes()); - Ok(()) - } - - #[inline] - fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result { - fmt::write(self, args) - } -} - -/// Convert `http::HeaderMap` to our `HeaderMap`. -impl From for HeaderMap { - fn from(mut map: http::HeaderMap) -> HeaderMap { - HeaderMap::from_drain(map.drain()) - } + /// Parse the header from a HTTP message. + fn parse(msg: &M) -> Result; } /// This encode set is used for HTTP header values and is defined at -/// https://tools.ietf.org/html/rfc5987#section-3.2. +/// . pub(crate) const HTTP_VALUE: &AsciiSet = &CONTROLS .add(b' ') .add(b'"') diff --git a/actix-http/src/header/shared/charset.rs b/actix-http/src/header/shared/charset.rs index 36bdbf7e2..1e77e1be8 100644 --- a/actix-http/src/header/shared/charset.rs +++ b/actix-http/src/header/shared/charset.rs @@ -1,14 +1,13 @@ -use std::fmt::{self, Display}; -use std::str::FromStr; +use std::{fmt, str}; use self::Charset::*; -/// A Mime charset. +/// A MIME character set. /// /// The string representation is normalized to upper case. /// /// See . -#[derive(Clone, Debug, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] #[allow(non_camel_case_types)] pub enum Charset { /// US ASCII @@ -88,23 +87,23 @@ impl Charset { Iso_8859_8_E => "ISO-8859-8-E", Iso_8859_8_I => "ISO-8859-8-I", Gb2312 => "GB2312", - Big5 => "big5", + Big5 => "Big5", Koi8_R => "KOI8-R", Ext(ref s) => s, } } } -impl Display for Charset { +impl fmt::Display for Charset { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.label()) } } -impl FromStr for Charset { +impl str::FromStr for Charset { type Err = crate::Error; - fn from_str(s: &str) -> crate::Result { + fn from_str(s: &str) -> Result { Ok(match s.to_ascii_uppercase().as_ref() { "US-ASCII" => Us_Ascii, "ISO-8859-1" => Iso_8859_1, @@ -128,7 +127,7 @@ impl FromStr for Charset { "ISO-8859-8-E" => Iso_8859_8_E, "ISO-8859-8-I" => Iso_8859_8_I, "GB2312" => Gb2312, - "big5" => Big5, + "BIG5" => Big5, "KOI8-R" => Koi8_R, s => Ext(s.to_owned()), }) diff --git a/actix-http/src/header/shared/content_encoding.rs b/actix-http/src/header/shared/content_encoding.rs new file mode 100644 index 000000000..bd25de704 --- /dev/null +++ b/actix-http/src/header/shared/content_encoding.rs @@ -0,0 +1,123 @@ +use std::{convert::TryFrom, str::FromStr}; + +use derive_more::{Display, Error}; +use http::header::InvalidHeaderValue; + +use crate::{ + error::ParseError, + header::{self, from_one_raw_str, Header, HeaderName, HeaderValue, TryIntoHeaderValue}, + HttpMessage, +}; + +/// Error returned when a content encoding is unknown. +#[derive(Debug, Display, Error)] +#[display(fmt = "unsupported content encoding")] +pub struct ContentEncodingParseError; + +/// Represents a supported content encoding. +/// +/// Includes a commonly-used subset of media types appropriate for use as HTTP content encodings. +/// See [IANA HTTP Content Coding Registry]. +/// +/// [IANA HTTP Content Coding Registry]: https://www.iana.org/assignments/http-parameters/http-parameters.xhtml +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum ContentEncoding { + /// Indicates the no-op identity encoding. + /// + /// I.e., no compression or modification. + Identity, + + /// A format using the Brotli algorithm. + Brotli, + + /// A format using the zlib structure with deflate algorithm. + Deflate, + + /// Gzip algorithm. + Gzip, + + /// Zstd algorithm. + Zstd, +} + +impl ContentEncoding { + /// Convert content encoding to string. + #[inline] + pub const fn as_str(self) -> &'static str { + match self { + ContentEncoding::Brotli => "br", + ContentEncoding::Gzip => "gzip", + ContentEncoding::Deflate => "deflate", + ContentEncoding::Zstd => "zstd", + ContentEncoding::Identity => "identity", + } + } + + /// Convert content encoding to header value. + #[inline] + pub const fn to_header_value(self) -> HeaderValue { + match self { + ContentEncoding::Brotli => HeaderValue::from_static("br"), + ContentEncoding::Gzip => HeaderValue::from_static("gzip"), + ContentEncoding::Deflate => HeaderValue::from_static("deflate"), + ContentEncoding::Zstd => HeaderValue::from_static("zstd"), + ContentEncoding::Identity => HeaderValue::from_static("identity"), + } + } +} + +impl Default for ContentEncoding { + #[inline] + fn default() -> Self { + Self::Identity + } +} + +impl FromStr for ContentEncoding { + type Err = ContentEncodingParseError; + + fn from_str(enc: &str) -> Result { + let enc = enc.trim(); + + if enc.eq_ignore_ascii_case("br") { + Ok(ContentEncoding::Brotli) + } else if enc.eq_ignore_ascii_case("gzip") { + Ok(ContentEncoding::Gzip) + } else if enc.eq_ignore_ascii_case("deflate") { + Ok(ContentEncoding::Deflate) + } else if enc.eq_ignore_ascii_case("identity") { + Ok(ContentEncoding::Identity) + } else if enc.eq_ignore_ascii_case("zstd") { + Ok(ContentEncoding::Zstd) + } else { + Err(ContentEncodingParseError) + } + } +} + +impl TryFrom<&str> for ContentEncoding { + type Error = ContentEncodingParseError; + + fn try_from(val: &str) -> Result { + val.parse() + } +} + +impl TryIntoHeaderValue for ContentEncoding { + type Error = InvalidHeaderValue; + + fn try_into_value(self) -> Result { + Ok(HeaderValue::from_static(self.as_str())) + } +} + +impl Header for ContentEncoding { + fn name() -> HeaderName { + header::CONTENT_ENCODING + } + + fn parse(msg: &T) -> Result { + from_one_raw_str(msg.headers().get(Self::name())) + } +} diff --git a/actix-http/src/header/shared/encoding.rs b/actix-http/src/header/shared/encoding.rs deleted file mode 100644 index aa49dea45..000000000 --- a/actix-http/src/header/shared/encoding.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::{fmt, str}; - -pub use self::Encoding::{ - Brotli, Chunked, Compress, Deflate, EncodingExt, Gzip, Identity, Trailers, -}; - -/// A value to represent an encoding used in `Transfer-Encoding` -/// or `Accept-Encoding` header. -#[derive(Clone, PartialEq, Debug)] -pub enum Encoding { - /// The `chunked` encoding. - Chunked, - /// The `br` encoding. - Brotli, - /// The `gzip` encoding. - Gzip, - /// The `deflate` encoding. - Deflate, - /// The `compress` encoding. - Compress, - /// The `identity` encoding. - Identity, - /// The `trailers` encoding. - Trailers, - /// Some other encoding that is less common, can be any String. - EncodingExt(String), -} - -impl fmt::Display for Encoding { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match *self { - Chunked => "chunked", - Brotli => "br", - Gzip => "gzip", - Deflate => "deflate", - Compress => "compress", - Identity => "identity", - Trailers => "trailers", - EncodingExt(ref s) => s.as_ref(), - }) - } -} - -impl str::FromStr for Encoding { - type Err = crate::error::ParseError; - fn from_str(s: &str) -> Result { - match s { - "chunked" => Ok(Chunked), - "br" => Ok(Brotli), - "deflate" => Ok(Deflate), - "gzip" => Ok(Gzip), - "compress" => Ok(Compress), - "identity" => Ok(Identity), - "trailers" => Ok(Trailers), - _ => Ok(EncodingExt(s.to_owned())), - } - } -} diff --git a/actix-http/src/header/shared/extended.rs b/actix-http/src/header/shared/extended.rs index 6bdcb7922..1af9ca20e 100644 --- a/actix-http/src/header/shared/extended.rs +++ b/actix-http/src/header/shared/extended.rs @@ -1,17 +1,17 @@ +//! Originally taken from `hyper::header::parsing`. + use std::{fmt, str::FromStr}; use language_tags::LanguageTag; use crate::header::{Charset, HTTP_VALUE}; -// From hyper v0.11.27 src/header/parsing.rs - /// The value part of an extended parameter consisting of three parts: /// - The REQUIRED character set name (`charset`). /// - The OPTIONAL language information (`language_tag`). /// - A character sequence representing the actual value (`value`), separated by single quotes. /// -/// It is defined in [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). +/// It is defined in [RFC 5987 §3.2](https://datatracker.ietf.org/doc/html/rfc5987#section-3.2). #[derive(Clone, Debug, PartialEq)] pub struct ExtendedValue { /// The character set that is used to encode the `value` to a string. @@ -24,17 +24,17 @@ pub struct ExtendedValue { pub value: Vec, } -/// Parses extended header parameter values (`ext-value`), as defined in -/// [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). +/// Parses extended header parameter values (`ext-value`), as defined +/// in [RFC 5987 §3.2](https://datatracker.ietf.org/doc/html/rfc5987#section-3.2). /// /// Extended values are denoted by parameter names that end with `*`. /// /// ## ABNF /// -/// ```text +/// ```plain /// ext-value = charset "'" [ language ] "'" value-chars /// ; like RFC 2231's -/// ; (see [RFC2231], Section 7) +/// ; (see [RFC 2231 §7]) /// /// charset = "UTF-8" / "ISO-8859-1" / mime-charset /// @@ -43,25 +43,27 @@ pub struct ExtendedValue { /// / "!" / "#" / "$" / "%" / "&" /// / "+" / "-" / "^" / "_" / "`" /// / "{" / "}" / "~" -/// ; as in Section 2.3 of [RFC2978] +/// ; as in [RFC 2978 §2.3] /// ; except that the single quote is not included /// ; SHOULD be registered in the IANA charset registry /// -/// language = +/// language = /// /// value-chars = *( pct-encoded / attr-char ) /// /// pct-encoded = "%" HEXDIG HEXDIG -/// ; see [RFC3986], Section 2.1 +/// ; see [RFC 3986 §2.1] /// /// attr-char = ALPHA / DIGIT /// / "!" / "#" / "$" / "&" / "+" / "-" / "." /// / "^" / "_" / "`" / "|" / "~" /// ; token except ( "*" / "'" / "%" ) /// ``` -pub fn parse_extended_value( - val: &str, -) -> Result { +/// +/// [RFC 2231 §7]: https://datatracker.ietf.org/doc/html/rfc2231#section-7 +/// [RFC 2978 §2.3]: https://datatracker.ietf.org/doc/html/rfc2978#section-2.3 +/// [RFC 3986 §2.1]: https://datatracker.ietf.org/doc/html/rfc5646#section-2.1 +pub fn parse_extended_value(val: &str) -> Result { // Break into three pieces separated by the single-quote character let mut parts = val.splitn(3, '\''); @@ -88,16 +90,15 @@ pub fn parse_extended_value( }; Ok(ExtendedValue { - value, charset, language_tag, + value, }) } impl fmt::Display for ExtendedValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let encoded_value = - percent_encoding::percent_encode(&self.value[..], HTTP_VALUE); + let encoded_value = percent_encoding::percent_encode(&self.value[..], HTTP_VALUE); if let Some(ref lang) = self.language_tag { write!(f, "{}'{}'{}", self.charset, lang, encoded_value) } else { @@ -139,8 +140,8 @@ mod tests { assert!(extended_value.language_tag.is_none()); assert_eq!( vec![ - 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', - b't', b'e', b's', + 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', b't', + b'e', b's', ], extended_value.value ); @@ -181,8 +182,8 @@ mod tests { charset: Charset::Ext("UTF-8".to_string()), language_tag: None, value: vec![ - 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', - b't', b'e', b's', + 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', b't', + b'e', b's', ], }; assert_eq!( diff --git a/actix-http/src/header/shared/http_date.rs b/actix-http/src/header/shared/http_date.rs new file mode 100644 index 000000000..473d6cad0 --- /dev/null +++ b/actix-http/src/header/shared/http_date.rs @@ -0,0 +1,82 @@ +use std::{fmt, io::Write, str::FromStr, time::SystemTime}; + +use bytes::BytesMut; +use http::header::{HeaderValue, InvalidHeaderValue}; + +use crate::{ + config::DATE_VALUE_LENGTH, error::ParseError, header::TryIntoHeaderValue, + helpers::MutWriter, +}; + +/// A timestamp with HTTP-style formatting and parsing. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct HttpDate(SystemTime); + +impl FromStr for HttpDate { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + match httpdate::parse_http_date(s) { + Ok(sys_time) => Ok(HttpDate(sys_time)), + Err(_) => Err(ParseError::Header), + } + } +} + +impl fmt::Display for HttpDate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let date_str = httpdate::fmt_http_date(self.0); + f.write_str(&date_str) + } +} + +impl TryIntoHeaderValue for HttpDate { + type Error = InvalidHeaderValue; + + fn try_into_value(self) -> Result { + let mut buf = BytesMut::with_capacity(DATE_VALUE_LENGTH); + let mut wrt = MutWriter(&mut buf); + + // unwrap: date output is known to be well formed and of known length + write!(wrt, "{}", httpdate::fmt_http_date(self.0)).unwrap(); + + HeaderValue::from_maybe_shared(buf.split().freeze()) + } +} + +impl From for HttpDate { + fn from(sys_time: SystemTime) -> HttpDate { + HttpDate(sys_time) + } +} + +impl From for SystemTime { + fn from(HttpDate(sys_time): HttpDate) -> SystemTime { + sys_time + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + + #[test] + fn date_header() { + macro_rules! assert_parsed_date { + ($case:expr, $exp:expr) => { + assert_eq!($case.parse::().unwrap(), $exp); + }; + } + + // 784198117 = SystemTime::from(datetime!(1994-11-07 08:48:37).assume_utc()).duration_since(SystemTime::UNIX_EPOCH)); + let nov_07 = HttpDate(SystemTime::UNIX_EPOCH + Duration::from_secs(784198117)); + + assert_parsed_date!("Mon, 07 Nov 1994 08:48:37 GMT", nov_07); + assert_parsed_date!("Monday, 07-Nov-94 08:48:37 GMT", nov_07); + assert_parsed_date!("Mon Nov 7 08:48:37 1994", nov_07); + + assert!("this-is-no-date".parse::().is_err()); + } +} diff --git a/actix-http/src/header/shared/httpdate.rs b/actix-http/src/header/shared/httpdate.rs deleted file mode 100644 index 72a225589..000000000 --- a/actix-http/src/header/shared/httpdate.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::fmt::{self, Display}; -use std::io::Write; -use std::str::FromStr; -use std::time::{SystemTime, UNIX_EPOCH}; - -use bytes::buf::BufMut; -use bytes::BytesMut; -use http::header::{HeaderValue, InvalidHeaderValue}; -use time::{offset, OffsetDateTime, PrimitiveDateTime}; - -use crate::error::ParseError; -use crate::header::IntoHeaderValue; -use crate::time_parser; - -/// A timestamp with HTTP formatting and parsing -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct HttpDate(OffsetDateTime); - -impl FromStr for HttpDate { - type Err = ParseError; - - fn from_str(s: &str) -> Result { - match time_parser::parse_http_date(s) { - Some(t) => Ok(HttpDate(t.assume_utc())), - None => Err(ParseError::Header), - } - } -} - -impl Display for HttpDate { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.0.format("%a, %d %b %Y %H:%M:%S GMT"), f) - } -} - -impl From for HttpDate { - fn from(dt: OffsetDateTime) -> HttpDate { - HttpDate(dt) - } -} - -impl From for HttpDate { - fn from(sys: SystemTime) -> HttpDate { - HttpDate(PrimitiveDateTime::from(sys).assume_utc()) - } -} - -impl IntoHeaderValue for HttpDate { - type Error = InvalidHeaderValue; - - fn try_into_value(self) -> Result { - let mut wrt = BytesMut::with_capacity(29).writer(); - write!( - wrt, - "{}", - self.0 - .to_offset(offset!(UTC)) - .format("%a, %d %b %Y %H:%M:%S GMT") - ) - .unwrap(); - HeaderValue::from_maybe_shared(wrt.get_mut().split().freeze()) - } -} - -impl From for SystemTime { - fn from(date: HttpDate) -> SystemTime { - let dt = date.0; - let epoch = OffsetDateTime::unix_epoch(); - - UNIX_EPOCH + (dt - epoch) - } -} - -#[cfg(test)] -mod tests { - use super::HttpDate; - use time::{date, time, PrimitiveDateTime}; - - #[test] - fn test_date() { - let nov_07 = HttpDate( - PrimitiveDateTime::new(date!(1994 - 11 - 07), time!(8:48:37)).assume_utc(), - ); - - assert_eq!( - "Sun, 07 Nov 1994 08:48:37 GMT".parse::().unwrap(), - nov_07 - ); - assert_eq!( - "Sunday, 07-Nov-94 08:48:37 GMT" - .parse::() - .unwrap(), - nov_07 - ); - assert_eq!( - "Sun Nov 7 08:48:37 1994".parse::().unwrap(), - nov_07 - ); - assert!("this-is-no-date".parse::().is_err()); - } -} diff --git a/actix-http/src/header/shared/mod.rs b/actix-http/src/header/shared/mod.rs index 72161e46b..257e54d7a 100644 --- a/actix-http/src/header/shared/mod.rs +++ b/actix-http/src/header/shared/mod.rs @@ -1,16 +1,16 @@ //! Originally taken from `hyper::header::shared`. mod charset; -mod encoding; -mod entity; +mod content_encoding; mod extended; -mod httpdate; +mod http_date; +mod quality; mod quality_item; pub use self::charset::Charset; -pub use self::encoding::Encoding; -pub use self::entity::EntityTag; +pub use self::content_encoding::ContentEncoding; pub use self::extended::{parse_extended_value, ExtendedValue}; -pub use self::httpdate::HttpDate; -pub use self::quality_item::{q, qitem, Quality, QualityItem}; +pub use self::http_date::HttpDate; +pub use self::quality::{q, Quality}; +pub use self::quality_item::QualityItem; pub use language_tags::LanguageTag; diff --git a/actix-http/src/header/shared/quality.rs b/actix-http/src/header/shared/quality.rs new file mode 100644 index 000000000..c80dd0a8e --- /dev/null +++ b/actix-http/src/header/shared/quality.rs @@ -0,0 +1,225 @@ +use std::{ + convert::{TryFrom, TryInto}, + fmt, +}; + +use derive_more::{Display, Error}; + +const MAX_QUALITY_INT: u16 = 1000; +const MAX_QUALITY_FLOAT: f32 = 1.0; + +/// Represents a quality used in q-factor values. +/// +/// The default value is equivalent to `q=1.0` (the [max](Self::MAX) value). +/// +/// # Implementation notes +/// The quality value is defined as a number between 0.0 and 1.0 with three decimal places. +/// This means there are 1001 possible values. Since floating point numbers are not exact and the +/// smallest floating point data type (`f32`) consumes four bytes, we use an `u16` value to store +/// the quality internally. +/// +/// [RFC 7231 §5.3.1] gives more information on quality values in HTTP header fields. +/// +/// # Examples +/// ``` +/// use actix_http::header::{Quality, q}; +/// assert_eq!(q(1.0), Quality::MAX); +/// +/// assert_eq!(q(0.42).to_string(), "0.42"); +/// assert_eq!(q(1.0).to_string(), "1"); +/// assert_eq!(Quality::MIN.to_string(), "0.001"); +/// assert_eq!(Quality::ZERO.to_string(), "0"); +/// ``` +/// +/// [RFC 7231 §5.3.1]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.1 +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct Quality(pub(super) u16); + +impl Quality { + /// The maximum quality value, equivalent to `q=1.0`. + pub const MAX: Quality = Quality(MAX_QUALITY_INT); + + /// The minimum, non-zero quality value, equivalent to `q=0.001`. + pub const MIN: Quality = Quality(1); + + /// The zero quality value, equivalent to `q=0.0`. + pub const ZERO: Quality = Quality(0); + + /// Converts a float in the range 0.0–1.0 to a `Quality`. + /// + /// Intentionally private. External uses should rely on the `TryFrom` impl. + /// + /// # Panics + /// Panics in debug mode when value is not in the range 0.0 <= n <= 1.0. + fn from_f32(value: f32) -> Self { + // Check that `value` is within range should be done before calling this method. + // Just in case, this debug_assert should catch if we were forgetful. + debug_assert!( + (0.0..=MAX_QUALITY_FLOAT).contains(&value), + "q value must be between 0.0 and 1.0" + ); + + Quality((value * MAX_QUALITY_INT as f32) as u16) + } +} + +/// The default value is [`Quality::MAX`]. +impl Default for Quality { + fn default() -> Quality { + Quality::MAX + } +} + +impl fmt::Display for Quality { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + 0 => f.write_str("0"), + MAX_QUALITY_INT => f.write_str("1"), + + // some number in the range 1–999 + x => { + f.write_str("0.")?; + + // This implementation avoids string allocation for removing trailing zeroes. + // In benchmarks it is twice as fast as approach using something like + // `format!("{}").trim_end_matches('0')` for non-fast-path quality values. + + if x < 10 { + // x in is range 1–9 + + f.write_str("00")?; + + // 0 is already handled so it's not possible to have a trailing 0 in this range + // we can just write the integer + itoa_fmt(f, x) + } else if x < 100 { + // x in is range 10–99 + + f.write_str("0")?; + + if x % 10 == 0 { + // trailing 0, divide by 10 and write + itoa_fmt(f, x / 10) + } else { + itoa_fmt(f, x) + } + } else { + // x is in range 100–999 + + if x % 100 == 0 { + // two trailing 0s, divide by 100 and write + itoa_fmt(f, x / 100) + } else if x % 10 == 0 { + // one trailing 0, divide by 10 and write + itoa_fmt(f, x / 10) + } else { + itoa_fmt(f, x) + } + } + } + } + } +} + +/// Write integer to a `fmt::Write`. +pub fn itoa_fmt(mut wr: W, value: V) -> fmt::Result { + let mut buf = itoa::Buffer::new(); + wr.write_str(buf.format(value)) +} + +#[derive(Debug, Clone, Display, Error)] +#[display(fmt = "quality out of bounds")] +#[non_exhaustive] +pub struct QualityOutOfBounds; + +impl TryFrom for Quality { + type Error = QualityOutOfBounds; + + #[inline] + fn try_from(value: f32) -> Result { + if (0.0..=MAX_QUALITY_FLOAT).contains(&value) { + Ok(Quality::from_f32(value)) + } else { + Err(QualityOutOfBounds) + } + } +} + +/// Convenience function to create a [`Quality`] from an `f32` (0.0–1.0). +/// +/// Not recommended for use with user input. Rely on the `TryFrom` impls where possible. +/// +/// # Panics +/// Panics if value is out of range. +/// +/// # Examples +/// ``` +/// # use actix_http::header::{q, Quality}; +/// let q1 = q(1.0); +/// assert_eq!(q1, Quality::MAX); +/// +/// let q2 = q(0.001); +/// assert_eq!(q2, Quality::MIN); +/// +/// let q3 = q(0.0); +/// assert_eq!(q3, Quality::ZERO); +/// +/// let q4 = q(0.42); +/// ``` +/// +/// An out-of-range `f32` quality will panic. +/// ```should_panic +/// # use actix_http::header::q; +/// let _q2 = q(1.42); +/// ``` +#[inline] +pub fn q(quality: T) -> Quality +where + T: TryInto, + T::Error: fmt::Debug, +{ + quality.try_into().expect("quality value was out of bounds") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn q_helper() { + assert_eq!(q(0.5), Quality(500)); + } + + #[test] + fn display_output() { + assert_eq!(Quality::ZERO.to_string(), "0"); + assert_eq!(Quality::MIN.to_string(), "0.001"); + assert_eq!(Quality::MAX.to_string(), "1"); + + assert_eq!(q(0.0).to_string(), "0"); + assert_eq!(q(1.0).to_string(), "1"); + assert_eq!(q(0.001).to_string(), "0.001"); + assert_eq!(q(0.5).to_string(), "0.5"); + assert_eq!(q(0.22).to_string(), "0.22"); + assert_eq!(q(0.123).to_string(), "0.123"); + assert_eq!(q(0.999).to_string(), "0.999"); + + for x in 0..=1000 { + // if trailing zeroes are handled correctly, we would not expect the serialized length + // to ever exceed "0." + 3 decimal places = 5 in length + assert!(q(x as f32 / 1000.0).to_string().len() <= 5); + } + } + + #[test] + #[should_panic] + fn negative_quality() { + q(-1.0); + } + + #[test] + #[should_panic] + fn quality_out_of_bounds() { + q(2.0); + } +} diff --git a/actix-http/src/header/shared/quality_item.rs b/actix-http/src/header/shared/quality_item.rs index 01a3b988a..71a3bdd53 100644 --- a/actix-http/src/header/shared/quality_item.rs +++ b/actix-http/src/header/shared/quality_item.rs @@ -1,101 +1,70 @@ -use std::{ - cmp, - convert::{TryFrom, TryInto}, - fmt, str, -}; +use std::{cmp, convert::TryFrom as _, fmt, str}; -use derive_more::{Display, Error}; +use crate::error::ParseError; -const MAX_QUALITY: u16 = 1000; -const MAX_FLOAT_QUALITY: f32 = 1.0; +use super::Quality; -/// Represents a quality used in quality values. +/// Represents an item with a quality value as defined +/// in [RFC 7231 §5.3.1](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.1). /// -/// Can be created with the [`q`] function. +/// # Parsing and Formatting +/// This wrapper be used to parse header value items that have a q-factor annotation as well as +/// serialize items with a their q-factor. /// -/// # Implementation notes +/// # Ordering +/// Since this context of use for this type is header value items, ordering is defined for +/// `QualityItem`s but _only_ considers the item's quality. Order of appearance should be used as +/// the secondary sorting parameter; i.e., a stable sort over the quality values will produce a +/// correctly sorted sequence. /// -/// The quality value is defined as a number between 0 and 1 with three decimal -/// places. This means there are 1001 possible values. Since floating point -/// numbers are not exact and the smallest floating point data type (`f32`) -/// consumes four bytes, hyper uses an `u16` value to store the -/// quality internally. For performance reasons you may set quality directly to -/// a value between 0 and 1000 e.g. `Quality(532)` matches the quality -/// `q=0.532`. +/// # Examples +/// ``` +/// # use actix_http::header::{QualityItem, q}; +/// let q_item: QualityItem = "hello;q=0.3".parse().unwrap(); +/// assert_eq!(&q_item.item, "hello"); +/// assert_eq!(q_item.quality, q(0.3)); /// -/// [RFC7231 Section 5.3.1](https://tools.ietf.org/html/rfc7231#section-5.3.1) -/// gives more information on quality values in HTTP header fields. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct Quality(u16); - -impl Quality { - /// # Panics - /// Panics in debug mode when value is not in the range 0.0 <= n <= 1.0. - fn from_f32(value: f32) -> Self { - // Check that `value` is within range should be done before calling this method. - // Just in case, this debug_assert should catch if we were forgetful. - debug_assert!( - (0.0f32..=1.0f32).contains(&value), - "q value must be between 0.0 and 1.0" - ); - - Quality((value * MAX_QUALITY as f32) as u16) - } -} - -impl Default for Quality { - fn default() -> Quality { - Quality(MAX_QUALITY) - } -} - -#[derive(Debug, Clone, Display, Error)] -pub struct QualityOutOfBounds; - -impl TryFrom for Quality { - type Error = QualityOutOfBounds; - - fn try_from(value: u16) -> Result { - if (0..=MAX_QUALITY).contains(&value) { - Ok(Quality(value)) - } else { - Err(QualityOutOfBounds) - } - } -} - -impl TryFrom for Quality { - type Error = QualityOutOfBounds; - - fn try_from(value: f32) -> Result { - if (0.0..=MAX_FLOAT_QUALITY).contains(&value) { - Ok(Quality::from_f32(value)) - } else { - Err(QualityOutOfBounds) - } - } -} - -/// Represents an item with a quality value as defined in -/// [RFC7231](https://tools.ietf.org/html/rfc7231#section-5.3.1). -#[derive(Clone, PartialEq, Debug)] +/// // note that format is normalized compared to parsed item +/// assert_eq!(q_item.to_string(), "hello; q=0.3"); +/// +/// // item with q=0.3 is greater than item with q=0.1 +/// let q_item_fallback: QualityItem = "abc;q=0.1".parse().unwrap(); +/// assert!(q_item > q_item_fallback); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct QualityItem { - /// The actual contents of the field. + /// The wrapped contents of the field. pub item: T, + /// The quality (client or server preference) for the value. pub quality: Quality, } impl QualityItem { - /// Creates a new `QualityItem` from an item and a quality. - /// The item can be of any type. - /// The quality should be a value in the range [0, 1]. - pub fn new(item: T, quality: Quality) -> QualityItem { + /// Constructs a new `QualityItem` from an item and a quality value. + /// + /// The item can be of any type. The quality should be a value in the range [0, 1]. + pub fn new(item: T, quality: Quality) -> Self { QualityItem { item, quality } } + + /// Constructs a new `QualityItem` from an item, using the maximum q-value. + pub fn max(item: T) -> Self { + Self::new(item, Quality::MAX) + } + + /// Constructs a new `QualityItem` from an item, using the minimum, non-zero q-value. + pub fn min(item: T) -> Self { + Self::new(item, Quality::MIN) + } + + /// Constructs a new `QualityItem` from an item, using zero q-value of zero. + pub fn zero(item: T) -> Self { + Self::new(item, Quality::ZERO) + } } -impl cmp::PartialOrd for QualityItem { +impl PartialOrd for QualityItem { fn partial_cmp(&self, other: &QualityItem) -> Option { self.quality.partial_cmp(&other.quality) } @@ -105,109 +74,139 @@ impl fmt::Display for QualityItem { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.item, f)?; - match self.quality.0 { - MAX_QUALITY => Ok(()), - 0 => f.write_str("; q=0"), - x => write!(f, "; q=0.{}", format!("{:03}", x).trim_end_matches('0')), + match self.quality { + // q-factor value is implied for max value + Quality::MAX => Ok(()), + + // fast path for zero + Quality::ZERO => f.write_str("; q=0"), + + // quality formatting is already using itoa + q => write!(f, "; q={}", q), } } } impl str::FromStr for QualityItem { - type Err = crate::error::ParseError; + type Err = ParseError; - fn from_str(qitem_str: &str) -> Result, crate::error::ParseError> { - if !qitem_str.is_ascii() { - return Err(crate::error::ParseError::Header); + fn from_str(q_item_str: &str) -> Result { + if !q_item_str.is_ascii() { + return Err(ParseError::Header); } - // Set defaults used if parsing fails. - let mut raw_item = qitem_str; - let mut quality = 1f32; + // set defaults used if quality-item parsing fails, i.e., item has no q attribute + let mut raw_item = q_item_str; + let mut quality = Quality::MAX; - let parts: Vec<_> = qitem_str.rsplitn(2, ';').map(str::trim).collect(); + let parts = q_item_str + .rsplit_once(';') + .map(|(item, q_attr)| (item.trim(), q_attr.trim())); - if parts.len() == 2 { + if let Some((val, q_attr)) = parts { // example for item with q-factor: // - // gzip; q=0.65 - // ^^^^^^ parts[0] - // ^^ start - // ^^^^ q_val - // ^^^^ parts[1] + // gzip;q=0.65 + // ^^^^ val + // ^^^^^^ q_attr + // ^^ q + // ^^^^ q_val - if parts[0].len() < 2 { + if q_attr.len() < 2 { // Can't possibly be an attribute since an attribute needs at least a name followed // by an equals sign. And bare identifiers are forbidden. - return Err(crate::error::ParseError::Header); + return Err(ParseError::Header); } - let start = &parts[0][0..2]; + let q = &q_attr[0..2]; - if start == "q=" || start == "Q=" { - let q_val = &parts[0][2..]; + if q == "q=" || q == "Q=" { + let q_val = &q_attr[2..]; if q_val.len() > 5 { // longer than 5 indicates an over-precise q-factor - return Err(crate::error::ParseError::Header); + return Err(ParseError::Header); } - let q_value = q_val - .parse::() - .map_err(|_| crate::error::ParseError::Header)?; + let q_value = q_val.parse::().map_err(|_| ParseError::Header)?; + let q_value = Quality::try_from(q_value).map_err(|_| ParseError::Header)?; - if (0f32..=1f32).contains(&q_value) { - quality = q_value; - raw_item = parts[1]; - } else { - return Err(crate::error::ParseError::Header); - } + quality = q_value; + raw_item = val; } } - let item = raw_item - .parse::() - .map_err(|_| crate::error::ParseError::Header)?; + let item = raw_item.parse::().map_err(|_| ParseError::Header)?; - // we already checked above that the quality is within range - Ok(QualityItem::new(item, Quality::from_f32(quality))) + Ok(QualityItem::new(item, quality)) } } -/// Convenience function to wrap a value in a `QualityItem` -/// Sets `q` to the default 1.0 -pub fn qitem(item: T) -> QualityItem { - QualityItem::new(item, Quality::default()) -} - -/// Convenience function to create a `Quality` from a float or integer. -/// -/// Implemented for `u16` and `f32`. Panics if value is out of range. -pub fn q(val: T) -> Quality -where - T: TryInto, - T::Error: fmt::Debug, -{ - // TODO: on next breaking change, handle unwrap differently - val.try_into().unwrap() -} - #[cfg(test)] mod tests { - use super::super::encoding::*; use super::*; + // copy of encoding from actix-web headers + #[allow(clippy::enum_variant_names)] // allow Encoding prefix on EncodingExt + #[derive(Clone, PartialEq, Debug)] + pub enum Encoding { + Chunked, + Brotli, + Gzip, + Deflate, + Compress, + Identity, + Trailers, + EncodingExt(String), + } + + impl fmt::Display for Encoding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use Encoding::*; + f.write_str(match *self { + Chunked => "chunked", + Brotli => "br", + Gzip => "gzip", + Deflate => "deflate", + Compress => "compress", + Identity => "identity", + Trailers => "trailers", + EncodingExt(ref s) => s.as_ref(), + }) + } + } + + impl str::FromStr for Encoding { + type Err = crate::error::ParseError; + fn from_str(s: &str) -> Result { + use Encoding::*; + match s { + "chunked" => Ok(Chunked), + "br" => Ok(Brotli), + "deflate" => Ok(Deflate), + "gzip" => Ok(Gzip), + "compress" => Ok(Compress), + "identity" => Ok(Identity), + "trailers" => Ok(Trailers), + _ => Ok(EncodingExt(s.to_owned())), + } + } + } + #[test] fn test_quality_item_fmt_q_1() { - let x = qitem(Chunked); + use Encoding::*; + let x = QualityItem::max(Chunked); assert_eq!(format!("{}", x), "chunked"); } #[test] fn test_quality_item_fmt_q_0001() { + use Encoding::*; let x = QualityItem::new(Chunked, Quality(1)); assert_eq!(format!("{}", x), "chunked; q=0.001"); } #[test] fn test_quality_item_fmt_q_05() { + use Encoding::*; // Custom value let x = QualityItem { item: EncodingExt("identity".to_owned()), @@ -218,6 +217,7 @@ mod tests { #[test] fn test_quality_item_fmt_q_0() { + use Encoding::*; // Custom value let x = QualityItem { item: EncodingExt("identity".to_owned()), @@ -228,6 +228,7 @@ mod tests { #[test] fn test_quality_item_from_str1() { + use Encoding::*; let x: Result, _> = "chunked".parse(); assert_eq!( x.unwrap(), @@ -237,8 +238,10 @@ mod tests { } ); } + #[test] fn test_quality_item_from_str2() { + use Encoding::*; let x: Result, _> = "chunked; q=1".parse(); assert_eq!( x.unwrap(), @@ -248,8 +251,10 @@ mod tests { } ); } + #[test] fn test_quality_item_from_str3() { + use Encoding::*; let x: Result, _> = "gzip; q=0.5".parse(); assert_eq!( x.unwrap(), @@ -259,8 +264,10 @@ mod tests { } ); } + #[test] fn test_quality_item_from_str4() { + use Encoding::*; let x: Result, _> = "gzip; q=0.273".parse(); assert_eq!( x.unwrap(), @@ -270,39 +277,25 @@ mod tests { } ); } + #[test] fn test_quality_item_from_str5() { let x: Result, _> = "gzip; q=0.2739999".parse(); assert!(x.is_err()); } + #[test] fn test_quality_item_from_str6() { let x: Result, _> = "gzip; q=2".parse(); assert!(x.is_err()); } + #[test] fn test_quality_item_ordering() { let x: QualityItem = "gzip; q=0.5".parse().ok().unwrap(); let y: QualityItem = "gzip; q=0.273".parse().ok().unwrap(); - let comparision_result: bool = x.gt(&y); - assert!(comparision_result) - } - - #[test] - fn test_quality() { - assert_eq!(q(0.5), Quality(500)); - } - - #[test] - #[should_panic] - fn test_quality_invalid() { - q(-1.0); - } - - #[test] - #[should_panic] - fn test_quality_invalid2() { - q(2.0); + let comparison_result: bool = x.gt(&y); + assert!(comparison_result) } #[test] diff --git a/actix-http/src/header/utils.rs b/actix-http/src/header/utils.rs index e232d462f..f4f34d347 100644 --- a/actix-http/src/header/utils.rs +++ b/actix-http/src/header/utils.rs @@ -1,7 +1,8 @@ +//! Header parsing utilities. + use std::{fmt, str::FromStr}; -use http::HeaderValue; - +use super::HeaderValue; use crate::{error::ParseError, header::HTTP_VALUE}; /// Reads a comma-delimited raw header into a Vec. @@ -11,9 +12,12 @@ where I: Iterator + 'a, T: FromStr, { - let mut result = Vec::new(); + let size_guess = all.size_hint().1.unwrap_or(2); + let mut result = Vec::with_capacity(size_guess); + for h in all { let s = h.to_str().map_err(|_| ParseError::Header)?; + result.extend( s.split(',') .filter_map(|x| match x.trim() { @@ -23,6 +27,7 @@ where .filter_map(|x| x.trim().parse().ok()), ) } + Ok(result) } @@ -31,10 +36,12 @@ where pub fn from_one_raw_str(val: Option<&HeaderValue>) -> Result { if let Some(line) = val { let line = line.to_str().map_err(|_| ParseError::Header)?; + if !line.is_empty() { return T::from_str(line).or(Err(ParseError::Header)); } } + Err(ParseError::Header) } @@ -45,19 +52,53 @@ where T: fmt::Display, { let mut iter = parts.iter(); + if let Some(part) = iter.next() { fmt::Display::fmt(part, f)?; } + for part in iter { f.write_str(", ")?; fmt::Display::fmt(part, f)?; } + Ok(()) } -/// Percent encode a sequence of bytes with a character set defined in -/// +/// Percent encode a sequence of bytes with a character set defined in [RFC 5987 §3.2]. +/// +/// [RFC 5987 §3.2]: https://datatracker.ietf.org/doc/html/rfc5987#section-3.2 +#[inline] pub fn http_percent_encode(f: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result { let encoded = percent_encoding::percent_encode(bytes, HTTP_VALUE); fmt::Display::fmt(&encoded, f) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn comma_delimited_parsing() { + let headers = vec![]; + let res: Vec = from_comma_delimited(headers.iter()).unwrap(); + assert_eq!(res, vec![0; 0]); + + let headers = vec![ + HeaderValue::from_static("1, 2"), + HeaderValue::from_static("3,4"), + ]; + let res: Vec = from_comma_delimited(headers.iter()).unwrap(); + assert_eq!(res, vec![1, 2, 3, 4]); + + let headers = vec![ + HeaderValue::from_static(""), + HeaderValue::from_static(","), + HeaderValue::from_static(" "), + HeaderValue::from_static("1 ,"), + HeaderValue::from_static(""), + ]; + let res: Vec = from_comma_delimited(headers.iter()).unwrap(); + assert_eq!(res, vec![1]); + } +} diff --git a/actix-http/src/helpers.rs b/actix-http/src/helpers.rs index 13195f7db..cba94d9b8 100644 --- a/actix-http/src/helpers.rs +++ b/actix-http/src/helpers.rs @@ -1,15 +1,15 @@ use std::io; -use bytes::{BufMut, BytesMut}; +use bytes::BufMut; use http::Version; const DIGITS_START: u8 = b'0'; -pub(crate) fn write_status_line(version: Version, n: u16, bytes: &mut BytesMut) { +pub(crate) fn write_status_line(version: Version, n: u16, buf: &mut B) { match version { - Version::HTTP_11 => bytes.put_slice(b"HTTP/1.1 "), - Version::HTTP_10 => bytes.put_slice(b"HTTP/1.0 "), - Version::HTTP_09 => bytes.put_slice(b"HTTP/0.9 "), + Version::HTTP_11 => buf.put_slice(b"HTTP/1.1 "), + Version::HTTP_10 => buf.put_slice(b"HTTP/1.0 "), + Version::HTTP_09 => buf.put_slice(b"HTTP/0.9 "), _ => { // other HTTP version handlers do not use this method } @@ -19,33 +19,44 @@ pub(crate) fn write_status_line(version: Version, n: u16, bytes: &mut BytesMut) let d10 = ((n / 10) % 10) as u8; let d1 = (n % 10) as u8; - bytes.put_u8(DIGITS_START + d100); - bytes.put_u8(DIGITS_START + d10); - bytes.put_u8(DIGITS_START + d1); + buf.put_u8(DIGITS_START + d100); + buf.put_u8(DIGITS_START + d10); + buf.put_u8(DIGITS_START + d1); // trailing space before reason - bytes.put_u8(b' '); + buf.put_u8(b' '); } -/// NOTE: bytes object has to contain enough space -pub fn write_content_length(n: u64, bytes: &mut BytesMut) { +/// Write out content length header. +/// +/// Buffer must to contain enough space or be implicitly extendable. +pub fn write_content_length(n: u64, buf: &mut B) { if n == 0 { - bytes.put_slice(b"\r\ncontent-length: 0\r\n"); + buf.put_slice(b"\r\ncontent-length: 0\r\n"); return; } - let mut buf = itoa::Buffer::new(); + let mut buffer = itoa::Buffer::new(); - bytes.put_slice(b"\r\ncontent-length: "); - bytes.put_slice(buf.format(n).as_bytes()); - bytes.put_slice(b"\r\n"); + buf.put_slice(b"\r\ncontent-length: "); + buf.put_slice(buffer.format(n).as_bytes()); + buf.put_slice(b"\r\n"); } -pub(crate) struct Writer<'a>(pub &'a mut BytesMut); +/// An `io::Write`r that only requires mutable reference and assumes that there is space available +/// in the buffer for every write operation or that it can be extended implicitly (like +/// `bytes::BytesMut`, for example). +/// +/// This is slightly faster (~10%) than `bytes::buf::Writer` in such cases because it does not +/// perform a remaining length check before writing. +pub(crate) struct MutWriter<'a, B>(pub(crate) &'a mut B); -impl<'a> io::Write for Writer<'a> { +impl<'a, B> io::Write for MutWriter<'a, B> +where + B: BufMut, +{ fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.extend_from_slice(buf); + self.0.put_slice(buf); Ok(buf.len()) } @@ -58,6 +69,8 @@ impl<'a> io::Write for Writer<'a> { mod tests { use std::str::from_utf8; + use bytes::BytesMut; + use super::*; #[test] diff --git a/actix-http/src/http_message.rs b/actix-http/src/http_message.rs index b1f04e50d..ccaa320fa 100644 --- a/actix-http/src/http_message.rs +++ b/actix-http/src/http_message.rs @@ -1,19 +1,18 @@ -use std::cell::{Ref, RefMut}; -use std::str; +use std::{ + cell::{Ref, RefMut}, + str, +}; use encoding_rs::{Encoding, UTF_8}; use http::header; use mime::Mime; -use crate::error::{ContentTypeError, ParseError}; -use crate::extensions::Extensions; -use crate::header::{Header, HeaderMap}; -use crate::payload::Payload; -#[cfg(feature = "cookies")] -use crate::{cookie::Cookie, error::CookieParseError}; - -#[cfg(feature = "cookies")] -struct Cookies(Vec>); +use crate::{ + error::{ContentTypeError, ParseError}, + header::{Header, HeaderMap}, + payload::Payload, + Extensions, +}; /// Trait that implements general purpose operations on HTTP messages. pub trait HttpMessage: Sized { @@ -104,41 +103,6 @@ pub trait HttpMessage: Sized { Ok(false) } } - - /// Load request cookies. - #[cfg(feature = "cookies")] - fn cookies(&self) -> Result>>, CookieParseError> { - if self.extensions().get::().is_none() { - let mut cookies = Vec::new(); - for hdr in self.headers().get_all(header::COOKIE) { - let s = - str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; - for cookie_str in s.split(';').map(|s| s.trim()) { - if !cookie_str.is_empty() { - cookies.push(Cookie::parse_encoded(cookie_str)?.into_owned()); - } - } - } - self.extensions_mut().insert(Cookies(cookies)); - } - - Ok(Ref::map(self.extensions(), |ext| { - &ext.get::().unwrap().0 - })) - } - - /// Return request cookie. - #[cfg(feature = "cookies")] - fn cookie(&self, name: &str) -> Option> { - if let Ok(cookies) = self.cookies() { - for cookie in cookies.iter() { - if cookie.name() == name { - return Some(cookie.to_owned()); - } - } - } - None - } } impl<'a, T> HttpMessage for &'a mut T diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index 574d4ef68..f2b415790 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -1,25 +1,24 @@ //! HTTP primitives for the Actix ecosystem. //! //! ## Crate Features -//! | Feature | Functionality | -//! | ---------------- | ----------------------------------------------------- | -//! | `openssl` | TLS support via [OpenSSL]. | -//! | `rustls` | TLS support via [rustls]. | -//! | `compress` | Payload compression support. (Deflate, Gzip & Brotli) | -//! | `cookies` | Support for cookies backed by the [cookie] crate. | -//! | `secure-cookies` | Adds for secure cookies. Enables `cookies` feature. | -//! | `trust-dns` | Use [trust-dns] as the client DNS resolver. | +//! | Feature | Functionality | +//! | ------------------- | ------------------------------------------- | +//! | `openssl` | TLS support via [OpenSSL]. | +//! | `rustls` | TLS support via [rustls]. | +//! | `compress-brotli` | Payload compression support: Brotli. | +//! | `compress-gzip` | Payload compression support: Deflate, Gzip. | +//! | `compress-zstd` | Payload compression support: Zstd. | +//! | `trust-dns` | Use [trust-dns] as the client DNS resolver. | //! //! [OpenSSL]: https://crates.io/crates/openssl //! [rustls]: https://crates.io/crates/rustls -//! [cookie]: https://crates.io/crates/cookie //! [trust-dns]: https://crates.io/crates/trust-dns #![deny(rust_2018_idioms, nonstandard_style)] +#![warn(future_incompatible)] #![allow( clippy::type_complexity, clippy::too_many_arguments, - clippy::new_without_default, clippy::borrow_interior_mutable_const )] #![doc(html_logo_url = "https://actix.rs/img/logo.png")] @@ -28,68 +27,43 @@ #[macro_use] extern crate log; -#[macro_use] -mod macros; +pub use ::http::{uri, uri::Uri}; +pub use ::http::{Method, StatusCode, Version}; pub mod body; mod builder; -pub mod client; mod config; -#[cfg(feature = "compress")] +#[cfg(feature = "__compress")] pub mod encoding; +pub mod error; mod extensions; -mod header; +pub mod h1; +pub mod h2; +pub mod header; mod helpers; -mod http_codes; mod http_message; mod message; mod payload; -mod request; -mod response; +mod requests; +mod responses; mod service; -mod time_parser; - -pub mod error; -pub mod h1; -pub mod h2; pub mod test; pub mod ws; -#[cfg(feature = "cookies")] -pub use cookie; - pub use self::builder::HttpServiceBuilder; pub use self::config::{KeepAlive, ServiceConfig}; -pub use self::error::{Error, ResponseError, Result}; +pub use self::error::Error; pub use self::extensions::Extensions; +pub use self::header::ContentEncoding; pub use self::http_message::HttpMessage; -pub use self::message::{Message, RequestHead, RequestHeadType, ResponseHead}; -pub use self::payload::{Payload, PayloadStream}; -pub use self::request::Request; -pub use self::response::{Response, ResponseBuilder}; +pub use self::message::ConnectionType; +pub use self::message::Message; +#[allow(deprecated)] +pub use self::payload::{BoxedPayloadStream, Payload, PayloadStream}; +pub use self::requests::{Request, RequestHead, RequestHeadType}; +pub use self::responses::{Response, ResponseBuilder, ResponseHead}; pub use self::service::HttpService; -pub mod http { - //! Various HTTP related types. - - // re-exports - pub use http::header::{HeaderName, HeaderValue}; - pub use http::uri::PathAndQuery; - pub use http::{uri, Error, Uri}; - pub use http::{Method, StatusCode, Version}; - - #[cfg(feature = "cookies")] - pub use crate::cookie::{Cookie, CookieBuilder}; - pub use crate::header::HeaderMap; - - /// A collection of HTTP headers and helpers. - pub mod header { - pub use crate::header::*; - } - pub use crate::header::ContentEncoding; - pub use crate::message::ConnectionType; -} - /// A major HTTP protocol version. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] #[non_exhaustive] @@ -105,34 +79,18 @@ type ConnectCallback = dyn Fn(&IO, &mut Extensions); /// /// # Implementation Details /// Uses Option to reduce necessary allocations when merging with request extensions. +#[derive(Default)] pub(crate) struct OnConnectData(Option); -impl Default for OnConnectData { - fn default() -> Self { - Self(None) - } -} - impl OnConnectData { /// Construct by calling the on-connect callback with the underlying transport I/O. - pub(crate) fn from_io( - io: &T, - on_connect_ext: Option<&ConnectCallback>, - ) -> Self { + pub(crate) fn from_io(io: &T, on_connect_ext: Option<&ConnectCallback>) -> Self { let ext = on_connect_ext.map(|handler| { - let mut extensions = Extensions::new(); + let mut extensions = Extensions::default(); handler(io, &mut extensions); extensions }); Self(ext) } - - /// Merge self into given request's extensions. - #[inline] - pub(crate) fn merge_into(&mut self, req: &mut Request) { - if let Some(ref mut ext) = self.0 { - req.head.extensions.get_mut().drain_from(ext); - } - } } diff --git a/actix-http/src/message.rs b/actix-http/src/message.rs index 6438ccba0..ecd08fbb3 100644 --- a/actix-http/src/message.rs +++ b/actix-http/src/message.rs @@ -1,13 +1,7 @@ -use std::cell::{Ref, RefCell, RefMut}; -use std::net; -use std::rc::Rc; +use std::{cell::RefCell, ops, rc::Rc}; use bitflags::bitflags; -use crate::extensions::Extensions; -use crate::header::HeaderMap; -use crate::http::{header, Method, StatusCode, Uri, Version}; - /// Represents various types of connection #[derive(Copy, Clone, PartialEq, Debug)] pub enum ConnectionType { @@ -41,331 +35,29 @@ pub trait Head: Default + 'static { F: FnOnce(&MessagePool) -> R; } -#[derive(Debug)] -pub struct RequestHead { - pub uri: Uri, - pub method: Method, - pub version: Version, - pub headers: HeaderMap, - pub extensions: RefCell, - pub peer_addr: Option, - flags: Flags, -} - -impl Default for RequestHead { - fn default() -> RequestHead { - RequestHead { - uri: Uri::default(), - method: Method::default(), - version: Version::HTTP_11, - headers: HeaderMap::with_capacity(16), - flags: Flags::empty(), - peer_addr: None, - extensions: RefCell::new(Extensions::new()), - } - } -} - -impl Head for RequestHead { - fn clear(&mut self) { - self.flags = Flags::empty(); - self.headers.clear(); - self.extensions.get_mut().clear(); - } - - fn with_pool(f: F) -> R - where - F: FnOnce(&MessagePool) -> R, - { - REQUEST_POOL.with(|p| f(p)) - } -} - -impl RequestHead { - /// Message extensions - #[inline] - pub fn extensions(&self) -> Ref<'_, Extensions> { - self.extensions.borrow() - } - - /// Mutable reference to a the message's extensions - #[inline] - pub fn extensions_mut(&self) -> RefMut<'_, Extensions> { - self.extensions.borrow_mut() - } - - /// Read the message headers. - pub fn headers(&self) -> &HeaderMap { - &self.headers - } - - /// Mutable reference to the message headers. - pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.headers - } - - /// Is to uppercase headers with Camel-Case. - /// Default is `false` - #[inline] - pub fn camel_case_headers(&self) -> bool { - self.flags.contains(Flags::CAMEL_CASE) - } - - /// Set `true` to send headers which are formatted as Camel-Case. - #[inline] - pub fn set_camel_case_headers(&mut self, val: bool) { - if val { - self.flags.insert(Flags::CAMEL_CASE); - } else { - self.flags.remove(Flags::CAMEL_CASE); - } - } - - #[inline] - /// Set connection type of the message - pub fn set_connection_type(&mut self, ctype: ConnectionType) { - match ctype { - ConnectionType::Close => self.flags.insert(Flags::CLOSE), - ConnectionType::KeepAlive => self.flags.insert(Flags::KEEP_ALIVE), - ConnectionType::Upgrade => self.flags.insert(Flags::UPGRADE), - } - } - - #[inline] - /// Connection type - pub fn connection_type(&self) -> ConnectionType { - if self.flags.contains(Flags::CLOSE) { - ConnectionType::Close - } else if self.flags.contains(Flags::KEEP_ALIVE) { - ConnectionType::KeepAlive - } else if self.flags.contains(Flags::UPGRADE) { - ConnectionType::Upgrade - } else if self.version < Version::HTTP_11 { - ConnectionType::Close - } else { - ConnectionType::KeepAlive - } - } - - /// Connection upgrade status - pub fn upgrade(&self) -> bool { - if let Some(hdr) = self.headers().get(header::CONNECTION) { - if let Ok(s) = hdr.to_str() { - s.to_ascii_lowercase().contains("upgrade") - } else { - false - } - } else { - false - } - } - - #[inline] - /// Get response body chunking state - pub fn chunked(&self) -> bool { - !self.flags.contains(Flags::NO_CHUNKING) - } - - #[inline] - pub fn no_chunking(&mut self, val: bool) { - if val { - self.flags.insert(Flags::NO_CHUNKING); - } else { - self.flags.remove(Flags::NO_CHUNKING); - } - } - - #[inline] - /// Request contains `EXPECT` header - pub fn expect(&self) -> bool { - self.flags.contains(Flags::EXPECT) - } - - #[inline] - pub(crate) fn set_expect(&mut self) { - self.flags.insert(Flags::EXPECT); - } -} - -#[derive(Debug)] -pub enum RequestHeadType { - Owned(RequestHead), - Rc(Rc, Option), -} - -impl RequestHeadType { - pub fn extra_headers(&self) -> Option<&HeaderMap> { - match self { - RequestHeadType::Owned(_) => None, - RequestHeadType::Rc(_, headers) => headers.as_ref(), - } - } -} - -impl AsRef for RequestHeadType { - fn as_ref(&self) -> &RequestHead { - match self { - RequestHeadType::Owned(head) => &head, - RequestHeadType::Rc(head, _) => head.as_ref(), - } - } -} - -impl From for RequestHeadType { - fn from(head: RequestHead) -> Self { - RequestHeadType::Owned(head) - } -} - -#[derive(Debug)] -pub struct ResponseHead { - pub version: Version, - pub status: StatusCode, - pub headers: HeaderMap, - pub reason: Option<&'static str>, - pub(crate) extensions: RefCell, - flags: Flags, -} - -impl ResponseHead { - /// Create new instance of `ResponseHead` type - #[inline] - pub fn new(status: StatusCode) -> ResponseHead { - ResponseHead { - status, - version: Version::default(), - headers: HeaderMap::with_capacity(12), - reason: None, - flags: Flags::empty(), - extensions: RefCell::new(Extensions::new()), - } - } - - /// Message extensions - #[inline] - pub fn extensions(&self) -> Ref<'_, Extensions> { - self.extensions.borrow() - } - - /// Mutable reference to a the message's extensions - #[inline] - pub fn extensions_mut(&self) -> RefMut<'_, Extensions> { - self.extensions.borrow_mut() - } - - #[inline] - /// Read the message headers. - pub fn headers(&self) -> &HeaderMap { - &self.headers - } - - #[inline] - /// Mutable reference to the message headers. - pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.headers - } - - #[inline] - /// Set connection type of the message - pub fn set_connection_type(&mut self, ctype: ConnectionType) { - match ctype { - ConnectionType::Close => self.flags.insert(Flags::CLOSE), - ConnectionType::KeepAlive => self.flags.insert(Flags::KEEP_ALIVE), - ConnectionType::Upgrade => self.flags.insert(Flags::UPGRADE), - } - } - - #[inline] - pub fn connection_type(&self) -> ConnectionType { - if self.flags.contains(Flags::CLOSE) { - ConnectionType::Close - } else if self.flags.contains(Flags::KEEP_ALIVE) { - ConnectionType::KeepAlive - } else if self.flags.contains(Flags::UPGRADE) { - ConnectionType::Upgrade - } else if self.version < Version::HTTP_11 { - ConnectionType::Close - } else { - ConnectionType::KeepAlive - } - } - - #[inline] - /// Check if keep-alive is enabled - pub fn keep_alive(&self) -> bool { - self.connection_type() == ConnectionType::KeepAlive - } - - #[inline] - /// Check upgrade status of this message - pub fn upgrade(&self) -> bool { - self.connection_type() == ConnectionType::Upgrade - } - - /// Get custom reason for the response - #[inline] - pub fn reason(&self) -> &str { - if let Some(reason) = self.reason { - reason - } else { - self.status - .canonical_reason() - .unwrap_or("") - } - } - - #[inline] - pub(crate) fn ctype(&self) -> Option { - if self.flags.contains(Flags::CLOSE) { - Some(ConnectionType::Close) - } else if self.flags.contains(Flags::KEEP_ALIVE) { - Some(ConnectionType::KeepAlive) - } else if self.flags.contains(Flags::UPGRADE) { - Some(ConnectionType::Upgrade) - } else { - None - } - } - - #[inline] - /// Get response body chunking state - pub fn chunked(&self) -> bool { - !self.flags.contains(Flags::NO_CHUNKING) - } - - #[inline] - /// Set no chunking for payload - pub fn no_chunking(&mut self, val: bool) { - if val { - self.flags.insert(Flags::NO_CHUNKING); - } else { - self.flags.remove(Flags::NO_CHUNKING); - } - } -} - pub struct Message { - // Rc here should not be cloned by anyone. - // It's used to reuse allocation of T and no shared ownership is allowed. + /// Rc here should not be cloned by anyone. + /// It's used to reuse allocation of T and no shared ownership is allowed. head: Rc, } impl Message { /// Get new message from the pool of objects + #[allow(clippy::new_without_default)] pub fn new() -> Self { - T::with_pool(|p| p.get_message()) + T::with_pool(MessagePool::get_message) } } -impl std::ops::Deref for Message { +impl ops::Deref for Message { type Target = T; fn deref(&self) -> &Self::Target { - &self.head.as_ref() + self.head.as_ref() } } -impl std::ops::DerefMut for Message { +impl ops::DerefMut for Message { fn deref_mut(&mut self) -> &mut Self::Target { Rc::get_mut(&mut self.head).expect("Multiple copies exist") } @@ -377,59 +69,12 @@ impl Drop for Message { } } -pub(crate) struct BoxedResponseHead { - head: Option>, -} - -impl BoxedResponseHead { - /// Get new message from the pool of objects - pub fn new(status: StatusCode) -> Self { - RESPONSE_POOL.with(|p| p.get_message(status)) - } - - pub(crate) fn take(&mut self) -> Self { - BoxedResponseHead { - head: self.head.take(), - } - } -} - -impl std::ops::Deref for BoxedResponseHead { - type Target = ResponseHead; - - fn deref(&self) -> &Self::Target { - self.head.as_ref().unwrap() - } -} - -impl std::ops::DerefMut for BoxedResponseHead { - fn deref_mut(&mut self) -> &mut Self::Target { - self.head.as_mut().unwrap() - } -} - -impl Drop for BoxedResponseHead { - fn drop(&mut self) { - if let Some(head) = self.head.take() { - RESPONSE_POOL.with(move |p| p.release(head)) - } - } -} - #[doc(hidden)] /// Request's objects pool pub struct MessagePool(RefCell>>); -#[doc(hidden)] -#[allow(clippy::vec_box)] -/// Request's objects pool -pub struct BoxedResponsePool(RefCell>>); - -thread_local!(static REQUEST_POOL: MessagePool = MessagePool::::create()); -thread_local!(static RESPONSE_POOL: BoxedResponsePool = BoxedResponsePool::create()); - impl MessagePool { - fn create() -> MessagePool { + pub(crate) fn create() -> MessagePool { MessagePool(RefCell::new(Vec::with_capacity(128))) } @@ -451,43 +96,11 @@ impl MessagePool { } #[inline] - /// Release request instance + /// Release message instance fn release(&self, msg: Rc) { - let v = &mut self.0.borrow_mut(); - if v.len() < 128 { - v.push(msg); - } - } -} - -impl BoxedResponsePool { - fn create() -> BoxedResponsePool { - BoxedResponsePool(RefCell::new(Vec::with_capacity(128))) - } - - /// Get message from the pool - #[inline] - fn get_message(&self, status: StatusCode) -> BoxedResponseHead { - if let Some(mut head) = self.0.borrow_mut().pop() { - head.reason = None; - head.status = status; - head.headers.clear(); - head.flags = Flags::empty(); - BoxedResponseHead { head: Some(head) } - } else { - BoxedResponseHead { - head: Some(Box::new(ResponseHead::new(status))), - } - } - } - - #[inline] - /// Release request instance - fn release(&self, mut msg: Box) { - let v = &mut self.0.borrow_mut(); - if v.len() < 128 { - msg.extensions.get_mut().clear(); - v.push(msg); + let pool = &mut self.0.borrow_mut(); + if pool.len() < 128 { + pool.push(msg); } } } diff --git a/actix-http/src/payload.rs b/actix-http/src/payload.rs index 54de6ed93..c9f338c7d 100644 --- a/actix-http/src/payload.rs +++ b/actix-http/src/payload.rs @@ -1,70 +1,89 @@ -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + mem, + pin::Pin, + task::{Context, Poll}, +}; use bytes::Bytes; use futures_core::Stream; -use h2::RecvStream; use crate::error::PayloadError; -/// Type represent boxed payload -pub type PayloadStream = Pin>>>; +/// A boxed payload stream. +pub type BoxedPayloadStream = Pin>>>; -/// Type represent streaming payload -pub enum Payload { - None, - H1(crate::h1::Payload), - H2(crate::h2::Payload), - Stream(S), +#[deprecated(since = "4.0.0", note = "Renamed to `BoxedPayloadStream`.")] +pub type PayloadStream = BoxedPayloadStream; + +pin_project_lite::pin_project! { + /// A streaming payload. + #[project = PayloadProj] + pub enum Payload { + None, + H1 { payload: crate::h1::Payload }, + H2 { payload: crate::h2::Payload }, + Stream { #[pin] payload: S }, + } } impl From for Payload { - fn from(v: crate::h1::Payload) -> Self { - Payload::H1(v) + fn from(payload: crate::h1::Payload) -> Self { + Payload::H1 { payload } } } impl From for Payload { - fn from(v: crate::h2::Payload) -> Self { - Payload::H2(v) + fn from(payload: crate::h2::Payload) -> Self { + Payload::H2 { payload } } } -impl From for Payload { - fn from(v: RecvStream) -> Self { - Payload::H2(crate::h2::Payload::new(v)) +impl From for Payload { + fn from(stream: h2::RecvStream) -> Self { + Payload::H2 { + payload: crate::h2::Payload::new(stream), + } } } -impl From for Payload { - fn from(pl: PayloadStream) -> Self { - Payload::Stream(pl) +impl From for Payload { + fn from(payload: BoxedPayloadStream) -> Self { + Payload::Stream { payload } } } impl Payload { /// Takes current payload and replaces it with `None` value pub fn take(&mut self) -> Payload { - std::mem::replace(self, Payload::None) + mem::replace(self, Payload::None) } } impl Stream for Payload where - S: Stream> + Unpin, + S: Stream>, { type Item = Result; #[inline] - fn poll_next( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.get_mut() { - Payload::None => Poll::Ready(None), - Payload::H1(ref mut pl) => pl.readany(cx), - Payload::H2(ref mut pl) => Pin::new(pl).poll_next(cx), - Payload::Stream(ref mut pl) => Pin::new(pl).poll_next(cx), + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + PayloadProj::None => Poll::Ready(None), + PayloadProj::H1 { payload } => Pin::new(payload).poll_next(cx), + PayloadProj::H2 { payload } => Pin::new(payload).poll_next(cx), + PayloadProj::Stream { payload } => payload.poll_next(cx), } } } + +#[cfg(test)] +mod tests { + use std::panic::{RefUnwindSafe, UnwindSafe}; + + use static_assertions::{assert_impl_all, assert_not_impl_any}; + + use super::*; + + assert_impl_all!(Payload: Unpin); + assert_not_impl_any!(Payload: Send, Sync, UnwindSafe, RefUnwindSafe); +} diff --git a/actix-http/src/requests/head.rs b/actix-http/src/requests/head.rs new file mode 100644 index 000000000..524075b61 --- /dev/null +++ b/actix-http/src/requests/head.rs @@ -0,0 +1,174 @@ +use std::{net, rc::Rc}; + +use crate::{ + header::{self, HeaderMap}, + message::{Flags, Head, MessagePool}, + ConnectionType, Method, Uri, Version, +}; + +thread_local! { + static REQUEST_POOL: MessagePool = MessagePool::::create() +} + +#[derive(Debug, Clone)] +pub struct RequestHead { + pub method: Method, + pub uri: Uri, + pub version: Version, + pub headers: HeaderMap, + pub peer_addr: Option, + flags: Flags, +} + +impl Default for RequestHead { + fn default() -> RequestHead { + RequestHead { + method: Method::default(), + uri: Uri::default(), + version: Version::HTTP_11, + headers: HeaderMap::with_capacity(16), + peer_addr: None, + flags: Flags::empty(), + } + } +} + +impl Head for RequestHead { + fn clear(&mut self) { + self.flags = Flags::empty(); + self.headers.clear(); + } + + fn with_pool(f: F) -> R + where + F: FnOnce(&MessagePool) -> R, + { + REQUEST_POOL.with(|p| f(p)) + } +} + +impl RequestHead { + /// Read the message headers. + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + /// Mutable reference to the message headers. + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers + } + + /// Is to uppercase headers with Camel-Case. + /// Default is `false` + #[inline] + pub fn camel_case_headers(&self) -> bool { + self.flags.contains(Flags::CAMEL_CASE) + } + + /// Set `true` to send headers which are formatted as Camel-Case. + #[inline] + pub fn set_camel_case_headers(&mut self, val: bool) { + if val { + self.flags.insert(Flags::CAMEL_CASE); + } else { + self.flags.remove(Flags::CAMEL_CASE); + } + } + + #[inline] + /// Set connection type of the message + pub fn set_connection_type(&mut self, ctype: ConnectionType) { + match ctype { + ConnectionType::Close => self.flags.insert(Flags::CLOSE), + ConnectionType::KeepAlive => self.flags.insert(Flags::KEEP_ALIVE), + ConnectionType::Upgrade => self.flags.insert(Flags::UPGRADE), + } + } + + #[inline] + /// Connection type + pub fn connection_type(&self) -> ConnectionType { + if self.flags.contains(Flags::CLOSE) { + ConnectionType::Close + } else if self.flags.contains(Flags::KEEP_ALIVE) { + ConnectionType::KeepAlive + } else if self.flags.contains(Flags::UPGRADE) { + ConnectionType::Upgrade + } else if self.version < Version::HTTP_11 { + ConnectionType::Close + } else { + ConnectionType::KeepAlive + } + } + + /// Connection upgrade status + pub fn upgrade(&self) -> bool { + self.headers() + .get(header::CONNECTION) + .map(|hdr| { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("upgrade") + } else { + false + } + }) + .unwrap_or(false) + } + + #[inline] + /// Get response body chunking state + pub fn chunked(&self) -> bool { + !self.flags.contains(Flags::NO_CHUNKING) + } + + #[inline] + pub fn no_chunking(&mut self, val: bool) { + if val { + self.flags.insert(Flags::NO_CHUNKING); + } else { + self.flags.remove(Flags::NO_CHUNKING); + } + } + + #[inline] + /// Request contains `EXPECT` header + pub fn expect(&self) -> bool { + self.flags.contains(Flags::EXPECT) + } + + #[inline] + pub(crate) fn set_expect(&mut self) { + self.flags.insert(Flags::EXPECT); + } +} + +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub enum RequestHeadType { + Owned(RequestHead), + Rc(Rc, Option), +} + +impl RequestHeadType { + pub fn extra_headers(&self) -> Option<&HeaderMap> { + match self { + RequestHeadType::Owned(_) => None, + RequestHeadType::Rc(_, headers) => headers.as_ref(), + } + } +} + +impl AsRef for RequestHeadType { + fn as_ref(&self) -> &RequestHead { + match self { + RequestHeadType::Owned(head) => head, + RequestHeadType::Rc(head, _) => head.as_ref(), + } + } +} + +impl From for RequestHeadType { + fn from(head: RequestHead) -> Self { + RequestHeadType::Owned(head) + } +} diff --git a/actix-http/src/requests/mod.rs b/actix-http/src/requests/mod.rs new file mode 100644 index 000000000..fc35da65a --- /dev/null +++ b/actix-http/src/requests/mod.rs @@ -0,0 +1,7 @@ +//! HTTP requests. + +mod head; +mod request; + +pub use self::head::{RequestHead, RequestHeadType}; +pub use self::request::Request; diff --git a/actix-http/src/request.rs b/actix-http/src/requests/request.rs similarity index 70% rename from actix-http/src/request.rs rename to actix-http/src/requests/request.rs index 197ec11c6..4eaaba8e1 100644 --- a/actix-http/src/request.rs +++ b/actix-http/src/requests/request.rs @@ -1,22 +1,25 @@ //! HTTP requests. use std::{ - cell::{Ref, RefMut}, - fmt, net, + cell::{Ref, RefCell, RefMut}, + fmt, mem, net, + rc::Rc, + str, }; use http::{header, Method, Uri, Version}; -use crate::extensions::Extensions; -use crate::header::HeaderMap; -use crate::message::{Message, RequestHead}; -use crate::payload::{Payload, PayloadStream}; -use crate::HttpMessage; +use crate::{ + header::HeaderMap, BoxedPayloadStream, Extensions, HttpMessage, Message, Payload, + RequestHead, +}; -/// Request -pub struct Request

{ +/// An HTTP request. +pub struct Request

{ pub(crate) payload: Payload

, pub(crate) head: Message, + pub(crate) conn_data: Option>, + pub(crate) req_data: RefCell, } impl

HttpMessage for Request

{ @@ -28,37 +31,42 @@ impl

HttpMessage for Request

{ } fn take_payload(&mut self) -> Payload

{ - std::mem::replace(&mut self.payload, Payload::None) + mem::replace(&mut self.payload, Payload::None) } /// Request extensions #[inline] fn extensions(&self) -> Ref<'_, Extensions> { - self.head.extensions() + self.req_data.borrow() } /// Mutable reference to a the request's extensions #[inline] fn extensions_mut(&self) -> RefMut<'_, Extensions> { - self.head.extensions_mut() + self.req_data.borrow_mut() } } -impl From> for Request { +impl From> for Request { fn from(head: Message) -> Self { Request { head, payload: Payload::None, + req_data: RefCell::new(Extensions::default()), + conn_data: None, } } } -impl Request { +impl Request { /// Create new Request instance - pub fn new() -> Request { + #[allow(clippy::new_without_default)] + pub fn new() -> Request { Request { head: Message::new(), payload: Payload::None, + req_data: RefCell::new(Extensions::default()), + conn_data: None, } } } @@ -69,16 +77,21 @@ impl

Request

{ Request { payload, head: Message::new(), + req_data: RefCell::new(Extensions::default()), + conn_data: None, } } /// Create new Request instance pub fn replace_payload(self, payload: Payload) -> (Request, Payload

) { let pl = self.payload; + ( Request { payload, head: self.head, + req_data: self.req_data, + conn_data: self.conn_data, }, pl, ) @@ -91,7 +104,7 @@ impl

Request

{ /// Get request's payload pub fn take_payload(&mut self) -> Payload

{ - std::mem::replace(&mut self.payload, Payload::None) + mem::replace(&mut self.payload, Payload::None) } /// Split request into request head and payload @@ -114,7 +127,7 @@ impl

Request

{ /// Mutable reference to the message's headers. pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.head_mut().headers + &mut self.head.headers } /// Request's uri. @@ -126,7 +139,7 @@ impl

Request

{ /// Mutable reference to the request's uri. #[inline] pub fn uri_mut(&mut self) -> &mut Uri { - &mut self.head_mut().uri + &mut self.head.uri } /// Read the Request method. @@ -168,6 +181,31 @@ impl

Request

{ pub fn peer_addr(&self) -> Option { self.head().peer_addr } + + /// Returns a reference a piece of connection data set in an [on-connect] callback. + /// + /// ```ignore + /// let opt_t = req.conn_data::(); + /// ``` + /// + /// [on-connect]: crate::HttpServiceBuilder::on_connect_ext + pub fn conn_data(&self) -> Option<&T> { + self.conn_data + .as_deref() + .and_then(|container| container.get::()) + } + + /// Returns the connection data container if an [on-connect] callback was registered. + /// + /// [on-connect]: crate::HttpServiceBuilder::on_connect_ext + pub fn take_conn_data(&mut self) -> Option> { + self.conn_data.take() + } + + /// Returns the request data container, leaving an empty one in it's place. + pub fn take_req_data(&mut self) -> Extensions { + mem::take(self.req_data.get_mut()) + } } impl

fmt::Debug for Request

{ diff --git a/actix-http/src/response.rs b/actix-http/src/response.rs deleted file mode 100644 index 471dacd28..000000000 --- a/actix-http/src/response.rs +++ /dev/null @@ -1,1165 +0,0 @@ -//! HTTP responses. - -use std::{ - cell::{Ref, RefMut}, - convert::TryInto, - fmt, - future::Future, - ops, - pin::Pin, - str, - task::{Context, Poll}, -}; - -use bytes::{Bytes, BytesMut}; -use futures_core::Stream; -use serde::Serialize; - -use crate::body::{Body, BodyStream, MessageBody, ResponseBody}; -use crate::error::Error; -use crate::extensions::Extensions; -use crate::header::{IntoHeaderPair, IntoHeaderValue}; -use crate::http::header::{self, HeaderName}; -use crate::http::{Error as HttpError, HeaderMap, StatusCode}; -use crate::message::{BoxedResponseHead, ConnectionType, ResponseHead}; -#[cfg(feature = "cookies")] -use crate::{ - cookie::{Cookie, CookieJar}, - http::header::HeaderValue, -}; - -/// An HTTP Response -pub struct Response { - head: BoxedResponseHead, - body: ResponseBody, - error: Option, -} - -impl Response { - /// Create HTTP response builder with specific status. - #[inline] - pub fn build(status: StatusCode) -> ResponseBuilder { - ResponseBuilder::new(status) - } - - /// Create HTTP response builder - #[inline] - pub fn build_from>(source: T) -> ResponseBuilder { - source.into() - } - - /// Constructs a response - #[inline] - pub fn new(status: StatusCode) -> Response { - Response { - head: BoxedResponseHead::new(status), - body: ResponseBody::Body(Body::Empty), - error: None, - } - } - - /// Constructs an error response - #[inline] - pub fn from_error(error: Error) -> Response { - let mut resp = error.as_response_error().error_response(); - if resp.head.status == StatusCode::INTERNAL_SERVER_ERROR { - error!("Internal Server Error: {:?}", error); - } - resp.error = Some(error); - resp - } - - /// Convert response to response with body - pub fn into_body(self) -> Response { - let b = match self.body { - ResponseBody::Body(b) => b, - ResponseBody::Other(b) => b, - }; - Response { - head: self.head, - error: self.error, - body: ResponseBody::Other(b), - } - } -} - -impl Response { - /// Constructs a response with body - #[inline] - pub fn with_body(status: StatusCode, body: B) -> Response { - Response { - head: BoxedResponseHead::new(status), - body: ResponseBody::Body(body), - error: None, - } - } - - #[inline] - /// Http message part of the response - pub fn head(&self) -> &ResponseHead { - &*self.head - } - - #[inline] - /// Mutable reference to a HTTP message part of the response - pub fn head_mut(&mut self) -> &mut ResponseHead { - &mut *self.head - } - - /// The source `error` for this response - #[inline] - pub fn error(&self) -> Option<&Error> { - self.error.as_ref() - } - - /// Get the response status code - #[inline] - pub fn status(&self) -> StatusCode { - self.head.status - } - - /// Set the `StatusCode` for this response - #[inline] - pub fn status_mut(&mut self) -> &mut StatusCode { - &mut self.head.status - } - - /// Get the headers from the response - #[inline] - pub fn headers(&self) -> &HeaderMap { - &self.head.headers - } - - /// Get a mutable reference to the headers - #[inline] - pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.head.headers - } - - /// Get an iterator for the cookies set by this response - #[cfg(feature = "cookies")] - #[inline] - pub fn cookies(&self) -> CookieIter<'_> { - CookieIter { - iter: self.head.headers.get_all(header::SET_COOKIE), - } - } - - /// Add a cookie to this response - #[cfg(feature = "cookies")] - #[inline] - pub fn add_cookie(&mut self, cookie: &Cookie<'_>) -> Result<(), HttpError> { - let h = &mut self.head.headers; - HeaderValue::from_str(&cookie.to_string()) - .map(|c| { - h.append(header::SET_COOKIE, c); - }) - .map_err(|e| e.into()) - } - - /// Remove all cookies with the given name from this response. Returns - /// the number of cookies removed. - #[cfg(feature = "cookies")] - #[inline] - pub fn del_cookie(&mut self, name: &str) -> usize { - let h = &mut self.head.headers; - let vals: Vec = h - .get_all(header::SET_COOKIE) - .map(|v| v.to_owned()) - .collect(); - h.remove(header::SET_COOKIE); - - let mut count: usize = 0; - for v in vals { - if let Ok(s) = v.to_str() { - if let Ok(c) = Cookie::parse_encoded(s) { - if c.name() == name { - count += 1; - continue; - } - } - } - h.append(header::SET_COOKIE, v); - } - count - } - - /// Connection upgrade status - #[inline] - pub fn upgrade(&self) -> bool { - self.head.upgrade() - } - - /// Keep-alive status for this connection - pub fn keep_alive(&self) -> bool { - self.head.keep_alive() - } - - /// Responses extensions - #[inline] - pub fn extensions(&self) -> Ref<'_, Extensions> { - self.head.extensions.borrow() - } - - /// Mutable reference to a the response's extensions - #[inline] - pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { - self.head.extensions.borrow_mut() - } - - /// Get body of this response - #[inline] - pub fn body(&self) -> &ResponseBody { - &self.body - } - - /// Set a body - pub fn set_body(self, body: B2) -> Response { - Response { - head: self.head, - body: ResponseBody::Body(body), - error: None, - } - } - - /// Split response and body - pub fn into_parts(self) -> (Response<()>, ResponseBody) { - ( - Response { - head: self.head, - body: ResponseBody::Body(()), - error: self.error, - }, - self.body, - ) - } - - /// Drop request's body - pub fn drop_body(self) -> Response<()> { - Response { - head: self.head, - body: ResponseBody::Body(()), - error: None, - } - } - - /// Set a body and return previous body value - pub(crate) fn replace_body(self, body: B2) -> (Response, ResponseBody) { - ( - Response { - head: self.head, - body: ResponseBody::Body(body), - error: self.error, - }, - self.body, - ) - } - - /// Set a body and return previous body value - pub fn map_body(mut self, f: F) -> Response - where - F: FnOnce(&mut ResponseHead, ResponseBody) -> ResponseBody, - { - let body = f(&mut self.head, self.body); - - Response { - body, - head: self.head, - error: self.error, - } - } - - /// Extract response body - pub fn take_body(&mut self) -> ResponseBody { - self.body.take_body() - } -} - -impl fmt::Debug for Response { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let res = writeln!( - f, - "\nResponse {:?} {}{}", - self.head.version, - self.head.status, - self.head.reason.unwrap_or(""), - ); - let _ = writeln!(f, " headers:"); - for (key, val) in self.head.headers.iter() { - let _ = writeln!(f, " {:?}: {:?}", key, val); - } - let _ = writeln!(f, " body: {:?}", self.body.size()); - res - } -} - -impl Future for Response { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { - Poll::Ready(Ok(Response { - head: self.head.take(), - body: self.body.take_body(), - error: self.error.take(), - })) - } -} - -#[cfg(feature = "cookies")] -pub struct CookieIter<'a> { - iter: header::GetAll<'a>, -} - -#[cfg(feature = "cookies")] -impl<'a> Iterator for CookieIter<'a> { - type Item = Cookie<'a>; - - #[inline] - fn next(&mut self) -> Option> { - for v in self.iter.by_ref() { - if let Ok(c) = Cookie::parse_encoded(v.to_str().ok()?) { - return Some(c); - } - } - None - } -} - -/// An HTTP response builder. -/// -/// This type can be used to construct an instance of `Response` through a builder-like pattern. -pub struct ResponseBuilder { - head: Option, - err: Option, - #[cfg(feature = "cookies")] - cookies: Option, -} - -impl ResponseBuilder { - #[inline] - /// Create response builder - pub fn new(status: StatusCode) -> Self { - ResponseBuilder { - head: Some(BoxedResponseHead::new(status)), - err: None, - #[cfg(feature = "cookies")] - cookies: None, - } - } - - /// Set HTTP status code of this response. - #[inline] - pub fn status(&mut self, status: StatusCode) -> &mut Self { - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.status = status; - } - self - } - - /// Insert a header, replacing any that were set with an equivalent field name. - /// - /// ```rust - /// # use actix_http::Response; - /// use actix_http::http::header::ContentType; - /// - /// Response::Ok() - /// .insert_header(ContentType(mime::APPLICATION_JSON)) - /// .insert_header(("X-TEST", "value")) - /// .finish(); - /// ``` - pub fn insert_header(&mut self, header: H) -> &mut Self - where - H: IntoHeaderPair, - { - if let Some(parts) = parts(&mut self.head, &self.err) { - match header.try_into_header_pair() { - Ok((key, value)) => { - parts.headers.insert(key, value); - } - Err(e) => self.err = Some(e.into()), - }; - } - - self - } - - /// Append a header, keeping any that were set with an equivalent field name. - /// - /// ```rust - /// # use actix_http::Response; - /// use actix_http::http::header::ContentType; - /// - /// Response::Ok() - /// .append_header(ContentType(mime::APPLICATION_JSON)) - /// .append_header(("X-TEST", "value1")) - /// .append_header(("X-TEST", "value2")) - /// .finish(); - /// ``` - pub fn append_header(&mut self, header: H) -> &mut Self - where - H: IntoHeaderPair, - { - if let Some(parts) = parts(&mut self.head, &self.err) { - match header.try_into_header_pair() { - Ok((key, value)) => parts.headers.append(key, value), - Err(e) => self.err = Some(e.into()), - }; - } - - self - } - - /// Replaced with [`Self::insert_header()`]. - #[deprecated = "Replaced with `insert_header((key, value))`."] - pub fn set_header(&mut self, key: K, value: V) -> &mut Self - where - K: TryInto, - K::Error: Into, - V: IntoHeaderValue, - { - if self.err.is_some() { - return self; - } - - match (key.try_into(), value.try_into_value()) { - (Ok(name), Ok(value)) => return self.insert_header((name, value)), - (Err(err), _) => self.err = Some(err.into()), - (_, Err(err)) => self.err = Some(err.into()), - } - - self - } - - /// Replaced with [`Self::append_header()`]. - #[deprecated = "Replaced with `append_header((key, value))`."] - pub fn header(&mut self, key: K, value: V) -> &mut Self - where - K: TryInto, - K::Error: Into, - V: IntoHeaderValue, - { - if self.err.is_some() { - return self; - } - - match (key.try_into(), value.try_into_value()) { - (Ok(name), Ok(value)) => return self.append_header((name, value)), - (Err(err), _) => self.err = Some(err.into()), - (_, Err(err)) => self.err = Some(err.into()), - } - - self - } - - /// Set the custom reason for the response. - #[inline] - pub fn reason(&mut self, reason: &'static str) -> &mut Self { - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.reason = Some(reason); - } - self - } - - /// Set connection type to KeepAlive - #[inline] - pub fn keep_alive(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.set_connection_type(ConnectionType::KeepAlive); - } - self - } - - /// Set connection type to Upgrade - #[inline] - pub fn upgrade(&mut self, value: V) -> &mut Self - where - V: IntoHeaderValue, - { - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.set_connection_type(ConnectionType::Upgrade); - } - - if let Ok(value) = value.try_into_value() { - self.insert_header((header::UPGRADE, value)); - } - - self - } - - /// Force close connection, even if it is marked as keep-alive - #[inline] - pub fn force_close(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.set_connection_type(ConnectionType::Close); - } - self - } - - /// Disable chunked transfer encoding for HTTP/1.1 streaming responses. - #[inline] - pub fn no_chunking(&mut self, len: u64) -> &mut Self { - self.insert_header((header::CONTENT_LENGTH, len)); - - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.no_chunking(true); - } - self - } - - /// Set response content type. - #[inline] - pub fn content_type(&mut self, value: V) -> &mut Self - where - V: IntoHeaderValue, - { - if let Some(parts) = parts(&mut self.head, &self.err) { - match value.try_into_value() { - Ok(value) => { - parts.headers.insert(header::CONTENT_TYPE, value); - } - Err(e) => self.err = Some(e.into()), - }; - } - self - } - - /// Set a cookie - /// - /// ```rust - /// use actix_http::{http, Request, Response}; - /// - /// fn index(req: Request) -> Response { - /// Response::Ok() - /// .cookie( - /// http::Cookie::build("name", "value") - /// .domain("www.rust-lang.org") - /// .path("/") - /// .secure(true) - /// .http_only(true) - /// .finish(), - /// ) - /// .finish() - /// } - /// ``` - #[cfg(feature = "cookies")] - pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { - if self.cookies.is_none() { - let mut jar = CookieJar::new(); - jar.add(cookie.into_owned()); - self.cookies = Some(jar) - } else { - self.cookies.as_mut().unwrap().add(cookie.into_owned()); - } - self - } - - /// Remove cookie - /// - /// ```rust - /// use actix_http::{http, Request, Response, HttpMessage}; - /// - /// fn index(req: Request) -> Response { - /// let mut builder = Response::Ok(); - /// - /// if let Some(ref cookie) = req.cookie("name") { - /// builder.del_cookie(cookie); - /// } - /// - /// builder.finish() - /// } - /// ``` - #[cfg(feature = "cookies")] - pub fn del_cookie<'a>(&mut self, cookie: &Cookie<'a>) -> &mut Self { - if self.cookies.is_none() { - self.cookies = Some(CookieJar::new()) - } - let jar = self.cookies.as_mut().unwrap(); - let cookie = cookie.clone().into_owned(); - jar.add_original(cookie.clone()); - jar.remove(cookie); - self - } - - /// This method calls provided closure with builder reference if value is `true`. - #[doc(hidden)] - #[deprecated = "Use an if statement."] - pub fn if_true(&mut self, value: bool, f: F) -> &mut Self - where - F: FnOnce(&mut ResponseBuilder), - { - if value { - f(self); - } - self - } - - /// This method calls provided closure with builder reference if value is `Some`. - #[doc(hidden)] - #[deprecated = "Use an if-let construction."] - pub fn if_some(&mut self, value: Option, f: F) -> &mut Self - where - F: FnOnce(T, &mut ResponseBuilder), - { - if let Some(val) = value { - f(val, self); - } - self - } - - /// Responses extensions - #[inline] - pub fn extensions(&self) -> Ref<'_, Extensions> { - let head = self.head.as_ref().expect("cannot reuse response builder"); - head.extensions.borrow() - } - - /// Mutable reference to a the response's extensions - #[inline] - pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { - let head = self.head.as_ref().expect("cannot reuse response builder"); - head.extensions.borrow_mut() - } - - #[inline] - /// Set a body and generate `Response`. - /// - /// `ResponseBuilder` can not be used after this call. - pub fn body>(&mut self, body: B) -> Response { - self.message_body(body.into()) - } - - /// Set a body and generate `Response`. - /// - /// `ResponseBuilder` can not be used after this call. - pub fn message_body(&mut self, body: B) -> Response { - if let Some(e) = self.err.take() { - return Response::from(Error::from(e)).into_body(); - } - - // allow unused mut when cookies feature is disabled - #[allow(unused_mut)] - let mut response = self.head.take().expect("cannot reuse response builder"); - - #[cfg(feature = "cookies")] - if let Some(ref jar) = self.cookies { - for cookie in jar.delta() { - match HeaderValue::from_str(&cookie.to_string()) { - Ok(val) => response.headers.append(header::SET_COOKIE, val), - Err(e) => return Response::from(Error::from(e)).into_body(), - }; - } - } - - Response { - head: response, - body: ResponseBody::Body(body), - error: None, - } - } - - #[inline] - /// Set a streaming body and generate `Response`. - /// - /// `ResponseBuilder` can not be used after this call. - pub fn streaming(&mut self, stream: S) -> Response - where - S: Stream> + Unpin + 'static, - E: Into + 'static, - { - self.body(Body::from_message(BodyStream::new(stream))) - } - - /// Set a json body and generate `Response` - /// - /// `ResponseBuilder` can not be used after this call. - pub fn json(&mut self, value: T) -> Response - where - T: ops::Deref, - T::Target: Serialize, - { - match serde_json::to_string(&*value) { - Ok(body) => { - let contains = if let Some(parts) = parts(&mut self.head, &self.err) { - parts.headers.contains_key(header::CONTENT_TYPE) - } else { - true - }; - - if !contains { - self.insert_header(header::ContentType(mime::APPLICATION_JSON)); - } - - self.body(Body::from(body)) - } - Err(e) => Error::from(e).into(), - } - } - - #[inline] - /// Set an empty body and generate `Response` - /// - /// `ResponseBuilder` can not be used after this call. - pub fn finish(&mut self) -> Response { - self.body(Body::Empty) - } - - /// This method construct new `ResponseBuilder` - pub fn take(&mut self) -> ResponseBuilder { - ResponseBuilder { - head: self.head.take(), - err: self.err.take(), - #[cfg(feature = "cookies")] - cookies: self.cookies.take(), - } - } -} - -#[inline] -fn parts<'a>( - parts: &'a mut Option, - err: &Option, -) -> Option<&'a mut ResponseHead> { - if err.is_some() { - return None; - } - parts.as_mut().map(|r| &mut **r) -} - -/// Convert `Response` to a `ResponseBuilder`. Body get dropped. -impl From> for ResponseBuilder { - fn from(res: Response) -> ResponseBuilder { - #[cfg(feature = "cookies")] - let jar = { - // If this response has cookies, load them into a jar - let mut jar: Option = None; - - for c in res.cookies() { - if let Some(ref mut j) = jar { - j.add_original(c.into_owned()); - } else { - let mut j = CookieJar::new(); - j.add_original(c.into_owned()); - jar = Some(j); - } - } - - jar - }; - - ResponseBuilder { - head: Some(res.head), - err: None, - #[cfg(feature = "cookies")] - cookies: jar, - } - } -} - -/// Convert `ResponseHead` to a `ResponseBuilder` -impl<'a> From<&'a ResponseHead> for ResponseBuilder { - fn from(head: &'a ResponseHead) -> ResponseBuilder { - let mut msg = BoxedResponseHead::new(head.status); - msg.version = head.version; - msg.reason = head.reason; - - for (k, v) in head.headers.iter() { - msg.headers.append(k.clone(), v.clone()); - } - - msg.no_chunking(!head.chunked()); - - #[cfg(feature = "cookies")] - let jar = { - // If this response has cookies, load them into a jar - let mut jar: Option = None; - - let cookies = CookieIter { - iter: head.headers.get_all(header::SET_COOKIE), - }; - - for c in cookies { - if let Some(ref mut j) = jar { - j.add_original(c.into_owned()); - } else { - let mut j = CookieJar::new(); - j.add_original(c.into_owned()); - jar = Some(j); - } - } - - jar - }; - - ResponseBuilder { - head: Some(msg), - err: None, - #[cfg(feature = "cookies")] - cookies: jar, - } - } -} - -impl Future for ResponseBuilder { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { - Poll::Ready(Ok(self.finish())) - } -} - -impl fmt::Debug for ResponseBuilder { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let head = self.head.as_ref().unwrap(); - - let res = writeln!( - f, - "\nResponseBuilder {:?} {}{}", - head.version, - head.status, - head.reason.unwrap_or(""), - ); - let _ = writeln!(f, " headers:"); - for (key, val) in head.headers.iter() { - let _ = writeln!(f, " {:?}: {:?}", key, val); - } - res - } -} - -/// Helper converters -impl, E: Into> From> for Response { - fn from(res: Result) -> Self { - match res { - Ok(val) => val.into(), - Err(err) => err.into().into(), - } - } -} - -impl From for Response { - fn from(mut builder: ResponseBuilder) -> Self { - builder.finish() - } -} - -impl From<&'static str> for Response { - fn from(val: &'static str) -> Self { - Response::Ok() - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(val) - } -} - -impl From<&'static [u8]> for Response { - fn from(val: &'static [u8]) -> Self { - Response::Ok() - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(val) - } -} - -impl From for Response { - fn from(val: String) -> Self { - Response::Ok() - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(val) - } -} - -impl<'a> From<&'a String> for Response { - fn from(val: &'a String) -> Self { - Response::Ok() - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(val) - } -} - -impl From for Response { - fn from(val: Bytes) -> Self { - Response::Ok() - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(val) - } -} - -impl From for Response { - fn from(val: BytesMut) -> Self { - Response::Ok() - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(val) - } -} - -#[cfg(test)] -mod tests { - use serde_json::json; - - use super::*; - use crate::body::Body; - use crate::http::header::{HeaderValue, CONTENT_TYPE, COOKIE, SET_COOKIE}; - use crate::HttpMessage; - - #[test] - fn test_debug() { - let resp = Response::Ok() - .append_header((COOKIE, HeaderValue::from_static("cookie1=value1; "))) - .append_header((COOKIE, HeaderValue::from_static("cookie2=value2; "))) - .finish(); - let dbg = format!("{:?}", resp); - assert!(dbg.contains("Response")); - } - - #[test] - fn test_response_cookies() { - let req = crate::test::TestRequest::default() - .append_header((COOKIE, "cookie1=value1")) - .append_header((COOKIE, "cookie2=value2")) - .finish(); - let cookies = req.cookies().unwrap(); - - let resp = Response::Ok() - .cookie( - crate::http::Cookie::build("name", "value") - .domain("www.rust-lang.org") - .path("/test") - .http_only(true) - .max_age(time::Duration::days(1)) - .finish(), - ) - .del_cookie(&cookies[0]) - .finish(); - - let mut val = resp - .headers() - .get_all(SET_COOKIE) - .map(|v| v.to_str().unwrap().to_owned()) - .collect::>(); - val.sort(); - - // the .del_cookie call - assert!(val[0].starts_with("cookie1=; Max-Age=0;")); - - // the .cookie call - assert_eq!( - val[1], - "name=value; HttpOnly; Path=/test; Domain=www.rust-lang.org; Max-Age=86400" - ); - } - - #[test] - fn test_update_response_cookies() { - let mut r = Response::Ok() - .cookie(crate::http::Cookie::new("original", "val100")) - .finish(); - - r.add_cookie(&crate::http::Cookie::new("cookie2", "val200")) - .unwrap(); - r.add_cookie(&crate::http::Cookie::new("cookie2", "val250")) - .unwrap(); - r.add_cookie(&crate::http::Cookie::new("cookie3", "val300")) - .unwrap(); - - assert_eq!(r.cookies().count(), 4); - r.del_cookie("cookie2"); - - let mut iter = r.cookies(); - let v = iter.next().unwrap(); - assert_eq!((v.name(), v.value()), ("original", "val100")); - let v = iter.next().unwrap(); - assert_eq!((v.name(), v.value()), ("cookie3", "val300")); - } - - #[test] - fn test_basic_builder() { - let resp = Response::Ok().insert_header(("X-TEST", "value")).finish(); - assert_eq!(resp.status(), StatusCode::OK); - } - - #[test] - fn test_upgrade() { - let resp = Response::build(StatusCode::OK) - .upgrade("websocket") - .finish(); - assert!(resp.upgrade()); - assert_eq!( - resp.headers().get(header::UPGRADE).unwrap(), - HeaderValue::from_static("websocket") - ); - } - - #[test] - fn test_force_close() { - let resp = Response::build(StatusCode::OK).force_close().finish(); - assert!(!resp.keep_alive()) - } - - #[test] - fn test_content_type() { - let resp = Response::build(StatusCode::OK) - .content_type("text/plain") - .body(Body::Empty); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain") - } - - #[test] - fn test_json() { - let resp = Response::build(StatusCode::OK).json(vec!["v1", "v2", "v3"]); - let ct = resp.headers().get(CONTENT_TYPE).unwrap(); - assert_eq!(ct, HeaderValue::from_static("application/json")); - assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); - } - - #[test] - fn test_json_ct() { - let resp = Response::build(StatusCode::OK) - .insert_header((CONTENT_TYPE, "text/json")) - .json(&vec!["v1", "v2", "v3"]); - let ct = resp.headers().get(CONTENT_TYPE).unwrap(); - assert_eq!(ct, HeaderValue::from_static("text/json")); - assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); - } - - #[test] - fn test_serde_json_in_body() { - use serde_json::json; - let resp = - Response::build(StatusCode::OK).body(json!({"test-key":"test-value"})); - assert_eq!(resp.body().get_ref(), br#"{"test-key":"test-value"}"#); - } - - #[test] - fn test_into_response() { - let resp: Response = "test".into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); - - let resp: Response = b"test".as_ref().into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream") - ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); - - let resp: Response = "test".to_owned().into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); - - let resp: Response = (&"test".to_owned()).into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); - - let b = Bytes::from_static(b"test"); - let resp: Response = b.into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream") - ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); - - let b = Bytes::from_static(b"test"); - let resp: Response = b.into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream") - ); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); - - let b = BytesMut::from("test"); - let resp: Response = b.into(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream") - ); - - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); - } - - #[test] - fn test_into_builder() { - let mut resp: Response = "test".into(); - assert_eq!(resp.status(), StatusCode::OK); - - resp.add_cookie(&crate::http::Cookie::new("cookie1", "val100")) - .unwrap(); - - let mut builder: ResponseBuilder = resp.into(); - let resp = builder.status(StatusCode::BAD_REQUEST).finish(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - - let cookie = resp.cookies().next().unwrap(); - assert_eq!((cookie.name(), cookie.value()), ("cookie1", "val100")); - } - - #[test] - fn response_builder_header_insert_kv() { - let mut res = Response::Ok(); - res.insert_header(("Content-Type", "application/octet-stream")); - let res = res.finish(); - - assert_eq!( - res.headers().get("Content-Type"), - Some(&HeaderValue::from_static("application/octet-stream")) - ); - } - - #[test] - fn response_builder_header_insert_typed() { - let mut res = Response::Ok(); - res.insert_header(header::ContentType(mime::APPLICATION_OCTET_STREAM)); - let res = res.finish(); - - assert_eq!( - res.headers().get("Content-Type"), - Some(&HeaderValue::from_static("application/octet-stream")) - ); - } - - #[test] - fn response_builder_header_append_kv() { - let mut res = Response::Ok(); - res.append_header(("Content-Type", "application/octet-stream")); - res.append_header(("Content-Type", "application/json")); - let res = res.finish(); - - let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); - assert_eq!(headers.len(), 2); - assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); - assert!(headers.contains(&HeaderValue::from_static("application/json"))); - } - - #[test] - fn response_builder_header_append_typed() { - let mut res = Response::Ok(); - res.append_header(header::ContentType(mime::APPLICATION_OCTET_STREAM)); - res.append_header(header::ContentType(mime::APPLICATION_JSON)); - let res = res.finish(); - - let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); - assert_eq!(headers.len(), 2); - assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); - assert!(headers.contains(&HeaderValue::from_static("application/json"))); - } -} diff --git a/actix-http/src/responses/builder.rs b/actix-http/src/responses/builder.rs new file mode 100644 index 000000000..5854863de --- /dev/null +++ b/actix-http/src/responses/builder.rs @@ -0,0 +1,441 @@ +//! HTTP response builder. + +use std::{ + cell::{Ref, RefMut}, + fmt, str, +}; + +use crate::{ + body::{EitherBody, MessageBody}, + error::{Error, HttpError}, + header::{self, TryIntoHeaderPair, TryIntoHeaderValue}, + responses::{BoxedResponseHead, ResponseHead}, + ConnectionType, Extensions, Response, StatusCode, +}; + +/// An HTTP response builder. +/// +/// Used to construct an instance of `Response` using a builder pattern. Response builders are often +/// created using [`Response::build`]. +/// +/// # Examples +/// ``` +/// use actix_http::{Response, ResponseBuilder, StatusCode, body, header}; +/// +/// # actix_rt::System::new().block_on(async { +/// let mut res: Response<_> = Response::build(StatusCode::OK) +/// .content_type(mime::APPLICATION_JSON) +/// .insert_header((header::SERVER, "my-app/1.0")) +/// .append_header((header::SET_COOKIE, "a=1")) +/// .append_header((header::SET_COOKIE, "b=2")) +/// .body("1234"); +/// +/// assert_eq!(res.status(), StatusCode::OK); +/// +/// assert!(res.headers().contains_key("server")); +/// assert_eq!(res.headers().get_all("set-cookie").count(), 2); +/// +/// assert_eq!(body::to_bytes(res.into_body()).await.unwrap(), &b"1234"[..]); +/// # }) +/// ``` +pub struct ResponseBuilder { + head: Option, + err: Option, +} + +impl ResponseBuilder { + /// Create response builder + /// + /// # Examples + /// ``` + /// use actix_http::{Response, ResponseBuilder, StatusCode}; + /// let res: Response<_> = ResponseBuilder::default().finish(); + /// assert_eq!(res.status(), StatusCode::OK); + /// ``` + #[inline] + pub fn new(status: StatusCode) -> Self { + ResponseBuilder { + head: Some(BoxedResponseHead::new(status)), + err: None, + } + } + + /// Set HTTP status code of this response. + /// + /// # Examples + /// ``` + /// use actix_http::{ResponseBuilder, StatusCode}; + /// let res = ResponseBuilder::default().status(StatusCode::NOT_FOUND).finish(); + /// assert_eq!(res.status(), StatusCode::NOT_FOUND); + /// ``` + #[inline] + pub fn status(&mut self, status: StatusCode) -> &mut Self { + if let Some(parts) = self.inner() { + parts.status = status; + } + self + } + + /// Insert a header, replacing any that were set with an equivalent field name. + /// + /// # Examples + /// ``` + /// use actix_http::{ResponseBuilder, header}; + /// + /// let res = ResponseBuilder::default() + /// .insert_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)) + /// .insert_header(("X-TEST", "value")) + /// .finish(); + /// + /// assert!(res.headers().contains_key("content-type")); + /// assert!(res.headers().contains_key("x-test")); + /// ``` + pub fn insert_header(&mut self, header: impl TryIntoHeaderPair) -> &mut Self { + if let Some(parts) = self.inner() { + match header.try_into_pair() { + Ok((key, value)) => { + parts.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + }; + } + + self + } + + /// Append a header, keeping any that were set with an equivalent field name. + /// + /// # Examples + /// ``` + /// use actix_http::{ResponseBuilder, header}; + /// + /// let res = ResponseBuilder::default() + /// .append_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)) + /// .append_header(("X-TEST", "value1")) + /// .append_header(("X-TEST", "value2")) + /// .finish(); + /// + /// assert_eq!(res.headers().get_all("content-type").count(), 1); + /// assert_eq!(res.headers().get_all("x-test").count(), 2); + /// ``` + pub fn append_header(&mut self, header: impl TryIntoHeaderPair) -> &mut Self { + if let Some(parts) = self.inner() { + match header.try_into_pair() { + Ok((key, value)) => parts.headers.append(key, value), + Err(e) => self.err = Some(e.into()), + }; + } + + self + } + + /// Set the custom reason for the response. + #[inline] + pub fn reason(&mut self, reason: &'static str) -> &mut Self { + if let Some(parts) = self.inner() { + parts.reason = Some(reason); + } + self + } + + /// Set connection type to KeepAlive + #[inline] + pub fn keep_alive(&mut self) -> &mut Self { + if let Some(parts) = self.inner() { + parts.set_connection_type(ConnectionType::KeepAlive); + } + self + } + + /// Set connection type to Upgrade + #[inline] + pub fn upgrade(&mut self, value: V) -> &mut Self + where + V: TryIntoHeaderValue, + { + if let Some(parts) = self.inner() { + parts.set_connection_type(ConnectionType::Upgrade); + } + + if let Ok(value) = value.try_into_value() { + self.insert_header((header::UPGRADE, value)); + } + + self + } + + /// Force close connection, even if it is marked as keep-alive + #[inline] + pub fn force_close(&mut self) -> &mut Self { + if let Some(parts) = self.inner() { + parts.set_connection_type(ConnectionType::Close); + } + self + } + + /// Disable chunked transfer encoding for HTTP/1.1 streaming responses. + #[inline] + pub fn no_chunking(&mut self, len: u64) -> &mut Self { + let mut buf = itoa::Buffer::new(); + self.insert_header((header::CONTENT_LENGTH, buf.format(len))); + + if let Some(parts) = self.inner() { + parts.no_chunking(true); + } + self + } + + /// Set response content type. + #[inline] + pub fn content_type(&mut self, value: V) -> &mut Self + where + V: TryIntoHeaderValue, + { + if let Some(parts) = self.inner() { + match value.try_into_value() { + Ok(value) => { + parts.headers.insert(header::CONTENT_TYPE, value); + } + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Responses extensions + #[inline] + pub fn extensions(&self) -> Ref<'_, Extensions> { + let head = self.head.as_ref().expect("cannot reuse response builder"); + head.extensions.borrow() + } + + /// Mutable reference to a the response's extensions + #[inline] + pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { + let head = self.head.as_ref().expect("cannot reuse response builder"); + head.extensions.borrow_mut() + } + + /// Generate response with a wrapped body. + /// + /// This `ResponseBuilder` will be left in a useless state. + pub fn body(&mut self, body: B) -> Response> + where + B: MessageBody + 'static, + { + match self.message_body(body) { + Ok(res) => res.map_body(|_, body| EitherBody::left(body)), + Err(err) => Response::from(err).map_body(|_, body| EitherBody::right(body)), + } + } + + /// Generate response with a body. + /// + /// This `ResponseBuilder` will be left in a useless state. + pub fn message_body(&mut self, body: B) -> Result, Error> { + if let Some(err) = self.err.take() { + return Err(Error::new_http().with_cause(err)); + } + + let head = self.head.take().expect("cannot reuse response builder"); + Ok(Response { head, body }) + } + + /// Generate response with an empty body. + /// + /// This `ResponseBuilder` will be left in a useless state. + #[inline] + pub fn finish(&mut self) -> Response> { + self.body(()) + } + + /// Create an owned `ResponseBuilder`, leaving the original in a useless state. + pub fn take(&mut self) -> ResponseBuilder { + ResponseBuilder { + head: self.head.take(), + err: self.err.take(), + } + } + + /// Get access to the inner response head if there has been no error. + fn inner(&mut self) -> Option<&mut ResponseHead> { + if self.err.is_some() { + return None; + } + + self.head.as_deref_mut() + } +} + +impl Default for ResponseBuilder { + fn default() -> Self { + Self::new(StatusCode::OK) + } +} + +/// Convert `Response` to a `ResponseBuilder`. Body get dropped. +impl From> for ResponseBuilder { + fn from(res: Response) -> ResponseBuilder { + ResponseBuilder { + head: Some(res.head), + err: None, + } + } +} + +/// Convert `ResponseHead` to a `ResponseBuilder` +impl<'a> From<&'a ResponseHead> for ResponseBuilder { + fn from(head: &'a ResponseHead) -> ResponseBuilder { + let mut msg = BoxedResponseHead::new(head.status); + msg.version = head.version; + msg.reason = head.reason; + + for (k, v) in head.headers.iter() { + msg.headers.append(k.clone(), v.clone()); + } + + msg.no_chunking(!head.chunked()); + + ResponseBuilder { + head: Some(msg), + err: None, + } + } +} + +impl fmt::Debug for ResponseBuilder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let head = self.head.as_ref().unwrap(); + + let res = writeln!( + f, + "\nResponseBuilder {:?} {}{}", + head.version, + head.status, + head.reason.unwrap_or(""), + ); + let _ = writeln!(f, " headers:"); + for (key, val) in head.headers.iter() { + let _ = writeln!(f, " {:?}: {:?}", key, val); + } + res + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::*; + use crate::header::{HeaderName, HeaderValue, CONTENT_TYPE}; + + #[test] + fn test_basic_builder() { + let resp = Response::build(StatusCode::OK) + .insert_header(("X-TEST", "value")) + .finish(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[test] + fn test_upgrade() { + let resp = Response::build(StatusCode::OK) + .upgrade("websocket") + .finish(); + assert!(resp.upgrade()); + assert_eq!( + resp.headers().get(header::UPGRADE).unwrap(), + HeaderValue::from_static("websocket") + ); + } + + #[test] + fn test_force_close() { + let resp = Response::build(StatusCode::OK).force_close().finish(); + assert!(!resp.keep_alive()); + } + + #[test] + fn test_content_type() { + let resp = Response::build(StatusCode::OK) + .content_type("text/plain") + .body(Bytes::new()); + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain"); + + let resp = Response::build(StatusCode::OK) + .content_type(mime::APPLICATION_JAVASCRIPT_UTF_8) + .body(Bytes::new()); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + "application/javascript; charset=utf-8" + ); + } + + #[test] + fn test_into_builder() { + let mut resp: Response<_> = "test".into(); + assert_eq!(resp.status(), StatusCode::OK); + + resp.headers_mut().insert( + HeaderName::from_static("cookie"), + HeaderValue::from_static("cookie1=val100"), + ); + + let mut builder: ResponseBuilder = resp.into(); + let resp = builder.status(StatusCode::BAD_REQUEST).finish(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let cookie = resp.headers().get_all("Cookie").next().unwrap(); + assert_eq!(cookie.to_str().unwrap(), "cookie1=val100"); + } + + #[test] + fn response_builder_header_insert_kv() { + let mut res = Response::build(StatusCode::OK); + res.insert_header(("Content-Type", "application/octet-stream")); + let res = res.finish(); + + assert_eq!( + res.headers().get("Content-Type"), + Some(&HeaderValue::from_static("application/octet-stream")) + ); + } + + #[test] + fn response_builder_header_insert_typed() { + let mut res = Response::build(StatusCode::OK); + res.insert_header((header::CONTENT_TYPE, mime::APPLICATION_OCTET_STREAM)); + let res = res.finish(); + + assert_eq!( + res.headers().get("Content-Type"), + Some(&HeaderValue::from_static("application/octet-stream")) + ); + } + + #[test] + fn response_builder_header_append_kv() { + let mut res = Response::build(StatusCode::OK); + res.append_header(("Content-Type", "application/octet-stream")); + res.append_header(("Content-Type", "application/json")); + let res = res.finish(); + + let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); + assert_eq!(headers.len(), 2); + assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); + assert!(headers.contains(&HeaderValue::from_static("application/json"))); + } + + #[test] + fn response_builder_header_append_typed() { + let mut res = Response::build(StatusCode::OK); + res.append_header((header::CONTENT_TYPE, mime::APPLICATION_OCTET_STREAM)); + res.append_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)); + let res = res.finish(); + + let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); + assert_eq!(headers.len(), 2); + assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); + assert!(headers.contains(&HeaderValue::from_static("application/json"))); + } +} diff --git a/actix-http/src/responses/head.rs b/actix-http/src/responses/head.rs new file mode 100644 index 000000000..78d9536e5 --- /dev/null +++ b/actix-http/src/responses/head.rs @@ -0,0 +1,208 @@ +//! Response head type and caching pool. + +use std::{ + cell::{Ref, RefCell, RefMut}, + ops, +}; + +use crate::{ + header::HeaderMap, message::Flags, ConnectionType, Extensions, StatusCode, Version, +}; + +thread_local! { + static RESPONSE_POOL: BoxedResponsePool = BoxedResponsePool::create(); +} + +#[derive(Debug)] +pub struct ResponseHead { + pub version: Version, + pub status: StatusCode, + pub headers: HeaderMap, + pub reason: Option<&'static str>, + pub(crate) extensions: RefCell, + flags: Flags, +} + +impl ResponseHead { + /// Create new instance of `ResponseHead` type + #[inline] + pub fn new(status: StatusCode) -> ResponseHead { + ResponseHead { + status, + version: Version::HTTP_11, + headers: HeaderMap::with_capacity(12), + reason: None, + flags: Flags::empty(), + extensions: RefCell::new(Extensions::new()), + } + } + + #[inline] + /// Read the message headers. + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + #[inline] + /// Mutable reference to the message headers. + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers + } + + /// Message extensions + #[inline] + pub fn extensions(&self) -> Ref<'_, Extensions> { + self.extensions.borrow() + } + + /// Mutable reference to a the message's extensions + #[inline] + pub fn extensions_mut(&self) -> RefMut<'_, Extensions> { + self.extensions.borrow_mut() + } + + #[inline] + /// Set connection type of the message + pub fn set_connection_type(&mut self, ctype: ConnectionType) { + match ctype { + ConnectionType::Close => self.flags.insert(Flags::CLOSE), + ConnectionType::KeepAlive => self.flags.insert(Flags::KEEP_ALIVE), + ConnectionType::Upgrade => self.flags.insert(Flags::UPGRADE), + } + } + + #[inline] + pub fn connection_type(&self) -> ConnectionType { + if self.flags.contains(Flags::CLOSE) { + ConnectionType::Close + } else if self.flags.contains(Flags::KEEP_ALIVE) { + ConnectionType::KeepAlive + } else if self.flags.contains(Flags::UPGRADE) { + ConnectionType::Upgrade + } else if self.version < Version::HTTP_11 { + ConnectionType::Close + } else { + ConnectionType::KeepAlive + } + } + + /// Check if keep-alive is enabled + #[inline] + pub fn keep_alive(&self) -> bool { + self.connection_type() == ConnectionType::KeepAlive + } + + /// Check upgrade status of this message + #[inline] + pub fn upgrade(&self) -> bool { + self.connection_type() == ConnectionType::Upgrade + } + + /// Get custom reason for the response + #[inline] + pub fn reason(&self) -> &str { + self.reason.unwrap_or_else(|| { + self.status + .canonical_reason() + .unwrap_or("") + }) + } + + #[inline] + pub(crate) fn conn_type(&self) -> Option { + if self.flags.contains(Flags::CLOSE) { + Some(ConnectionType::Close) + } else if self.flags.contains(Flags::KEEP_ALIVE) { + Some(ConnectionType::KeepAlive) + } else if self.flags.contains(Flags::UPGRADE) { + Some(ConnectionType::Upgrade) + } else { + None + } + } + + #[inline] + /// Get response body chunking state + pub fn chunked(&self) -> bool { + !self.flags.contains(Flags::NO_CHUNKING) + } + + #[inline] + /// Set no chunking for payload + pub fn no_chunking(&mut self, val: bool) { + if val { + self.flags.insert(Flags::NO_CHUNKING); + } else { + self.flags.remove(Flags::NO_CHUNKING); + } + } +} + +pub(crate) struct BoxedResponseHead { + head: Option>, +} + +impl BoxedResponseHead { + /// Get new message from the pool of objects + pub fn new(status: StatusCode) -> Self { + RESPONSE_POOL.with(|p| p.get_message(status)) + } +} + +impl ops::Deref for BoxedResponseHead { + type Target = ResponseHead; + + fn deref(&self) -> &Self::Target { + self.head.as_ref().unwrap() + } +} + +impl ops::DerefMut for BoxedResponseHead { + fn deref_mut(&mut self) -> &mut Self::Target { + self.head.as_mut().unwrap() + } +} + +impl Drop for BoxedResponseHead { + fn drop(&mut self) { + if let Some(head) = self.head.take() { + RESPONSE_POOL.with(move |p| p.release(head)) + } + } +} + +/// Request's objects pool +#[doc(hidden)] +pub struct BoxedResponsePool(#[allow(clippy::vec_box)] RefCell>>); + +impl BoxedResponsePool { + fn create() -> BoxedResponsePool { + BoxedResponsePool(RefCell::new(Vec::with_capacity(128))) + } + + /// Get message from the pool + #[inline] + fn get_message(&self, status: StatusCode) -> BoxedResponseHead { + if let Some(mut head) = self.0.borrow_mut().pop() { + head.reason = None; + head.status = status; + head.headers.clear(); + head.flags = Flags::empty(); + BoxedResponseHead { head: Some(head) } + } else { + BoxedResponseHead { + head: Some(Box::new(ResponseHead::new(status))), + } + } + } + + /// Release request instance + #[inline] + fn release(&self, mut msg: Box) { + let pool = &mut self.0.borrow_mut(); + if pool.len() < 128 { + msg.extensions.get_mut().clear(); + pool.push(msg); + } + } +} diff --git a/actix-http/src/responses/mod.rs b/actix-http/src/responses/mod.rs new file mode 100644 index 000000000..899232b9f --- /dev/null +++ b/actix-http/src/responses/mod.rs @@ -0,0 +1,11 @@ +//! HTTP response. + +mod builder; +mod head; +#[allow(clippy::module_inception)] +mod response; + +pub use self::builder::ResponseBuilder; +pub(crate) use self::head::BoxedResponseHead; +pub use self::head::ResponseHead; +pub use self::response::Response; diff --git a/actix-http/src/responses/response.rs b/actix-http/src/responses/response.rs new file mode 100644 index 000000000..ec9157afb --- /dev/null +++ b/actix-http/src/responses/response.rs @@ -0,0 +1,404 @@ +//! HTTP response. + +use std::{ + cell::{Ref, RefMut}, + fmt, str, +}; + +use bytes::{Bytes, BytesMut}; +use bytestring::ByteString; + +use crate::{ + body::{BoxBody, MessageBody}, + header::{self, HeaderMap, TryIntoHeaderValue}, + responses::BoxedResponseHead, + Error, Extensions, ResponseBuilder, ResponseHead, StatusCode, +}; + +/// An HTTP response. +pub struct Response { + pub(crate) head: BoxedResponseHead, + pub(crate) body: B, +} + +impl Response { + /// Constructs a new response with default body. + #[inline] + pub fn new(status: StatusCode) -> Self { + Response { + head: BoxedResponseHead::new(status), + body: BoxBody::new(()), + } + } + + /// Constructs a new response builder. + #[inline] + pub fn build(status: StatusCode) -> ResponseBuilder { + ResponseBuilder::new(status) + } + + // just a couple frequently used shortcuts + // this list should not grow larger than a few + + /// Constructs a new response with status 200 OK. + #[inline] + pub fn ok() -> Self { + Response::new(StatusCode::OK) + } + + /// Constructs a new response with status 400 Bad Request. + #[inline] + pub fn bad_request() -> Self { + Response::new(StatusCode::BAD_REQUEST) + } + + /// Constructs a new response with status 404 Not Found. + #[inline] + pub fn not_found() -> Self { + Response::new(StatusCode::NOT_FOUND) + } + + /// Constructs a new response with status 500 Internal Server Error. + #[inline] + pub fn internal_server_error() -> Self { + Response::new(StatusCode::INTERNAL_SERVER_ERROR) + } + + // end shortcuts +} + +impl Response { + /// Constructs a new response with given body. + #[inline] + pub fn with_body(status: StatusCode, body: B) -> Response { + Response { + head: BoxedResponseHead::new(status), + body, + } + } + + /// Returns a reference to the head of this response. + #[inline] + pub fn head(&self) -> &ResponseHead { + &*self.head + } + + /// Returns a mutable reference to the head of this response. + #[inline] + pub fn head_mut(&mut self) -> &mut ResponseHead { + &mut *self.head + } + + /// Returns the status code of this response. + #[inline] + pub fn status(&self) -> StatusCode { + self.head.status + } + + /// Returns a mutable reference the status code of this response. + #[inline] + pub fn status_mut(&mut self) -> &mut StatusCode { + &mut self.head.status + } + + /// Returns a reference to response headers. + #[inline] + pub fn headers(&self) -> &HeaderMap { + &self.head.headers + } + + /// Returns a mutable reference to response headers. + #[inline] + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.head.headers + } + + /// Returns true if connection upgrade is enabled. + #[inline] + pub fn upgrade(&self) -> bool { + self.head.upgrade() + } + + /// Returns true if keep-alive is enabled. + pub fn keep_alive(&self) -> bool { + self.head.keep_alive() + } + + /// Returns a reference to the extensions of this response. + #[inline] + pub fn extensions(&self) -> Ref<'_, Extensions> { + self.head.extensions.borrow() + } + + /// Returns a mutable reference to the extensions of this response. + #[inline] + pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { + self.head.extensions.borrow_mut() + } + + /// Returns a reference to the body of this response. + #[inline] + pub fn body(&self) -> &B { + &self.body + } + + /// Sets new body. + pub fn set_body(self, body: B2) -> Response { + Response { + head: self.head, + body, + } + } + + /// Drops body and returns new response. + pub fn drop_body(self) -> Response<()> { + self.set_body(()) + } + + /// Sets new body, returning new response and previous body value. + pub(crate) fn replace_body(self, body: B2) -> (Response, B) { + ( + Response { + head: self.head, + body, + }, + self.body, + ) + } + + /// Returns split head and body. + /// + /// # Implementation Notes + /// Due to internal performance optimizations, the first element of the returned tuple is a + /// `Response` as well but only contains the head of the response this was called on. + pub fn into_parts(self) -> (Response<()>, B) { + self.replace_body(()) + } + + /// Returns new response with mapped body. + pub fn map_body(mut self, f: F) -> Response + where + F: FnOnce(&mut ResponseHead, B) -> B2, + { + let body = f(&mut self.head, self.body); + + Response { + head: self.head, + body, + } + } + + #[inline] + pub fn map_into_boxed_body(self) -> Response + where + B: MessageBody + 'static, + { + self.map_body(|_, body| body.boxed()) + } + + /// Returns body, consuming this response. + pub fn into_body(self) -> B { + self.body + } +} + +impl fmt::Debug for Response +where + B: MessageBody, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let res = writeln!( + f, + "\nResponse {:?} {}{}", + self.head.version, + self.head.status, + self.head.reason.unwrap_or(""), + ); + let _ = writeln!(f, " headers:"); + for (key, val) in self.head.headers.iter() { + let _ = writeln!(f, " {:?}: {:?}", key, val); + } + let _ = writeln!(f, " body: {:?}", self.body.size()); + res + } +} + +impl Default for Response { + #[inline] + fn default() -> Response { + Response::with_body(StatusCode::default(), B::default()) + } +} + +impl>, E: Into> From> for Response { + fn from(res: Result) -> Self { + match res { + Ok(val) => val.into(), + Err(err) => Response::from(err.into()), + } + } +} + +impl From for Response { + fn from(mut builder: ResponseBuilder) -> Self { + builder.finish().map_into_boxed_body() + } +} + +impl From for Response { + fn from(val: std::convert::Infallible) -> Self { + match val {} + } +} + +impl From<&'static str> for Response<&'static str> { + fn from(val: &'static str) -> Self { + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res + } +} + +impl From<&'static [u8]> for Response<&'static [u8]> { + fn from(val: &'static [u8]) -> Self { + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::APPLICATION_OCTET_STREAM.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res + } +} + +impl From for Response { + fn from(val: String) -> Self { + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res + } +} + +impl From<&String> for Response { + fn from(val: &String) -> Self { + let mut res = Response::with_body(StatusCode::OK, val.clone()); + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res + } +} + +impl From for Response { + fn from(val: Bytes) -> Self { + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::APPLICATION_OCTET_STREAM.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res + } +} + +impl From for Response { + fn from(val: BytesMut) -> Self { + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::APPLICATION_OCTET_STREAM.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res + } +} + +impl From for Response { + fn from(val: ByteString) -> Self { + let mut res = Response::with_body(StatusCode::OK, val); + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + res + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + body::to_bytes, + header::{HeaderValue, CONTENT_TYPE, COOKIE}, + }; + + #[test] + fn test_debug() { + let resp = Response::build(StatusCode::OK) + .append_header((COOKIE, HeaderValue::from_static("cookie1=value1; "))) + .append_header((COOKIE, HeaderValue::from_static("cookie2=value2; "))) + .finish(); + let dbg = format!("{:?}", resp); + assert!(dbg.contains("Response")); + } + + #[actix_rt::test] + async fn test_into_response() { + let res = Response::from("test"); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); + + let res = Response::from(b"test".as_ref()); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); + + let res = Response::from("test".to_owned()); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); + + let res = Response::from("test".to_owned()); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); + + let b = Bytes::from_static(b"test"); + let res = Response::from(b); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); + + let b = Bytes::from_static(b"test"); + let res = Response::from(b); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); + + let b = BytesMut::from("test"); + let res = Response::from(b); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(to_bytes(res.into_body()).await.unwrap(), &b"test"[..]); + } +} diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index fee26dcc3..cd2efe678 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -1,23 +1,28 @@ -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{fmt, net, rc::Rc}; +use std::{ + fmt, + future::Future, + marker::PhantomData, + net, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_rt::net::TcpStream; -use actix_service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; -use bytes::Bytes; -use futures_core::{ready, Future}; -use h2::server::{self, Handshake}; -use pin_project::pin_project; +use actix_service::{ + fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _, +}; +use futures_core::{future::LocalBoxFuture, ready}; +use pin_project_lite::pin_project; -use crate::body::MessageBody; -use crate::builder::HttpServiceBuilder; -use crate::config::{KeepAlive, ServiceConfig}; -use crate::error::{DispatchError, Error}; -use crate::request::Request; -use crate::response::Response; -use crate::{h1, h2::Dispatcher, ConnectCallback, OnConnectData, Protocol}; +use crate::{ + body::{BoxBody, MessageBody}, + builder::HttpServiceBuilder, + config::{KeepAlive, ServiceConfig}, + error::DispatchError, + h1, h2, ConnectCallback, OnConnectData, Protocol, Request, Response, +}; /// A `ServiceFactory` for HTTP/1.1 or HTTP/2 protocol. pub struct HttpService { @@ -32,7 +37,7 @@ pub struct HttpService { impl HttpService where S: ServiceFactory, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, @@ -47,7 +52,7 @@ where impl HttpService where S: ServiceFactory, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, @@ -86,7 +91,7 @@ where impl HttpService where S: ServiceFactory, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, @@ -100,9 +105,8 @@ where pub fn expect(self, expect: X1) -> HttpService where X1: ServiceFactory, - X1::Error: Into, + X1::Error: Into>, X1::InitError: fmt::Debug, - >::Future: 'static, { HttpService { expect, @@ -123,7 +127,6 @@ where U1: ServiceFactory<(Request, Framed), Config = (), Response = ()>, U1::Error: fmt::Display, U1::InitError: fmt::Debug, - )>>::Future: 'static, { HttpService { upgrade, @@ -145,23 +148,23 @@ where impl HttpService where S: ServiceFactory, - S::Error: Into + 'static, + S::Future: 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, + X: ServiceFactory, - X::Error: Into, + X::Future: 'static, + X::Error: Into>, X::InitError: fmt::Debug, - >::Future: 'static, - U: ServiceFactory< - (Request, Framed), - Config = (), - Response = (), - >, - U::Error: fmt::Display + Into, + + U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, + U::Future: 'static, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, - )>>::Future: 'static, { /// Create simple tcp stream service pub fn tcp( @@ -173,7 +176,7 @@ where Error = DispatchError, InitError = (), > { - pipeline_factory(|io: TcpStream| async { + fn_service(|io: TcpStream| async { let peer_addr = io.peer_addr().ok(); Ok((io, Protocol::Http1, peer_addr)) }) @@ -183,33 +186,43 @@ where #[cfg(feature = "openssl")] mod openssl { - use super::*; - use actix_service::ServiceFactoryExt; - use actix_tls::accept::openssl::{Acceptor, SslAcceptor, SslError, SslStream}; - use actix_tls::accept::TlsError; + use actix_service::ServiceFactoryExt as _; + use actix_tls::accept::{ + openssl::{ + reexports::{Error as SslError, SslAcceptor}, + Acceptor, TlsStream, + }, + TlsError, + }; - impl HttpService, S, B, X, U> + use super::*; + + impl HttpService, S, B, X, U> where S: ServiceFactory, - S::Error: Into + 'static, + S::Future: 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, + X: ServiceFactory, - X::Error: Into, + X::Future: 'static, + X::Error: Into>, X::InitError: fmt::Debug, - >::Future: 'static, + U: ServiceFactory< - (Request, Framed, h1::Codec>), + (Request, Framed, h1::Codec>), Config = (), Response = (), >, - U::Error: fmt::Display + Into, + U::Future: 'static, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, - , h1::Codec>)>>::Future: 'static, { - /// Create openssl based service + /// Create OpenSSL based service. pub fn openssl( self, acceptor: SslAcceptor, @@ -220,25 +233,26 @@ mod openssl { Error = TlsError, InitError = (), > { - pipeline_factory( - Acceptor::new(acceptor) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()), - ) - .and_then(|io: SslStream| async { - let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() { - if protos.windows(2).any(|window| window == b"h2") { - Protocol::Http2 + Acceptor::new(acceptor) + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) + .map(|io: TlsStream| { + let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() { + if protos.windows(2).any(|window| window == b"h2") { + Protocol::Http2 + } else { + Protocol::Http1 + } } else { Protocol::Http1 - } - } else { - Protocol::Http1 - }; - let peer_addr = io.get_ref().peer_addr().ok(); - Ok((io, proto, peer_addr)) - }) - .and_then(self.map_err(TlsError::Service)) + }; + + let peer_addr = io.get_ref().peer_addr().ok(); + (io, proto, peer_addr) + }) + .and_then(self.map_err(TlsError::Service)) } } } @@ -247,34 +261,40 @@ mod openssl { mod rustls { use std::io; - use actix_tls::accept::rustls::{Acceptor, ServerConfig, Session, TlsStream}; - use actix_tls::accept::TlsError; + use actix_service::ServiceFactoryExt as _; + use actix_tls::accept::{ + rustls::{reexports::ServerConfig, Acceptor, TlsStream}, + TlsError, + }; use super::*; - use actix_service::ServiceFactoryExt; impl HttpService, S, B, X, U> where S: ServiceFactory, - S::Error: Into + 'static, + S::Future: 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, + X: ServiceFactory, - X::Error: Into, + X::Future: 'static, + X::Error: Into>, X::InitError: fmt::Debug, - >::Future: 'static, + U: ServiceFactory< (Request, Framed, h1::Codec>), Config = (), Response = (), >, - U::Error: fmt::Display + Into, + U::Future: 'static, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, - , h1::Codec>)>>::Future: 'static, { - /// Create openssl based service + /// Create Rustls based service. pub fn rustls( self, mut config: ServerConfig, @@ -285,28 +305,29 @@ mod rustls { Error = TlsError, InitError = (), > { - let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()]; - config.set_protocols(&protos); + let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + protos.extend_from_slice(&config.alpn_protocols); + config.alpn_protocols = protos; - pipeline_factory( - Acceptor::new(config) - .map_err(TlsError::Tls) - .map_init_err(|_| panic!()), - ) - .and_then(|io: TlsStream| async { - let proto = if let Some(protos) = io.get_ref().1.get_alpn_protocol() { - if protos.windows(2).any(|window| window == b"h2") { - Protocol::Http2 + Acceptor::new(config) + .map_init_err(|_| { + unreachable!("TLS acceptor service factory does not error on init") + }) + .map_err(TlsError::into_service_error) + .and_then(|io: TlsStream| async { + let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() { + if protos.windows(2).any(|window| window == b"h2") { + Protocol::Http2 + } else { + Protocol::Http1 + } } else { Protocol::Http1 - } - } else { - Protocol::Http1 - }; - let peer_addr = io.get_ref().0.peer_addr().ok(); - Ok((io, proto, peer_addr)) - }) - .and_then(self.map_err(TlsError::Service)) + }; + let peer_addr = io.get_ref().0.peer_addr().ok(); + Ok((io, proto, peer_addr)) + }) + .and_then(self.map_err(TlsError::Service)) } } } @@ -314,137 +335,124 @@ mod rustls { impl ServiceFactory<(T, Protocol, Option)> for HttpService where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + 'static, + S: ServiceFactory, - S::Error: Into + 'static, + S::Future: 'static, + S::Error: Into> + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, + B: MessageBody + 'static, + X: ServiceFactory, - X::Error: Into, + X::Future: 'static, + X::Error: Into>, X::InitError: fmt::Debug, - >::Future: 'static, + U: ServiceFactory<(Request, Framed), Config = (), Response = ()>, - U::Error: fmt::Display + Into, + U::Future: 'static, + U::Error: fmt::Display + Into>, U::InitError: fmt::Debug, - )>>::Future: 'static, { type Response = (); type Error = DispatchError; type Config = (); type Service = HttpServiceHandler; type InitError = (); - type Future = HttpServiceResponse; + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - HttpServiceResponse { - fut: self.srv.new_service(()), - fut_ex: Some(self.expect.new_service(())), - fut_upg: self.upgrade.as_ref().map(|f| f.new_service(())), - expect: None, - upgrade: None, - on_connect_ext: self.on_connect_ext.clone(), - cfg: self.cfg.clone(), - _phantom: PhantomData, - } - } -} + let service = self.srv.new_service(()); + let expect = self.expect.new_service(()); + let upgrade = self.upgrade.as_ref().map(|s| s.new_service(())); + let on_connect_ext = self.on_connect_ext.clone(); + let cfg = self.cfg.clone(); -#[doc(hidden)] -#[pin_project] -pub struct HttpServiceResponse -where - S: ServiceFactory, - X: ServiceFactory, - U: ServiceFactory<(Request, Framed)>, -{ - #[pin] - fut: S::Future, - #[pin] - fut_ex: Option, - #[pin] - fut_upg: Option, - expect: Option, - upgrade: Option, - on_connect_ext: Option>>, - cfg: ServiceConfig, - _phantom: PhantomData, -} + Box::pin(async move { + let expect = expect + .await + .map_err(|e| log::error!("Init http expect service error: {:?}", e))?; -impl Future for HttpServiceResponse -where - T: AsyncRead + AsyncWrite + Unpin, - S: ServiceFactory, - S::Error: Into + 'static, - S::InitError: fmt::Debug, - S::Response: Into> + 'static, - >::Future: 'static, - B: MessageBody + 'static, - X: ServiceFactory, - X::Error: Into, - X::InitError: fmt::Debug, - >::Future: 'static, - U: ServiceFactory<(Request, Framed), Response = ()>, - U::Error: fmt::Display, - U::InitError: fmt::Debug, - )>>::Future: 'static, -{ - type Output = - Result, ()>; + let upgrade = match upgrade { + Some(upgrade) => { + let upgrade = upgrade + .await + .map_err(|e| log::error!("Init http upgrade service error: {:?}", e))?; + Some(upgrade) + } + None => None, + }; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut().project(); + let service = service + .await + .map_err(|e| log::error!("Init http service error: {:?}", e))?; - if let Some(fut) = this.fut_ex.as_pin_mut() { - let expect = ready!(fut - .poll(cx) - .map_err(|e| log::error!("Init http service error: {:?}", e)))?; - this = self.as_mut().project(); - *this.expect = Some(expect); - this.fut_ex.set(None); - } - - if let Some(fut) = this.fut_upg.as_pin_mut() { - let upgrade = ready!(fut - .poll(cx) - .map_err(|e| log::error!("Init http service error: {:?}", e)))?; - this = self.as_mut().project(); - *this.upgrade = Some(upgrade); - this.fut_upg.set(None); - } - - let result = ready!(this - .fut - .poll(cx) - .map_err(|e| log::error!("Init http service error: {:?}", e))); - - Poll::Ready(result.map(|service| { - let this = self.as_mut().project(); - HttpServiceHandler::new( - this.cfg.clone(), + Ok(HttpServiceHandler::new( + cfg, service, - this.expect.take().unwrap(), - this.upgrade.take(), - this.on_connect_ext.clone(), - ) - })) + expect, + upgrade, + on_connect_ext, + )) + }) } } -/// `Service` implementation for HTTP transport +/// `Service` implementation for HTTP/1 and HTTP/2 transport pub struct HttpServiceHandler where S: Service, X: Service, U: Service<(Request, Framed)>, { - flow: Rc>, - cfg: ServiceConfig, - on_connect_ext: Option>>, + pub(super) flow: Rc>, + pub(super) cfg: ServiceConfig, + pub(super) on_connect_ext: Option>>, _phantom: PhantomData, } +impl HttpServiceHandler +where + S: Service, + S::Error: Into>, + X: Service, + X::Error: Into>, + U: Service<(Request, Framed)>, + U::Error: Into>, +{ + pub(super) fn new( + cfg: ServiceConfig, + service: S, + expect: X, + upgrade: Option, + on_connect_ext: Option>>, + ) -> HttpServiceHandler { + HttpServiceHandler { + cfg, + on_connect_ext, + flow: HttpFlow::new(service, expect, upgrade), + _phantom: PhantomData, + } + } + + pub(super) fn _poll_ready( + &self, + cx: &mut Context<'_>, + ) -> Poll>> { + ready!(self.flow.expect.poll_ready(cx).map_err(Into::into))?; + + ready!(self.flow.service.poll_ready(cx).map_err(Into::into))?; + + if let Some(ref upg) = self.flow.upgrade { + ready!(upg.poll_ready(cx).map_err(Into::into))?; + }; + + Poll::Ready(Ok(())) + } +} + /// A collection of services that describe an HTTP request flow. pub(super) struct HttpFlow { pub(super) service: S, @@ -462,122 +470,64 @@ impl HttpFlow { } } -impl HttpServiceHandler -where - S: Service, - S::Error: Into + 'static, - S::Future: 'static, - S::Response: Into> + 'static, - B: MessageBody + 'static, - X: Service, - X::Error: Into, - U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display, -{ - fn new( - cfg: ServiceConfig, - service: S, - expect: X, - upgrade: Option, - on_connect_ext: Option>>, - ) -> HttpServiceHandler { - HttpServiceHandler { - cfg, - on_connect_ext, - flow: HttpFlow::new(service, expect, upgrade), - _phantom: PhantomData, - } - } -} - impl Service<(T, Protocol, Option)> for HttpServiceHandler where T: AsyncRead + AsyncWrite + Unpin, + S: Service, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, + B: MessageBody + 'static, + X: Service, - X::Error: Into, + X::Error: Into>, + U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display + Into, + U::Error: fmt::Display + Into>, { type Response = (); type Error = DispatchError; type Future = HttpServiceHandlerResponse; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - let ready = self - .flow - .expect - .poll_ready(cx) - .map_err(|e| { - let e = e.into(); - log::error!("Http service readiness error: {:?}", e); - DispatchError::Service(e) - })? - .is_ready(); - - let ready = self - .flow - .service - .poll_ready(cx) - .map_err(|e| { - let e = e.into(); - log::error!("Http service readiness error: {:?}", e); - DispatchError::Service(e) - })? - .is_ready() - && ready; - - let ready = if let Some(ref upg) = self.flow.upgrade { - upg.poll_ready(cx) - .map_err(|e| { - let e = e.into(); - log::error!("Http service readiness error: {:?}", e); - DispatchError::Service(e) - })? - .is_ready() - && ready - } else { - ready - }; - - if ready { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } + self._poll_ready(cx).map_err(|err| { + log::error!("HTTP service readiness error: {:?}", err); + DispatchError::Service(err) + }) } fn call( &self, (io, proto, peer_addr): (T, Protocol, Option), ) -> Self::Future { - let on_connect_data = - OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); + let conn_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); match proto { Protocol::Http2 => HttpServiceHandlerResponse { - state: State::H2Handshake(Some(( - server::handshake(io), - self.cfg.clone(), - self.flow.clone(), - on_connect_data, - peer_addr, - ))), + state: State::H2Handshake { + handshake: Some(( + h2::handshake_with_timeout(io, &self.cfg), + self.cfg.clone(), + self.flow.clone(), + conn_data, + peer_addr, + )), + }, }, Protocol::Http1 => HttpServiceHandlerResponse { - state: State::H1(h1::Dispatcher::new( - io, - self.cfg.clone(), - self.flow.clone(), - on_connect_data, - peer_addr, - )), + state: State::H1 { + dispatcher: h1::Dispatcher::new( + io, + self.flow.clone(), + self.cfg.clone(), + peer_addr, + conn_data, + ), + }, }, proto => unimplemented!("Unsupported HTTP version: {:?}.", proto), @@ -585,60 +535,81 @@ where } } -#[pin_project(project = StateProj)] -enum State -where - S: Service, - S::Future: 'static, - S::Error: Into, - T: AsyncRead + AsyncWrite + Unpin, - B: MessageBody, - X: Service, - X::Error: Into, - U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display, -{ - H1(#[pin] h1::Dispatcher), - H2(#[pin] Dispatcher), - H2Handshake( - Option<( - Handshake, - ServiceConfig, - Rc>, - OnConnectData, - Option, - )>, - ), +pin_project! { + #[project = StateProj] + enum State + where + T: AsyncRead, + T: AsyncWrite, + T: Unpin, + + S: Service, + S::Future: 'static, + S::Error: Into>, + + B: MessageBody, + + X: Service, + X::Error: Into>, + + U: Service<(Request, Framed), Response = ()>, + U::Error: fmt::Display, + { + H1 { #[pin] dispatcher: h1::Dispatcher }, + H2 { #[pin] dispatcher: h2::Dispatcher }, + H2Handshake { + handshake: Option<( + h2::HandshakeWithTimeout, + ServiceConfig, + Rc>, + OnConnectData, + Option, + )>, + }, + } } -#[pin_project] -pub struct HttpServiceHandlerResponse -where - T: AsyncRead + AsyncWrite + Unpin, - S: Service, - S::Error: Into + 'static, - S::Future: 'static, - S::Response: Into> + 'static, - B: MessageBody + 'static, - X: Service, - X::Error: Into, - U: Service<(Request, Framed), Response = ()>, - U::Error: fmt::Display, -{ - #[pin] - state: State, +pin_project! { + pub struct HttpServiceHandlerResponse + where + T: AsyncRead, + T: AsyncWrite, + T: Unpin, + + S: Service, + S::Error: Into>, + S::Error: 'static, + S::Future: 'static, + S::Response: Into>, + S::Response: 'static, + + B: MessageBody, + + X: Service, + X::Error: Into>, + + U: Service<(Request, Framed), Response = ()>, + U::Error: fmt::Display, + { + #[pin] + state: State, + } } impl Future for HttpServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, + S: Service, - S::Error: Into + 'static, + S::Error: Into> + 'static, S::Future: 'static, S::Response: Into> + 'static, - B: MessageBody, + + B: MessageBody + 'static, + X: Service, - X::Error: Into, + X::Error: Into>, + U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { @@ -646,26 +617,23 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.as_mut().project().state.project() { - StateProj::H1(disp) => disp.poll(cx), - StateProj::H2(disp) => disp.poll(cx), - StateProj::H2Handshake(data) => { + StateProj::H1 { dispatcher } => dispatcher.poll(cx), + StateProj::H2 { dispatcher } => dispatcher.poll(cx), + StateProj::H2Handshake { handshake: data } => { match ready!(Pin::new(&mut data.as_mut().unwrap().0).poll(cx)) { - Ok(conn) => { - let (_, cfg, srv, on_connect_data, peer_addr) = - data.take().unwrap(); - self.as_mut().project().state.set(State::H2(Dispatcher::new( - srv, - conn, - on_connect_data, - cfg, - None, - peer_addr, - ))); + Ok((conn, timer)) => { + let (_, config, flow, conn_data, peer_addr) = data.take().unwrap(); + + self.as_mut().project().state.set(State::H2 { + dispatcher: h2::Dispatcher::new( + conn, flow, config, peer_addr, conn_data, timer, + ), + }); self.poll(cx) } Err(err) => { trace!("H2 handshake error: {}", err); - Poll::Ready(Err(err.into())) + Poll::Ready(Err(err)) } } } diff --git a/actix-http/src/test.rs b/actix-http/src/test.rs index 870a656df..1f76498ef 100644 --- a/actix-http/src/test.rs +++ b/actix-http/src/test.rs @@ -13,20 +13,15 @@ use actix_codec::{AsyncRead, AsyncWrite, ReadBuf}; use bytes::{Bytes, BytesMut}; use http::{Method, Uri, Version}; -#[cfg(feature = "cookies")] use crate::{ - cookie::{Cookie, CookieJar}, - header::{self, HeaderValue}, -}; -use crate::{ - header::{HeaderMap, IntoHeaderPair}, + header::{HeaderMap, TryIntoHeaderPair}, payload::Payload, Request, }; /// Test `Request` builder /// -/// ```rust,ignore +/// ```ignore /// # use http::{header, StatusCode}; /// # use actix_web::*; /// use actix_web::test::TestRequest; @@ -54,8 +49,6 @@ struct Inner { method: Method, uri: Uri, headers: HeaderMap, - #[cfg(feature = "cookies")] - cookies: CookieJar, payload: Option, } @@ -66,8 +59,6 @@ impl Default for TestRequest { uri: Uri::from_str("/").unwrap(), version: Version::HTTP_11, headers: HeaderMap::new(), - #[cfg(feature = "cookies")] - cookies: CookieJar::new(), payload: None, })) } @@ -101,11 +92,8 @@ impl TestRequest { } /// Insert a header, replacing any that were set with an equivalent field name. - pub fn insert_header(&mut self, header: H) -> &mut Self - where - H: IntoHeaderPair, - { - match header.try_into_header_pair() { + pub fn insert_header(&mut self, header: impl TryIntoHeaderPair) -> &mut Self { + match header.try_into_pair() { Ok((key, value)) => { parts(&mut self.0).headers.insert(key, value); } @@ -118,11 +106,8 @@ impl TestRequest { } /// Append a header, keeping any that were set with an equivalent field name. - pub fn append_header(&mut self, header: H) -> &mut Self - where - H: IntoHeaderPair, - { - match header.try_into_header_pair() { + pub fn append_header(&mut self, header: impl TryIntoHeaderPair) -> &mut Self { + match header.try_into_pair() { Ok((key, value)) => { parts(&mut self.0).headers.append(key, value); } @@ -134,15 +119,8 @@ impl TestRequest { self } - /// Set cookie for this request. - #[cfg(feature = "cookies")] - pub fn cookie<'a>(&mut self, cookie: Cookie<'a>) -> &mut Self { - parts(&mut self.0).cookies.add(cookie.into_owned()); - self - } - /// Set request payload. - pub fn set_payload>(&mut self, data: B) -> &mut Self { + pub fn set_payload(&mut self, data: impl Into) -> &mut Self { let mut payload = crate::h1::Payload::empty(); payload.unread_data(data.into()); parts(&mut self.0).payload = Some(payload.into()); @@ -169,22 +147,6 @@ impl TestRequest { head.version = inner.version; head.headers = inner.headers; - #[cfg(feature = "cookies")] - { - let cookie: String = inner - .cookies - .delta() - // ensure only name=value is written to cookie header - .map(|c| Cookie::new(c.name(), c.value()).encoded().to_string()) - .collect::>() - .join("; "); - - if !cookie.is_empty() { - head.headers - .insert(header::COOKIE, HeaderValue::from_str(&cookie).unwrap()); - } - } - req } } @@ -302,7 +264,7 @@ impl TestSeqBuffer { /// Create new empty `TestBuffer` instance. pub fn empty() -> Self { - Self::new("") + Self::new(BytesMut::new()) } pub fn read_buf(&self) -> Ref<'_, BytesMut> { diff --git a/actix-http/src/time_parser.rs b/actix-http/src/time_parser.rs deleted file mode 100644 index 46bf73037..000000000 --- a/actix-http/src/time_parser.rs +++ /dev/null @@ -1,72 +0,0 @@ -use time::{Date, OffsetDateTime, PrimitiveDateTime}; - -/// Attempt to parse a `time` string as one of either RFC 1123, RFC 850, or asctime. -pub fn parse_http_date(time: &str) -> Option { - try_parse_rfc_1123(time) - .or_else(|| try_parse_rfc_850(time)) - .or_else(|| try_parse_asctime(time)) -} - -/// Attempt to parse a `time` string as a RFC 1123 formatted date time string. -/// -/// Eg: `Fri, 12 Feb 2021 00:14:29 GMT` -fn try_parse_rfc_1123(time: &str) -> Option { - time::parse(time, "%a, %d %b %Y %H:%M:%S").ok() -} - -/// Attempt to parse a `time` string as a RFC 850 formatted date time string. -/// -/// Eg: `Wednesday, 11-Jan-21 13:37:41 UTC` -fn try_parse_rfc_850(time: &str) -> Option { - let dt = PrimitiveDateTime::parse(time, "%A, %d-%b-%y %H:%M:%S").ok()?; - - // If the `time` string contains a two-digit year, then as per RFC 2616 § 19.3, - // we consider the year as part of this century if it's within the next 50 years, - // otherwise we consider as part of the previous century. - - let now = OffsetDateTime::now_utc(); - let century_start_year = (now.year() / 100) * 100; - let mut expanded_year = century_start_year + dt.year(); - - if expanded_year > now.year() + 50 { - expanded_year -= 100; - } - - let date = Date::try_from_ymd(expanded_year, dt.month(), dt.day()).ok()?; - Some(PrimitiveDateTime::new(date, dt.time())) -} - -/// Attempt to parse a `time` string using ANSI C's `asctime` format. -/// -/// Eg: `Wed Feb 13 15:46:11 2013` -fn try_parse_asctime(time: &str) -> Option { - time::parse(time, "%a %b %_d %H:%M:%S %Y").ok() -} - -#[cfg(test)] -mod tests { - use time::{date, time}; - - use super::*; - - #[test] - fn test_rfc_850_year_shift() { - let date = try_parse_rfc_850("Friday, 19-Nov-82 16:14:55 EST").unwrap(); - assert_eq!(date, date!(1982 - 11 - 19).with_time(time!(16:14:55))); - - let date = try_parse_rfc_850("Wednesday, 11-Jan-62 13:37:41 EST").unwrap(); - assert_eq!(date, date!(2062 - 01 - 11).with_time(time!(13:37:41))); - - let date = try_parse_rfc_850("Wednesday, 11-Jan-21 13:37:41 EST").unwrap(); - assert_eq!(date, date!(2021 - 01 - 11).with_time(time!(13:37:41))); - - let date = try_parse_rfc_850("Wednesday, 11-Jan-23 13:37:41 EST").unwrap(); - assert_eq!(date, date!(2023 - 01 - 11).with_time(time!(13:37:41))); - - let date = try_parse_rfc_850("Wednesday, 11-Jan-99 13:37:41 EST").unwrap(); - assert_eq!(date, date!(1999 - 01 - 11).with_time(time!(13:37:41))); - - let date = try_parse_rfc_850("Wednesday, 11-Jan-00 13:37:41 EST").unwrap(); - assert_eq!(date, date!(2000 - 01 - 11).with_time(time!(13:37:41))); - } -} diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index 54d850854..f5b755eec 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -63,8 +63,8 @@ pub enum Item { Last(Bytes), } -#[derive(Debug, Copy, Clone)] /// WebSocket protocol codec. +#[derive(Debug, Clone)] pub struct Codec { flags: Flags, max_size: usize, @@ -80,7 +80,7 @@ bitflags! { impl Codec { /// Create new WebSocket frames decoder. - pub fn new() -> Codec { + pub const fn new() -> Codec { Codec { max_size: 65_536, flags: Flags::SERVER, @@ -89,7 +89,8 @@ impl Codec { /// Set max frame size. /// - /// By default max size is set to 64kB. + /// By default max size is set to 64KiB. + #[must_use = "This returns the a new Codec, without modifying the original."] pub fn max_size(mut self, size: usize) -> Self { self.max_size = size; self @@ -98,12 +99,19 @@ impl Codec { /// Set decoder to client mode. /// /// By default decoder works in server mode. + #[must_use = "This returns the a new Codec, without modifying the original."] pub fn client_mode(mut self) -> Self { self.flags.remove(Flags::SERVER); self } } +impl Default for Codec { + fn default() -> Self { + Self::new() + } +} + impl Encoder for Codec { type Error = ProtocolError; @@ -216,9 +224,7 @@ impl Decoder for Codec { OpCode::Continue => { if self.flags.contains(Flags::CONTINUATION) { Ok(Some(Frame::Continuation(Item::Continue( - payload - .map(|pl| pl.freeze()) - .unwrap_or_else(Bytes::new), + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationNotStarted) @@ -228,9 +234,7 @@ impl Decoder for Codec { if !self.flags.contains(Flags::CONTINUATION) { self.flags.insert(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::FirstBinary( - payload - .map(|pl| pl.freeze()) - .unwrap_or_else(Bytes::new), + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationStarted) @@ -240,9 +244,7 @@ impl Decoder for Codec { if !self.flags.contains(Flags::CONTINUATION) { self.flags.insert(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::FirstText( - payload - .map(|pl| pl.freeze()) - .unwrap_or_else(Bytes::new), + payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationStarted) diff --git a/actix-http/src/ws/dispatcher.rs b/actix-http/src/ws/dispatcher.rs index 7be7cf637..f12ae1b1a 100644 --- a/actix-http/src/ws/dispatcher.rs +++ b/actix-http/src/ws/dispatcher.rs @@ -4,18 +4,21 @@ use std::task::{Context, Poll}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_service::{IntoService, Service}; -use actix_utils::dispatcher::{Dispatcher as InnerDispatcher, DispatcherError}; +use pin_project_lite::pin_project; use super::{Codec, Frame, Message}; -#[pin_project::pin_project] -pub struct Dispatcher -where - S: Service + 'static, - T: AsyncRead + AsyncWrite, -{ - #[pin] - inner: InnerDispatcher, +pin_project! { + pub struct Dispatcher + where + S: Service, + S: 'static, + T: AsyncRead, + T: AsyncWrite, + { + #[pin] + inner: inner::Dispatcher, + } } impl Dispatcher @@ -27,13 +30,13 @@ where { pub fn new>(io: T, service: F) -> Self { Dispatcher { - inner: InnerDispatcher::new(Framed::new(io, Codec::new()), service), + inner: inner::Dispatcher::new(Framed::new(io, Codec::new()), service), } } pub fn with>(framed: Framed, service: F) -> Self { Dispatcher { - inner: InnerDispatcher::new(framed, service), + inner: inner::Dispatcher::new(framed, service), } } } @@ -45,9 +48,391 @@ where S::Future: 'static, S::Error: 'static, { - type Output = Result<(), DispatcherError>; + type Output = Result<(), inner::DispatcherError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.project().inner.poll(cx) } } + +/// Framed dispatcher service and related utilities. +mod inner { + // allow dead code since this mod was ripped from actix-utils + #![allow(dead_code)] + + use core::{ + fmt, + future::Future, + mem, + pin::Pin, + task::{Context, Poll}, + }; + + use actix_service::{IntoService, Service}; + use futures_core::stream::Stream; + use local_channel::mpsc; + use log::debug; + use pin_project_lite::pin_project; + + use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; + + use crate::{body::BoxBody, Response}; + + /// Framed transport errors + pub enum DispatcherError + where + U: Encoder + Decoder, + { + /// Inner service error. + Service(E), + + /// Frame encoding error. + Encoder(>::Error), + + /// Frame decoding error. + Decoder(::Error), + } + + impl From for DispatcherError + where + U: Encoder + Decoder, + { + fn from(err: E) -> Self { + DispatcherError::Service(err) + } + } + + impl fmt::Debug for DispatcherError + where + E: fmt::Debug, + U: Encoder + Decoder, + >::Error: fmt::Debug, + ::Error: fmt::Debug, + { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + DispatcherError::Service(ref e) => { + write!(fmt, "DispatcherError::Service({:?})", e) + } + DispatcherError::Encoder(ref e) => { + write!(fmt, "DispatcherError::Encoder({:?})", e) + } + DispatcherError::Decoder(ref e) => { + write!(fmt, "DispatcherError::Decoder({:?})", e) + } + } + } + } + + impl fmt::Display for DispatcherError + where + E: fmt::Display, + U: Encoder + Decoder, + >::Error: fmt::Debug, + ::Error: fmt::Debug, + { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + DispatcherError::Service(ref e) => write!(fmt, "{}", e), + DispatcherError::Encoder(ref e) => write!(fmt, "{:?}", e), + DispatcherError::Decoder(ref e) => write!(fmt, "{:?}", e), + } + } + } + + impl From> for Response + where + E: fmt::Debug + fmt::Display, + U: Encoder + Decoder, + >::Error: fmt::Debug, + ::Error: fmt::Debug, + { + fn from(err: DispatcherError) -> Self { + Response::internal_server_error().set_body(BoxBody::new(err.to_string())) + } + } + + /// Message type wrapper for signalling end of message stream. + pub enum Message { + /// Message item. + Item(T), + + /// Signal from service to flush all messages and stop processing. + Close, + } + + pin_project! { + /// A future that reads frames from a [`Framed`] object and passes them to a [`Service`]. + pub struct Dispatcher + where + S: Service<::Item, Response = I>, + S::Error: 'static, + S::Future: 'static, + T: AsyncRead, + T: AsyncWrite, + U: Encoder, + U: Decoder, + I: 'static, + >::Error: fmt::Debug, + { + service: S, + state: State, + #[pin] + framed: Framed, + rx: mpsc::Receiver, S::Error>>, + tx: mpsc::Sender, S::Error>>, + } + } + + enum State + where + S: Service<::Item>, + U: Encoder + Decoder, + { + Processing, + Error(DispatcherError), + FramedError(DispatcherError), + FlushAndStop, + Stopping, + } + + impl State + where + S: Service<::Item>, + U: Encoder + Decoder, + { + fn take_error(&mut self) -> DispatcherError { + match mem::replace(self, State::Processing) { + State::Error(err) => err, + _ => panic!(), + } + } + + fn take_framed_error(&mut self) -> DispatcherError { + match mem::replace(self, State::Processing) { + State::FramedError(err) => err, + _ => panic!(), + } + } + } + + impl Dispatcher + where + S: Service<::Item, Response = I>, + S::Error: 'static, + S::Future: 'static, + T: AsyncRead + AsyncWrite, + U: Decoder + Encoder, + I: 'static, + ::Error: fmt::Debug, + >::Error: fmt::Debug, + { + /// Create new `Dispatcher`. + pub fn new(framed: Framed, service: F) -> Self + where + F: IntoService::Item>, + { + let (tx, rx) = mpsc::channel(); + Dispatcher { + framed, + rx, + tx, + service: service.into_service(), + state: State::Processing, + } + } + + /// Construct new `Dispatcher` instance with customer `mpsc::Receiver` + pub fn with_rx( + framed: Framed, + service: F, + rx: mpsc::Receiver, S::Error>>, + ) -> Self + where + F: IntoService::Item>, + { + let tx = rx.sender(); + Dispatcher { + framed, + rx, + tx, + service: service.into_service(), + state: State::Processing, + } + } + + /// Get sender handle. + pub fn tx(&self) -> mpsc::Sender, S::Error>> { + self.tx.clone() + } + + /// Get reference to a service wrapped by `Dispatcher` instance. + pub fn service(&self) -> &S { + &self.service + } + + /// Get mutable reference to a service wrapped by `Dispatcher` instance. + pub fn service_mut(&mut self) -> &mut S { + &mut self.service + } + + /// Get reference to a framed instance wrapped by `Dispatcher` instance. + pub fn framed(&self) -> &Framed { + &self.framed + } + + /// Get mutable reference to a framed instance wrapped by `Dispatcher` instance. + pub fn framed_mut(&mut self) -> &mut Framed { + &mut self.framed + } + + /// Read from framed object. + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool + where + S: Service<::Item, Response = I>, + S::Error: 'static, + S::Future: 'static, + T: AsyncRead + AsyncWrite, + U: Decoder + Encoder, + I: 'static, + >::Error: fmt::Debug, + { + loop { + let this = self.as_mut().project(); + match this.service.poll_ready(cx) { + Poll::Ready(Ok(_)) => { + let item = match this.framed.next_item(cx) { + Poll::Ready(Some(Ok(el))) => el, + Poll::Ready(Some(Err(err))) => { + *this.state = State::FramedError(DispatcherError::Decoder(err)); + return true; + } + Poll::Pending => return false, + Poll::Ready(None) => { + *this.state = State::Stopping; + return true; + } + }; + + let tx = this.tx.clone(); + let fut = this.service.call(item); + actix_rt::spawn(async move { + let item = fut.await; + let _ = tx.send(item.map(Message::Item)); + }); + } + Poll::Pending => return false, + Poll::Ready(Err(err)) => { + *this.state = State::Error(DispatcherError::Service(err)); + return true; + } + } + } + } + + /// Write to framed object. + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool + where + S: Service<::Item, Response = I>, + S::Error: 'static, + S::Future: 'static, + T: AsyncRead + AsyncWrite, + U: Decoder + Encoder, + I: 'static, + >::Error: fmt::Debug, + { + loop { + let mut this = self.as_mut().project(); + while !this.framed.is_write_buf_full() { + match Pin::new(&mut this.rx).poll_next(cx) { + Poll::Ready(Some(Ok(Message::Item(msg)))) => { + if let Err(err) = this.framed.as_mut().write(msg) { + *this.state = State::FramedError(DispatcherError::Encoder(err)); + return true; + } + } + Poll::Ready(Some(Ok(Message::Close))) => { + *this.state = State::FlushAndStop; + return true; + } + Poll::Ready(Some(Err(err))) => { + *this.state = State::Error(DispatcherError::Service(err)); + return true; + } + Poll::Ready(None) | Poll::Pending => break, + } + } + + if !this.framed.is_write_buf_empty() { + match this.framed.flush(cx) { + Poll::Pending => break, + Poll::Ready(Ok(_)) => {} + Poll::Ready(Err(err)) => { + debug!("Error sending data: {:?}", err); + *this.state = State::FramedError(DispatcherError::Encoder(err)); + return true; + } + } + } else { + break; + } + } + + false + } + } + + impl Future for Dispatcher + where + S: Service<::Item, Response = I>, + S::Error: 'static, + S::Future: 'static, + T: AsyncRead + AsyncWrite, + U: Decoder + Encoder, + I: 'static, + >::Error: fmt::Debug, + ::Error: fmt::Debug, + { + type Output = Result<(), DispatcherError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let this = self.as_mut().project(); + + return match this.state { + State::Processing => { + if self.as_mut().poll_read(cx) || self.as_mut().poll_write(cx) { + continue; + } else { + Poll::Pending + } + } + State::Error(_) => { + // flush write buffer + if !this.framed.is_write_buf_empty() + && this.framed.flush(cx).is_pending() + { + return Poll::Pending; + } + Poll::Ready(Err(this.state.take_error())) + } + State::FlushAndStop => { + if !this.framed.is_write_buf_empty() { + this.framed.flush(cx).map(|res| { + if let Err(err) = res { + debug!("Error sending data: {:?}", err); + } + + Ok(()) + }) + } else { + Poll::Ready(Ok(())) + } + } + State::FramedError(_) => Poll::Ready(Err(this.state.take_framed_error())), + State::Stopping => Poll::Ready(Ok(())), + }; + } + } + } +} diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs index 46edf5d85..b58ef7362 100644 --- a/actix-http/src/ws/frame.rs +++ b/actix-http/src/ws/frame.rs @@ -16,8 +16,7 @@ impl Parser { src: &[u8], server: bool, max_size: usize, - ) -> Result)>, ProtocolError> - { + ) -> Result)>, ProtocolError> { let chunk_len = src.len(); let mut idx = 2; @@ -228,15 +227,11 @@ mod tests { payload: Bytes, } - fn is_none( - frm: &Result)>, ProtocolError>, - ) -> bool { + fn is_none(frm: &Result)>, ProtocolError>) -> bool { matches!(*frm, Ok(None)) } - fn extract( - frm: Result)>, ProtocolError>, - ) -> F { + fn extract(frm: Result)>, ProtocolError>) -> F { match frm { Ok(Some((finished, opcode, payload))) => F { finished, diff --git a/actix-http/src/ws/mask.rs b/actix-http/src/ws/mask.rs index 276ca4a85..20b4372a0 100644 --- a/actix-http/src/ws/mask.rs +++ b/actix-http/src/ws/mask.rs @@ -25,8 +25,8 @@ pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { // // un aligned prefix and suffix would be mask/unmask per byte. // proper aligned middle slice goes into fast path and operates on 4-byte blocks. - let (mut prefix, words, mut suffix) = unsafe { buf.align_to_mut::() }; - apply_mask_fallback(&mut prefix, mask); + let (prefix, words, suffix) = unsafe { buf.align_to_mut::() }; + apply_mask_fallback(prefix, mask); let head = prefix.len() & 3; let mask_u32 = if head > 0 { if cfg!(target_endian = "big") { @@ -40,7 +40,7 @@ pub fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) { for word in words.iter_mut() { *word ^= mask_u32; } - apply_mask_fallback(&mut suffix, mask_u32.to_ne_bytes()); + apply_mask_fallback(suffix, mask_u32.to_ne_bytes()); } #[cfg(test)] @@ -54,8 +54,8 @@ mod tests { let mask = [0x6d, 0xb6, 0xb2, 0x80]; let unmasked = vec![ - 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, - 0x74, 0xf9, 0x12, 0x03, + 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, + 0x12, 0x03, ]; // Check masking with proper alignment. @@ -85,8 +85,8 @@ mod tests { fn test_apply_mask() { let mask = [0x6d, 0xb6, 0xb2, 0x80]; let unmasked = vec![ - 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, - 0x74, 0xf9, 0x12, 0x03, + 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, + 0x12, 0x03, ]; for data_len in 0..=unmasked.len() { diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 0490163d5..568d801a2 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -1,4 +1,4 @@ -//! WebSocket protocol. +//! WebSocket protocol implementation. //! //! To setup a WebSocket, first perform the WebSocket handshake then on success convert `Payload` into a //! `WsStream` stream and then use `WsWriter` to communicate with the peer. @@ -8,9 +8,8 @@ use std::io; use derive_more::{Display, Error, From}; use http::{header, Method, StatusCode}; -use crate::error::ResponseError; -use crate::message::RequestHead; -use crate::response::{Response, ResponseBuilder}; +use crate::body::BoxBody; +use crate::{header::HeaderValue, RequestHead, Response, ResponseBuilder}; mod codec; mod dispatcher; @@ -24,7 +23,7 @@ pub use self::frame::Parser; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; /// WebSocket protocol errors. -#[derive(Debug, Display, From, Error)] +#[derive(Debug, Display, Error, From)] pub enum ProtocolError { /// Received an unmasked frame from client. #[display(fmt = "Received an unmasked frame from client.")] @@ -67,10 +66,8 @@ pub enum ProtocolError { Io(io::Error), } -impl ResponseError for ProtocolError {} - /// WebSocket handshake errors -#[derive(PartialEq, Debug, Display)] +#[derive(Debug, Clone, Copy, PartialEq, Display, Error)] pub enum HandshakeError { /// Only get method is allowed. #[display(fmt = "Method not allowed.")] @@ -89,7 +86,7 @@ pub enum HandshakeError { NoVersionHeader, /// Unsupported WebSocket version. - #[display(fmt = "Unsupported version.")] + #[display(fmt = "Unsupported WebSocket version.")] UnsupportedVersion, /// WebSocket key is not set or wrong. @@ -97,36 +94,56 @@ pub enum HandshakeError { BadWebsocketKey, } -impl ResponseError for HandshakeError { - fn error_response(&self) -> Response { - match self { - HandshakeError::GetMethodRequired => Response::MethodNotAllowed() - .insert_header((header::ALLOW, "GET")) - .finish(), +impl From for Response { + fn from(err: HandshakeError) -> Self { + match err { + HandshakeError::GetMethodRequired => { + let mut res = Response::new(StatusCode::METHOD_NOT_ALLOWED); + #[allow(clippy::declare_interior_mutable_const)] + const HV_GET: HeaderValue = HeaderValue::from_static("GET"); + res.headers_mut().insert(header::ALLOW, HV_GET); + res + } - HandshakeError::NoWebsocketUpgrade => Response::BadRequest() - .reason("No WebSocket UPGRADE header found") - .finish(), + HandshakeError::NoWebsocketUpgrade => { + let mut res = Response::bad_request(); + res.head_mut().reason = Some("No WebSocket Upgrade header found"); + res + } - HandshakeError::NoConnectionUpgrade => Response::BadRequest() - .reason("No CONNECTION upgrade") - .finish(), + HandshakeError::NoConnectionUpgrade => { + let mut res = Response::bad_request(); + res.head_mut().reason = Some("No Connection upgrade"); + res + } - HandshakeError::NoVersionHeader => Response::BadRequest() - .reason("Websocket version header is required") - .finish(), + HandshakeError::NoVersionHeader => { + let mut res = Response::bad_request(); + res.head_mut().reason = Some("WebSocket version header is required"); + res + } - HandshakeError::UnsupportedVersion => Response::BadRequest() - .reason("Unsupported version") - .finish(), + HandshakeError::UnsupportedVersion => { + let mut res = Response::bad_request(); + res.head_mut().reason = Some("Unsupported WebSocket version"); + res + } HandshakeError::BadWebsocketKey => { - Response::BadRequest().reason("Handshake error").finish() + let mut res = Response::bad_request(); + res.head_mut().reason = Some("Handshake error"); + res } } } } +impl From<&HandshakeError> for Response { + fn from(err: &HandshakeError) -> Self { + (*err).into() + } +} + /// Verify WebSocket handshake request and create handshake response. pub fn handshake(req: &RequestHead) -> Result { verify_handshake(req)?; @@ -192,16 +209,20 @@ pub fn handshake_response(req: &RequestHead) -> ResponseBuilder { Response::build(StatusCode::SWITCHING_PROTOCOLS) .upgrade("websocket") - .insert_header((header::TRANSFER_ENCODING, "chunked")) - .insert_header((header::SEC_WEBSOCKET_ACCEPT, key)) + .insert_header(( + header::SEC_WEBSOCKET_ACCEPT, + // key is known to be header value safe ascii + HeaderValue::from_bytes(&key).unwrap(), + )) .take() } #[cfg(test)] mod tests { + use crate::{header, Method}; + use super::*; use crate::test::TestRequest; - use http::{header, Method}; #[test] fn test_handshake() { @@ -314,18 +335,18 @@ mod tests { } #[test] - fn test_wserror_http_response() { - let resp: Response = HandshakeError::GetMethodRequired.error_response(); + fn test_ws_error_http_response() { + let resp: Response = HandshakeError::GetMethodRequired.into(); assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let resp: Response = HandshakeError::NoWebsocketUpgrade.error_response(); + let resp: Response = HandshakeError::NoWebsocketUpgrade.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: Response = HandshakeError::NoConnectionUpgrade.error_response(); + let resp: Response = HandshakeError::NoConnectionUpgrade.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: Response = HandshakeError::NoVersionHeader.error_response(); + let resp: Response = HandshakeError::NoVersionHeader.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: Response = HandshakeError::UnsupportedVersion.error_response(); + let resp: Response = HandshakeError::UnsupportedVersion.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: Response = HandshakeError::BadWebsocketKey.error_response(); + let resp: Response = HandshakeError::BadWebsocketKey.into(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } } diff --git a/actix-http/src/ws/proto.rs b/actix-http/src/ws/proto.rs index 1e8bf7af3..4227f221d 100644 --- a/actix-http/src/ws/proto.rs +++ b/actix-http/src/ws/proto.rs @@ -1,7 +1,11 @@ -use std::convert::{From, Into}; -use std::fmt; +use std::{ + convert::{From, Into}, + fmt, +}; -/// Operation codes as part of RFC6455. +/// Operation codes defined in [RFC 6455 §11.8]. +/// +/// [RFC 6455]: https://datatracker.ietf.org/doc/html/rfc6455#section-11.8 #[derive(Debug, Eq, PartialEq, Clone, Copy)] pub enum OpCode { /// Indicates a continuation frame of a fragmented message. @@ -28,8 +32,9 @@ pub enum OpCode { impl fmt::Display for OpCode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use self::OpCode::*; - match *self { + use OpCode::*; + + match self { Continue => write!(f, "CONTINUE"), Text => write!(f, "TEXT"), Binary => write!(f, "BINARY"), @@ -44,6 +49,7 @@ impl fmt::Display for OpCode { impl From for u8 { fn from(op: OpCode) -> u8 { use self::OpCode::*; + match op { Continue => 0, Text => 1, @@ -62,6 +68,7 @@ impl From for u8 { impl From for OpCode { fn from(byte: u8) -> OpCode { use self::OpCode::*; + match byte { 0 => Continue, 1 => Text, @@ -77,63 +84,66 @@ impl From for OpCode { /// Status code used to indicate why an endpoint is closing the WebSocket connection. #[derive(Debug, Eq, PartialEq, Clone, Copy)] pub enum CloseCode { - /// Indicates a normal closure, meaning that the purpose for - /// which the connection was established has been fulfilled. + /// Indicates a normal closure, meaning that the purpose for which the connection was + /// established has been fulfilled. Normal, - /// Indicates that an endpoint is "going away", such as a server - /// going down or a browser having navigated away from a page. + + /// Indicates that an endpoint is "going away", such as a server going down or a browser having + /// navigated away from a page. Away, - /// Indicates that an endpoint is terminating the connection due - /// to a protocol error. + + /// Indicates that an endpoint is terminating the connection due to a protocol error. Protocol, - /// Indicates that an endpoint is terminating the connection - /// because it has received a type of data it cannot accept (e.g., an - /// endpoint that understands only text data MAY send this if it + + /// Indicates that an endpoint is terminating the connection because it has received a type of + /// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if it /// receives a binary message). Unsupported, - /// Indicates an abnormal closure. If the abnormal closure was due to an - /// error, this close code will not be used. Instead, the `on_error` method - /// of the handler will be called with the error. However, if the connection - /// is simply dropped, without an error, this close code will be sent to the - /// handler. + + /// Indicates an abnormal closure. If the abnormal closure was due to an error, this close code + /// will not be used. Instead, the `on_error` method of the handler will be called with + /// the error. However, if the connection is simply dropped, without an error, this close code + /// will be sent to the handler. Abnormal, - /// Indicates that an endpoint is terminating the connection - /// because it has received data within a message that was not - /// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\] + + /// Indicates that an endpoint is terminating the connection because it has received data within + /// a message that was not consistent with the type of the message (e.g., non-UTF-8 \[RFC 3629\] /// data within a text message). Invalid, - /// Indicates that an endpoint is terminating the connection - /// because it has received a message that violates its policy. This - /// is a generic status code that can be returned when there is no - /// other more suitable status code (e.g., Unsupported or Size) or if there - /// is a need to hide specific details about the policy. + + /// Indicates that an endpoint is terminating the connection because it has received a message + /// that violates its policy. This is a generic status code that can be returned when there is + /// no other more suitable status code (e.g., Unsupported or Size) or if there is a need to hide + /// specific details about the policy. Policy, - /// Indicates that an endpoint is terminating the connection - /// because it has received a message that is too big for it to - /// process. + + /// Indicates that an endpoint is terminating the connection because it has received a message + /// that is too big for it to process. Size, - /// Indicates that an endpoint (client) is terminating the - /// connection because it has expected the server to negotiate one or - /// more extension, but the server didn't return them in the response - /// message of the WebSocket handshake. The list of extensions that - /// are needed should be given as the reason for closing. - /// Note that this status code is not used by the server, because it - /// can fail the WebSocket handshake instead. + + /// Indicates that an endpoint (client) is terminating the connection because it has expected + /// the server to negotiate one or more extension, but the server didn't return them in the + /// response message of the WebSocket handshake. The list of extensions that are needed should + /// be given as the reason for closing. Note that this status code is not used by the server, + /// because it can fail the WebSocket handshake instead. Extension, - /// Indicates that a server is terminating the connection because - /// it encountered an unexpected condition that prevented it from - /// fulfilling the request. + + /// Indicates that a server is terminating the connection because it encountered an unexpected + /// condition that prevented it from fulfilling the request. Error, - /// Indicates that the server is restarting. A client may choose to - /// reconnect, and if it does, it should use a randomized delay of 5-30 - /// seconds between attempts. + + /// Indicates that the server is restarting. A client may choose to reconnect, and if it does, + /// it should use a randomized delay of 5-30 seconds between attempts. Restart, - /// Indicates that the server is overloaded and the client should either - /// connect to a different IP (when multiple targets exist), or - /// reconnect to the same IP when a user has performed an action. + + /// Indicates that the server is overloaded and the client should either connect to a different + /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed + /// an action. Again, + #[doc(hidden)] Tls, + #[doc(hidden)] Other(u16), } @@ -141,6 +151,7 @@ pub enum CloseCode { impl From for u16 { fn from(code: CloseCode) -> u16 { use self::CloseCode::*; + match code { Normal => 1000, Away => 1001, @@ -163,6 +174,7 @@ impl From for u16 { impl From for CloseCode { fn from(code: u16) -> CloseCode { use self::CloseCode::*; + match code { 1000 => Normal, 1001 => Away, @@ -210,17 +222,30 @@ impl> From<(CloseCode, T)> for CloseReason { } } -static WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +/// The WebSocket GUID as stated in the spec. +/// See . +static WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; -// TODO: hash is always same size, we don't need String -pub fn hash_key(key: &[u8]) -> String { - use sha1::Digest; - let mut hasher = sha1::Sha1::new(); +/// Hashes the `Sec-WebSocket-Key` header according to the WebSocket spec. +/// +/// Result is a Base64 encoded byte array. `base64(sha1(input))` is always 28 bytes. +pub fn hash_key(key: &[u8]) -> [u8; 28] { + let hash = { + use sha1::Digest as _; - hasher.update(key); - hasher.update(WS_GUID.as_bytes()); + let mut hasher = sha1::Sha1::new(); - base64::encode(&hasher.finalize()) + hasher.update(key); + hasher.update(WS_GUID); + + hasher.finalize() + }; + + let mut hash_b64 = [0; 28]; + let n = base64::encode_config_slice(&hash, base64::STANDARD, &mut hash_b64); + assert_eq!(n, 28); + + hash_b64 } #[cfg(test)] @@ -288,11 +313,11 @@ mod test { #[test] fn test_hash_key() { let hash = hash_key(b"hello actix-web"); - assert_eq!(&hash, "cR1dlyUUJKp0s/Bel25u5TgvC3E="); + assert_eq!(&hash, b"cR1dlyUUJKp0s/Bel25u5TgvC3E="); } #[test] - fn closecode_from_u16() { + fn close_code_from_u16() { assert_eq!(CloseCode::from(1000u16), CloseCode::Normal); assert_eq!(CloseCode::from(1001u16), CloseCode::Away); assert_eq!(CloseCode::from(1002u16), CloseCode::Protocol); @@ -310,7 +335,7 @@ mod test { } #[test] - fn closecode_into_u16() { + fn close_code_into_u16() { assert_eq!(1000u16, Into::::into(CloseCode::Normal)); assert_eq!(1001u16, Into::::into(CloseCode::Away)); assert_eq!(1002u16, Into::::into(CloseCode::Protocol)); diff --git a/actix-http/tests/test_client.rs b/actix-http/tests/test_client.rs index 91b2412f4..a3adcdfd6 100644 --- a/actix-http/tests/test_client.rs +++ b/actix-http/tests/test_client.rs @@ -1,8 +1,12 @@ -use actix_http::{http, HttpService, Request, Response}; +use std::convert::Infallible; + +use actix_http::{body::BoxBody, HttpMessage, HttpService, Request, Response, StatusCode}; use actix_http_test::test_server; use actix_service::ServiceFactoryExt; +use actix_utils::future; use bytes::Bytes; -use futures_util::future::{self, ok}; +use derive_more::{Display, Error}; +use futures_util::StreamExt as _; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -30,7 +34,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ async fn test_h1_v2() { let srv = test_server(move || { HttpService::build() - .finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .finish(|_| future::ok::<_, Infallible>(Response::ok().set_body(STR))) .tcp() }) .await; @@ -58,7 +62,7 @@ async fn test_h1_v2() { async fn test_connection_close() { let srv = test_server(move || { HttpService::build() - .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .finish(|_| future::ok::<_, Infallible>(Response::ok().set_body(STR))) .tcp() .map(|_| ()) }) @@ -72,11 +76,11 @@ async fn test_connection_close() { async fn test_with_query_parameter() { let srv = test_server(move || { HttpService::build() - .finish(|req: Request| { + .finish(|req: Request| async move { if req.uri().query().unwrap().contains("qp=") { - ok::<_, ()>(Response::Ok().finish()) + Ok::<_, Infallible>(Response::ok()) } else { - ok::<_, ()>(Response::BadRequest().finish()) + Ok(Response::bad_request()) } }) .tcp() @@ -88,3 +92,65 @@ async fn test_with_query_parameter() { let response = request.send().await.unwrap(); assert!(response.status().is_success()); } + +#[derive(Debug, Display, Error)] +#[display(fmt = "expect failed")] +struct ExpectFailed; + +impl From for Response { + fn from(_: ExpectFailed) -> Self { + Response::new(StatusCode::EXPECTATION_FAILED) + } +} + +#[actix_rt::test] +async fn test_h1_expect() { + let srv = test_server(move || { + HttpService::build() + .expect(|req: Request| async { + if req.headers().contains_key("AUTH") { + Ok(req) + } else { + Err(ExpectFailed) + } + }) + .h1(|req: Request| async move { + let (_, mut body) = req.into_parts(); + let mut buf = Vec::new(); + while let Some(Ok(chunk)) = body.next().await { + buf.extend_from_slice(&chunk); + } + let str = std::str::from_utf8(&buf).unwrap(); + assert_eq!(str, "expect body"); + + Ok::<_, Infallible>(Response::ok()) + }) + .tcp() + }) + .await; + + // test expect without payload. + let request = srv + .request(http::Method::GET, srv.url("/")) + .insert_header(("Expect", "100-continue")); + + let response = request.send().await; + assert!(response.is_err()); + + // test expect would fail to continue + let request = srv + .request(http::Method::GET, srv.url("/")) + .insert_header(("Expect", "100-continue")); + + let response = request.send_body("expect body").await.unwrap(); + assert_eq!(response.status(), StatusCode::EXPECTATION_FAILED); + + // test expect would continue + let request = srv + .request(http::Method::GET, srv.url("/")) + .insert_header(("Expect", "100-continue")) + .insert_header(("AUTH", "996")); + + let response = request.send_body("expect body").await.unwrap(); + assert!(response.status().is_success()); +} diff --git a/actix-http/tests/test_h2_timer.rs b/actix-http/tests/test_h2_timer.rs new file mode 100644 index 000000000..2b9c26e4a --- /dev/null +++ b/actix-http/tests/test_h2_timer.rs @@ -0,0 +1,153 @@ +use std::io; + +use actix_http::{error::Error, HttpService, Response}; +use actix_server::Server; +use tokio::io::AsyncWriteExt; + +#[actix_rt::test] +async fn h2_ping_pong() -> io::Result<()> { + let (tx, rx) = std::sync::mpsc::sync_channel(1); + + let lst = std::net::TcpListener::bind("127.0.0.1:0")?; + + let addr = lst.local_addr().unwrap(); + + let join = std::thread::spawn(move || { + actix_rt::System::new().block_on(async move { + let srv = Server::build() + .disable_signals() + .workers(1) + .listen("h2_ping_pong", lst, || { + HttpService::build() + .keep_alive(3) + .h2(|_| async { Ok::<_, Error>(Response::ok()) }) + .tcp() + })? + .run(); + + tx.send(srv.handle()).unwrap(); + + srv.await + }) + }); + + let handle = rx.recv().unwrap(); + + let (sync_tx, rx) = std::sync::mpsc::sync_channel(1); + + // use a separate thread for h2 client so it can be blocked. + std::thread::spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async move { + let stream = tokio::net::TcpStream::connect(addr).await.unwrap(); + + let (mut tx, conn) = h2::client::handshake(stream).await.unwrap(); + + tokio::spawn(async move { conn.await.unwrap() }); + + let (res, _) = tx.send_request(::http::Request::new(()), true).unwrap(); + let res = res.await.unwrap(); + + assert_eq!(res.status().as_u16(), 200); + + sync_tx.send(()).unwrap(); + + // intentionally block the client thread so it can not answer ping pong. + std::thread::sleep(std::time::Duration::from_secs(1000)); + }) + }); + + rx.recv().unwrap(); + + let now = std::time::Instant::now(); + + // stop server gracefully. this step would take up to 30 seconds. + handle.stop(true).await; + + // join server thread. only when connection are all gone this step would finish. + join.join().unwrap()?; + + // check the time used for join server thread so it's known that the server shutdown + // is from keep alive and not server graceful shutdown timeout. + assert!(now.elapsed() < std::time::Duration::from_secs(30)); + + Ok(()) +} + +#[actix_rt::test] +async fn h2_handshake_timeout() -> io::Result<()> { + let (tx, rx) = std::sync::mpsc::sync_channel(1); + + let lst = std::net::TcpListener::bind("127.0.0.1:0")?; + + let addr = lst.local_addr().unwrap(); + + let join = std::thread::spawn(move || { + actix_rt::System::new().block_on(async move { + let srv = Server::build() + .disable_signals() + .workers(1) + .listen("h2_ping_pong", lst, || { + HttpService::build() + .keep_alive(30) + // set first request timeout to 5 seconds. + // this is the timeout used for http2 handshake. + .client_timeout(5000) + .h2(|_| async { Ok::<_, Error>(Response::ok()) }) + .tcp() + })? + .run(); + + tx.send(srv.handle()).unwrap(); + + srv.await + }) + }); + + let handle = rx.recv().unwrap(); + + let (sync_tx, rx) = std::sync::mpsc::sync_channel(1); + + // use a separate thread for tcp client so it can be blocked. + std::thread::spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async move { + let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap(); + + // do not send the last new line intentionally. + // This should hang the server handshake + let malicious_buf = b"PRI * HTTP/2.0\r\n\r\nSM\r\n"; + stream.write_all(malicious_buf).await.unwrap(); + stream.flush().await.unwrap(); + + sync_tx.send(()).unwrap(); + + // intentionally block the client thread so it sit idle and not do handshake. + std::thread::sleep(std::time::Duration::from_secs(1000)); + + drop(stream) + }) + }); + + rx.recv().unwrap(); + + let now = std::time::Instant::now(); + + // stop server gracefully. this step would take up to 30 seconds. + handle.stop(true).await; + + // join server thread. only when connection are all gone this step would finish. + join.join().unwrap()?; + + // check the time used for join server thread so it's known that the server shutdown + // is from handshake timeout and not server graceful shutdown timeout. + assert!(now.elapsed() < std::time::Duration::from_secs(30)); + + Ok(()) +} diff --git a/actix-http/tests/test_openssl.rs b/actix-http/tests/test_openssl.rs index f44968baa..1e371473f 100644 --- a/actix-http/tests/test_openssl.rs +++ b/actix-http/tests/test_openssl.rs @@ -2,18 +2,21 @@ extern crate tls_openssl as openssl; -use std::io; +use std::{convert::Infallible, io}; -use actix_http::error::{ErrorBadRequest, PayloadError}; -use actix_http::http::header::{self, HeaderName, HeaderValue}; -use actix_http::http::{Method, StatusCode, Version}; -use actix_http::HttpMessage; -use actix_http::{body, Error, HttpService, Request, Response}; +use actix_http::{ + body::{BodyStream, BoxBody, SizedStream}, + error::PayloadError, + header::{self, HeaderValue}, + Error, HttpService, Method, Request, Response, StatusCode, Version, +}; use actix_http_test::test_server; use actix_service::{fn_service, ServiceFactoryExt}; +use actix_utils::future::{err, ok, ready}; use bytes::{Bytes, BytesMut}; -use futures_util::future::{err, ok, ready}; -use futures_util::stream::{once, Stream, StreamExt}; +use derive_more::{Display, Error}; +use futures_core::Stream; +use futures_util::stream::{once, StreamExt as _}; use openssl::{ pkey::PKey, ssl::{SslAcceptor, SslMethod}, @@ -66,7 +69,7 @@ fn tls_config() -> SslAcceptor { async fn test_h2() -> io::Result<()> { let srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, Error>(Response::Ok().finish())) + .h2(|_| ok::<_, Error>(Response::ok())) .openssl(tls_config()) .map_err(|_| ()) }) @@ -84,7 +87,7 @@ async fn test_h2_1() -> io::Result<()> { .finish(|req: Request| { assert!(req.peer_addr().is_some()); assert_eq!(req.version(), Version::HTTP_2); - ok::<_, Error>(Response::Ok().finish()) + ok::<_, Error>(Response::ok()) }) .openssl(tls_config()) .map_err(|_| ()) @@ -98,12 +101,12 @@ async fn test_h2_1() -> io::Result<()> { #[actix_rt::test] async fn test_h2_body() -> io::Result<()> { - let data = "HELLOWORLD".to_owned().repeat(64 * 1024); + let data = "HELLOWORLD".to_owned().repeat(64 * 1024); // 640 KiB let mut srv = test_server(move || { HttpService::build() .h2(|mut req: Request<_>| async move { let body = load_body(req.take_payload()).await?; - Ok::<_, Error>(Response::Ok().body(body)) + Ok::<_, Error>(Response::ok().set_body(body)) }) .openssl(tls_config()) .map_err(|_| ()) @@ -123,46 +126,39 @@ async fn test_h2_content_length() { let srv = test_server(move || { HttpService::build() .h2(|req: Request| { - let indx: usize = req.uri().path()[1..].parse().unwrap(); + let idx: usize = req.uri().path()[1..].parse().unwrap(); let statuses = [ - StatusCode::NO_CONTENT, StatusCode::CONTINUE, - StatusCode::SWITCHING_PROTOCOLS, - StatusCode::PROCESSING, + StatusCode::NO_CONTENT, StatusCode::OK, StatusCode::NOT_FOUND, ]; - ok::<_, ()>(Response::new(statuses[indx])) + ok::<_, Infallible>(Response::new(statuses[idx])) }) .openssl(tls_config()) .map_err(|_| ()) }) .await; - let header = HeaderName::from_static("content-length"); - let value = HeaderValue::from_static("0"); + static VALUE: HeaderValue = HeaderValue::from_static("0"); { - for i in 0..4 { + let req = srv.request(Method::HEAD, srv.surl("/0")).send(); + req.await.expect_err("should timeout on recv 1xx frame"); + + let req = srv.request(Method::GET, srv.surl("/0")).send(); + req.await.expect_err("should timeout on recv 1xx frame"); + + let req = srv.request(Method::GET, srv.surl("/1")).send(); + let response = req.await.unwrap(); + assert!(response.headers().get("content-length").is_none()); + + for &i in &[2, 3] { let req = srv .request(Method::GET, srv.surl(&format!("/{}", i))) .send(); let response = req.await.unwrap(); - assert_eq!(response.headers().get(&header), None); - - let req = srv - .request(Method::HEAD, srv.surl(&format!("/{}", i))) - .send(); - let response = req.await.unwrap(); - assert_eq!(response.headers().get(&header), None); - } - - for i in 4..6 { - let req = srv - .request(Method::GET, srv.surl(&format!("/{}", i))) - .send(); - let response = req.await.unwrap(); - assert_eq!(response.headers().get(&header), Some(&value)); + assert_eq!(response.headers().get("content-length"), Some(&VALUE)); } } } @@ -174,10 +170,11 @@ async fn test_h2_headers() { let mut srv = test_server(move || { let data = data.clone(); - HttpService::build().h2(move |_| { - let mut builder = Response::Ok(); - for idx in 0..90 { - builder.insert_header( + HttpService::build() + .h2(move |_| { + let mut builder = Response::build(StatusCode::OK); + for idx in 0..90 { + builder.insert_header( (format!("X-TEST-{}", idx).as_str(), "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ @@ -193,12 +190,13 @@ async fn test_h2_headers() { TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", )); - } - ok::<_, ()>(builder.body(data.clone())) - }) + } + ok::<_, Infallible>(builder.body(data.clone())) + }) .openssl(tls_config()) - .map_err(|_| ()) - }).await; + .map_err(|_| ()) + }) + .await; let response = srv.sget("/").send().await.unwrap(); assert!(response.status().is_success()); @@ -234,7 +232,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ async fn test_h2_body2() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .openssl(tls_config()) .map_err(|_| ()) }) @@ -252,7 +250,7 @@ async fn test_h2_body2() { async fn test_h2_head_empty() { let mut srv = test_server(move || { HttpService::build() - .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .finish(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .openssl(tls_config()) .map_err(|_| ()) }) @@ -276,7 +274,7 @@ async fn test_h2_head_empty() { async fn test_h2_head_binary() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .openssl(tls_config()) .map_err(|_| ()) }) @@ -299,7 +297,7 @@ async fn test_h2_head_binary() { async fn test_h2_head_binary2() { let srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .openssl(tls_config()) .map_err(|_| ()) }) @@ -318,10 +316,12 @@ async fn test_h2_head_binary2() { async fn test_h2_body_length() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| { - let body = once(ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), + .h2(|_| async { + let body = + once(async { Ok::<_, Infallible>(Bytes::from_static(STR.as_ref())) }); + + Ok::<_, Infallible>( + Response::ok().set_body(SizedStream::new(STR.len() as u64, body)), ) }) .openssl(tls_config()) @@ -343,10 +343,10 @@ async fn test_h2_body_chunked_explicit() { HttpService::build() .h2(|_| { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() + ok::<_, Infallible>( + Response::build(StatusCode::OK) .insert_header((header::TRANSFER_ENCODING, "chunked")) - .streaming(body), + .body(BodyStream::new(body)), ) }) .openssl(tls_config()) @@ -371,8 +371,8 @@ async fn test_h2_response_http_error_handling() { HttpService::build() .h2(fn_service(|_| { let broken_header = Bytes::from_static(b"\0\0\0"); - ok::<_, ()>( - Response::Ok() + ok::<_, Infallible>( + Response::build(StatusCode::OK) .insert_header((header::CONTENT_TYPE, broken_header)) .body(STR), ) @@ -387,14 +387,29 @@ async fn test_h2_response_http_error_handling() { // read response let bytes = srv.load_body(response).await.unwrap(); - assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); + assert_eq!( + bytes, + Bytes::from_static(b"error processing HTTP: failed to parse header value") + ); +} + +#[derive(Debug, Display, Error)] +#[display(fmt = "error")] +struct BadRequest; + +impl From for Response { + fn from(err: BadRequest) -> Self { + Response::build(StatusCode::BAD_REQUEST) + .body(err.to_string()) + .map_into_boxed_body() + } } #[actix_rt::test] async fn test_h2_service_error() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| err::(ErrorBadRequest("error"))) + .h2(|_| err::, _>(BadRequest)) .openssl(tls_config()) .map_err(|_| ()) }) @@ -416,8 +431,8 @@ async fn test_h2_on_connect() { data.insert(20isize); }) .h2(|req: Request| { - assert!(req.extensions().contains::()); - ok::<_, ()>(Response::Ok().finish()) + assert!(req.conn_data::().is_some()); + ok::<_, Infallible>(Response::ok()) }) .openssl(tls_config()) .map_err(|_| ()) diff --git a/actix-http/tests/test_rustls.rs b/actix-http/tests/test_rustls.rs index a36400910..51fefae72 100644 --- a/actix-http/tests/test_rustls.rs +++ b/actix-http/tests/test_rustls.rs @@ -2,33 +2,52 @@ extern crate tls_rustls as rustls; -use actix_http::error::PayloadError; -use actix_http::http::header::{self, HeaderName, HeaderValue}; -use actix_http::http::{Method, StatusCode, Version}; -use actix_http::{body, error, Error, HttpService, Request, Response}; -use actix_http_test::test_server; -use actix_service::{fn_factory_with_config, fn_service}; - -use bytes::{Bytes, BytesMut}; -use futures_util::future::{self, err, ok}; -use futures_util::stream::{once, Stream, StreamExt}; -use rustls::{ - internal::pemfile::{certs, pkcs8_private_keys}, - NoClientAuth, ServerConfig as RustlsServerConfig, +use std::{ + convert::{Infallible, TryFrom}, + io::{self, BufReader, Write}, + net::{SocketAddr, TcpStream as StdTcpStream}, + sync::Arc, + task::Poll, }; -use std::fs::File; -use std::io::{self, BufReader}; +use actix_http::{ + body::{BodyStream, BoxBody, SizedStream}, + error::PayloadError, + header::{self, HeaderName, HeaderValue}, + Error, HttpService, Method, Request, Response, StatusCode, Version, +}; +use actix_http_test::test_server; +use actix_rt::pin; +use actix_service::{fn_factory_with_config, fn_service}; +use actix_tls::connect::rustls::webpki_roots_cert_store; +use actix_utils::future::{err, ok, poll_fn}; +use bytes::{Bytes, BytesMut}; +use derive_more::{Display, Error}; +use futures_core::{ready, Stream}; +use futures_util::stream::once; +use rustls::{Certificate, PrivateKey, ServerConfig as RustlsServerConfig, ServerName}; +use rustls_pemfile::{certs, pkcs8_private_keys}; -async fn load_body(mut stream: S) -> Result +async fn load_body(stream: S) -> Result where - S: Stream> + Unpin, + S: Stream>, { - let mut body = BytesMut::new(); - while let Some(item) = stream.next().await { - body.extend_from_slice(&item?) - } - Ok(body) + let mut buf = BytesMut::new(); + + pin!(stream); + + poll_fn(|cx| loop { + let body = stream.as_mut(); + + match ready!(body.poll_next(cx)) { + Some(Ok(bytes)) => buf.extend_from_slice(&*bytes), + None => return Poll::Ready(Ok(())), + Some(Err(err)) => return Poll::Ready(Err(err)), + } + }) + .await?; + + Ok(buf) } fn tls_config() -> RustlsServerConfig { @@ -36,22 +55,61 @@ fn tls_config() -> RustlsServerConfig { let cert_file = cert.serialize_pem().unwrap(); let key_file = cert.serialize_private_key_pem(); - let mut config = RustlsServerConfig::new(NoClientAuth::new()); let cert_file = &mut BufReader::new(cert_file.as_bytes()); let key_file = &mut BufReader::new(key_file.as_bytes()); - let cert_chain = certs(cert_file).unwrap(); + let cert_chain = certs(cert_file) + .unwrap() + .into_iter() + .map(Certificate) + .collect(); let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + + let mut config = RustlsServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert_chain, PrivateKey(keys.remove(0))) + .unwrap(); + + config.alpn_protocols.push(HTTP1_1_ALPN_PROTOCOL.to_vec()); + config.alpn_protocols.push(H2_ALPN_PROTOCOL.to_vec()); config } +pub fn get_negotiated_alpn_protocol( + addr: SocketAddr, + client_alpn_protocol: &[u8], +) -> Option> { + let mut config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(webpki_roots_cert_store()) + .with_no_client_auth(); + + config.alpn_protocols.push(client_alpn_protocol.to_vec()); + + let mut sess = rustls::ClientConnection::new( + Arc::new(config), + ServerName::try_from("localhost").unwrap(), + ) + .unwrap(); + + let mut sock = StdTcpStream::connect(addr).unwrap(); + let mut stream = rustls::Stream::new(&mut sess, &mut sock); + + // The handshake will fails because the client will not be able to verify the server + // certificate, but it doesn't matter here as we are just interested in the negotiated ALPN + // protocol + let _ = stream.flush(); + + sess.alpn_protocol().map(|proto| proto.to_vec()) +} + #[actix_rt::test] async fn test_h1() -> io::Result<()> { let srv = test_server(move || { HttpService::build() - .h1(|_| future::ok::<_, Error>(Response::Ok().finish())) + .h1(|_| ok::<_, Error>(Response::ok())) .rustls(tls_config()) }) .await; @@ -65,7 +123,7 @@ async fn test_h1() -> io::Result<()> { async fn test_h2() -> io::Result<()> { let srv = test_server(move || { HttpService::build() - .h2(|_| future::ok::<_, Error>(Response::Ok().finish())) + .h2(|_| ok::<_, Error>(Response::ok())) .rustls(tls_config()) }) .await; @@ -82,7 +140,7 @@ async fn test_h1_1() -> io::Result<()> { .h1(|req: Request| { assert!(req.peer_addr().is_some()); assert_eq!(req.version(), Version::HTTP_11); - future::ok::<_, Error>(Response::Ok().finish()) + ok::<_, Error>(Response::ok()) }) .rustls(tls_config()) }) @@ -100,7 +158,7 @@ async fn test_h2_1() -> io::Result<()> { .finish(|req: Request| { assert!(req.peer_addr().is_some()); assert_eq!(req.version(), Version::HTTP_2); - future::ok::<_, Error>(Response::Ok().finish()) + ok::<_, Error>(Response::ok()) }) .rustls(tls_config()) }) @@ -118,7 +176,7 @@ async fn test_h2_body1() -> io::Result<()> { HttpService::build() .h2(|mut req: Request<_>| async move { let body = load_body(req.take_payload()).await?; - Ok::<_, Error>(Response::Ok().body(body)) + Ok::<_, Error>(Response::ok().set_body(body)) }) .rustls(tls_config()) }) @@ -139,14 +197,12 @@ async fn test_h2_content_length() { .h2(|req: Request| { let indx: usize = req.uri().path()[1..].parse().unwrap(); let statuses = [ - StatusCode::NO_CONTENT, StatusCode::CONTINUE, - StatusCode::SWITCHING_PROTOCOLS, - StatusCode::PROCESSING, + StatusCode::NO_CONTENT, StatusCode::OK, StatusCode::NOT_FOUND, ]; - future::ok::<_, ()>(Response::new(statuses[indx])) + ok::<_, Infallible>(Response::new(statuses[indx])) }) .rustls(tls_config()) }) @@ -154,22 +210,31 @@ async fn test_h2_content_length() { let header = HeaderName::from_static("content-length"); let value = HeaderValue::from_static("0"); + { - for i in 0..4 { + for &i in &[0] { + let req = srv + .request(Method::HEAD, srv.surl(&format!("/{}", i))) + .send(); + let _response = req.await.expect_err("should timeout on recv 1xx frame"); + // assert_eq!(response.headers().get(&header), None); + + let req = srv + .request(Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let _response = req.await.expect_err("should timeout on recv 1xx frame"); + // assert_eq!(response.headers().get(&header), None); + } + + for &i in &[1] { let req = srv .request(Method::GET, srv.surl(&format!("/{}", i))) .send(); let response = req.await.unwrap(); assert_eq!(response.headers().get(&header), None); - - let req = srv - .request(Method::HEAD, srv.surl(&format!("/{}", i))) - .send(); - let response = req.await.unwrap(); - assert_eq!(response.headers().get(&header), None); } - for i in 4..6 { + for &i in &[2, 3] { let req = srv .request(Method::GET, srv.surl(&format!("/{}", i))) .send(); @@ -186,10 +251,11 @@ async fn test_h2_headers() { let mut srv = test_server(move || { let data = data.clone(); - HttpService::build().h2(move |_| { - let mut config = Response::Ok(); - for idx in 0..90 { - config.insert_header(( + HttpService::build() + .h2(move |_| { + let mut config = Response::build(StatusCode::OK); + for idx in 0..90 { + config.insert_header(( format!("X-TEST-{}", idx).as_str(), "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ @@ -205,11 +271,12 @@ async fn test_h2_headers() { TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", )); - } - future::ok::<_, ()>(config.body(data.clone())) - }) + } + ok::<_, Infallible>(config.body(data.clone())) + }) .rustls(tls_config()) - }).await; + }) + .await; let response = srv.sget("/").send().await.unwrap(); assert!(response.status().is_success()); @@ -245,7 +312,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ async fn test_h2_body2() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .rustls(tls_config()) }) .await; @@ -262,7 +329,7 @@ async fn test_h2_body2() { async fn test_h2_head_empty() { let mut srv = test_server(move || { HttpService::build() - .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .finish(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .rustls(tls_config()) }) .await; @@ -288,7 +355,7 @@ async fn test_h2_head_empty() { async fn test_h2_head_binary() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .rustls(tls_config()) }) .await; @@ -313,7 +380,7 @@ async fn test_h2_head_binary() { async fn test_h2_head_binary2() { let srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .rustls(tls_config()) }) .await; @@ -335,9 +402,9 @@ async fn test_h2_body_length() { let mut srv = test_server(move || { HttpService::build() .h2(|_| { - let body = once(ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), + let body = once(ok::<_, Infallible>(Bytes::from_static(STR.as_ref()))); + ok::<_, Infallible>( + Response::ok().set_body(SizedStream::new(STR.len() as u64, body)), ) }) .rustls(tls_config()) @@ -358,10 +425,10 @@ async fn test_h2_body_chunked_explicit() { HttpService::build() .h2(|_| { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() + ok::<_, Infallible>( + Response::build(StatusCode::OK) .insert_header((header::TRANSFER_ENCODING, "chunked")) - .streaming(body), + .body(BodyStream::new(body)), ) }) .rustls(tls_config()) @@ -384,10 +451,10 @@ async fn test_h2_response_http_error_handling() { let mut srv = test_server(move || { HttpService::build() .h2(fn_factory_with_config(|_: ()| { - ok::<_, ()>(fn_service(|_| { + ok::<_, Infallible>(fn_service(|_| { let broken_header = Bytes::from_static(b"\0\0\0"); - ok::<_, ()>( - Response::Ok() + ok::<_, Infallible>( + Response::build(StatusCode::OK) .insert_header((http::header::CONTENT_TYPE, broken_header)) .body(STR), ) @@ -402,14 +469,27 @@ async fn test_h2_response_http_error_handling() { // read response let bytes = srv.load_body(response).await.unwrap(); - assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); + assert_eq!( + bytes, + Bytes::from_static(b"error processing HTTP: failed to parse header value") + ); +} + +#[derive(Debug, Display, Error)] +#[display(fmt = "error")] +struct BadRequest; + +impl From for Response { + fn from(_: BadRequest) -> Self { + Response::bad_request().set_body(BoxBody::new("error")) + } } #[actix_rt::test] async fn test_h2_service_error() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| err::(error::ErrorBadRequest("error"))) + .h2(|_| err::, _>(BadRequest)) .rustls(tls_config()) }) .await; @@ -426,7 +506,7 @@ async fn test_h2_service_error() { async fn test_h1_service_error() { let mut srv = test_server(move || { HttpService::build() - .h1(|_| err::(error::ErrorBadRequest("error"))) + .h1(|_| err::, _>(BadRequest)) .rustls(tls_config()) }) .await; @@ -438,3 +518,85 @@ async fn test_h1_service_error() { let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(b"error")); } + +const H2_ALPN_PROTOCOL: &[u8] = b"h2"; +const HTTP1_1_ALPN_PROTOCOL: &[u8] = b"http/1.1"; +const CUSTOM_ALPN_PROTOCOL: &[u8] = b"custom"; + +#[actix_rt::test] +async fn test_alpn_h1() -> io::Result<()> { + let srv = test_server(move || { + let mut config = tls_config(); + config.alpn_protocols.push(CUSTOM_ALPN_PROTOCOL.to_vec()); + HttpService::build() + .h1(|_| ok::<_, Error>(Response::ok())) + .rustls(config) + }) + .await; + + assert_eq!( + get_negotiated_alpn_protocol(srv.addr(), CUSTOM_ALPN_PROTOCOL), + Some(CUSTOM_ALPN_PROTOCOL.to_vec()) + ); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + Ok(()) +} + +#[actix_rt::test] +async fn test_alpn_h2() -> io::Result<()> { + let srv = test_server(move || { + let mut config = tls_config(); + config.alpn_protocols.push(CUSTOM_ALPN_PROTOCOL.to_vec()); + HttpService::build() + .h2(|_| ok::<_, Error>(Response::ok())) + .rustls(config) + }) + .await; + + assert_eq!( + get_negotiated_alpn_protocol(srv.addr(), H2_ALPN_PROTOCOL), + Some(H2_ALPN_PROTOCOL.to_vec()) + ); + assert_eq!( + get_negotiated_alpn_protocol(srv.addr(), CUSTOM_ALPN_PROTOCOL), + Some(CUSTOM_ALPN_PROTOCOL.to_vec()) + ); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + Ok(()) +} + +#[actix_rt::test] +async fn test_alpn_h2_1() -> io::Result<()> { + let srv = test_server(move || { + let mut config = tls_config(); + config.alpn_protocols.push(CUSTOM_ALPN_PROTOCOL.to_vec()); + HttpService::build() + .finish(|_| ok::<_, Error>(Response::ok())) + .rustls(config) + }) + .await; + + assert_eq!( + get_negotiated_alpn_protocol(srv.addr(), H2_ALPN_PROTOCOL), + Some(H2_ALPN_PROTOCOL.to_vec()) + ); + assert_eq!( + get_negotiated_alpn_protocol(srv.addr(), HTTP1_1_ALPN_PROTOCOL), + Some(HTTP1_1_ALPN_PROTOCOL.to_vec()) + ); + assert_eq!( + get_negotiated_alpn_protocol(srv.addr(), CUSTOM_ALPN_PROTOCOL), + Some(CUSTOM_ALPN_PROTOCOL.to_vec()) + ); + + let response = srv.sget("/").send().await.unwrap(); + assert!(response.status().is_success()); + + Ok(()) +} diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs index 910fa81f2..1bb574fd6 100644 --- a/actix-http/tests/test_server.rs +++ b/actix-http/tests/test_server.rs @@ -1,30 +1,36 @@ -use std::io::{Read, Write}; -use std::time::Duration; -use std::{net, thread}; +use std::{ + convert::Infallible, + io::{Read, Write}, + net, thread, + time::Duration, +}; +use actix_http::{ + body::{self, BodyStream, BoxBody, SizedStream}, + header, Error, HttpService, KeepAlive, Request, Response, StatusCode, +}; use actix_http_test::test_server; use actix_rt::time::sleep; use actix_service::fn_service; +use actix_utils::future::{err, ok, ready}; use bytes::Bytes; -use futures_util::future::{self, err, ok, ready, FutureExt}; -use futures_util::stream::{once, StreamExt}; -use regex::Regex; - -use actix_http::HttpMessage; -use actix_http::{ - body, error, http, http::header, Error, HttpService, KeepAlive, Request, Response, +use derive_more::{Display, Error}; +use futures_util::{ + stream::{once, StreamExt as _}, + FutureExt as _, }; +use regex::Regex; #[actix_rt::test] async fn test_h1() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .keep_alive(KeepAlive::Disabled) .client_timeout(1000) .client_disconnect(1000) .h1(|req: Request| { assert!(req.peer_addr().is_some()); - future::ok::<_, ()>(Response::Ok().finish()) + ok::<_, Infallible>(Response::ok()) }) .tcp() }) @@ -32,11 +38,13 @@ async fn test_h1() { let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); + + srv.stop().await; } #[actix_rt::test] async fn test_h1_2() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .keep_alive(KeepAlive::Disabled) .client_timeout(1000) @@ -44,7 +52,7 @@ async fn test_h1_2() { .finish(|req: Request| { assert!(req.peer_addr().is_some()); assert_eq!(req.version(), http::Version::HTTP_11); - future::ok::<_, ()>(Response::Ok().finish()) + ok::<_, Infallible>(Response::ok()) }) .tcp() }) @@ -52,20 +60,32 @@ async fn test_h1_2() { let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); + + srv.stop().await; +} + +#[derive(Debug, Display, Error)] +#[display(fmt = "expect failed")] +struct ExpectFailed; + +impl From for Response { + fn from(_: ExpectFailed) -> Self { + Response::new(StatusCode::EXPECTATION_FAILED) + } } #[actix_rt::test] async fn test_expect_continue() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .expect(fn_service(|req: Request| { if req.head().uri.query() == Some("yes=") { ok(req) } else { - err(error::ErrorPreconditionFailed("error")) + err(ExpectFailed) } })) - .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + .finish(|_| ok::<_, Infallible>(Response::ok())) .tcp() }) .await; @@ -74,29 +94,31 @@ async fn test_expect_continue() { let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); let mut data = String::new(); let _ = stream.read_to_string(&mut data); - assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length")); + assert!(data.starts_with("HTTP/1.1 417 Expectation Failed\r\ncontent-length")); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); let mut data = String::new(); let _ = stream.read_to_string(&mut data); assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); + + srv.stop().await; } #[actix_rt::test] async fn test_expect_continue_h1() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .expect(fn_service(|req: Request| { sleep(Duration::from_millis(20)).then(move |_| { if req.head().uri.query() == Some("yes=") { ok(req) } else { - err(error::ErrorPreconditionFailed("error")) + err(ExpectFailed) } }) })) - .h1(fn_service(|_| future::ok::<_, ()>(Response::Ok().finish()))) + .h1(fn_service(|_| ok::<_, Infallible>(Response::ok()))) .tcp() }) .await; @@ -105,13 +127,15 @@ async fn test_expect_continue_h1() { let _ = stream.write_all(b"GET /test HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); let mut data = String::new(); let _ = stream.read_to_string(&mut data); - assert!(data.starts_with("HTTP/1.1 412 Precondition Failed\r\ncontent-length")); + assert!(data.starts_with("HTTP/1.1 417 Expectation Failed\r\ncontent-length")); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); let mut data = String::new(); let _ = stream.read_to_string(&mut data); assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); + + srv.stop().await; } #[actix_rt::test] @@ -119,18 +143,18 @@ async fn test_chunked_payload() { let chunk_sizes = vec![32768, 32, 32768]; let total_size: usize = chunk_sizes.iter().sum(); - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .h1(fn_service(|mut request: Request| { request .take_payload() .map(|res| match res { Ok(pl) => pl, - Err(e) => panic!(format!("Error reading payload: {}", e)), + Err(e) => panic!("Error reading payload: {}", e), }) .fold(0usize, |acc, chunk| ready(acc + chunk.len())) .map(|req_size| { - Ok::<_, Error>(Response::Ok().body(format!("size={}", req_size))) + Ok::<_, Error>(Response::ok().set_body(format!("size={}", req_size))) }) })) .tcp() @@ -139,8 +163,7 @@ async fn test_chunked_payload() { let returned_size = { let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream - .write_all(b"POST /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"); + let _ = stream.write_all(b"POST /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"); for chunk_size in chunk_sizes.iter() { let mut bytes = Vec::new(); @@ -162,20 +185,23 @@ async fn test_chunked_payload() { let re = Regex::new(r"size=(\d+)").unwrap(); let size: usize = match re.captures(&data) { Some(caps) => caps.get(1).unwrap().as_str().parse().unwrap(), - None => panic!(format!("Failed to find size in HTTP Response: {}", data)), + None => panic!("Failed to find size in HTTP Response: {}", data), }; + size }; assert_eq!(returned_size, total_size); + + srv.stop().await; } #[actix_rt::test] async fn test_slow_request() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .client_timeout(100) - .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + .finish(|_| ok::<_, Infallible>(Response::ok())) .tcp() }) .await; @@ -185,13 +211,15 @@ async fn test_slow_request() { let mut data = String::new(); let _ = stream.read_to_string(&mut data); assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); + + srv.stop().await; } #[actix_rt::test] async fn test_http1_malformed_request() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() }) .await; @@ -201,13 +229,15 @@ async fn test_http1_malformed_request() { let mut data = String::new(); let _ = stream.read_to_string(&mut data); assert!(data.starts_with("HTTP/1.1 400 Bad Request")); + + srv.stop().await; } #[actix_rt::test] async fn test_http1_keepalive() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() }) .await; @@ -222,14 +252,16 @@ async fn test_http1_keepalive() { let mut data = vec![0; 1024]; let _ = stream.read(&mut data); assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + + srv.stop().await; } #[actix_rt::test] async fn test_http1_keepalive_timeout() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .keep_alive(1) - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() }) .await; @@ -244,20 +276,21 @@ async fn test_http1_keepalive_timeout() { let mut data = vec![0; 1024]; let res = stream.read(&mut data).unwrap(); assert_eq!(res, 0); + + srv.stop().await; } #[actix_rt::test] async fn test_http1_keepalive_close() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() }) .await; let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = - stream.write_all(b"GET /test/tests/test HTTP/1.1\r\nconnection: close\r\n\r\n"); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\nconnection: close\r\n\r\n"); let mut data = vec![0; 1024]; let _ = stream.read(&mut data); assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); @@ -265,13 +298,15 @@ async fn test_http1_keepalive_close() { let mut data = vec![0; 1024]; let res = stream.read(&mut data).unwrap(); assert_eq!(res, 0); + + srv.stop().await; } #[actix_rt::test] async fn test_http10_keepalive_default_close() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() }) .await; @@ -285,20 +320,22 @@ async fn test_http10_keepalive_default_close() { let mut data = vec![0; 1024]; let res = stream.read(&mut data).unwrap(); assert_eq!(res, 0); + + srv.stop().await; } #[actix_rt::test] async fn test_http10_keepalive() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() }) .await; let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); - let _ = stream - .write_all(b"GET /test/tests/test HTTP/1.0\r\nconnection: keep-alive\r\n\r\n"); + let _ = + stream.write_all(b"GET /test/tests/test HTTP/1.0\r\nconnection: keep-alive\r\n\r\n"); let mut data = vec![0; 1024]; let _ = stream.read(&mut data); assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); @@ -312,14 +349,16 @@ async fn test_http10_keepalive() { let mut data = vec![0; 1024]; let res = stream.read(&mut data).unwrap(); assert_eq!(res, 0); + + srv.stop().await; } #[actix_rt::test] async fn test_http1_keepalive_disabled() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .keep_alive(KeepAlive::Disabled) - .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, Infallible>(Response::ok())) .tcp() }) .await; @@ -333,16 +372,18 @@ async fn test_http1_keepalive_disabled() { let mut data = vec![0; 1024]; let res = stream.read(&mut data).unwrap(); assert_eq!(res, 0); + + srv.stop().await; } #[actix_rt::test] async fn test_content_length() { - use actix_http::http::{ + use actix_http::{ header::{HeaderName, HeaderValue}, StatusCode, }; - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .h1(|req: Request| { let indx: usize = req.uri().path()[1..].parse().unwrap(); @@ -354,7 +395,7 @@ async fn test_content_length() { StatusCode::OK, StatusCode::NOT_FOUND, ]; - future::ok::<_, ()>(Response::new(statuses[indx])) + ok::<_, Infallible>(Response::new(statuses[indx])) }) .tcp() }) @@ -380,6 +421,8 @@ async fn test_content_length() { assert_eq!(response.headers().get(&header), Some(&value)); } } + + srv.stop().await; } #[actix_rt::test] @@ -389,10 +432,11 @@ async fn test_h1_headers() { let mut srv = test_server(move || { let data = data.clone(); - HttpService::build().h1(move |_| { - let mut builder = Response::Ok(); - for idx in 0..90 { - builder.insert_header(( + HttpService::build() + .h1(move |_| { + let mut builder = Response::build(StatusCode::OK); + for idx in 0..90 { + builder.insert_header(( format!("X-TEST-{}", idx).as_str(), "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ @@ -408,10 +452,12 @@ async fn test_h1_headers() { TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", )); - } - future::ok::<_, ()>(builder.body(data.clone())) - }).tcp() - }).await; + } + ok::<_, Infallible>(builder.body(data.clone())) + }) + .tcp() + }) + .await; let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); @@ -419,6 +465,8 @@ async fn test_h1_headers() { // read response let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from(data2)); + + srv.stop().await; } const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ @@ -447,7 +495,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ async fn test_h1_body() { let mut srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h1(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .tcp() }) .await; @@ -458,13 +506,15 @@ async fn test_h1_body() { // read response let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] async fn test_h1_head_empty() { let mut srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h1(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .tcp() }) .await; @@ -483,13 +533,15 @@ async fn test_h1_head_empty() { // read response let bytes = srv.load_body(response).await.unwrap(); assert!(bytes.is_empty()); + + srv.stop().await; } #[actix_rt::test] async fn test_h1_head_binary() { let mut srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h1(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .tcp() }) .await; @@ -508,13 +560,15 @@ async fn test_h1_head_binary() { // read response let bytes = srv.load_body(response).await.unwrap(); assert!(bytes.is_empty()); + + srv.stop().await; } #[actix_rt::test] async fn test_h1_head_binary2() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h1(|_| ok::<_, Infallible>(Response::ok().set_body(STR))) .tcp() }) .await; @@ -529,6 +583,8 @@ async fn test_h1_head_binary2() { .unwrap(); assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); } + + srv.stop().await; } #[actix_rt::test] @@ -536,9 +592,9 @@ async fn test_h1_body_length() { let mut srv = test_server(|| { HttpService::build() .h1(|_| { - let body = once(ok(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), + let body = once(ok::<_, Infallible>(Bytes::from_static(STR.as_ref()))); + ok::<_, Infallible>( + Response::ok().set_body(SizedStream::new(STR.len() as u64, body)), ) }) .tcp() @@ -551,6 +607,8 @@ async fn test_h1_body_length() { // read response let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -559,10 +617,10 @@ async fn test_h1_body_chunked_explicit() { HttpService::build() .h1(|_| { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>( - Response::Ok() + ok::<_, Infallible>( + Response::build(StatusCode::OK) .insert_header((header::TRANSFER_ENCODING, "chunked")) - .streaming(body), + .body(BodyStream::new(body)), ) }) .tcp() @@ -586,6 +644,8 @@ async fn test_h1_body_chunked_explicit() { // decode assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -594,7 +654,7 @@ async fn test_h1_body_chunked_implicit() { HttpService::build() .h1(|_| { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>(Response::Ok().streaming(body)) + ok::<_, Infallible>(Response::build(StatusCode::OK).body(BodyStream::new(body))) }) .tcp() }) @@ -615,6 +675,8 @@ async fn test_h1_body_chunked_implicit() { // read response let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -623,8 +685,8 @@ async fn test_h1_response_http_error_handling() { HttpService::build() .h1(fn_service(|_| { let broken_header = Bytes::from_static(b"\0\0\0"); - ok::<_, ()>( - Response::Ok() + ok::<_, Infallible>( + Response::build(StatusCode::OK) .insert_header((http::header::CONTENT_TYPE, broken_header)) .body(STR), ) @@ -638,14 +700,29 @@ async fn test_h1_response_http_error_handling() { // read response let bytes = srv.load_body(response).await.unwrap(); - assert_eq!(bytes, Bytes::from_static(b"failed to parse header value")); + assert_eq!( + bytes, + Bytes::from_static(b"error processing HTTP: failed to parse header value") + ); + + srv.stop().await; +} + +#[derive(Debug, Display, Error)] +#[display(fmt = "error")] +struct BadRequest; + +impl From for Response { + fn from(_: BadRequest) -> Self { + Response::bad_request().set_body(BoxBody::new("error")) + } } #[actix_rt::test] async fn test_h1_service_error() { let mut srv = test_server(|| { HttpService::build() - .h1(|_| future::err::(error::ErrorBadRequest("error"))) + .h1(|_| err::, _>(BadRequest)) .tcp() }) .await; @@ -656,18 +733,20 @@ async fn test_h1_service_error() { // read response let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(b"error")); + + srv.stop().await; } #[actix_rt::test] async fn test_h1_on_connect() { - let srv = test_server(|| { + let mut srv = test_server(|| { HttpService::build() .on_connect_ext(|_, data| { data.insert(20isize); }) .h1(|req: Request| { - assert!(req.extensions().contains::()); - future::ok::<_, ()>(Response::Ok().finish()) + assert!(req.conn_data::().is_some()); + ok::<_, Infallible>(Response::ok()) }) .tcp() }) @@ -675,4 +754,87 @@ async fn test_h1_on_connect() { let response = srv.get("/").send().await.unwrap(); assert!(response.status().is_success()); + + srv.stop().await; +} + +/// Tests compliance with 304 Not Modified spec in RFC 7232 §4.1. +/// https://datatracker.ietf.org/doc/html/rfc7232#section-4.1 +#[actix_rt::test] +async fn test_not_modified_spec_h1() { + // TODO: this test needing a few seconds to complete reveals some weirdness with either the + // dispatcher or the client, though similar hangs occur on other tests in this file, only + // succeeding, it seems, because of the keepalive timer + + static CL: header::HeaderName = header::CONTENT_LENGTH; + + let mut srv = test_server(|| { + HttpService::build() + .h1(|req: Request| { + let res: Response = match req.path() { + // with no content-length + "/none" => Response::with_body(StatusCode::NOT_MODIFIED, body::None::new()) + .map_into_boxed_body(), + + // with no content-length + "/body" => Response::with_body(StatusCode::NOT_MODIFIED, "1234") + .map_into_boxed_body(), + + // with manual content-length header and specific None body + "/cl-none" => { + let mut res = + Response::with_body(StatusCode::NOT_MODIFIED, body::None::new()); + res.headers_mut() + .insert(CL.clone(), header::HeaderValue::from_static("24")); + res.map_into_boxed_body() + } + + // with manual content-length header and ignore-able body + "/cl-body" => { + let mut res = Response::with_body(StatusCode::NOT_MODIFIED, "1234"); + res.headers_mut() + .insert(CL.clone(), header::HeaderValue::from_static("4")); + res.map_into_boxed_body() + } + + _ => panic!("unknown route"), + }; + + ok::<_, Infallible>(res) + }) + .tcp() + }) + .await; + + let res = srv.get("/none").send().await.unwrap(); + assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED); + assert_eq!(res.headers().get(&CL), None); + assert!(srv.load_body(res).await.unwrap().is_empty()); + + let res = srv.get("/body").send().await.unwrap(); + assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED); + assert_eq!(res.headers().get(&CL), None); + assert!(srv.load_body(res).await.unwrap().is_empty()); + + let res = srv.get("/cl-none").send().await.unwrap(); + assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED); + assert_eq!( + res.headers().get(&CL), + Some(&header::HeaderValue::from_static("24")), + ); + assert!(srv.load_body(res).await.unwrap().is_empty()); + + let res = srv.get("/cl-body").send().await.unwrap(); + assert_eq!(res.status(), http::StatusCode::NOT_MODIFIED); + assert_eq!( + res.headers().get(&CL), + Some(&header::HeaderValue::from_static("4")), + ); + // server does not prevent payload from being sent but clients may choose not to read it + // TODO: this is probably a bug, especially since CL header can differ in length from the body + assert!(!srv.load_body(res).await.unwrap().is_empty()); + + // TODO: add stream response tests + + srv.stop().await; } diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs index 7ed9b0df1..ed8c61fd6 100644 --- a/actix-http/tests/test_ws.rs +++ b/actix-http/tests/test_ws.rs @@ -1,192 +1,195 @@ -use std::cell::Cell; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::{ + cell::Cell, + convert::Infallible, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use actix_http::{body, h1, ws, Error, HttpService, Request, Response}; +use actix_http::{ + body::{BodySize, BoxBody}, + h1, + ws::{self, CloseCode, Frame, Item, Message}, + Error, HttpService, Request, Response, +}; use actix_http_test::test_server; use actix_service::{fn_factory, Service}; -use actix_utils::dispatcher::Dispatcher; use bytes::Bytes; -use futures_util::future; -use futures_util::task::{Context, Poll}; -use futures_util::{SinkExt, StreamExt}; +use derive_more::{Display, Error, From}; +use futures_core::future::LocalBoxFuture; +use futures_util::{SinkExt as _, StreamExt as _}; -struct WsService(Arc, Cell)>>); +#[derive(Clone)] +struct WsService(Cell); -impl WsService { +impl WsService { fn new() -> Self { - WsService(Arc::new(Mutex::new((PhantomData, Cell::new(false))))) + WsService(Cell::new(false)) } fn set_polled(&self) { - *self.0.lock().unwrap().1.get_mut() = true; + self.0.set(true); } fn was_polled(&self) -> bool { - self.0.lock().unwrap().1.get() + self.0.get() } } -impl Clone for WsService { - fn clone(&self) -> Self { - WsService(self.0.clone()) +#[derive(Debug, Display, Error, From)] +enum WsServiceError { + #[display(fmt = "http error")] + Http(actix_http::Error), + + #[display(fmt = "ws handshake error")] + Ws(actix_http::ws::HandshakeError), + + #[display(fmt = "io error")] + Io(std::io::Error), + + #[display(fmt = "dispatcher error")] + Dispatcher, +} + +impl From for Response { + fn from(err: WsServiceError) -> Self { + match err { + WsServiceError::Http(err) => err.into(), + WsServiceError::Ws(err) => err.into(), + WsServiceError::Io(_err) => unreachable!(), + WsServiceError::Dispatcher => { + Response::internal_server_error().set_body(BoxBody::new(format!("{}", err))) + } + } } } -impl Service<(Request, Framed)> for WsService +impl Service<(Request, Framed)> for WsService where T: AsyncRead + AsyncWrite + Unpin + 'static, { type Response = (); - type Error = Error; - type Future = Pin>>>; + type Error = WsServiceError; + type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&self, _ctx: &mut Context<'_>) -> Poll> { + fn poll_ready(&self, _: &mut Context<'_>) -> Poll> { self.set_polled(); Poll::Ready(Ok(())) } fn call(&self, (req, mut framed): (Request, Framed)) -> Self::Future { - let fut = async move { - let res = ws::handshake(req.head()).unwrap().message_body(()); + assert!(self.was_polled()); - framed - .send((res, body::BodySize::None).into()) + Box::pin(async move { + let res = ws::handshake(req.head())?.message_body(())?; + + framed.send((res, BodySize::None).into()).await?; + + let framed = framed.replace_codec(ws::Codec::new()); + + ws::Dispatcher::with(framed, service) .await - .unwrap(); + .map_err(|_| WsServiceError::Dispatcher)?; - Dispatcher::new(framed.replace_codec(ws::Codec::new()), service) - .await - .map_err(|_| panic!()) - }; - - Box::pin(fut) + Ok(()) + }) } } -async fn service(msg: ws::Frame) -> Result { +async fn service(msg: Frame) -> Result { let msg = match msg { - ws::Frame::Ping(msg) => ws::Message::Pong(msg), - ws::Frame::Text(text) => { - ws::Message::Text(String::from_utf8_lossy(&text).into_owned().into()) - } - ws::Frame::Binary(bin) => ws::Message::Binary(bin), - ws::Frame::Continuation(item) => ws::Message::Continuation(item), - ws::Frame::Close(reason) => ws::Message::Close(reason), - _ => panic!(), + Frame::Ping(msg) => Message::Pong(msg), + Frame::Text(text) => Message::Text(String::from_utf8_lossy(&text).into_owned().into()), + Frame::Binary(bin) => Message::Binary(bin), + Frame::Continuation(item) => Message::Continuation(item), + Frame::Close(reason) => Message::Close(reason), + _ => return Err(ws::ProtocolError::BadOpCode.into()), }; + Ok(msg) } #[actix_rt::test] async fn test_simple() { - let ws_service = WsService::new(); - let mut srv = test_server({ - let ws_service = ws_service.clone(); - move || { - let ws_service = ws_service.clone(); - HttpService::build() - .upgrade(fn_factory(move || future::ok::<_, ()>(ws_service.clone()))) - .finish(|_| future::ok::<_, ()>(Response::NotFound())) - .tcp() - } + let mut srv = test_server(|| { + HttpService::build() + .upgrade(fn_factory(|| async { + Ok::<_, Infallible>(WsService::new()) + })) + .finish(|_| async { Ok::<_, Infallible>(Response::not_found()) }) + .tcp() }) .await; // client service let mut framed = srv.ws().await.unwrap(); - framed.send(ws::Message::Text("text".into())).await.unwrap(); - let (item, mut framed) = framed.into_future().await; - assert_eq!( - item.unwrap().unwrap(), - ws::Frame::Text(Bytes::from_static(b"text")) - ); + framed.send(Message::Text("text".into())).await.unwrap(); + + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, Frame::Text(Bytes::from_static(b"text"))); + + framed.send(Message::Binary("text".into())).await.unwrap(); + + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, Frame::Binary(Bytes::from_static(&b"text"[..]))); + + framed.send(Message::Ping("text".into())).await.unwrap(); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, Frame::Pong("text".to_string().into())); framed - .send(ws::Message::Binary("text".into())) + .send(Message::Continuation(Item::FirstText("text".into()))) .await .unwrap(); - let (item, mut framed) = framed.into_future().await; + let item = framed.next().await.unwrap().unwrap(); assert_eq!( - item.unwrap().unwrap(), - ws::Frame::Binary(Bytes::from_static(&b"text"[..])) - ); - - framed.send(ws::Message::Ping("text".into())).await.unwrap(); - let (item, mut framed) = framed.into_future().await; - assert_eq!( - item.unwrap().unwrap(), - ws::Frame::Pong("text".to_string().into()) - ); - - framed - .send(ws::Message::Continuation(ws::Item::FirstText( - "text".into(), - ))) - .await - .unwrap(); - let (item, mut framed) = framed.into_future().await; - assert_eq!( - item.unwrap().unwrap(), - ws::Frame::Continuation(ws::Item::FirstText(Bytes::from_static(b"text"))) + item, + Frame::Continuation(Item::FirstText(Bytes::from_static(b"text"))) ); assert!(framed - .send(ws::Message::Continuation(ws::Item::FirstText( - "text".into() - ))) + .send(Message::Continuation(Item::FirstText("text".into()))) .await .is_err()); assert!(framed - .send(ws::Message::Continuation(ws::Item::FirstBinary( - "text".into() - ))) + .send(Message::Continuation(Item::FirstBinary("text".into()))) .await .is_err()); framed - .send(ws::Message::Continuation(ws::Item::Continue("text".into()))) + .send(Message::Continuation(Item::Continue("text".into()))) .await .unwrap(); - let (item, mut framed) = framed.into_future().await; + let item = framed.next().await.unwrap().unwrap(); assert_eq!( - item.unwrap().unwrap(), - ws::Frame::Continuation(ws::Item::Continue(Bytes::from_static(b"text"))) + item, + Frame::Continuation(Item::Continue(Bytes::from_static(b"text"))) ); framed - .send(ws::Message::Continuation(ws::Item::Last("text".into()))) + .send(Message::Continuation(Item::Last("text".into()))) .await .unwrap(); - let (item, mut framed) = framed.into_future().await; + let item = framed.next().await.unwrap().unwrap(); assert_eq!( - item.unwrap().unwrap(), - ws::Frame::Continuation(ws::Item::Last(Bytes::from_static(b"text"))) + item, + Frame::Continuation(Item::Last(Bytes::from_static(b"text"))) ); assert!(framed - .send(ws::Message::Continuation(ws::Item::Continue("text".into()))) + .send(Message::Continuation(Item::Continue("text".into()))) .await .is_err()); assert!(framed - .send(ws::Message::Continuation(ws::Item::Last("text".into()))) + .send(Message::Continuation(Item::Last("text".into()))) .await .is_err()); framed - .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + .send(Message::Close(Some(CloseCode::Normal.into()))) .await .unwrap(); - let (item, _framed) = framed.into_future().await; - assert_eq!( - item.unwrap().unwrap(), - ws::Frame::Close(Some(ws::CloseCode::Normal.into())) - ); - - assert!(ws_service.was_polled()); + let item = framed.next().await.unwrap().unwrap(); + assert_eq!(item, Frame::Close(Some(CloseCode::Normal.into()))); } diff --git a/actix-multipart/CHANGES.md b/actix-multipart/CHANGES.md index 2142ebf4b..92feade3b 100644 --- a/actix-multipart/CHANGES.md +++ b/actix-multipart/CHANGES.md @@ -3,79 +3,128 @@ ## Unreleased - 2021-xx-xx +## 0.4.0-beta.12 - 2022-01-04 +- Minimum supported Rust version (MSRV) is now 1.54. + + +## 0.4.0-beta.11 - 2021-12-27 +- No significant changes since `0.4.0-beta.10`. + + +## 0.4.0-beta.10 - 2021-12-11 +- No significant changes since `0.4.0-beta.9`. + + +## 0.4.0-beta.9 - 2021-12-01 +- Polling `Field` after dropping `Multipart` now fails immediately instead of hanging forever. [#2463] + +[#2463]: https://github.com/actix/actix-web/pull/2463 + + +## 0.4.0-beta.8 - 2021-11-22 +- Ensure a correct Content-Disposition header is included in every part of a multipart message. [#2451] +- Added `MultipartError::NoContentDisposition` variant. [#2451] +- Since Content-Disposition is now ensured, `Field::content_disposition` is now infallible. [#2451] +- Added `Field::name` method for getting the field name. [#2451] +- `MultipartError` now marks variants with inner errors as the source. [#2451] +- `MultipartError` is now marked as non-exhaustive. [#2451] + +[#2451]: https://github.com/actix/actix-web/pull/2451 + + +## 0.4.0-beta.7 - 2021-10-20 +- Minimum supported Rust version (MSRV) is now 1.52. + + +## 0.4.0-beta.6 - 2021-09-09 +- Minimum supported Rust version (MSRV) is now 1.51. + + +## 0.4.0-beta.5 - 2021-06-17 +- No notable changes. + + +## 0.4.0-beta.4 - 2021-04-02 +- No notable changes. + + +## 0.4.0-beta.3 - 2021-03-09 +- No notable changes. + + ## 0.4.0-beta.2 - 2021-02-10 -* No notable changes. +- No notable changes. ## 0.4.0-beta.1 - 2021-01-07 -* Fix multipart consuming payload before header checks. [#1513] -* Update `bytes` to `1.0`. [#1813] +- Fix multipart consuming payload before header checks. [#1513] +- Update `bytes` to `1.0`. [#1813] [#1813]: https://github.com/actix/actix-web/pull/1813 [#1513]: https://github.com/actix/actix-web/pull/1513 ## 0.3.0 - 2020-09-11 -* No significant changes from `0.3.0-beta.2`. +- No significant changes from `0.3.0-beta.2`. ## 0.3.0-beta.2 - 2020-09-10 -* Update `actix-*` dependencies to latest versions. +- Update `actix-*` dependencies to latest versions. ## 0.3.0-beta.1 - 2020-07-15 -* Update `actix-web` to 3.0.0-beta.1 +- Update `actix-web` to 3.0.0-beta.1 ## 0.3.0-alpha.1 - 2020-05-25 -* Update `actix-web` to 3.0.0-alpha.3 -* Bump minimum supported Rust version to 1.40 -* Minimize `futures` dependencies -* Remove the unused `time` dependency -* Fix missing `std::error::Error` implement for `MultipartError`. +- Update `actix-web` to 3.0.0-alpha.3 +- Bump minimum supported Rust version to 1.40 +- Minimize `futures` dependencies +- Remove the unused `time` dependency +- Fix missing `std::error::Error` implement for `MultipartError`. ## [0.2.0] - 2019-12-20 -* Release +- Release ## [0.2.0-alpha.4] - 2019-12-xx -* Multipart handling now handles Pending during read of boundary #1205 +- Multipart handling now handles Pending during read of boundary #1205 ## [0.2.0-alpha.2] - 2019-12-03 -* Migrate to `std::future` +- Migrate to `std::future` ## [0.1.4] - 2019-09-12 -* Multipart handling now parses requests which do not end in CRLF #1038 +- Multipart handling now parses requests which do not end in CRLF #1038 ## [0.1.3] - 2019-08-18 -* Fix ring dependency from actix-web default features for #741. +- Fix ring dependency from actix-web default features for #741. ## [0.1.2] - 2019-06-02 -* Fix boundary parsing #876 +- Fix boundary parsing #876 ## [0.1.1] - 2019-05-25 -* Fix disconnect handling #834 +- Fix disconnect handling #834 ## [0.1.0] - 2019-05-18 -* Release +- Release ## [0.1.0-beta.4] - 2019-05-12 -* Handle cancellation of uploads #736 +- Handle cancellation of uploads #736 -* Upgrade to actix-web 1.0.0-beta.4 +- Upgrade to actix-web 1.0.0-beta.4 ## [0.1.0-beta.1] - 2019-04-21 -* Do not support nested multipart +- Do not support nested multipart -* Split multipart support to separate crate +- Split multipart support to separate crate -* Optimize multipart handling #634, #769 +- Optimize multipart handling #634, #769 diff --git a/actix-multipart/Cargo.toml b/actix-multipart/Cargo.toml index 67b2698bd..4f41caf44 100644 --- a/actix-multipart/Cargo.toml +++ b/actix-multipart/Cargo.toml @@ -1,13 +1,11 @@ [package] name = "actix-multipart" -version = "0.4.0-beta.2" +version = "0.4.0-beta.12" authors = ["Nikolay Kim "] description = "Multipart form support for Actix Web" -readme = "README.md" keywords = ["http", "web", "framework", "async", "futures"] homepage = "https://actix.rs" repository = "https://github.com/actix/actix-web.git" -documentation = "https://docs.rs/actix-multipart/" license = "MIT OR Apache-2.0" edition = "2018" @@ -16,17 +14,21 @@ name = "actix_multipart" path = "src/lib.rs" [dependencies] -actix-web = { version = "4.0.0-beta.3", default-features = false } -actix-utils = "3.0.0-beta.2" +actix-utils = "3.0.0" +actix-web = { version = "4.0.0-beta.19", default-features = false } bytes = "1" derive_more = "0.99.5" +futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } httparse = "1.3" -futures-util = { version = "0.3.7", default-features = false } +local-waker = "0.1" log = "0.4" mime = "0.3" twoway = "0.2" [dev-dependencies] -actix-rt = "2" -actix-http = "3.0.0-beta.3" +actix-rt = "2.2" +actix-http = "3.0.0-beta.18" +futures-util = { version = "0.3.7", default-features = false, features = ["alloc"] } +tokio = { version = "1.8.4", features = ["sync"] } +tokio-stream = "0.1" diff --git a/actix-multipart/README.md b/actix-multipart/README.md index defcf7828..91cd8a6e9 100644 --- a/actix-multipart/README.md +++ b/actix-multipart/README.md @@ -3,16 +3,15 @@ > Multipart form support for Actix Web. [![crates.io](https://img.shields.io/crates/v/actix-multipart?label=latest)](https://crates.io/crates/actix-multipart) -[![Documentation](https://docs.rs/actix-multipart/badge.svg?version=0.4.0-beta.2)](https://docs.rs/actix-multipart/0.4.0-beta.2) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-multipart/badge.svg?version=0.4.0-beta.12)](https://docs.rs/actix-multipart/0.4.0-beta.12) +[![Version](https://img.shields.io/badge/rustc-1.54+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.54.0.html) ![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-multipart.svg)
-[![dependency status](https://deps.rs/crate/actix-multipart/0.4.0-beta.2/status.svg)](https://deps.rs/crate/actix-multipart/0.4.0-beta.2) +[![dependency status](https://deps.rs/crate/actix-multipart/0.4.0-beta.12/status.svg)](https://deps.rs/crate/actix-multipart/0.4.0-beta.12) [![Download](https://img.shields.io/crates/d/actix-multipart.svg)](https://crates.io/crates/actix-multipart) -[![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-multipart) -- [Chat on Gitter](https://gitter.im/actix/actix-web) -- Minimum Supported Rust Version (MSRV): 1.46.0 +- Minimum Supported Rust Version (MSRV): 1.54 diff --git a/actix-multipart/src/error.rs b/actix-multipart/src/error.rs index cdbb5d395..7d0da35e0 100644 --- a/actix-multipart/src/error.rs +++ b/actix-multipart/src/error.rs @@ -2,39 +2,52 @@ use actix_web::error::{ParseError, PayloadError}; use actix_web::http::StatusCode; use actix_web::ResponseError; -use derive_more::{Display, From}; +use derive_more::{Display, Error, From}; /// A set of errors that can occur during parsing multipart streams -#[derive(Debug, Display, From)] +#[non_exhaustive] +#[derive(Debug, Display, From, Error)] pub enum MultipartError { + /// Content-Disposition header is not found or is not equal to "form-data". + /// + /// According to [RFC 7578 §4.2](https://datatracker.ietf.org/doc/html/rfc7578#section-4.2) a + /// Content-Disposition header must always be present and equal to "form-data". + #[display(fmt = "No Content-Disposition `form-data` header")] + NoContentDisposition, + /// Content-Type header is not found - #[display(fmt = "No Content-type header found")] + #[display(fmt = "No Content-Type header found")] NoContentType, + /// Can not parse Content-Type header #[display(fmt = "Can not parse Content-Type header")] ParseContentType, + /// Multipart boundary is not found #[display(fmt = "Multipart boundary is not found")] Boundary, + /// Nested multipart is not supported #[display(fmt = "Nested multipart is not supported")] Nested, + /// Multipart stream is incomplete #[display(fmt = "Multipart stream is incomplete")] Incomplete, + /// Error during field parsing #[display(fmt = "{}", _0)] Parse(ParseError), + /// Payload error #[display(fmt = "{}", _0)] Payload(PayloadError), + /// Not consumed #[display(fmt = "Multipart stream is not consumed")] NotConsumed, } -impl std::error::Error for MultipartError {} - /// Return `BadRequest` for `MultipartError` impl ResponseError for MultipartError { fn status_code(&self) -> StatusCode { @@ -45,11 +58,10 @@ impl ResponseError for MultipartError { #[cfg(test)] mod tests { use super::*; - use actix_web::HttpResponse; #[test] fn test_multipart_error() { - let resp: HttpResponse = MultipartError::Boundary.error_response(); + let resp = MultipartError::Boundary.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } } diff --git a/actix-multipart/src/extractor.rs b/actix-multipart/src/extractor.rs index 6aaa415c4..1ad1f203d 100644 --- a/actix-multipart/src/extractor.rs +++ b/actix-multipart/src/extractor.rs @@ -1,21 +1,22 @@ //! Multipart payload support + +use actix_utils::future::{ready, Ready}; use actix_web::{dev::Payload, Error, FromRequest, HttpRequest}; -use futures_util::future::{ok, Ready}; use crate::server::Multipart; -/// Get request's payload as multipart stream +/// Get request's payload as multipart stream. /// /// Content-type: multipart/form-data; /// /// ## Server example /// -/// ```rust -/// use futures_util::stream::{Stream, StreamExt}; +/// ``` /// use actix_web::{web, HttpResponse, Error}; -/// use actix_multipart as mp; +/// use actix_multipart::Multipart; +/// use futures_util::stream::StreamExt as _; /// -/// async fn index(mut payload: mp::Multipart) -> Result { +/// async fn index(mut payload: Multipart) -> Result { /// // iterate over multipart stream /// while let Some(item) = payload.next().await { /// let mut field = item?; @@ -25,20 +26,19 @@ use crate::server::Multipart; /// println!("-- CHUNK: \n{:?}", std::str::from_utf8(&chunk?)); /// } /// } +/// /// Ok(HttpResponse::Ok().into()) /// } -/// # fn main() {} /// ``` impl FromRequest for Multipart { type Error = Error; type Future = Ready>; - type Config = (); #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - ok(match Multipart::boundary(req.headers()) { + ready(Ok(match Multipart::boundary(req.headers()) { Ok(boundary) => Multipart::from_boundary(boundary, payload.take()), Err(err) => Multipart::from_error(err), - }) + })) } } diff --git a/actix-multipart/src/lib.rs b/actix-multipart/src/lib.rs index 38a24e28f..3d536e08d 100644 --- a/actix-multipart/src/lib.rs +++ b/actix-multipart/src/lib.rs @@ -1,6 +1,7 @@ //! Multipart form support for Actix Web. -#![deny(rust_2018_idioms)] +#![deny(rust_2018_idioms, nonstandard_style)] +#![warn(future_incompatible)] #![allow(clippy::borrow_interior_mutable_const)] mod error; diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs index 8cd1c8e0c..239f7f905 100644 --- a/actix-multipart/src/server.rs +++ b/actix-multipart/src/server.rs @@ -1,19 +1,23 @@ -//! Multipart payload support +//! Multipart response payload support. -use std::cell::{Cell, RefCell, RefMut}; -use std::convert::TryFrom; -use std::marker::PhantomData; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; -use std::{cmp, fmt}; +use std::{ + cell::{Cell, RefCell, RefMut}, + cmp, + convert::TryFrom, + fmt, + marker::PhantomData, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; +use actix_web::{ + error::{ParseError, PayloadError}, + http::header::{self, ContentDisposition, HeaderMap, HeaderName, HeaderValue}, +}; use bytes::{Bytes, BytesMut}; -use futures_util::stream::{LocalBoxStream, Stream, StreamExt}; - -use actix_utils::task::LocalWaker; -use actix_web::error::{ParseError, PayloadError}; -use actix_web::http::header::{self, ContentDisposition, HeaderMap, HeaderName, HeaderValue}; +use futures_core::stream::{LocalBoxStream, Stream}; +use local_waker::LocalWaker; use crate::error::MultipartError; @@ -28,7 +32,7 @@ const MAX_HEADERS: usize = 32; pub struct Multipart { safety: Safety, error: Option, - inner: Option>>, + inner: Option, } enum InnerMultipartItem { @@ -40,10 +44,13 @@ enum InnerMultipartItem { enum InnerState { /// Stream eof Eof, + /// Skip data until first boundary FirstBoundary, + /// Reading boundary Boundary, + /// Reading Headers, Headers, } @@ -59,7 +66,7 @@ impl Multipart { /// Create multipart instance for boundary. pub fn new(headers: &HeaderMap, stream: S) -> Multipart where - S: Stream> + Unpin + 'static, + S: Stream> + 'static, { match Self::boundary(headers) { Ok(boundary) => Multipart::from_boundary(boundary, stream), @@ -69,39 +76,32 @@ impl Multipart { /// Extract boundary info from headers. pub(crate) fn boundary(headers: &HeaderMap) -> Result { - if let Some(content_type) = headers.get(&header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - if let Ok(ct) = content_type.parse::() { - if let Some(boundary) = ct.get_param(mime::BOUNDARY) { - Ok(boundary.as_str().to_owned()) - } else { - Err(MultipartError::Boundary) - } - } else { - Err(MultipartError::ParseContentType) - } - } else { - Err(MultipartError::ParseContentType) - } - } else { - Err(MultipartError::NoContentType) - } + headers + .get(&header::CONTENT_TYPE) + .ok_or(MultipartError::NoContentType)? + .to_str() + .ok() + .and_then(|content_type| content_type.parse::().ok()) + .ok_or(MultipartError::ParseContentType)? + .get_param(mime::BOUNDARY) + .map(|boundary| boundary.as_str().to_owned()) + .ok_or(MultipartError::Boundary) } /// Create multipart instance for given boundary and stream pub(crate) fn from_boundary(boundary: String, stream: S) -> Multipart where - S: Stream> + Unpin + 'static, + S: Stream> + 'static, { Multipart { error: None, safety: Safety::new(), - inner: Some(Rc::new(RefCell::new(InnerMultipart { + inner: Some(InnerMultipart { boundary, - payload: PayloadRef::new(PayloadBuffer::new(Box::new(stream))), + payload: PayloadRef::new(PayloadBuffer::new(stream)), state: InnerState::FirstBoundary, item: InnerMultipartItem::None, - }))), + }), } } @@ -118,20 +118,27 @@ impl Multipart { impl Stream for Multipart { type Item = Result; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if let Some(err) = self.error.take() { - Poll::Ready(Some(Err(err))) - } else if self.safety.current() { - let this = self.get_mut(); - let mut inner = this.inner.as_mut().unwrap().borrow_mut(); - if let Some(mut payload) = inner.payload.get_mut(&this.safety) { - payload.poll_stream(cx)?; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + match this.inner.as_mut() { + Some(inner) => { + if let Some(mut buffer) = inner.payload.get_mut(&this.safety) { + // check safety and poll read payload to buffer. + buffer.poll_stream(cx)?; + } else if !this.safety.is_clean() { + // safety violation + return Poll::Ready(Some(Err(MultipartError::NotConsumed))); + } else { + return Poll::Pending; + } + + inner.poll(&this.safety, cx) } - inner.poll(&this.safety, cx) - } else if !self.safety.is_clean() { - Poll::Ready(Some(Err(MultipartError::NotConsumed))) - } else { - Poll::Pending + None => Poll::Ready(Some(Err(this + .error + .take() + .expect("Multipart polled after finish")))), } } } @@ -152,17 +159,15 @@ impl InnerMultipart { Ok(httparse::Status::Complete((_, hdrs))) => { // convert headers let mut headers = HeaderMap::with_capacity(hdrs.len()); + for h in hdrs { - if let Ok(name) = HeaderName::try_from(h.name) { - if let Ok(value) = HeaderValue::try_from(h.value) { - headers.append(name, value); - } else { - return Err(ParseError::Header.into()); - } - } else { - return Err(ParseError::Header.into()); - } + let name = + HeaderName::try_from(h.name).map_err(|_| ParseError::Header)?; + let value = HeaderValue::try_from(h.value) + .map_err(|_| ParseError::Header)?; + headers.append(name, value); } + Ok(Some(headers)) } Ok(httparse::Status::Partial) => Err(ParseError::Header.into()), @@ -332,31 +337,55 @@ impl InnerMultipart { return Poll::Pending; }; - // content type - let mut mt = mime::APPLICATION_OCTET_STREAM; - if let Some(content_type) = headers.get(&header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - if let Ok(ct) = content_type.parse::() { - mt = ct; - } - } - } + // According to RFC 7578 §4.2, a Content-Disposition header must always be present and + // set to "form-data". + + let content_disposition = headers + .get(&header::CONTENT_DISPOSITION) + .and_then(|cd| ContentDisposition::from_raw(cd).ok()) + .filter(|content_disposition| { + let is_form_data = + content_disposition.disposition == header::DispositionType::FormData; + + let has_field_name = content_disposition + .parameters + .iter() + .any(|param| matches!(param, header::DispositionParam::Name(_))); + + is_form_data && has_field_name + }); + + let cd = if let Some(content_disposition) = content_disposition { + content_disposition + } else { + return Poll::Ready(Some(Err(MultipartError::NoContentDisposition))); + }; + + let ct: mime::Mime = headers + .get(&header::CONTENT_TYPE) + .and_then(|ct| ct.to_str().ok()) + .and_then(|ct| ct.parse().ok()) + .unwrap_or(mime::APPLICATION_OCTET_STREAM); self.state = InnerState::Boundary; - // nested multipart stream - if mt.type_() == mime::MULTIPART { - Poll::Ready(Some(Err(MultipartError::Nested))) - } else { - let field = Rc::new(RefCell::new(InnerField::new( - self.payload.clone(), - self.boundary.clone(), - &headers, - )?)); - self.item = InnerMultipartItem::Field(Rc::clone(&field)); - - Poll::Ready(Some(Ok(Field::new(safety.clone(cx), headers, mt, field)))) + // nested multipart stream is not supported + if ct.type_() == mime::MULTIPART { + return Poll::Ready(Some(Err(MultipartError::Nested))); } + + let field = + InnerField::new_in_rc(self.payload.clone(), self.boundary.clone(), &headers)?; + + self.item = InnerMultipartItem::Field(Rc::clone(&field)); + + Poll::Ready(Some(Ok(Field::new( + safety.clone(cx), + headers, + ct, + cd, + field, + )))) } } } @@ -371,6 +400,7 @@ impl Drop for InnerMultipart { /// A single field in a multipart stream pub struct Field { ct: mime::Mime, + cd: ContentDisposition, headers: HeaderMap, inner: Rc>, safety: Safety, @@ -381,35 +411,52 @@ impl Field { safety: Safety, headers: HeaderMap, ct: mime::Mime, + cd: ContentDisposition, inner: Rc>, ) -> Self { Field { ct, + cd, headers, inner, safety, } } - /// Get a map of headers + /// Returns a reference to the field's header map. pub fn headers(&self) -> &HeaderMap { &self.headers } - /// Get the content type of the field + /// Returns a reference to the field's content (mime) type. pub fn content_type(&self) -> &mime::Mime { &self.ct } - /// Get the content disposition of the field, if it exists - pub fn content_disposition(&self) -> Option { - // RFC 7578: 'Each part MUST contain a Content-Disposition header field - // where the disposition type is "form-data".' - if let Some(content_disposition) = self.headers.get(&header::CONTENT_DISPOSITION) { - ContentDisposition::from_raw(content_disposition).ok() - } else { - None - } + /// Returns the field's Content-Disposition. + /// + /// Per [RFC 7578 §4.2]: "Each part MUST contain a Content-Disposition header field where the + /// disposition type is `form-data`. The Content-Disposition header field MUST also contain an + /// additional parameter of `name`; the value of the `name` parameter is the original field name + /// from the form." + /// + /// This crate validates that it exists before returning a `Field`. As such, it is safe to + /// unwrap `.content_disposition().get_name()`. The [name](Self::name) method is provided as + /// a convenience. + /// + /// [RFC 7578 §4.2]: https://datatracker.ietf.org/doc/html/rfc7578#section-4.2 + pub fn content_disposition(&self) -> &ContentDisposition { + &self.cd + } + + /// Returns the field's name. + /// + /// See [content_disposition](Self::content_disposition) regarding guarantees about existence of + /// the name field. + pub fn name(&self) -> &str { + self.content_disposition() + .get_name() + .expect("field name should be guaranteed to exist in multipart form-data") } } @@ -417,17 +464,19 @@ impl Stream for Field { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.safety.current() { - let mut inner = self.inner.borrow_mut(); - if let Some(mut payload) = inner.payload.as_ref().unwrap().get_mut(&self.safety) { - payload.poll_stream(cx)?; - } - inner.poll(&self.safety) - } else if !self.safety.is_clean() { - Poll::Ready(Some(Err(MultipartError::NotConsumed))) + let this = self.get_mut(); + let mut inner = this.inner.borrow_mut(); + if let Some(mut buffer) = inner.payload.as_ref().unwrap().get_mut(&this.safety) { + // check safety and poll read payload to buffer. + buffer.poll_stream(cx)?; + } else if !this.safety.is_clean() { + // safety violation + return Poll::Ready(Some(Err(MultipartError::NotConsumed))); } else { - Poll::Pending + return Poll::Pending; } + + inner.poll(&this.safety) } } @@ -451,20 +500,23 @@ struct InnerField { } impl InnerField { + fn new_in_rc( + payload: PayloadRef, + boundary: String, + headers: &HeaderMap, + ) -> Result>, PayloadError> { + Self::new(payload, boundary, headers).map(|this| Rc::new(RefCell::new(this))) + } + fn new( payload: PayloadRef, boundary: String, headers: &HeaderMap, ) -> Result { let len = if let Some(len) = headers.get(&header::CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - Some(len) - } else { - return Err(PayloadError::Incomplete(None)); - } - } else { - return Err(PayloadError::Incomplete(None)); + match len.to_str().ok().and_then(|len| len.parse::().ok()) { + Some(len) => Some(len), + None => return Err(PayloadError::Incomplete(None)), } } else { None @@ -638,10 +690,7 @@ impl PayloadRef { } } - fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option> - where - 'a: 'b, - { + fn get_mut(&self, s: &Safety) -> Option> { if s.current() { Some(self.payload.borrow_mut()) } else { @@ -658,9 +707,11 @@ impl Clone for PayloadRef { } } -/// Counter. It tracks of number of clones of payloads and give access to -/// payload only to top most task panics if Safety get destroyed and it not top -/// most task. +/// Counter. It tracks of number of clones of payloads and give access to payload only to top most. +/// * When dropped, parent task is awakened. This is to support the case where Field is +/// dropped in a separate task than Multipart. +/// * Assumes that parent owners don't move to different tasks; only the top-most is allowed to. +/// * If dropped and is not top most owner, is_clean flag is set to false. #[derive(Debug)] struct Safety { task: LocalWaker, @@ -703,15 +754,16 @@ impl Safety { impl Drop for Safety { fn drop(&mut self) { - // parent task is dead if Rc::strong_count(&self.payload) != self.level { - self.clean.set(true); + // Multipart dropped leaving a Field + self.clean.set(false); } + self.task.wake(); } } -/// Payload buffer +/// Payload buffer. struct PayloadBuffer { eof: bool, buf: BytesMut, @@ -719,7 +771,7 @@ struct PayloadBuffer { } impl PayloadBuffer { - /// Create new `PayloadBuffer` instance + /// Constructs new `PayloadBuffer` instance. fn new(stream: S) -> Self where S: Stream> + 'static, @@ -727,7 +779,7 @@ impl PayloadBuffer { PayloadBuffer { eof: false, buf: BytesMut::new(), - stream: stream.boxed_local(), + stream: Box::pin(stream), } } @@ -767,7 +819,7 @@ impl PayloadBuffer { } /// Read until specified ending - pub fn read_until(&mut self, line: &[u8]) -> Result, MultipartError> { + fn read_until(&mut self, line: &[u8]) -> Result, MultipartError> { let res = twoway::find_bytes(&self.buf, line) .map(|idx| self.buf.split_to(idx + line.len()).freeze()); @@ -779,12 +831,12 @@ impl PayloadBuffer { } /// Read bytes until new line delimiter - pub fn readline(&mut self) -> Result, MultipartError> { + fn readline(&mut self) -> Result, MultipartError> { self.read_until(b"\n") } /// Read bytes until new line delimiter or eof - pub fn readline_or_eof(&mut self) -> Result, MultipartError> { + fn readline_or_eof(&mut self) -> Result, MultipartError> { match self.readline() { Err(MultipartError::Incomplete) if self.eof => Ok(Some(self.buf.split().freeze())), line => line, @@ -792,7 +844,7 @@ impl PayloadBuffer { } /// Put unprocessed data back to the buffer - pub fn unprocessed(&mut self, data: Bytes) { + fn unprocessed(&mut self, data: Bytes) { let buf = BytesMut::from(data.as_ref()); let buf = std::mem::replace(&mut self.buf, buf); self.buf.extend_from_slice(&buf); @@ -804,12 +856,15 @@ mod tests { use super::*; use actix_http::h1::Payload; - use actix_utils::mpsc; use actix_web::http::header::{DispositionParam, DispositionType}; + use actix_web::rt; use actix_web::test::TestRequest; use actix_web::FromRequest; use bytes::Bytes; - use futures_util::future::lazy; + use futures_util::{future::lazy, StreamExt}; + use std::time::Duration; + use tokio::sync::mpsc; + use tokio_stream::wrappers::UnboundedReceiverStream; #[actix_rt::test] async fn test_boundary() { @@ -855,13 +910,17 @@ mod tests { } fn create_stream() -> ( - mpsc::Sender>, + mpsc::UnboundedSender>, impl Stream>, ) { - let (tx, rx) = mpsc::channel(); + let (tx, rx) = mpsc::unbounded_channel(); - (tx, rx.map(|res| res.map_err(|_| panic!()))) + ( + tx, + UnboundedReceiverStream::new(rx).map(|res| res.map_err(|_| panic!())), + ) } + // Stream that returns from a Bytes, one char at a time and Pending every other poll() struct SlowStream { bytes: Bytes, @@ -889,9 +948,11 @@ mod tests { cx.waker().wake_by_ref(); return Poll::Pending; } + if this.pos == this.bytes.len() { return Poll::Ready(None); } + let res = Poll::Ready(Some(Ok(this.bytes.slice(this.pos..(this.pos + 1))))); this.pos += 1; this.ready = false; @@ -907,6 +968,7 @@ mod tests { Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ test\r\n\ --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\ Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ data\r\n\ --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n", @@ -958,7 +1020,7 @@ mod tests { let mut multipart = Multipart::new(&headers, payload); match multipart.next().await { Some(Ok(mut field)) => { - let cd = field.content_disposition().unwrap(); + let cd = field.content_disposition(); assert_eq!(cd.disposition, DispositionType::FormData); assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); @@ -1020,7 +1082,7 @@ mod tests { let mut multipart = Multipart::new(&headers, payload); match multipart.next().await.unwrap() { Ok(mut field) => { - let cd = field.content_disposition().unwrap(); + let cd = field.content_disposition(); assert_eq!(cd.disposition, DispositionType::FormData); assert_eq!(cd.parameters[0], DispositionParam::Name("file".into())); @@ -1171,8 +1233,103 @@ mod tests { // and should not consume the payload match payload { - actix_web::dev::Payload::H1(_) => {} //expected + actix_web::dev::Payload::H1 { .. } => {} //expected _ => unreachable!(), } } + + #[actix_rt::test] + async fn no_content_disposition() { + let bytes = Bytes::from( + "testasdadsad\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ + test\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n", + ); + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static( + "multipart/mixed; boundary=\"abbc761f78ff4d7cb7573b5a23f96ef0\"", + ), + ); + let payload = SlowStream::new(bytes); + + let mut multipart = Multipart::new(&headers, payload); + let res = multipart.next().await.unwrap(); + assert!(res.is_err()); + assert!(matches!( + res.unwrap_err(), + MultipartError::NoContentDisposition, + )); + } + + #[actix_rt::test] + async fn no_name_in_content_disposition() { + let bytes = Bytes::from( + "testasdadsad\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\ + Content-Disposition: form-data; filename=\"fn.txt\"\r\n\ + Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\ + test\r\n\ + --abbc761f78ff4d7cb7573b5a23f96ef0\r\n", + ); + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static( + "multipart/mixed; boundary=\"abbc761f78ff4d7cb7573b5a23f96ef0\"", + ), + ); + let payload = SlowStream::new(bytes); + + let mut multipart = Multipart::new(&headers, payload); + let res = multipart.next().await.unwrap(); + assert!(res.is_err()); + assert!(matches!( + res.unwrap_err(), + MultipartError::NoContentDisposition, + )); + } + + #[actix_rt::test] + async fn test_drop_multipart_dont_hang() { + let (sender, payload) = create_stream(); + let (bytes, headers) = create_simple_request_with_header(); + sender.send(Ok(bytes)).unwrap(); + drop(sender); // eof + + let mut multipart = Multipart::new(&headers, payload); + let mut field = multipart.next().await.unwrap().unwrap(); + + drop(multipart); + + // should fail immediately + match field.next().await { + Some(Err(MultipartError::NotConsumed)) => {} + _ => panic!(), + }; + } + + #[actix_rt::test] + async fn test_drop_field_awaken_multipart() { + let (sender, payload) = create_stream(); + let (bytes, headers) = create_simple_request_with_header(); + sender.send(Ok(bytes)).unwrap(); + drop(sender); // eof + + let mut multipart = Multipart::new(&headers, payload); + let mut field = multipart.next().await.unwrap().unwrap(); + + let task = rt::spawn(async move { + rt::time::sleep(Duration::from_secs(1)).await; + assert_eq!(field.next().await.unwrap().unwrap(), "test"); + drop(field); + }); + + // dropping field should awaken current task + let _ = multipart.next().await.unwrap().unwrap(); + task.await.unwrap(); + } } diff --git a/actix-router/CHANGES.md b/actix-router/CHANGES.md new file mode 100644 index 000000000..3034ed794 --- /dev/null +++ b/actix-router/CHANGES.md @@ -0,0 +1,146 @@ +# Changes + +## Unreleased - 2021-xx-xx +- `Resource` trait now have an associated type, `Path`, instead of the generic parameter. [#2568] +- `Resource` is now implemented for `&mut Path<_>` and `RefMut>`. [#2568] + +[#2568]: https://github.com/actix/actix-web/pull/2568 + + +## 0.5.0-beta.4 - 2022-01-04 +- `PathDeserializer` now decodes all percent encoded characters in dynamic segments. [#2566] +- Minimum supported Rust version (MSRV) is now 1.54. + +[#2566]: https://github.com/actix/actix-net/pull/2566 + + +## 0.5.0-beta.3 - 2021-12-17 +- Minimum supported Rust version (MSRV) is now 1.52. + + +## 0.5.0-beta.2 - 2021-09-09 +- Introduce `ResourceDef::join`. [#380] +- Disallow prefix routes with tail segments. [#379] +- Enforce path separators on dynamic prefixes. [#378] +- Improve malformed path error message. [#384] +- Prefix segments now always end with with a segment delimiter or end-of-input. [#2355] +- Prefix segments with trailing slashes define a trailing empty segment. [#2355] +- Support multi-pattern prefixes and joins. [#2356] +- `ResourceDef::pattern` now returns the first pattern in multi-pattern resources. [#2356] +- Support `build_resource_path` on multi-pattern resources. [#2356] +- Minimum supported Rust version (MSRV) is now 1.51. + +[#378]: https://github.com/actix/actix-net/pull/378 +[#379]: https://github.com/actix/actix-net/pull/379 +[#380]: https://github.com/actix/actix-net/pull/380 +[#384]: https://github.com/actix/actix-net/pull/384 +[#2355]: https://github.com/actix/actix-web/pull/2355 +[#2356]: https://github.com/actix/actix-web/pull/2356 + + +## 0.5.0-beta.1 - 2021-07-20 +- Fix a bug in multi-patterns where static patterns are interpreted as regex. [#366] +- Introduce `ResourceDef::pattern_iter` to get an iterator over all patterns in a multi-pattern resource. [#373] +- Fix segment interpolation leaving `Path` in unintended state after matching. [#368] +- Fix `ResourceDef` `PartialEq` implementation. [#373] +- Re-work `IntoPatterns` trait, adding a `Patterns` enum. [#372] +- Implement `IntoPatterns` for `bytestring::ByteString`. [#372] +- Rename `Path::{len => segment_count}` to be more descriptive of it's purpose. [#370] +- Rename `ResourceDef::{resource_path => resource_path_from_iter}`. [#371] +- `ResourceDef::resource_path_from_iter` now takes an `IntoIterator`. [#373] +- Rename `ResourceDef::{resource_path_named => resource_path_from_map}`. [#371] +- Rename `ResourceDef::{is_prefix_match => find_match}`. [#373] +- Rename `ResourceDef::{match_path => capture_match_info}`. [#373] +- Rename `ResourceDef::{match_path_checked => capture_match_info_fn}`. [#373] +- Remove `ResourceDef::name_mut` and introduce `ResourceDef::set_name`. [#373] +- Rename `Router::{*_checked => *_fn}`. [#373] +- Return type of `ResourceDef::name` is now `Option<&str>`. [#373] +- Return type of `ResourceDef::pattern` is now `Option<&str>`. [#373] + +[#368]: https://github.com/actix/actix-net/pull/368 +[#366]: https://github.com/actix/actix-net/pull/366 +[#368]: https://github.com/actix/actix-net/pull/368 +[#370]: https://github.com/actix/actix-net/pull/370 +[#371]: https://github.com/actix/actix-net/pull/371 +[#372]: https://github.com/actix/actix-net/pull/372 +[#373]: https://github.com/actix/actix-net/pull/373 + + +## 0.4.0 - 2021-06-06 +- When matching path parameters, `%25` is now kept in the percent-encoded form; no longer decoded to `%`. [#357] +- Path tail patterns now match new lines (`\n`) in request URL. [#360] +- Fixed a safety bug where `Path` could return a malformed string after percent decoding. [#359] +- Methods `Path::{add, add_static}` now take `impl Into>`. [#345] + +[#345]: https://github.com/actix/actix-net/pull/345 +[#357]: https://github.com/actix/actix-net/pull/357 +[#359]: https://github.com/actix/actix-net/pull/359 +[#360]: https://github.com/actix/actix-net/pull/360 + + +## 0.3.0 - 2019-12-31 +- Version was yanked previously. See https://crates.io/crates/actix-router/0.3.0 + + +## 0.2.7 - 2021-02-06 +- Add `Router::recognize_checked` [#247] + +[#247]: https://github.com/actix/actix-net/pull/247 + + +## 0.2.6 - 2021-01-09 +- Use `bytestring` version range compatible with Bytes v1.0. [#246] + +[#246]: https://github.com/actix/actix-net/pull/246 + + +## 0.2.5 - 2020-09-20 +- Fix `from_hex()` method + + +## 0.2.4 - 2019-12-31 +- Add `ResourceDef::resource_path_named()` path generation method + + +## 0.2.3 - 2019-12-25 +- Add impl `IntoPattern` for `&String` + + +## 0.2.2 - 2019-12-25 +- Use `IntoPattern` for `RouterBuilder::path()` + + +## 0.2.1 - 2019-12-25 +- Add `IntoPattern` trait +- Add multi-pattern resources + + +## 0.2.0 - 2019-12-07 +- Update http to 0.2 +- Update regex to 1.3 +- Use bytestring instead of string + + +## 0.1.5 - 2019-05-15 +- Remove debug prints + + +## 0.1.4 - 2019-05-15 +- Fix checked resource match + + +## 0.1.3 - 2019-04-22 +- Added support for `remainder match` (i.e "/path/{tail}*") + + +## 0.1.2 - 2019-04-07 +- Export `Quoter` type +- Allow to reset `Path` instance + + +## 0.1.1 - 2019-04-03 +- Get dynamic segment by name instead of iterator. + + +## 0.1.0 - 2019-03-09 +- Initial release diff --git a/actix-router/Cargo.toml b/actix-router/Cargo.toml new file mode 100644 index 000000000..801613568 --- /dev/null +++ b/actix-router/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "actix-router" +version = "0.5.0-beta.4" +authors = [ + "Nikolay Kim ", + "Ali MJ Al-Nasrawy ", + "Rob Ede ", +] +description = "Resource path matching and router" +keywords = ["actix", "router", "routing"] +repository = "https://github.com/actix/actix-web.git" +license = "MIT OR Apache-2.0" +edition = "2018" + +[lib] +name = "actix_router" +path = "src/lib.rs" + +[features] +default = ["http"] + +[dependencies] +bytestring = ">=0.1.5, <2" +firestorm = "0.5" +http = { version = "0.2.3", optional = true } +log = "0.4" +regex = "1.5" +serde = "1" + +[dev-dependencies] +criterion = { version = "0.3", features = ["html_reports"] } +firestorm = { version = "0.5", features = ["enable_system_time"] } +http = "0.2.5" +serde = { version = "1", features = ["derive"] } + +[[bench]] +name = "router" +harness = false diff --git a/actix-router/LICENSE-APACHE b/actix-router/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-router/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-router/LICENSE-MIT b/actix-router/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-router/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-router/benches/router.rs b/actix-router/benches/router.rs new file mode 100644 index 000000000..a428b9f13 --- /dev/null +++ b/actix-router/benches/router.rs @@ -0,0 +1,194 @@ +//! Based on https://github.com/ibraheemdev/matchit/blob/master/benches/bench.rs + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +macro_rules! register { + (colon) => {{ + register!(finish => ":p1", ":p2", ":p3", ":p4") + }}; + (brackets) => {{ + register!(finish => "{p1}", "{p2}", "{p3}", "{p4}") + }}; + (regex) => {{ + register!(finish => "(.*)", "(.*)", "(.*)", "(.*)") + }}; + (finish => $p1:literal, $p2:literal, $p3:literal, $p4:literal) => {{ + let arr = [ + concat!("/authorizations"), + concat!("/authorizations/", $p1), + concat!("/applications/", $p1, "/tokens/", $p2), + concat!("/events"), + concat!("/repos/", $p1, "/", $p2, "/events"), + concat!("/networks/", $p1, "/", $p2, "/events"), + concat!("/orgs/", $p1, "/events"), + concat!("/users/", $p1, "/received_events"), + concat!("/users/", $p1, "/received_events/public"), + concat!("/users/", $p1, "/events"), + concat!("/users/", $p1, "/events/public"), + concat!("/users/", $p1, "/events/orgs/", $p2), + concat!("/feeds"), + concat!("/notifications"), + concat!("/repos/", $p1, "/", $p2, "/notifications"), + concat!("/notifications/threads/", $p1), + concat!("/notifications/threads/", $p1, "/subscription"), + concat!("/repos/", $p1, "/", $p2, "/stargazers"), + concat!("/users/", $p1, "/starred"), + concat!("/user/starred"), + concat!("/user/starred/", $p1, "/", $p2), + concat!("/repos/", $p1, "/", $p2, "/subscribers"), + concat!("/users/", $p1, "/subscriptions"), + concat!("/user/subscriptions"), + concat!("/repos/", $p1, "/", $p2, "/subscription"), + concat!("/user/subscriptions/", $p1, "/", $p2), + concat!("/users/", $p1, "/gists"), + concat!("/gists"), + concat!("/gists/", $p1), + concat!("/gists/", $p1, "/star"), + concat!("/repos/", $p1, "/", $p2, "/git/blobs/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/commits/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/refs"), + concat!("/repos/", $p1, "/", $p2, "/git/tags/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/trees/", $p3), + concat!("/issues"), + concat!("/user/issues"), + concat!("/orgs/", $p1, "/issues"), + concat!("/repos/", $p1, "/", $p2, "/issues"), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3), + concat!("/repos/", $p1, "/", $p2, "/assignees"), + concat!("/repos/", $p1, "/", $p2, "/assignees/", $p3), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/events"), + concat!("/repos/", $p1, "/", $p2, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/labels/", $p3), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/milestones/", $p3, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/milestones/"), + concat!("/repos/", $p1, "/", $p2, "/milestones/", $p3), + concat!("/emojis"), + concat!("/gitignore/templates"), + concat!("/gitignore/templates/", $p1), + concat!("/meta"), + concat!("/rate_limit"), + concat!("/users/", $p1, "/orgs"), + concat!("/user/orgs"), + concat!("/orgs/", $p1), + concat!("/orgs/", $p1, "/members"), + concat!("/orgs/", $p1, "/members", $p2), + concat!("/orgs/", $p1, "/public_members"), + concat!("/orgs/", $p1, "/public_members/", $p2), + concat!("/orgs/", $p1, "/teams"), + concat!("/teams/", $p1), + concat!("/teams/", $p1, "/members"), + concat!("/teams/", $p1, "/members", $p2), + concat!("/teams/", $p1, "/repos"), + concat!("/teams/", $p1, "/repos/", $p2, "/", $p3), + concat!("/user/teams"), + concat!("/repos/", $p1, "/", $p2, "/pulls"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/commits"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/files"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/merge"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/comments"), + concat!("/user/repos"), + concat!("/users/", $p1, "/repos"), + concat!("/orgs/", $p1, "/repos"), + concat!("/repositories"), + concat!("/repos/", $p1, "/", $p2), + concat!("/repos/", $p1, "/", $p2, "/contributors"), + concat!("/repos/", $p1, "/", $p2, "/languages"), + concat!("/repos/", $p1, "/", $p2, "/teams"), + concat!("/repos/", $p1, "/", $p2, "/tags"), + concat!("/repos/", $p1, "/", $p2, "/branches"), + concat!("/repos/", $p1, "/", $p2, "/branches/", $p3), + concat!("/repos/", $p1, "/", $p2, "/collaborators"), + concat!("/repos/", $p1, "/", $p2, "/collaborators/", $p3), + concat!("/repos/", $p1, "/", $p2, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/commits/", $p3, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/commits"), + concat!("/repos/", $p1, "/", $p2, "/commits/", $p3), + concat!("/repos/", $p1, "/", $p2, "/readme"), + concat!("/repos/", $p1, "/", $p2, "/keys"), + concat!("/repos/", $p1, "/", $p2, "/keys", $p3), + concat!("/repos/", $p1, "/", $p2, "/downloads"), + concat!("/repos/", $p1, "/", $p2, "/downloads", $p3), + concat!("/repos/", $p1, "/", $p2, "/forks"), + concat!("/repos/", $p1, "/", $p2, "/hooks"), + concat!("/repos/", $p1, "/", $p2, "/hooks", $p3), + concat!("/repos/", $p1, "/", $p2, "/releases"), + concat!("/repos/", $p1, "/", $p2, "/releases/", $p3), + concat!("/repos/", $p1, "/", $p2, "/releases/", $p3, "/assets"), + concat!("/repos/", $p1, "/", $p2, "/stats/contributors"), + concat!("/repos/", $p1, "/", $p2, "/stats/commit_activity"), + concat!("/repos/", $p1, "/", $p2, "/stats/code_frequency"), + concat!("/repos/", $p1, "/", $p2, "/stats/participation"), + concat!("/repos/", $p1, "/", $p2, "/stats/punch_card"), + concat!("/repos/", $p1, "/", $p2, "/statuses/", $p3), + concat!("/search/repositories"), + concat!("/search/code"), + concat!("/search/issues"), + concat!("/search/users"), + concat!("/legacy/issues/search/", $p1, "/", $p2, "/", $p3, "/", $p4), + concat!("/legacy/repos/search/", $p1), + concat!("/legacy/user/search/", $p1), + concat!("/legacy/user/email/", $p1), + concat!("/users/", $p1), + concat!("/user"), + concat!("/users"), + concat!("/user/emails"), + concat!("/users/", $p1, "/followers"), + concat!("/user/followers"), + concat!("/users/", $p1, "/following"), + concat!("/user/following"), + concat!("/user/following/", $p1), + concat!("/users/", $p1, "/following", $p2), + concat!("/users/", $p1, "/keys"), + concat!("/user/keys"), + concat!("/user/keys/", $p1), + ]; + std::array::IntoIter::new(arr) + }}; +} + +fn call() -> impl Iterator { + let arr = [ + "/authorizations", + "/user/repos", + "/repos/rust-lang/rust/stargazers", + "/orgs/rust-lang/public_members/nikomatsakis", + "/repos/rust-lang/rust/releases/1.51.0", + ]; + + std::array::IntoIter::new(arr) +} + +fn compare_routers(c: &mut Criterion) { + let mut group = c.benchmark_group("Compare Routers"); + + let mut actix = actix_router::Router::::build(); + for route in register!(brackets) { + actix.path(route, true); + } + let actix = actix.finish(); + group.bench_function("actix", |b| { + b.iter(|| { + for route in call() { + let mut path = actix_router::Path::new(route); + black_box(actix.recognize(&mut path).unwrap()); + } + }); + }); + + let regex_set = regex::RegexSet::new(register!(regex)).unwrap(); + group.bench_function("regex", |b| { + b.iter(|| { + for route in call() { + black_box(regex_set.matches(route)); + } + }); + }); + + group.finish(); +} + +criterion_group!(benches, compare_routers); +criterion_main!(benches); diff --git a/actix-router/examples/flamegraph.rs b/actix-router/examples/flamegraph.rs new file mode 100644 index 000000000..798cc22d9 --- /dev/null +++ b/actix-router/examples/flamegraph.rs @@ -0,0 +1,169 @@ +macro_rules! register { + (brackets) => {{ + register!(finish => "{p1}", "{p2}", "{p3}", "{p4}") + }}; + (finish => $p1:literal, $p2:literal, $p3:literal, $p4:literal) => {{ + let arr = [ + concat!("/authorizations"), + concat!("/authorizations/", $p1), + concat!("/applications/", $p1, "/tokens/", $p2), + concat!("/events"), + concat!("/repos/", $p1, "/", $p2, "/events"), + concat!("/networks/", $p1, "/", $p2, "/events"), + concat!("/orgs/", $p1, "/events"), + concat!("/users/", $p1, "/received_events"), + concat!("/users/", $p1, "/received_events/public"), + concat!("/users/", $p1, "/events"), + concat!("/users/", $p1, "/events/public"), + concat!("/users/", $p1, "/events/orgs/", $p2), + concat!("/feeds"), + concat!("/notifications"), + concat!("/repos/", $p1, "/", $p2, "/notifications"), + concat!("/notifications/threads/", $p1), + concat!("/notifications/threads/", $p1, "/subscription"), + concat!("/repos/", $p1, "/", $p2, "/stargazers"), + concat!("/users/", $p1, "/starred"), + concat!("/user/starred"), + concat!("/user/starred/", $p1, "/", $p2), + concat!("/repos/", $p1, "/", $p2, "/subscribers"), + concat!("/users/", $p1, "/subscriptions"), + concat!("/user/subscriptions"), + concat!("/repos/", $p1, "/", $p2, "/subscription"), + concat!("/user/subscriptions/", $p1, "/", $p2), + concat!("/users/", $p1, "/gists"), + concat!("/gists"), + concat!("/gists/", $p1), + concat!("/gists/", $p1, "/star"), + concat!("/repos/", $p1, "/", $p2, "/git/blobs/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/commits/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/refs"), + concat!("/repos/", $p1, "/", $p2, "/git/tags/", $p3), + concat!("/repos/", $p1, "/", $p2, "/git/trees/", $p3), + concat!("/issues"), + concat!("/user/issues"), + concat!("/orgs/", $p1, "/issues"), + concat!("/repos/", $p1, "/", $p2, "/issues"), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3), + concat!("/repos/", $p1, "/", $p2, "/assignees"), + concat!("/repos/", $p1, "/", $p2, "/assignees/", $p3), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/events"), + concat!("/repos/", $p1, "/", $p2, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/labels/", $p3), + concat!("/repos/", $p1, "/", $p2, "/issues/", $p3, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/milestones/", $p3, "/labels"), + concat!("/repos/", $p1, "/", $p2, "/milestones/"), + concat!("/repos/", $p1, "/", $p2, "/milestones/", $p3), + concat!("/emojis"), + concat!("/gitignore/templates"), + concat!("/gitignore/templates/", $p1), + concat!("/meta"), + concat!("/rate_limit"), + concat!("/users/", $p1, "/orgs"), + concat!("/user/orgs"), + concat!("/orgs/", $p1), + concat!("/orgs/", $p1, "/members"), + concat!("/orgs/", $p1, "/members", $p2), + concat!("/orgs/", $p1, "/public_members"), + concat!("/orgs/", $p1, "/public_members/", $p2), + concat!("/orgs/", $p1, "/teams"), + concat!("/teams/", $p1), + concat!("/teams/", $p1, "/members"), + concat!("/teams/", $p1, "/members", $p2), + concat!("/teams/", $p1, "/repos"), + concat!("/teams/", $p1, "/repos/", $p2, "/", $p3), + concat!("/user/teams"), + concat!("/repos/", $p1, "/", $p2, "/pulls"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/commits"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/files"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/merge"), + concat!("/repos/", $p1, "/", $p2, "/pulls/", $p3, "/comments"), + concat!("/user/repos"), + concat!("/users/", $p1, "/repos"), + concat!("/orgs/", $p1, "/repos"), + concat!("/repositories"), + concat!("/repos/", $p1, "/", $p2), + concat!("/repos/", $p1, "/", $p2, "/contributors"), + concat!("/repos/", $p1, "/", $p2, "/languages"), + concat!("/repos/", $p1, "/", $p2, "/teams"), + concat!("/repos/", $p1, "/", $p2, "/tags"), + concat!("/repos/", $p1, "/", $p2, "/branches"), + concat!("/repos/", $p1, "/", $p2, "/branches/", $p3), + concat!("/repos/", $p1, "/", $p2, "/collaborators"), + concat!("/repos/", $p1, "/", $p2, "/collaborators/", $p3), + concat!("/repos/", $p1, "/", $p2, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/commits/", $p3, "/comments"), + concat!("/repos/", $p1, "/", $p2, "/commits"), + concat!("/repos/", $p1, "/", $p2, "/commits/", $p3), + concat!("/repos/", $p1, "/", $p2, "/readme"), + concat!("/repos/", $p1, "/", $p2, "/keys"), + concat!("/repos/", $p1, "/", $p2, "/keys", $p3), + concat!("/repos/", $p1, "/", $p2, "/downloads"), + concat!("/repos/", $p1, "/", $p2, "/downloads", $p3), + concat!("/repos/", $p1, "/", $p2, "/forks"), + concat!("/repos/", $p1, "/", $p2, "/hooks"), + concat!("/repos/", $p1, "/", $p2, "/hooks", $p3), + concat!("/repos/", $p1, "/", $p2, "/releases"), + concat!("/repos/", $p1, "/", $p2, "/releases/", $p3), + concat!("/repos/", $p1, "/", $p2, "/releases/", $p3, "/assets"), + concat!("/repos/", $p1, "/", $p2, "/stats/contributors"), + concat!("/repos/", $p1, "/", $p2, "/stats/commit_activity"), + concat!("/repos/", $p1, "/", $p2, "/stats/code_frequency"), + concat!("/repos/", $p1, "/", $p2, "/stats/participation"), + concat!("/repos/", $p1, "/", $p2, "/stats/punch_card"), + concat!("/repos/", $p1, "/", $p2, "/statuses/", $p3), + concat!("/search/repositories"), + concat!("/search/code"), + concat!("/search/issues"), + concat!("/search/users"), + concat!("/legacy/issues/search/", $p1, "/", $p2, "/", $p3, "/", $p4), + concat!("/legacy/repos/search/", $p1), + concat!("/legacy/user/search/", $p1), + concat!("/legacy/user/email/", $p1), + concat!("/users/", $p1), + concat!("/user"), + concat!("/users"), + concat!("/user/emails"), + concat!("/users/", $p1, "/followers"), + concat!("/user/followers"), + concat!("/users/", $p1, "/following"), + concat!("/user/following"), + concat!("/user/following/", $p1), + concat!("/users/", $p1, "/following", $p2), + concat!("/users/", $p1, "/keys"), + concat!("/user/keys"), + concat!("/user/keys/", $p1), + ]; + + arr.to_vec() + }}; +} + +static PATHS: [&str; 5] = [ + "/authorizations", + "/user/repos", + "/repos/rust-lang/rust/stargazers", + "/orgs/rust-lang/public_members/nikomatsakis", + "/repos/rust-lang/rust/releases/1.51.0", +]; + +fn main() { + let mut router = actix_router::Router::::build(); + + for route in register!(brackets) { + router.path(route, true); + } + + let actix = router.finish(); + + if firestorm::enabled() { + firestorm::bench("target", || { + for &route in &PATHS { + let mut path = actix_router::Path::new(route); + actix.recognize(&mut path).unwrap(); + } + }) + .unwrap(); + } +} diff --git a/actix-router/src/de.rs b/actix-router/src/de.rs new file mode 100644 index 000000000..27aa49ef2 --- /dev/null +++ b/actix-router/src/de.rs @@ -0,0 +1,805 @@ +use std::borrow::Cow; + +use serde::de::{self, Deserializer, Error as DeError, Visitor}; +use serde::forward_to_deserialize_any; + +use crate::path::{Path, PathIter}; +use crate::{Quoter, ResourcePath}; + +thread_local! { + static FULL_QUOTER: Quoter = Quoter::new(b"+/%", b""); +} + +macro_rules! unsupported_type { + ($trait_fn:ident, $name:expr) => { + fn $trait_fn(self, _: V) -> Result + where + V: Visitor<'de>, + { + Err(de::Error::custom(concat!("unsupported type: ", $name))) + } + }; +} + +macro_rules! parse_single_value { + ($trait_fn:ident) => { + fn $trait_fn(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.path.segment_count() != 1 { + Err(de::value::Error::custom( + format!( + "wrong number of parameters: {} expected 1", + self.path.segment_count() + ) + .as_str(), + )) + } else { + Value { + value: &self.path[0], + } + .$trait_fn(visitor) + } + } + }; +} + +macro_rules! parse_value { + ($trait_fn:ident, $visit_fn:ident, $tp:tt) => { + fn $trait_fn(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let decoded = FULL_QUOTER + .with(|q| q.requote(self.value.as_bytes())) + .map(Cow::Owned) + .unwrap_or(Cow::Borrowed(self.value)); + + let v = decoded.parse().map_err(|_| { + de::value::Error::custom(format!("can not parse {:?} to a {}", self.value, $tp)) + })?; + + visitor.$visit_fn(v) + } + }; +} + +pub struct PathDeserializer<'de, T: ResourcePath> { + path: &'de Path, +} + +impl<'de, T: ResourcePath + 'de> PathDeserializer<'de, T> { + pub fn new(path: &'de Path) -> Self { + PathDeserializer { path } + } +} + +impl<'de, T: ResourcePath + 'de> Deserializer<'de> for PathDeserializer<'de, T> { + type Error = de::value::Error; + + fn deserialize_map(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_map(ParamsDeserializer { + params: self.path.iter(), + current: None, + }) + } + + fn deserialize_struct( + self, + _: &'static str, + _: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.deserialize_map(visitor) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.deserialize_unit(visitor) + } + + fn deserialize_newtype_struct( + self, + _: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.path.segment_count() < len { + Err(de::value::Error::custom( + format!( + "wrong number of parameters: {} expected {}", + self.path.segment_count(), + len + ) + .as_str(), + )) + } else { + visitor.visit_seq(ParamsSeq { + params: self.path.iter(), + }) + } + } + + fn deserialize_tuple_struct( + self, + _: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if self.path.segment_count() < len { + Err(de::value::Error::custom( + format!( + "wrong number of parameters: {} expected {}", + self.path.segment_count(), + len + ) + .as_str(), + )) + } else { + visitor.visit_seq(ParamsSeq { + params: self.path.iter(), + }) + } + } + + fn deserialize_enum( + self, + _: &'static str, + _: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if self.path.is_empty() { + Err(de::value::Error::custom("expected at least one parameters")) + } else { + visitor.visit_enum(ValueEnum { + value: &self.path[0], + }) + } + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(ParamsSeq { + params: self.path.iter(), + }) + } + + unsupported_type!(deserialize_any, "'any'"); + unsupported_type!(deserialize_option, "Option"); + unsupported_type!(deserialize_identifier, "identifier"); + unsupported_type!(deserialize_ignored_any, "ignored_any"); + + parse_single_value!(deserialize_bool); + parse_single_value!(deserialize_i8); + parse_single_value!(deserialize_i16); + parse_single_value!(deserialize_i32); + parse_single_value!(deserialize_i64); + parse_single_value!(deserialize_u8); + parse_single_value!(deserialize_u16); + parse_single_value!(deserialize_u32); + parse_single_value!(deserialize_u64); + parse_single_value!(deserialize_f32); + parse_single_value!(deserialize_f64); + parse_single_value!(deserialize_str); + parse_single_value!(deserialize_string); + parse_single_value!(deserialize_bytes); + parse_single_value!(deserialize_byte_buf); + parse_single_value!(deserialize_char); +} + +struct ParamsDeserializer<'de, T: ResourcePath> { + params: PathIter<'de, T>, + current: Option<(&'de str, &'de str)>, +} + +impl<'de, T: ResourcePath> de::MapAccess<'de> for ParamsDeserializer<'de, T> { + type Error = de::value::Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: de::DeserializeSeed<'de>, + { + self.current = self.params.next().map(|ref item| (item.0, item.1)); + match self.current { + Some((key, _)) => Ok(Some(seed.deserialize(Key { key })?)), + None => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: de::DeserializeSeed<'de>, + { + if let Some((_, value)) = self.current.take() { + seed.deserialize(Value { value }) + } else { + Err(de::value::Error::custom("unexpected item")) + } + } +} + +struct Key<'de> { + key: &'de str, +} + +impl<'de> Deserializer<'de> for Key<'de> { + type Error = de::value::Error; + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_str(self.key) + } + + fn deserialize_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("Unexpected")) + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes + byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct enum ignored_any + } +} + +struct Value<'de> { + value: &'de str, +} + +impl<'de> Deserializer<'de> for Value<'de> { + type Error = de::value::Error; + + parse_value!(deserialize_bool, visit_bool, "bool"); + parse_value!(deserialize_i8, visit_i8, "i8"); + parse_value!(deserialize_i16, visit_i16, "i16"); + parse_value!(deserialize_i32, visit_i32, "i16"); + parse_value!(deserialize_i64, visit_i64, "i64"); + parse_value!(deserialize_u8, visit_u8, "u8"); + parse_value!(deserialize_u16, visit_u16, "u16"); + parse_value!(deserialize_u32, visit_u32, "u32"); + parse_value!(deserialize_u64, visit_u64, "u64"); + parse_value!(deserialize_f32, visit_f32, "f32"); + parse_value!(deserialize_f64, visit_f64, "f64"); + parse_value!(deserialize_char, visit_char, "char"); + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match FULL_QUOTER.with(|q| q.requote(self.value.as_bytes())) { + Some(s) => visitor.visit_string(s), + None => visitor.visit_borrowed_str(self.value), + } + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match FULL_QUOTER.with(|q| q.requote(self.value.as_bytes())) { + Some(s) => visitor.visit_byte_buf(s.into()), + None => visitor.visit_borrowed_bytes(self.value.as_bytes()), + } + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_bytes(visitor) + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_str(visitor) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_some(self) + } + + fn deserialize_enum( + self, + _: &'static str, + _: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_enum(ValueEnum { value: self.value }) + } + + fn deserialize_newtype_struct( + self, + _: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_tuple(self, _: usize, _: V) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("unsupported type: tuple")) + } + + fn deserialize_struct( + self, + _: &'static str, + _: &'static [&'static str], + _: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("unsupported type: struct")) + } + + fn deserialize_tuple_struct( + self, + _: &'static str, + _: usize, + _: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("unsupported type: tuple struct")) + } + + unsupported_type!(deserialize_any, "any"); + unsupported_type!(deserialize_seq, "seq"); + unsupported_type!(deserialize_map, "map"); + unsupported_type!(deserialize_identifier, "identifier"); +} + +struct ParamsSeq<'de, T: ResourcePath> { + params: PathIter<'de, T>, +} + +impl<'de, T: ResourcePath> de::SeqAccess<'de> for ParamsSeq<'de, T> { + type Error = de::value::Error; + + fn next_element_seed(&mut self, seed: U) -> Result, Self::Error> + where + U: de::DeserializeSeed<'de>, + { + match self.params.next() { + Some(item) => Ok(Some(seed.deserialize(Value { value: item.1 })?)), + None => Ok(None), + } + } +} + +struct ValueEnum<'de> { + value: &'de str, +} + +impl<'de> de::EnumAccess<'de> for ValueEnum<'de> { + type Error = de::value::Error; + type Variant = UnitVariant; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: de::DeserializeSeed<'de>, + { + Ok((seed.deserialize(Key { key: self.value })?, UnitVariant)) + } +} + +struct UnitVariant; + +impl<'de> de::VariantAccess<'de> for UnitVariant { + type Error = de::value::Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + Ok(()) + } + + fn newtype_variant_seed(self, _seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + Err(de::value::Error::custom("not supported")) + } + + fn tuple_variant(self, _len: usize, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("not supported")) + } + + fn struct_variant( + self, + _: &'static [&'static str], + _: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(de::value::Error::custom("not supported")) + } +} + +#[cfg(test)] +mod tests { + use serde::{de, Deserialize}; + + use super::*; + use crate::path::Path; + use crate::router::Router; + use crate::ResourceDef; + + #[derive(Deserialize)] + struct MyStruct { + key: String, + value: String, + } + + #[derive(Deserialize)] + struct Id { + _id: String, + } + + #[derive(Debug, Deserialize)] + struct Test1(String, u32); + + #[derive(Debug, Deserialize)] + struct Test2 { + key: String, + value: u32, + } + + #[derive(Debug, Deserialize, PartialEq)] + #[serde(rename_all = "lowercase")] + enum TestEnum { + Val1, + Val2, + } + + #[derive(Debug, Deserialize)] + struct Test3 { + val: TestEnum, + } + + #[test] + fn test_request_extract() { + let mut router = Router::<()>::build(); + router.path("/{key}/{value}/", ()); + let router = router.finish(); + + let mut path = Path::new("/name/user1/"); + assert!(router.recognize(&mut path).is_some()); + + let s: MyStruct = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(s.key, "name"); + assert_eq!(s.value, "user1"); + + let s: (String, String) = + de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(s.0, "name"); + assert_eq!(s.1, "user1"); + + let mut router = Router::<()>::build(); + router.path("/{key}/{value}/", ()); + let router = router.finish(); + + let mut path = Path::new("/name/32/"); + assert!(router.recognize(&mut path).is_some()); + + let s: Test1 = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(s.0, "name"); + assert_eq!(s.1, 32); + + let s: Test2 = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(s.key, "name"); + assert_eq!(s.value, 32); + + let s: (String, u8) = + de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(s.0, "name"); + assert_eq!(s.1, 32); + + let res: Vec = + de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(res[0], "name".to_owned()); + assert_eq!(res[1], "32".to_owned()); + } + + #[test] + fn test_extract_path_single() { + let mut router = Router::<()>::build(); + router.path("/{value}/", ()); + let router = router.finish(); + + let mut path = Path::new("/32/"); + assert!(router.recognize(&mut path).is_some()); + let i: i8 = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(i, 32); + } + + #[test] + fn test_extract_enum() { + let mut router = Router::<()>::build(); + router.path("/{val}/", ()); + let router = router.finish(); + + let mut path = Path::new("/val1/"); + assert!(router.recognize(&mut path).is_some()); + let i: TestEnum = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(i, TestEnum::Val1); + + let mut router = Router::<()>::build(); + router.path("/{val1}/{val2}/", ()); + let router = router.finish(); + + let mut path = Path::new("/val1/val2/"); + assert!(router.recognize(&mut path).is_some()); + let i: (TestEnum, TestEnum) = + de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(i, (TestEnum::Val1, TestEnum::Val2)); + } + + #[test] + fn test_extract_enum_value() { + let mut router = Router::<()>::build(); + router.path("/{val}/", ()); + let router = router.finish(); + + let mut path = Path::new("/val1/"); + assert!(router.recognize(&mut path).is_some()); + let i: Test3 = de::Deserialize::deserialize(PathDeserializer::new(&path)).unwrap(); + assert_eq!(i.val, TestEnum::Val1); + + let mut path = Path::new("/val3/"); + assert!(router.recognize(&mut path).is_some()); + let i: Result = + de::Deserialize::deserialize(PathDeserializer::new(&path)); + assert!(i.is_err()); + assert!(format!("{:?}", i).contains("unknown variant")); + } + + #[test] + fn test_extract_errors() { + let mut router = Router::<()>::build(); + router.path("/{value}/", ()); + let router = router.finish(); + + let mut path = Path::new("/name/"); + assert!(router.recognize(&mut path).is_some()); + + let s: Result = + de::Deserialize::deserialize(PathDeserializer::new(&path)); + assert!(s.is_err()); + assert!(format!("{:?}", s).contains("wrong number of parameters")); + + let s: Result = + de::Deserialize::deserialize(PathDeserializer::new(&path)); + assert!(s.is_err()); + assert!(format!("{:?}", s).contains("can not parse")); + + let s: Result<(String, String), de::value::Error> = + de::Deserialize::deserialize(PathDeserializer::new(&path)); + assert!(s.is_err()); + assert!(format!("{:?}", s).contains("wrong number of parameters")); + + let s: Result = + de::Deserialize::deserialize(PathDeserializer::new(&path)); + assert!(s.is_err()); + assert!(format!("{:?}", s).contains("can not parse")); + } + + #[test] + fn deserialize_path_decode_string() { + let rdef = ResourceDef::new("/{key}"); + + let mut path = Path::new("/%25"); + rdef.capture_match_info(&mut path); + let de = PathDeserializer::new(&path); + let segment: String = serde::Deserialize::deserialize(de).unwrap(); + assert_eq!(segment, "%"); + + let mut path = Path::new("/%2F"); + rdef.capture_match_info(&mut path); + let de = PathDeserializer::new(&path); + let segment: String = serde::Deserialize::deserialize(de).unwrap(); + assert_eq!(segment, "/") + } + + #[test] + fn deserialize_path_decode_seq() { + let rdef = ResourceDef::new("/{key}/{value}"); + + let mut path = Path::new("/%30%25/%30%2F"); + rdef.capture_match_info(&mut path); + let de = PathDeserializer::new(&path); + let segment: (String, String) = serde::Deserialize::deserialize(de).unwrap(); + assert_eq!(segment.0, "0%"); + assert_eq!(segment.1, "0/"); + } + + #[test] + fn deserialize_path_decode_map() { + #[derive(Deserialize)] + struct Vals { + key: String, + value: String, + } + + let rdef = ResourceDef::new("/{key}/{value}"); + + let mut path = Path::new("/%25/%2F"); + rdef.capture_match_info(&mut path); + let de = PathDeserializer::new(&path); + let vals: Vals = serde::Deserialize::deserialize(de).unwrap(); + assert_eq!(vals.key, "%"); + assert_eq!(vals.value, "/"); + } + + #[test] + fn deserialize_borrowed() { + #[derive(Debug, Deserialize)] + struct Params<'a> { + val: &'a str, + } + + let rdef = ResourceDef::new("/{val}"); + + let mut path = Path::new("/X"); + rdef.capture_match_info(&mut path); + let de = PathDeserializer::new(&path); + let params: Params<'_> = serde::Deserialize::deserialize(de).unwrap(); + assert_eq!(params.val, "X"); + let de = PathDeserializer::new(&path); + let params: &str = serde::Deserialize::deserialize(de).unwrap(); + assert_eq!(params, "X"); + + let mut path = Path::new("/%2F"); + rdef.capture_match_info(&mut path); + let de = PathDeserializer::new(&path); + assert!( as serde::Deserialize>::deserialize(de).is_err()); + let de = PathDeserializer::new(&path); + assert!(<&str as serde::Deserialize>::deserialize(de).is_err()); + } + + // #[test] + // fn test_extract_path_decode() { + // let mut router = Router::<()>::default(); + // router.register_resource(Resource::new(ResourceDef::new("/{value}/"))); + + // macro_rules! test_single_value { + // ($value:expr, $expected:expr) => {{ + // let req = TestRequest::with_uri($value).finish(); + // let info = router.recognize(&req, &(), 0); + // let req = req.with_route_info(info); + // assert_eq!( + // *Path::::from_request(&req, &PathConfig::default()).unwrap(), + // $expected + // ); + // }}; + // } + + // test_single_value!("/%25/", "%"); + // test_single_value!("/%40%C2%A3%24%25%5E%26%2B%3D/", "@£$%^&+="); + // test_single_value!("/%2B/", "+"); + // test_single_value!("/%252B/", "%2B"); + // test_single_value!("/%2F/", "/"); + // test_single_value!("/%252F/", "%2F"); + // test_single_value!( + // "/http%3A%2F%2Flocalhost%3A80%2Ffoo/", + // "http://localhost:80/foo" + // ); + // test_single_value!("/%2Fvar%2Flog%2Fsyslog/", "/var/log/syslog"); + // test_single_value!( + // "/http%3A%2F%2Flocalhost%3A80%2Ffile%2F%252Fvar%252Flog%252Fsyslog/", + // "http://localhost:80/file/%2Fvar%2Flog%2Fsyslog" + // ); + + // let req = TestRequest::with_uri("/%25/7/?id=test").finish(); + + // let mut router = Router::<()>::default(); + // router.register_resource(Resource::new(ResourceDef::new("/{key}/{value}/"))); + // let info = router.recognize(&req, &(), 0); + // let req = req.with_route_info(info); + + // let s = Path::::from_request(&req, &PathConfig::default()).unwrap(); + // assert_eq!(s.key, "%"); + // assert_eq!(s.value, 7); + + // let s = Path::<(String, String)>::from_request(&req, &PathConfig::default()).unwrap(); + // assert_eq!(s.0, "%"); + // assert_eq!(s.1, "7"); + // } + + // #[test] + // fn test_extract_path_no_decode() { + // let mut router = Router::<()>::default(); + // router.register_resource(Resource::new(ResourceDef::new("/{value}/"))); + + // let req = TestRequest::with_uri("/%25/").finish(); + // let info = router.recognize(&req, &(), 0); + // let req = req.with_route_info(info); + // assert_eq!( + // *Path::::from_request(&req, &&PathConfig::default().disable_decoding()) + // .unwrap(), + // "%25" + // ); + // } +} diff --git a/actix-router/src/lib.rs b/actix-router/src/lib.rs new file mode 100644 index 000000000..22f294b9d --- /dev/null +++ b/actix-router/src/lib.rs @@ -0,0 +1,28 @@ +//! Resource path matching and router. + +#![deny(rust_2018_idioms, nonstandard_style)] +#![warn(future_incompatible)] +#![doc(html_logo_url = "https://actix.rs/img/logo.png")] +#![doc(html_favicon_url = "https://actix.rs/favicon.ico")] + +mod de; +mod path; +mod pattern; +mod quoter; +mod resource; +mod resource_path; +mod router; + +#[cfg(feature = "http")] +mod url; + +pub use self::de::PathDeserializer; +pub use self::path::Path; +pub use self::pattern::{IntoPatterns, Patterns}; +pub use self::quoter::Quoter; +pub use self::resource::ResourceDef; +pub use self::resource_path::{Resource, ResourcePath}; +pub use self::router::{ResourceInfo, Router, RouterBuilder}; + +#[cfg(feature = "http")] +pub use self::url::Url; diff --git a/actix-router/src/path.rs b/actix-router/src/path.rs new file mode 100644 index 000000000..fc7bb16ac --- /dev/null +++ b/actix-router/src/path.rs @@ -0,0 +1,250 @@ +use std::borrow::Cow; +use std::ops::{DerefMut, Index}; + +use firestorm::profile_method; +use serde::de; + +use crate::{de::PathDeserializer, Resource, ResourcePath}; + +#[derive(Debug, Clone)] +pub(crate) enum PathItem { + Static(Cow<'static, str>), + Segment(u16, u16), +} + +impl Default for PathItem { + fn default() -> Self { + Self::Static(Cow::Borrowed("")) + } +} + +/// Resource path match information. +/// +/// If resource path contains variable patterns, `Path` stores them. +#[derive(Debug, Clone, Default)] +pub struct Path { + path: T, + pub(crate) skip: u16, + pub(crate) segments: Vec<(Cow<'static, str>, PathItem)>, +} + +impl Path { + pub fn new(path: T) -> Path { + Path { + path, + skip: 0, + segments: Vec::new(), + } + } + + /// Get reference to inner path instance. + #[inline] + pub fn get_ref(&self) -> &T { + &self.path + } + + /// Get mutable reference to inner path instance. + #[inline] + pub fn get_mut(&mut self) -> &mut T { + &mut self.path + } + + /// Path. + #[inline] + pub fn path(&self) -> &str { + profile_method!(path); + + let skip = self.skip as usize; + let path = self.path.path(); + if skip <= path.len() { + &path[skip..] + } else { + "" + } + } + + /// Set new path. + #[inline] + pub fn set(&mut self, path: T) { + self.skip = 0; + self.path = path; + self.segments.clear(); + } + + /// Reset state. + #[inline] + pub fn reset(&mut self) { + self.skip = 0; + self.segments.clear(); + } + + /// Skip first `n` chars in path. + #[inline] + pub fn skip(&mut self, n: u16) { + self.skip += n; + } + + pub(crate) fn add(&mut self, name: impl Into>, value: PathItem) { + profile_method!(add); + + match value { + PathItem::Static(s) => self.segments.push((name.into(), PathItem::Static(s))), + PathItem::Segment(begin, end) => self.segments.push(( + name.into(), + PathItem::Segment(self.skip + begin, self.skip + end), + )), + } + } + + #[doc(hidden)] + pub fn add_static( + &mut self, + name: impl Into>, + value: impl Into>, + ) { + self.segments + .push((name.into(), PathItem::Static(value.into()))); + } + + /// Check if there are any matched patterns. + #[inline] + pub fn is_empty(&self) -> bool { + self.segments.is_empty() + } + + /// Returns number of interpolated segments. + #[inline] + pub fn segment_count(&self) -> usize { + self.segments.len() + } + + /// Get matched parameter by name without type conversion + pub fn get(&self, name: &str) -> Option<&str> { + profile_method!(get); + + for (seg_name, val) in self.segments.iter() { + if name == seg_name { + return match val { + PathItem::Static(ref s) => Some(s), + PathItem::Segment(s, e) => { + Some(&self.path.path()[(*s as usize)..(*e as usize)]) + } + }; + } + } + + None + } + + /// Get unprocessed part of the path + pub fn unprocessed(&self) -> &str { + &self.path.path()[(self.skip as usize)..] + } + + /// Get matched parameter by name. + /// + /// If keyed parameter is not available empty string is used as default value. + pub fn query(&self, key: &str) -> &str { + profile_method!(query); + + if let Some(s) = self.get(key) { + s + } else { + "" + } + } + + /// Return iterator to items in parameter container. + pub fn iter(&self) -> PathIter<'_, T> { + PathIter { + idx: 0, + params: self, + } + } + + /// Try to deserialize matching parameters to a specified type `U` + pub fn load<'de, U: serde::Deserialize<'de>>(&'de self) -> Result { + profile_method!(load); + de::Deserialize::deserialize(PathDeserializer::new(self)) + } +} + +#[derive(Debug)] +pub struct PathIter<'a, T> { + idx: usize, + params: &'a Path, +} + +impl<'a, T: ResourcePath> Iterator for PathIter<'a, T> { + type Item = (&'a str, &'a str); + + #[inline] + fn next(&mut self) -> Option<(&'a str, &'a str)> { + if self.idx < self.params.segment_count() { + let idx = self.idx; + let res = match self.params.segments[idx].1 { + PathItem::Static(ref s) => s, + PathItem::Segment(s, e) => &self.params.path.path()[(s as usize)..(e as usize)], + }; + self.idx += 1; + return Some((&self.params.segments[idx].0, res)); + } + None + } +} + +impl<'a, T: ResourcePath> Index<&'a str> for Path { + type Output = str; + + fn index(&self, name: &'a str) -> &str { + self.get(name) + .expect("Value for parameter is not available") + } +} + +impl Index for Path { + type Output = str; + + fn index(&self, idx: usize) -> &str { + match self.segments[idx].1 { + PathItem::Static(ref s) => s, + PathItem::Segment(s, e) => &self.path.path()[(s as usize)..(e as usize)], + } + } +} + +impl Resource for Path { + type Path = T; + + fn resource_path(&mut self) -> &mut Path { + self + } +} + +impl Resource for T +where + T: DerefMut>, + P: ResourcePath, +{ + type Path = P; + + fn resource_path(&mut self) -> &mut Path { + &mut *self + } +} + +#[cfg(test)] +mod tests { + use std::cell::RefCell; + + use super::*; + + #[test] + fn deref_impls() { + let mut foo = Path::new("/foo"); + let _ = (&mut foo).resource_path(); + + let foo = RefCell::new(foo); + let _ = foo.borrow_mut().resource_path(); + } +} diff --git a/actix-router/src/pattern.rs b/actix-router/src/pattern.rs new file mode 100644 index 000000000..78a638a78 --- /dev/null +++ b/actix-router/src/pattern.rs @@ -0,0 +1,92 @@ +/// One or many patterns. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Patterns { + Single(String), + List(Vec), +} + +impl Patterns { + pub fn is_empty(&self) -> bool { + match self { + Patterns::Single(_) => false, + Patterns::List(pats) => pats.is_empty(), + } + } +} + +/// Helper trait for type that could be converted to one or more path patterns. +pub trait IntoPatterns { + fn patterns(&self) -> Patterns; +} + +impl IntoPatterns for String { + fn patterns(&self) -> Patterns { + Patterns::Single(self.clone()) + } +} + +impl IntoPatterns for &String { + fn patterns(&self) -> Patterns { + (*self).patterns() + } +} + +impl IntoPatterns for str { + fn patterns(&self) -> Patterns { + Patterns::Single(self.to_owned()) + } +} + +impl IntoPatterns for &str { + fn patterns(&self) -> Patterns { + (*self).patterns() + } +} + +impl IntoPatterns for bytestring::ByteString { + fn patterns(&self) -> Patterns { + Patterns::Single(self.to_string()) + } +} + +impl IntoPatterns for Patterns { + fn patterns(&self) -> Patterns { + self.clone() + } +} + +impl> IntoPatterns for Vec { + fn patterns(&self) -> Patterns { + let mut patterns = self.iter().map(|v| v.as_ref().to_owned()); + + match patterns.size_hint() { + (1, _) => Patterns::Single(patterns.next().unwrap()), + _ => Patterns::List(patterns.collect()), + } + } +} + +macro_rules! array_patterns_single (($tp:ty) => { + impl IntoPatterns for [$tp; 1] { + fn patterns(&self) -> Patterns { + Patterns::Single(self[0].to_owned()) + } + } +}); + +macro_rules! array_patterns_multiple (($tp:ty, $str_fn:expr, $($num:tt) +) => { + // for each array length specified in space-separated $num + $( + impl IntoPatterns for [$tp; $num] { + fn patterns(&self) -> Patterns { + Patterns::List(self.iter().map($str_fn).collect()) + } + } + )+ +}); + +array_patterns_single!(&str); +array_patterns_multiple!(&str, |&v| v.to_owned(), 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16); + +array_patterns_single!(String); +array_patterns_multiple!(String, |v| v.clone(), 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16); diff --git a/actix-router/src/quoter.rs b/actix-router/src/quoter.rs new file mode 100644 index 000000000..26ecc92cd --- /dev/null +++ b/actix-router/src/quoter.rs @@ -0,0 +1,219 @@ +#[allow(dead_code)] +const GEN_DELIMS: &[u8] = b":/?#[]@"; + +#[allow(dead_code)] +const SUB_DELIMS_WITHOUT_QS: &[u8] = b"!$'()*,"; + +#[allow(dead_code)] +const SUB_DELIMS: &[u8] = b"!$'()*,+?=;"; + +#[allow(dead_code)] +const RESERVED: &[u8] = b":/?#[]@!$'()*,+?=;"; + +#[allow(dead_code)] +const UNRESERVED: &[u8] = b"abcdefghijklmnopqrstuvwxyz + ABCDEFGHIJKLMNOPQRSTUVWXYZ + 1234567890 + -._~"; + +const ALLOWED: &[u8] = b"abcdefghijklmnopqrstuvwxyz + ABCDEFGHIJKLMNOPQRSTUVWXYZ + 1234567890 + -._~ + !$'()*,"; + +const QS: &[u8] = b"+&=;b"; + +/// A quoter +pub struct Quoter { + /// Simple bit-map of safe values in the 0-127 ASCII range. + safe_table: [u8; 16], + + /// Simple bit-map of protected values in the 0-127 ASCII range. + protected_table: [u8; 16], +} + +impl Quoter { + pub fn new(safe: &[u8], protected: &[u8]) -> Quoter { + let mut quoter = Quoter { + safe_table: [0; 16], + protected_table: [0; 16], + }; + + // prepare safe table + for ch in 0..128 { + if ALLOWED.contains(&ch) { + set_bit(&mut quoter.safe_table, ch); + } + + if QS.contains(&ch) { + set_bit(&mut quoter.safe_table, ch); + } + } + + for &ch in safe { + set_bit(&mut quoter.safe_table, ch) + } + + // prepare protected table + for &ch in protected { + set_bit(&mut quoter.safe_table, ch); + set_bit(&mut quoter.protected_table, ch); + } + + quoter + } + + /// Re-quotes... ? + /// + /// Returns `None` when no modification to the original string was required. + pub fn requote(&self, val: &[u8]) -> Option { + let mut has_pct = 0; + let mut pct = [b'%', 0, 0]; + let mut idx = 0; + let mut cloned: Option> = None; + + let len = val.len(); + + while idx < len { + let ch = val[idx]; + + if has_pct != 0 { + pct[has_pct] = val[idx]; + has_pct += 1; + + if has_pct == 3 { + has_pct = 0; + let buf = cloned.as_mut().unwrap(); + + if let Some(ch) = hex_pair_to_char(pct[1], pct[2]) { + if ch < 128 { + if bit_at(&self.protected_table, ch) { + buf.extend_from_slice(&pct); + idx += 1; + continue; + } + + if bit_at(&self.safe_table, ch) { + buf.push(ch); + idx += 1; + continue; + } + } + + buf.push(ch); + } else { + buf.extend_from_slice(&pct[..]); + } + } + } else if ch == b'%' { + has_pct = 1; + + if cloned.is_none() { + let mut c = Vec::with_capacity(len); + c.extend_from_slice(&val[..idx]); + cloned = Some(c); + } + } else if let Some(ref mut cloned) = cloned { + cloned.push(ch) + } + + idx += 1; + } + + cloned.map(|data| String::from_utf8_lossy(&data).into_owned()) + } +} + +/// Converts an ASCII character in the hex-encoded set (`0-9`, `A-F`, `a-f`) to its integer +/// representation from `0x0`–`0xF`. +/// +/// - `0x30 ('0') => 0x0` +/// - `0x39 ('9') => 0x9` +/// - `0x41 ('a') => 0xA` +/// - `0x61 ('A') => 0xA` +/// - `0x46 ('f') => 0xF` +/// - `0x66 ('F') => 0xF` +fn from_ascii_hex(v: u8) -> Option { + match v { + b'0'..=b'9' => Some(v - 0x30), // ord('0') == 0x30 + b'A'..=b'F' => Some(v - 0x41 + 10), // ord('A') == 0x41 + b'a'..=b'f' => Some(v - 0x61 + 10), // ord('a') == 0x61 + _ => None, + } +} + +/// Decode a ASCII hex-encoded pair to an integer. +/// +/// Returns `None` if either portion of the decoded pair does not evaluate to a valid hex value. +/// +/// - `0x33 ('3'), 0x30 ('0') => 0x30 ('0')` +/// - `0x34 ('4'), 0x31 ('1') => 0x41 ('A')` +/// - `0x36 ('6'), 0x31 ('1') => 0x61 ('a')` +fn hex_pair_to_char(d1: u8, d2: u8) -> Option { + let (d_high, d_low) = (from_ascii_hex(d1)?, from_ascii_hex(d2)?); + + // left shift high nibble by 4 bits + Some(d_high << 4 | d_low) +} + +/// Sets bit in given bit-map to 1=true. +/// +/// # Panics +/// Panics if `ch` index is out of bounds. +fn set_bit(array: &mut [u8], ch: u8) { + array[(ch >> 3) as usize] |= 0b1 << (ch & 0b111) +} + +/// Returns true if bit to true in given bit-map. +/// +/// # Panics +/// Panics if `ch` index is out of bounds. +fn bit_at(array: &[u8], ch: u8) -> bool { + array[(ch >> 3) as usize] & (0b1 << (ch & 0b111)) != 0 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn hex_encoding() { + let hex = b"0123456789abcdefABCDEF"; + + for i in 0..256 { + let c = i as u8; + if hex.contains(&c) { + assert!(from_ascii_hex(c).is_some()) + } else { + assert!(from_ascii_hex(c).is_none()) + } + } + + let expected = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, + ]; + for i in 0..hex.len() { + assert_eq!(from_ascii_hex(hex[i]).unwrap(), expected[i]); + } + } + + #[test] + fn custom_quoter() { + let q = Quoter::new(b"", b"+"); + assert_eq!(q.requote(b"/a%25c").unwrap(), "/a%c"); + assert_eq!(q.requote(b"/a%2Bc").unwrap(), "/a%2Bc"); + + let q = Quoter::new(b"%+", b"/"); + assert_eq!(q.requote(b"/a%25b%2Bc").unwrap(), "/a%b+c"); + assert_eq!(q.requote(b"/a%2fb").unwrap(), "/a%2fb"); + assert_eq!(q.requote(b"/a%2Fb").unwrap(), "/a%2Fb"); + assert_eq!(q.requote(b"/a%0Ab").unwrap(), "/a\nb"); + } + + #[test] + fn quoter_no_modification() { + let q = Quoter::new(b"", b""); + assert_eq!(q.requote(b"/abc/../efg"), None); + } +} diff --git a/actix-router/src/resource.rs b/actix-router/src/resource.rs new file mode 100644 index 000000000..d39a6b923 --- /dev/null +++ b/actix-router/src/resource.rs @@ -0,0 +1,1815 @@ +use std::{ + borrow::{Borrow, Cow}, + collections::HashMap, + hash::{BuildHasher, Hash, Hasher}, + mem, +}; + +use firestorm::{profile_fn, profile_method, profile_section}; +use regex::{escape, Regex, RegexSet}; + +use crate::{ + path::{Path, PathItem}, + IntoPatterns, Patterns, Resource, ResourcePath, +}; + +const MAX_DYNAMIC_SEGMENTS: usize = 16; + +/// Regex flags to allow '.' in regex to match '\n' +/// +/// See the docs under: https://docs.rs/regex/1/regex/#grouping-and-flags +const REGEX_FLAGS: &str = "(?s-m)"; + +/// Describes the set of paths that match to a resource. +/// +/// `ResourceDef`s are effectively a way to transform the a custom resource pattern syntax into +/// suitable regular expressions from which to check matches with paths and capture portions of a +/// matched path into variables. Common cases are on a fast path that avoids going through the +/// regex engine. +/// +/// +/// # Pattern Format and Matching Behavior +/// Resource pattern is defined as a string of zero or more _segments_ where each segment is +/// preceded by a slash `/`. +/// +/// This means that pattern string __must__ either be empty or begin with a slash (`/`). This also +/// implies that a trailing slash in pattern defines an empty segment. For example, the pattern +/// `"/user/"` has two segments: `["user", ""]` +/// +/// A key point to understand is that `ResourceDef` matches segments, not strings. Segments are +/// matched individually. For example, the pattern `/user/` is not considered a prefix for the path +/// `/user/123/456`, because the second segment doesn't match: `["user", ""]` +/// vs `["user", "123", "456"]`. +/// +/// This definition is consistent with the definition of absolute URL path in +/// [RFC 3986 §3.3](https://datatracker.ietf.org/doc/html/rfc3986#section-3.3) +/// +/// +/// # Static Resources +/// A static resource is the most basic type of definition. Pass a pattern to [new][Self::new]. +/// Conforming paths must match the pattern exactly. +/// +/// ## Examples +/// ``` +/// # use actix_router::ResourceDef; +/// let resource = ResourceDef::new("/home"); +/// +/// assert!(resource.is_match("/home")); +/// +/// assert!(!resource.is_match("/home/")); +/// assert!(!resource.is_match("/home/new")); +/// assert!(!resource.is_match("/homes")); +/// assert!(!resource.is_match("/search")); +/// ``` +/// +/// # Dynamic Segments +/// Also known as "path parameters". Resources can define sections of a pattern that be extracted +/// from a conforming path, if it conforms to (one of) the resource pattern(s). +/// +/// The marker for a dynamic segment is curly braces wrapping an identifier. For example, +/// `/user/{id}` would match paths like `/user/123` or `/user/james` and be able to extract the user +/// IDs "123" and "james", respectively. +/// +/// However, this resource pattern (`/user/{id}`) would, not cover `/user/123/stars` (unless +/// constructed as a prefix; see next section) since the default pattern for segments matches all +/// characters until it finds a `/` character (or the end of the path). Custom segment patterns are +/// covered further down. +/// +/// Dynamic segments do not need to be delimited by `/` characters, they can be defined within a +/// path segment. For example, `/rust-is-{opinion}` can match the paths `/rust-is-cool` and +/// `/rust-is-hard`. +/// +/// For information on capturing segment values from paths or other custom resource types, +/// see [`capture_match_info`][Self::capture_match_info] +/// and [`capture_match_info_fn`][Self::capture_match_info_fn]. +/// +/// A resource definition can contain at most 16 dynamic segments. +/// +/// ## Examples +/// ``` +/// use actix_router::{Path, ResourceDef}; +/// +/// let resource = ResourceDef::prefix("/user/{id}"); +/// +/// assert!(resource.is_match("/user/123")); +/// assert!(!resource.is_match("/user")); +/// assert!(!resource.is_match("/user/")); +/// +/// let mut path = Path::new("/user/123"); +/// resource.capture_match_info(&mut path); +/// assert_eq!(path.get("id").unwrap(), "123"); +/// ``` +/// +/// # Prefix Resources +/// A prefix resource is defined as pattern that can match just the start of a path, up to a +/// segment boundary. +/// +/// Prefix patterns with a trailing slash may have an unexpected, though correct, behavior. +/// They define and therefore require an empty segment in order to match. It is easier to understand +/// this behavior after reading the [matching behavior section]. Examples are given below. +/// +/// The empty pattern (`""`), as a prefix, matches any path. +/// +/// Prefix resources can contain dynamic segments. +/// +/// ## Examples +/// ``` +/// # use actix_router::ResourceDef; +/// let resource = ResourceDef::prefix("/home"); +/// assert!(resource.is_match("/home")); +/// assert!(resource.is_match("/home/new")); +/// assert!(!resource.is_match("/homes")); +/// +/// // prefix pattern with a trailing slash +/// let resource = ResourceDef::prefix("/user/{id}/"); +/// assert!(resource.is_match("/user/123/")); +/// assert!(resource.is_match("/user/123//stars")); +/// assert!(!resource.is_match("/user/123/stars")); +/// assert!(!resource.is_match("/user/123")); +/// ``` +/// +/// # Custom Regex Segments +/// Dynamic segments can be customised to only match a specific regular expression. It can be +/// helpful to do this if resource definitions would otherwise conflict and cause one to +/// be inaccessible. +/// +/// The regex used when capturing segment values can be specified explicitly using this syntax: +/// `{name:regex}`. For example, `/user/{id:\d+}` will only match paths where the user ID +/// is numeric. +/// +/// The regex could potentially match multiple segments. If this is not wanted, then care must be +/// taken to avoid matching a slash `/`. It is guaranteed, however, that the match ends at a +/// segment boundary; the pattern `r"(/|$)` is always appended to the regex. +/// +/// By default, dynamic segments use this regex: `[^/]+`. This shows why it is the case, as shown in +/// the earlier section, that segments capture a slice of the path up to the next `/` character. +/// +/// Custom regex segments can be used in static and prefix resource definition variants. +/// +/// ## Examples +/// ``` +/// # use actix_router::ResourceDef; +/// let resource = ResourceDef::new(r"/user/{id:\d+}"); +/// assert!(resource.is_match("/user/123")); +/// assert!(resource.is_match("/user/314159")); +/// assert!(!resource.is_match("/user/abc")); +/// ``` +/// +/// # Tail Segments +/// As a shortcut to defining a custom regex for matching _all_ remaining characters (not just those +/// up until a `/` character), there is a special pattern to match (and capture) the remaining +/// path portion. +/// +/// To do this, use the segment pattern: `{name}*`. Since a tail segment also has a name, values are +/// extracted in the same way as non-tail dynamic segments. +/// +/// ## Examples +/// ``` +/// # use actix_router::{Path, ResourceDef}; +/// let resource = ResourceDef::new("/blob/{tail}*"); +/// assert!(resource.is_match("/blob/HEAD/Cargo.toml")); +/// assert!(resource.is_match("/blob/HEAD/README.md")); +/// +/// let mut path = Path::new("/blob/main/LICENSE"); +/// resource.capture_match_info(&mut path); +/// assert_eq!(path.get("tail").unwrap(), "main/LICENSE"); +/// ``` +/// +/// # Multi-Pattern Resources +/// For resources that can map to multiple distinct paths, it may be suitable to use +/// multi-pattern resources by passing an array/vec to [`new`][Self::new]. They will be combined +/// into a regex set which is usually quicker to check matches on than checking each +/// pattern individually. +/// +/// Multi-pattern resources can contain dynamic segments just like single pattern ones. +/// However, take care to use consistent and semantically-equivalent segment names; it could affect +/// expectations in the router using these definitions and cause runtime panics. +/// +/// ## Examples +/// ``` +/// # use actix_router::ResourceDef; +/// let resource = ResourceDef::new(["/home", "/index"]); +/// assert!(resource.is_match("/home")); +/// assert!(resource.is_match("/index")); +/// ``` +/// +/// # Trailing Slashes +/// It should be noted that this library takes no steps to normalize intra-path or trailing slashes. +/// As such, all resource definitions implicitly expect a pre-processing step to normalize paths if +/// they you wish to accommodate "recoverable" path errors. Below are several examples of +/// resource-path pairs that would not be compatible. +/// +/// ## Examples +/// ``` +/// # use actix_router::ResourceDef; +/// assert!(!ResourceDef::new("/root").is_match("/root/")); +/// assert!(!ResourceDef::new("/root/").is_match("/root")); +/// assert!(!ResourceDef::prefix("/root/").is_match("/root")); +/// ``` +/// +/// [matching behavior section]: #pattern-format-and-matching-behavior +#[derive(Clone, Debug)] +pub struct ResourceDef { + id: u16, + + /// Optional name of resource. + name: Option, + + /// Pattern that generated the resource definition. + patterns: Patterns, + + is_prefix: bool, + + /// Pattern type. + pat_type: PatternType, + + /// List of segments that compose the pattern, in order. + segments: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +enum PatternSegment { + /// Literal slice of pattern. + Const(String), + + /// Name of dynamic segment. + Var(String), +} + +#[derive(Clone, Debug)] +#[allow(clippy::large_enum_variant)] +enum PatternType { + /// Single constant/literal segment. + Static(String), + + /// Single regular expression and list of dynamic segment names. + Dynamic(Regex, Vec<&'static str>), + + /// Regular expression set and list of component expressions plus dynamic segment names. + DynamicSet(RegexSet, Vec<(Regex, Vec<&'static str>)>), +} + +impl ResourceDef { + /// Constructs a new resource definition from patterns. + /// + /// Multi-pattern resources can be constructed by providing a slice (or vec) of patterns. + /// + /// # Panics + /// Panics if path pattern is malformed. + /// + /// # Examples + /// ``` + /// use actix_router::ResourceDef; + /// + /// let resource = ResourceDef::new("/user/{id}"); + /// assert!(resource.is_match("/user/123")); + /// assert!(!resource.is_match("/user/123/stars")); + /// assert!(!resource.is_match("user/1234")); + /// assert!(!resource.is_match("/foo")); + /// + /// let resource = ResourceDef::new(["/profile", "/user/{id}"]); + /// assert!(resource.is_match("/profile")); + /// assert!(resource.is_match("/user/123")); + /// assert!(!resource.is_match("user/123")); + /// assert!(!resource.is_match("/foo")); + /// ``` + pub fn new(paths: T) -> Self { + profile_method!(new); + Self::construct(paths, false) + } + + /// Constructs a new resource definition using a pattern that performs prefix matching. + /// + /// More specifically, the regular expressions generated for matching are different when using + /// this method vs using `new`; they will not be appended with the `$` meta-character that + /// matches the end of an input. + /// + /// Although it will compile and run correctly, it is meaningless to construct a prefix + /// resource definition with a tail segment; use [`new`][Self::new] in this case. + /// + /// # Panics + /// Panics if path pattern is malformed. + /// + /// # Examples + /// ``` + /// use actix_router::ResourceDef; + /// + /// let resource = ResourceDef::prefix("/user/{id}"); + /// assert!(resource.is_match("/user/123")); + /// assert!(resource.is_match("/user/123/stars")); + /// assert!(!resource.is_match("user/123")); + /// assert!(!resource.is_match("user/123/stars")); + /// assert!(!resource.is_match("/foo")); + /// ``` + pub fn prefix(paths: T) -> Self { + profile_method!(prefix); + ResourceDef::construct(paths, true) + } + + /// Constructs a new resource definition using a string pattern that performs prefix matching, + /// ensuring a leading `/` if pattern is not empty. + /// + /// # Panics + /// Panics if path pattern is malformed. + /// + /// # Examples + /// ``` + /// use actix_router::ResourceDef; + /// + /// let resource = ResourceDef::root_prefix("user/{id}"); + /// + /// assert_eq!(&resource, &ResourceDef::prefix("/user/{id}")); + /// assert_eq!(&resource, &ResourceDef::root_prefix("/user/{id}")); + /// assert_ne!(&resource, &ResourceDef::new("user/{id}")); + /// assert_ne!(&resource, &ResourceDef::new("/user/{id}")); + /// + /// assert!(resource.is_match("/user/123")); + /// assert!(!resource.is_match("user/123")); + /// ``` + pub fn root_prefix(path: &str) -> Self { + profile_method!(root_prefix); + ResourceDef::prefix(insert_slash(path).into_owned()) + } + + /// Returns a numeric resource ID. + /// + /// If not explicitly set using [`set_id`][Self::set_id], this will return `0`. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/root"); + /// assert_eq!(resource.id(), 0); + /// + /// resource.set_id(42); + /// assert_eq!(resource.id(), 42); + /// ``` + pub fn id(&self) -> u16 { + self.id + } + + /// Set numeric resource ID. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/root"); + /// resource.set_id(42); + /// assert_eq!(resource.id(), 42); + /// ``` + pub fn set_id(&mut self, id: u16) { + self.id = id; + } + + /// Returns resource definition name, if set. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/root"); + /// assert!(resource.name().is_none()); + /// + /// resource.set_name("root"); + /// assert_eq!(resource.name().unwrap(), "root"); + pub fn name(&self) -> Option<&str> { + self.name.as_deref() + } + + /// Assigns a new name to the resource. + /// + /// # Panics + /// Panics if `name` is an empty string. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/root"); + /// resource.set_name("root"); + /// assert_eq!(resource.name().unwrap(), "root"); + /// ``` + pub fn set_name(&mut self, name: impl Into) { + let name = name.into(); + + assert!(!name.is_empty(), "resource name should not be empty"); + + self.name = Some(name) + } + + /// Returns `true` if pattern type is prefix. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// assert!(ResourceDef::prefix("/user").is_prefix()); + /// assert!(!ResourceDef::new("/user").is_prefix()); + /// ``` + pub fn is_prefix(&self) -> bool { + self.is_prefix + } + + /// Returns the pattern string that generated the resource definition. + /// + /// If definition is constructed with multiple patterns, the first pattern is returned. To get + /// all patterns, use [`patterns_iter`][Self::pattern_iter]. If resource has 0 patterns, + /// returns `None`. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/user/{id}"); + /// assert_eq!(resource.pattern().unwrap(), "/user/{id}"); + /// + /// let mut resource = ResourceDef::new(["/profile", "/user/{id}"]); + /// assert_eq!(resource.pattern(), Some("/profile")); + pub fn pattern(&self) -> Option<&str> { + match &self.patterns { + Patterns::Single(pattern) => Some(pattern.as_str()), + Patterns::List(patterns) => patterns.first().map(AsRef::as_ref), + } + } + + /// Returns iterator of pattern strings that generated the resource definition. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut resource = ResourceDef::new("/root"); + /// let mut iter = resource.pattern_iter(); + /// assert_eq!(iter.next().unwrap(), "/root"); + /// assert!(iter.next().is_none()); + /// + /// let mut resource = ResourceDef::new(["/root", "/backup"]); + /// let mut iter = resource.pattern_iter(); + /// assert_eq!(iter.next().unwrap(), "/root"); + /// assert_eq!(iter.next().unwrap(), "/backup"); + /// assert!(iter.next().is_none()); + pub fn pattern_iter(&self) -> impl Iterator { + struct PatternIter<'a> { + patterns: &'a Patterns, + list_idx: usize, + done: bool, + } + + impl<'a> Iterator for PatternIter<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + match &self.patterns { + Patterns::Single(pattern) => { + if self.done { + return None; + } + + self.done = true; + Some(pattern.as_str()) + } + Patterns::List(patterns) if patterns.is_empty() => None, + Patterns::List(patterns) => match patterns.get(self.list_idx) { + Some(pattern) => { + self.list_idx += 1; + Some(pattern.as_str()) + } + None => { + // fast path future call + self.done = true; + None + } + }, + } + } + + fn size_hint(&self) -> (usize, Option) { + match &self.patterns { + Patterns::Single(_) => (1, Some(1)), + Patterns::List(patterns) => (patterns.len(), Some(patterns.len())), + } + } + } + + PatternIter { + patterns: &self.patterns, + list_idx: 0, + done: false, + } + } + + /// Joins two resources. + /// + /// Resulting resource is prefix if `other` is prefix. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let joined = ResourceDef::prefix("/root").join(&ResourceDef::prefix("/seg")); + /// assert_eq!(joined, ResourceDef::prefix("/root/seg")); + /// ``` + pub fn join(&self, other: &ResourceDef) -> ResourceDef { + let patterns = self + .pattern_iter() + .flat_map(move |this| other.pattern_iter().map(move |other| (this, other))) + .map(|(this, other)| [this, other].join("")) + .collect::>(); + + match patterns.len() { + 1 => ResourceDef::construct(&patterns[0], other.is_prefix()), + _ => ResourceDef::construct(patterns, other.is_prefix()), + } + } + + /// Returns `true` if `path` matches this resource. + /// + /// The behavior of this method depends on how the `ResourceDef` was constructed. For example, + /// static resources will not be able to match as many paths as dynamic and prefix resources. + /// See [`ResourceDef`] struct docs for details on resource definition types. + /// + /// This method will always agree with [`find_match`][Self::find_match] on whether the path + /// matches or not. + /// + /// # Examples + /// ``` + /// use actix_router::ResourceDef; + /// + /// // static resource + /// let resource = ResourceDef::new("/user"); + /// assert!(resource.is_match("/user")); + /// assert!(!resource.is_match("/users")); + /// assert!(!resource.is_match("/user/123")); + /// assert!(!resource.is_match("/foo")); + /// + /// // dynamic resource + /// let resource = ResourceDef::new("/user/{user_id}"); + /// assert!(resource.is_match("/user/123")); + /// assert!(!resource.is_match("/user/123/stars")); + /// + /// // prefix resource + /// let resource = ResourceDef::prefix("/root"); + /// assert!(resource.is_match("/root")); + /// assert!(resource.is_match("/root/leaf")); + /// assert!(!resource.is_match("/roots")); + /// + /// // more examples are shown in the `ResourceDef` struct docs + /// ``` + #[inline] + pub fn is_match(&self, path: &str) -> bool { + profile_method!(is_match); + + // this function could be expressed as: + // `self.find_match(path).is_some()` + // but this skips some checks and uses potentially faster regex methods + + match &self.pat_type { + PatternType::Static(pattern) => self.static_match(pattern, path).is_some(), + PatternType::Dynamic(re, _) => re.is_match(path), + PatternType::DynamicSet(re, _) => re.is_match(path), + } + } + + /// Tries to match `path` to this resource, returning the position in the path where the + /// match ends. + /// + /// This method will always agree with [`is_match`][Self::is_match] on whether the path matches + /// or not. + /// + /// # Examples + /// ``` + /// use actix_router::ResourceDef; + /// + /// // static resource + /// let resource = ResourceDef::new("/user"); + /// assert_eq!(resource.find_match("/user"), Some(5)); + /// assert!(resource.find_match("/user/").is_none()); + /// assert!(resource.find_match("/user/123").is_none()); + /// assert!(resource.find_match("/foo").is_none()); + /// + /// // constant prefix resource + /// let resource = ResourceDef::prefix("/user"); + /// assert_eq!(resource.find_match("/user"), Some(5)); + /// assert_eq!(resource.find_match("/user/"), Some(5)); + /// assert_eq!(resource.find_match("/user/123"), Some(5)); + /// + /// // dynamic prefix resource + /// let resource = ResourceDef::prefix("/user/{id}"); + /// assert_eq!(resource.find_match("/user/123"), Some(9)); + /// assert_eq!(resource.find_match("/user/1234/"), Some(10)); + /// assert_eq!(resource.find_match("/user/12345/stars"), Some(11)); + /// assert!(resource.find_match("/user/").is_none()); + /// + /// // multi-pattern resource + /// let resource = ResourceDef::new(["/user/{id}", "/profile/{id}"]); + /// assert_eq!(resource.find_match("/user/123"), Some(9)); + /// assert_eq!(resource.find_match("/profile/1234"), Some(13)); + /// ``` + pub fn find_match(&self, path: &str) -> Option { + profile_method!(find_match); + + match &self.pat_type { + PatternType::Static(pattern) => self.static_match(pattern, path), + + PatternType::Dynamic(re, _) => Some(re.captures(path)?[1].len()), + + PatternType::DynamicSet(re, params) => { + let idx = re.matches(path).into_iter().next()?; + let (ref pattern, _) = params[idx]; + Some(pattern.captures(path)?[1].len()) + } + } + } + + /// Collects dynamic segment values into `path`. + /// + /// Returns `true` if `path` matches this resource. + /// + /// # Examples + /// ``` + /// use actix_router::{Path, ResourceDef}; + /// + /// let resource = ResourceDef::prefix("/user/{id}"); + /// let mut path = Path::new("/user/123/stars"); + /// assert!(resource.capture_match_info(&mut path)); + /// assert_eq!(path.get("id").unwrap(), "123"); + /// assert_eq!(path.unprocessed(), "/stars"); + /// + /// let resource = ResourceDef::new("/blob/{path}*"); + /// let mut path = Path::new("/blob/HEAD/Cargo.toml"); + /// assert!(resource.capture_match_info(&mut path)); + /// assert_eq!(path.get("path").unwrap(), "HEAD/Cargo.toml"); + /// assert_eq!(path.unprocessed(), ""); + /// ``` + pub fn capture_match_info(&self, path: &mut Path) -> bool { + profile_method!(capture_match_info); + self.capture_match_info_fn(path, |_, _| true, ()) + } + + /// Collects dynamic segment values into `resource` after matching paths and executing + /// check function. + /// + /// The check function is given a reference to the passed resource and optional arbitrary data. + /// This is useful if you want to conditionally match on some non-path related aspect of the + /// resource type. + /// + /// Returns `true` if resource path matches this resource definition _and_ satisfies the + /// given check function. + /// + /// # Examples + /// ``` + /// use actix_router::{Path, ResourceDef}; + /// + /// fn try_match(resource: &ResourceDef, path: &mut Path<&str>) -> bool { + /// let admin_allowed = std::env::var("ADMIN_ALLOWED").ok(); + /// + /// resource.capture_match_info_fn( + /// path, + /// // when env var is not set, reject when path contains "admin" + /// |res, admin_allowed| !res.path().contains("admin"), + /// &admin_allowed + /// ) + /// } + /// + /// let resource = ResourceDef::prefix("/user/{id}"); + /// + /// // path matches; segment values are collected into path + /// let mut path = Path::new("/user/james/stars"); + /// assert!(try_match(&resource, &mut path)); + /// assert_eq!(path.get("id").unwrap(), "james"); + /// assert_eq!(path.unprocessed(), "/stars"); + /// + /// // path matches but fails check function; no segments are collected + /// let mut path = Path::new("/user/admin/stars"); + /// assert!(!try_match(&resource, &mut path)); + /// assert_eq!(path.unprocessed(), "/user/admin/stars"); + /// ``` + pub fn capture_match_info_fn( + &self, + resource: &mut R, + check_fn: F, + user_data: U, + ) -> bool + where + R: Resource, + F: FnOnce(&R, U) -> bool, + { + profile_method!(capture_match_info_fn); + + let mut segments = <[PathItem; MAX_DYNAMIC_SEGMENTS]>::default(); + let path = resource.resource_path(); + let path_str = path.path(); + + let (matched_len, matched_vars) = match &self.pat_type { + PatternType::Static(pattern) => { + profile_section!(pattern_static_or_prefix); + + match self.static_match(pattern, path_str) { + Some(len) => (len, None), + None => return false, + } + } + + PatternType::Dynamic(re, names) => { + profile_section!(pattern_dynamic); + + let captures = { + profile_section!(pattern_dynamic_regex_exec); + + match re.captures(path.path()) { + Some(captures) => captures, + _ => return false, + } + }; + + { + profile_section!(pattern_dynamic_extract_captures); + + for (no, name) in names.iter().enumerate() { + if let Some(m) = captures.name(name) { + segments[no] = PathItem::Segment(m.start() as u16, m.end() as u16); + } else { + log::error!( + "Dynamic path match but not all segments found: {}", + name + ); + return false; + } + } + }; + + (captures[1].len(), Some(names)) + } + + PatternType::DynamicSet(re, params) => { + profile_section!(pattern_dynamic_set); + + let path = path.path(); + let (pattern, names) = match re.matches(path).into_iter().next() { + Some(idx) => ¶ms[idx], + _ => return false, + }; + + let captures = match pattern.captures(path.path()) { + Some(captures) => captures, + _ => return false, + }; + + for (no, name) in names.iter().enumerate() { + if let Some(m) = captures.name(name) { + segments[no] = PathItem::Segment(m.start() as u16, m.end() as u16); + } else { + log::error!("Dynamic path match but not all segments found: {}", name); + return false; + } + } + + (captures[1].len(), Some(names)) + } + }; + + if !check_fn(resource, user_data) { + return false; + } + + // Modify `path` to skip matched part and store matched segments + let path = resource.resource_path(); + + if let Some(vars) = matched_vars { + for i in 0..vars.len() { + path.add(vars[i], mem::take(&mut segments[i])); + } + } + + path.skip(matched_len as u16); + + true + } + + /// Assembles resource path using a closure that maps variable segment names to values. + fn build_resource_path(&self, path: &mut String, mut vars: F) -> bool + where + F: FnMut(&str) -> Option, + I: AsRef, + { + for segment in &self.segments { + match segment { + PatternSegment::Const(val) => path.push_str(val), + PatternSegment::Var(name) => match vars(name) { + Some(val) => path.push_str(val.as_ref()), + _ => return false, + }, + } + } + + true + } + + /// Assembles full resource path from iterator of dynamic segment values. + /// + /// Returns `true` on success. + /// + /// For multi-pattern resources, the first pattern is used under the assumption that it would be + /// equivalent to any other choice. + /// + /// # Examples + /// ``` + /// # use actix_router::ResourceDef; + /// let mut s = String::new(); + /// let resource = ResourceDef::new("/user/{id}/post/{title}"); + /// + /// assert!(resource.resource_path_from_iter(&mut s, &["123", "my-post"])); + /// assert_eq!(s, "/user/123/post/my-post"); + /// ``` + pub fn resource_path_from_iter(&self, path: &mut String, values: I) -> bool + where + I: IntoIterator, + I::Item: AsRef, + { + profile_method!(resource_path_from_iter); + let mut iter = values.into_iter(); + self.build_resource_path(path, |_| iter.next()) + } + + /// Assembles resource path from map of dynamic segment values. + /// + /// Returns `true` on success. + /// + /// For multi-pattern resources, the first pattern is used under the assumption that it would be + /// equivalent to any other choice. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use actix_router::ResourceDef; + /// let mut s = String::new(); + /// let resource = ResourceDef::new("/user/{id}/post/{title}"); + /// + /// let mut map = HashMap::new(); + /// map.insert("id", "123"); + /// map.insert("title", "my-post"); + /// + /// assert!(resource.resource_path_from_map(&mut s, &map)); + /// assert_eq!(s, "/user/123/post/my-post"); + /// ``` + pub fn resource_path_from_map( + &self, + path: &mut String, + values: &HashMap, + ) -> bool + where + K: Borrow + Eq + Hash, + V: AsRef, + S: BuildHasher, + { + profile_method!(resource_path_from_map); + self.build_resource_path(path, |name| values.get(name).map(AsRef::::as_ref)) + } + + /// Returns true if `prefix` acts as a proper prefix (i.e., separated by a slash) in `path`. + fn static_match(&self, pattern: &str, path: &str) -> Option { + let rem = path.strip_prefix(pattern)?; + + match self.is_prefix { + // resource is not a prefix so an exact match is needed + false if rem.is_empty() => Some(pattern.len()), + + // resource is a prefix so rem should start with a path delimiter + true if rem.is_empty() || rem.starts_with('/') => Some(pattern.len()), + + // otherwise, no match + _ => None, + } + } + + fn construct(paths: T, is_prefix: bool) -> Self { + profile_method!(construct); + + let patterns = paths.patterns(); + let (pat_type, segments) = match &patterns { + Patterns::Single(pattern) => ResourceDef::parse(pattern, is_prefix, false), + + // since zero length pattern sets are possible + // just return a useless `ResourceDef` + Patterns::List(patterns) if patterns.is_empty() => ( + PatternType::DynamicSet(RegexSet::empty(), Vec::new()), + Vec::new(), + ), + + Patterns::List(patterns) => { + let mut re_set = Vec::with_capacity(patterns.len()); + let mut pattern_data = Vec::new(); + let mut segments = None; + + for pattern in patterns { + match ResourceDef::parse(pattern, is_prefix, true) { + (PatternType::Dynamic(re, names), segs) => { + re_set.push(re.as_str().to_owned()); + pattern_data.push((re, names)); + segments.get_or_insert(segs); + } + _ => unreachable!(), + } + } + + let pattern_re_set = RegexSet::new(re_set).unwrap(); + let segments = segments.unwrap_or_else(Vec::new); + + ( + PatternType::DynamicSet(pattern_re_set, pattern_data), + segments, + ) + } + }; + + ResourceDef { + id: 0, + name: None, + patterns, + is_prefix, + pat_type, + segments, + } + } + + /// Parses a dynamic segment definition from a pattern. + /// + /// The returned tuple includes: + /// - the segment descriptor, either `Var` or `Tail` + /// - the segment's regex to check values against + /// - the remaining, unprocessed string slice + /// - whether the parsed parameter represents a tail pattern + /// + /// # Panics + /// Panics if given patterns does not contain a dynamic segment. + fn parse_param(pattern: &str) -> (PatternSegment, String, &str, bool) { + profile_method!(parse_param); + + const DEFAULT_PATTERN: &str = "[^/]+"; + const DEFAULT_PATTERN_TAIL: &str = ".*"; + + let mut params_nesting = 0usize; + let close_idx = pattern + .find(|c| match c { + '{' => { + params_nesting += 1; + false + } + '}' => { + params_nesting -= 1; + params_nesting == 0 + } + _ => false, + }) + .unwrap_or_else(|| { + panic!( + r#"pattern "{}" contains malformed dynamic segment"#, + pattern + ) + }); + + let (mut param, mut unprocessed) = pattern.split_at(close_idx + 1); + + // remove outer curly brackets + param = ¶m[1..param.len() - 1]; + + let tail = unprocessed == "*"; + + let (name, pattern) = match param.find(':') { + Some(idx) => { + assert!(!tail, "custom regex is not supported for tail match"); + + let (name, pattern) = param.split_at(idx); + (name, &pattern[1..]) + } + None => ( + param, + if tail { + unprocessed = &unprocessed[1..]; + DEFAULT_PATTERN_TAIL + } else { + DEFAULT_PATTERN + }, + ), + }; + + let segment = PatternSegment::Var(name.to_string()); + let regex = format!(r"(?P<{}>{})", &name, &pattern); + + (segment, regex, unprocessed, tail) + } + + /// Parse `pattern` using `is_prefix` and `force_dynamic` flags. + /// + /// Parameters: + /// - `is_prefix`: Use `true` if `pattern` should be treated as a prefix; i.e., a conforming + /// path will be a match even if it has parts remaining to process + /// - `force_dynamic`: Use `true` to disallow the return of static and prefix segments. + /// + /// The returned tuple includes: + /// - the pattern type detected, either `Static`, `Prefix`, or `Dynamic` + /// - a list of segment descriptors from the pattern + fn parse( + pattern: &str, + is_prefix: bool, + force_dynamic: bool, + ) -> (PatternType, Vec) { + profile_method!(parse); + + if !force_dynamic && pattern.find('{').is_none() && !pattern.ends_with('*') { + // pattern is static + return ( + PatternType::Static(pattern.to_owned()), + vec![PatternSegment::Const(pattern.to_owned())], + ); + } + + let mut unprocessed = pattern; + let mut segments = Vec::new(); + let mut re = format!("{}^", REGEX_FLAGS); + let mut dyn_segment_count = 0; + let mut has_tail_segment = false; + + while let Some(idx) = unprocessed.find('{') { + let (prefix, rem) = unprocessed.split_at(idx); + + segments.push(PatternSegment::Const(prefix.to_owned())); + re.push_str(&escape(prefix)); + + let (param_pattern, re_part, rem, tail) = Self::parse_param(rem); + + if tail { + has_tail_segment = true; + } + + segments.push(param_pattern); + re.push_str(&re_part); + + unprocessed = rem; + dyn_segment_count += 1; + } + + if is_prefix && has_tail_segment { + // tail segments in prefixes have no defined semantics + + #[cfg(not(test))] + log::warn!( + "Prefix resources should not have tail segments. \ + Use `ResourceDef::new` constructor. \ + This may become a panic in the future." + ); + + // panic in tests to make this case detectable + #[cfg(test)] + panic!("prefix resource definitions should not have tail segments"); + } + + if unprocessed.ends_with('*') { + // unnamed tail segment + + #[cfg(not(test))] + log::warn!( + "Tail segments must have names. \ + Consider `.../{{tail}}*`. \ + This may become a panic in the future." + ); + + // panic in tests to make this case detectable + #[cfg(test)] + panic!("tail segments must have names"); + } else if !has_tail_segment && !unprocessed.is_empty() { + // prevent `Const("")` element from being added after last dynamic segment + + segments.push(PatternSegment::Const(unprocessed.to_owned())); + re.push_str(&escape(unprocessed)); + } + + assert!( + dyn_segment_count <= MAX_DYNAMIC_SEGMENTS, + "Only {} dynamic segments are allowed, provided: {}", + MAX_DYNAMIC_SEGMENTS, + dyn_segment_count + ); + + // Store the pattern in capture group #1 to have context info outside it + let mut re = format!("({})", re); + + // Ensure the match ends at a segment boundary + if !has_tail_segment { + if is_prefix { + re.push_str(r"(/|$)"); + } else { + re.push('$'); + } + } + + let re = match Regex::new(&re) { + Ok(re) => re, + Err(err) => panic!("Wrong path pattern: \"{}\" {}", pattern, err), + }; + + // `Bok::leak(Box::new(name))` is an intentional memory leak. In typical applications the + // routing table is only constructed once (per worker) so leak is bounded. If you are + // constructing `ResourceDef`s more than once in your application's lifecycle you would + // expect a linear increase in leaked memory over time. + let names = re + .capture_names() + .filter_map(|name| name.map(|name| Box::leak(Box::new(name.to_owned())).as_str())) + .collect(); + + (PatternType::Dynamic(re, names), segments) + } +} + +impl Eq for ResourceDef {} + +impl PartialEq for ResourceDef { + fn eq(&self, other: &ResourceDef) -> bool { + self.patterns == other.patterns && self.is_prefix == other.is_prefix + } +} + +impl Hash for ResourceDef { + fn hash(&self, state: &mut H) { + self.patterns.hash(state); + } +} + +impl<'a> From<&'a str> for ResourceDef { + fn from(path: &'a str) -> ResourceDef { + ResourceDef::new(path) + } +} + +impl From for ResourceDef { + fn from(path: String) -> ResourceDef { + ResourceDef::new(path) + } +} + +pub(crate) fn insert_slash(path: &str) -> Cow<'_, str> { + profile_fn!(insert_slash); + + if !path.is_empty() && !path.starts_with('/') { + let mut new_path = String::with_capacity(path.len() + 1); + new_path.push('/'); + new_path.push_str(path); + Cow::Owned(new_path) + } else { + Cow::Borrowed(path) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn equivalence() { + assert_eq!( + ResourceDef::root_prefix("/root"), + ResourceDef::prefix("/root") + ); + assert_eq!( + ResourceDef::root_prefix("root"), + ResourceDef::prefix("/root") + ); + assert_eq!( + ResourceDef::root_prefix("/{id}"), + ResourceDef::prefix("/{id}") + ); + assert_eq!( + ResourceDef::root_prefix("{id}"), + ResourceDef::prefix("/{id}") + ); + + assert_eq!(ResourceDef::new("/"), ResourceDef::new(["/"])); + assert_eq!(ResourceDef::new("/"), ResourceDef::new(vec!["/"])); + + assert_ne!(ResourceDef::new(""), ResourceDef::prefix("")); + assert_ne!(ResourceDef::new("/"), ResourceDef::prefix("/")); + assert_ne!(ResourceDef::new("/{id}"), ResourceDef::prefix("/{id}")); + } + + #[test] + fn parse_static() { + let re = ResourceDef::new(""); + + assert!(!re.is_prefix()); + + assert!(re.is_match("")); + assert!(!re.is_match("/")); + assert_eq!(re.find_match(""), Some(0)); + assert_eq!(re.find_match("/"), None); + + let re = ResourceDef::new("/"); + assert!(re.is_match("/")); + assert!(!re.is_match("")); + assert!(!re.is_match("/foo")); + + let re = ResourceDef::new("/name"); + assert!(re.is_match("/name")); + assert!(!re.is_match("/name1")); + assert!(!re.is_match("/name/")); + assert!(!re.is_match("/name~")); + + let mut path = Path::new("/name"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.unprocessed(), ""); + + assert_eq!(re.find_match("/name"), Some(5)); + assert_eq!(re.find_match("/name1"), None); + assert_eq!(re.find_match("/name/"), None); + assert_eq!(re.find_match("/name~"), None); + + let re = ResourceDef::new("/name/"); + assert!(re.is_match("/name/")); + assert!(!re.is_match("/name")); + assert!(!re.is_match("/name/gs")); + + let re = ResourceDef::new("/user/profile"); + assert!(re.is_match("/user/profile")); + assert!(!re.is_match("/user/profile/profile")); + + let mut path = Path::new("/user/profile"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.unprocessed(), ""); + } + + #[test] + fn parse_param() { + let re = ResourceDef::new("/user/{id}"); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(!re.is_match("/user/2345/")); + assert!(!re.is_match("/user/2345/sdg")); + + let mut path = Path::new("/user/profile"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "profile"); + assert_eq!(path.unprocessed(), ""); + + let mut path = Path::new("/user/1245125"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "1245125"); + assert_eq!(path.unprocessed(), ""); + + let re = ResourceDef::new("/v{version}/resource/{id}"); + assert!(re.is_match("/v1/resource/320120")); + assert!(!re.is_match("/v/resource/1")); + assert!(!re.is_match("/resource")); + + let mut path = Path::new("/v151/resource/adage32"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("version").unwrap(), "151"); + assert_eq!(path.get("id").unwrap(), "adage32"); + assert_eq!(path.unprocessed(), ""); + + let re = ResourceDef::new("/{id:[[:digit:]]{6}}"); + assert!(re.is_match("/012345")); + assert!(!re.is_match("/012")); + assert!(!re.is_match("/01234567")); + assert!(!re.is_match("/XXXXXX")); + + let mut path = Path::new("/012345"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "012345"); + assert_eq!(path.unprocessed(), ""); + } + + #[allow(clippy::cognitive_complexity)] + #[test] + fn dynamic_set() { + let re = ResourceDef::new(vec![ + "/user/{id}", + "/v{version}/resource/{id}", + "/{id:[[:digit:]]{6}}", + "/static", + ]); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(!re.is_match("/user/2345/")); + assert!(!re.is_match("/user/2345/sdg")); + + let mut path = Path::new("/user/profile"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "profile"); + assert_eq!(path.unprocessed(), ""); + + let mut path = Path::new("/user/1245125"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "1245125"); + assert_eq!(path.unprocessed(), ""); + + assert!(re.is_match("/v1/resource/320120")); + assert!(!re.is_match("/v/resource/1")); + assert!(!re.is_match("/resource")); + + let mut path = Path::new("/v151/resource/adage32"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("version").unwrap(), "151"); + assert_eq!(path.get("id").unwrap(), "adage32"); + + assert!(re.is_match("/012345")); + assert!(!re.is_match("/012")); + assert!(!re.is_match("/01234567")); + assert!(!re.is_match("/XXXXXX")); + + assert!(re.is_match("/static")); + assert!(!re.is_match("/a/static")); + assert!(!re.is_match("/static/a")); + + let mut path = Path::new("/012345"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "012345"); + + let re = ResourceDef::new([ + "/user/{id}", + "/v{version}/resource/{id}", + "/{id:[[:digit:]]{6}}", + ]); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(!re.is_match("/user/2345/")); + assert!(!re.is_match("/user/2345/sdg")); + + let re = ResourceDef::new([ + "/user/{id}".to_string(), + "/v{version}/resource/{id}".to_string(), + "/{id:[[:digit:]]{6}}".to_string(), + ]); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(!re.is_match("/user/2345/")); + assert!(!re.is_match("/user/2345/sdg")); + } + + #[test] + fn dynamic_set_prefix() { + let re = ResourceDef::prefix(vec!["/u/{id}", "/{id:[[:digit:]]{3}}"]); + + assert_eq!(re.find_match("/u/abc"), Some(6)); + assert_eq!(re.find_match("/u/abc/123"), Some(6)); + assert_eq!(re.find_match("/s/user/profile"), None); + + assert_eq!(re.find_match("/123"), Some(4)); + assert_eq!(re.find_match("/123/456"), Some(4)); + assert_eq!(re.find_match("/12345"), None); + + let mut path = Path::new("/151/res"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "151"); + assert_eq!(path.unprocessed(), "/res"); + } + + #[test] + fn parse_tail() { + let re = ResourceDef::new("/user/-{id}*"); + + let mut path = Path::new("/user/-profile"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "profile"); + + let mut path = Path::new("/user/-2345"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "2345"); + + let mut path = Path::new("/user/-2345/"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "2345/"); + + let mut path = Path::new("/user/-2345/sdg"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "2345/sdg"); + } + + #[test] + fn static_tail() { + let re = ResourceDef::new("/user{tail}*"); + assert!(re.is_match("/users")); + assert!(re.is_match("/user-foo")); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(re.is_match("/user/2345/")); + assert!(re.is_match("/user/2345/sdg")); + assert!(!re.is_match("/foo/profile")); + + let re = ResourceDef::new("/user/{tail}*"); + assert!(re.is_match("/user/profile")); + assert!(re.is_match("/user/2345")); + assert!(re.is_match("/user/2345/")); + assert!(re.is_match("/user/2345/sdg")); + assert!(!re.is_match("/foo/profile")); + } + + #[test] + fn dynamic_tail() { + let re = ResourceDef::new("/user/{id}/{tail}*"); + assert!(!re.is_match("/user/2345")); + let mut path = Path::new("/user/2345/sdg"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "2345"); + assert_eq!(path.get("tail").unwrap(), "sdg"); + assert_eq!(path.unprocessed(), ""); + } + + #[test] + fn newline_patterns_and_paths() { + let re = ResourceDef::new("/user/a\nb"); + assert!(re.is_match("/user/a\nb")); + assert!(!re.is_match("/user/a\nb/profile")); + + let re = ResourceDef::new("/a{x}b/test/a{y}b"); + let mut path = Path::new("/a\nb/test/a\nb"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("x").unwrap(), "\n"); + assert_eq!(path.get("y").unwrap(), "\n"); + + let re = ResourceDef::new("/user/{tail}*"); + assert!(re.is_match("/user/a\nb/")); + + let re = ResourceDef::new("/user/{id}*"); + let mut path = Path::new("/user/a\nb/a\nb"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "a\nb/a\nb"); + + let re = ResourceDef::new("/user/{id:.*}"); + let mut path = Path::new("/user/a\nb/a\nb"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "a\nb/a\nb"); + } + + #[cfg(feature = "http")] + #[test] + fn parse_urlencoded_param() { + use std::convert::TryFrom; + + let re = ResourceDef::new("/user/{id}/test"); + + let mut path = Path::new("/user/2345/test"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "2345"); + + let mut path = Path::new("/user/qwe%25/test"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "qwe%25"); + + let uri = http::Uri::try_from("/user/qwe%25/test").unwrap(); + let mut path = Path::new(uri); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.get("id").unwrap(), "qwe%25"); + } + + #[test] + fn prefix_static() { + let re = ResourceDef::prefix("/name"); + + assert!(re.is_prefix()); + + assert!(re.is_match("/name")); + assert!(re.is_match("/name/")); + assert!(re.is_match("/name/test/test")); + assert!(!re.is_match("/name1")); + assert!(!re.is_match("/name~")); + + let mut path = Path::new("/name"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.unprocessed(), ""); + + let mut path = Path::new("/name/test"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.unprocessed(), "/test"); + + assert_eq!(re.find_match("/name"), Some(5)); + assert_eq!(re.find_match("/name/"), Some(5)); + assert_eq!(re.find_match("/name/test/test"), Some(5)); + assert_eq!(re.find_match("/name1"), None); + assert_eq!(re.find_match("/name~"), None); + + let re = ResourceDef::prefix("/name/"); + assert!(re.is_match("/name/")); + assert!(re.is_match("/name//gs")); + assert!(!re.is_match("/name/gs")); + assert!(!re.is_match("/name")); + + let mut path = Path::new("/name/gs"); + assert!(!re.capture_match_info(&mut path)); + + let mut path = Path::new("/name//gs"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(path.unprocessed(), "/gs"); + + let re = ResourceDef::root_prefix("name/"); + assert!(re.is_match("/name/")); + assert!(re.is_match("/name//gs")); + assert!(!re.is_match("/name/gs")); + assert!(!re.is_match("/name")); + + let mut path = Path::new("/name/gs"); + assert!(!re.capture_match_info(&mut path)); + } + + #[test] + fn prefix_dynamic() { + let re = ResourceDef::prefix("/{name}"); + + assert!(re.is_prefix()); + + assert!(re.is_match("/name/")); + assert!(re.is_match("/name/gs")); + assert!(re.is_match("/name")); + + assert_eq!(re.find_match("/name/"), Some(5)); + assert_eq!(re.find_match("/name/gs"), Some(5)); + assert_eq!(re.find_match("/name"), Some(5)); + assert_eq!(re.find_match(""), None); + + let mut path = Path::new("/test2/"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(&path["name"], "test2"); + assert_eq!(&path[0], "test2"); + assert_eq!(path.unprocessed(), "/"); + + let mut path = Path::new("/test2/subpath1/subpath2/index.html"); + assert!(re.capture_match_info(&mut path)); + assert_eq!(&path["name"], "test2"); + assert_eq!(&path[0], "test2"); + assert_eq!(path.unprocessed(), "/subpath1/subpath2/index.html"); + + let resource = ResourceDef::prefix("/user"); + // input string shorter than prefix + assert!(resource.find_match("/foo").is_none()); + } + + #[test] + fn prefix_empty() { + let re = ResourceDef::prefix(""); + + assert!(re.is_prefix()); + + assert!(re.is_match("")); + assert!(re.is_match("/")); + assert!(re.is_match("/name/test/test")); + } + + #[test] + fn build_path_list() { + let mut s = String::new(); + let resource = ResourceDef::new("/user/{item1}/test"); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["user1"]).iter())); + assert_eq!(s, "/user/user1/test"); + + let mut s = String::new(); + let resource = ResourceDef::new("/user/{item1}/{item2}/test"); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["item", "item2"]).iter())); + assert_eq!(s, "/user/item/item2/test"); + + let mut s = String::new(); + let resource = ResourceDef::new("/user/{item1}/{item2}"); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["item", "item2"]).iter())); + assert_eq!(s, "/user/item/item2"); + + let mut s = String::new(); + let resource = ResourceDef::new("/user/{item1}/{item2}/"); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["item", "item2"]).iter())); + assert_eq!(s, "/user/item/item2/"); + + let mut s = String::new(); + assert!(!resource.resource_path_from_iter(&mut s, &mut (&["item"]).iter())); + + let mut s = String::new(); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["item", "item2"]).iter())); + assert_eq!(s, "/user/item/item2/"); + assert!(!resource.resource_path_from_iter(&mut s, &mut (&["item"]).iter())); + + let mut s = String::new(); + assert!(resource.resource_path_from_iter(&mut s, &mut vec!["item", "item2"].iter())); + assert_eq!(s, "/user/item/item2/"); + } + + #[test] + fn multi_pattern_build_path() { + let resource = ResourceDef::new(["/user/{id}", "/profile/{id}"]); + let mut s = String::new(); + assert!(resource.resource_path_from_iter(&mut s, &mut ["123"].iter())); + assert_eq!(s, "/user/123"); + } + + #[test] + fn multi_pattern_capture_segment_values() { + let resource = ResourceDef::new(["/user/{id}", "/profile/{id}"]); + + let mut path = Path::new("/user/123"); + assert!(resource.capture_match_info(&mut path)); + assert!(path.get("id").is_some()); + + let mut path = Path::new("/profile/123"); + assert!(resource.capture_match_info(&mut path)); + assert!(path.get("id").is_some()); + + let resource = ResourceDef::new(["/user/{id}", "/profile/{uid}"]); + + let mut path = Path::new("/user/123"); + assert!(resource.capture_match_info(&mut path)); + assert!(path.get("id").is_some()); + assert!(path.get("uid").is_none()); + + let mut path = Path::new("/profile/123"); + assert!(resource.capture_match_info(&mut path)); + assert!(path.get("id").is_none()); + assert!(path.get("uid").is_some()); + } + + #[test] + fn dynamic_prefix_proper_segmentation() { + let resource = ResourceDef::prefix(r"/id/{id:\d{3}}"); + + assert!(resource.is_match("/id/123")); + assert!(resource.is_match("/id/123/foo")); + assert!(!resource.is_match("/id/1234")); + assert!(!resource.is_match("/id/123a")); + + assert_eq!(resource.find_match("/id/123"), Some(7)); + assert_eq!(resource.find_match("/id/123/foo"), Some(7)); + assert_eq!(resource.find_match("/id/1234"), None); + assert_eq!(resource.find_match("/id/123a"), None); + } + + #[test] + fn build_path_map() { + let resource = ResourceDef::new("/user/{item1}/{item2}/"); + + let mut map = HashMap::new(); + map.insert("item1", "item"); + + let mut s = String::new(); + assert!(!resource.resource_path_from_map(&mut s, &map)); + + map.insert("item2", "item2"); + + let mut s = String::new(); + assert!(resource.resource_path_from_map(&mut s, &map)); + assert_eq!(s, "/user/item/item2/"); + } + + #[test] + fn build_path_tail() { + let resource = ResourceDef::new("/user/{item1}*"); + + let mut s = String::new(); + assert!(!resource.resource_path_from_iter(&mut s, &mut (&[""; 0]).iter())); + + let mut s = String::new(); + assert!(resource.resource_path_from_iter(&mut s, &mut (&["user1"]).iter())); + assert_eq!(s, "/user/user1"); + + let mut s = String::new(); + let mut map = HashMap::new(); + map.insert("item1", "item"); + assert!(resource.resource_path_from_map(&mut s, &map)); + assert_eq!(s, "/user/item"); + } + + #[test] + fn prefix_trailing_slash() { + // The prefix "/abc/" matches two segments: ["user", ""] + + // These are not prefixes + let re = ResourceDef::prefix("/abc/"); + assert_eq!(re.find_match("/abc/def"), None); + assert_eq!(re.find_match("/abc//def"), Some(5)); + + let re = ResourceDef::prefix("/{id}/"); + assert_eq!(re.find_match("/abc/def"), None); + assert_eq!(re.find_match("/abc//def"), Some(5)); + } + + #[test] + fn join() { + // test joined defs match the same paths as each component separately + + fn seq_find_match(re1: &ResourceDef, re2: &ResourceDef, path: &str) -> Option { + let len1 = re1.find_match(path)?; + let len2 = re2.find_match(&path[len1..])?; + Some(len1 + len2) + } + + macro_rules! join_test { + ($pat1:expr, $pat2:expr => $($test:expr),+) => {{ + let pat1 = $pat1; + let pat2 = $pat2; + $({ + let _path = $test; + let (re1, re2) = (ResourceDef::prefix(pat1), ResourceDef::new(pat2)); + let _seq = seq_find_match(&re1, &re2, _path); + let _join = re1.join(&re2).find_match(_path); + assert_eq!( + _seq, _join, + "patterns: prefix {:?}, {:?}; mismatch on \"{}\"; seq={:?}; join={:?}", + pat1, pat2, _path, _seq, _join + ); + assert!(!re1.join(&re2).is_prefix()); + + let (re1, re2) = (ResourceDef::prefix(pat1), ResourceDef::prefix(pat2)); + let _seq = seq_find_match(&re1, &re2, _path); + let _join = re1.join(&re2).find_match(_path); + assert_eq!( + _seq, _join, + "patterns: prefix {:?}, prefix {:?}; mismatch on \"{}\"; seq={:?}; join={:?}", + pat1, pat2, _path, _seq, _join + ); + assert!(re1.join(&re2).is_prefix()); + })+ + }} + } + + join_test!("", "" => "", "/hello", "/"); + join_test!("/user", "" => "", "/user", "/user/123", "/user11", "user", "user/123"); + join_test!("", "/user" => "", "/user", "foo", "/user11", "user", "user/123"); + join_test!("/user", "/xx" => "", "", "/", "/user", "/xx", "/userxx", "/user/xx"); + + join_test!(["/ver/{v}", "/v{v}"], ["/req/{req}", "/{req}"] => "/v1/abc", + "/ver/1/abc", "/v1/req/abc", "/ver/1/req/abc", "/v1/abc/def", + "/ver1/req/abc/def", "", "/", "/v1/"); + } + + #[test] + fn match_methods_agree() { + macro_rules! match_methods_agree { + ($pat:expr => $($test:expr),+) => {{ + match_methods_agree!(finish $pat, ResourceDef::new($pat), $($test),+); + }}; + (prefix $pat:expr => $($test:expr),+) => {{ + match_methods_agree!(finish $pat, ResourceDef::prefix($pat), $($test),+); + }}; + (finish $pat:expr, $re:expr, $($test:expr),+) => {{ + let re = $re; + $({ + let _is = re.is_match($test); + let _find = re.find_match($test).is_some(); + assert_eq!( + _is, _find, + "pattern: {:?}; mismatch on \"{}\"; is={}; find={}", + $pat, $test, _is, _find + ); + })+ + }} + } + + match_methods_agree!("" => "", "/", "/foo"); + match_methods_agree!("/" => "", "/", "/foo"); + match_methods_agree!("/user" => "user", "/user", "/users", "/user/123", "/foo"); + match_methods_agree!("/v{v}" => "v", "/v", "/v1", "/v222", "/foo"); + match_methods_agree!(["/v{v}", "/version/{v}"] => "/v", "/v1", "/version", "/version/1", "/foo"); + + match_methods_agree!("/path{tail}*" => "/path", "/path1", "/path/123"); + match_methods_agree!("/path/{tail}*" => "/path", "/path1", "/path/123"); + + match_methods_agree!(prefix "" => "", "/", "/foo"); + match_methods_agree!(prefix "/user" => "user", "/user", "/users", "/user/123", "/foo"); + match_methods_agree!(prefix r"/id/{id:\d{3}}" => "/id/123", "/id/1234"); + match_methods_agree!(["/v{v}", "/ver/{v}"] => "", "s/v", "/v1", "/v1/xx", "/ver/i3/5", "/ver/1"); + } + + #[test] + #[should_panic] + fn duplicate_segment_name() { + ResourceDef::new("/user/{id}/post/{id}"); + } + + #[test] + #[should_panic] + fn invalid_dynamic_segment_delimiter() { + ResourceDef::new("/user/{username"); + } + + #[test] + #[should_panic] + fn invalid_dynamic_segment_name() { + ResourceDef::new("/user/{}"); + } + + #[test] + #[should_panic] + fn invalid_too_many_dynamic_segments() { + // valid + ResourceDef::new("/{a}/{b}/{c}/{d}/{e}/{f}/{g}/{h}/{i}/{j}/{k}/{l}/{m}/{n}/{o}/{p}"); + + // panics + ResourceDef::new( + "/{a}/{b}/{c}/{d}/{e}/{f}/{g}/{h}/{i}/{j}/{k}/{l}/{m}/{n}/{o}/{p}/{q}", + ); + } + + #[test] + #[should_panic] + fn invalid_custom_regex_for_tail() { + ResourceDef::new(r"/{tail:\d+}*"); + } + + #[test] + #[should_panic] + fn invalid_unnamed_tail_segment() { + ResourceDef::new("/*"); + } + + #[test] + #[should_panic] + fn prefix_plus_tail_match_disallowed() { + ResourceDef::prefix("/user/{id}*"); + } +} diff --git a/actix-router/src/resource_path.rs b/actix-router/src/resource_path.rs new file mode 100644 index 000000000..00eb0927c --- /dev/null +++ b/actix-router/src/resource_path.rs @@ -0,0 +1,39 @@ +use crate::Path; + +// TODO: this trait is necessary, document it +// see impl Resource for ServiceRequest +pub trait Resource { + /// Type of resource's path returned in `resource_path`. + type Path: ResourcePath; + + fn resource_path(&mut self) -> &mut Path; +} + +pub trait ResourcePath { + fn path(&self) -> &str; +} + +impl ResourcePath for String { + fn path(&self) -> &str { + self.as_str() + } +} + +impl<'a> ResourcePath for &'a str { + fn path(&self) -> &str { + self + } +} + +impl ResourcePath for bytestring::ByteString { + fn path(&self) -> &str { + &*self + } +} + +#[cfg(feature = "http")] +impl ResourcePath for http::Uri { + fn path(&self) -> &str { + self.path() + } +} diff --git a/actix-router/src/router.rs b/actix-router/src/router.rs new file mode 100644 index 000000000..47940708e --- /dev/null +++ b/actix-router/src/router.rs @@ -0,0 +1,278 @@ +use firestorm::profile_method; + +use crate::{IntoPatterns, Resource, ResourceDef}; + +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct ResourceId(pub u16); + +/// Information about current resource +#[derive(Debug, Clone)] +pub struct ResourceInfo { + #[allow(dead_code)] + resource: ResourceId, +} + +/// Resource router. +// T is the resource itself +// U is any other data needed for routing like method guards +pub struct Router { + routes: Vec<(ResourceDef, T, Option)>, +} + +impl Router { + pub fn build() -> RouterBuilder { + RouterBuilder { + resources: Vec::new(), + } + } + + pub fn recognize(&self, resource: &mut R) -> Option<(&T, ResourceId)> + where + R: Resource, + { + profile_method!(recognize); + + for item in self.routes.iter() { + if item.0.capture_match_info(resource.resource_path()) { + return Some((&item.1, ResourceId(item.0.id()))); + } + } + + None + } + + pub fn recognize_mut(&mut self, resource: &mut R) -> Option<(&mut T, ResourceId)> + where + R: Resource, + { + profile_method!(recognize_mut); + + for item in self.routes.iter_mut() { + if item.0.capture_match_info(resource.resource_path()) { + return Some((&mut item.1, ResourceId(item.0.id()))); + } + } + + None + } + + pub fn recognize_fn(&self, resource: &mut R, check: F) -> Option<(&T, ResourceId)> + where + F: Fn(&R, &Option) -> bool, + R: Resource, + { + profile_method!(recognize_checked); + + for item in self.routes.iter() { + if item.0.capture_match_info_fn(resource, &check, &item.2) { + return Some((&item.1, ResourceId(item.0.id()))); + } + } + + None + } + + pub fn recognize_mut_fn( + &mut self, + resource: &mut R, + check: F, + ) -> Option<(&mut T, ResourceId)> + where + F: Fn(&R, &Option) -> bool, + R: Resource, + { + profile_method!(recognize_mut_checked); + + for item in self.routes.iter_mut() { + if item.0.capture_match_info_fn(resource, &check, &item.2) { + return Some((&mut item.1, ResourceId(item.0.id()))); + } + } + + None + } +} + +pub struct RouterBuilder { + resources: Vec<(ResourceDef, T, Option)>, +} + +impl RouterBuilder { + /// Register resource for specified path. + pub fn path( + &mut self, + path: P, + resource: T, + ) -> &mut (ResourceDef, T, Option) { + profile_method!(path); + + self.resources + .push((ResourceDef::new(path), resource, None)); + self.resources.last_mut().unwrap() + } + + /// Register resource for specified path prefix. + pub fn prefix(&mut self, prefix: &str, resource: T) -> &mut (ResourceDef, T, Option) { + profile_method!(prefix); + + self.resources + .push((ResourceDef::prefix(prefix), resource, None)); + self.resources.last_mut().unwrap() + } + + /// Register resource for ResourceDef + pub fn rdef(&mut self, rdef: ResourceDef, resource: T) -> &mut (ResourceDef, T, Option) { + profile_method!(rdef); + + self.resources.push((rdef, resource, None)); + self.resources.last_mut().unwrap() + } + + /// Finish configuration and create router instance. + pub fn finish(self) -> Router { + Router { + routes: self.resources, + } + } +} + +#[cfg(test)] +mod tests { + use crate::path::Path; + use crate::router::{ResourceId, Router}; + + #[allow(clippy::cognitive_complexity)] + #[test] + fn test_recognizer_1() { + let mut router = Router::::build(); + router.path("/name", 10).0.set_id(0); + router.path("/name/{val}", 11).0.set_id(1); + router.path("/name/{val}/index.html", 12).0.set_id(2); + router.path("/file/{file}.{ext}", 13).0.set_id(3); + router.path("/v{val}/{val2}/index.html", 14).0.set_id(4); + router.path("/v/{tail:.*}", 15).0.set_id(5); + router.path("/test2/{test}.html", 16).0.set_id(6); + router.path("/{test}/index.html", 17).0.set_id(7); + let mut router = router.finish(); + + let mut path = Path::new("/unknown"); + assert!(router.recognize_mut(&mut path).is_none()); + + let mut path = Path::new("/name"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 10); + assert_eq!(info, ResourceId(0)); + assert!(path.is_empty()); + + let mut path = Path::new("/name/value"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 11); + assert_eq!(info, ResourceId(1)); + assert_eq!(path.get("val").unwrap(), "value"); + assert_eq!(&path["val"], "value"); + + let mut path = Path::new("/name/value2/index.html"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 12); + assert_eq!(info, ResourceId(2)); + assert_eq!(path.get("val").unwrap(), "value2"); + + let mut path = Path::new("/file/file.gz"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 13); + assert_eq!(info, ResourceId(3)); + assert_eq!(path.get("file").unwrap(), "file"); + assert_eq!(path.get("ext").unwrap(), "gz"); + + let mut path = Path::new("/vtest/ttt/index.html"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 14); + assert_eq!(info, ResourceId(4)); + assert_eq!(path.get("val").unwrap(), "test"); + assert_eq!(path.get("val2").unwrap(), "ttt"); + + let mut path = Path::new("/v/blah-blah/index.html"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 15); + assert_eq!(info, ResourceId(5)); + assert_eq!(path.get("tail").unwrap(), "blah-blah/index.html"); + + let mut path = Path::new("/test2/index.html"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 16); + assert_eq!(info, ResourceId(6)); + assert_eq!(path.get("test").unwrap(), "index"); + + let mut path = Path::new("/bbb/index.html"); + let (h, info) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 17); + assert_eq!(info, ResourceId(7)); + assert_eq!(path.get("test").unwrap(), "bbb"); + } + + #[test] + fn test_recognizer_2() { + let mut router = Router::::build(); + router.path("/index.json", 10); + router.path("/{source}.json", 11); + let mut router = router.finish(); + + let mut path = Path::new("/index.json"); + let (h, _) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 10); + + let mut path = Path::new("/test.json"); + let (h, _) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 11); + } + + #[test] + fn test_recognizer_with_prefix() { + let mut router = Router::::build(); + router.path("/name", 10).0.set_id(0); + router.path("/name/{val}", 11).0.set_id(1); + let mut router = router.finish(); + + let mut path = Path::new("/name"); + path.skip(5); + assert!(router.recognize_mut(&mut path).is_none()); + + let mut path = Path::new("/test/name"); + path.skip(5); + let (h, _) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 10); + + let mut path = Path::new("/test/name/value"); + path.skip(5); + let (h, id) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 11); + assert_eq!(id, ResourceId(1)); + assert_eq!(path.get("val").unwrap(), "value"); + assert_eq!(&path["val"], "value"); + + // same patterns + let mut router = Router::::build(); + router.path("/name", 10); + router.path("/name/{val}", 11); + let mut router = router.finish(); + + let mut path = Path::new("/name"); + path.skip(6); + assert!(router.recognize_mut(&mut path).is_none()); + + let mut path = Path::new("/test2/name"); + path.skip(6); + let (h, _) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 10); + + let mut path = Path::new("/test2/name-test"); + path.skip(6); + assert!(router.recognize_mut(&mut path).is_none()); + + let mut path = Path::new("/test2/name/ttt"); + path.skip(6); + let (h, _) = router.recognize_mut(&mut path).unwrap(); + assert_eq!(*h, 11); + assert_eq!(&path["val"], "ttt"); + } +} diff --git a/actix-router/src/url.rs b/actix-router/src/url.rs new file mode 100644 index 000000000..c5a3508aa --- /dev/null +++ b/actix-router/src/url.rs @@ -0,0 +1,140 @@ +use crate::ResourcePath; + +use crate::Quoter; + +thread_local! { + static DEFAULT_QUOTER: Quoter = Quoter::new(b"@:", b"%/+"); +} + +#[derive(Debug, Clone, Default)] +pub struct Url { + uri: http::Uri, + path: Option, +} + +impl Url { + #[inline] + pub fn new(uri: http::Uri) -> Url { + let path = DEFAULT_QUOTER.with(|q| q.requote(uri.path().as_bytes())); + Url { uri, path } + } + + #[inline] + pub fn new_with_quoter(uri: http::Uri, quoter: &Quoter) -> Url { + Url { + path: quoter.requote(uri.path().as_bytes()), + uri, + } + } + + /// Returns URI. + #[inline] + pub fn uri(&self) -> &http::Uri { + &self.uri + } + + /// Returns path. + #[inline] + pub fn path(&self) -> &str { + match self.path { + Some(ref path) => path, + _ => self.uri.path(), + } + } + + #[inline] + pub fn update(&mut self, uri: &http::Uri) { + self.uri = uri.clone(); + self.path = DEFAULT_QUOTER.with(|q| q.requote(uri.path().as_bytes())); + } + + #[inline] + pub fn update_with_quoter(&mut self, uri: &http::Uri, quoter: &Quoter) { + self.uri = uri.clone(); + self.path = quoter.requote(uri.path().as_bytes()); + } +} + +impl ResourcePath for Url { + #[inline] + fn path(&self) -> &str { + self.path() + } +} + +#[cfg(test)] +mod tests { + use http::Uri; + use std::convert::TryFrom; + + use super::*; + use crate::{Path, ResourceDef}; + + const PROTECTED: &[u8] = b"%/+"; + + fn match_url(pattern: &'static str, url: impl AsRef) -> Path { + let re = ResourceDef::new(pattern); + let uri = Uri::try_from(url.as_ref()).unwrap(); + let mut path = Path::new(Url::new(uri)); + assert!(re.capture_match_info(&mut path)); + path + } + + fn percent_encode(data: &[u8]) -> String { + data.iter().map(|c| format!("%{:02X}", c)).collect() + } + + #[test] + fn parse_url() { + let re = "/user/{id}/test"; + + let path = match_url(re, "/user/2345/test"); + assert_eq!(path.get("id").unwrap(), "2345"); + } + + #[test] + fn protected_chars() { + let re = "/user/{id}/test"; + + let encoded = percent_encode(PROTECTED); + let path = match_url(re, format!("/user/{}/test", encoded)); + // characters in captured segment remain unencoded + assert_eq!(path.get("id").unwrap(), &encoded); + + // "%25" should never be decoded into '%' to guarantee the output is a valid + // percent-encoded format + let path = match_url(re, "/user/qwe%25/test"); + assert_eq!(path.get("id").unwrap(), "qwe%25"); + + let path = match_url(re, "/user/qwe%25rty/test"); + assert_eq!(path.get("id").unwrap(), "qwe%25rty"); + } + + #[test] + fn non_protected_ascii() { + let non_protected_ascii = ('\u{0}'..='\u{7F}') + .filter(|&c| c.is_ascii() && !PROTECTED.contains(&(c as u8))) + .collect::(); + let encoded = percent_encode(non_protected_ascii.as_bytes()); + let path = match_url("/user/{id}/test", format!("/user/{}/test", encoded)); + assert_eq!(path.get("id").unwrap(), &non_protected_ascii); + } + + #[test] + fn valid_utf8_multibyte() { + let test = ('\u{FF00}'..='\u{FFFF}').collect::(); + let encoded = percent_encode(test.as_bytes()); + let path = match_url("/a/{id}/b", format!("/a/{}/b", &encoded)); + assert_eq!(path.get("id").unwrap(), &test); + } + + #[test] + fn invalid_utf8() { + let invalid_utf8 = percent_encode((0x80..=0xff).collect::>().as_slice()); + let uri = Uri::try_from(format!("/{}", invalid_utf8)).unwrap(); + let path = Path::new(Url::new(uri)); + + // We should always get a valid utf8 string + assert!(String::from_utf8(path.path().as_bytes().to_owned()).is_ok()); + } +} diff --git a/actix-test/CHANGES.md b/actix-test/CHANGES.md new file mode 100644 index 000000000..32ab2344f --- /dev/null +++ b/actix-test/CHANGES.md @@ -0,0 +1,57 @@ +# Changes + +## Unreleased - 2021-xx-xx + + +## 0.1.0-beta.11 - 2022-01-04 +- Minimum supported Rust version (MSRV) is now 1.54. + + +## 0.1.0-beta.10 - 2021-12-27 +- No significant changes since `0.1.0-beta.9`. + + +## 0.1.0-beta.9 - 2021-12-17 +- Re-export `actix_http::body::to_bytes`. [#2518] +- Update `actix_web::test` re-exports. [#2518] + +[#2518]: https://github.com/actix/actix-web/pull/2518 + + +## 0.1.0-beta.8 - 2021-12-11 +- No significant changes since `0.1.0-beta.7`. + + +## 0.1.0-beta.7 - 2021-11-22 +- Fix compatibility with experimental `io-uring` feature of `actix-rt`. [#2408] + +[#2408]: https://github.com/actix/actix-web/pull/2408 + + +## 0.1.0-beta.6 - 2021-11-15 +- No significant changes from `0.1.0-beta.5`. + + +## 0.1.0-beta.5 - 2021-10-20 +- Updated rustls to v0.20. [#2414] +- Minimum supported Rust version (MSRV) is now 1.52. + +[#2414]: https://github.com/actix/actix-web/pull/2414 + + +## 0.1.0-beta.4 - 2021-09-09 +- Minimum supported Rust version (MSRV) is now 1.51. + + +## 0.1.0-beta.3 - 2021-06-20 +- No significant changes from `0.1.0-beta.2`. + + +## 0.1.0-beta.2 - 2021-04-17 +- No significant changes from `0.1.0-beta.1`. + + +## 0.1.0-beta.1 - 2021-04-02 +- Move integration testing structs from `actix-web`. [#2112] + +[#2112]: https://github.com/actix/actix-web/pull/2112 diff --git a/actix-test/Cargo.toml b/actix-test/Cargo.toml new file mode 100644 index 000000000..89f28a5da --- /dev/null +++ b/actix-test/Cargo.toml @@ -0,0 +1,48 @@ +[package] +name = "actix-test" +version = "0.1.0-beta.11" +authors = [ + "Nikolay Kim ", + "Rob Ede ", +] +description = "Integration testing tools for Actix Web applications" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +categories = [ + "network-programming", + "asynchronous", + "web-programming::http-server", + "web-programming::websocket", +] +license = "MIT OR Apache-2.0" +edition = "2018" + +[features] +default = [] + +# rustls +rustls = ["tls-rustls", "actix-http/rustls", "awc/rustls"] + +# openssl +openssl = ["tls-openssl", "actix-http/openssl", "awc/openssl"] + +[dependencies] +actix-codec = "0.4.1" +actix-http = "3.0.0-beta.18" +actix-http-test = "3.0.0-beta.11" +actix-rt = "2.1" +actix-service = "2.0.0" +actix-utils = "3.0.0" +actix-web = { version = "4.0.0-beta.19", default-features = false, features = ["cookies"] } +awc = { version = "3.0.0-beta.18", default-features = false, features = ["cookies"] } + +futures-core = { version = "0.3.7", default-features = false, features = ["std"] } +futures-util = { version = "0.3.7", default-features = false, features = [] } +log = "0.4" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +serde_urlencoded = "0.7" +tls-openssl = { package = "openssl", version = "0.10.9", optional = true } +tls-rustls = { package = "rustls", version = "0.20.0", optional = true } +tokio = { version = "1.8.4", features = ["sync"] } diff --git a/actix-test/LICENSE-APACHE b/actix-test/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-test/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-test/LICENSE-MIT b/actix-test/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-test/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-test/src/lib.rs b/actix-test/src/lib.rs new file mode 100644 index 000000000..f86120f2f --- /dev/null +++ b/actix-test/src/lib.rs @@ -0,0 +1,575 @@ +//! Integration testing tools for Actix Web applications. +//! +//! The main integration testing tool is [`TestServer`]. It spawns a real HTTP server on an +//! unused port and provides methods that use a real HTTP client. Therefore, it is much closer to +//! real-world cases than using `init_service`, which skips HTTP encoding and decoding. +//! +//! # Examples +//! ``` +//! use actix_web::{get, web, test, App, HttpResponse, Error, Responder}; +//! +//! #[get("/")] +//! async fn my_handler() -> Result { +//! Ok(HttpResponse::Ok()) +//! } +//! +//! #[actix_rt::test] +//! async fn test_example() { +//! let srv = actix_test::start(|| +//! App::new().service(my_handler) +//! ); +//! +//! let req = srv.get("/"); +//! let res = req.send().await.unwrap(); +//! +//! assert!(res.status().is_success()); +//! } +//! ``` + +#![deny(rust_2018_idioms, nonstandard_style)] +#![warn(future_incompatible)] + +#[cfg(feature = "openssl")] +extern crate tls_openssl as openssl; +#[cfg(feature = "rustls")] +extern crate tls_rustls as rustls; + +use std::{fmt, net, thread, time::Duration}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +pub use actix_http::{body::to_bytes, test::TestBuffer}; +use actix_http::{header::HeaderMap, ws, HttpService, Method, Request, Response}; +pub use actix_http_test::unused_addr; +use actix_service::{map_config, IntoServiceFactory, ServiceFactory, ServiceFactoryExt as _}; +pub use actix_web::test::{ + call_and_read_body, call_and_read_body_json, call_service, init_service, ok_service, + read_body, read_body_json, simple_service, TestRequest, +}; +use actix_web::{ + body::MessageBody, + dev::{AppConfig, Server, ServerHandle, Service}, + rt::{self, System}, + web, Error, +}; +use awc::{error::PayloadError, Client, ClientRequest, ClientResponse, Connector}; +use futures_core::Stream; +use tokio::sync::mpsc; + +/// Start default [`TestServer`]. +/// +/// # Examples +/// ``` +/// use actix_web::{get, web, test, App, HttpResponse, Error, Responder}; +/// +/// #[get("/")] +/// async fn my_handler() -> Result { +/// Ok(HttpResponse::Ok()) +/// } +/// +/// #[actix_web::test] +/// async fn test_example() { +/// let srv = actix_test::start(|| +/// App::new().service(my_handler) +/// ); +/// +/// let req = srv.get("/"); +/// let res = req.send().await.unwrap(); +/// +/// assert!(res.status().is_success()); +/// } +/// ``` +pub fn start(factory: F) -> TestServer +where + F: Fn() -> I + Send + Clone + 'static, + I: IntoServiceFactory, + S: ServiceFactory + 'static, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + >::Future: 'static, + B: MessageBody + 'static, +{ + start_with(TestServerConfig::default(), factory) +} + +/// Start test server with custom configuration +/// +/// Check [`TestServerConfig`] docs for configuration options. +/// +/// # Examples +/// ``` +/// use actix_web::{get, web, test, App, HttpResponse, Error, Responder}; +/// +/// #[get("/")] +/// async fn my_handler() -> Result { +/// Ok(HttpResponse::Ok()) +/// } +/// +/// #[actix_web::test] +/// async fn test_example() { +/// let srv = actix_test::start_with(actix_test::config().h1(), || +/// App::new().service(my_handler) +/// ); +/// +/// let req = srv.get("/"); +/// let res = req.send().await.unwrap(); +/// +/// assert!(res.status().is_success()); +/// } +/// ``` +pub fn start_with(cfg: TestServerConfig, factory: F) -> TestServer +where + F: Fn() -> I + Send + Clone + 'static, + I: IntoServiceFactory, + S: ServiceFactory + 'static, + S::Error: Into + 'static, + S::InitError: fmt::Debug, + S::Response: Into> + 'static, + >::Future: 'static, + B: MessageBody + 'static, +{ + // for sending handles and server info back from the spawned thread + let (started_tx, started_rx) = std::sync::mpsc::channel(); + + // for signaling the shutdown of spawned server and system + let (thread_stop_tx, thread_stop_rx) = mpsc::channel(1); + + let tls = match cfg.stream { + StreamType::Tcp => false, + #[cfg(feature = "openssl")] + StreamType::Openssl(_) => true, + #[cfg(feature = "rustls")] + StreamType::Rustls(_) => true, + }; + + // run server in separate orphaned thread + thread::spawn(move || { + rt::System::new().block_on(async move { + let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); + let local_addr = tcp.local_addr().unwrap(); + let factory = factory.clone(); + let srv_cfg = cfg.clone(); + let timeout = cfg.client_timeout; + + let builder = Server::build().workers(1).disable_signals().system_exit(); + + let srv = match srv_cfg.stream { + StreamType::Tcp => match srv_cfg.tp { + HttpVer::Http1 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); + + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + HttpService::build() + .client_timeout(timeout) + .h1(map_config(fac, move |_| app_cfg.clone())) + .tcp() + }), + HttpVer::Http2 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); + + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + HttpService::build() + .client_timeout(timeout) + .h2(map_config(fac, move |_| app_cfg.clone())) + .tcp() + }), + HttpVer::Both => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); + + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + HttpService::build() + .client_timeout(timeout) + .finish(map_config(fac, move |_| app_cfg.clone())) + .tcp() + }), + }, + #[cfg(feature = "openssl")] + StreamType::Openssl(acceptor) => match cfg.tp { + HttpVer::Http1 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); + + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + HttpService::build() + .client_timeout(timeout) + .h1(map_config(fac, move |_| app_cfg.clone())) + .openssl(acceptor.clone()) + }), + HttpVer::Http2 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); + + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + HttpService::build() + .client_timeout(timeout) + .h2(map_config(fac, move |_| app_cfg.clone())) + .openssl(acceptor.clone()) + }), + HttpVer::Both => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); + + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + HttpService::build() + .client_timeout(timeout) + .finish(map_config(fac, move |_| app_cfg.clone())) + .openssl(acceptor.clone()) + }), + }, + #[cfg(feature = "rustls")] + StreamType::Rustls(config) => match cfg.tp { + HttpVer::Http1 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); + + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + HttpService::build() + .client_timeout(timeout) + .h1(map_config(fac, move |_| app_cfg.clone())) + .rustls(config.clone()) + }), + HttpVer::Http2 => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); + + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + HttpService::build() + .client_timeout(timeout) + .h2(map_config(fac, move |_| app_cfg.clone())) + .rustls(config.clone()) + }), + HttpVer::Both => builder.listen("test", tcp, move || { + let app_cfg = AppConfig::__priv_test_new( + false, + local_addr.to_string(), + local_addr, + ); + + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + HttpService::build() + .client_timeout(timeout) + .finish(map_config(fac, move |_| app_cfg.clone())) + .rustls(config.clone()) + }), + }, + } + .expect("test server could not be created"); + + let srv = srv.run(); + started_tx + .send((System::current(), srv.handle(), local_addr)) + .unwrap(); + + // drive server loop + srv.await.unwrap(); + + // notify TestServer that server and system have shut down + // all thread managed resources should be dropped at this point + }); + + let _ = thread_stop_tx.send(()); + }); + + let (system, server, addr) = started_rx.recv().unwrap(); + + let client = { + let connector = { + #[cfg(feature = "openssl")] + { + use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); + Connector::new() + .conn_lifetime(Duration::from_secs(0)) + .timeout(Duration::from_millis(30000)) + .openssl(builder.build()) + } + #[cfg(not(feature = "openssl"))] + { + Connector::new() + .conn_lifetime(Duration::from_secs(0)) + .timeout(Duration::from_millis(30000)) + } + }; + + Client::builder().connector(connector).finish() + }; + + TestServer { + server, + thread_stop_rx, + client, + system, + addr, + tls, + } +} + +#[derive(Debug, Clone)] +enum HttpVer { + Http1, + Http2, + Both, +} + +#[derive(Clone)] +enum StreamType { + Tcp, + #[cfg(feature = "openssl")] + Openssl(openssl::ssl::SslAcceptor), + #[cfg(feature = "rustls")] + Rustls(rustls::ServerConfig), +} + +/// Create default test server config. +pub fn config() -> TestServerConfig { + TestServerConfig::default() +} + +#[derive(Clone)] +pub struct TestServerConfig { + tp: HttpVer, + stream: StreamType, + client_timeout: u64, +} + +impl Default for TestServerConfig { + fn default() -> Self { + TestServerConfig::new() + } +} + +impl TestServerConfig { + /// Create default server configuration + pub(crate) fn new() -> TestServerConfig { + TestServerConfig { + tp: HttpVer::Both, + stream: StreamType::Tcp, + client_timeout: 5000, + } + } + + /// Accept HTTP/1.1 only. + pub fn h1(mut self) -> Self { + self.tp = HttpVer::Http1; + self + } + + /// Accept HTTP/2 only. + pub fn h2(mut self) -> Self { + self.tp = HttpVer::Http2; + self + } + + /// Accept secure connections via OpenSSL. + #[cfg(feature = "openssl")] + pub fn openssl(mut self, acceptor: openssl::ssl::SslAcceptor) -> Self { + self.stream = StreamType::Openssl(acceptor); + self + } + + /// Accept secure connections via Rustls. + #[cfg(feature = "rustls")] + pub fn rustls(mut self, config: rustls::ServerConfig) -> Self { + self.stream = StreamType::Rustls(config); + self + } + + /// Set client timeout in milliseconds for first request. + pub fn client_timeout(mut self, val: u64) -> Self { + self.client_timeout = val; + self + } +} + +/// A basic HTTP server controller that simplifies the process of writing integration tests for +/// Actix Web applications. +/// +/// See [`start`] for usage example. +pub struct TestServer { + server: ServerHandle, + thread_stop_rx: mpsc::Receiver<()>, + client: awc::Client, + system: rt::System, + addr: net::SocketAddr, + tls: bool, +} + +impl TestServer { + /// Construct test server url + pub fn addr(&self) -> net::SocketAddr { + self.addr + } + + /// Construct test server url + pub fn url(&self, uri: &str) -> String { + let scheme = if self.tls { "https" } else { "http" }; + + if uri.starts_with('/') { + format!("{}://localhost:{}{}", scheme, self.addr.port(), uri) + } else { + format!("{}://localhost:{}/{}", scheme, self.addr.port(), uri) + } + } + + /// Create `GET` request. + pub fn get(&self, path: impl AsRef) -> ClientRequest { + self.client.get(self.url(path.as_ref()).as_str()) + } + + /// Create `POST` request. + pub fn post(&self, path: impl AsRef) -> ClientRequest { + self.client.post(self.url(path.as_ref()).as_str()) + } + + /// Create `HEAD` request. + pub fn head(&self, path: impl AsRef) -> ClientRequest { + self.client.head(self.url(path.as_ref()).as_str()) + } + + /// Create `PUT` request. + pub fn put(&self, path: impl AsRef) -> ClientRequest { + self.client.put(self.url(path.as_ref()).as_str()) + } + + /// Create `PATCH` request. + pub fn patch(&self, path: impl AsRef) -> ClientRequest { + self.client.patch(self.url(path.as_ref()).as_str()) + } + + /// Create `DELETE` request. + pub fn delete(&self, path: impl AsRef) -> ClientRequest { + self.client.delete(self.url(path.as_ref()).as_str()) + } + + /// Create `OPTIONS` request. + pub fn options(&self, path: impl AsRef) -> ClientRequest { + self.client.options(self.url(path.as_ref()).as_str()) + } + + /// Connect request with given method and path. + pub fn request(&self, method: Method, path: impl AsRef) -> ClientRequest { + self.client.request(method, path.as_ref()) + } + + pub async fn load_body( + &mut self, + mut response: ClientResponse, + ) -> Result + where + S: Stream> + Unpin + 'static, + { + response.body().limit(10_485_760).await + } + + /// Connect to WebSocket server at a given path. + pub async fn ws_at( + &mut self, + path: &str, + ) -> Result, awc::error::WsClientError> { + let url = self.url(path); + let connect = self.client.ws(url).connect(); + connect.await.map(|(_, framed)| framed) + } + + /// Connect to a WebSocket server. + pub async fn ws( + &mut self, + ) -> Result, awc::error::WsClientError> { + self.ws_at("/").await + } + + /// Get default HeaderMap of Client. + /// + /// Returns Some(&mut HeaderMap) when Client object is unique + /// (No other clone of client exists at the same time). + pub fn client_headers(&mut self) -> Option<&mut HeaderMap> { + self.client.headers() + } + + /// Stop HTTP server. + /// + /// Waits for spawned `Server` and `System` to shutdown (force) shutdown. + pub async fn stop(mut self) { + // signal server to stop + self.server.stop(false).await; + + // also signal system to stop + // though this is handled by `ServerBuilder::exit_system` too + self.system.stop(); + + // wait for thread to be stopped but don't care about result + let _ = self.thread_stop_rx.recv().await; + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + // calls in this Drop impl should be enough to shut down the server, system, and thread + // without needing to await anything + + // signal server to stop + let _ = self.server.stop(true); + + // signal system to stop + self.system.stop(); + } +} diff --git a/actix-web-actors/CHANGES.md b/actix-web-actors/CHANGES.md index acd9ceada..74ab3c785 100644 --- a/actix-web-actors/CHANGES.md +++ b/actix-web-actors/CHANGES.md @@ -3,76 +3,114 @@ ## Unreleased - 2021-xx-xx +## 4.0.0-beta.10 - 2022-01-04 +- Minimum supported Rust version (MSRV) is now 1.54. + + +## 4.0.0-beta.9 - 2021-12-27 +- No significant changes since `4.0.0-beta.8`. + + +## 4.0.0-beta.8 - 2021-12-11 +- Add `ws:WsResponseBuilder` for building WebSocket session response. [#1920] +- Deprecate `ws::{start_with_addr, start_with_protocols}`. [#1920] +- Minimum supported Rust version (MSRV) is now 1.52. + +[#1920]: https://github.com/actix/actix-web/pull/1920 + + +## 4.0.0-beta.7 - 2021-09-09 +- Minimum supported Rust version (MSRV) is now 1.51. + + +## 4.0.0-beta.6 - 2021-06-26 +- Update `actix` to `0.12`. [#2277] + +[#2277]: https://github.com/actix/actix-web/pull/2277 + + +## 4.0.0-beta.5 - 2021-06-17 +- No notable changes. + + +## 4.0.0-beta.4 - 2021-04-02 +- No notable changes. + + +## 4.0.0-beta.3 - 2021-03-09 +- No notable changes. + + ## 4.0.0-beta.2 - 2021-02-10 -* No notable changes. +- No notable changes. ## 4.0.0-beta.1 - 2021-01-07 -* Update `pin-project` to `1.0`. -* Update `bytes` to `1.0`. [#1813] -* `WebsocketContext::text` now takes an `Into`. [#1864] +- Update `pin-project` to `1.0`. +- Update `bytes` to `1.0`. [#1813] +- `WebsocketContext::text` now takes an `Into`. [#1864] [#1813]: https://github.com/actix/actix-web/pull/1813 [#1864]: https://github.com/actix/actix-web/pull/1864 ## 3.0.0 - 2020-09-11 -* No significant changes from `3.0.0-beta.2`. +- No significant changes from `3.0.0-beta.2`. ## 3.0.0-beta.2 - 2020-09-10 -* Update `actix-*` dependencies to latest versions. +- Update `actix-*` dependencies to latest versions. ## [3.0.0-beta.1] - 2020-xx-xx -* Update `actix-web` & `actix-http` dependencies to beta.1 -* Bump minimum supported Rust version to 1.40 +- Update `actix-web` & `actix-http` dependencies to beta.1 +- Bump minimum supported Rust version to 1.40 ## [3.0.0-alpha.1] - 2020-05-08 -* Update the actix-web dependency to 3.0.0-alpha.1 -* Update the actix dependency to 0.10.0-alpha.2 -* Update the actix-http dependency to 2.0.0-alpha.3 +- Update the actix-web dependency to 3.0.0-alpha.1 +- Update the actix dependency to 0.10.0-alpha.2 +- Update the actix-http dependency to 2.0.0-alpha.3 ## [2.0.0] - 2019-12-20 -* Release +- Release ## [2.0.0-alpha.1] - 2019-12-15 -* Migrate to actix-web 2.0.0 +- Migrate to actix-web 2.0.0 ## [1.0.4] - 2019-12-07 -* Allow comma-separated websocket subprotocols without spaces (#1172) +- Allow comma-separated websocket subprotocols without spaces (#1172) ## [1.0.3] - 2019-11-14 -* Update actix-web and actix-http dependencies +- Update actix-web and actix-http dependencies ## [1.0.2] - 2019-07-20 -* Add `ws::start_with_addr()`, returning the address of the created actor, along +- Add `ws::start_with_addr()`, returning the address of the created actor, along with the `HttpResponse`. -* Add support for specifying protocols on websocket handshake #835 +- Add support for specifying protocols on websocket handshake #835 ## [1.0.1] - 2019-06-28 -* Allow to use custom ws codec with `WebsocketContext` #925 +- Allow to use custom ws codec with `WebsocketContext` #925 ## [1.0.0] - 2019-05-29 -* Update actix-http and actix-web +- Update actix-http and actix-web ## [0.1.0-alpha.3] - 2019-04-02 -* Update actix-http and actix-web +- Update actix-http and actix-web ## [0.1.0-alpha.2] - 2019-03-29 -* Update actix-http and actix-web +- Update actix-http and actix-web ## [0.1.0-alpha.1] - 2019-03-28 -* Initial impl +- Initial impl diff --git a/actix-web-actors/Cargo.toml b/actix-web-actors/Cargo.toml index 698ff9420..25b6ea538 100644 --- a/actix-web-actors/Cargo.toml +++ b/actix-web-actors/Cargo.toml @@ -1,13 +1,11 @@ [package] name = "actix-web-actors" -version = "4.0.0-beta.2" +version = "4.0.0-beta.10" authors = ["Nikolay Kim "] description = "Actix actors support for Actix Web" -readme = "README.md" keywords = ["actix", "http", "web", "framework", "async"] homepage = "https://actix.rs" -repository = "https://github.com/actix/actix-web.git" -documentation = "https://docs.rs/actix-web-actors/" +repository = "https://github.com/actix/actix-web" license = "MIT OR Apache-2.0" edition = "2018" @@ -16,18 +14,21 @@ name = "actix_web_actors" path = "src/lib.rs" [dependencies] -actix = { version = "0.11.0-beta.2", default-features = false } -actix-codec = "0.4.0-beta.1" -actix-http = "3.0.0-beta.3" -actix-web = { version = "4.0.0-beta.3", default-features = false } +actix = { version = "0.12.0", default-features = false } +actix-codec = "0.4.1" +actix-http = "3.0.0-beta.18" +actix-web = { version = "4.0.0-beta.19", default-features = false } bytes = "1" bytestring = "1" futures-core = { version = "0.3.7", default-features = false } -pin-project = "1.0.0" -tokio = { version = "1", features = ["sync"] } +pin-project-lite = "0.2" +tokio = { version = "1.8.4", features = ["sync"] } [dev-dependencies] -actix-rt = "2" -env_logger = "0.8" +actix-rt = "2.2" +actix-test = "0.1.0-beta.11" +awc = { version = "3.0.0-beta.18", default-features = false } + +env_logger = "0.9" futures-util = { version = "0.3.7", default-features = false } diff --git a/actix-web-actors/README.md b/actix-web-actors/README.md index c9b588153..60e6a9bd9 100644 --- a/actix-web-actors/README.md +++ b/actix-web-actors/README.md @@ -3,16 +3,15 @@ > Actix actors support for Actix Web. [![crates.io](https://img.shields.io/crates/v/actix-web-actors?label=latest)](https://crates.io/crates/actix-web-actors) -[![Documentation](https://docs.rs/actix-web-actors/badge.svg?version=0.5.0)](https://docs.rs/actix-web-actors/0.5.0) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-web-actors/badge.svg?version=4.0.0-beta.10)](https://docs.rs/actix-web-actors/4.0.0-beta.10) +[![Version](https://img.shields.io/badge/rustc-1.54+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.54.0.html) ![License](https://img.shields.io/crates/l/actix-web-actors.svg)
-[![dependency status](https://deps.rs/crate/actix-web-actors/0.5.0/status.svg)](https://deps.rs/crate/actix-web-actors/0.5.0) +[![dependency status](https://deps.rs/crate/actix-web-actors/4.0.0-beta.10/status.svg)](https://deps.rs/crate/actix-web-actors/4.0.0-beta.10) [![Download](https://img.shields.io/crates/d/actix-web-actors.svg)](https://crates.io/crates/actix-web-actors) -[![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-web-actors) -- [Chat on Gitter](https://gitter.im/actix/actix-web) -- Minimum supported Rust version: 1.46 or later +- Minimum Supported Rust Version (MSRV): 1.54 diff --git a/actix-web-actors/src/context.rs b/actix-web-actors/src/context.rs index bd6635cba..d7459aea4 100644 --- a/actix-web-actors/src/context.rs +++ b/actix-web-actors/src/context.rs @@ -44,7 +44,7 @@ where #[inline] fn spawn(&mut self, fut: F) -> SpawnHandle where - F: ActorFuture + 'static, + F: ActorFuture + 'static, { self.inner.spawn(fut) } @@ -52,7 +52,7 @@ where #[inline] fn wait(&mut self, fut: F) where - F: ActorFuture + 'static, + F: ActorFuture + 'static, { self.inner.wait(fut) } diff --git a/actix-web-actors/src/lib.rs b/actix-web-actors/src/lib.rs index 7a4823d91..70c957020 100644 --- a/actix-web-actors/src/lib.rs +++ b/actix-web-actors/src/lib.rs @@ -1,7 +1,7 @@ //! Actix actors support for Actix Web. -#![deny(rust_2018_idioms)] -#![allow(clippy::borrow_interior_mutable_const)] +#![deny(rust_2018_idioms, nonstandard_style)] +#![warn(future_incompatible)] mod context; pub mod ws; diff --git a/actix-web-actors/src/ws.rs b/actix-web-actors/src/ws.rs index 1ab4cfce5..6fde10192 100644 --- a/actix-web-actors/src/ws.rs +++ b/actix-web-actors/src/ws.rs @@ -1,34 +1,222 @@ //! Websocket integration. -use std::future::Future; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{collections::VecDeque, convert::TryFrom}; - -use actix::dev::{ - AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, ToEnvelope, +use std::{ + collections::VecDeque, + convert::TryFrom, + future::Future, + io, mem, + pin::Pin, + task::{Context, Poll}, }; -use actix::fut::ActorFuture; + use actix::{ + dev::{ + AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, + ToEnvelope, + }, + fut::ActorFuture, Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message as ActixMessage, SpawnHandle, }; -use actix_codec::{Decoder, Encoder}; +use actix_codec::{Decoder as _, Encoder as _}; use actix_http::ws::{hash_key, Codec}; pub use actix_http::ws::{ CloseCode, CloseReason, Frame, HandshakeError, Message, ProtocolError, }; -use actix_web::dev::HttpResponseBuilder; -use actix_web::error::{Error, PayloadError}; -use actix_web::http::{header, Method, StatusCode}; -use actix_web::{HttpRequest, HttpResponse}; +use actix_web::{ + error::{Error, PayloadError}, + http::{ + header::{self, HeaderValue}, + Method, StatusCode, + }, + HttpRequest, HttpResponse, HttpResponseBuilder, +}; use bytes::{Bytes, BytesMut}; use bytestring::ByteString; use futures_core::Stream; -use tokio::sync::oneshot::Sender; +use pin_project_lite::pin_project; +use tokio::sync::oneshot; + +/// Builder for Websocket session response. +/// +/// # Examples +/// +/// Create a Websocket session response with default configuration. +/// ```ignore +/// WsResponseBuilder::new(WsActor, &req, stream).start() +/// ``` +/// +/// Create a Websocket session with a specific max frame size, [`Codec`], and protocols. +/// ```ignore +/// const MAX_FRAME_SIZE: usize = 16_384; // 16KiB +/// +/// ws::WsResponseBuilder::new(WsActor, &req, stream) +/// .codec(Codec::new()) +/// .protocols(&["A", "B"]) +/// .frame_size(MAX_FRAME_SIZE) +/// .start() +/// ``` +pub struct WsResponseBuilder<'a, A, T> +where + A: Actor> + StreamHandler>, + T: Stream> + 'static, +{ + actor: A, + req: &'a HttpRequest, + stream: T, + codec: Option, + protocols: Option<&'a [&'a str]>, + frame_size: Option, +} + +impl<'a, A, T> WsResponseBuilder<'a, A, T> +where + A: Actor> + StreamHandler>, + T: Stream> + 'static, +{ + /// Construct a new `WsResponseBuilder` with actor, request, and payload stream. + /// + /// For usage example, see docs on [`WsResponseBuilder`] struct. + pub fn new(actor: A, req: &'a HttpRequest, stream: T) -> Self { + WsResponseBuilder { + actor, + req, + stream, + codec: None, + protocols: None, + frame_size: None, + } + } + + /// Set the protocols for the session. + pub fn protocols(mut self, protocols: &'a [&'a str]) -> Self { + self.protocols = Some(protocols); + self + } + + /// Set the max frame size for each message (in bytes). + /// + /// **Note**: This will override any given [`Codec`]'s max frame size. + pub fn frame_size(mut self, frame_size: usize) -> Self { + self.frame_size = Some(frame_size); + self + } + + /// Set the [`Codec`] for the session. If [`Self::frame_size`] is also set, the given + /// [`Codec`]'s max frame size will be overridden. + pub fn codec(mut self, codec: Codec) -> Self { + self.codec = Some(codec); + self + } + + fn handshake_resp(&self) -> Result { + match self.protocols { + Some(protocols) => handshake_with_protocols(self.req, protocols), + None => handshake(self.req), + } + } + + fn set_frame_size(&mut self) { + if let Some(frame_size) = self.frame_size { + match &mut self.codec { + Some(codec) => { + // modify existing codec's max frame size + let orig_codec = mem::take(codec); + *codec = orig_codec.max_size(frame_size); + } + + None => { + // create a new codec with the given size + self.codec = Some(Codec::new().max_size(frame_size)); + } + } + } + } + + /// Create a new Websocket context from an actor, request stream, and codec. + /// + /// Returns a pair, where the first item is an addr for the created actor, and the second item + /// is a stream intended to be set as part of the response + /// via [`HttpResponseBuilder::streaming()`]. + fn create_with_codec_addr( + actor: A, + stream: S, + codec: Codec, + ) -> (Addr
, impl Stream>) + where + A: StreamHandler>, + S: Stream> + 'static, + { + let mb = Mailbox::default(); + let mut ctx = WebsocketContext { + inner: ContextParts::new(mb.sender_producer()), + messages: VecDeque::new(), + }; + ctx.add_stream(WsStream::new(stream, codec.clone())); + + let addr = ctx.address(); + + (addr, WebsocketContextFut::new(ctx, actor, mb, codec)) + } + + /// Perform WebSocket handshake and start actor. + /// + /// `req` is an [`HttpRequest`] that should be requesting a websocket protocol change. + /// `stream` should be a [`Bytes`] stream (such as `actix_web::web::Payload`) that contains a + /// stream of the body request. + /// + /// If there is a problem with the handshake, an error is returned. + /// + /// If successful, consume the [`WsResponseBuilder`] and return a [`HttpResponse`] wrapped in + /// a [`Result`]. + pub fn start(mut self) -> Result { + let mut res = self.handshake_resp()?; + self.set_frame_size(); + + match self.codec { + Some(codec) => { + let out_stream = WebsocketContext::with_codec(self.actor, self.stream, codec); + Ok(res.streaming(out_stream)) + } + None => { + let out_stream = WebsocketContext::create(self.actor, self.stream); + Ok(res.streaming(out_stream)) + } + } + } + + /// Perform WebSocket handshake and start actor. + /// + /// `req` is an [`HttpRequest`] that should be requesting a websocket protocol change. + /// `stream` should be a [`Bytes`] stream (such as `actix_web::web::Payload`) that contains a + /// stream of the body request. + /// + /// If there is a problem with the handshake, an error is returned. + /// + /// If successful, returns a pair where the first item is an address for the created actor and + /// the second item is the [`HttpResponse`] that should be returned from the websocket request. + pub fn start_with_addr(mut self) -> Result<(Addr, HttpResponse), Error> { + let mut res = self.handshake_resp()?; + self.set_frame_size(); + + match self.codec { + Some(codec) => { + let (addr, out_stream) = + Self::create_with_codec_addr(self.actor, self.stream, codec); + Ok((addr, res.streaming(out_stream))) + } + None => { + let (addr, out_stream) = + WebsocketContext::create_with_addr(self.actor, self.stream); + Ok((addr, res.streaming(out_stream))) + } + } + } +} /// Perform WebSocket handshake and start actor. +/// +/// To customize options, see [`WsResponseBuilder`]. pub fn start(actor: A, req: &HttpRequest, stream: T) -> Result where A: Actor> + StreamHandler>, @@ -40,15 +228,15 @@ where /// Perform WebSocket handshake and start actor. /// -/// `req` is an HTTP Request that should be requesting a websocket protocol -/// change. `stream` should be a `Bytes` stream (such as -/// `actix_web::web::Payload`) that contains a stream of the body request. +/// `req` is an HTTP Request that should be requesting a websocket protocol change. `stream` should +/// be a `Bytes` stream (such as `actix_web::web::Payload`) that contains a stream of the +/// body request. /// /// If there is a problem with the handshake, an error is returned. /// -/// If successful, returns a pair where the first item is an address for the -/// created actor and the second item is the response that should be returned -/// from the WebSocket request. +/// If successful, returns a pair where the first item is an address for the created actor and the +/// second item is the response that should be returned from the WebSocket request. +#[deprecated(since = "4.0.0", note = "Prefer `WsResponseBuilder::start_with_addr`.")] pub fn start_with_addr( actor: A, req: &HttpRequest, @@ -66,6 +254,10 @@ where /// Do WebSocket handshake and start ws actor. /// /// `protocols` is a sequence of known protocols. +#[deprecated( + since = "4.0.0", + note = "Prefer `WsResponseBuilder` for setting protocols." +)] pub fn start_with_protocols( actor: A, protocols: &[&str], @@ -82,20 +274,19 @@ where /// Prepare WebSocket handshake response. /// -/// This function returns handshake `HttpResponse`, ready to send to peer. -/// It does not perform any IO. +/// This function returns handshake `HttpResponse`, ready to send to peer. It does not perform +/// any IO. pub fn handshake(req: &HttpRequest) -> Result { handshake_with_protocols(req, &[]) } /// Prepare WebSocket handshake response. /// -/// This function returns handshake `HttpResponse`, ready to send to peer. -/// It does not perform any IO. +/// This function returns handshake `HttpResponse`, ready to send to peer. It does not perform +/// any IO. /// -/// `protocols` is a sequence of known protocols. On successful handshake, -/// the returned response headers contain the first protocol in this list -/// which the server also knows. +/// `protocols` is a sequence of known protocols. On successful handshake, the returned response +/// headers contain the first protocol in this list which the server also knows. pub fn handshake_with_protocols( req: &HttpRequest, protocols: &[&str], @@ -162,7 +353,11 @@ pub fn handshake_with_protocols( let mut response = HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) .upgrade("websocket") - .insert_header((header::SEC_WEBSOCKET_ACCEPT, key)) + .insert_header(( + header::SEC_WEBSOCKET_ACCEPT, + // key is known to be header value safe ascii + HeaderValue::from_bytes(&key).unwrap(), + )) .take(); if let Some(protocol) = protocol { @@ -204,14 +399,14 @@ where { fn spawn(&mut self, fut: F) -> SpawnHandle where - F: ActorFuture + 'static, + F: ActorFuture + 'static, { self.inner.spawn(fut) } fn wait(&mut self, fut: F) where - F: ActorFuture + 'static, + F: ActorFuture + 'static, { self.inner.wait(fut) } @@ -238,8 +433,8 @@ impl WebsocketContext where A: Actor, { + /// Create a new Websocket context from a request and an actor. #[inline] - /// Create a new Websocket context from a request and an actor pub fn create(actor: A, stream: S) -> impl Stream> where A: StreamHandler>, @@ -249,12 +444,11 @@ where stream } - #[inline] /// Create a new Websocket context from a request and an actor. /// - /// Returns a pair, where the first item is an addr for the created actor, - /// and the second item is a stream intended to be set as part of the - /// response via `HttpResponseBuilder::streaming()`. + /// Returns a pair, where the first item is an addr for the created actor, and the second item + /// is a stream intended to be set as part of the response + /// via [`HttpResponseBuilder::streaming()`]. pub fn create_with_addr( actor: A, stream: S, @@ -275,7 +469,6 @@ where (addr, WebsocketContextFut::new(ctx, actor, mb, Codec::new())) } - #[inline] /// Create a new Websocket context from a request, an actor, and a codec pub fn with_codec( actor: A, @@ -291,7 +484,7 @@ where inner: ContextParts::new(mb.sender_producer()), messages: VecDeque::new(), }; - ctx.add_stream(WsStream::new(stream, codec)); + ctx.add_stream(WsStream::new(stream, codec.clone())); WebsocketContextFut::new(ctx, actor, mb, codec) } @@ -449,18 +642,20 @@ where M: ActixMessage + Send + 'static, M::Result: Send, { - fn pack(msg: M, tx: Option>) -> Envelope { + fn pack(msg: M, tx: Option>) -> Envelope { Envelope::new(msg, tx) } } -#[pin_project::pin_project] -struct WsStream { - #[pin] - stream: S, - decoder: Codec, - buf: BytesMut, - closed: bool, +pin_project! { + #[derive(Debug)] + struct WsStream { + #[pin] + stream: S, + decoder: Codec, + buf: BytesMut, + closed: bool, + } } impl WsStream @@ -488,7 +683,6 @@ where if !*this.closed { loop { - this = self.as_mut().project(); match Pin::new(&mut this.stream).poll_next(cx) { Poll::Ready(Some(Ok(chunk))) => { this.buf.extend_from_slice(&chunk[..]); @@ -540,9 +734,12 @@ where #[cfg(test)] mod tests { + use actix_web::{ + http::{header, Method}, + test::TestRequest, + }; + use super::*; - use actix_web::http::{header, Method}; - use actix_web::test::TestRequest; #[test] fn test_handshake() { diff --git a/actix-web-actors/tests/test_ws.rs b/actix-web-actors/tests/test_ws.rs index 912480ae4..a9eb37699 100644 --- a/actix-web-actors/tests/test_ws.rs +++ b/actix-web-actors/tests/test_ws.rs @@ -1,6 +1,7 @@ use actix::prelude::*; -use actix_web::{test, web, App, HttpRequest}; -use actix_web_actors::*; +use actix_http::ws::Codec; +use actix_web::{web, App, HttpRequest}; +use actix_web_actors::ws; use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; @@ -12,37 +13,34 @@ impl Actor for Ws { impl StreamHandler> for Ws { fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { - match msg.unwrap() { - ws::Message::Ping(msg) => ctx.pong(&msg), - ws::Message::Text(text) => ctx.text(text), - ws::Message::Binary(bin) => ctx.binary(bin), - ws::Message::Close(reason) => ctx.close(reason), - _ => {} + match msg { + Ok(ws::Message::Ping(msg)) => ctx.pong(&msg), + Ok(ws::Message::Text(text)) => ctx.text(text), + Ok(ws::Message::Binary(bin)) => ctx.binary(bin), + Ok(ws::Message::Close(reason)) => ctx.close(reason), + _ => ctx.close(Some(ws::CloseCode::Error.into())), } } } -#[actix_rt::test] -async fn test_simple() { - let mut srv = test::start(|| { - App::new().service(web::resource("/").to( - |req: HttpRequest, stream: web::Payload| async move { ws::start(Ws, &req, stream) }, - )) - }); +const MAX_FRAME_SIZE: usize = 10_000; +const DEFAULT_FRAME_SIZE: usize = 10; +async fn common_test_code(mut srv: actix_test::TestServer, frame_size: usize) { // client service let mut framed = srv.ws().await.unwrap(); - framed.send(ws::Message::Text("text".into())).await.unwrap(); + framed.send(ws::Message::Text("text".into())).await.unwrap(); let item = framed.next().await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); + let bytes = Bytes::from(vec![0; frame_size]); framed - .send(ws::Message::Binary("text".into())) + .send(ws::Message::Binary(bytes.clone())) .await .unwrap(); let item = framed.next().await.unwrap().unwrap(); - assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text"))); + assert_eq!(item, ws::Frame::Binary(bytes)); framed.send(ws::Message::Ping("text".into())).await.unwrap(); let item = framed.next().await.unwrap().unwrap(); @@ -52,7 +50,137 @@ async fn test_simple() { .send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) .await .unwrap(); - let item = framed.next().await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into()))); } + +#[actix_rt::test] +async fn simple_builder() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream).start() + }, + )) + }); + + common_test_code(srv, DEFAULT_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn builder_with_frame_size() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .frame_size(MAX_FRAME_SIZE) + .start() + }, + )) + }); + + common_test_code(srv, MAX_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn builder_with_frame_size_exceeded() { + const MAX_FRAME_SIZE: usize = 64; + + let mut srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .frame_size(MAX_FRAME_SIZE) + .start() + }, + )) + }); + + // client service + let mut framed = srv.ws().await.unwrap(); + + // create a request with a frame size larger than expected + let bytes = Bytes::from(vec![0; MAX_FRAME_SIZE + 1]); + framed.send(ws::Message::Binary(bytes)).await.unwrap(); + + let frame = framed.next().await.unwrap().unwrap(); + let close_reason = match frame { + ws::Frame::Close(Some(reason)) => reason, + _ => panic!("close frame expected"), + }; + assert_eq!(close_reason.code, ws::CloseCode::Error); +} + +#[actix_rt::test] +async fn builder_with_codec() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .codec(Codec::new()) + .start() + }, + )) + }); + + common_test_code(srv, DEFAULT_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn builder_with_protocols() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .protocols(&["A", "B"]) + .start() + }, + )) + }); + + common_test_code(srv, DEFAULT_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn builder_with_codec_and_frame_size() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .codec(Codec::new()) + .frame_size(MAX_FRAME_SIZE) + .start() + }, + )) + }); + + common_test_code(srv, DEFAULT_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn builder_full() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { + ws::WsResponseBuilder::new(Ws, &req, stream) + .frame_size(MAX_FRAME_SIZE) + .codec(Codec::new()) + .protocols(&["A", "B"]) + .start() + }, + )) + }); + + common_test_code(srv, MAX_FRAME_SIZE).await; +} + +#[actix_rt::test] +async fn simple_start() { + let srv = actix_test::start(|| { + App::new().service(web::resource("/").to( + |req: HttpRequest, stream: web::Payload| async move { ws::start(Ws, &req, stream) }, + )) + }); + + common_test_code(srv, DEFAULT_FRAME_SIZE).await; +} diff --git a/actix-web-codegen/CHANGES.md b/actix-web-codegen/CHANGES.md index 2ce728aad..c044ff74d 100644 --- a/actix-web-codegen/CHANGES.md +++ b/actix-web-codegen/CHANGES.md @@ -3,70 +3,106 @@ ## Unreleased - 2021-xx-xx +## 0.5.0-rc.1 - 2022-01-04 +- Minimum supported Rust version (MSRV) is now 1.54. + + +## 0.5.0-beta.6 - 2021-12-11 +- No significant changes since `0.5.0-beta.5`. + + +## 0.5.0-beta.5 - 2021-10-20 +- Improve error recovery potential when macro input is invalid. [#2410] +- Add `#[actix_web::test]` macro for setting up tests with a runtime. [#2409] +- Minimum supported Rust version (MSRV) is now 1.52. + +[#2410]: https://github.com/actix/actix-web/pull/2410 +[#2409]: https://github.com/actix/actix-web/pull/2409 + + +## 0.5.0-beta.4 - 2021-09-09 +- In routing macros, paths are now validated at compile time. [#2350] +- Minimum supported Rust version (MSRV) is now 1.51. + +[#2350]: https://github.com/actix/actix-web/pull/2350 + + +## 0.5.0-beta.3 - 2021-06-17 +- No notable changes. + + +## 0.5.0-beta.2 - 2021-03-09 +- Preserve doc comments when using route macros. [#2022] +- Add `name` attribute to `route` macro. [#1934] + +[#2022]: https://github.com/actix/actix-web/pull/2022 +[#1934]: https://github.com/actix/actix-web/pull/1934 + + ## 0.5.0-beta.1 - 2021-02-10 -* Use new call signature for `System::new`. +- Use new call signature for `System::new`. ## 0.4.0 - 2020-09-20 -* Added compile success and failure testing. [#1677] -* Add `route` macro for supporting multiple HTTP methods guards. [#1674] +- Added compile success and failure testing. [#1677] +- Add `route` macro for supporting multiple HTTP methods guards. [#1674] [#1677]: https://github.com/actix/actix-web/pull/1677 [#1674]: https://github.com/actix/actix-web/pull/1674 ## 0.3.0 - 2020-09-11 -* No significant changes from `0.3.0-beta.1`. +- No significant changes from `0.3.0-beta.1`. ## 0.3.0-beta.1 - 2020-07-14 -* Add main entry-point macro that uses re-exported runtime. [#1559] +- Add main entry-point macro that uses re-exported runtime. [#1559] [#1559]: https://github.com/actix/actix-web/pull/1559 ## 0.2.2 - 2020-05-23 -* Add resource middleware on actix-web-codegen [#1467] +- Add resource middleware on actix-web-codegen [#1467] [#1467]: https://github.com/actix/actix-web/pull/1467 ## 0.2.1 - 2020-02-25 -* Add `#[allow(missing_docs)]` attribute to generated structs [#1368] -* Allow the handler function to be named as `config` [#1290] +- Add `#[allow(missing_docs)]` attribute to generated structs [#1368] +- Allow the handler function to be named as `config` [#1290] [#1368]: https://github.com/actix/actix-web/issues/1368 [#1290]: https://github.com/actix/actix-web/issues/1290 ## 0.2.0 - 2019-12-13 -* Generate code for actix-web 2.0 +- Generate code for actix-web 2.0 ## 0.1.3 - 2019-10-14 -* Bump up `syn` & `quote` to 1.0 -* Provide better error message +- Bump up `syn` & `quote` to 1.0 +- Provide better error message ## 0.1.2 - 2019-06-04 -* Add macros for head, options, trace, connect and patch http methods +- Add macros for head, options, trace, connect and patch http methods ## 0.1.1 - 2019-06-01 -* Add syn "extra-traits" feature +- Add syn "extra-traits" feature ## 0.1.0 - 2019-05-18 -* Release +- Release ## 0.1.0-beta.1 - 2019-04-20 -* Gen code for actix-web 1.0.0-beta.1 +- Gen code for actix-web 1.0.0-beta.1 ## 0.1.0-alpha.6 - 2019-04-14 -* Gen code for actix-web 1.0.0-alpha.6 +- Gen code for actix-web 1.0.0-alpha.6 ## 0.1.0-alpha.1 - 2019-03-28 -* Initial impl +- Initial impl diff --git a/actix-web-codegen/Cargo.toml b/actix-web-codegen/Cargo.toml index 886d9ac3e..9b1887012 100644 --- a/actix-web-codegen/Cargo.toml +++ b/actix-web-codegen/Cargo.toml @@ -1,12 +1,13 @@ [package] name = "actix-web-codegen" -version = "0.5.0-beta.1" +version = "0.5.0-rc.1" description = "Routing and runtime macros for Actix Web" -readme = "README.md" homepage = "https://actix.rs" -repository = "https://github.com/actix/actix-web" -documentation = "https://docs.rs/actix-web-codegen" -authors = ["Nikolay Kim "] +repository = "https://github.com/actix/actix-web.git" +authors = [ + "Nikolay Kim ", + "Rob Ede ", +] license = "MIT OR Apache-2.0" edition = "2018" @@ -14,13 +15,18 @@ edition = "2018" proc-macro = true [dependencies] +actix-router = "0.5.0-beta.4" +proc-macro2 = "1" quote = "1" syn = { version = "1", features = ["full", "parsing"] } -proc-macro2 = "1" [dev-dependencies] -actix-rt = "2" -actix-web = "4.0.0-beta.3" -futures-util = { version = "0.3.7", default-features = false } +actix-macros = "0.2.3" +actix-rt = "2.2" +actix-test = "0.1.0-beta.11" +actix-utils = "3.0.0" +actix-web = "4.0.0-beta.19" + +futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } trybuild = "1" rustversion = "1" diff --git a/actix-web-codegen/README.md b/actix-web-codegen/README.md index 5820bb443..1fd97184c 100644 --- a/actix-web-codegen/README.md +++ b/actix-web-codegen/README.md @@ -3,19 +3,18 @@ > Routing and runtime macros for Actix Web. [![crates.io](https://img.shields.io/crates/v/actix-web-codegen?label=latest)](https://crates.io/crates/actix-web-codegen) -[![Documentation](https://docs.rs/actix-web-codegen/badge.svg?version=0.5.0-beta.1)](https://docs.rs/actix-web-codegen/0.5.0-beta.1) -[![Version](https://img.shields.io/badge/rustc-1.46+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.46.html) +[![Documentation](https://docs.rs/actix-web-codegen/badge.svg?version=0.5.0-rc.1)](https://docs.rs/actix-web-codegen/0.5.0-rc.1) +[![Version](https://img.shields.io/badge/rustc-1.54+-ab6000.svg)](https://blog.rust-lang.org/2021/05/06/Rust-1.54.0.html) ![License](https://img.shields.io/crates/l/actix-web-codegen.svg)
-[![dependency status](https://deps.rs/crate/actix-web-codegen/0.5.0-beta.1/status.svg)](https://deps.rs/crate/actix-web-codegen/0.5.0-beta.1) +[![dependency status](https://deps.rs/crate/actix-web-codegen/0.5.0-rc.1/status.svg)](https://deps.rs/crate/actix-web-codegen/0.5.0-rc.1) [![Download](https://img.shields.io/crates/d/actix-web-codegen.svg)](https://crates.io/crates/actix-web-codegen) -[![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/actix-web-codegen) -- [Chat on Gitter](https://gitter.im/actix/actix-web) -- Minimum supported Rust version: 1.46 or later. +- Minimum Supported Rust Version (MSRV): 1.54 ## Compile Testing diff --git a/actix-web-codegen/src/lib.rs b/actix-web-codegen/src/lib.rs index 670d82ce9..52cfc0d8f 100644 --- a/actix-web-codegen/src/lib.rs +++ b/actix-web-codegen/src/lib.rs @@ -57,20 +57,24 @@ //! [DELETE]: macro@delete #![recursion_limit = "512"] +#![deny(rust_2018_idioms, nonstandard_style)] +#![warn(future_incompatible)] use proc_macro::TokenStream; +use quote::quote; mod route; /// Creates resource handler, allowing multiple HTTP method guards. /// /// # Syntax -/// ```text +/// ```plain /// #[route("path", method="HTTP_METHOD"[, attributes])] /// ``` /// /// # Attributes /// - `"path"` - Raw literal string with path for which to register handler. +/// - `name="resource_name"` - Specifies resource name for the handler. If not set, the function name of handler is used. /// - `method="HTTP_METHOD"` - Registers HTTP method to provide guard for. Upper-case string, "GET", "POST" for example. /// - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard` /// - `wrap="Middleware"` - Registers a resource middleware. @@ -81,7 +85,7 @@ mod route; /// /// # Example /// -/// ```rust +/// ``` /// # use actix_web::HttpResponse; /// # use actix_web_codegen::route; /// #[route("/test", method="GET", method="HEAD")] @@ -110,12 +114,13 @@ concat!(" Creates route handler with `actix_web::guard::", stringify!($variant), "`. # Syntax -```text +```plain #[", stringify!($method), r#"("path"[, attributes])] ``` # Attributes - `"path"` - Raw literal string with path for which to register handler. +- `name="resource_name"` - Specifies resource name for the handler. If not set, the function name of handler is used. - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard`. - `wrap="Middleware"` - Registers a resource middleware. @@ -125,7 +130,7 @@ code, e.g `my_guard` or `my_module::my_guard`. # Example -```rust +``` # use actix_web::HttpResponse; # use actix_web_codegen::"#, stringify!($method), "; #[", stringify!($method), r#"("/")] @@ -155,41 +160,41 @@ method_macro! { } /// Marks async main function as the actix system entry-point. -/// -/// # Actix Web Re-export -/// This macro can be applied with `#[actix_web::main]` when used in Actix Web applications. -/// + /// # Examples -/// ```rust -/// #[actix_web_codegen::main] +/// ``` +/// #[actix_web::main] /// async fn main() { /// async { println!("Hello world"); }.await /// } /// ``` #[proc_macro_attribute] pub fn main(_: TokenStream, item: TokenStream) -> TokenStream { - use quote::quote; - - let mut input = syn::parse_macro_input!(item as syn::ItemFn); - let attrs = &input.attrs; - let vis = &input.vis; - let sig = &mut input.sig; - let body = &input.block; - - if sig.asyncness.is_none() { - return syn::Error::new_spanned(sig.fn_token, "only async fn is supported") - .to_compile_error() - .into(); - } - - sig.asyncness = None; - - (quote! { - #(#attrs)* - #vis #sig { - actix_web::rt::System::new() - .block_on(async move { #body }) - } + let mut output: TokenStream = (quote! { + #[::actix_web::rt::main(system = "::actix_web::rt::System")] }) - .into() + .into(); + + output.extend(item); + output +} + +/// Marks async test functions to use the actix system entry-point. +/// +/// # Examples +/// ``` +/// #[actix_web::test] +/// async fn test() { +/// assert_eq!(async { "Hello world" }.await, "Hello world"); +/// } +/// ``` +#[proc_macro_attribute] +pub fn test(_: TokenStream, item: TokenStream) -> TokenStream { + let mut output: TokenStream = (quote! { + #[::actix_web::rt::test(system = "::actix_web::rt::System")] + }) + .into(); + + output.extend(item); + output } diff --git a/actix-web-codegen/src/route.rs b/actix-web-codegen/src/route.rs index ddbd42454..a4472efd2 100644 --- a/actix-web-codegen/src/route.rs +++ b/actix-web-codegen/src/route.rs @@ -1,12 +1,10 @@ -extern crate proc_macro; - -use std::collections::HashSet; -use std::convert::TryFrom; +use std::{collections::HashSet, convert::TryFrom}; +use actix_router::ResourceDef; use proc_macro::TokenStream; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{format_ident, quote, ToTokens, TokenStreamExt}; -use syn::{parse_macro_input, AttributeArgs, Ident, NestedMeta}; +use syn::{parse_macro_input, AttributeArgs, Ident, LitStr, NestedMeta}; enum ResourceType { Async, @@ -78,6 +76,7 @@ impl TryFrom<&syn::LitStr> for MethodType { struct Args { path: syn::LitStr, + resource_name: Option, guards: Vec, wrappers: Vec, methods: HashSet, @@ -86,6 +85,7 @@ struct Args { impl Args { fn new(args: AttributeArgs, method: Option) -> syn::Result { let mut path = None; + let mut resource_name = None; let mut guards = Vec::new(); let mut wrappers = Vec::new(); let mut methods = HashSet::new(); @@ -99,6 +99,7 @@ impl Args { match arg { NestedMeta::Lit(syn::Lit::Str(lit)) => match path { None => { + let _ = ResourceDef::new(lit.value()); path = Some(lit); } _ => { @@ -109,7 +110,16 @@ impl Args { } }, NestedMeta::Meta(syn::Meta::NameValue(nv)) => { - if nv.path.is_ident("guard") { + if nv.path.is_ident("name") { + if let syn::Lit::Str(lit) = nv.lit { + resource_name = Some(lit); + } else { + return Err(syn::Error::new_spanned( + nv.lit, + "Attribute name expects literal string!", + )); + } + } else if nv.path.is_ident("guard") { if let syn::Lit::Str(lit) = nv.lit { guards.push(Ident::new(&lit.value(), Span::call_site())); } else { @@ -164,6 +174,7 @@ impl Args { } Ok(Args { path: path.unwrap(), + resource_name, guards, wrappers, methods, @@ -176,6 +187,9 @@ pub struct Route { args: Args, ast: syn::ItemFn, resource_type: ResourceType, + + /// The doc comment attributes to copy to generated struct, if any. + doc_attributes: Vec, } fn guess_resource_type(typ: &syn::Type) -> ResourceType { @@ -203,7 +217,7 @@ fn guess_resource_type(typ: &syn::Type) -> ResourceType { impl Route { pub fn new( args: AttributeArgs, - input: TokenStream, + ast: syn::ItemFn, method: Option, ) -> syn::Result { if args.is_empty() { @@ -212,15 +226,23 @@ impl Route { format!( r#"invalid service definition, expected #[{}("")]"#, method - .map(|it| it.as_str()) - .unwrap_or("route") + .map_or("route", |it| it.as_str()) .to_ascii_lowercase() ), )); } - let ast: syn::ItemFn = syn::parse(input)?; + let name = ast.sig.ident.clone(); + // Try and pull out the doc comments so that we can reapply them to the generated struct. + // Note that multi line doc comments are converted to multiple doc attributes. + let doc_attributes = ast + .attrs + .iter() + .filter(|attr| attr.path.is_ident("doc")) + .cloned() + .collect(); + let args = Args::new(args, method)?; if args.methods.is_empty() { return Err(syn::Error::new( @@ -248,6 +270,7 @@ impl Route { args, ast, resource_type, + doc_attributes, }) } } @@ -260,13 +283,17 @@ impl ToTokens for Route { args: Args { path, + resource_name, guards, wrappers, methods, }, resource_type, + doc_attributes, } = self; - let resource_name = name.to_string(); + let resource_name = resource_name + .as_ref() + .map_or_else(|| name.to_string(), LitStr::value); let method_guards = { let mut others = methods.iter(); // unwrapping since length is checked to be at least one @@ -287,6 +314,7 @@ impl ToTokens for Route { }; let stream = quote! { + #(#doc_attributes)* #[allow(non_camel_case_types, missing_docs)] pub struct #name; @@ -315,8 +343,28 @@ pub(crate) fn with_method( input: TokenStream, ) -> TokenStream { let args = parse_macro_input!(args as syn::AttributeArgs); - match Route::new(args, input, method) { + + let ast = match syn::parse::(input.clone()) { + Ok(ast) => ast, + // on parse error, make IDEs happy; see fn docs + Err(err) => return input_and_compile_error(input, err), + }; + + match Route::new(args, ast, method) { Ok(route) => route.into_token_stream().into(), - Err(err) => err.to_compile_error().into(), + // on macro related error, make IDEs happy; see fn docs + Err(err) => input_and_compile_error(input, err), } } + +/// Converts the error to a token stream and appends it to the original input. +/// +/// Returning the original input in addition to the error is good for IDEs which can gracefully +/// recover and show more precise errors within the macro body. +/// +/// See for more info. +fn input_and_compile_error(mut item: TokenStream, err: syn::Error) -> TokenStream { + let compile_err = TokenStream::from(err.to_compile_error()); + item.extend(compile_err); + item +} diff --git a/actix-web-codegen/tests/test_macro.rs b/actix-web-codegen/tests/test_macro.rs index 34cdad649..769cf2bc3 100644 --- a/actix-web-codegen/tests/test_macro.rs +++ b/actix-web-codegen/tests/test_macro.rs @@ -1,11 +1,17 @@ use std::future::Future; -use std::task::{Context, Poll}; -use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform}; -use actix_web::http::header::{HeaderName, HeaderValue}; -use actix_web::{http, test, web::Path, App, Error, HttpResponse, Responder}; +use actix_utils::future::{ok, Ready}; +use actix_web::{ + dev::{Service, ServiceRequest, ServiceResponse, Transform}, + http::{ + self, + header::{HeaderName, HeaderValue}, + StatusCode, + }, + web, App, Error, HttpResponse, Responder, +}; use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, route, trace}; -use futures_util::future::{self, LocalBoxFuture}; +use futures_core::future::LocalBoxFuture; // Make sure that we can name function as 'config' #[get("/config")] @@ -55,26 +61,26 @@ async fn trace_test() -> impl Responder { #[get("/test")] fn auto_async() -> impl Future> { - future::ok(HttpResponse::Ok().finish()) + ok(HttpResponse::Ok().finish()) } #[get("/test")] fn auto_sync() -> impl Future> { - future::ok(HttpResponse::Ok().finish()) + ok(HttpResponse::Ok().finish()) } #[put("/test/{param}")] -async fn put_param_test(_: Path) -> impl Responder { +async fn put_param_test(_: web::Path) -> impl Responder { HttpResponse::Created() } #[delete("/test/{param}")] -async fn delete_param_test(_: Path) -> impl Responder { +async fn delete_param_test(_: web::Path) -> impl Responder { HttpResponse::NoContent() } #[get("/test/{param}")] -async fn get_param_test(_: Path) -> impl Responder { +async fn get_param_test(_: web::Path) -> impl Responder { HttpResponse::Ok() } @@ -83,6 +89,13 @@ async fn route_test() -> impl Responder { HttpResponse::Ok() } +#[get("/custom_resource_name", name = "custom")] +async fn custom_resource_name_test<'a>(req: actix_web::HttpRequest) -> impl Responder { + assert!(req.url_for_static("custom").is_ok()); + assert!(req.url_for_static("custom_resource_name_test").is_err()); + HttpResponse::Ok() +} + pub struct ChangeStatusCode; impl Transform for ChangeStatusCode @@ -95,10 +108,10 @@ where type Error = Error; type Transform = ChangeStatusCodeMiddleware; type InitError = (); - type Future = future::Ready>; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - future::ok(ChangeStatusCodeMiddleware { service }) + ok(ChangeStatusCodeMiddleware { service }) } } @@ -116,9 +129,7 @@ where type Error = Error; type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } + actix_web::dev::forward_ready!(service); fn call(&self, req: ServiceRequest) -> Self::Future { let fut = self.service.call(req); @@ -135,13 +146,14 @@ where } #[get("/test/wrap", wrap = "ChangeStatusCode")] -async fn get_wrap(_: Path) -> impl Responder { +async fn get_wrap(_: web::Path) -> impl Responder { + // panic!("actually never gets called because path failed to extract"); HttpResponse::Ok() } #[actix_rt::test] async fn test_params() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new() .service(get_param_test) .service(put_param_test) @@ -163,7 +175,7 @@ async fn test_params() { #[actix_rt::test] async fn test_body() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new() .service(post_test) .service(put_test) @@ -174,6 +186,7 @@ async fn test_body() { .service(patch_test) .service(test_handler) .service(route_test) + .service(custom_resource_name_test) }); let request = srv.request(http::Method::GET, srv.url("/test")); let response = request.send().await.unwrap(); @@ -228,22 +241,30 @@ async fn test_body() { let request = srv.request(http::Method::PATCH, srv.url("/multi")); let response = request.send().await.unwrap(); assert!(!response.status().is_success()); + + let request = srv.request(http::Method::GET, srv.url("/custom_resource_name")); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); } #[actix_rt::test] async fn test_auto_async() { - let srv = test::start(|| App::new().service(auto_async)); + let srv = actix_test::start(|| App::new().service(auto_async)); let request = srv.request(http::Method::GET, srv.url("/test")); let response = request.send().await.unwrap(); assert!(response.status().is_success()); } -#[actix_rt::test] +#[actix_web::test] async fn test_wrap() { - let srv = test::start(|| App::new().service(get_wrap)); + let srv = actix_test::start(|| App::new().service(get_wrap)); let request = srv.request(http::Method::GET, srv.url("/test/wrap")); - let response = request.send().await.unwrap(); + let mut response = request.send().await.unwrap(); + assert_eq!(response.status(), StatusCode::NOT_FOUND); assert!(response.headers().contains_key("custom-header")); + let body = response.body().await.unwrap(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert!(body.contains("wrong number of parameters")); } diff --git a/actix-web-codegen/tests/trybuild.rs b/actix-web-codegen/tests/trybuild.rs index d2d8a38f5..b2d9ce186 100644 --- a/actix-web-codegen/tests/trybuild.rs +++ b/actix-web-codegen/tests/trybuild.rs @@ -1,3 +1,4 @@ +#[rustversion::stable(1.54)] // MSRV #[test] fn compile_macros() { let t = trybuild::TestCases::new(); @@ -9,12 +10,9 @@ fn compile_macros() { t.compile_fail("tests/trybuild/route-missing-method-fail.rs"); t.compile_fail("tests/trybuild/route-duplicate-method-fail.rs"); t.compile_fail("tests/trybuild/route-unexpected-method-fail.rs"); + t.compile_fail("tests/trybuild/route-malformed-path-fail.rs"); + + t.pass("tests/trybuild/docstring-ok.rs"); + + t.pass("tests/trybuild/test-runtime.rs"); } - -// #[rustversion::not(nightly)] -// fn skip_on_nightly(t: &trybuild::TestCases) { -// -// } - -// #[rustversion::nightly] -// fn skip_on_nightly(_t: &trybuild::TestCases) {} diff --git a/actix-web-codegen/tests/trybuild/docstring-ok.rs b/actix-web-codegen/tests/trybuild/docstring-ok.rs new file mode 100644 index 000000000..4cf310be5 --- /dev/null +++ b/actix-web-codegen/tests/trybuild/docstring-ok.rs @@ -0,0 +1,17 @@ +use actix_web::{Responder, HttpResponse, App}; +use actix_web_codegen::*; + +/// doc comments shouldn't break anything +#[get("/")] +async fn index() -> impl Responder { + HttpResponse::Ok() +} + +#[actix_web::main] +async fn main() { + let srv = actix_test::start(|| App::new().service(index)); + + let request = srv.get("/"); + let response = request.send().await.unwrap(); + assert!(response.status().is_success()); +} diff --git a/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.rs b/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.rs index 9a38050f7..9322b4895 100644 --- a/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.rs +++ b/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.rs @@ -7,9 +7,9 @@ async fn index() -> String { #[actix_web::main] async fn main() { - use actix_web::{App, test}; + use actix_web::App; - let srv = test::start(|| App::new().service(index)); + let srv = actix_test::start(|| App::new().service(index)); let request = srv.get("/"); let response = request.send().await.unwrap(); diff --git a/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.stderr b/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.stderr index f3eda68af..90cff1b1c 100644 --- a/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.stderr +++ b/actix-web-codegen/tests/trybuild/route-duplicate-method-fail.stderr @@ -4,8 +4,8 @@ error: HTTP method defined more than once: `GET` 3 | #[route("/", method="GET", method="GET")] | ^^^^^ -error[E0425]: cannot find value `index` in this scope - --> $DIR/route-duplicate-method-fail.rs:12:49 +error[E0277]: the trait bound `fn() -> impl std::future::Future {index}: HttpServiceFactory` is not satisfied + --> $DIR/route-duplicate-method-fail.rs:12:55 | -12 | let srv = test::start(|| App::new().service(index)); - | ^^^^^ not found in this scope +12 | let srv = actix_test::start(|| App::new().service(index)); + | ^^^^^ the trait `HttpServiceFactory` is not implemented for `fn() -> impl std::future::Future {index}` diff --git a/actix-web-codegen/tests/trybuild/route-malformed-path-fail.rs b/actix-web-codegen/tests/trybuild/route-malformed-path-fail.rs new file mode 100644 index 000000000..1258a6f2f --- /dev/null +++ b/actix-web-codegen/tests/trybuild/route-malformed-path-fail.rs @@ -0,0 +1,33 @@ +use actix_web_codegen::get; + +#[get("/{")] +async fn zero() -> &'static str { + "malformed resource def" +} + +#[get("/{foo")] +async fn one() -> &'static str { + "malformed resource def" +} + +#[get("/{}")] +async fn two() -> &'static str { + "malformed resource def" +} + +#[get("/*")] +async fn three() -> &'static str { + "malformed resource def" +} + +#[get("/{tail:\\d+}*")] +async fn four() -> &'static str { + "malformed resource def" +} + +#[get("/{a}/{b}/{c}/{d}/{e}/{f}/{g}/{h}/{i}/{j}/{k}/{l}/{m}/{n}/{o}/{p}/{q}")] +async fn five() -> &'static str { + "malformed resource def" +} + +fn main() {} diff --git a/actix-web-codegen/tests/trybuild/route-malformed-path-fail.stderr b/actix-web-codegen/tests/trybuild/route-malformed-path-fail.stderr new file mode 100644 index 000000000..93c510109 --- /dev/null +++ b/actix-web-codegen/tests/trybuild/route-malformed-path-fail.stderr @@ -0,0 +1,42 @@ +error: custom attribute panicked + --> $DIR/route-malformed-path-fail.rs:3:1 + | +3 | #[get("/{")] + | ^^^^^^^^^^^^ + | + = help: message: pattern "{" contains malformed dynamic segment + +error: custom attribute panicked + --> $DIR/route-malformed-path-fail.rs:8:1 + | +8 | #[get("/{foo")] + | ^^^^^^^^^^^^^^^ + | + = help: message: pattern "{foo" contains malformed dynamic segment + +error: custom attribute panicked + --> $DIR/route-malformed-path-fail.rs:13:1 + | +13 | #[get("/{}")] + | ^^^^^^^^^^^^^ + | + = help: message: Wrong path pattern: "/{}" regex parse error: + ((?s-m)^/(?P<>[^/]+))$ + ^ + error: empty capture group name + +error: custom attribute panicked + --> $DIR/route-malformed-path-fail.rs:23:1 + | +23 | #[get("/{tail:\\d+}*")] + | ^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: message: custom regex is not supported for tail match + +error: custom attribute panicked + --> $DIR/route-malformed-path-fail.rs:28:1 + | +28 | #[get("/{a}/{b}/{c}/{d}/{e}/{f}/{g}/{h}/{i}/{j}/{k}/{l}/{m}/{n}/{o}/{p}/{q}")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: message: Only 16 dynamic segments are allowed, provided: 17 diff --git a/actix-web-codegen/tests/trybuild/route-missing-method-fail.rs b/actix-web-codegen/tests/trybuild/route-missing-method-fail.rs index ce87a55a4..cd43c0669 100644 --- a/actix-web-codegen/tests/trybuild/route-missing-method-fail.rs +++ b/actix-web-codegen/tests/trybuild/route-missing-method-fail.rs @@ -7,9 +7,9 @@ async fn index() -> String { #[actix_web::main] async fn main() { - use actix_web::{App, test}; + use actix_web::App; - let srv = test::start(|| App::new().service(index)); + let srv = actix_test::start(|| App::new().service(index)); let request = srv.get("/"); let response = request.send().await.unwrap(); diff --git a/actix-web-codegen/tests/trybuild/route-missing-method-fail.stderr b/actix-web-codegen/tests/trybuild/route-missing-method-fail.stderr index 0518a61ed..b1cefafde 100644 --- a/actix-web-codegen/tests/trybuild/route-missing-method-fail.stderr +++ b/actix-web-codegen/tests/trybuild/route-missing-method-fail.stderr @@ -1,13 +1,13 @@ error: The #[route(..)] macro requires at least one `method` attribute - --> $DIR/route-missing-method-fail.rs:3:1 + --> tests/trybuild/route-missing-method-fail.rs:3:1 | 3 | #[route("/")] | ^^^^^^^^^^^^^ | - = note: this error originates in an attribute macro (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the attribute macro `route` (in Nightly builds, run with -Z macro-backtrace for more info) -error[E0425]: cannot find value `index` in this scope - --> $DIR/route-missing-method-fail.rs:12:49 +error[E0277]: the trait bound `fn() -> impl std::future::Future {index}: HttpServiceFactory` is not satisfied + --> tests/trybuild/route-missing-method-fail.rs:12:55 | -12 | let srv = test::start(|| App::new().service(index)); - | ^^^^^ not found in this scope +12 | let srv = actix_test::start(|| App::new().service(index)); + | ^^^^^ the trait `HttpServiceFactory` is not implemented for `fn() -> impl std::future::Future {index}` diff --git a/actix-web-codegen/tests/trybuild/route-ok.rs b/actix-web-codegen/tests/trybuild/route-ok.rs index c4f679604..e1082e88e 100644 --- a/actix-web-codegen/tests/trybuild/route-ok.rs +++ b/actix-web-codegen/tests/trybuild/route-ok.rs @@ -7,9 +7,9 @@ async fn index() -> String { #[actix_web::main] async fn main() { - use actix_web::{App, test}; + use actix_web::App; - let srv = test::start(|| App::new().service(index)); + let srv = actix_test::start(|| App::new().service(index)); let request = srv.get("/"); let response = request.send().await.unwrap(); diff --git a/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.rs b/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.rs index 28cd1344c..1a50e01bc 100644 --- a/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.rs +++ b/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.rs @@ -7,9 +7,9 @@ async fn index() -> String { #[actix_web::main] async fn main() { - use actix_web::{App, test}; + use actix_web::App; - let srv = test::start(|| App::new().service(index)); + let srv = actix_test::start(|| App::new().service(index)); let request = srv.get("/"); let response = request.send().await.unwrap(); diff --git a/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.stderr b/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.stderr index 9d87f310b..dda366067 100644 --- a/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.stderr +++ b/actix-web-codegen/tests/trybuild/route-unexpected-method-fail.stderr @@ -4,8 +4,8 @@ error: Unexpected HTTP method: `UNEXPECTED` 3 | #[route("/", method="UNEXPECTED")] | ^^^^^^^^^^^^ -error[E0425]: cannot find value `index` in this scope - --> $DIR/route-unexpected-method-fail.rs:12:49 +error[E0277]: the trait bound `fn() -> impl std::future::Future {index}: HttpServiceFactory` is not satisfied + --> $DIR/route-unexpected-method-fail.rs:12:55 | -12 | let srv = test::start(|| App::new().service(index)); - | ^^^^^ not found in this scope +12 | let srv = actix_test::start(|| App::new().service(index)); + | ^^^^^ the trait `HttpServiceFactory` is not implemented for `fn() -> impl std::future::Future {index}` diff --git a/actix-web-codegen/tests/trybuild/simple.rs b/actix-web-codegen/tests/trybuild/simple.rs index 761b04905..8170edbfa 100644 --- a/actix-web-codegen/tests/trybuild/simple.rs +++ b/actix-web-codegen/tests/trybuild/simple.rs @@ -1,4 +1,4 @@ -use actix_web::{Responder, HttpResponse, App, test}; +use actix_web::{Responder, HttpResponse, App}; use actix_web_codegen::*; #[get("/config")] @@ -8,7 +8,7 @@ async fn config() -> impl Responder { #[actix_web::main] async fn main() { - let srv = test::start(|| App::new().service(config)); + let srv = actix_test::start(|| App::new().service(config)); let request = srv.get("/config"); let response = request.send().await.unwrap(); diff --git a/actix-web-codegen/tests/trybuild/test-runtime.rs b/actix-web-codegen/tests/trybuild/test-runtime.rs new file mode 100644 index 000000000..0b901b258 --- /dev/null +++ b/actix-web-codegen/tests/trybuild/test-runtime.rs @@ -0,0 +1,6 @@ +#[actix_web::test] +async fn my_test() { + assert!(async { 1 }.await, 1); +} + +fn main() {} diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 9224f414d..f2c81ef25 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -1,26 +1,149 @@ # Changes ## Unreleased - 2021-xx-xx -### Changed -* Feature `cookies` is now optional and enabled by default. [#1981] + +## 3.0.0-beta.18 - 2022-01-04 +- Minimum supported Rust version (MSRV) is now 1.54. + + +## 3.0.0-beta.17 - 2021-12-29 +### Changed +- Update `cookie` dependency (re-exported) to `0.16`. [#2555] + +### Security +- `cookie` upgrade addresses [`RUSTSEC-2020-0071`]. + +[#2555]: https://github.com/actix/actix-web/pull/2555 +[`RUSTSEC-2020-0071`]: https://rustsec.org/advisories/RUSTSEC-2020-0071.html + + +## 3.0.0-beta.16 - 2021-12-29 +- `*::send_json` and `*::send_form` methods now receive `impl Serialize`. [#2553] +- `FrozenClientRequest::extra_header` now uses receives an `impl TryIntoHeaderPair`. [#2553] +- Remove unnecessary `Unpin` bounds on `*::send_stream`. [#2553] + +[#2553]: https://github.com/actix/actix-web/pull/2553 + + +## 3.0.0-beta.15 - 2021-12-27 +- Rename `Connector::{ssl => openssl}`. [#2503] +- Improve `Client` instantiation efficiency when using `openssl` by only building connectors once. [#2503] +- `ClientRequest::send_body` now takes an `impl MessageBody`. [#2546] +- Rename `MessageBody => ResponseBody` to avoid conflicts with `MessageBody` trait. [#2546] +- `impl Future` for `ResponseBody` no longer requires the body type be `Unpin`. [#2546] +- `impl Future` for `JsonBody` no longer requires the body type be `Unpin`. [#2546] +- `impl Stream` for `ClientResponse` no longer requires the body type be `Unpin`. [#2546] + +[#2503]: https://github.com/actix/actix-web/pull/2503 +[#2546]: https://github.com/actix/actix-web/pull/2546 + + +## 3.0.0-beta.14 - 2021-12-17 +- Add `ClientBuilder::add_default_header` and deprecate `ClientBuilder::header`. [#2510] + +[#2510]: https://github.com/actix/actix-web/pull/2510 + + +## 3.0.0-beta.13 - 2021-12-11 +- No significant changes since `3.0.0-beta.12`. + + +## 3.0.0-beta.12 - 2021-11-30 +- Update `actix-tls` to `3.0.0-rc.1`. [#2474] + +[#2474]: https://github.com/actix/actix-web/pull/2474 + + +## 3.0.0-beta.11 - 2021-11-22 +- No significant changes from `3.0.0-beta.10`. + + +## 3.0.0-beta.10 - 2021-11-15 +- No significant changes from `3.0.0-beta.9`. + + +## 3.0.0-beta.9 - 2021-10-20 +- Updated rustls to v0.20. [#2414] + +[#2414]: https://github.com/actix/actix-web/pull/2414 + + +## 3.0.0-beta.8 - 2021-09-09 +### Changed +- Send headers within the redirect requests. [#2310] + +[#2310]: https://github.com/actix/actix-web/pull/2310 + + +## 3.0.0-beta.7 - 2021-06-26 +### Changed +- Change compression algorithm features flags. [#2250] + +[#2250]: https://github.com/actix/actix-web/pull/2250 + + +## 3.0.0-beta.6 - 2021-06-17 +- No significant changes since 3.0.0-beta.5. + + +## 3.0.0-beta.5 - 2021-04-17 +### Removed +- Deprecated methods on `ClientRequest`: `if_true`, `if_some`. [#2148] + +[#2148]: https://github.com/actix/actix-web/pull/2148 + + +## 3.0.0-beta.4 - 2021-04-02 +### Added +- Add `Client::headers` to get default mut reference of `HeaderMap` of client object. [#2114] + +### Changed +- `ConnectorService` type is renamed to `BoxConnectorService`. [#2081] +- Fix http/https encoding when enabling `compress` feature. [#2116] +- Rename `TestResponse::header` to `append_header`, `set` to `insert_header`. `TestResponse` header + methods now take `TryIntoHeaderPair` tuples. [#2094] + +[#2081]: https://github.com/actix/actix-web/pull/2081 +[#2094]: https://github.com/actix/actix-web/pull/2094 +[#2114]: https://github.com/actix/actix-web/pull/2114 +[#2116]: https://github.com/actix/actix-web/pull/2116 + + +## 3.0.0-beta.3 - 2021-03-08 +### Added +- `ClientResponse::timeout` for set the timeout of collecting response body. [#1931] +- `ClientBuilder::local_address` for bind to a local ip address for this client. [#2024] + +### Changed +- Feature `cookies` is now optional and enabled by default. [#1981] +- `ClientBuilder::connector` method would take `actix_http::client::Connector` type. [#2008] +- Basic auth password now takes blank passwords as an empty string instead of Option. [#2050] + +### Removed +- `ClientBuilder::default` function [#2008] + +[#1931]: https://github.com/actix/actix-web/pull/1931 [#1981]: https://github.com/actix/actix-web/pull/1981 +[#2008]: https://github.com/actix/actix-web/pull/2008 +[#2024]: https://github.com/actix/actix-web/pull/2024 +[#2050]: https://github.com/actix/actix-web/pull/2050 ## 3.0.0-beta.2 - 2021-02-10 ### Added -* `ClientRequest::insert_header` method which allows using typed headers. [#1869] -* `ClientRequest::append_header` method which allows using typed headers. [#1869] -* `trust-dns` optional feature to enable `trust-dns-resolver` as client dns resolver. [#1969] +- `ClientRequest::insert_header` method which allows using typed headers. [#1869] +- `ClientRequest::append_header` method which allows using typed headers. [#1869] +- `trust-dns` optional feature to enable `trust-dns-resolver` as client dns resolver. [#1969] ### Changed -* Relax default timeout for `Connector` to 5 seconds(original 1 second). [#1905] +- Relax default timeout for `Connector` to 5 seconds(original 1 second). [#1905] ### Removed -* `ClientRequest::set`; use `ClientRequest::insert_header`. [#1869] -* `ClientRequest::set_header`; use `ClientRequest::insert_header`. [#1869] -* `ClientRequest::set_header_if_none`; use `ClientRequest::insert_header_if_none`. [#1869] -* `ClientRequest::header`; use `ClientRequest::append_header`. [#1869] +- `ClientRequest::set`; use `ClientRequest::insert_header`. [#1869] +- `ClientRequest::set_header`; use `ClientRequest::insert_header`. [#1869] +- `ClientRequest::set_header_if_none`; use `ClientRequest::insert_header_if_none`. [#1869] +- `ClientRequest::header`; use `ClientRequest::append_header`. [#1869] [#1869]: https://github.com/actix/actix-web/pull/1869 [#1905]: https://github.com/actix/actix-web/pull/1905 @@ -29,32 +152,32 @@ ## 3.0.0-beta.1 - 2021-01-07 ### Changed -* Update `rand` to `0.8` -* Update `bytes` to `1.0`. [#1813] -* Update `rust-tls` to `0.19`. [#1813] +- Update `rand` to `0.8` +- Update `bytes` to `1.0`. [#1813] +- Update `rust-tls` to `0.19`. [#1813] [#1813]: https://github.com/actix/actix-web/pull/1813 ## 2.0.3 - 2020-11-29 ### Fixed -* Ensure `actix-http` dependency uses same `serde_urlencoded`. +- Ensure `actix-http` dependency uses same `serde_urlencoded`. ## 2.0.2 - 2020-11-25 ### Changed -* Upgrade `serde_urlencoded` to `0.7`. [#1773] +- Upgrade `serde_urlencoded` to `0.7`. [#1773] [#1773]: https://github.com/actix/actix-web/pull/1773 ## 2.0.1 - 2020-10-30 ### Changed -* Upgrade `base64` to `0.13`. [#1744] -* Deprecate `ClientRequest::{if_some, if_true}`. [#1760] +- Upgrade `base64` to `0.13`. [#1744] +- Deprecate `ClientRequest::{if_some, if_true}`. [#1760] ### Fixed -* Use `Accept-Encoding: identity` instead of `Accept-Encoding: br` when no compression feature +- Use `Accept-Encoding: identity` instead of `Accept-Encoding: br` when no compression feature is enabled [#1737] [#1737]: https://github.com/actix/actix-web/pull/1737 @@ -64,209 +187,209 @@ ## 2.0.0 - 2020-09-11 ### Changed -* `Client::build` was renamed to `Client::builder`. +- `Client::build` was renamed to `Client::builder`. ## 2.0.0-beta.4 - 2020-09-09 ### Changed -* Update actix-codec & actix-tls dependencies. +- Update actix-codec & actix-tls dependencies. ## 2.0.0-beta.3 - 2020-08-17 ### Changed -* Update `rustls` to 0.18 +- Update `rustls` to 0.18 ## 2.0.0-beta.2 - 2020-07-21 ### Changed -* Update `actix-http` dependency to 2.0.0-beta.2 +- Update `actix-http` dependency to 2.0.0-beta.2 ## [2.0.0-beta.1] - 2020-07-14 ### Changed -* Update `actix-http` dependency to 2.0.0-beta.1 +- Update `actix-http` dependency to 2.0.0-beta.1 ## [2.0.0-alpha.2] - 2020-05-21 ### Changed -* Implement `std::error::Error` for our custom errors [#1422] -* Bump minimum supported Rust version to 1.40 -* Update `base64` dependency to 0.12 +- Implement `std::error::Error` for our custom errors [#1422] +- Bump minimum supported Rust version to 1.40 +- Update `base64` dependency to 0.12 [#1422]: https://github.com/actix/actix-web/pull/1422 ## [2.0.0-alpha.1] - 2020-03-11 -* Update `actix-http` dependency to 2.0.0-alpha.2 -* Update `rustls` dependency to 0.17 -* ClientBuilder accepts initial_window_size and initial_connection_window_size HTTP2 configuration -* ClientBuilder allowing to set max_http_version to limit HTTP version to be used +- Update `actix-http` dependency to 2.0.0-alpha.2 +- Update `rustls` dependency to 0.17 +- ClientBuilder accepts initial_window_size and initial_connection_window_size HTTP2 configuration +- ClientBuilder allowing to set max_http_version to limit HTTP version to be used ## [1.0.1] - 2019-12-15 -* Fix compilation with default features off +- Fix compilation with default features off ## [1.0.0] - 2019-12-13 -* Release +- Release ## [1.0.0-alpha.3] -* Migrate to `std::future` +- Migrate to `std::future` ## [0.2.8] - 2019-11-06 -* Add support for setting query from Serialize type for client request. +- Add support for setting query from Serialize type for client request. ## [0.2.7] - 2019-09-25 ### Added -* Remaining getter methods for `ClientRequest`'s private `head` field #1101 +- Remaining getter methods for `ClientRequest`'s private `head` field #1101 ## [0.2.6] - 2019-09-12 ### Added -* Export frozen request related types. +- Export frozen request related types. ## [0.2.5] - 2019-09-11 ### Added -* Add `FrozenClientRequest` to support retries for sending HTTP requests +- Add `FrozenClientRequest` to support retries for sending HTTP requests ### Changed -* Ensure that the `Host` header is set when initiating a WebSocket client connection. +- Ensure that the `Host` header is set when initiating a WebSocket client connection. ## [0.2.4] - 2019-08-13 ### Changed -* Update percent-encoding to "2.1" +- Update percent-encoding to "2.1" -* Update serde_urlencoded to "0.6.1" +- Update serde_urlencoded to "0.6.1" ## [0.2.3] - 2019-08-01 ### Added -* Add `rustls` support +- Add `rustls` support ## [0.2.2] - 2019-07-01 ### Changed -* Always append a colon after username in basic auth +- Always append a colon after username in basic auth -* Upgrade `rand` dependency version to 0.7 +- Upgrade `rand` dependency version to 0.7 ## [0.2.1] - 2019-06-05 ### Added -* Add license files +- Add license files ## [0.2.0] - 2019-05-12 ### Added -* Allow to send headers in `Camel-Case` form. +- Allow to send headers in `Camel-Case` form. ### Changed -* Upgrade actix-http dependency. +- Upgrade actix-http dependency. ## [0.1.1] - 2019-04-19 ### Added -* Allow to specify server address for http and ws requests. +- Allow to specify server address for http and ws requests. ### Changed -* `ClientRequest::if_true()` and `ClientRequest::if_some()` use instance instead of ref +- `ClientRequest::if_true()` and `ClientRequest::if_some()` use instance instead of ref ## [0.1.0] - 2019-04-16 -* No changes +- No changes ## [0.1.0-alpha.6] - 2019-04-14 ### Changed -* Do not set default headers for websocket request +- Do not set default headers for websocket request ## [0.1.0-alpha.5] - 2019-04-12 ### Changed -* Do not set any default headers +- Do not set any default headers ### Added -* Add Debug impl for BoxedSocket +- Add Debug impl for BoxedSocket ## [0.1.0-alpha.4] - 2019-04-08 ### Changed -* Update actix-http dependency +- Update actix-http dependency ## [0.1.0-alpha.3] - 2019-04-02 ### Added -* Export `MessageBody` type +- Export `MessageBody` type -* `ClientResponse::json()` - Loads and parse `application/json` encoded body +- `ClientResponse::json()` - Loads and parse `application/json` encoded body ### Changed -* `ClientRequest::json()` accepts reference instead of object. +- `ClientRequest::json()` accepts reference instead of object. -* `ClientResponse::body()` does not consume response object. +- `ClientResponse::body()` does not consume response object. -* Renamed `ClientRequest::close_connection()` to `ClientRequest::force_close()` +- Renamed `ClientRequest::close_connection()` to `ClientRequest::force_close()` ## [0.1.0-alpha.2] - 2019-03-29 ### Added -* Per request and session wide request timeout. +- Per request and session wide request timeout. -* Session wide headers. +- Session wide headers. -* Session wide basic and bearer auth. +- Session wide basic and bearer auth. -* Re-export `actix_http::client::Connector`. +- Re-export `actix_http::client::Connector`. ### Changed -* Allow to override request's uri +- Allow to override request's uri -* Export `ws` sub-module with websockets related types +- Export `ws` sub-module with websockets related types ## [0.1.0-alpha.1] - 2019-03-28 -* Initial impl +- Initial impl diff --git a/awc/Cargo.toml b/awc/Cargo.toml index 9beecc6d4..036c74da9 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -1,19 +1,20 @@ [package] name = "awc" -version = "3.0.0-beta.2" -authors = ["Nikolay Kim "] +version = "3.0.0-beta.18" +authors = [ + "Nikolay Kim ", + "fakeshadow <24548779@qq.com>", +] description = "Async HTTP and WebSocket client library built on the Actix ecosystem" -readme = "README.md" keywords = ["actix", "http", "framework", "async", "web"] -homepage = "https://actix.rs" -repository = "https://github.com/actix/actix-web.git" -documentation = "https://docs.rs/awc/" categories = [ "network-programming", "asynchronous", "web-programming::http-client", "web-programming::websocket", ] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" license = "MIT OR Apache-2.0" edition = "2018" @@ -23,64 +24,93 @@ path = "src/lib.rs" [package.metadata.docs.rs] # features that docs.rs will build with -features = ["openssl", "rustls", "compress", "cookies"] +features = ["openssl", "rustls", "compress-brotli", "compress-gzip", "compress-zstd", "cookies"] [features] -default = ["compress", "cookies"] +default = ["compress-brotli", "compress-gzip", "compress-zstd", "cookies"] # openssl -openssl = ["tls-openssl", "actix-http/openssl"] +openssl = ["tls-openssl", "actix-tls/openssl"] # rustls -rustls = ["tls-rustls", "actix-http/rustls"] +rustls = ["tls-rustls", "actix-tls/rustls"] -# content-encoding support -compress = ["actix-http/compress"] +# Brotli algorithm content-encoding support +compress-brotli = ["actix-http/compress-brotli", "__compress"] +# Gzip and deflate algorithms content-encoding support +compress-gzip = ["actix-http/compress-gzip", "__compress"] +# Zstd algorithm content-encoding support +compress-zstd = ["actix-http/compress-zstd", "__compress"] # cookie parsing and cookie jar -cookies = ["actix-http/cookies"] +cookies = ["cookie"] # trust-dns as dns resolver -trust-dns = ["actix-http/trust-dns"] +trust-dns = ["trust-dns-resolver"] + +# Internal (PRIVATE!) features used to aid testing and cheking feature status. +# Don't rely on these whatsoever. They may disappear at anytime. +__compress = [] + +# Enable dangerous feature for testing and local network usage: +# - HTTP/2 over TCP(No Tls). +# DO NOT enable this over any internet use case. +dangerous-h2c = [] [dependencies] -actix-codec = "0.4.0-beta.1" -actix-service = "2.0.0-beta.4" -actix-http = "3.0.0-beta.3" -actix-rt = "2" +actix-codec = "0.4.1" +actix-service = "2.0.0" +actix-http = "3.0.0-beta.18" +actix-rt = { version = "2.1", default-features = false } +actix-tls = { version = "3.0.0", features = ["connect", "uri"] } +actix-utils = "3.0.0" +ahash = "0.7" base64 = "0.13" bytes = "1" -cfg-if = "1.0" +cfg-if = "1" derive_more = "0.99.5" -futures-core = { version = "0.3.7", default-features = false } +futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } +futures-util = { version = "0.3.7", default-features = false, features = ["alloc", "sink"] } +h2 = "0.3.9" +http = "0.2.5" +itoa = "1" log =" 0.4" mime = "0.3" percent-encoding = "2.1" +pin-project-lite = "0.2" rand = "0.8" serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.7" -tls-openssl = { version = "0.10.9", package = "openssl", optional = true } -tls-rustls = { version = "0.19.0", package = "rustls", optional = true, features = ["dangerous_configuration"] } +tokio = { version = "1.8.4", features = ["sync"] } -[target.'cfg(windows)'.dependencies.tls-openssl] -version = "0.10.9" -package = "openssl" -features = ["vendored"] -optional = true +cookie = { version = "0.16", features = ["percent-encode"], optional = true } + +tls-openssl = { package = "openssl", version = "0.10.9", optional = true } +tls-rustls = { package = "rustls", version = "0.20.0", optional = true, features = ["dangerous_configuration"] } + +trust-dns-resolver = { version = "0.20.0", optional = true } [dev-dependencies] -actix-web = { version = "4.0.0-beta.3", features = ["openssl"] } -actix-http = { version = "3.0.0-beta.3", features = ["openssl"] } -actix-http-test = { version = "3.0.0-beta.2", features = ["openssl"] } -actix-utils = "3.0.0-beta.1" -actix-server = "2.0.0-beta.3" -actix-tls = { version = "3.0.0-beta.3", features = ["openssl", "rustls"] } +actix-http = { version = "3.0.0-beta.18", features = ["openssl"] } +actix-http-test = { version = "3.0.0-beta.11", features = ["openssl"] } +actix-server = "2.0.0-rc.2" +actix-test = { version = "0.1.0-beta.11", features = ["openssl", "rustls"] } +actix-tls = { version = "3.0.0", features = ["openssl", "rustls"] } +actix-utils = "3.0.0" +actix-web = { version = "4.0.0-beta.19", features = ["openssl"] } brotli2 = "0.3.2" +const-str = "0.3" +env_logger = "0.9" flate2 = "1.0.13" futures-util = { version = "0.3.7", default-features = false } -env_logger = "0.8" +static_assertions = "1.1" rcgen = "0.8" -webpki = "0.21" +rustls-pemfile = "0.2" +zstd = "0.9" + +[[example]] +name = "client" +required-features = ["rustls"] diff --git a/awc/README.md b/awc/README.md index 043ae6a41..6a68ac05a 100644 --- a/awc/README.md +++ b/awc/README.md @@ -3,19 +3,19 @@ > Async HTTP and WebSocket client library. [![crates.io](https://img.shields.io/crates/v/awc?label=latest)](https://crates.io/crates/awc) -[![Documentation](https://docs.rs/awc/badge.svg?version=3.0.0-beta.2)](https://docs.rs/awc/3.0.0-beta.2) +[![Documentation](https://docs.rs/awc/badge.svg?version=3.0.0-beta.18)](https://docs.rs/awc/3.0.0-beta.18) ![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/awc) -[![Dependency Status](https://deps.rs/crate/awc/3.0.0-beta.2/status.svg)](https://deps.rs/crate/awc/3.0.0-beta.2) -[![Join the chat at https://gitter.im/actix/actix-web](https://badges.gitter.im/actix/actix-web.svg)](https://gitter.im/actix/actix-web?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Dependency Status](https://deps.rs/crate/awc/3.0.0-beta.18/status.svg)](https://deps.rs/crate/awc/3.0.0-beta.18) +[![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x) ## Documentation & Resources - [API Documentation](https://docs.rs/awc) -- [Example Project](https://github.com/actix/examples/tree/HEAD/awc_https) -- [Chat on Gitter](https://gitter.im/actix/actix-web) -- Minimum Supported Rust Version (MSRV): 1.46.0 +- [Example Project](https://github.com/actix/examples/tree/HEAD/security/awc_https) +- Minimum Supported Rust Version (MSRV): 1.54 ## Example + ```rust use actix_rt::System; use awc::Client; @@ -26,7 +26,7 @@ fn main() { let res = client .get("http://www.rust-lang.org") // <- Create request builder - .header("User-Agent", "Actix-web") + .insert_header(("User-Agent", "Actix-web")) .send() // <- Send http request .await; diff --git a/examples/client.rs b/awc/examples/client.rs similarity index 53% rename from examples/client.rs rename to awc/examples/client.rs index b9574590d..653cb226f 100644 --- a/examples/client.rs +++ b/awc/examples/client.rs @@ -1,18 +1,20 @@ -use actix_http::Error; +use std::error::Error as StdError; #[actix_web::main] -async fn main() -> Result<(), Error> { - std::env::set_var("RUST_LOG", "actix_http=trace"); +async fn main() -> Result<(), Box> { + std::env::set_var("RUST_LOG", "client=trace,awc=trace,actix_http=trace"); env_logger::init(); let client = awc::Client::new(); // Create request builder, configure request and send - let mut response = client + let request = client .get("https://www.rust-lang.org/") - .append_header(("User-Agent", "Actix-web")) - .send() - .await?; + .append_header(("User-Agent", "Actix-web")); + + println!("Request: {:?}", request); + + let mut response = request.send().await?; // server http response println!("Response: {:?}", response); diff --git a/awc/src/any_body.rs b/awc/src/any_body.rs new file mode 100644 index 000000000..d09a943ab --- /dev/null +++ b/awc/src/any_body.rs @@ -0,0 +1,192 @@ +use std::{ + fmt, mem, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use pin_project_lite::pin_project; + +use actix_http::body::{BodySize, BoxBody, MessageBody}; + +pin_project! { + /// Represents various types of HTTP message body. + #[derive(Clone)] + #[project = AnyBodyProj] + pub enum AnyBody { + /// Empty response. `Content-Length` header is not set. + None, + + /// Complete, in-memory response body. + Bytes { body: Bytes }, + + /// Generic / Other message body. + Body { #[pin] body: B }, + } +} + +impl AnyBody { + /// Constructs a "body" representing an empty response. + pub fn none() -> Self { + Self::None + } + + /// Constructs a new, 0-length body. + pub fn empty() -> Self { + Self::Bytes { body: Bytes::new() } + } + + /// Create boxed body from generic message body. + pub fn new_boxed(body: B) -> Self + where + B: MessageBody + 'static, + { + Self::Body { body: body.boxed() } + } + + /// Constructs new `AnyBody` instance from a slice of bytes by copying it. + /// + /// If your bytes container is owned, it may be cheaper to use a `From` impl. + pub fn copy_from_slice(s: &[u8]) -> Self { + Self::Bytes { + body: Bytes::copy_from_slice(s), + } + } + + #[doc(hidden)] + #[deprecated(since = "4.0.0", note = "Renamed to `copy_from_slice`.")] + pub fn from_slice(s: &[u8]) -> Self { + Self::Bytes { + body: Bytes::copy_from_slice(s), + } + } +} + +impl AnyBody { + /// Create body from generic message body. + pub fn new(body: B) -> Self { + Self::Body { body } + } +} + +impl AnyBody +where + B: MessageBody + 'static, +{ + /// Converts a [`MessageBody`] type into the best possible representation. + /// + /// Checks size for `None` and tries to convert to `Bytes`. Otherwise, uses the `Body` variant. + pub fn from_message_body(body: B) -> Self + where + B: MessageBody, + { + if matches!(body.size(), BodySize::None) { + return Self::None; + } + + match body.try_into_bytes() { + Ok(body) => Self::Bytes { body }, + Err(body) => Self::new(body), + } + } + + pub fn into_boxed(self) -> AnyBody { + match self { + Self::None => AnyBody::None, + Self::Bytes { body } => AnyBody::Bytes { body }, + Self::Body { body } => AnyBody::new_boxed(body), + } + } +} + +impl MessageBody for AnyBody +where + B: MessageBody, +{ + type Error = crate::BoxError; + + fn size(&self) -> BodySize { + match self { + AnyBody::None => BodySize::None, + AnyBody::Bytes { ref body } => BodySize::Sized(body.len() as u64), + AnyBody::Body { ref body } => body.size(), + } + } + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.project() { + AnyBodyProj::None => Poll::Ready(None), + AnyBodyProj::Bytes { body } => { + let len = body.len(); + if len == 0 { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(mem::take(body)))) + } + } + + AnyBodyProj::Body { body } => body.poll_next(cx).map_err(|err| err.into()), + } + } +} + +impl PartialEq for AnyBody { + fn eq(&self, other: &AnyBody) -> bool { + match self { + AnyBody::None => matches!(*other, AnyBody::None), + AnyBody::Bytes { body } => match other { + AnyBody::Bytes { body: b2 } => body == b2, + _ => false, + }, + AnyBody::Body { .. } => false, + } + } +} + +impl fmt::Debug for AnyBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + AnyBody::None => write!(f, "AnyBody::None"), + AnyBody::Bytes { ref body } => write!(f, "AnyBody::Bytes({:?})", body), + AnyBody::Body { ref body } => write!(f, "AnyBody::Message({:?})", body), + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomPinned; + + use static_assertions::{assert_impl_all, assert_not_impl_all}; + + use super::*; + + struct PinType(PhantomPinned); + + impl MessageBody for PinType { + type Error = crate::BoxError; + + fn size(&self) -> BodySize { + unimplemented!() + } + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + unimplemented!() + } + } + + assert_impl_all!(AnyBody<()>: MessageBody, fmt::Debug, Send, Sync, Unpin); + assert_impl_all!(AnyBody>: MessageBody, fmt::Debug, Send, Sync, Unpin); + assert_impl_all!(AnyBody: MessageBody, fmt::Debug, Send, Sync, Unpin); + assert_impl_all!(AnyBody: MessageBody, fmt::Debug, Unpin); + assert_impl_all!(AnyBody: MessageBody); + + assert_not_impl_all!(AnyBody: Send, Sync, Unpin); + assert_not_impl_all!(AnyBody: Send, Sync, Unpin); +} diff --git a/awc/src/builder.rs b/awc/src/builder.rs index 94ffb8a71..16a4e9cb5 100644 --- a/awc/src/builder.rs +++ b/awc/src/builder.rs @@ -1,62 +1,95 @@ -use std::convert::TryFrom; -use std::fmt; -use std::rc::Rc; -use std::time::Duration; +use std::{convert::TryFrom, fmt, net::IpAddr, rc::Rc, time::Duration}; -use actix_http::client::{Connect as HttpConnect, ConnectError, Connection, Connector}; -use actix_http::http::{self, header, Error as HttpError, HeaderMap, HeaderName}; -use actix_service::Service; +use actix_http::{ + error::HttpError, + header::{self, HeaderMap, HeaderName, TryIntoHeaderPair}, + Uri, +}; +use actix_rt::net::{ActixStream, TcpStream}; +use actix_service::{boxed, Service}; -use crate::connect::{Connect, ConnectorWrapper}; -use crate::{Client, ClientConfig}; +use crate::{ + client::{ + ClientConfig, ConnectInfo, Connector, ConnectorService, TcpConnectError, TcpConnection, + }, + connect::DefaultConnector, + error::SendRequestError, + middleware::{NestTransform, Redirect, Transform}, + Client, ConnectRequest, ConnectResponse, +}; /// An HTTP Client builder /// /// This type can be used to construct an instance of `Client` through a /// builder-like pattern. -pub struct ClientBuilder { - default_headers: bool, - allow_redirects: bool, - max_redirects: usize, +pub struct ClientBuilder { max_http_version: Option, stream_window_size: Option, conn_window_size: Option, - headers: HeaderMap, + fundamental_headers: bool, + default_headers: HeaderMap, timeout: Option, - connector: Option>, -} - -impl Default for ClientBuilder { - fn default() -> Self { - Self::new() - } + connector: Connector, + middleware: M, + local_address: Option, + max_redirects: u8, } impl ClientBuilder { - pub fn new() -> Self { + #[allow(clippy::new_ret_no_self)] + pub fn new() -> ClientBuilder< + impl Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone, + (), + > { ClientBuilder { - default_headers: true, - allow_redirects: true, - max_redirects: 10, - headers: HeaderMap::new(), - timeout: Some(Duration::from_secs(5)), - connector: None, max_http_version: None, stream_window_size: None, conn_window_size: None, + fundamental_headers: true, + default_headers: HeaderMap::new(), + timeout: Some(Duration::from_secs(5)), + connector: Connector::new(), + middleware: (), + local_address: None, + max_redirects: 10, } } +} +impl ClientBuilder +where + S: Service, Response = TcpConnection, Error = TcpConnectError> + + Clone + + 'static, + Io: ActixStream + fmt::Debug + 'static, +{ /// Use custom connector service. - pub fn connector(mut self, connector: T) -> Self + pub fn connector(self, connector: Connector) -> ClientBuilder where - T: Service + 'static, - T::Response: Connection, - ::Future: 'static, - T::Future: 'static, + S1: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone + + 'static, + Io1: ActixStream + fmt::Debug + 'static, { - self.connector = Some(Box::new(ConnectorWrapper(connector))); - self + ClientBuilder { + middleware: self.middleware, + fundamental_headers: self.fundamental_headers, + default_headers: self.default_headers, + timeout: self.timeout, + local_address: self.local_address, + connector, + max_http_version: self.max_http_version, + stream_window_size: self.stream_window_size, + conn_window_size: self.conn_window_size, + max_redirects: self.max_redirects, + } } /// Set request timeout @@ -74,11 +107,9 @@ impl ClientBuilder { self } - /// Do not follow redirects. - /// - /// Redirects are allowed by default. - pub fn disable_redirects(mut self) -> Self { - self.allow_redirects = false; + /// Set local IP Address the connector would use for establishing connection. + pub fn local_address(mut self, addr: IpAddr) -> Self { + self.local_address = Some(addr); self } @@ -90,6 +121,22 @@ impl ClientBuilder { self } + /// Do not follow redirects. + /// + /// Redirects are allowed by default. + pub fn disable_redirects(mut self) -> Self { + self.max_redirects = 0; + self + } + + /// Set max number of redirects. + /// + /// Max redirects is set to 10 by default. + pub fn max_redirects(mut self, num: u8) -> Self { + self.max_redirects = num; + self + } + /// Indicates the initial window size (in octets) for /// HTTP2 stream-level flow control for received data. /// @@ -108,55 +155,63 @@ impl ClientBuilder { self } - /// Set max number of redirects. + /// Do not add fundamental default request headers. /// - /// Max redirects is set to 10 by default. - pub fn max_redirects(mut self, num: usize) -> Self { - self.max_redirects = num; - self - } - - /// Do not add default request headers. /// By default `Date` and `User-Agent` headers are set. pub fn no_default_headers(mut self) -> Self { - self.default_headers = false; + self.fundamental_headers = false; self } - /// Add default header. Headers added by this method - /// get added to every request. + /// Add default header. + /// + /// Headers added by this method get added to every request unless overriden by . + /// + /// # Panics + /// Panics if header name or value is invalid. + pub fn add_default_header(mut self, header: impl TryIntoHeaderPair) -> Self { + match header.try_into_pair() { + Ok((key, value)) => self.default_headers.append(key, value), + Err(err) => panic!("Header error: {:?}", err.into()), + } + + self + } + + #[doc(hidden)] + #[deprecated(since = "3.0.0", note = "Prefer `add_default_header((key, value))`.")] pub fn header(mut self, key: K, value: V) -> Self where HeaderName: TryFrom, >::Error: fmt::Debug + Into, - V: header::IntoHeaderValue, + V: header::TryIntoHeaderValue, V::Error: fmt::Debug, { match HeaderName::try_from(key) { Ok(key) => match value.try_into_value() { Ok(value) => { - self.headers.append(key, value); + self.default_headers.append(key, value); } - Err(e) => log::error!("Header value error: {:?}", e), + Err(err) => log::error!("Header value error: {:?}", err), }, - Err(e) => log::error!("Header name error: {:?}", e), + Err(err) => log::error!("Header name error: {:?}", err), } self } /// Set client wide HTTP basic authorization header - pub fn basic_auth(self, username: U, password: Option<&str>) -> Self + pub fn basic_auth(self, username: N, password: Option<&str>) -> Self where - U: fmt::Display, + N: fmt::Display, { let auth = match password { Some(password) => format!("{}:{}", username, password), None => format!("{}:", username), }; - self.header( + self.add_default_header(( header::AUTHORIZATION, format!("Basic {}", base64::encode(&auth)), - ) + )) } /// Set client wide HTTP bearer authentication header @@ -164,32 +219,80 @@ impl ClientBuilder { where T: fmt::Display, { - self.header(header::AUTHORIZATION, format!("Bearer {}", token)) + self.add_default_header((header::AUTHORIZATION, format!("Bearer {}", token))) + } + + /// Registers middleware, in the form of a middleware component (type), that runs during inbound + /// and/or outbound processing in the request life-cycle (request -> response), + /// modifying request/response as necessary, across all requests managed by the `Client`. + pub fn wrap( + self, + mw: M1, + ) -> ClientBuilder> + where + M: Transform, + M1: Transform, + { + ClientBuilder { + middleware: NestTransform::new(self.middleware, mw), + fundamental_headers: self.fundamental_headers, + max_http_version: self.max_http_version, + stream_window_size: self.stream_window_size, + conn_window_size: self.conn_window_size, + default_headers: self.default_headers, + timeout: self.timeout, + connector: self.connector, + local_address: self.local_address, + max_redirects: self.max_redirects, + } } /// Finish build process and create `Client` instance. - pub fn finish(self) -> Client { - let connector = if let Some(connector) = self.connector { - connector + pub fn finish(self) -> Client + where + M: Transform>, ConnectRequest> + 'static, + M::Transform: + Service, + { + let max_redirects = self.max_redirects; + + if max_redirects > 0 { + self.wrap(Redirect::new().max_redirect_times(max_redirects)) + ._finish() } else { - let mut connector = Connector::new(); - if let Some(val) = self.max_http_version { - connector = connector.max_http_version(val) - }; - if let Some(val) = self.conn_window_size { - connector = connector.initial_connection_window_size(val) - }; - if let Some(val) = self.stream_window_size { - connector = connector.initial_window_size(val) - }; - Box::new(ConnectorWrapper(connector.finish())) as _ + self._finish() + } + } + + fn _finish(self) -> Client + where + M: Transform>, ConnectRequest> + 'static, + M::Transform: + Service, + { + let mut connector = self.connector; + + if let Some(val) = self.max_http_version { + connector = connector.max_http_version(val); }; - let config = ClientConfig { - headers: self.headers, + if let Some(val) = self.conn_window_size { + connector = connector.initial_connection_window_size(val) + }; + if let Some(val) = self.stream_window_size { + connector = connector.initial_window_size(val) + }; + if let Some(val) = self.local_address { + connector = connector.local_address(val); + } + + let connector = DefaultConnector::new(connector.finish()); + let connector = boxed::rc_service(self.middleware.new_transform(connector)); + + Client(ClientConfig { + default_headers: Rc::new(self.default_headers), timeout: self.timeout, connector, - }; - Client(Rc::new(config)) + }) } } @@ -202,7 +305,7 @@ mod tests { let client = ClientBuilder::new().basic_auth("username", Some("password")); assert_eq!( client - .headers + .default_headers .get(header::AUTHORIZATION) .unwrap() .to_str() @@ -213,7 +316,7 @@ mod tests { let client = ClientBuilder::new().basic_auth("username", None); assert_eq!( client - .headers + .default_headers .get(header::AUTHORIZATION) .unwrap() .to_str() @@ -227,7 +330,7 @@ mod tests { let client = ClientBuilder::new().bearer_auth("someS3cr3tAutht0k3n"); assert_eq!( client - .headers + .default_headers .get(header::AUTHORIZATION) .unwrap() .to_str() diff --git a/actix-http/src/client/config.rs b/awc/src/client/config.rs similarity index 83% rename from actix-http/src/client/config.rs rename to awc/src/client/config.rs index fad902d04..530c1e03b 100644 --- a/actix-http/src/client/config.rs +++ b/awc/src/client/config.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{net::IpAddr, time::Duration}; const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2MB const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1MB @@ -7,24 +7,28 @@ const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1MB #[derive(Clone)] pub(crate) struct ConnectorConfig { pub(crate) timeout: Duration, + pub(crate) handshake_timeout: Duration, pub(crate) conn_lifetime: Duration, pub(crate) conn_keep_alive: Duration, pub(crate) disconnect_timeout: Option, pub(crate) limit: usize, pub(crate) conn_window_size: u32, pub(crate) stream_window_size: u32, + pub(crate) local_address: Option, } impl Default for ConnectorConfig { fn default() -> Self { Self { timeout: Duration::from_secs(5), + handshake_timeout: Duration::from_secs(5), conn_lifetime: Duration::from_secs(75), conn_keep_alive: Duration::from_secs(15), disconnect_timeout: Some(Duration::from_millis(3000)), limit: 100, conn_window_size: DEFAULT_H2_CONN_WINDOW, stream_window_size: DEFAULT_H2_STREAM_WINDOW, + local_address: None, } } } diff --git a/awc/src/client/connection.rs b/awc/src/client/connection.rs new file mode 100644 index 000000000..456f119aa --- /dev/null +++ b/awc/src/client/connection.rs @@ -0,0 +1,454 @@ +use std::{ + io, + ops::{Deref, DerefMut}, + pin::Pin, + task::{Context, Poll}, + time, +}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; +use actix_rt::task::JoinHandle; +use bytes::Bytes; +use futures_core::future::LocalBoxFuture; +use h2::client::SendRequest; + +use actix_http::{body::MessageBody, h1::ClientCodec, Payload, RequestHeadType, ResponseHead}; + +use crate::BoxError; + +use super::error::SendRequestError; +use super::pool::Acquired; +use super::{h1proto, h2proto}; + +/// Trait alias for types impl [tokio::io::AsyncRead] and [tokio::io::AsyncWrite]. +pub trait ConnectionIo: AsyncRead + AsyncWrite + Unpin + 'static {} + +impl ConnectionIo for T {} + +/// HTTP client connection +pub struct H1Connection { + io: Option, + created: time::Instant, + acquired: Acquired, +} + +impl H1Connection { + /// close or release the connection to pool based on flag input + pub(super) fn on_release(&mut self, keep_alive: bool) { + if keep_alive { + self.release(); + } else { + self.close(); + } + } + + /// Close connection + fn close(&mut self) { + let io = self.io.take().unwrap(); + self.acquired.close(ConnectionInnerType::H1(io)); + } + + /// Release this connection to the connection pool + fn release(&mut self) { + let io = self.io.take().unwrap(); + self.acquired + .release(ConnectionInnerType::H1(io), self.created); + } + + fn io_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Io> { + Pin::new(self.get_mut().io.as_mut().unwrap()) + } +} + +impl AsyncRead for H1Connection { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.io_pin_mut().poll_read(cx, buf) + } +} + +impl AsyncWrite for H1Connection { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.io_pin_mut().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.io_pin_mut().poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.io_pin_mut().poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.io_pin_mut().poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.io.as_ref().unwrap().is_write_vectored() + } +} + +/// HTTP2 client connection +pub struct H2Connection { + io: Option, + created: time::Instant, + acquired: Acquired, +} + +impl Deref for H2Connection { + type Target = SendRequest; + + fn deref(&self) -> &Self::Target { + &self.io.as_ref().unwrap().sender + } +} + +impl DerefMut for H2Connection { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.io.as_mut().unwrap().sender + } +} + +impl H2Connection { + /// close or release the connection to pool based on flag input + pub(super) fn on_release(&mut self, close: bool) { + if close { + self.close(); + } else { + self.release(); + } + } + + /// Close connection + fn close(&mut self) { + let io = self.io.take().unwrap(); + self.acquired.close(ConnectionInnerType::H2(io)); + } + + /// Release this connection to the connection pool + fn release(&mut self) { + let io = self.io.take().unwrap(); + self.acquired + .release(ConnectionInnerType::H2(io), self.created); + } +} + +/// `H2ConnectionInner` has two parts: `SendRequest` and `Connection`. +/// +/// `Connection` is spawned as an async task on runtime and `H2ConnectionInner` holds a handle +/// for this task. Therefore, it can wake up and quit the task when SendRequest is dropped. +pub(super) struct H2ConnectionInner { + handle: JoinHandle<()>, + sender: SendRequest, +} + +impl H2ConnectionInner { + pub(super) fn new( + sender: SendRequest, + connection: h2::client::Connection, + ) -> Self { + let handle = actix_rt::spawn(async move { + let _ = connection.await; + }); + + Self { handle, sender } + } +} + +/// Cancel spawned connection task on drop. +impl Drop for H2ConnectionInner { + fn drop(&mut self) { + // TODO: this can end up sending extraneous requests; see if there is a better way to handle + if self + .sender + .send_request(http::Request::new(()), true) + .is_err() + { + self.handle.abort(); + } + } +} + +/// Unified connection type cover HTTP/1 Plain/TLS and HTTP/2 protocols. +#[allow(dead_code)] +pub enum Connection> +where + A: ConnectionIo, + B: ConnectionIo, +{ + Tcp(ConnectionType
), + Tls(ConnectionType), +} + +/// Unified connection type cover Http1/2 protocols +pub enum ConnectionType { + H1(H1Connection), + H2(H2Connection), +} + +/// Helper type for storing connection types in pool. +pub(super) enum ConnectionInnerType { + H1(Io), + H2(H2ConnectionInner), +} + +impl ConnectionType { + pub(super) fn from_pool( + inner: ConnectionInnerType, + created: time::Instant, + acquired: Acquired, + ) -> Self { + match inner { + ConnectionInnerType::H1(io) => Self::from_h1(io, created, acquired), + ConnectionInnerType::H2(io) => Self::from_h2(io, created, acquired), + } + } + + pub(super) fn from_h1(io: Io, created: time::Instant, acquired: Acquired) -> Self { + Self::H1(H1Connection { + io: Some(io), + created, + acquired, + }) + } + + pub(super) fn from_h2( + io: H2ConnectionInner, + created: time::Instant, + acquired: Acquired, + ) -> Self { + Self::H2(H2Connection { + io: Some(io), + created, + acquired, + }) + } +} + +impl Connection +where + A: ConnectionIo, + B: ConnectionIo, +{ + /// Send a request through connection. + pub fn send_request( + self, + head: H, + body: RB, + ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> + where + H: Into + 'static, + RB: MessageBody + 'static, + RB::Error: Into, + { + Box::pin(async move { + match self { + Connection::Tcp(ConnectionType::H1(conn)) => { + h1proto::send_request(conn, head.into(), body).await + } + Connection::Tls(ConnectionType::H1(conn)) => { + h1proto::send_request(conn, head.into(), body).await + } + Connection::Tls(ConnectionType::H2(conn)) => { + h2proto::send_request(conn, head.into(), body).await + } + _ => { + unreachable!("Plain TCP connection can be used only with HTTP/1.1 protocol") + } + } + }) + } + + /// Send request, returns Response and Framed tunnel. + pub fn open_tunnel + 'static>( + self, + head: H, + ) -> LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed, ClientCodec>), SendRequestError>, + > { + Box::pin(async move { + match self { + Connection::Tcp(ConnectionType::H1(ref _conn)) => { + let (head, framed) = h1proto::open_tunnel(self, head.into()).await?; + Ok((head, framed)) + } + Connection::Tls(ConnectionType::H1(ref _conn)) => { + let (head, framed) = h1proto::open_tunnel(self, head.into()).await?; + Ok((head, framed)) + } + Connection::Tls(ConnectionType::H2(mut conn)) => { + conn.release(); + Err(SendRequestError::TunnelNotSupported) + } + Connection::Tcp(ConnectionType::H2(_)) => { + unreachable!("Plain Tcp connection can be used only in Http1 protocol") + } + } + }) + } +} + +impl AsyncRead for Connection +where + A: ConnectionIo, + B: ConnectionIo, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_read(cx, buf), + Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_read(cx, buf), + _ => unreachable!("H2Connection can not impl AsyncRead trait"), + } + } +} + +const H2_UNREACHABLE_WRITE: &str = "H2Connection can not impl AsyncWrite trait"; + +impl AsyncWrite for Connection +where + A: ConnectionIo, + B: ConnectionIo, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_write(cx, buf), + Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_write(cx, buf), + _ => unreachable!(H2_UNREACHABLE_WRITE), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_flush(cx), + Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_flush(cx), + _ => unreachable!(H2_UNREACHABLE_WRITE), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_shutdown(cx), + Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_shutdown(cx), + _ => unreachable!(H2_UNREACHABLE_WRITE), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + Connection::Tcp(ConnectionType::H1(conn)) => { + Pin::new(conn).poll_write_vectored(cx, bufs) + } + Connection::Tls(ConnectionType::H1(conn)) => { + Pin::new(conn).poll_write_vectored(cx, bufs) + } + _ => unreachable!(H2_UNREACHABLE_WRITE), + } + } + + fn is_write_vectored(&self) -> bool { + match *self { + Connection::Tcp(ConnectionType::H1(ref conn)) => conn.is_write_vectored(), + Connection::Tls(ConnectionType::H1(ref conn)) => conn.is_write_vectored(), + _ => unreachable!(H2_UNREACHABLE_WRITE), + } + } +} + +#[cfg(test)] +mod test { + use std::{ + future::Future, + net, + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, + }; + + use actix_rt::{ + net::TcpStream, + time::{interval, Interval}, + }; + + use super::*; + + #[actix_rt::test] + async fn test_h2_connection_drop() { + let addr = "127.0.0.1:0".parse::().unwrap(); + let listener = net::TcpListener::bind(addr).unwrap(); + let local = listener.local_addr().unwrap(); + + std::thread::spawn(move || while listener.accept().is_ok() {}); + + let tcp = TcpStream::connect(local).await.unwrap(); + let (sender, connection) = h2::client::handshake(tcp).await.unwrap(); + let conn = H2ConnectionInner::new(sender.clone(), connection); + + assert!(sender.clone().ready().await.is_ok()); + assert!(h2::client::SendRequest::clone(&conn.sender) + .ready() + .await + .is_ok()); + + drop(conn); + + struct DropCheck { + sender: h2::client::SendRequest, + interval: Interval, + start_from: Instant, + } + + impl Future for DropCheck { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + match futures_core::ready!(this.sender.poll_ready(cx)) { + Ok(()) => { + if this.start_from.elapsed() > Duration::from_secs(10) { + panic!("connection should be gone and can not be ready"); + } else { + let _ = this.interval.poll_tick(cx); + Poll::Pending + } + } + Err(_) => Poll::Ready(()), + } + } + } + + DropCheck { + sender, + interval: interval(Duration::from_millis(100)), + start_from: Instant::now(), + } + .await; + } +} diff --git a/awc/src/client/connector.rs b/awc/src/client/connector.rs new file mode 100644 index 000000000..423f656a8 --- /dev/null +++ b/awc/src/client/connector.rs @@ -0,0 +1,901 @@ +use std::{ + fmt, + future::Future, + net::IpAddr, + pin::Pin, + rc::Rc, + task::{Context, Poll}, + time::Duration, +}; + +use actix_http::Protocol; +use actix_rt::{ + net::{ActixStream, TcpStream}, + time::{sleep, Sleep}, +}; +use actix_service::Service; +use actix_tls::connect::{ + ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection, + Connector as TcpConnector, Resolver, +}; +use futures_core::{future::LocalBoxFuture, ready}; +use http::Uri; +use pin_project_lite::pin_project; + +use super::{ + config::ConnectorConfig, + connection::{Connection, ConnectionIo}, + error::ConnectError, + pool::ConnectionPool, + Connect, +}; + +enum OurTlsConnector { + #[allow(dead_code)] // only dead when no TLS feature is enabled + None, + + #[cfg(feature = "openssl")] + Openssl(actix_tls::connect::openssl::reexports::SslConnector), + + /// Provided because building the OpenSSL context on newer versions can be very slow. + /// This prevents unnecessary calls to `.build()` while constructing the client connector. + #[cfg(feature = "openssl")] + #[allow(dead_code)] // false positive; used in build_ssl + OpensslBuilder(actix_tls::connect::openssl::reexports::SslConnectorBuilder), + + #[cfg(feature = "rustls")] + Rustls(std::sync::Arc), +} + +/// Manages HTTP client network connectivity. +/// +/// The `Connector` type uses a builder-like combinator pattern for service +/// construction that finishes by calling the `.finish()` method. +/// +/// ```ignore +/// use std::time::Duration; +/// use actix_http::client::Connector; +/// +/// let connector = Connector::new() +/// .timeout(Duration::from_secs(5)) +/// .finish(); +/// ``` +pub struct Connector { + connector: T, + config: ConnectorConfig, + + #[allow(dead_code)] // only dead when no TLS feature is enabled + tls: OurTlsConnector, +} + +impl Connector<()> { + #[allow(clippy::new_ret_no_self, clippy::let_unit_value)] + pub fn new() -> Connector< + impl Service< + ConnectInfo, + Response = TcpConnection, + Error = actix_tls::connect::ConnectError, + > + Clone, + > { + Connector { + connector: TcpConnector::new(resolver::resolver()).service(), + config: ConnectorConfig::default(), + tls: Self::build_ssl(vec![b"h2".to_vec(), b"http/1.1".to_vec()]), + } + } + + /// Provides an empty TLS connector when no TLS feature is enabled. + #[cfg(not(any(feature = "openssl", feature = "rustls")))] + fn build_ssl(_: Vec>) -> OurTlsConnector { + OurTlsConnector::None + } + + /// Build TLS connector with rustls, based on supplied ALPN protocols + /// + /// Note that if both `openssl` and `rustls` features are enabled, rustls will be used. + #[cfg(feature = "rustls")] + fn build_ssl(protocols: Vec>) -> OurTlsConnector { + use actix_tls::connect::rustls::{reexports::ClientConfig, webpki_roots_cert_store}; + + let mut config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(webpki_roots_cert_store()) + .with_no_client_auth(); + + config.alpn_protocols = protocols; + + OurTlsConnector::Rustls(std::sync::Arc::new(config)) + } + + /// Build TLS connector with openssl, based on supplied ALPN protocols + #[cfg(all(feature = "openssl", not(feature = "rustls")))] + fn build_ssl(protocols: Vec>) -> OurTlsConnector { + use actix_tls::connect::openssl::reexports::{SslConnector, SslMethod}; + use bytes::{BufMut, BytesMut}; + + let mut alpn = BytesMut::with_capacity(20); + for proto in &protocols { + alpn.put_u8(proto.len() as u8); + alpn.put(proto.as_slice()); + } + + let mut ssl = SslConnector::builder(SslMethod::tls()).unwrap(); + if let Err(err) = ssl.set_alpn_protos(&alpn) { + log::error!("Can not set ALPN protocol: {:?}", err); + } + + OurTlsConnector::OpensslBuilder(ssl) + } +} + +impl Connector { + /// Use custom connector. + pub fn connector(self, connector: S1) -> Connector + where + Io1: ActixStream + fmt::Debug + 'static, + S1: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone, + { + Connector { + connector, + config: self.config, + tls: self.tls, + } + } +} + +impl Connector +where + // Note: + // Input Io type is bound to ActixStream trait but internally in client module they + // are bound to ConnectionIo trait alias. And latter is the trait exposed to public + // in the form of Box type. + // + // This remap is to hide ActixStream's trait methods. They are not meant to be called + // from user code. + IO: ActixStream + fmt::Debug + 'static, + S: Service, Response = TcpConnection, Error = TcpConnectError> + + Clone + + 'static, +{ + /// Tcp connection timeout, i.e. max time to connect to remote host including dns name + /// resolution. Set to 5 second by default. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.config.timeout = timeout; + self + } + + /// Tls handshake timeout, i.e. max time to do tls handshake with remote host after tcp + /// connection established. Set to 5 second by default. + pub fn handshake_timeout(mut self, timeout: Duration) -> Self { + self.config.handshake_timeout = timeout; + self + } + + /// Use custom OpenSSL `SslConnector` instance. + #[cfg(feature = "openssl")] + pub fn openssl( + mut self, + connector: actix_tls::connect::openssl::reexports::SslConnector, + ) -> Self { + self.tls = OurTlsConnector::Openssl(connector); + self + } + + /// See docs for [`Connector::openssl`]. + #[doc(hidden)] + #[cfg(feature = "openssl")] + #[deprecated(since = "3.0.0", note = "Renamed to `Connector::openssl`.")] + pub fn ssl( + mut self, + connector: actix_tls::connect::openssl::reexports::SslConnector, + ) -> Self { + self.tls = OurTlsConnector::Openssl(connector); + self + } + + /// Use custom Rustls `ClientConfig` instance. + #[cfg(feature = "rustls")] + pub fn rustls( + mut self, + connector: std::sync::Arc, + ) -> Self { + self.tls = OurTlsConnector::Rustls(connector); + self + } + + /// Maximum supported HTTP major version. + /// + /// Supported versions are HTTP/1.1 and HTTP/2. + pub fn max_http_version(mut self, val: http::Version) -> Self { + let versions = match val { + http::Version::HTTP_11 => vec![b"http/1.1".to_vec()], + http::Version::HTTP_2 => vec![b"h2".to_vec(), b"http/1.1".to_vec()], + _ => { + unimplemented!("actix-http client only supports versions http/1.1 & http/2") + } + }; + self.tls = Connector::build_ssl(versions); + self + } + + /// Indicates the initial window size (in octets) for + /// HTTP2 stream-level flow control for received data. + /// + /// The default value is 65,535 and is good for APIs, but not for big objects. + pub fn initial_window_size(mut self, size: u32) -> Self { + self.config.stream_window_size = size; + self + } + + /// Indicates the initial window size (in octets) for + /// HTTP2 connection-level flow control for received data. + /// + /// The default value is 65,535 and is good for APIs, but not for big objects. + pub fn initial_connection_window_size(mut self, size: u32) -> Self { + self.config.conn_window_size = size; + self + } + + /// Set total number of simultaneous connections per type of scheme. + /// + /// If limit is 0, the connector has no limit. + /// The default limit size is 100. + pub fn limit(mut self, limit: usize) -> Self { + self.config.limit = limit; + self + } + + /// Set keep-alive period for opened connection. + /// + /// Keep-alive period is the period between connection usage. If + /// the delay between repeated usages of the same connection + /// exceeds this period, the connection is closed. + /// Default keep-alive period is 15 seconds. + pub fn conn_keep_alive(mut self, dur: Duration) -> Self { + self.config.conn_keep_alive = dur; + self + } + + /// Set max lifetime period for connection. + /// + /// Connection lifetime is max lifetime of any opened connection + /// until it is closed regardless of keep-alive period. + /// Default lifetime period is 75 seconds. + pub fn conn_lifetime(mut self, dur: Duration) -> Self { + self.config.conn_lifetime = dur; + self + } + + /// Set server connection disconnect timeout in milliseconds. + /// + /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete + /// within this time, the socket get dropped. This timeout affects only secure connections. + /// + /// To disable timeout set value to 0. + /// + /// By default disconnect timeout is set to 3000 milliseconds. + pub fn disconnect_timeout(mut self, dur: Duration) -> Self { + self.config.disconnect_timeout = Some(dur); + self + } + + /// Set local IP Address the connector would use for establishing connection. + pub fn local_address(mut self, addr: IpAddr) -> Self { + self.config.local_address = Some(addr); + self + } + + /// Finish configuration process and create connector service. + /// + /// The `Connector` builder always concludes by calling `finish()` last in its combinator chain. + pub fn finish(self) -> ConnectorService { + let local_address = self.config.local_address; + let timeout = self.config.timeout; + + let tcp_service_inner = + TcpConnectorInnerService::new(self.connector, timeout, local_address); + + #[allow(clippy::redundant_clone)] + let tcp_service = TcpConnectorService { + service: tcp_service_inner.clone(), + }; + + let tls = match self.tls { + #[cfg(feature = "openssl")] + OurTlsConnector::OpensslBuilder(builder) => { + OurTlsConnector::Openssl(builder.build()) + } + tls => tls, + }; + + let tls_service = match tls { + OurTlsConnector::None => { + #[cfg(not(feature = "dangerous-h2c"))] + { + None + } + + #[cfg(feature = "dangerous-h2c")] + { + use std::io; + + use actix_tls::connect::Connection; + use actix_utils::future::{ready, Ready}; + + impl IntoConnectionIo for TcpConnection> { + fn into_connection_io(self) -> (Box, Protocol) { + let io = self.into_parts().0; + (io, Protocol::Http2) + } + } + + /// With the `dangerous-h2c` feature enabled, this connector uses a no-op TLS + /// connection service that passes through plain TCP as a TLS connection. + /// + /// The protocol version of this fake TLS connection is set to be HTTP/2. + #[derive(Clone)] + struct NoOpTlsConnectorService; + + impl Service> for NoOpTlsConnectorService + where + IO: ActixStream + 'static, + { + type Response = Connection>; + type Error = io::Error; + type Future = Ready>; + + actix_service::always_ready!(); + + fn call(&self, connection: Connection) -> Self::Future { + let (io, connection) = connection.replace_io(()); + let (_, connection) = connection.replace_io(Box::new(io) as _); + + ready(Ok(connection)) + } + } + + let handshake_timeout = self.config.handshake_timeout; + + let tls_service = TlsConnectorService { + tcp_service: tcp_service_inner, + tls_service: NoOpTlsConnectorService, + timeout: handshake_timeout, + }; + + Some(actix_service::boxed::rc_service(tls_service)) + } + } + + #[cfg(feature = "openssl")] + OurTlsConnector::Openssl(tls) => { + const H2: &[u8] = b"h2"; + + use actix_tls::connect::openssl::{reexports::AsyncSslStream, TlsConnector}; + + impl IntoConnectionIo for TcpConnection> { + fn into_connection_io(self) -> (Box, Protocol) { + let sock = self.into_parts().0; + let h2 = sock + .ssl() + .selected_alpn_protocol() + .map_or(false, |protos| protos.windows(2).any(|w| w == H2)); + if h2 { + (Box::new(sock), Protocol::Http2) + } else { + (Box::new(sock), Protocol::Http1) + } + } + } + + let handshake_timeout = self.config.handshake_timeout; + + let tls_service = TlsConnectorService { + tcp_service: tcp_service_inner, + tls_service: TlsConnector::service(tls), + timeout: handshake_timeout, + }; + + Some(actix_service::boxed::rc_service(tls_service)) + } + + #[cfg(feature = "openssl")] + OurTlsConnector::OpensslBuilder(_) => { + unreachable!("OpenSSL builder is built before this match."); + } + + #[cfg(feature = "rustls")] + OurTlsConnector::Rustls(tls) => { + const H2: &[u8] = b"h2"; + + use actix_tls::connect::rustls::{reexports::AsyncTlsStream, TlsConnector}; + + impl IntoConnectionIo for TcpConnection> { + fn into_connection_io(self) -> (Box, Protocol) { + let sock = self.into_parts().0; + let h2 = sock + .get_ref() + .1 + .alpn_protocol() + .map_or(false, |protos| protos.windows(2).any(|w| w == H2)); + if h2 { + (Box::new(sock), Protocol::Http2) + } else { + (Box::new(sock), Protocol::Http1) + } + } + } + + let handshake_timeout = self.config.handshake_timeout; + + let tls_service = TlsConnectorService { + tcp_service: tcp_service_inner, + tls_service: TlsConnector::service(tls), + timeout: handshake_timeout, + }; + + Some(actix_service::boxed::rc_service(tls_service)) + } + }; + + let tcp_config = self.config.no_disconnect_timeout(); + + let tcp_pool = ConnectionPool::new(tcp_service, tcp_config); + + let tls_config = self.config; + let tls_pool = + tls_service.map(move |tls_service| ConnectionPool::new(tls_service, tls_config)); + + ConnectorServicePriv { tcp_pool, tls_pool } + } +} + +/// tcp service for map `TcpConnection` type to `(Io, Protocol)` +#[derive(Clone)] +pub struct TcpConnectorService { + service: S, +} + +impl Service for TcpConnectorService +where + S: Service, Error = ConnectError> + + Clone + + 'static, +{ + type Response = (Io, Protocol); + type Error = ConnectError; + type Future = TcpConnectorFuture; + + actix_service::forward_ready!(service); + + fn call(&self, req: Connect) -> Self::Future { + TcpConnectorFuture { + fut: self.service.call(req), + } + } +} + +pin_project! { + #[project = TcpConnectorFutureProj] + pub struct TcpConnectorFuture { + #[pin] + fut: Fut, + } +} + +impl Future for TcpConnectorFuture +where + Fut: Future, ConnectError>>, +{ + type Output = Result<(Io, Protocol), ConnectError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project() + .fut + .poll(cx) + .map_ok(|res| (res.into_parts().0, Protocol::Http1)) + } +} + +/// service for establish tcp connection and do client tls handshake. +/// operation is canceled when timeout limit reached. +struct TlsConnectorService { + /// TCP connection is canceled on `TcpConnectorInnerService`'s timeout setting. + tcp_service: Tcp, + + /// TLS connection is canceled on `TlsConnectorService`'s timeout setting. + tls_service: Tls, + + timeout: Duration, +} + +impl Service for TlsConnectorService +where + Tcp: Service, Error = ConnectError> + + Clone + + 'static, + Tls: Service, Error = std::io::Error> + Clone + 'static, + Tls::Response: IntoConnectionIo, + IO: ConnectionIo, +{ + type Response = (Box, Protocol); + type Error = ConnectError; + type Future = TlsConnectorFuture; + + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + ready!(self.tcp_service.poll_ready(cx))?; + ready!(self.tls_service.poll_ready(cx))?; + Poll::Ready(Ok(())) + } + + fn call(&self, req: Connect) -> Self::Future { + let fut = self.tcp_service.call(req); + let tls_service = self.tls_service.clone(); + let timeout = self.timeout; + + TlsConnectorFuture::TcpConnect { + fut, + tls_service: Some(tls_service), + timeout, + } + } +} + +pin_project! { + #[project = TlsConnectorProj] + #[allow(clippy::large_enum_variant)] + enum TlsConnectorFuture { + TcpConnect { + #[pin] + fut: Fut1, + tls_service: Option, + timeout: Duration, + }, + TlsConnect { + #[pin] + fut: Fut2, + #[pin] + timeout: Sleep, + }, + } + +} +/// helper trait for generic over different TlsStream types between tls crates. +trait IntoConnectionIo { + fn into_connection_io(self) -> (Box, Protocol); +} + +impl Future for TlsConnectorFuture +where + S: Service, Response = Res, Error = std::io::Error, Future = Fut2>, + S::Response: IntoConnectionIo, + Fut1: Future, ConnectError>>, + Fut2: Future>, + Io: ConnectionIo, +{ + type Output = Result<(Box, Protocol), ConnectError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project() { + TlsConnectorProj::TcpConnect { + fut, + tls_service, + timeout, + } => { + let res = ready!(fut.poll(cx))?; + let fut = tls_service + .take() + .expect("TlsConnectorFuture polled after complete") + .call(res); + let timeout = sleep(*timeout); + self.set(TlsConnectorFuture::TlsConnect { fut, timeout }); + self.poll(cx) + } + TlsConnectorProj::TlsConnect { fut, timeout } => match fut.poll(cx)? { + Poll::Ready(res) => Poll::Ready(Ok(res.into_connection_io())), + Poll::Pending => timeout.poll(cx).map(|_| Err(ConnectError::Timeout)), + }, + } + } +} + +/// service for establish tcp connection. +/// operation is canceled when timeout limit reached. +#[derive(Clone)] +pub struct TcpConnectorInnerService { + service: S, + timeout: Duration, + local_address: Option, +} + +impl TcpConnectorInnerService { + fn new(service: S, timeout: Duration, local_address: Option) -> Self { + Self { + service, + timeout, + local_address, + } + } +} + +impl Service for TcpConnectorInnerService +where + S: Service, Response = TcpConnection, Error = TcpConnectError> + + Clone + + 'static, +{ + type Response = S::Response; + type Error = ConnectError; + type Future = TcpConnectorInnerFuture; + + actix_service::forward_ready!(service); + + fn call(&self, req: Connect) -> Self::Future { + let mut req = ConnectInfo::new(req.uri).set_addr(req.addr); + + if let Some(local_addr) = self.local_address { + req = req.set_local_addr(local_addr); + } + + TcpConnectorInnerFuture { + fut: self.service.call(req), + timeout: sleep(self.timeout), + } + } +} + +pin_project! { + #[project = TcpConnectorInnerFutureProj] + pub struct TcpConnectorInnerFuture { + #[pin] + fut: Fut, + #[pin] + timeout: Sleep, + } +} + +impl Future for TcpConnectorInnerFuture +where + Fut: Future, TcpConnectError>>, +{ + type Output = Result, ConnectError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.fut.poll(cx) { + Poll::Ready(res) => Poll::Ready(res.map_err(ConnectError::from)), + Poll::Pending => this.timeout.poll(cx).map(|_| Err(ConnectError::Timeout)), + } + } +} + +/// Connector service for pooled Plain/Tls Tcp connections. +pub type ConnectorService = ConnectorServicePriv< + TcpConnectorService>, + Rc< + dyn Service< + Connect, + Response = (Box, Protocol), + Error = ConnectError, + Future = LocalBoxFuture< + 'static, + Result<(Box, Protocol), ConnectError>, + >, + >, + >, + IO, + Box, +>; + +pub struct ConnectorServicePriv +where + S1: Service, + S2: Service, + Io1: ConnectionIo, + Io2: ConnectionIo, +{ + tcp_pool: ConnectionPool, + tls_pool: Option>, +} + +impl Service for ConnectorServicePriv +where + S1: Service + Clone + 'static, + S2: Service + Clone + 'static, + Io1: ConnectionIo, + Io2: ConnectionIo, +{ + type Response = Connection; + type Error = ConnectError; + type Future = ConnectorServiceFuture; + + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + ready!(self.tcp_pool.poll_ready(cx))?; + if let Some(ref tls_pool) = self.tls_pool { + ready!(tls_pool.poll_ready(cx))?; + } + Poll::Ready(Ok(())) + } + + fn call(&self, req: Connect) -> Self::Future { + match req.uri.scheme_str() { + Some("https") | Some("wss") => match self.tls_pool { + None => ConnectorServiceFuture::SslIsNotSupported, + Some(ref pool) => ConnectorServiceFuture::Tls { + fut: pool.call(req), + }, + }, + _ => ConnectorServiceFuture::Tcp { + fut: self.tcp_pool.call(req), + }, + } + } +} + +pin_project! { + #[project = ConnectorServiceFutureProj] + pub enum ConnectorServiceFuture + where + S1: Service, + S1: Clone, + S1: 'static, + S2: Service, + S2: Clone, + S2: 'static, + Io1: ConnectionIo, + Io2: ConnectionIo, + { + Tcp { + #[pin] + fut: as Service>::Future + }, + Tls { + #[pin] + fut: as Service>::Future + }, + SslIsNotSupported + } +} + +impl Future for ConnectorServiceFuture +where + S1: Service + Clone + 'static, + S2: Service + Clone + 'static, + Io1: ConnectionIo, + Io2: ConnectionIo, +{ + type Output = Result, ConnectError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project() { + ConnectorServiceFutureProj::Tcp { fut } => fut.poll(cx).map_ok(Connection::Tcp), + ConnectorServiceFutureProj::Tls { fut } => fut.poll(cx).map_ok(Connection::Tls), + ConnectorServiceFutureProj::SslIsNotSupported => { + Poll::Ready(Err(ConnectError::SslIsNotSupported)) + } + } + } +} + +#[cfg(not(feature = "trust-dns"))] +mod resolver { + use super::*; + + pub(super) fn resolver() -> Resolver { + Resolver::default() + } +} + +#[cfg(feature = "trust-dns")] +mod resolver { + use std::{cell::RefCell, net::SocketAddr}; + + use actix_tls::connect::Resolve; + use futures_core::future::LocalBoxFuture; + use trust_dns_resolver::{ + config::{ResolverConfig, ResolverOpts}, + system_conf::read_system_conf, + TokioAsyncResolver, + }; + + use super::*; + + pub(super) fn resolver() -> Resolver { + // new type for impl Resolve trait for TokioAsyncResolver. + struct TrustDnsResolver(TokioAsyncResolver); + + impl Resolve for TrustDnsResolver { + fn lookup<'a>( + &'a self, + host: &'a str, + port: u16, + ) -> LocalBoxFuture<'a, Result, Box>> + { + Box::pin(async move { + let res = self + .0 + .lookup_ip(host) + .await? + .iter() + .map(|ip| SocketAddr::new(ip, port)) + .collect(); + Ok(res) + }) + } + } + + // resolver struct is cached in thread local so new clients can reuse the existing instance + thread_local! { + static TRUST_DNS_RESOLVER: RefCell> = RefCell::new(None); + } + + // get from thread local or construct a new trust-dns resolver. + TRUST_DNS_RESOLVER.with(|local| { + let resolver = local.borrow().as_ref().map(Clone::clone); + + match resolver { + Some(resolver) => resolver, + + None => { + let (cfg, opts) = match read_system_conf() { + Ok((cfg, opts)) => (cfg, opts), + Err(e) => { + log::error!("TRust-DNS can not load system config: {}", e); + (ResolverConfig::default(), ResolverOpts::default()) + } + }; + + let resolver = TokioAsyncResolver::tokio(cfg, opts).unwrap(); + + // box trust dns resolver and put it in thread local. + let resolver = Resolver::custom(TrustDnsResolver(resolver)); + *local.borrow_mut() = Some(resolver.clone()); + + resolver + } + } + }) + } +} + +#[cfg(feature = "dangerous-h2c")] +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use actix_http::{HttpService, Request, Response, Version}; + use actix_http_test::test_server; + use actix_service::ServiceFactoryExt as _; + + use super::*; + use crate::Client; + + #[actix_rt::test] + async fn h2c_connector() { + let mut srv = test_server(|| { + HttpService::build() + .h2(|_req: Request| async { Ok::<_, Infallible>(Response::ok()) }) + .tcp() + .map_err(|_| ()) + }) + .await; + + let connector = Connector { + connector: TcpConnector::new(resolver::resolver()).service(), + config: ConnectorConfig::default(), + tls: OurTlsConnector::None, + }; + + let client = Client::builder().connector(connector).finish(); + + let request = client.get(srv.surl("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), Version::HTTP_2); + + srv.stop().await; + } +} diff --git a/actix-http/src/client/error.rs b/awc/src/client/error.rs similarity index 79% rename from actix-http/src/client/error.rs rename to awc/src/client/error.rs index 7768462b8..9f290c5c0 100644 --- a/actix-http/src/client/error.rs +++ b/awc/src/client/error.rs @@ -1,15 +1,17 @@ -use std::io; +use std::{fmt, io}; use derive_more::{Display, From}; -#[cfg(feature = "openssl")] -use actix_tls::accept::openssl::SslError; +use actix_http::error::{HttpError, ParseError}; -use crate::error::{Error, ParseError, ResponseError}; -use crate::http::{Error as HttpError, StatusCode}; +#[cfg(feature = "openssl")] +use actix_tls::accept::openssl::reexports::Error as OpensslError; + +use crate::BoxError; /// A set of errors that can occur while connecting to an HTTP host #[derive(Debug, Display, From)] +#[non_exhaustive] pub enum ConnectError { /// SSL feature is not enabled #[display(fmt = "SSL is not supported")] @@ -18,14 +20,14 @@ pub enum ConnectError { /// SSL error #[cfg(feature = "openssl")] #[display(fmt = "{}", _0)] - SslError(SslError), + SslError(OpensslError), /// Failed to resolve the hostname #[display(fmt = "Failed resolving hostname: {}", _0)] Resolver(Box), /// No dns records - #[display(fmt = "No dns records found for the input")] + #[display(fmt = "No DNS records found for the input")] NoRecords, /// Http2 error @@ -64,6 +66,7 @@ impl From for ConnectError { } #[derive(Debug, Display, From)] +#[non_exhaustive] pub enum InvalidUrl { #[display(fmt = "Missing URL scheme")] MissingScheme, @@ -82,6 +85,7 @@ impl std::error::Error for InvalidUrl {} /// A set of errors that can occur during request sending and response reading #[derive(Debug, Display, From)] +#[non_exhaustive] pub enum SendRequestError { /// Invalid URL #[display(fmt = "Invalid URL: {}", _0)] @@ -114,26 +118,18 @@ pub enum SendRequestError { TunnelNotSupported, /// Error sending request body - Body(Error), + Body(BoxError), + + /// Other errors that can occur after submitting a request. + #[display(fmt = "{:?}: {}", _1, _0)] + Custom(BoxError, Box), } impl std::error::Error for SendRequestError {} -/// Convert `SendRequestError` to a server `Response` -impl ResponseError for SendRequestError { - fn status_code(&self) -> StatusCode { - match *self { - SendRequestError::Connect(ConnectError::Timeout) => { - StatusCode::GATEWAY_TIMEOUT - } - SendRequestError::Connect(_) => StatusCode::BAD_REQUEST, - _ => StatusCode::INTERNAL_SERVER_ERROR, - } - } -} - /// A set of errors that can occur during freezing a request #[derive(Debug, Display, From)] +#[non_exhaustive] pub enum FreezeRequestError { /// Invalid URL #[display(fmt = "Invalid URL: {}", _0)] @@ -142,15 +138,20 @@ pub enum FreezeRequestError { /// HTTP error #[display(fmt = "{}", _0)] Http(HttpError), + + /// Other errors that can occur after submitting a request. + #[display(fmt = "{:?}: {}", _1, _0)] + Custom(BoxError, Box), } impl std::error::Error for FreezeRequestError {} impl From for SendRequestError { - fn from(e: FreezeRequestError) -> Self { - match e { - FreezeRequestError::Url(e) => e.into(), - FreezeRequestError::Http(e) => e.into(), + fn from(err: FreezeRequestError) -> Self { + match err { + FreezeRequestError::Url(err) => err.into(), + FreezeRequestError::Http(err) => err.into(), + FreezeRequestError::Custom(err, msg) => SendRequestError::Custom(err, msg), } } } diff --git a/awc/src/client/h1proto.rs b/awc/src/client/h1proto.rs new file mode 100644 index 000000000..cf716db72 --- /dev/null +++ b/awc/src/client/h1proto.rs @@ -0,0 +1,233 @@ +use std::{ + io::Write, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_codec::Framed; +use actix_http::{ + body::{BodySize, MessageBody}, + error::PayloadError, + h1, + header::{HeaderMap, TryIntoHeaderValue, EXPECT, HOST}, + Payload, RequestHeadType, ResponseHead, StatusCode, +}; +use actix_utils::future::poll_fn; +use bytes::{buf::BufMut, Bytes, BytesMut}; +use futures_core::{ready, Stream}; +use futures_util::SinkExt as _; +use pin_project_lite::pin_project; + +use crate::BoxError; + +use super::{ + connection::{ConnectionIo, H1Connection}, + error::{ConnectError, SendRequestError}, +}; + +pub(crate) async fn send_request( + io: H1Connection, + mut head: RequestHeadType, + body: B, +) -> Result<(ResponseHead, Payload), SendRequestError> +where + Io: ConnectionIo, + B: MessageBody, + B::Error: Into, +{ + // set request host header + if !head.as_ref().headers.contains_key(HOST) + && !head.extra_headers().iter().any(|h| h.contains_key(HOST)) + { + if let Some(host) = head.as_ref().uri.host() { + let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); + + match head.as_ref().uri.port_u16() { + None | Some(80) | Some(443) => write!(wrt, "{}", host)?, + Some(port) => write!(wrt, "{}:{}", host, port)?, + }; + + match wrt.get_mut().split().freeze().try_into_value() { + Ok(value) => match head { + RequestHeadType::Owned(ref mut head) => { + head.headers.insert(HOST, value); + } + RequestHeadType::Rc(_, ref mut extra_headers) => { + let headers = extra_headers.get_or_insert(HeaderMap::new()); + headers.insert(HOST, value); + } + }, + Err(e) => log::error!("Can not set HOST header {}", e), + } + } + } + + // create Framed and prepare sending request + let mut framed = Framed::new(io, h1::ClientCodec::default()); + + // Check EXPECT header and enable expect handle flag accordingly. + // See https://datatracker.ietf.org/doc/html/rfc7231#section-5.1.1 + let is_expect = if head.as_ref().headers.contains_key(EXPECT) { + match body.size() { + BodySize::None | BodySize::Sized(0) => { + let keep_alive = framed.codec_ref().keepalive(); + framed.io_mut().on_release(keep_alive); + + // TODO: use a new variant or a new type better describing error violate + // `Requirements for clients` session of above RFC + return Err(SendRequestError::Connect(ConnectError::Disconnected)); + } + _ => true, + } + } else { + false + }; + + framed.send((head, body.size()).into()).await?; + + let mut pin_framed = Pin::new(&mut framed); + + // special handle for EXPECT request. + let (do_send, mut res_head) = if is_expect { + let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx)) + .await + .ok_or(ConnectError::Disconnected)??; + + // return response head in case status code is not continue + // and current head would be used as final response head. + (head.status == StatusCode::CONTINUE, Some(head)) + } else { + (true, None) + }; + + if do_send { + // send request body + match body.size() { + BodySize::None | BodySize::Sized(0) => {} + _ => send_body(body, pin_framed.as_mut()).await?, + }; + + // read response and init read body + let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx)) + .await + .ok_or(ConnectError::Disconnected)??; + + res_head = Some(head); + } + + let head = res_head.unwrap(); + + match pin_framed.codec_ref().message_type() { + h1::MessageType::None => { + let keep_alive = pin_framed.codec_ref().keepalive(); + pin_framed.io_mut().on_release(keep_alive); + + Ok((head, Payload::None)) + } + _ => Ok(( + head, + Payload::Stream { + payload: Box::pin(PlStream::new(framed)), + }, + )), + } +} + +pub(crate) async fn open_tunnel( + io: Io, + head: RequestHeadType, +) -> Result<(ResponseHead, Framed), SendRequestError> +where + Io: ConnectionIo, +{ + // create Framed and send request. + let mut framed = Framed::new(io, h1::ClientCodec::default()); + framed.send((head, BodySize::None).into()).await?; + + // read response head. + let head = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx)) + .await + .ok_or(ConnectError::Disconnected)??; + + Ok((head, framed)) +} + +/// send request body to the peer +pub(crate) async fn send_body( + body: B, + mut framed: Pin<&mut Framed>, +) -> Result<(), SendRequestError> +where + Io: ConnectionIo, + B: MessageBody, + B::Error: Into, +{ + actix_rt::pin!(body); + + let mut eof = false; + while !eof { + while !eof && !framed.as_ref().is_write_buf_full() { + match poll_fn(|cx| body.as_mut().poll_next(cx)).await { + Some(Ok(chunk)) => { + framed.as_mut().write(h1::Message::Chunk(Some(chunk)))?; + } + Some(Err(err)) => return Err(SendRequestError::Body(err.into())), + None => { + eof = true; + framed.as_mut().write(h1::Message::Chunk(None))?; + } + } + } + + if !framed.as_ref().is_write_buf_empty() { + poll_fn(|cx| match framed.as_mut().flush(cx) { + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => { + if !framed.as_ref().is_write_buf_full() { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + }) + .await?; + } + } + + framed.get_mut().flush().await?; + Ok(()) +} + +pin_project! { + pub(crate) struct PlStream { + #[pin] + framed: Framed, h1::ClientPayloadCodec>, + } +} + +impl PlStream { + fn new(framed: Framed, h1::ClientCodec>) -> Self { + let framed = framed.into_map_codec(|codec| codec.into_payload_codec()); + + PlStream { framed } + } +} + +impl Stream for PlStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + match ready!(this.framed.as_mut().next_item(cx)?) { + Some(Some(chunk)) => Poll::Ready(Some(Ok(chunk))), + Some(None) => { + let keep_alive = this.framed.codec_ref().keepalive(); + this.framed.io_mut().on_release(keep_alive); + Poll::Ready(None) + } + None => Poll::Ready(None), + } + } +} diff --git a/actix-http/src/client/h2proto.rs b/awc/src/client/h2proto.rs similarity index 61% rename from actix-http/src/client/h2proto.rs rename to awc/src/client/h2proto.rs index a70bc1738..709896ddd 100644 --- a/actix-http/src/client/h2proto.rs +++ b/awc/src/client/h2proto.rs @@ -1,47 +1,44 @@ -use std::convert::TryFrom; use std::future::Future; -use std::time; -use actix_codec::{AsyncRead, AsyncWrite}; +use actix_utils::future::poll_fn; use bytes::Bytes; -use futures_util::future::poll_fn; -use futures_util::pin_mut; use h2::{ client::{Builder, Connection, SendRequest}, SendStream, }; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; use http::{request::Request, Method, Version}; +use log::trace; -use crate::body::{BodySize, MessageBody}; -use crate::header::HeaderMap; -use crate::message::{RequestHeadType, ResponseHead}; -use crate::payload::Payload; +use actix_http::{ + body::{BodySize, MessageBody}, + header::HeaderMap, + Payload, RequestHeadType, ResponseHead, +}; -use super::config::ConnectorConfig; -use super::connection::{ConnectionType, IoConnection}; -use super::error::SendRequestError; -use super::pool::Acquired; -use crate::client::connection::H2Connection; +use crate::BoxError; -pub(crate) async fn send_request( - mut io: H2Connection, +use super::{ + config::ConnectorConfig, + connection::{ConnectionIo, H2Connection}, + error::SendRequestError, +}; + +pub(crate) async fn send_request( + mut io: H2Connection, head: RequestHeadType, body: B, - created: time::Instant, - pool: Option>, ) -> Result<(ResponseHead, Payload), SendRequestError> where - T: AsyncRead + AsyncWrite + Unpin + 'static, + Io: ConnectionIo, B: MessageBody, + B::Error: Into, { trace!("Sending client request: {:?} {:?}", head, body.size()); + let head_req = head.as_ref().method == Method::HEAD; let length = body.size(); - let eof = matches!( - length, - BodySize::None | BodySize::Empty | BodySize::Sized(0) - ); + let eof = matches!(length, BodySize::None | BodySize::Sized(0)); let mut req = Request::new(()); *req.uri_mut() = head.as_ref().uri.clone(); @@ -54,17 +51,26 @@ where // Content length let _ = match length { BodySize::None => None, + + BodySize::Sized(0) => { + #[allow(clippy::declare_interior_mutable_const)] + const HV_ZERO: HeaderValue = HeaderValue::from_static("0"); + req.headers_mut().insert(CONTENT_LENGTH, HV_ZERO) + } + + BodySize::Sized(len) => { + let mut buf = itoa::Buffer::new(); + + req.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::from_str(buf.format(len)).unwrap(), + ) + } + BodySize::Stream => { skip_len = false; None } - BodySize::Empty => req - .headers_mut() - .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), - BodySize::Sized(len) => req.headers_mut().insert( - CONTENT_LENGTH, - HeaderValue::try_from(format!("{}", len)).unwrap(), - ), }; // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. @@ -87,7 +93,10 @@ where // copy headers for (key, value) in headers { match *key { - CONNECTION | TRANSFER_ENCODING => continue, // http2 specific + // TODO: consider skipping other headers according to: + // https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2.2 + // omit HTTP/1.x only headers + CONNECTION | TRANSFER_ENCODING => continue, CONTENT_LENGTH if skip_len => continue, // DATE => has_date = true, _ => {} @@ -97,13 +106,13 @@ where let res = poll_fn(|cx| io.poll_ready(cx)).await; if let Err(e) = res { - release(io, pool, created, e.is_io()); + io.on_release(e.is_io()); return Err(SendRequestError::from(e)); } let resp = match io.send_request(req, eof) { Ok((fut, send)) => { - release(io, pool, created, false); + io.on_release(false); if !eof { send_body(body, send).await?; @@ -111,7 +120,7 @@ where fut.await.map_err(SendRequestError::from)? } Err(e) => { - release(io, pool, created, e.is_io()); + io.on_release(e.is_io()); return Err(e.into()); } }; @@ -125,12 +134,15 @@ where Ok((head, payload)) } -async fn send_body( - body: B, - mut send: SendStream, -) -> Result<(), SendRequestError> { +async fn send_body(body: B, mut send: SendStream) -> Result<(), SendRequestError> +where + B: MessageBody, + B::Error: Into, +{ let mut buf = None; - pin_mut!(body); + + actix_rt::pin!(body); + loop { if buf.is_none() { match poll_fn(|cx| body.as_mut().poll_next(cx)).await { @@ -138,10 +150,10 @@ async fn send_body( send.reserve_capacity(b.len()); buf = Some(b); } - Some(Err(e)) => return Err(e.into()), + Some(Err(err)) => return Err(SendRequestError::Body(err.into())), None => { - if let Err(e) = send.send_data(Bytes::new(), true) { - return Err(e.into()); + if let Err(err) = send.send_data(Bytes::new(), true) { + return Err(err.into()); } send.reserve_capacity(0); return Ok(()); @@ -158,43 +170,23 @@ async fn send_body( if let Err(e) = send.send_data(bytes, false) { return Err(e.into()); - } else { - if !b.is_empty() { - send.reserve_capacity(b.len()); - } else { - buf = None; - } - continue; } + if !b.is_empty() { + send.reserve_capacity(b.len()); + } else { + buf = None; + } + continue; } Some(Err(e)) => return Err(e.into()), } } } -/// release SendRequest object -fn release( - io: H2Connection, - pool: Option>, - created: time::Instant, - close: bool, -) { - if let Some(mut pool) = pool { - if close { - pool.close(IoConnection::new(ConnectionType::H2(io), created, None)); - } else { - pool.release(IoConnection::new(ConnectionType::H2(io), created, None)); - } - } -} - -pub(crate) fn handshake( +pub(crate) fn handshake( io: Io, config: &ConnectorConfig, -) -> impl Future, Connection), h2::Error>> -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ +) -> impl Future, Connection), h2::Error>> { let mut builder = Builder::new(); builder .initial_window_size(config.stream_window_size) diff --git a/awc/src/client/mod.rs b/awc/src/client/mod.rs new file mode 100644 index 000000000..d5854d83e --- /dev/null +++ b/awc/src/client/mod.rs @@ -0,0 +1,203 @@ +//! HTTP client. + +use std::{convert::TryFrom, rc::Rc, time::Duration}; + +use actix_http::{error::HttpError, header::HeaderMap, Method, RequestHead, Uri}; +use actix_rt::net::TcpStream; +use actix_service::Service; +pub use actix_tls::connect::{ + ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection, +}; + +use crate::{ws, BoxConnectorService, ClientBuilder, ClientRequest}; + +mod config; +mod connection; +mod connector; +mod error; +mod h1proto; +mod h2proto; +mod pool; + +pub use self::connection::{Connection, ConnectionIo}; +pub use self::connector::{Connector, ConnectorService}; +pub use self::error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}; + +#[derive(Clone)] +pub struct Connect { + pub uri: Uri, + pub addr: Option, +} + +/// An asynchronous HTTP and WebSocket client. +/// +/// You should take care to create, at most, one `Client` per thread. Otherwise, expect higher CPU +/// and memory usage. +/// +/// # Examples +/// ``` +/// use awc::Client; +/// +/// #[actix_rt::main] +/// async fn main() { +/// let mut client = Client::default(); +/// +/// let res = client.get("http://www.rust-lang.org") +/// .insert_header(("User-Agent", "my-app/1.2")) +/// .send() +/// .await; +/// +/// println!("Response: {:?}", res); +/// } +/// ``` +#[derive(Clone)] +pub struct Client(pub(crate) ClientConfig); + +#[derive(Clone)] +pub(crate) struct ClientConfig { + pub(crate) connector: BoxConnectorService, + pub(crate) default_headers: Rc, + pub(crate) timeout: Option, +} + +impl Default for Client { + fn default() -> Self { + ClientBuilder::new().finish() + } +} + +impl Client { + /// Create new client instance with default settings. + pub fn new() -> Client { + Client::default() + } + + /// Create `Client` builder. + /// This function is equivalent of `ClientBuilder::new()`. + pub fn builder() -> ClientBuilder< + impl Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone, + > { + ClientBuilder::new() + } + + /// Construct HTTP request. + pub fn request(&self, method: Method, url: U) -> ClientRequest + where + Uri: TryFrom, + >::Error: Into, + { + let mut req = ClientRequest::new(method, url, self.0.clone()); + + for header in self.0.default_headers.iter() { + // header map is empty + // TODO: probably append instead + req = req.insert_header_if_none(header); + } + req + } + + /// Create `ClientRequest` from `RequestHead` + /// + /// It is useful for proxy requests. This implementation + /// copies all headers and the method. + pub fn request_from(&self, url: U, head: &RequestHead) -> ClientRequest + where + Uri: TryFrom, + >::Error: Into, + { + let mut req = self.request(head.method.clone(), url); + for header in head.headers.iter() { + req = req.insert_header_if_none(header); + } + req + } + + /// Construct HTTP *GET* request. + pub fn get(&self, url: U) -> ClientRequest + where + Uri: TryFrom, + >::Error: Into, + { + self.request(Method::GET, url) + } + + /// Construct HTTP *HEAD* request. + pub fn head(&self, url: U) -> ClientRequest + where + Uri: TryFrom, + >::Error: Into, + { + self.request(Method::HEAD, url) + } + + /// Construct HTTP *PUT* request. + pub fn put(&self, url: U) -> ClientRequest + where + Uri: TryFrom, + >::Error: Into, + { + self.request(Method::PUT, url) + } + + /// Construct HTTP *POST* request. + pub fn post(&self, url: U) -> ClientRequest + where + Uri: TryFrom, + >::Error: Into, + { + self.request(Method::POST, url) + } + + /// Construct HTTP *PATCH* request. + pub fn patch(&self, url: U) -> ClientRequest + where + Uri: TryFrom, + >::Error: Into, + { + self.request(Method::PATCH, url) + } + + /// Construct HTTP *DELETE* request. + pub fn delete(&self, url: U) -> ClientRequest + where + Uri: TryFrom, + >::Error: Into, + { + self.request(Method::DELETE, url) + } + + /// Construct HTTP *OPTIONS* request. + pub fn options(&self, url: U) -> ClientRequest + where + Uri: TryFrom, + >::Error: Into, + { + self.request(Method::OPTIONS, url) + } + + /// Initialize a WebSocket connection. + /// Returns a WebSocket connection builder. + pub fn ws(&self, url: U) -> ws::WebsocketsRequest + where + Uri: TryFrom, + >::Error: Into, + { + let mut req = ws::WebsocketsRequest::new(url, self.0.clone()); + for (key, value) in self.0.default_headers.iter() { + req.head.headers.insert(key.clone(), value.clone()); + } + req + } + + /// Get default HeaderMap of Client. + /// + /// Returns Some(&mut HeaderMap) when Client object is unique + /// (No other clone of client exists at the same time). + pub fn headers(&mut self) -> Option<&mut HeaderMap> { + Rc::get_mut(&mut self.0.default_headers) + } +} diff --git a/awc/src/client/pool.rs b/awc/src/client/pool.rs new file mode 100644 index 000000000..9d130412b --- /dev/null +++ b/awc/src/client/pool.rs @@ -0,0 +1,658 @@ +//! Client connection pooling keyed on the authority part of the connection URI. + +use std::{ + cell::RefCell, + collections::VecDeque, + future::Future, + io, + ops::Deref, + pin::Pin, + rc::Rc, + sync::Arc, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use actix_codec::{AsyncRead, AsyncWrite, ReadBuf}; +use actix_http::Protocol; +use actix_rt::time::{sleep, Sleep}; +use actix_service::Service; +use ahash::AHashMap; +use futures_core::future::LocalBoxFuture; +use futures_util::FutureExt; +use http::uri::Authority; +use pin_project_lite::pin_project; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; + +use super::config::ConnectorConfig; +use super::connection::{ConnectionInnerType, ConnectionIo, ConnectionType, H2ConnectionInner}; +use super::error::ConnectError; +use super::h2proto::handshake; +use super::Connect; + +#[derive(Hash, Eq, PartialEq, Clone, Debug)] +pub struct Key { + authority: Authority, +} + +impl From for Key { + fn from(authority: Authority) -> Key { + Key { authority } + } +} + +#[doc(hidden)] +/// Connections pool for reuse Io type for certain [`http::uri::Authority`] as key. +pub struct ConnectionPool +where + Io: AsyncWrite + Unpin + 'static, +{ + connector: S, + inner: ConnectionPoolInner, +} + +/// wrapper type for check the ref count of Rc. +pub struct ConnectionPoolInner(Rc>) +where + Io: AsyncWrite + Unpin + 'static; + +impl ConnectionPoolInner +where + Io: AsyncWrite + Unpin + 'static, +{ + fn new(config: ConnectorConfig) -> Self { + let permits = Arc::new(Semaphore::new(config.limit)); + let available = RefCell::new(AHashMap::default()); + + Self(Rc::new(ConnectionPoolInnerPriv { + config, + available, + permits, + })) + } + + /// spawn a async for graceful shutdown h1 Io type with a timeout. + fn close(&self, conn: ConnectionInnerType) { + if let Some(timeout) = self.config.disconnect_timeout { + if let ConnectionInnerType::H1(io) = conn { + actix_rt::spawn(CloseConnection::new(io, timeout)); + } + } + } +} + +impl Clone for ConnectionPoolInner +where + Io: AsyncWrite + Unpin + 'static, +{ + fn clone(&self) -> Self { + Self(Rc::clone(&self.0)) + } +} + +impl Deref for ConnectionPoolInner +where + Io: AsyncWrite + Unpin + 'static, +{ + type Target = ConnectionPoolInnerPriv; + + fn deref(&self) -> &Self::Target { + &*self.0 + } +} + +impl Drop for ConnectionPoolInner +where + Io: AsyncWrite + Unpin + 'static, +{ + fn drop(&mut self) { + // When strong count is one it means the pool is dropped + // remove and drop all Io types. + if Rc::strong_count(&self.0) == 1 { + self.permits.close(); + std::mem::take(&mut *self.available.borrow_mut()) + .into_iter() + .for_each(|(_, conns)| { + conns.into_iter().for_each(|pooled| self.close(pooled.conn)) + }); + } + } +} + +pub struct ConnectionPoolInnerPriv +where + Io: AsyncWrite + Unpin + 'static, +{ + config: ConnectorConfig, + available: RefCell>>>, + permits: Arc, +} + +impl ConnectionPool +where + Io: AsyncWrite + Unpin + 'static, +{ + /// Construct a new connection pool. + /// + /// [`super::config::ConnectorConfig`]'s `limit` is used as the max permits allowed for + /// in-flight connections. + /// + /// The pool can only have equal to `limit` amount of requests spawning/using Io type + /// concurrently. + /// + /// Any requests beyond limit would be wait in fifo order and get notified in async manner + /// by [`tokio::sync::Semaphore`] + pub(crate) fn new(connector: S, config: ConnectorConfig) -> Self { + let inner = ConnectionPoolInner::new(config); + + Self { connector, inner } + } +} + +impl Service for ConnectionPool +where + S: Service + Clone + 'static, + Io: ConnectionIo, +{ + type Response = ConnectionType; + type Error = ConnectError; + type Future = LocalBoxFuture<'static, Result>; + + actix_service::forward_ready!(connector); + + fn call(&self, req: Connect) -> Self::Future { + let connector = self.connector.clone(); + let inner = self.inner.clone(); + + Box::pin(async move { + let key = if let Some(authority) = req.uri.authority() { + authority.clone().into() + } else { + return Err(ConnectError::Unresolved); + }; + + // acquire an owned permit and carry it with connection + let permit = inner.permits.clone().acquire_owned().await.map_err(|_| { + ConnectError::Io(io::Error::new( + io::ErrorKind::Other, + "failed to acquire semaphore on client connection pool", + )) + })?; + + let conn = { + let mut conn = None; + + // check if there is idle connection for given key. + let mut map = inner.available.borrow_mut(); + + if let Some(conns) = map.get_mut(&key) { + let now = Instant::now(); + + while let Some(mut c) = conns.pop_front() { + let config = &inner.config; + let idle_dur = now - c.used; + let age = now - c.created; + let conn_ineligible = + idle_dur > config.conn_keep_alive || age > config.conn_lifetime; + + if conn_ineligible { + // drop connections that are too old + inner.close(c.conn); + } else { + // check if the connection is still usable + if let ConnectionInnerType::H1(ref mut io) = c.conn { + let check = ConnectionCheckFuture { io }; + match check.now_or_never().expect("ConnectionCheckFuture must never yield with Poll::Pending.") { + ConnectionState::Tainted => { + inner.close(c.conn); + continue; + } + ConnectionState::Skip => continue, + ConnectionState::Live => conn = Some(c), + } + } else { + conn = Some(c); + } + + break; + } + } + }; + + conn + }; + + // construct acquired. It's used to put Io type back to pool/ close the Io type. + // permit is carried with the whole lifecycle of Acquired. + let acquired = Acquired { key, inner, permit }; + + // match the connection and spawn new one if did not get anything. + match conn { + Some(conn) => Ok(ConnectionType::from_pool(conn.conn, conn.created, acquired)), + None => { + let (io, proto) = connector.call(req).await?; + + // TODO: remove when http3 is added in support. + assert!(proto != Protocol::Http3); + + if proto == Protocol::Http1 { + Ok(ConnectionType::from_h1(io, Instant::now(), acquired)) + } else { + let config = &acquired.inner.config; + let (sender, connection) = handshake(io, config).await?; + let inner = H2ConnectionInner::new(sender, connection); + Ok(ConnectionType::from_h2(inner, Instant::now(), acquired)) + } + } + } + }) + } +} + +/// Type for check the connection and determine if it's usable. +struct ConnectionCheckFuture<'a, Io> { + io: &'a mut Io, +} + +enum ConnectionState { + /// IO is pending and a new request would wake it. + Live, + + /// IO unexpectedly has unread data and should be dropped. + Tainted, + + /// IO should be skipped but not dropped. + Skip, +} + +impl Future for ConnectionCheckFuture<'_, Io> +where + Io: AsyncRead + Unpin, +{ + type Output = ConnectionState; + + // this future is only used to get access to Context. + // It should never return Poll::Pending. + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + let mut buf = [0; 2]; + let mut read_buf = ReadBuf::new(&mut buf); + + let state = match Pin::new(&mut this.io).poll_read(cx, &mut read_buf) { + Poll::Ready(Ok(())) if !read_buf.filled().is_empty() => ConnectionState::Tainted, + + Poll::Pending => ConnectionState::Live, + _ => ConnectionState::Skip, + }; + + Poll::Ready(state) + } +} + +struct PooledConnection { + conn: ConnectionInnerType, + used: Instant, + created: Instant, +} + +pin_project! { + #[project = CloseConnectionProj] + struct CloseConnection { + io: Io, + #[pin] + timeout: Sleep, + } +} + +impl CloseConnection +where + Io: AsyncWrite + Unpin, +{ + fn new(io: Io, timeout: Duration) -> Self { + CloseConnection { + io, + timeout: sleep(timeout), + } + } +} + +impl Future for CloseConnection +where + Io: AsyncWrite + Unpin, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let this = self.project(); + + match this.timeout.poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Pin::new(this.io).poll_shutdown(cx).map(|_| ()), + } + } +} + +pub struct Acquired +where + Io: AsyncWrite + Unpin + 'static, +{ + /// authority key for identify connection. + key: Key, + /// handle to connection pool. + inner: ConnectionPoolInner, + /// permit for limit concurrent in-flight connection for a Client object. + permit: OwnedSemaphorePermit, +} + +impl Acquired { + /// Close the IO. + pub(super) fn close(&self, conn: ConnectionInnerType) { + self.inner.close(conn); + } + + /// Release IO back into pool. + pub(super) fn release(&self, conn: ConnectionInnerType, created: Instant) { + let Acquired { key, inner, .. } = self; + + inner + .available + .borrow_mut() + .entry(key.clone()) + .or_insert_with(VecDeque::new) + .push_back(PooledConnection { + conn, + created, + used: Instant::now(), + }); + + let _ = &self.permit; + } +} + +#[cfg(test)] +mod test { + use std::{cell::Cell, io}; + + use http::Uri; + + use super::*; + use crate::client::connection::ConnectionType; + + /// A stream type that always returns pending on async read. + /// + /// Mocks an idle TCP stream that is ready to be used for client connections. + struct TestStream(Rc>); + + impl Drop for TestStream { + fn drop(&mut self) { + self.0.set(self.0.get() - 1); + } + } + + impl AsyncRead for TestStream { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Pending + } + } + + impl AsyncWrite for TestStream { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &[u8], + ) -> Poll> { + unimplemented!() + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + unimplemented!() + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[derive(Clone)] + struct TestPoolConnector { + generated: Rc>, + } + + impl Service for TestPoolConnector { + type Response = (TestStream, Protocol); + type Error = ConnectError; + type Future = LocalBoxFuture<'static, Result>; + + actix_service::always_ready!(); + + fn call(&self, _: Connect) -> Self::Future { + self.generated.set(self.generated.get() + 1); + let generated = self.generated.clone(); + Box::pin(async { Ok((TestStream(generated), Protocol::Http1)) }) + } + } + + fn release(conn: ConnectionType) + where + T: AsyncRead + AsyncWrite + Unpin + 'static, + { + match conn { + ConnectionType::H1(mut conn) => conn.on_release(true), + ConnectionType::H2(mut conn) => conn.on_release(false), + } + } + + #[actix_rt::test] + async fn test_pool_limit() { + let connector = TestPoolConnector { + generated: Rc::new(Cell::new(0)), + }; + + let config = ConnectorConfig { + limit: 1, + ..Default::default() + }; + + let pool = super::ConnectionPool::new(connector, config); + + let req = Connect { + uri: Uri::from_static("http://localhost"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + + let waiting = Rc::new(Cell::new(true)); + + let waiting_clone = waiting.clone(); + actix_rt::spawn(async move { + actix_rt::time::sleep(Duration::from_millis(100)).await; + waiting_clone.set(false); + drop(conn); + }); + + assert!(waiting.get()); + + let now = Instant::now(); + let conn = pool.call(req).await.unwrap(); + + release(conn); + assert!(!waiting.get()); + assert!(now.elapsed() >= Duration::from_millis(100)); + } + + #[actix_rt::test] + async fn test_pool_keep_alive() { + let generated = Rc::new(Cell::new(0)); + let generated_clone = generated.clone(); + + let connector = TestPoolConnector { generated }; + + let config = ConnectorConfig { + conn_keep_alive: Duration::from_secs(1), + ..Default::default() + }; + + let pool = super::ConnectionPool::new(connector, config); + + let req = Connect { + uri: Uri::from_static("http://localhost"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + actix_rt::time::sleep(Duration::from_millis(1500)).await; + actix_rt::task::yield_now().await; + + let conn = pool.call(req).await.unwrap(); + // Note: spawned recycle connection is not ran yet. + // This is tokio current thread runtime specific behavior. + assert_eq!(2, generated_clone.get()); + + // yield task so the old connection is properly dropped. + actix_rt::task::yield_now().await; + assert_eq!(1, generated_clone.get()); + + release(conn); + } + + #[actix_rt::test] + async fn test_pool_lifetime() { + let generated = Rc::new(Cell::new(0)); + let generated_clone = generated.clone(); + + let connector = TestPoolConnector { generated }; + + let config = ConnectorConfig { + conn_lifetime: Duration::from_secs(1), + ..Default::default() + }; + + let pool = super::ConnectionPool::new(connector, config); + + let req = Connect { + uri: Uri::from_static("http://localhost"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + actix_rt::time::sleep(Duration::from_millis(1500)).await; + actix_rt::task::yield_now().await; + + let conn = pool.call(req).await.unwrap(); + // Note: spawned recycle connection is not ran yet. + // This is tokio current thread runtime specific behavior. + assert_eq!(2, generated_clone.get()); + + // yield task so the old connection is properly dropped. + actix_rt::task::yield_now().await; + assert_eq!(1, generated_clone.get()); + + release(conn); + } + + #[actix_rt::test] + async fn test_pool_authority_key() { + let generated = Rc::new(Cell::new(0)); + let generated_clone = generated.clone(); + + let connector = TestPoolConnector { generated }; + + let config = ConnectorConfig::default(); + + let pool = super::ConnectionPool::new(connector, config); + + let req = Connect { + uri: Uri::from_static("https://crates.io"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + let conn = pool.call(req).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + let req = Connect { + uri: Uri::from_static("https://google.com"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(2, generated_clone.get()); + release(conn); + let conn = pool.call(req).await.unwrap(); + assert_eq!(2, generated_clone.get()); + release(conn); + } + + #[actix_rt::test] + async fn test_pool_drop() { + let generated = Rc::new(Cell::new(0)); + let generated_clone = generated.clone(); + + let connector = TestPoolConnector { generated }; + + let config = ConnectorConfig::default(); + + let pool = Rc::new(super::ConnectionPool::new(connector, config)); + + let req = Connect { + uri: Uri::from_static("https://crates.io"), + addr: None, + }; + + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(1, generated_clone.get()); + release(conn); + + let req = Connect { + uri: Uri::from_static("https://google.com"), + addr: None, + }; + let conn = pool.call(req.clone()).await.unwrap(); + assert_eq!(2, generated_clone.get()); + release(conn); + + let clone1 = pool.clone(); + let clone2 = clone1.clone(); + + drop(clone2); + for _ in 0..2 { + actix_rt::task::yield_now().await; + } + assert_eq!(2, generated_clone.get()); + + drop(clone1); + for _ in 0..2 { + actix_rt::task::yield_now().await; + } + assert_eq!(2, generated_clone.get()); + + drop(pool); + for _ in 0..2 { + actix_rt::task::yield_now().await; + } + assert_eq!(0, generated_clone.get()); + } +} diff --git a/awc/src/connect.rs b/awc/src/connect.rs index a9b8f9f83..f93014a67 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -1,147 +1,170 @@ use std::{ - fmt, io, net, + future::Future, + net, pin::Pin, + rc::Rc, task::{Context, Poll}, }; -use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; -use actix_http::{ - body::Body, - client::{Connect as ClientConnect, ConnectError, Connection, SendRequestError}, - h1::ClientCodec, - RequestHead, RequestHeadType, ResponseHead, -}; +use actix_codec::Framed; +use actix_http::{h1::ClientCodec, Payload, RequestHead, RequestHeadType, ResponseHead}; use actix_service::Service; -use futures_core::future::LocalBoxFuture; +use futures_core::{future::LocalBoxFuture, ready}; -use crate::response::ClientResponse; +use crate::{ + any_body::AnyBody, + client::{ + Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError, + }, + ClientResponse, +}; -pub(crate) struct ConnectorWrapper(pub T); +pub type BoxConnectorService = Rc< + dyn Service< + ConnectRequest, + Response = ConnectResponse, + Error = SendRequestError, + Future = LocalBoxFuture<'static, Result>, + >, +>; -type TunnelResponse = (ResponseHead, Framed); +pub type BoxedSocket = Box; -pub(crate) trait Connect { - fn send_request( - &self, - head: RequestHeadType, - body: Body, - addr: Option, - ) -> LocalBoxFuture<'static, Result>; - - /// Send request, returns Response and Framed - fn open_tunnel( - &self, - head: RequestHead, - addr: Option, - ) -> LocalBoxFuture<'static, Result>; +pub enum ConnectRequest { + Client(RequestHeadType, AnyBody, Option), + Tunnel(RequestHead, Option), } -impl Connect for ConnectorWrapper +pub enum ConnectResponse { + Client(ClientResponse), + Tunnel(ResponseHead, Framed), +} + +impl ConnectResponse { + pub fn into_client_response(self) -> ClientResponse { + match self { + ConnectResponse::Client(res) => res, + _ => panic!( + "ClientResponse only reachable with ConnectResponse::ClientResponse variant" + ), + } + } + + pub fn into_tunnel_response(self) -> (ResponseHead, Framed) { + match self { + ConnectResponse::Tunnel(head, framed) => (head, framed), + _ => panic!( + "TunnelResponse only reachable with ConnectResponse::TunnelResponse variant" + ), + } + } +} + +pub struct DefaultConnector { + connector: S, +} + +impl DefaultConnector { + pub(crate) fn new(connector: S) -> Self { + Self { connector } + } +} + +impl Service for DefaultConnector where - T: Service, - T::Response: Connection, - ::Io: 'static, - ::Future: 'static, - ::TunnelFuture: 'static, - T::Future: 'static, + S: Service>, + Io: ConnectionIo, { - fn send_request( - &self, - head: RequestHeadType, - body: Body, - addr: Option, - ) -> LocalBoxFuture<'static, Result> { + type Response = ConnectResponse; + type Error = SendRequestError; + type Future = ConnectRequestFuture; + + actix_service::forward_ready!(connector); + + fn call(&self, req: ConnectRequest) -> Self::Future { // connect to the host - let fut = self.0.call(ClientConnect { - uri: head.as_ref().uri.clone(), - addr, - }); + let fut = match req { + ConnectRequest::Client(ref head, .., addr) => self.connector.call(ClientConnect { + uri: head.as_ref().uri.clone(), + addr, + }), + ConnectRequest::Tunnel(ref head, addr) => self.connector.call(ClientConnect { + uri: head.uri.clone(), + addr, + }), + }; - Box::pin(async move { - let connection = fut.await?; - - // send request - let (head, payload) = connection.send_request(head, body).await?; - - Ok(ClientResponse::new(head, payload)) - }) - } - - fn open_tunnel( - &self, - head: RequestHead, - addr: Option, - ) -> LocalBoxFuture<'static, Result> { - // connect to the host - let fut = self.0.call(ClientConnect { - uri: head.uri.clone(), - addr, - }); - - Box::pin(async move { - let connection = fut.await?; - - // send request - let (head, framed) = connection.open_tunnel(RequestHeadType::from(head)).await?; - - let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io)))); - Ok((head, framed)) - }) + ConnectRequestFuture::Connection { + fut, + req: Some(req), + } } } -trait AsyncSocket { - fn as_read(&self) -> &(dyn AsyncRead + Unpin); - fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin); - fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin); -} - -struct Socket(T); - -impl AsyncSocket for Socket { - fn as_read(&self) -> &(dyn AsyncRead + Unpin) { - &self.0 - } - fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin) { - &mut self.0 - } - fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin) { - &mut self.0 +pin_project_lite::pin_project! { + #[project = ConnectRequestProj] + pub enum ConnectRequestFuture + where + Io: ConnectionIo + { + Connection { + #[pin] + fut: Fut, + req: Option + }, + Client { + fut: LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> + }, + Tunnel { + fut: LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed, ClientCodec>), SendRequestError>, + >, + } } } -pub struct BoxedSocket(Box); +impl Future for ConnectRequestFuture +where + Fut: Future, ConnectError>>, + Io: ConnectionIo, +{ + type Output = Result; -impl fmt::Debug for BoxedSocket { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "BoxedSocket") - } -} - -impl AsyncRead for BoxedSocket { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf) - } -} - -impl AsyncWrite for BoxedSocket { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(self.get_mut().0.as_write()).poll_write(cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(self.get_mut().0.as_write()).poll_flush(cx) - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(self.get_mut().0.as_write()).poll_shutdown(cx) + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project() { + ConnectRequestProj::Connection { fut, req } => { + let connection = ready!(fut.poll(cx))?; + let req = req.take().unwrap(); + match req { + ConnectRequest::Client(head, body, ..) => { + // send request + let fut = ConnectRequestFuture::Client { + fut: connection.send_request(head, body), + }; + self.set(fut); + } + ConnectRequest::Tunnel(head, ..) => { + // send request + let fut = ConnectRequestFuture::Tunnel { + fut: connection.open_tunnel(RequestHeadType::from(head)), + }; + self.set(fut); + } + } + self.poll(cx) + } + ConnectRequestProj::Client { fut } => { + let (head, payload) = ready!(fut.as_mut().poll(cx))?; + Poll::Ready(Ok(ConnectResponse::Client(ClientResponse::new( + head, payload, + )))) + } + ConnectRequestProj::Tunnel { fut } => { + let (head, framed) = ready!(fut.as_mut().poll(cx))?; + let framed = framed.into_map_io(|io| Box::new(io) as _); + Poll::Ready(Ok(ConnectResponse::Tunnel(head, framed))) + } + } } } diff --git a/awc/src/error.rs b/awc/src/error.rs index c60339f76..aa9dc4d99 100644 --- a/awc/src/error.rs +++ b/awc/src/error.rs @@ -1,15 +1,19 @@ -//! Http client errors -pub use actix_http::client::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}; -pub use actix_http::error::PayloadError; -pub use actix_http::http::Error as HttpError; -pub use actix_http::ws::HandshakeError as WsHandshakeError; -pub use actix_http::ws::ProtocolError as WsProtocolError; +//! HTTP client errors -use actix_http::ResponseError; +// TODO: figure out how best to expose http::Error vs actix_http::Error +pub use actix_http::{ + error::{HttpError, PayloadError}, + header::HeaderValue, + ws::{HandshakeError as WsHandshakeError, ProtocolError as WsProtocolError}, + StatusCode, +}; + +use derive_more::{Display, From}; use serde_json::error::Error as JsonError; -use actix_http::http::{header::HeaderValue, StatusCode}; -use derive_more::{Display, From}; +pub use crate::client::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}; + +// TODO: address display, error, and from impls /// Websocket client error #[derive(Debug, Display, From)] @@ -17,24 +21,31 @@ pub enum WsClientError { /// Invalid response status #[display(fmt = "Invalid response status")] InvalidResponseStatus(StatusCode), + /// Invalid upgrade header #[display(fmt = "Invalid upgrade header")] InvalidUpgradeHeader, + /// Invalid connection header #[display(fmt = "Invalid connection header")] InvalidConnectionHeader(HeaderValue), - /// Missing CONNECTION header - #[display(fmt = "Missing CONNECTION header")] + + /// Missing Connection header + #[display(fmt = "Missing Connection header")] MissingConnectionHeader, - /// Missing SEC-WEBSOCKET-ACCEPT header - #[display(fmt = "Missing SEC-WEBSOCKET-ACCEPT header")] + + /// Missing Sec-Websocket-Accept header + #[display(fmt = "Missing Sec-Websocket-Accept header")] MissingWebSocketAcceptHeader, + /// Invalid challenge response #[display(fmt = "Invalid challenge response")] - InvalidChallengeResponse(String, HeaderValue), + InvalidChallengeResponse([u8; 28], HeaderValue), + /// Protocol error #[display(fmt = "{}", _0)] Protocol(WsProtocolError), + /// Send request error #[display(fmt = "{}", _0)] SendRequest(SendRequestError), @@ -69,6 +80,3 @@ pub enum JsonPayloadError { } impl std::error::Error for JsonPayloadError {} - -/// Return `InternalServerError` for `JsonPayloadError` -impl ResponseError for JsonPayloadError {} diff --git a/awc/src/frozen.rs b/awc/src/frozen.rs index 46b4063a0..4023bd1c8 100644 --- a/awc/src/frozen.rs +++ b/awc/src/frozen.rs @@ -1,21 +1,24 @@ -use std::convert::TryFrom; -use std::net; -use std::rc::Rc; -use std::time::Duration; +use std::{net, rc::Rc, time::Duration}; use bytes::Bytes; use futures_core::Stream; use serde::Serialize; -use actix_http::body::Body; -use actix_http::http::header::IntoHeaderValue; -use actix_http::http::{Error as HttpError, HeaderMap, HeaderName, Method, Uri}; -use actix_http::{Error, RequestHead}; +use actix_http::{ + body::MessageBody, + error::HttpError, + header::{HeaderMap, TryIntoHeaderPair}, + Method, RequestHead, Uri, +}; -use crate::sender::{RequestSender, SendClientRequest}; -use crate::ClientConfig; +use crate::{ + client::ClientConfig, + sender::{RequestSender, SendClientRequest}, + BoxError, +}; -/// `FrozenClientRequest` struct represents clonable client request. +/// `FrozenClientRequest` struct represents cloneable client request. +/// /// It could be used to send same request multiple times. #[derive(Clone)] pub struct FrozenClientRequest { @@ -23,7 +26,7 @@ pub struct FrozenClientRequest { pub(crate) addr: Option, pub(crate) response_decompress: bool, pub(crate) timeout: Option, - pub(crate) config: Rc, + pub(crate) config: ClientConfig, } impl FrozenClientRequest { @@ -45,13 +48,13 @@ impl FrozenClientRequest { /// Send a body. pub fn send_body(&self, body: B) -> SendClientRequest where - B: Into, + B: MessageBody + 'static, { RequestSender::Rc(self.head.clone(), None).send_body( self.addr, self.response_decompress, self.timeout, - self.config.as_ref(), + &self.config, body, ) } @@ -62,7 +65,7 @@ impl FrozenClientRequest { self.addr, self.response_decompress, self.timeout, - self.config.as_ref(), + &self.config, value, ) } @@ -73,7 +76,7 @@ impl FrozenClientRequest { self.addr, self.response_decompress, self.timeout, - self.config.as_ref(), + &self.config, value, ) } @@ -81,14 +84,14 @@ impl FrozenClientRequest { /// Send a streaming body. pub fn send_stream(&self, stream: S) -> SendClientRequest where - S: Stream> + Unpin + 'static, - E: Into + 'static, + S: Stream> + 'static, + E: Into + 'static, { RequestSender::Rc(self.head.clone(), None).send_stream( self.addr, self.response_decompress, self.timeout, - self.config.as_ref(), + &self.config, stream, ) } @@ -99,24 +102,18 @@ impl FrozenClientRequest { self.addr, self.response_decompress, self.timeout, - self.config.as_ref(), + &self.config, ) } - /// Create a `FrozenSendBuilder` with extra headers + /// Clones this `FrozenClientRequest`, returning a new one with extra headers added. pub fn extra_headers(&self, extra_headers: HeaderMap) -> FrozenSendBuilder { FrozenSendBuilder::new(self.clone(), extra_headers) } - /// Create a `FrozenSendBuilder` with an extra header - pub fn extra_header(&self, key: K, value: V) -> FrozenSendBuilder - where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, - { - self.extra_headers(HeaderMap::new()) - .extra_header(key, value) + /// Clones this `FrozenClientRequest`, returning a new one with the extra header added. + pub fn extra_header(&self, header: impl TryIntoHeaderPair) -> FrozenSendBuilder { + self.extra_headers(HeaderMap::new()).extra_header(header) } } @@ -137,29 +134,20 @@ impl FrozenSendBuilder { } /// Insert a header, it overrides existing header in `FrozenClientRequest`. - pub fn extra_header(mut self, key: K, value: V) -> Self - where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, - { - match HeaderName::try_from(key) { - Ok(key) => match value.try_into_value() { - Ok(value) => { - self.extra_headers.insert(key, value); - } - Err(e) => self.err = Some(e.into()), - }, - Err(e) => self.err = Some(e.into()), + pub fn extra_header(mut self, header: impl TryIntoHeaderPair) -> Self { + match header.try_into_pair() { + Ok((key, value)) => { + self.extra_headers.insert(key, value); + } + + Err(err) => self.err = Some(err.into()), } + self } /// Complete request construction and send a body. - pub fn send_body(self, body: B) -> SendClientRequest - where - B: Into, - { + pub fn send_body(self, body: impl MessageBody + 'static) -> SendClientRequest { if let Some(e) = self.err { return e.into(); } @@ -168,28 +156,28 @@ impl FrozenSendBuilder { self.req.addr, self.req.response_decompress, self.req.timeout, - self.req.config.as_ref(), + &self.req.config, body, ) } /// Complete request construction and send a json body. - pub fn send_json(self, value: &T) -> SendClientRequest { - if let Some(e) = self.err { - return e.into(); + pub fn send_json(self, value: impl Serialize) -> SendClientRequest { + if let Some(err) = self.err { + return err.into(); } RequestSender::Rc(self.req.head, Some(self.extra_headers)).send_json( self.req.addr, self.req.response_decompress, self.req.timeout, - self.req.config.as_ref(), + &self.req.config, value, ) } /// Complete request construction and send an urlencoded body. - pub fn send_form(self, value: &T) -> SendClientRequest { + pub fn send_form(self, value: impl Serialize) -> SendClientRequest { if let Some(e) = self.err { return e.into(); } @@ -198,7 +186,7 @@ impl FrozenSendBuilder { self.req.addr, self.req.response_decompress, self.req.timeout, - self.req.config.as_ref(), + &self.req.config, value, ) } @@ -206,8 +194,8 @@ impl FrozenSendBuilder { /// Complete request construction and send a streaming body. pub fn send_stream(self, stream: S) -> SendClientRequest where - S: Stream> + Unpin + 'static, - E: Into + 'static, + S: Stream> + 'static, + E: Into + 'static, { if let Some(e) = self.err { return e.into(); @@ -217,7 +205,7 @@ impl FrozenSendBuilder { self.req.addr, self.req.response_decompress, self.req.timeout, - self.req.config.as_ref(), + &self.req.config, stream, ) } @@ -232,7 +220,7 @@ impl FrozenSendBuilder { self.req.addr, self.req.response_decompress, self.req.timeout, - self.req.config.as_ref(), + &self.req.config, ) } } diff --git a/awc/src/lib.rs b/awc/src/lib.rs index e50c19c8c..970ca2d92 100644 --- a/awc/src/lib.rs +++ b/awc/src/lib.rs @@ -1,7 +1,6 @@ //! `awc` is a HTTP and WebSocket client library built on the Actix ecosystem. //! -//! ## Making a GET request -//! +//! # Making a GET request //! ```no_run //! # #[actix_rt::main] //! # async fn main() -> Result<(), awc::error::SendRequestError> { @@ -16,10 +15,8 @@ //! # } //! ``` //! -//! ## Making POST requests -//! -//! ### Raw body contents -//! +//! # Making POST requests +//! ## Raw body contents //! ```no_run //! # #[actix_rt::main] //! # async fn main() -> Result<(), awc::error::SendRequestError> { @@ -31,8 +28,7 @@ //! # } //! ``` //! -//! ### Forms -//! +//! ## Forms //! ```no_run //! # #[actix_rt::main] //! # async fn main() -> Result<(), awc::error::SendRequestError> { @@ -46,8 +42,7 @@ //! # } //! ``` //! -//! ### JSON -//! +//! ## JSON //! ```no_run //! # #[actix_rt::main] //! # async fn main() -> Result<(), awc::error::SendRequestError> { @@ -64,8 +59,24 @@ //! # } //! ``` //! -//! ## WebSocket support +//! # Response Compression +//! All [official][iana-encodings] and common content encoding codecs are supported, optionally. //! +//! The `Accept-Encoding` header will automatically be populated with enabled codecs and added to +//! outgoing requests, allowing servers to select their `Content-Encoding` accordingly. +//! +//! Feature flags enable these codecs according to the table below. By default, all `compress-*` +//! features are enabled. +//! +//! | Feature | Codecs | +//! | ----------------- | ------------- | +//! | `compress-brotli` | brotli | +//! | `compress-gzip` | gzip, deflate | +//! | `compress-zstd` | zstd | +//! +//! [iana-encodings]: https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding +//! +//! # WebSocket support //! ```no_run //! # #[actix_rt::main] //! # async fn main() -> Result<(), Box> { @@ -84,7 +95,8 @@ //! # } //! ``` -#![deny(rust_2018_idioms)] +#![deny(rust_2018_idioms, nonstandard_style)] +#![warn(future_incompatible)] #![allow( clippy::type_complexity, clippy::borrow_interior_mutable_const, @@ -93,190 +105,40 @@ #![doc(html_logo_url = "https://actix.rs/img/logo.png")] #![doc(html_favicon_url = "https://actix.rs/favicon.ico")] -use std::convert::TryFrom; -use std::rc::Rc; -use std::time::Duration; +pub use actix_http::body; #[cfg(feature = "cookies")] -pub use actix_http::cookie; -pub use actix_http::{client::Connector, http}; - -use actix_http::http::{Error as HttpError, HeaderMap, Method, Uri}; -use actix_http::RequestHead; +pub use cookie; +mod any_body; mod builder; +mod client; mod connect; pub mod error; mod frozen; +pub mod middleware; mod request; -mod response; +mod responses; mod sender; pub mod test; pub mod ws; +pub mod http { + //! Various HTTP related types. + + // TODO: figure out how best to expose http::Error vs actix_http::Error + pub use actix_http::{ + header, uri, ConnectionType, Error, Method, StatusCode, Uri, Version, + }; +} + pub use self::builder::ClientBuilder; -pub use self::connect::BoxedSocket; +pub use self::client::{Client, Connector}; +pub use self::connect::{BoxConnectorService, BoxedSocket, ConnectRequest, ConnectResponse}; pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder}; pub use self::request::ClientRequest; -pub use self::response::{ClientResponse, JsonBody, MessageBody}; +#[allow(deprecated)] +pub use self::responses::{ClientResponse, JsonBody, MessageBody, ResponseBody}; pub use self::sender::SendClientRequest; -use self::connect::{Connect, ConnectorWrapper}; - -/// An asynchronous HTTP and WebSocket client. -/// -/// ## Examples -/// -/// ```rust -/// use awc::Client; -/// -/// #[actix_rt::main] -/// async fn main() { -/// let mut client = Client::default(); -/// -/// let res = client.get("http://www.rust-lang.org") // <- Create request builder -/// .insert_header(("User-Agent", "Actix-web")) -/// .send() // <- Send HTTP request -/// .await; // <- send request and wait for response -/// -/// println!("Response: {:?}", res); -/// } -/// ``` -#[derive(Clone)] -pub struct Client(Rc); - -pub(crate) struct ClientConfig { - pub(crate) connector: Box, - pub(crate) headers: HeaderMap, - pub(crate) timeout: Option, -} - -impl Default for Client { - fn default() -> Self { - Client(Rc::new(ClientConfig { - connector: Box::new(ConnectorWrapper(Connector::new().finish())), - headers: HeaderMap::new(), - timeout: Some(Duration::from_secs(5)), - })) - } -} - -impl Client { - /// Create new client instance with default settings. - pub fn new() -> Client { - Client::default() - } - - /// Create `Client` builder. - /// This function is equivalent of `ClientBuilder::new()`. - pub fn builder() -> ClientBuilder { - ClientBuilder::new() - } - - /// Construct HTTP request. - pub fn request(&self, method: Method, url: U) -> ClientRequest - where - Uri: TryFrom, - >::Error: Into, - { - let mut req = ClientRequest::new(method, url, self.0.clone()); - - for header in self.0.headers.iter() { - req = req.insert_header_if_none(header); - } - req - } - - /// Create `ClientRequest` from `RequestHead` - /// - /// It is useful for proxy requests. This implementation - /// copies all headers and the method. - pub fn request_from(&self, url: U, head: &RequestHead) -> ClientRequest - where - Uri: TryFrom, - >::Error: Into, - { - let mut req = self.request(head.method.clone(), url); - for header in head.headers.iter() { - req = req.insert_header_if_none(header); - } - req - } - - /// Construct HTTP *GET* request. - pub fn get(&self, url: U) -> ClientRequest - where - Uri: TryFrom, - >::Error: Into, - { - self.request(Method::GET, url) - } - - /// Construct HTTP *HEAD* request. - pub fn head(&self, url: U) -> ClientRequest - where - Uri: TryFrom, - >::Error: Into, - { - self.request(Method::HEAD, url) - } - - /// Construct HTTP *PUT* request. - pub fn put(&self, url: U) -> ClientRequest - where - Uri: TryFrom, - >::Error: Into, - { - self.request(Method::PUT, url) - } - - /// Construct HTTP *POST* request. - pub fn post(&self, url: U) -> ClientRequest - where - Uri: TryFrom, - >::Error: Into, - { - self.request(Method::POST, url) - } - - /// Construct HTTP *PATCH* request. - pub fn patch(&self, url: U) -> ClientRequest - where - Uri: TryFrom, - >::Error: Into, - { - self.request(Method::PATCH, url) - } - - /// Construct HTTP *DELETE* request. - pub fn delete(&self, url: U) -> ClientRequest - where - Uri: TryFrom, - >::Error: Into, - { - self.request(Method::DELETE, url) - } - - /// Construct HTTP *OPTIONS* request. - pub fn options(&self, url: U) -> ClientRequest - where - Uri: TryFrom, - >::Error: Into, - { - self.request(Method::OPTIONS, url) - } - - /// Initialize a WebSocket connection. - /// Returns a WebSocket connection builder. - pub fn ws(&self, url: U) -> ws::WebsocketsRequest - where - Uri: TryFrom, - >::Error: Into, - { - let mut req = ws::WebsocketsRequest::new(url, self.0.clone()); - for (key, value) in self.0.headers.iter() { - req.head.headers.insert(key.clone(), value.clone()); - } - req - } -} +pub(crate) type BoxError = Box; diff --git a/awc/src/middleware/mod.rs b/awc/src/middleware/mod.rs new file mode 100644 index 000000000..330e3b7fe --- /dev/null +++ b/awc/src/middleware/mod.rs @@ -0,0 +1,71 @@ +mod redirect; + +pub use self::redirect::Redirect; + +use std::marker::PhantomData; + +use actix_service::Service; + +/// Trait for transform a type to another one. +/// Both the input and output type should impl [actix_service::Service] trait. +pub trait Transform { + type Transform: Service; + + /// Creates and returns a new Transform component. + fn new_transform(self, service: S) -> Self::Transform; +} + +#[doc(hidden)] +/// Helper struct for constructing Nested types that would call `Transform::new_transform` +/// in a chain. +/// +/// The child field would be called first and the output `Service` type is +/// passed to parent as input type. +pub struct NestTransform +where + T1: Transform, + T2: Transform, +{ + child: T1, + parent: T2, + _service: PhantomData<(S, Req)>, +} + +impl NestTransform +where + T1: Transform, + T2: Transform, +{ + pub(crate) fn new(child: T1, parent: T2) -> Self { + NestTransform { + child, + parent, + _service: PhantomData, + } + } +} + +impl Transform for NestTransform +where + T1: Transform, + T2: Transform, +{ + type Transform = T2::Transform; + + fn new_transform(self, service: S) -> Self::Transform { + let service = self.child.new_transform(service); + self.parent.new_transform(service) + } +} + +/// Dummy impl for kick start `NestTransform` type in `ClientBuilder` type +impl Transform for () +where + S: Service, +{ + type Transform = S; + + fn new_transform(self, service: S) -> Self::Transform { + service + } +} diff --git a/awc/src/middleware/redirect.rs b/awc/src/middleware/redirect.rs new file mode 100644 index 000000000..704d2d79d --- /dev/null +++ b/awc/src/middleware/redirect.rs @@ -0,0 +1,587 @@ +use std::{ + convert::TryFrom, + future::Future, + net::SocketAddr, + pin::Pin, + rc::Rc, + task::{Context, Poll}, +}; + +use actix_http::{header, Method, RequestHead, RequestHeadType, StatusCode, Uri}; +use actix_service::Service; +use bytes::Bytes; +use futures_core::ready; + +use super::Transform; +use crate::{ + any_body::AnyBody, + client::{InvalidUrl, SendRequestError}, + connect::{ConnectRequest, ConnectResponse}, + ClientResponse, +}; + +pub struct Redirect { + max_redirect_times: u8, +} + +impl Default for Redirect { + fn default() -> Self { + Self::new() + } +} + +impl Redirect { + pub fn new() -> Self { + Self { + max_redirect_times: 10, + } + } + + pub fn max_redirect_times(mut self, times: u8) -> Self { + self.max_redirect_times = times; + self + } +} + +impl Transform for Redirect +where + S: Service + 'static, +{ + type Transform = RedirectService; + + fn new_transform(self, service: S) -> Self::Transform { + RedirectService { + max_redirect_times: self.max_redirect_times, + connector: Rc::new(service), + } + } +} + +pub struct RedirectService { + max_redirect_times: u8, + connector: Rc, +} + +impl Service for RedirectService +where + S: Service + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = RedirectServiceFuture; + + actix_service::forward_ready!(connector); + + fn call(&self, req: ConnectRequest) -> Self::Future { + match req { + ConnectRequest::Tunnel(head, addr) => { + let fut = self.connector.call(ConnectRequest::Tunnel(head, addr)); + RedirectServiceFuture::Tunnel { fut } + } + ConnectRequest::Client(head, body, addr) => { + let connector = self.connector.clone(); + let max_redirect_times = self.max_redirect_times; + + // backup the uri and method for reuse schema and authority. + let (uri, method, headers) = match head { + RequestHeadType::Owned(ref head) => { + (head.uri.clone(), head.method.clone(), head.headers.clone()) + } + RequestHeadType::Rc(ref head, ..) => { + (head.uri.clone(), head.method.clone(), head.headers.clone()) + } + }; + + let body_opt = match body { + AnyBody::Bytes { ref body } => Some(body.clone()), + _ => None, + }; + + let fut = connector.call(ConnectRequest::Client(head, body, addr)); + + RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(uri), + method: Some(method), + headers: Some(headers), + body: body_opt, + addr, + connector: Some(connector), + } + } + } + } +} + +pin_project_lite::pin_project! { + #[project = RedirectServiceProj] + pub enum RedirectServiceFuture + where + S: Service, + S: 'static + { + Tunnel { #[pin] fut: S::Future }, + Client { + #[pin] + fut: S::Future, + max_redirect_times: u8, + uri: Option, + method: Option, + headers: Option, + body: Option, + addr: Option, + connector: Option>, + } + } +} + +impl Future for RedirectServiceFuture +where + S: Service + 'static, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project() { + RedirectServiceProj::Tunnel { fut } => fut.poll(cx), + RedirectServiceProj::Client { + fut, + max_redirect_times, + uri, + method, + headers, + body, + addr, + connector, + } => match ready!(fut.poll(cx))? { + ConnectResponse::Client(res) => match res.head().status { + StatusCode::MOVED_PERMANENTLY + | StatusCode::FOUND + | StatusCode::SEE_OTHER + | StatusCode::TEMPORARY_REDIRECT + | StatusCode::PERMANENT_REDIRECT + if *max_redirect_times > 0 => + { + let is_redirect = res.head().status == StatusCode::TEMPORARY_REDIRECT + || res.head().status == StatusCode::PERMANENT_REDIRECT; + + let prev_uri = uri.take().unwrap(); + + // rebuild uri from the location header value. + let next_uri = build_next_uri(&res, &prev_uri)?; + + // take ownership of states that could be reused + let addr = addr.take(); + let connector = connector.take(); + + // reset method + let method = if is_redirect { + method.take().unwrap() + } else { + let method = method.take().unwrap(); + match method { + Method::GET | Method::HEAD => method, + _ => Method::GET, + } + }; + + let mut body = body.take(); + let body_new = if is_redirect { + // try to reuse body + match body { + Some(ref bytes) => AnyBody::Bytes { + body: bytes.clone(), + }, + // TODO: should this be AnyBody::Empty or AnyBody::None. + _ => AnyBody::empty(), + } + } else { + body = None; + // remove body + AnyBody::None + }; + + let mut headers = headers.take().unwrap(); + + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); + + // use a new request head. + let mut head = RequestHead::default(); + head.uri = next_uri.clone(); + head.method = method.clone(); + head.headers = headers.clone(); + + let head = RequestHeadType::Owned(head); + + let mut max_redirect_times = *max_redirect_times; + max_redirect_times -= 1; + + let fut = connector + .as_ref() + .unwrap() + .call(ConnectRequest::Client(head, body_new, addr)); + + self.set(RedirectServiceFuture::Client { + fut, + max_redirect_times, + uri: Some(next_uri), + method: Some(method), + headers: Some(headers), + body, + addr, + connector, + }); + + self.poll(cx) + } + _ => Poll::Ready(Ok(ConnectResponse::Client(res))), + }, + _ => unreachable!("ConnectRequest::Tunnel is not handled by Redirect"), + }, + } + } +} + +fn build_next_uri(res: &ClientResponse, prev_uri: &Uri) -> Result { + let uri = res + .headers() + .get(header::LOCATION) + .map(|value| { + // try to parse the location to a full uri + let uri = Uri::try_from(value.as_bytes()) + .map_err(|e| SendRequestError::Url(InvalidUrl::HttpError(e.into())))?; + if uri.scheme().is_none() || uri.authority().is_none() { + let uri = Uri::builder() + .scheme(prev_uri.scheme().cloned().unwrap()) + .authority(prev_uri.authority().cloned().unwrap()) + .path_and_query(value.as_bytes()) + .build()?; + Ok::<_, SendRequestError>(uri) + } else { + Ok(uri) + } + }) + // TODO: this error type is wrong. + .ok_or(SendRequestError::Url(InvalidUrl::MissingScheme))??; + + Ok(uri) +} + +fn remove_sensitive_headers(headers: &mut header::HeaderMap, prev_uri: &Uri, next_uri: &Uri) { + if next_uri.host() != prev_uri.host() + || next_uri.port() != prev_uri.port() + || next_uri.scheme() != prev_uri.scheme() + { + headers.remove(header::COOKIE); + headers.remove(header::AUTHORIZATION); + headers.remove(header::PROXY_AUTHORIZATION); + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use actix_web::{web, App, Error, HttpRequest, HttpResponse}; + + use super::*; + use crate::{http::header::HeaderValue, ClientBuilder}; + + #[actix_rt::test] + async fn test_basic_redirect() { + let client = ClientBuilder::new() + .disable_redirects() + .wrap(Redirect::new().max_redirect_times(10)) + .finish(); + + let srv = actix_test::start(|| { + App::new() + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>(HttpResponse::BadRequest()) + }))) + .service(web::resource("/").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Found() + .append_header(("location", "/test")) + .finish(), + ) + }))) + }); + + let res = client.get(srv.url("/")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 400); + } + + #[actix_rt::test] + async fn test_redirect_limit() { + let client = ClientBuilder::new() + .disable_redirects() + .wrap(Redirect::new().max_redirect_times(1)) + .connector(crate::Connector::new()) + .finish(); + + let srv = actix_test::start(|| { + App::new() + .service(web::resource("/").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Found() + .append_header(("location", "/test")) + .finish(), + ) + }))) + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Found() + .append_header(("location", "/test2")) + .finish(), + ) + }))) + .service(web::resource("/test2").route(web::to(|| async { + Ok::<_, Error>(HttpResponse::BadRequest()) + }))) + }); + + let res = client.get(srv.url("/")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 302); + } + + #[actix_rt::test] + async fn test_redirect_status_kind_307_308() { + let srv = actix_test::start(|| { + async fn root() -> HttpResponse { + HttpResponse::TemporaryRedirect() + .append_header(("location", "/test")) + .finish() + } + + async fn test(req: HttpRequest, body: Bytes) -> HttpResponse { + if req.method() == Method::POST && !body.is_empty() { + HttpResponse::Ok().finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + App::new() + .service(web::resource("/").route(web::to(root))) + .service(web::resource("/test").route(web::to(test))) + }); + + let res = srv.post("/").send_body("Hello").await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + } + + #[actix_rt::test] + async fn test_redirect_status_kind_301_302_303() { + let srv = actix_test::start(|| { + async fn root() -> HttpResponse { + HttpResponse::Found() + .append_header(("location", "/test")) + .finish() + } + + async fn test(req: HttpRequest, body: Bytes) -> HttpResponse { + if (req.method() == Method::GET || req.method() == Method::HEAD) + && body.is_empty() + { + HttpResponse::Ok().finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + App::new() + .service(web::resource("/").route(web::to(root))) + .service(web::resource("/test").route(web::to(test))) + }); + + let res = srv.post("/").send_body("Hello").await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + + let res = srv.post("/").send().await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + } + + #[actix_rt::test] + async fn test_redirect_headers() { + let srv = actix_test::start(|| { + async fn root(req: HttpRequest) -> HttpResponse { + if req + .headers() + .get("custom") + .unwrap_or(&HeaderValue::from_str("").unwrap()) + == "value" + { + HttpResponse::Found() + .append_header(("location", "/test")) + .finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + async fn test(req: HttpRequest) -> HttpResponse { + if req + .headers() + .get("custom") + .unwrap_or(&HeaderValue::from_str("").unwrap()) + == "value" + { + HttpResponse::Ok().finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + App::new() + .service(web::resource("/").route(web::to(root))) + .service(web::resource("/test").route(web::to(test))) + }); + + let client = ClientBuilder::new() + .add_default_header(("custom", "value")) + .disable_redirects() + .finish(); + let res = client.get(srv.url("/")).send().await.unwrap(); + assert_eq!(res.status().as_u16(), 302); + + let client = ClientBuilder::new() + .add_default_header(("custom", "value")) + .finish(); + let res = client.get(srv.url("/")).send().await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + + let client = ClientBuilder::new().finish(); + let res = client + .get(srv.url("/")) + .insert_header(("custom", "value")) + .send() + .await + .unwrap(); + assert_eq!(res.status().as_u16(), 200); + } + + #[actix_rt::test] + async fn test_redirect_cross_origin_headers() { + // defining two services to have two different origins + let srv2 = actix_test::start(|| { + async fn root(req: HttpRequest) -> HttpResponse { + if req.headers().get(header::AUTHORIZATION).is_none() { + HttpResponse::Ok().finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + App::new().service(web::resource("/").route(web::to(root))) + }); + let srv2_port: u16 = srv2.addr().port(); + + let srv1 = actix_test::start(move || { + async fn root(req: HttpRequest) -> HttpResponse { + let port = *req.app_data::().unwrap(); + if req.headers().get(header::AUTHORIZATION).is_some() { + HttpResponse::Found() + .append_header(( + "location", + format!("http://localhost:{}/", port).as_str(), + )) + .finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + async fn test1(req: HttpRequest) -> HttpResponse { + if req.headers().get(header::AUTHORIZATION).is_some() { + HttpResponse::Found() + .append_header(("location", "/test2")) + .finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + async fn test2(req: HttpRequest) -> HttpResponse { + if req.headers().get(header::AUTHORIZATION).is_some() { + HttpResponse::Ok().finish() + } else { + HttpResponse::InternalServerError().finish() + } + } + + App::new() + .app_data(srv2_port) + .service(web::resource("/").route(web::to(root))) + .service(web::resource("/test1").route(web::to(test1))) + .service(web::resource("/test2").route(web::to(test2))) + }); + + // send a request to different origins, http://srv1/ then http://srv2/. So it should remove the header + let client = ClientBuilder::new() + .add_default_header((header::AUTHORIZATION, "auth_key_value")) + .finish(); + let res = client.get(srv1.url("/")).send().await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + + // send a request to same origin, http://srv1/test1 then http://srv1/test2. So it should NOT remove any header + let res = client.get(srv1.url("/test1")).send().await.unwrap(); + assert_eq!(res.status().as_u16(), 200); + } + + #[actix_rt::test] + async fn test_remove_sensitive_headers() { + fn gen_headers() -> header::HeaderMap { + let mut headers = header::HeaderMap::new(); + headers.insert(header::USER_AGENT, HeaderValue::from_str("value").unwrap()); + headers.insert( + header::AUTHORIZATION, + HeaderValue::from_str("value").unwrap(), + ); + headers.insert( + header::PROXY_AUTHORIZATION, + HeaderValue::from_str("value").unwrap(), + ); + headers.insert(header::COOKIE, HeaderValue::from_str("value").unwrap()); + headers + } + + // Same origin + let prev_uri = Uri::from_str("https://host/path1").unwrap(); + let next_uri = Uri::from_str("https://host/path2").unwrap(); + let mut headers = gen_headers(); + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); + assert_eq!(headers.len(), 4); + + // different schema + let prev_uri = Uri::from_str("http://host/").unwrap(); + let next_uri = Uri::from_str("https://host/").unwrap(); + let mut headers = gen_headers(); + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); + assert_eq!(headers.len(), 1); + + // different host + let prev_uri = Uri::from_str("https://host1/").unwrap(); + let next_uri = Uri::from_str("https://host2/").unwrap(); + let mut headers = gen_headers(); + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); + assert_eq!(headers.len(), 1); + + // different port + let prev_uri = Uri::from_str("https://host:12/").unwrap(); + let next_uri = Uri::from_str("https://host:23/").unwrap(); + let mut headers = gen_headers(); + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); + assert_eq!(headers.len(), 1); + + // different everything! + let prev_uri = Uri::from_str("http://host1:12/path1").unwrap(); + let next_uri = Uri::from_str("https://host2:23/path2").unwrap(); + let mut headers = gen_headers(); + remove_sensitive_headers(&mut headers, &prev_uri, &next_uri); + assert_eq!(headers.len(), 1); + } +} diff --git a/awc/src/request.rs b/awc/src/request.rs index 812c0e805..8bcf1ee01 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -1,55 +1,46 @@ -use std::convert::TryFrom; -use std::rc::Rc; -use std::time::Duration; -use std::{fmt, net}; +use std::{convert::TryFrom, fmt, net, rc::Rc, time::Duration}; use bytes::Bytes; use futures_core::Stream; use serde::Serialize; -use actix_http::body::Body; -#[cfg(feature = "cookies")] -use actix_http::cookie::{Cookie, CookieJar}; -use actix_http::http::header::{self, IntoHeaderPair}; -use actix_http::http::{ - uri, ConnectionType, Error as HttpError, HeaderMap, HeaderValue, Method, Uri, Version, +use actix_http::{ + body::MessageBody, + error::HttpError, + header::{self, HeaderMap, HeaderValue, TryIntoHeaderPair}, + ConnectionType, Method, RequestHead, Uri, Version, }; -use actix_http::{Error, RequestHead}; -use crate::error::{FreezeRequestError, InvalidUrl}; -use crate::frozen::FrozenClientRequest; -use crate::sender::{PrepForSendingError, RequestSender, SendClientRequest}; -use crate::ClientConfig; +use crate::{ + client::ClientConfig, + error::{FreezeRequestError, InvalidUrl}, + frozen::FrozenClientRequest, + sender::{PrepForSendingError, RequestSender, SendClientRequest}, + BoxError, +}; -cfg_if::cfg_if! { - if #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] { - const HTTPS_ENCODING: &str = "br, gzip, deflate"; - } else if #[cfg(feature = "compress")] { - const HTTPS_ENCODING: &str = "br"; - } else { - const HTTPS_ENCODING: &str = "identity"; - } -} +#[cfg(feature = "cookies")] +use crate::cookie::{Cookie, CookieJar}; /// An HTTP Client request builder /// /// This type can be used to construct an instance of `ClientRequest` through a /// builder-like pattern. /// -/// ```rust -/// #[actix_rt::main] -/// async fn main() { -/// let response = awc::Client::new() -/// .get("http://www.rust-lang.org") // <- Create request builder -/// .insert_header(("User-Agent", "Actix-web")) -/// .send() // <- Send HTTP request -/// .await; +/// ```no_run +/// # #[actix_rt::main] +/// # async fn main() { +/// let response = awc::Client::new() +/// .get("http://www.rust-lang.org") // <- Create request builder +/// .insert_header(("User-Agent", "Actix-web")) +/// .send() // <- Send HTTP request +/// .await; /// -/// response.and_then(|response| { // <- server HTTP response -/// println!("Response: {:?}", response); -/// Ok(()) -/// }); -/// } +/// response.and_then(|response| { // <- server HTTP response +/// println!("Response: {:?}", response); +/// Ok(()) +/// }); +/// # } /// ``` pub struct ClientRequest { pub(crate) head: RequestHead, @@ -57,7 +48,7 @@ pub struct ClientRequest { addr: Option, response_decompress: bool, timeout: Option, - config: Rc, + config: ClientConfig, #[cfg(feature = "cookies")] cookies: Option, @@ -65,7 +56,7 @@ pub struct ClientRequest { impl ClientRequest { /// Create new client request builder. - pub(crate) fn new(method: Method, uri: U, config: Rc) -> Self + pub(crate) fn new(method: Method, uri: U, config: ClientConfig) -> Self where Uri: TryFrom, >::Error: Into, @@ -124,10 +115,10 @@ impl ClientRequest { &self.head.method } - #[doc(hidden)] /// Set HTTP version of this request. /// /// By default requests's HTTP version depends on network stream + #[doc(hidden)] #[inline] pub fn version(mut self, version: Version) -> Self { self.head.version = version; @@ -157,11 +148,8 @@ impl ClientRequest { } /// Insert a header, replacing any that were set with an equivalent field name. - pub fn insert_header(mut self, header: H) -> Self - where - H: IntoHeaderPair, - { - match header.try_into_header_pair() { + pub fn insert_header(mut self, header: impl TryIntoHeaderPair) -> Self { + match header.try_into_pair() { Ok((key, value)) => { self.head.headers.insert(key, value); } @@ -172,11 +160,8 @@ impl ClientRequest { } /// Insert a header only if it is not yet set. - pub fn insert_header_if_none(mut self, header: H) -> Self - where - H: IntoHeaderPair, - { - match header.try_into_header_pair() { + pub fn insert_header_if_none(mut self, header: impl TryIntoHeaderPair) -> Self { + match header.try_into_pair() { Ok((key, value)) => { if !self.head.headers.contains_key(&key) { self.head.headers.insert(key, value); @@ -190,23 +175,16 @@ impl ClientRequest { /// Append a header, keeping any that were set with an equivalent field name. /// - /// ```rust - /// # #[actix_rt::main] - /// # async fn main() { - /// # use awc::Client; - /// use awc::http::header::ContentType; + /// ```no_run + /// use awc::{http::header, Client}; /// /// Client::new() /// .get("http://www.rust-lang.org") /// .insert_header(("X-TEST", "value")) - /// .insert_header(ContentType(mime::APPLICATION_JSON)); - /// # } + /// .insert_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)); /// ``` - pub fn append_header(mut self, header: H) -> Self - where - H: IntoHeaderPair, - { - match header.try_into_header_pair() { + pub fn append_header(mut self, header: impl TryIntoHeaderPair) -> Self { + match header.try_into_pair() { Ok((key, value)) => self.head.headers.append(key, value), Err(e) => self.err = Some(e.into()), }; @@ -248,51 +226,41 @@ impl ClientRequest { /// Set content length #[inline] pub fn content_length(self, len: u64) -> Self { - self.append_header((header::CONTENT_LENGTH, len)) + let mut buf = itoa::Buffer::new(); + self.insert_header((header::CONTENT_LENGTH, buf.format(len))) } - /// Set HTTP basic authorization header - pub fn basic_auth(self, username: U, password: Option<&str>) -> Self - where - U: fmt::Display, - { - let auth = match password { - Some(password) => format!("{}:{}", username, password), - None => format!("{}:", username), - }; - self.append_header(( + /// Set HTTP basic authorization header. + /// + /// If no password is needed, just provide an empty string. + pub fn basic_auth(self, username: impl fmt::Display, password: impl fmt::Display) -> Self { + let auth = format!("{}:{}", username, password); + + self.insert_header(( header::AUTHORIZATION, format!("Basic {}", base64::encode(&auth)), )) } /// Set HTTP bearer authentication header - pub fn bearer_auth(self, token: T) -> Self - where - T: fmt::Display, - { - self.append_header((header::AUTHORIZATION, format!("Bearer {}", token))) + pub fn bearer_auth(self, token: impl fmt::Display) -> Self { + self.insert_header((header::AUTHORIZATION, format!("Bearer {}", token))) } /// Set a cookie /// - /// ```rust - /// #[actix_rt::main] - /// async fn main() { - /// let resp = awc::Client::new().get("https://www.rust-lang.org") - /// .cookie( - /// awc::http::Cookie::build("name", "value") - /// .domain("www.rust-lang.org") - /// .path("/") - /// .secure(true) - /// .http_only(true) - /// .finish(), - /// ) - /// .send() - /// .await; + /// ```no_run + /// use awc::{cookie::Cookie, Client}; /// - /// println!("Response: {:?}", resp); - /// } + /// # #[actix_rt::main] + /// # async fn main() { + /// let res = Client::new().get("https://httpbin.org/cookies") + /// .cookie(Cookie::new("name", "value")) + /// .send() + /// .await; + /// + /// println!("Response: {:?}", res); + /// # } /// ``` #[cfg(feature = "cookies")] pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { @@ -321,34 +289,6 @@ impl ClientRequest { self } - /// This method calls provided closure with builder reference if value is `true`. - #[doc(hidden)] - #[deprecated = "Use an if statement."] - pub fn if_true(self, value: bool, f: F) -> Self - where - F: FnOnce(ClientRequest) -> ClientRequest, - { - if value { - f(self) - } else { - self - } - } - - /// This method calls provided closure with builder reference if value is `Some`. - #[doc(hidden)] - #[deprecated = "Use an if-let construction."] - pub fn if_some(self, value: Option, f: F) -> Self - where - F: FnOnce(T, ClientRequest) -> ClientRequest, - { - if let Some(val) = value { - f(val, self) - } else { - self - } - } - /// Sets the query part of the request pub fn query( mut self, @@ -392,7 +332,7 @@ impl ClientRequest { /// Complete request construction and send body. pub fn send_body(self, body: B) -> SendClientRequest where - B: Into, + B: MessageBody + 'static, { let slf = match self.prep_for_sending() { Ok(slf) => slf, @@ -403,7 +343,7 @@ impl ClientRequest { slf.addr, slf.response_decompress, slf.timeout, - slf.config.as_ref(), + &slf.config, body, ) } @@ -419,7 +359,7 @@ impl ClientRequest { slf.addr, slf.response_decompress, slf.timeout, - slf.config.as_ref(), + &slf.config, value, ) } @@ -437,7 +377,7 @@ impl ClientRequest { slf.addr, slf.response_decompress, slf.timeout, - slf.config.as_ref(), + &slf.config, value, ) } @@ -445,8 +385,8 @@ impl ClientRequest { /// Set an streaming body and generate `ClientRequest`. pub fn send_stream(self, stream: S) -> SendClientRequest where - S: Stream> + Unpin + 'static, - E: Into + 'static, + S: Stream> + 'static, + E: Into + 'static, { let slf = match self.prep_for_sending() { Ok(slf) => slf, @@ -457,7 +397,7 @@ impl ClientRequest { slf.addr, slf.response_decompress, slf.timeout, - slf.config.as_ref(), + &slf.config, stream, ) } @@ -473,7 +413,7 @@ impl ClientRequest { slf.addr, slf.response_decompress, slf.timeout, - slf.config.as_ref(), + &slf.config, ) } @@ -504,7 +444,7 @@ impl ClientRequest { let cookie: String = jar .delta() // ensure only name=value is written to cookie header - .map(|c| Cookie::new(c.name(), c.value()).encoded().to_string()) + .map(|c| c.stripped().encoded().to_string()) .collect::>() .join("; "); @@ -517,22 +457,44 @@ impl ClientRequest { let mut slf = self; + // Set Accept-Encoding HTTP header depending on enabled feature. + // If decompress is not ask, then we are not able to find which encoding is + // supported, so we cannot guess Accept-Encoding HTTP header. if slf.response_decompress { - let https = slf - .head - .uri - .scheme() - .map(|s| s == &uri::Scheme::HTTPS) - .unwrap_or(true); + // Set Accept-Encoding with compression algorithm awc is built with. + #[allow(clippy::vec_init_then_push)] + #[cfg(feature = "__compress")] + let accept_encoding = { + let mut encoding = vec![]; - if https { - slf = slf.insert_header_if_none((header::ACCEPT_ENCODING, HTTPS_ENCODING)) - } else { - #[cfg(any(feature = "flate2-zlib", feature = "flate2-rust"))] + #[cfg(feature = "compress-brotli")] { - slf = slf.insert_header_if_none((header::ACCEPT_ENCODING, "gzip, deflate")) + encoding.push("br"); } + + #[cfg(feature = "compress-gzip")] + { + encoding.push("gzip"); + encoding.push("deflate"); + } + + #[cfg(feature = "compress-zstd")] + encoding.push("zstd"); + + assert!( + !encoding.is_empty(), + "encoding can not be empty unless __compress feature has been explicitly enabled" + ); + + encoding.join(", ") }; + + // Otherwise tell the server, we do not support any compression algorithm. + // So we clearly indicate that we do want identity encoding. + #[cfg(not(feature = "__compress"))] + let accept_encoding = "identity"; + + slf = slf.insert_header_if_none((header::ACCEPT_ENCODING, accept_encoding)); } Ok(slf) @@ -558,6 +520,8 @@ impl fmt::Debug for ClientRequest { mod tests { use std::time::SystemTime; + use actix_http::header::HttpDate; + use super::*; use crate::Client; @@ -574,7 +538,7 @@ mod tests { let req = Client::new() .put("/") .version(Version::HTTP_2) - .insert_header(header::Date(SystemTime::now().into())) + .insert_header((header::DATE, HttpDate::from(SystemTime::now()))) .content_type("plain/text") .append_header((header::SERVER, "awc")); @@ -607,7 +571,7 @@ mod tests { #[actix_rt::test] async fn test_client_header() { let req = Client::builder() - .header(header::CONTENT_TYPE, "111") + .add_default_header((header::CONTENT_TYPE, "111")) .finish() .get("/"); @@ -625,7 +589,7 @@ mod tests { #[actix_rt::test] async fn test_client_header_override() { let req = Client::builder() - .header(header::CONTENT_TYPE, "111") + .add_default_header((header::CONTENT_TYPE, "111")) .finish() .get("/") .insert_header((header::CONTENT_TYPE, "222")); @@ -643,9 +607,7 @@ mod tests { #[actix_rt::test] async fn client_basic_auth() { - let req = Client::new() - .get("/") - .basic_auth("username", Some("password")); + let req = Client::new().get("/").basic_auth("username", "password"); assert_eq!( req.head .headers @@ -656,7 +618,7 @@ mod tests { "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" ); - let req = Client::new().get("/").basic_auth("username", None); + let req = Client::new().get("/").basic_auth("username", ""); assert_eq!( req.head .headers diff --git a/awc/src/response.rs b/awc/src/response.rs deleted file mode 100644 index cf687329d..000000000 --- a/awc/src/response.rs +++ /dev/null @@ -1,465 +0,0 @@ -use std::fmt; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{ - cell::{Ref, RefMut}, - mem, -}; - -use bytes::{Bytes, BytesMut}; -use futures_core::{ready, Stream}; - -use actix_http::error::PayloadError; -use actix_http::http::header; -use actix_http::http::{HeaderMap, StatusCode, Version}; -use actix_http::{Extensions, HttpMessage, Payload, PayloadStream, ResponseHead}; -use serde::de::DeserializeOwned; - -#[cfg(feature = "cookies")] -use actix_http::{cookie::Cookie, error::CookieParseError}; - -use crate::error::JsonPayloadError; - -/// Client Response -pub struct ClientResponse { - pub(crate) head: ResponseHead, - pub(crate) payload: Payload, -} - -impl HttpMessage for ClientResponse { - type Stream = S; - - fn headers(&self) -> &HeaderMap { - &self.head.headers - } - - fn extensions(&self) -> Ref<'_, Extensions> { - self.head.extensions() - } - - fn extensions_mut(&self) -> RefMut<'_, Extensions> { - self.head.extensions_mut() - } - - fn take_payload(&mut self) -> Payload { - mem::replace(&mut self.payload, Payload::None) - } - - /// Load request cookies. - #[cfg(feature = "cookies")] - fn cookies(&self) -> Result>>, CookieParseError> { - struct Cookies(Vec>); - - if self.extensions().get::().is_none() { - let mut cookies = Vec::new(); - for hdr in self.headers().get_all(&header::SET_COOKIE) { - let s = std::str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; - cookies.push(Cookie::parse_encoded(s)?.into_owned()); - } - self.extensions_mut().insert(Cookies(cookies)); - } - Ok(Ref::map(self.extensions(), |ext| { - &ext.get::().unwrap().0 - })) - } -} - -impl ClientResponse { - /// Create new Request instance - pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { - ClientResponse { head, payload } - } - - #[inline] - pub(crate) fn head(&self) -> &ResponseHead { - &self.head - } - - /// Read the Request Version. - #[inline] - pub fn version(&self) -> Version { - self.head().version - } - - /// Get the status from the server. - #[inline] - pub fn status(&self) -> StatusCode { - self.head().status - } - - #[inline] - /// Returns request's headers. - pub fn headers(&self) -> &HeaderMap { - &self.head().headers - } - - /// Set a body and return previous body value - pub fn map_body(mut self, f: F) -> ClientResponse - where - F: FnOnce(&mut ResponseHead, Payload) -> Payload, - { - let payload = f(&mut self.head, self.payload); - - ClientResponse { - payload, - head: self.head, - } - } -} - -impl ClientResponse -where - S: Stream>, -{ - /// Loads HTTP response's body. - pub fn body(&mut self) -> MessageBody { - MessageBody::new(self) - } - - /// Loads and parse `application/json` encoded body. - /// Return `JsonBody` future. It resolves to a `T` value. - /// - /// Returns error: - /// - /// * content type is not `application/json` - /// * content length is greater than 256k - pub fn json(&mut self) -> JsonBody { - JsonBody::new(self) - } -} - -impl Stream for ClientResponse -where - S: Stream> + Unpin, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().payload).poll_next(cx) - } -} - -impl fmt::Debug for ClientResponse { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?; - writeln!(f, " headers:")?; - for (key, val) in self.headers().iter() { - writeln!(f, " {:?}: {:?}", key, val)?; - } - Ok(()) - } -} - -/// Future that resolves to a complete HTTP message body. -pub struct MessageBody { - length: Option, - err: Option, - fut: Option>, -} - -impl MessageBody -where - S: Stream>, -{ - /// Create `MessageBody` for request. - pub fn new(res: &mut ClientResponse) -> MessageBody { - let mut len = None; - if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) { - if let Ok(s) = l.to_str() { - if let Ok(l) = s.parse::() { - len = Some(l) - } else { - return Self::err(PayloadError::UnknownLength); - } - } else { - return Self::err(PayloadError::UnknownLength); - } - } - - MessageBody { - length: len, - err: None, - fut: Some(ReadBody::new(res.take_payload(), 262_144)), - } - } - - /// Change max size of payload. By default max size is 256kB - pub fn limit(mut self, limit: usize) -> Self { - if let Some(ref mut fut) = self.fut { - fut.limit = limit; - } - self - } - - fn err(e: PayloadError) -> Self { - MessageBody { - fut: None, - err: Some(e), - length: None, - } - } -} - -impl Future for MessageBody -where - S: Stream> + Unpin, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - if let Some(err) = this.err.take() { - return Poll::Ready(Err(err)); - } - - if let Some(len) = this.length.take() { - if len > this.fut.as_ref().unwrap().limit { - return Poll::Ready(Err(PayloadError::Overflow)); - } - } - - Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) - } -} - -/// Response's payload json parser, it resolves to a deserialized `T` value. -/// -/// Returns error: -/// -/// * content type is not `application/json` -/// * content length is greater than 64k -pub struct JsonBody { - length: Option, - err: Option, - fut: Option>, - _phantom: PhantomData, -} - -impl JsonBody -where - S: Stream>, - U: DeserializeOwned, -{ - /// Create `JsonBody` for request. - pub fn new(req: &mut ClientResponse) -> Self { - // check content-type - let json = if let Ok(Some(mime)) = req.mime_type() { - mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) - } else { - false - }; - if !json { - return JsonBody { - length: None, - fut: None, - err: Some(JsonPayloadError::ContentType), - _phantom: PhantomData, - }; - } - - let mut len = None; - if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) { - if let Ok(s) = l.to_str() { - if let Ok(l) = s.parse::() { - len = Some(l) - } - } - } - - JsonBody { - length: len, - err: None, - fut: Some(ReadBody::new(req.take_payload(), 65536)), - _phantom: PhantomData, - } - } - - /// Change max size of payload. By default max size is 64kB - pub fn limit(mut self, limit: usize) -> Self { - if let Some(ref mut fut) = self.fut { - fut.limit = limit; - } - self - } -} - -impl Unpin for JsonBody -where - T: Stream> + Unpin, - U: DeserializeOwned, -{ -} - -impl Future for JsonBody -where - T: Stream> + Unpin, - U: DeserializeOwned, -{ - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if let Some(err) = self.err.take() { - return Poll::Ready(Err(err)); - } - - if let Some(len) = self.length.take() { - if len > self.fut.as_ref().unwrap().limit { - return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Overflow))); - } - } - - let body = ready!(Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx))?; - Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) - } -} - -struct ReadBody { - stream: Payload, - buf: BytesMut, - limit: usize, -} - -impl ReadBody { - fn new(stream: Payload, limit: usize) -> Self { - Self { - stream, - buf: BytesMut::with_capacity(std::cmp::min(limit, 32768)), - limit, - } - } -} - -impl Future for ReadBody -where - S: Stream> + Unpin, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - loop { - return match Pin::new(&mut this.stream).poll_next(cx)? { - Poll::Ready(Some(chunk)) => { - if (this.buf.len() + chunk.len()) > this.limit { - Poll::Ready(Err(PayloadError::Overflow)) - } else { - this.buf.extend_from_slice(&chunk); - continue; - } - } - Poll::Ready(None) => Poll::Ready(Ok(this.buf.split().freeze())), - Poll::Pending => Poll::Pending, - }; - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde::{Deserialize, Serialize}; - - use crate::{http::header, test::TestResponse}; - - #[actix_rt::test] - async fn test_body() { - let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish(); - match req.body().await.err().unwrap() { - PayloadError::UnknownLength => {} - _ => unreachable!("error"), - } - - let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish(); - match req.body().await.err().unwrap() { - PayloadError::Overflow => {} - _ => unreachable!("error"), - } - - let mut req = TestResponse::default() - .set_payload(Bytes::from_static(b"test")) - .finish(); - assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test")); - - let mut req = TestResponse::default() - .set_payload(Bytes::from_static(b"11111111111111")) - .finish(); - match req.body().limit(5).await.err().unwrap() { - PayloadError::Overflow => {} - _ => unreachable!("error"), - } - } - - #[derive(Serialize, Deserialize, PartialEq, Debug)] - struct MyObject { - name: String, - } - - fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool { - match err { - JsonPayloadError::Payload(PayloadError::Overflow) => { - matches!(other, JsonPayloadError::Payload(PayloadError::Overflow)) - } - JsonPayloadError::ContentType => { - matches!(other, JsonPayloadError::ContentType) - } - _ => false, - } - } - - #[actix_rt::test] - async fn test_json_body() { - let mut req = TestResponse::default().finish(); - let json = JsonBody::<_, MyObject>::new(&mut req).await; - assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); - - let mut req = TestResponse::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/text"), - ) - .finish(); - let json = JsonBody::<_, MyObject>::new(&mut req).await; - assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); - - let mut req = TestResponse::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("10000"), - ) - .finish(); - - let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await; - assert!(json_eq( - json.err().unwrap(), - JsonPayloadError::Payload(PayloadError::Overflow) - )); - - let mut req = TestResponse::default() - .header( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ) - .header( - header::CONTENT_LENGTH, - header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .finish(); - - let json = JsonBody::<_, MyObject>::new(&mut req).await; - assert_eq!( - json.ok().unwrap(), - MyObject { - name: "test".to_owned() - } - ); - } -} diff --git a/awc/src/responses/json_body.rs b/awc/src/responses/json_body.rs new file mode 100644 index 000000000..3912324b6 --- /dev/null +++ b/awc/src/responses/json_body.rs @@ -0,0 +1,192 @@ +use std::{ + future::Future, + marker::PhantomData, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_http::{error::PayloadError, header, HttpMessage}; +use bytes::Bytes; +use futures_core::{ready, Stream}; +use pin_project_lite::pin_project; +use serde::de::DeserializeOwned; + +use super::{read_body::ReadBody, ResponseTimeout, DEFAULT_BODY_LIMIT}; +use crate::{error::JsonPayloadError, ClientResponse}; + +pin_project! { + /// A `Future` that reads a body stream, parses JSON, resolving to a deserialized `T`. + /// + /// # Errors + /// `Future` implementation returns error if: + /// - content type is not `application/json`; + /// - content length is greater than [limit](JsonBody::limit) (default: 2 MiB). + pub struct JsonBody { + #[pin] + body: Option>, + length: Option, + timeout: ResponseTimeout, + err: Option, + _phantom: PhantomData, + } +} + +impl JsonBody +where + S: Stream>, + T: DeserializeOwned, +{ + /// Creates a JSON body stream reader from a response by taking its payload. + pub fn new(res: &mut ClientResponse) -> Self { + // check content-type + let json = if let Ok(Some(mime)) = res.mime_type() { + mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) + } else { + false + }; + + if !json { + return JsonBody { + length: None, + body: None, + timeout: ResponseTimeout::default(), + err: Some(JsonPayloadError::ContentType), + _phantom: PhantomData, + }; + } + + let length = res + .headers() + .get(&header::CONTENT_LENGTH) + .and_then(|len_hdr| len_hdr.to_str().ok()) + .and_then(|len_str| len_str.parse::().ok()); + + JsonBody { + body: Some(ReadBody::new(res.take_payload(), DEFAULT_BODY_LIMIT)), + length, + timeout: mem::take(&mut res.timeout), + err: None, + _phantom: PhantomData, + } + } + + /// Change max size of payload. Default limit is 2 MiB. + pub fn limit(mut self, limit: usize) -> Self { + if let Some(ref mut fut) = self.body { + fut.limit = limit; + } + + self + } +} + +impl Future for JsonBody +where + S: Stream>, + T: DeserializeOwned, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Some(err) = this.err.take() { + return Poll::Ready(Err(err)); + } + + if let Some(len) = this.length.take() { + let body = Option::as_ref(&this.body).unwrap(); + if len > body.limit { + return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Overflow))); + } + } + + this.timeout + .poll_timeout(cx) + .map_err(JsonPayloadError::Payload)?; + + let body = ready!(this.body.as_pin_mut().unwrap().poll(cx))?; + Poll::Ready(serde_json::from_slice::(&body).map_err(JsonPayloadError::from)) + } +} + +#[cfg(test)] +mod tests { + use actix_http::BoxedPayloadStream; + use serde::{Deserialize, Serialize}; + use static_assertions::assert_impl_all; + + use super::*; + use crate::{http::header, test::TestResponse}; + + assert_impl_all!(JsonBody: Unpin); + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct MyObject { + name: String, + } + + fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool { + match err { + JsonPayloadError::Payload(PayloadError::Overflow) => { + matches!(other, JsonPayloadError::Payload(PayloadError::Overflow)) + } + JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType), + _ => false, + } + } + + #[actix_rt::test] + async fn read_json_body() { + let mut req = TestResponse::default().finish(); + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + + let mut req = TestResponse::default() + .insert_header(( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/text"), + )) + .finish(); + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); + + let mut req = TestResponse::default() + .insert_header(( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + )) + .insert_header(( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("10000"), + )) + .finish(); + + let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await; + assert!(json_eq( + json.err().unwrap(), + JsonPayloadError::Payload(PayloadError::Overflow) + )); + + let mut req = TestResponse::default() + .insert_header(( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + )) + .insert_header(( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + )) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .finish(); + + let json = JsonBody::<_, MyObject>::new(&mut req).await; + assert_eq!( + json.ok().unwrap(), + MyObject { + name: "test".to_owned() + } + ); + } +} diff --git a/awc/src/responses/mod.rs b/awc/src/responses/mod.rs new file mode 100644 index 000000000..588ce014c --- /dev/null +++ b/awc/src/responses/mod.rs @@ -0,0 +1,49 @@ +use std::{future::Future, io, pin::Pin, task::Context}; + +use actix_http::error::PayloadError; +use actix_rt::time::Sleep; + +mod json_body; +mod read_body; +mod response; +mod response_body; + +pub use self::json_body::JsonBody; +pub use self::response::ClientResponse; +#[allow(deprecated)] +pub use self::response_body::{MessageBody, ResponseBody}; + +/// Default body size limit: 2 MiB +const DEFAULT_BODY_LIMIT: usize = 2 * 1024 * 1024; + +/// Helper enum with reusable sleep passed from `SendClientResponse`. +/// +/// See [`ClientResponse::_timeout`] for reason. +pub(crate) enum ResponseTimeout { + Disabled(Option>>), + Enabled(Pin>), +} + +impl Default for ResponseTimeout { + fn default() -> Self { + Self::Disabled(None) + } +} + +impl ResponseTimeout { + fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> { + match *self { + Self::Enabled(ref mut timeout) => { + if timeout.as_mut().poll(cx).is_ready() { + Err(PayloadError::Io(io::Error::new( + io::ErrorKind::TimedOut, + "Response Payload IO timed out", + ))) + } else { + Ok(()) + } + } + Self::Disabled(_) => Ok(()), + } + } +} diff --git a/awc/src/responses/read_body.rs b/awc/src/responses/read_body.rs new file mode 100644 index 000000000..a32bbb984 --- /dev/null +++ b/awc/src/responses/read_body.rs @@ -0,0 +1,61 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_http::{error::PayloadError, Payload}; +use bytes::{Bytes, BytesMut}; +use futures_core::{ready, Stream}; +use pin_project_lite::pin_project; + +pin_project! { + pub(crate) struct ReadBody { + #[pin] + pub(crate) stream: Payload, + pub(crate) buf: BytesMut, + pub(crate) limit: usize, + } +} + +impl ReadBody { + pub(crate) fn new(stream: Payload, limit: usize) -> Self { + Self { + stream, + buf: BytesMut::new(), + limit, + } + } +} + +impl Future for ReadBody +where + S: Stream>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + while let Some(chunk) = ready!(this.stream.as_mut().poll_next(cx)?) { + if (this.buf.len() + chunk.len()) > *this.limit { + return Poll::Ready(Err(PayloadError::Overflow)); + } + + this.buf.extend_from_slice(&chunk); + } + + Poll::Ready(Ok(this.buf.split().freeze())) + } +} + +#[cfg(test)] +mod tests { + use static_assertions::assert_impl_all; + + use super::*; + use crate::any_body::AnyBody; + + assert_impl_all!(ReadBody<()>: Unpin); + assert_impl_all!(ReadBody: Unpin); +} diff --git a/awc/src/responses/response.rs b/awc/src/responses/response.rs new file mode 100644 index 000000000..0197265f1 --- /dev/null +++ b/awc/src/responses/response.rs @@ -0,0 +1,258 @@ +use std::{ + cell::{Ref, RefMut}, + fmt, mem, + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use actix_http::{ + error::PayloadError, header::HeaderMap, BoxedPayloadStream, Extensions, HttpMessage, + Payload, ResponseHead, StatusCode, Version, +}; +use actix_rt::time::{sleep, Sleep}; +use bytes::Bytes; +use futures_core::Stream; +use pin_project_lite::pin_project; +use serde::de::DeserializeOwned; + +#[cfg(feature = "cookies")] +use crate::cookie::{Cookie, ParseError as CookieParseError}; + +use super::{JsonBody, ResponseBody, ResponseTimeout}; + +pin_project! { + /// Client Response + pub struct ClientResponse { + pub(crate) head: ResponseHead, + #[pin] + pub(crate) payload: Payload, + pub(crate) timeout: ResponseTimeout, + } +} + +impl ClientResponse { + /// Create new Request instance + pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { + ClientResponse { + head, + payload, + timeout: ResponseTimeout::default(), + } + } + + #[inline] + pub(crate) fn head(&self) -> &ResponseHead { + &self.head + } + + /// Read the Request Version. + #[inline] + pub fn version(&self) -> Version { + self.head().version + } + + /// Get the status from the server. + #[inline] + pub fn status(&self) -> StatusCode { + self.head().status + } + + #[inline] + /// Returns request's headers. + pub fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + /// Set a body and return previous body value + pub fn map_body(mut self, f: F) -> ClientResponse + where + F: FnOnce(&mut ResponseHead, Payload) -> Payload, + { + let payload = f(&mut self.head, self.payload); + + ClientResponse { + payload, + head: self.head, + timeout: self.timeout, + } + } + + /// Set a timeout duration for [`ClientResponse`](self::ClientResponse). + /// + /// This duration covers the duration of processing the response body stream + /// and would end it as timeout error when deadline met. + /// + /// Disabled by default. + pub fn timeout(self, dur: Duration) -> Self { + let timeout = match self.timeout { + ResponseTimeout::Disabled(Some(mut timeout)) + | ResponseTimeout::Enabled(mut timeout) => match Instant::now().checked_add(dur) { + Some(deadline) => { + timeout.as_mut().reset(deadline.into()); + ResponseTimeout::Enabled(timeout) + } + None => ResponseTimeout::Enabled(Box::pin(sleep(dur))), + }, + _ => ResponseTimeout::Enabled(Box::pin(sleep(dur))), + }; + + Self { + payload: self.payload, + head: self.head, + timeout, + } + } + + /// This method does not enable timeout. It's used to pass the boxed `Sleep` from + /// `SendClientRequest` and reuse it's heap allocation together with it's slot in + /// timer wheel. + pub(crate) fn _timeout(mut self, timeout: Option>>) -> Self { + self.timeout = ResponseTimeout::Disabled(timeout); + self + } + + /// Load request cookies. + #[cfg(feature = "cookies")] + pub fn cookies(&self) -> Result>>, CookieParseError> { + struct Cookies(Vec>); + + if self.extensions().get::().is_none() { + let mut cookies = Vec::new(); + for hdr in self.headers().get_all(&actix_http::header::SET_COOKIE) { + let s = std::str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; + cookies.push(Cookie::parse_encoded(s)?.into_owned()); + } + self.extensions_mut().insert(Cookies(cookies)); + } + + Ok(Ref::map(self.extensions(), |ext| { + &ext.get::().unwrap().0 + })) + } + + /// Return request cookie. + #[cfg(feature = "cookies")] + pub fn cookie(&self, name: &str) -> Option> { + if let Ok(cookies) = self.cookies() { + for cookie in cookies.iter() { + if cookie.name() == name { + return Some(cookie.to_owned()); + } + } + } + None + } +} + +impl ClientResponse +where + S: Stream>, +{ + /// Returns a [`Future`] that consumes the body stream and resolves to [`Bytes`]. + /// + /// # Errors + /// `Future` implementation returns error if: + /// - content type is not `application/json` + /// - content length is greater than [limit](JsonBody::limit) (default: 2 MiB) + /// + /// # Examples + /// ```no_run + /// # use awc::Client; + /// # use bytes::Bytes; + /// # #[actix_rt::main] + /// # async fn async_ctx() -> Result<(), Box> { + /// let client = Client::default(); + /// let mut res = client.get("https://httpbin.org/robots.txt").send().await?; + /// let body: Bytes = res.body().await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`Future`]: std::future::Future + pub fn body(&mut self) -> ResponseBody { + ResponseBody::new(self) + } + + /// Returns a [`Future`] consumes the body stream, parses JSON, and resolves to a deserialized + /// `T` value. + /// + /// # Errors + /// Future returns error if: + /// - content type is not `application/json`; + /// - content length is greater than [limit](JsonBody::limit) (default: 2 MiB). + /// + /// # Examples + /// ```no_run + /// # use awc::Client; + /// # #[actix_rt::main] + /// # async fn async_ctx() -> Result<(), Box> { + /// let client = Client::default(); + /// let mut res = client.get("https://httpbin.org/json").send().await?; + /// let val = res.json::().await?; + /// assert!(val.is_object()); + /// # Ok(()) + /// # } + /// ``` + /// + /// [`Future`]: std::future::Future + pub fn json(&mut self) -> JsonBody { + JsonBody::new(self) + } +} + +impl fmt::Debug for ClientResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?; + writeln!(f, " headers:")?; + for (key, val) in self.headers().iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +impl HttpMessage for ClientResponse { + type Stream = S; + + fn headers(&self) -> &HeaderMap { + &self.head.headers + } + + fn take_payload(&mut self) -> Payload { + mem::replace(&mut self.payload, Payload::None) + } + + fn extensions(&self) -> Ref<'_, Extensions> { + self.head.extensions() + } + + fn extensions_mut(&self) -> RefMut<'_, Extensions> { + self.head.extensions_mut() + } +} + +impl Stream for ClientResponse +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + this.timeout.poll_timeout(cx)?; + this.payload.poll_next(cx) + } +} + +#[cfg(test)] +mod tests { + use static_assertions::assert_impl_all; + + use super::*; + use crate::any_body::AnyBody; + + assert_impl_all!(ClientResponse: Unpin); + assert_impl_all!(ClientResponse<()>: Unpin); + assert_impl_all!(ClientResponse: Unpin); +} diff --git a/awc/src/responses/response_body.rs b/awc/src/responses/response_body.rs new file mode 100644 index 000000000..8d9d1274a --- /dev/null +++ b/awc/src/responses/response_body.rs @@ -0,0 +1,144 @@ +use std::{ + future::Future, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_http::{error::PayloadError, header, HttpMessage}; +use bytes::Bytes; +use futures_core::Stream; +use pin_project_lite::pin_project; + +use super::{read_body::ReadBody, ResponseTimeout, DEFAULT_BODY_LIMIT}; +use crate::ClientResponse; + +pin_project! { + /// A `Future` that reads a body stream, resolving as [`Bytes`]. + /// + /// # Errors + /// `Future` implementation returns error if: + /// - content type is not `application/json`; + /// - content length is greater than [limit](JsonBody::limit) (default: 2 MiB). + pub struct ResponseBody { + #[pin] + body: Option>, + length: Option, + timeout: ResponseTimeout, + err: Option, + } +} + +#[deprecated(since = "3.0.0", note = "Renamed to `ResponseBody`.")] +pub type MessageBody = ResponseBody; + +impl ResponseBody +where + S: Stream>, +{ + /// Creates a body stream reader from a response by taking its payload. + pub fn new(res: &mut ClientResponse) -> ResponseBody { + let length = match res.headers().get(&header::CONTENT_LENGTH) { + Some(value) => { + let len = value.to_str().ok().and_then(|s| s.parse::().ok()); + + match len { + None => return Self::err(PayloadError::UnknownLength), + len => len, + } + } + None => None, + }; + + ResponseBody { + body: Some(ReadBody::new(res.take_payload(), DEFAULT_BODY_LIMIT)), + length, + timeout: mem::take(&mut res.timeout), + err: None, + } + } + + /// Change max size limit of payload. + /// + /// The default limit is 2 MiB. + pub fn limit(mut self, limit: usize) -> Self { + if let Some(ref mut body) = self.body { + body.limit = limit; + } + + self + } + + fn err(err: PayloadError) -> Self { + ResponseBody { + body: None, + length: None, + timeout: ResponseTimeout::default(), + err: Some(err), + } + } +} + +impl Future for ResponseBody +where + S: Stream>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Some(err) = this.err.take() { + return Poll::Ready(Err(err)); + } + + if let Some(len) = this.length.take() { + let body = Option::as_ref(&this.body).unwrap(); + if len > body.limit { + return Poll::Ready(Err(PayloadError::Overflow)); + } + } + + this.timeout.poll_timeout(cx)?; + + this.body.as_pin_mut().unwrap().poll(cx) + } +} + +#[cfg(test)] +mod tests { + use static_assertions::assert_impl_all; + + use super::*; + use crate::{http::header, test::TestResponse}; + + assert_impl_all!(ResponseBody<()>: Unpin); + + #[actix_rt::test] + async fn read_body() { + let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "xxxx")).finish(); + match req.body().await.err().unwrap() { + PayloadError::UnknownLength => {} + _ => unreachable!("error"), + } + + let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "10000000")).finish(); + match req.body().await.err().unwrap() { + PayloadError::Overflow => {} + _ => unreachable!("error"), + } + + let mut req = TestResponse::default() + .set_payload(Bytes::from_static(b"test")) + .finish(); + assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test")); + + let mut req = TestResponse::default() + .set_payload(Bytes::from_static(b"11111111111111")) + .finish(); + match req.body().limit(5).await.err().unwrap() { + PayloadError::Overflow => {} + _ => unreachable!("error"), + } + } +} diff --git a/awc/src/sender.rs b/awc/src/sender.rs index a72b129f8..cd30e571d 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -8,12 +8,10 @@ use std::{ }; use actix_http::{ - body::{Body, BodyStream}, - http::{ - header::{self, HeaderMap, HeaderName, IntoHeaderValue}, - Error as HttpError, - }, - Error, RequestHead, RequestHeadType, + body::{BodyStream, MessageBody}, + error::HttpError, + header::{self, HeaderMap, HeaderName, TryIntoHeaderValue}, + RequestHead, RequestHeadType, }; use actix_rt::time::{sleep, Sleep}; use bytes::Bytes; @@ -21,28 +19,35 @@ use derive_more::From; use futures_core::Stream; use serde::Serialize; -#[cfg(feature = "compress")] -use actix_http::encoding::Decoder; -#[cfg(feature = "compress")] -use actix_http::http::header::ContentEncoding; -#[cfg(feature = "compress")] -use actix_http::{Payload, PayloadStream}; +#[cfg(feature = "__compress")] +use actix_http::{encoding::Decoder, header::ContentEncoding, Payload}; -use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError}; -use crate::response::ClientResponse; -use crate::ClientConfig; +use crate::{ + any_body::AnyBody, + client::ClientConfig, + error::{FreezeRequestError, InvalidUrl, SendRequestError}, + BoxError, ClientResponse, ConnectRequest, ConnectResponse, +}; #[derive(Debug, From)] pub(crate) enum PrepForSendingError { Url(InvalidUrl), Http(HttpError), + Json(serde_json::Error), + Form(serde_urlencoded::ser::Error), } impl From for FreezeRequestError { fn from(err: PrepForSendingError) -> FreezeRequestError { match err { - PrepForSendingError::Url(e) => FreezeRequestError::Url(e), - PrepForSendingError::Http(e) => FreezeRequestError::Http(e), + PrepForSendingError::Url(err) => FreezeRequestError::Url(err), + PrepForSendingError::Http(err) => FreezeRequestError::Http(err), + PrepForSendingError::Json(err) => { + FreezeRequestError::Custom(Box::new(err), Box::new("json serialization error")) + } + PrepForSendingError::Form(err) => { + FreezeRequestError::Custom(Box::new(err), Box::new("form serialization error")) + } } } } @@ -52,6 +57,12 @@ impl From for SendRequestError { match err { PrepForSendingError::Url(e) => SendRequestError::Url(e), PrepForSendingError::Http(e) => SendRequestError::Http(e), + PrepForSendingError::Json(err) => { + SendRequestError::Custom(Box::new(err), Box::new("json serialization error")) + } + PrepForSendingError::Form(err) => { + SendRequestError::Custom(Box::new(err), Box::new("form serialization error")) + } } } } @@ -60,7 +71,7 @@ impl From for SendRequestError { #[must_use = "futures do nothing unless polled"] pub enum SendClientRequest { Fut( - Pin>>>, + Pin>>>, // FIXME: use a pinned Sleep instead of box. Option>>, bool, @@ -70,7 +81,7 @@ pub enum SendClientRequest { impl SendClientRequest { pub(crate) fn new( - send: Pin>>>, + send: Pin>>>, response_decompress: bool, timeout: Option, ) -> SendClientRequest { @@ -79,30 +90,35 @@ impl SendClientRequest { } } -#[cfg(feature = "compress")] +#[cfg(feature = "__compress")] impl Future for SendClientRequest { - type Output = Result>>, SendRequestError>; + type Output = Result>, SendRequestError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); match this { SendClientRequest::Fut(send, delay, response_decompress) => { - if delay.is_some() { - match Pin::new(delay.as_mut().unwrap()).poll(cx) { - Poll::Pending => {} - _ => return Poll::Ready(Err(SendRequestError::Timeout)), + if let Some(delay) = delay { + if delay.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(SendRequestError::Timeout)); } } - let res = futures_core::ready!(Pin::new(send).poll(cx)).map(|res| { - res.map_body(|head, payload| { - if *response_decompress { - Payload::Stream(Decoder::from_headers(payload, &head.headers)) - } else { - Payload::Stream(Decoder::new(payload, ContentEncoding::Identity)) - } - }) + let res = futures_core::ready!(send.as_mut().poll(cx)).map(|res| { + res.into_client_response()._timeout(delay.take()).map_body( + |head, payload| { + if *response_decompress { + Payload::Stream { + payload: Decoder::from_headers(payload, &head.headers), + } + } else { + Payload::Stream { + payload: Decoder::new(payload, ContentEncoding::Identity), + } + } + }, + ) }); Poll::Ready(res) @@ -115,7 +131,7 @@ impl Future for SendClientRequest { } } -#[cfg(not(feature = "compress"))] +#[cfg(not(feature = "__compress"))] impl Future for SendClientRequest { type Output = Result; @@ -123,13 +139,14 @@ impl Future for SendClientRequest { let this = self.get_mut(); match this { SendClientRequest::Fut(send, delay, _) => { - if delay.is_some() { - match Pin::new(delay.as_mut().unwrap()).poll(cx) { - Poll::Pending => {} - _ => return Poll::Ready(Err(SendRequestError::Timeout)), + if let Some(delay) = delay { + if delay.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(SendRequestError::Timeout)); } } - Pin::new(send).poll(cx) + send.as_mut() + .poll(cx) + .map_ok(|res| res.into_client_response()._timeout(delay.take())) } SendClientRequest::Err(ref mut e) => match e.take() { Some(e) => Poll::Ready(Err(e)), @@ -145,12 +162,6 @@ impl From for SendClientRequest { } } -impl From for SendClientRequest { - fn from(e: Error) -> Self { - SendClientRequest::Err(Some(e.into())) - } -} - impl From for SendClientRequest { fn from(e: HttpError) -> Self { SendClientRequest::Err(Some(e.into())) @@ -170,86 +181,73 @@ pub(crate) enum RequestSender { } impl RequestSender { - pub(crate) fn send_body( + pub(crate) fn send_body( self, addr: Option, response_decompress: bool, timeout: Option, config: &ClientConfig, - body: B, - ) -> SendClientRequest - where - B: Into, - { - let fut = match self { - RequestSender::Owned(head) => { - config - .connector - .send_request(RequestHeadType::Owned(head), body.into(), addr) - } - RequestSender::Rc(head, extra_headers) => config.connector.send_request( + body: impl MessageBody + 'static, + ) -> SendClientRequest { + let req = match self { + RequestSender::Owned(head) => ConnectRequest::Client( + RequestHeadType::Owned(head), + AnyBody::from_message_body(body).into_boxed(), + addr, + ), + RequestSender::Rc(head, extra_headers) => ConnectRequest::Client( RequestHeadType::Rc(head, extra_headers), - body.into(), + AnyBody::from_message_body(body).into_boxed(), addr, ), }; + let fut = config.connector.call(req); + SendClientRequest::new(fut, response_decompress, timeout.or(config.timeout)) } - pub(crate) fn send_json( + pub(crate) fn send_json( mut self, addr: Option, response_decompress: bool, timeout: Option, config: &ClientConfig, - value: &T, + value: impl Serialize, ) -> SendClientRequest { - let body = match serde_json::to_string(value) { + let body = match serde_json::to_string(&value) { Ok(body) => body, - Err(e) => return Error::from(e).into(), + Err(err) => return PrepForSendingError::Json(err).into(), }; if let Err(e) = self.set_header_if_none(header::CONTENT_TYPE, "application/json") { return e.into(); } - self.send_body( - addr, - response_decompress, - timeout, - config, - Body::Bytes(Bytes::from(body)), - ) + self.send_body(addr, response_decompress, timeout, config, body) } - pub(crate) fn send_form( + pub(crate) fn send_form( mut self, addr: Option, response_decompress: bool, timeout: Option, config: &ClientConfig, - value: &T, + value: impl Serialize, ) -> SendClientRequest { let body = match serde_urlencoded::to_string(value) { Ok(body) => body, - Err(e) => return Error::from(e).into(), + Err(err) => return PrepForSendingError::Form(err).into(), }; // set content-type - if let Err(e) = + if let Err(err) = self.set_header_if_none(header::CONTENT_TYPE, "application/x-www-form-urlencoded") { - return e.into(); + return err.into(); } - self.send_body( - addr, - response_decompress, - timeout, - config, - Body::Bytes(Bytes::from(body)), - ) + self.send_body(addr, response_decompress, timeout, config, body) } pub(crate) fn send_stream( @@ -261,15 +259,15 @@ impl RequestSender { stream: S, ) -> SendClientRequest where - S: Stream> + Unpin + 'static, - E: Into + 'static, + S: Stream> + 'static, + E: Into + 'static, { self.send_body( addr, response_decompress, timeout, config, - Body::from_message(BodyStream::new(stream)), + BodyStream::new(stream), ) } @@ -280,12 +278,12 @@ impl RequestSender { timeout: Option, config: &ClientConfig, ) -> SendClientRequest { - self.send_body(addr, response_decompress, timeout, config, Body::Empty) + self.send_body(addr, response_decompress, timeout, config, ()) } fn set_header_if_none(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> where - V: IntoHeaderValue, + V: TryIntoHeaderValue, { match self { RequestSender::Owned(head) => { diff --git a/awc/src/test.rs b/awc/src/test.rs index 97bbb9c3d..96ae1f0a1 100644 --- a/awc/src/test.rs +++ b/awc/src/test.rs @@ -1,16 +1,10 @@ //! Test helpers for actix http client to use during testing. -use std::convert::TryFrom; -use actix_http::http::header::{Header, IntoHeaderValue}; -use actix_http::http::{Error as HttpError, HeaderName, StatusCode, Version}; -#[cfg(feature = "cookies")] -use actix_http::{ - cookie::{Cookie, CookieJar}, - http::header::{self, HeaderValue}, -}; -use actix_http::{h1, Payload, ResponseHead}; +use actix_http::{h1, header::TryIntoHeaderPair, Payload, ResponseHead, StatusCode, Version}; use bytes::Bytes; +#[cfg(feature = "cookies")] +use crate::cookie::{Cookie, CookieJar}; use crate::ClientResponse; /// Test `ClientResponse` builder @@ -34,13 +28,8 @@ impl Default for TestResponse { impl TestResponse { /// Create TestResponse and set header - pub fn with_header(key: K, value: V) -> Self - where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, - { - Self::default().header(key, value) + pub fn with_header(header: impl TryIntoHeaderPair) -> Self { + Self::default().insert_header(header) } /// Set HTTP version of this response @@ -49,27 +38,20 @@ impl TestResponse { self } - /// Set a header - pub fn set(mut self, hdr: H) -> Self { - if let Ok(value) = hdr.try_into_value() { - self.head.headers.append(H::name(), value); + /// Insert a header + pub fn insert_header(mut self, header: impl TryIntoHeaderPair) -> Self { + if let Ok((key, value)) = header.try_into_pair() { + self.head.headers.insert(key, value); return self; } panic!("Can not set header"); } /// Append a header - pub fn header(mut self, key: K, value: V) -> Self - where - HeaderName: TryFrom, - >::Error: Into, - V: IntoHeaderValue, - { - if let Ok(key) = HeaderName::try_from(key) { - if let Ok(value) = value.try_into_value() { - self.head.headers.append(key, value); - return self; - } + pub fn append_header(mut self, header: impl TryIntoHeaderPair) -> Self { + if let Ok((key, value)) = header.try_into_pair() { + self.head.headers.append(key, value); + return self; } panic!("Can not create header"); } @@ -83,7 +65,7 @@ impl TestResponse { /// Set response's payload pub fn set_payload>(mut self, data: B) -> Self { - let mut payload = h1::Payload::empty(); + let (_, mut payload) = h1::Payload::create(true); payload.unread_data(data.into()); self.payload = Some(payload.into()); self @@ -97,6 +79,8 @@ impl TestResponse { #[cfg(feature = "cookies")] for cookie in self.cookies.delta() { + use actix_http::header::{self, HeaderValue}; + head.headers.insert( header::SET_COOKIE, HeaderValue::from_str(&cookie.encoded().to_string()).unwrap(), @@ -106,7 +90,8 @@ impl TestResponse { if let Some(pl) = self.payload { ClientResponse::new(head, pl) } else { - ClientResponse::new(head, h1::Payload::empty().into()) + let (_, payload) = h1::Payload::create(true); + ClientResponse::new(head, payload.into()) } } } @@ -115,6 +100,8 @@ impl TestResponse { mod tests { use std::time::SystemTime; + use actix_http::header::HttpDate; + use super::*; use crate::{cookie, http::header}; @@ -122,7 +109,7 @@ mod tests { fn test_basics() { let res = TestResponse::default() .version(Version::HTTP_2) - .set(header::Date(SystemTime::now().into())) + .insert_header((header::DATE, HttpDate::from(SystemTime::now()))) .cookie(cookie::Cookie::build("name", "value").finish()) .finish(); assert!(res.headers().contains_key(header::SET_COOKIE)); diff --git a/awc/src/ws.rs b/awc/src/ws.rs index d5528595d..f3ee02d43 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -6,7 +6,7 @@ //! //! ```no_run //! use awc::{Client, ws}; -//! use futures_util::{sink::SinkExt, stream::StreamExt}; +//! use futures_util::{sink::SinkExt as _, stream::StreamExt as _}; //! //! #[actix_rt::main] //! async fn main() { @@ -26,25 +26,28 @@ //! } //! ``` -use std::convert::TryFrom; -use std::net::SocketAddr; -use std::rc::Rc; -use std::{fmt, str}; +use std::{convert::TryFrom, fmt, net::SocketAddr, str}; use actix_codec::Framed; -#[cfg(feature = "cookies")] -use actix_http::cookie::{Cookie, CookieJar}; use actix_http::{ws, Payload, RequestHead}; use actix_rt::time::timeout; +use actix_service::Service as _; pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; -use crate::connect::BoxedSocket; -use crate::error::{InvalidUrl, SendRequestError, WsClientError}; -use crate::http::header::{self, HeaderName, HeaderValue, IntoHeaderValue, AUTHORIZATION}; -use crate::http::{ConnectionType, Error as HttpError, Method, StatusCode, Uri, Version}; -use crate::response::ClientResponse; -use crate::ClientConfig; +use crate::{ + client::ClientConfig, + connect::{BoxedSocket, ConnectRequest}, + error::{HttpError, InvalidUrl, SendRequestError, WsClientError}, + http::{ + header::{self, HeaderName, HeaderValue, TryIntoHeaderValue, AUTHORIZATION}, + ConnectionType, Method, StatusCode, Uri, Version, + }, + ClientResponse, +}; + +#[cfg(feature = "cookies")] +use crate::cookie::{Cookie, CookieJar}; /// WebSocket connection. pub struct WebsocketsRequest { @@ -55,7 +58,7 @@ pub struct WebsocketsRequest { addr: Option, max_size: usize, server_mode: bool, - config: Rc, + config: ClientConfig, #[cfg(feature = "cookies")] cookies: Option, @@ -63,7 +66,7 @@ pub struct WebsocketsRequest { impl WebsocketsRequest { /// Create new WebSocket connection - pub(crate) fn new(uri: U, config: Rc) -> Self + pub(crate) fn new(uri: U, config: ClientConfig) -> Self where Uri: TryFrom, >::Error: Into, @@ -168,7 +171,7 @@ impl WebsocketsRequest { where HeaderName: TryFrom, >::Error: Into, - V: IntoHeaderValue, + V: TryIntoHeaderValue, { match HeaderName::try_from(key) { Ok(key) => match value.try_into_value() { @@ -187,7 +190,7 @@ impl WebsocketsRequest { where HeaderName: TryFrom, >::Error: Into, - V: IntoHeaderValue, + V: TryIntoHeaderValue, { match HeaderName::try_from(key) { Ok(key) => match value.try_into_value() { @@ -206,7 +209,7 @@ impl WebsocketsRequest { where HeaderName: TryFrom, >::Error: Into, - V: IntoHeaderValue, + V: TryIntoHeaderValue, { match HeaderName::try_from(key) { Ok(key) => { @@ -280,7 +283,7 @@ impl WebsocketsRequest { let cookie: String = jar .delta() // ensure only name=value is written to cookie header - .map(|c| Cookie::new(c.name(), c.value()).encoded().to_string()) + .map(|c| c.stripped().encoded().to_string()) .collect::>() .join("; "); @@ -297,13 +300,16 @@ impl WebsocketsRequest { } self.head.set_connection_type(ConnectionType::Upgrade); + + #[allow(clippy::declare_interior_mutable_const)] + const HV_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); + self.head.headers.insert(header::UPGRADE, HV_WEBSOCKET); + + #[allow(clippy::declare_interior_mutable_const)] + const HV_THIRTEEN: HeaderValue = HeaderValue::from_static("13"); self.head .headers - .insert(header::UPGRADE, HeaderValue::from_static("websocket")); - self.head.headers.insert( - header::SEC_WEBSOCKET_VERSION, - HeaderValue::from_static("13"), - ); + .insert(header::SEC_WEBSOCKET_VERSION, HV_THIRTEEN); if let Some(protocols) = self.protocols.take() { self.head.headers.insert( @@ -312,9 +318,8 @@ impl WebsocketsRequest { ); } - // Generate a random key for the `Sec-WebSocket-Key` header. - // a base64-encoded (see Section 4 of [RFC4648]) value that, - // when decoded, is 16 bytes in length (RFC 6455) + // Generate a random key for the `Sec-WebSocket-Key` header which is a base64-encoded + // (see RFC 4648 §4) value that, when decoded, is 16 bytes in length (RFC 6455 §1.3). let sec_key: [u8; 16] = rand::random(); let key = base64::encode(&sec_key); @@ -327,18 +332,21 @@ impl WebsocketsRequest { let max_size = self.max_size; let server_mode = self.server_mode; - let fut = self.config.connector.open_tunnel(head, self.addr); + let req = ConnectRequest::Tunnel(head, self.addr); + + let fut = self.config.connector.call(req); // set request timeout - let (head, framed) = if let Some(to) = self.config.timeout { + let res = if let Some(to) = self.config.timeout { timeout(to, fut) .await - .map_err(|_| SendRequestError::Timeout) - .and_then(|res| res)? + .map_err(|_| SendRequestError::Timeout)?? } else { fut.await? }; + let (head, framed) = res.into_tunnel_response(); + // verify response if head.status != StatusCode::SWITCHING_PROTOCOLS { return Err(WsClientError::InvalidResponseStatus(head.status)); @@ -377,12 +385,14 @@ impl WebsocketsRequest { if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) { let encoded = ws::hash_key(key.as_ref()); - if hdr_key.as_bytes() != encoded.as_bytes() { + + if hdr_key.as_bytes() != encoded { log::trace!( - "Invalid challenge response: expected: {} received: {:?}", - encoded, + "Invalid challenge response: expected: {:?} received: {:?}", + &encoded, key ); + return Err(WsClientError::InvalidChallengeResponse( encoded, hdr_key.clone(), @@ -438,7 +448,7 @@ mod tests { #[actix_rt::test] async fn test_header_override() { let req = Client::builder() - .header(header::CONTENT_TYPE, "111") + .add_default_header((header::CONTENT_TYPE, "111")) .finish() .ws("/") .set_header(header::CONTENT_TYPE, "222"); @@ -512,7 +522,7 @@ mod tests { "test-origin" ); assert_eq!(req.max_size, 100); - assert_eq!(req.server_mode, true); + assert!(req.server_mode); assert_eq!(req.protocols, Some("v1,v2".to_string())); assert_eq!( req.head.headers.get(header::CONTENT_TYPE).unwrap(), diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs index 7e74d226e..c1378157b 100644 --- a/awc/tests/test_client.rs +++ b/awc/tests/test_client.rs @@ -1,56 +1,35 @@ -use std::collections::HashMap; -use std::io::{Read, Write}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::time::Duration; +use std::{ + collections::HashMap, + convert::Infallible, + io::{Read, Write}, + net::{IpAddr, Ipv4Addr}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; -use brotli2::write::BrotliEncoder; +use actix_utils::future::ok; use bytes::Bytes; -use flate2::read::GzDecoder; -use flate2::write::GzEncoder; -use flate2::Compression; -use futures_util::{future::ok, stream}; +use cookie::Cookie; +use futures_util::stream; use rand::Rng; -use actix_http::{ - http::{self, StatusCode}, - HttpService, -}; +use actix_http::{HttpService, StatusCode}; use actix_http_test::test_server; -use actix_service::{map_config, pipeline_factory}; -use actix_web::{ - dev::{AppConfig, BodyEncoding}, - http::{header, Cookie}, - middleware::Compress, - test, web, App, Error, HttpMessage, HttpRequest, HttpResponse, -}; -use awc::error::SendRequestError; +use actix_service::{fn_service, map_config, ServiceFactoryExt as _}; +use actix_web::{dev::AppConfig, http::header, web, App, Error, HttpRequest, HttpResponse}; +use awc::error::{JsonPayloadError, PayloadError, SendRequestError}; -const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World"; +mod utils; + +const S: &str = "Hello World "; +const STR: &str = const_str::repeat!(S, 100); #[actix_rt::test] async fn test_simple() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) }); @@ -76,7 +55,7 @@ async fn test_simple() { #[actix_rt::test] async fn test_json() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service( web::resource("/").route(web::to(|_: web::Json| HttpResponse::Ok())), ) @@ -92,7 +71,7 @@ async fn test_json() { #[actix_rt::test] async fn test_form() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service(web::resource("/").route(web::to( |_: web::Form>| HttpResponse::Ok(), ))) @@ -111,7 +90,7 @@ async fn test_form() { #[actix_rt::test] async fn test_timeout() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service(web::resource("/").route(web::to(|| async { actix_rt::time::sleep(Duration::from_millis(200)).await; Ok::<_, Error>(HttpResponse::Ok().body(STR)) @@ -119,9 +98,8 @@ async fn test_timeout() { }); let connector = awc::Connector::new() - .connector(actix_tls::connect::default_connector()) - .timeout(Duration::from_secs(15)) - .finish(); + .connector(actix_tls::connect::ConnectorService::default()) + .timeout(Duration::from_secs(15)); let client = awc::Client::builder() .connector(connector) @@ -137,7 +115,7 @@ async fn test_timeout() { #[actix_rt::test] async fn test_timeout_override() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service(web::resource("/").route(web::to(|| async { actix_rt::time::sleep(Duration::from_millis(200)).await; Ok::<_, Error>(HttpResponse::Ok().body(STR)) @@ -157,6 +135,79 @@ async fn test_timeout_override() { } } +#[actix_rt::test] +async fn test_response_timeout() { + use futures_util::stream::{once, StreamExt as _}; + + let srv = actix_test::start(|| { + App::new().service(web::resource("/").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .content_type("application/json") + .streaming(Box::pin(once(async { + actix_rt::time::sleep(Duration::from_millis(200)).await; + Ok::<_, Error>(Bytes::from(STR)) + }))), + ) + }))) + }); + + let client = awc::Client::new(); + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(500)) + .body() + .await + .unwrap(); + assert_eq!(std::str::from_utf8(res.as_ref()).unwrap(), STR); + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .next() + .await + .unwrap(); + match res { + Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut), + _ => panic!("Response error type is not matched"), + } + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .body() + .await; + match res { + Err(PayloadError::Io(e)) => assert_eq!(e.kind(), std::io::ErrorKind::TimedOut), + _ => panic!("Response error type is not matched"), + } + + let res = client + .get(srv.url("/")) + .send() + .await + .unwrap() + .timeout(Duration::from_millis(100)) + .json::>() + .await; + match res { + Err(JsonPayloadError::Payload(PayloadError::Io(e))) => { + assert_eq!(e.kind(), std::io::ErrorKind::TimedOut) + } + _ => panic!("Response error type is not matched"), + } +} + #[actix_rt::test] async fn test_connection_reuse() { let num = Arc::new(AtomicUsize::new(0)); @@ -164,7 +215,7 @@ async fn test_connection_reuse() { let srv = test_server(move || { let num2 = num2.clone(); - pipeline_factory(move |io| { + fn_service(move |io| { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) @@ -201,7 +252,7 @@ async fn test_connection_force_close() { let srv = test_server(move || { let num2 = num2.clone(); - pipeline_factory(move |io| { + fn_service(move |io| { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) @@ -238,7 +289,7 @@ async fn test_connection_server_close() { let srv = test_server(move || { let num2 = num2.clone(); - pipeline_factory(move |io| { + fn_service(move |io| { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) @@ -278,7 +329,7 @@ async fn test_connection_wait_queue() { let srv = test_server(move || { let num2 = num2.clone(); - pipeline_factory(move |io| { + fn_service(move |io| { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) @@ -295,7 +346,7 @@ async fn test_connection_wait_queue() { .await; let client = awc::Client::builder() - .connector(awc::Connector::new().limit(1).finish()) + .connector(awc::Connector::new().limit(1)) .finish(); // req 1 @@ -326,7 +377,7 @@ async fn test_connection_wait_queue_force_close() { let srv = test_server(move || { let num2 = num2.clone(); - pipeline_factory(move |io| { + fn_service(move |io| { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) @@ -344,7 +395,7 @@ async fn test_connection_wait_queue_force_close() { .await; let client = awc::Client::builder() - .connector(awc::Connector::new().limit(1).finish()) + .connector(awc::Connector::new().limit(1)) .finish(); // req 1 @@ -370,7 +421,7 @@ async fn test_connection_wait_queue_force_close() { #[actix_rt::test] async fn test_with_query_parameter() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service(web::resource("/").to(|req: HttpRequest| { if req.query_string().contains("qp") { HttpResponse::Ok() @@ -388,20 +439,18 @@ async fn test_with_query_parameter() { assert!(res.status().is_success()); } +#[cfg(feature = "compress-gzip")] #[actix_rt::test] async fn test_no_decompress() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new() - .wrap(Compress::default()) - .service(web::resource("/").route(web::to(|| { - let mut res = HttpResponse::Ok().body(STR); - res.encoding(header::ContentEncoding::Gzip); - res - }))) + .wrap(actix_web::middleware::Compress::default()) + .service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) }); let mut res = awc::Client::new() .get(srv.url("/")) + .insert_header((header::ACCEPT_ENCODING, "gzip")) .no_decompress() .send() .await @@ -410,15 +459,12 @@ async fn test_no_decompress() { // read response let bytes = res.body().await.unwrap(); - - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + assert_eq!(utils::gzip::decode(bytes), STR.as_bytes()); // POST let mut res = awc::Client::new() .post(srv.url("/")) + .insert_header((header::ACCEPT_ENCODING, "gzip")) .no_decompress() .send() .await @@ -426,23 +472,17 @@ async fn test_no_decompress() { assert!(res.status().is_success()); let bytes = res.body().await.unwrap(); - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + assert_eq!(utils::gzip::decode(bytes), STR.as_bytes()); } +#[cfg(feature = "compress-gzip")] #[actix_rt::test] async fn test_client_gzip_encoding() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service(web::resource("/").route(web::to(|| { - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(STR.as_ref()).unwrap(); - let data = e.finish().unwrap(); - HttpResponse::Ok() - .insert_header(("content-encoding", "gzip")) - .body(data) + .insert_header(header::ContentEncoding::Gzip) + .body(utils::gzip::encode(STR)) }))) }); @@ -452,20 +492,17 @@ async fn test_client_gzip_encoding() { // read response let bytes = response.body().await.unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + assert_eq!(bytes, STR); } +#[cfg(feature = "compress-gzip")] #[actix_rt::test] async fn test_client_gzip_encoding_large() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service(web::resource("/").route(web::to(|| { - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(STR.repeat(10).as_ref()).unwrap(); - let data = e.finish().unwrap(); - HttpResponse::Ok() - .insert_header(("content-encoding", "gzip")) - .body(data) + .insert_header(header::ContentEncoding::Gzip) + .body(utils::gzip::encode(STR.repeat(10))) }))) }); @@ -475,9 +512,10 @@ async fn test_client_gzip_encoding_large() { // read response let bytes = response.body().await.unwrap(); - assert_eq!(bytes, Bytes::from(STR.repeat(10))); + assert_eq!(bytes, STR.repeat(10)); } +#[cfg(feature = "compress-gzip")] #[actix_rt::test] async fn test_client_gzip_encoding_large_random() { let data = rand::thread_rng() @@ -486,14 +524,11 @@ async fn test_client_gzip_encoding_large_random() { .map(char::from) .collect::(); - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service(web::resource("/").route(web::to(|data: Bytes| { - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(&data).unwrap(); - let data = e.finish().unwrap(); HttpResponse::Ok() - .insert_header(("content-encoding", "gzip")) - .body(data) + .insert_header(header::ContentEncoding::Gzip) + .body(utils::gzip::encode(data)) }))) }); @@ -503,19 +538,17 @@ async fn test_client_gzip_encoding_large_random() { // read response let bytes = response.body().await.unwrap(); - assert_eq!(bytes, Bytes::from(data)); + assert_eq!(bytes, data); } +#[cfg(feature = "compress-brotli")] #[actix_rt::test] async fn test_client_brotli_encoding() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service(web::resource("/").route(web::to(|data: Bytes| { - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(&data).unwrap(); - let data = e.finish().unwrap(); HttpResponse::Ok() .insert_header(("content-encoding", "br")) - .body(data) + .body(utils::brotli::encode(data)) }))) }); @@ -528,6 +561,7 @@ async fn test_client_brotli_encoding() { assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } +#[cfg(feature = "compress-brotli")] #[actix_rt::test] async fn test_client_brotli_encoding_large_random() { let data = rand::thread_rng() @@ -536,14 +570,11 @@ async fn test_client_brotli_encoding_large_random() { .map(char::from) .collect::(); - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service(web::resource("/").route(web::to(|data: Bytes| { - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(&data).unwrap(); - let data = e.finish().unwrap(); HttpResponse::Ok() - .insert_header(("content-encoding", "br")) - .body(data) + .insert_header(header::ContentEncoding::Brotli) + .body(utils::brotli::encode(&data)) }))) }); @@ -553,27 +584,25 @@ async fn test_client_brotli_encoding_large_random() { // read response let bytes = response.body().await.unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data)); + assert_eq!(bytes, data); } #[actix_rt::test] async fn test_client_deflate_encoding() { - let srv = test::start(|| { - App::new().default_service(web::to(|body: Bytes| { - HttpResponse::Ok() - .encoding(http::ContentEncoding::Br) - .body(body) - })) + let srv = actix_test::start(|| { + App::new().default_service(web::to(|body: Bytes| HttpResponse::Ok().body(body))) }); - let req = srv.post("/").send_body(STR); + let req = srv + .post("/") + .insert_header((header::ACCEPT_ENCODING, "gzip")) + .send_body(STR); let mut res = req.await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let bytes = res.body().await.unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + assert_eq!(bytes, STR); } #[actix_rt::test] @@ -584,15 +613,14 @@ async fn test_client_deflate_encoding_large_random() { .take(70_000) .collect::(); - let srv = test::start(|| { - App::new().default_service(web::to(|body: Bytes| { - HttpResponse::Ok() - .encoding(http::ContentEncoding::Br) - .body(body) - })) + let srv = actix_test::start(|| { + App::new().default_service(web::to(|body: Bytes| HttpResponse::Ok().body(body))) }); - let req = srv.post("/").send_body(data.clone()); + let req = srv + .post("/") + .insert_header((header::ACCEPT_ENCODING, "br")) + .send_body(data.clone()); let mut res = req.await.unwrap(); let bytes = res.body().await.unwrap(); @@ -603,17 +631,18 @@ async fn test_client_deflate_encoding_large_random() { #[actix_rt::test] async fn test_client_streaming_explicit() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().default_service(web::to(|body: web::Payload| { - HttpResponse::Ok() - .encoding(http::ContentEncoding::Identity) - .streaming(body) + HttpResponse::Ok().streaming(body) })) }); let body = stream::once(async { Ok::<_, actix_http::Error>(Bytes::from_static(STR.as_bytes())) }); - let req = srv.post("/").send_stream(Box::pin(body)); + let req = srv + .post("/") + .insert_header((header::ACCEPT_ENCODING, "identity")) + .send_stream(Box::pin(body)); let mut res = req.await.unwrap(); assert!(res.status().is_success()); @@ -624,19 +653,18 @@ async fn test_client_streaming_explicit() { #[actix_rt::test] async fn test_body_streaming_implicit() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().default_service(web::to(|| { - let body = stream::once(async { - Ok::<_, actix_http::Error>(Bytes::from_static(STR.as_bytes())) - }); - - HttpResponse::Ok() - .encoding(http::ContentEncoding::Gzip) - .streaming(Box::pin(body)) + let body = + stream::once(async { Ok::<_, Infallible>(Bytes::from_static(STR.as_bytes())) }); + HttpResponse::Ok().streaming(body) })) }); - let req = srv.get("/").send(); + let req = srv + .get("/") + .insert_header((header::ACCEPT_ENCODING, "gzip")) + .send(); let mut res = req.await.unwrap(); assert!(res.status().is_success()); @@ -660,7 +688,7 @@ async fn test_client_cookie_handling() { let cookie1b = cookie1.clone(); let cookie2b = cookie2.clone(); - let srv = test::start(move || { + let srv = actix_test::start(move || { let cookie1 = cookie1b.clone(); let cookie2 = cookie2b.clone(); @@ -716,22 +744,19 @@ async fn test_client_cookie_handling() { #[actix_rt::test] async fn client_unread_response() { - let addr = test::unused_addr(); - + let addr = actix_test::unused_addr(); let lst = std::net::TcpListener::bind(addr).unwrap(); std::thread::spawn(move || { - for stream in lst.incoming() { - let mut stream = stream.unwrap(); - let mut b = [0; 1000]; - let _ = stream.read(&mut b).unwrap(); - let _ = stream.write_all( - b"HTTP/1.1 200 OK\r\n\ + let (mut stream, _) = lst.accept().unwrap(); + let mut b = [0; 1000]; + let _ = stream.read(&mut b).unwrap(); + let _ = stream.write_all( + b"HTTP/1.1 200 OK\r\n\ connection: close\r\n\ \r\n\ welcome!", - ); - } + ); }); // client request @@ -746,7 +771,7 @@ async fn client_unread_response() { #[actix_rt::test] async fn client_basic_auth() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().route( "/", web::to(|req: HttpRequest| { @@ -756,7 +781,7 @@ async fn client_basic_auth() { .unwrap() .to_str() .unwrap() - == "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" + == format!("Basic {}", base64::encode("username:password")) { HttpResponse::Ok() } else { @@ -767,14 +792,14 @@ async fn client_basic_auth() { }); // set authorization header to Basic - let request = srv.get("/").basic_auth("username", Some("password")); + let request = srv.get("/").basic_auth("username", "password"); let response = request.send().await.unwrap(); assert!(response.status().is_success()); } #[actix_rt::test] async fn client_bearer_auth() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().route( "/", web::to(|req: HttpRequest| { @@ -799,3 +824,34 @@ async fn client_bearer_auth() { let response = request.send().await.unwrap(); assert!(response.status().is_success()); } + +#[actix_rt::test] +async fn test_local_address() { + let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + + let srv = actix_test::start(move || { + App::new().service(web::resource("/").route(web::to( + move |req: HttpRequest| async move { + assert_eq!(req.peer_addr().unwrap().ip(), ip); + Ok::<_, Error>(HttpResponse::Ok()) + }, + ))) + }); + let client = awc::Client::builder().local_address(ip).finish(); + + let res = client.get(srv.url("/")).send().await.unwrap(); + + assert_eq!(res.status(), 200); + + let client = awc::Client::builder() + .connector( + // connector local address setting should always be override by client builder. + awc::Connector::new().local_address(IpAddr::V4(Ipv4Addr::new(128, 0, 0, 1))), + ) + .local_address(ip) + .finish(); + + let res = client.get(srv.url("/")).send().await.unwrap(); + + assert_eq!(res.status(), 200); +} diff --git a/awc/tests/test_connector.rs b/awc/tests/test_connector.rs index fd725506d..0f0b81414 100644 --- a/awc/tests/test_connector.rs +++ b/awc/tests/test_connector.rs @@ -39,7 +39,7 @@ fn tls_config() -> SslAcceptor { #[actix_rt::test] async fn test_connection_window_size() { - let srv = test_server(move || { + let srv = test_server(|| { HttpService::build() .h2(map_config( App::new().service(web::resource("/").route(web::to(HttpResponse::Ok))), @@ -58,7 +58,7 @@ async fn test_connection_window_size() { .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); let client = awc::Client::builder() - .connector(awc::Connector::new().ssl(builder.build()).finish()) + .connector(awc::Connector::new().openssl(builder.build())) .initial_window_size(100) .initial_connection_window_size(100) .finish(); diff --git a/awc/tests/test_rustls_client.rs b/awc/tests/test_rustls_client.rs index a928715a8..652997de6 100644 --- a/awc/tests/test_rustls_client.rs +++ b/awc/tests/test_rustls_client.rs @@ -8,44 +8,59 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, + time::SystemTime, }; use actix_http::HttpService; use actix_http_test::test_server; -use actix_service::{map_config, pipeline_factory, ServiceFactoryExt}; +use actix_service::{fn_service, map_config, ServiceFactoryExt}; +use actix_tls::connect::rustls::webpki_roots_cert_store; +use actix_utils::future::ok; use actix_web::{dev::AppConfig, http::Version, web, App, HttpResponse}; -use futures_util::future::ok; -use rustls::internal::pemfile::{certs, pkcs8_private_keys}; -use rustls::{ClientConfig, NoClientAuth, ServerConfig}; +use rustls::{ + client::{ServerCertVerified, ServerCertVerifier}, + Certificate, ClientConfig, PrivateKey, ServerConfig, ServerName, +}; +use rustls_pemfile::{certs, pkcs8_private_keys}; fn tls_config() -> ServerConfig { let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); let cert_file = cert.serialize_pem().unwrap(); let key_file = cert.serialize_private_key_pem(); - let mut config = ServerConfig::new(NoClientAuth::new()); let cert_file = &mut BufReader::new(cert_file.as_bytes()); let key_file = &mut BufReader::new(key_file.as_bytes()); - let cert_chain = certs(cert_file).unwrap(); + let cert_chain = certs(cert_file) + .unwrap() + .into_iter() + .map(Certificate) + .collect(); let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - config + ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert_chain, PrivateKey(keys.remove(0))) + .unwrap() } mod danger { + use super::*; + pub struct NoCertificateVerification; - impl rustls::ServerCertVerifier for NoCertificateVerification { + impl ServerCertVerifier for NoCertificateVerification { fn verify_server_cert( &self, - _roots: &rustls::RootCertStore, - _presented_certs: &[rustls::Certificate], - _dns_name: webpki::DNSNameRef<'_>, - _ocsp: &[u8], - ) -> Result { - Ok(rustls::ServerCertVerified::assertion()) + _end_entity: &Certificate, + _intermediates: &[Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: SystemTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) } } } @@ -57,7 +72,7 @@ async fn test_connection_reuse_h2() { let srv = test_server(move || { let num2 = num2.clone(); - pipeline_factory(move |io| { + fn_service(move |io| { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) @@ -73,16 +88,21 @@ async fn test_connection_reuse_h2() { }) .await; - // disable TLS verification - let mut config = ClientConfig::new(); + let mut config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(webpki_roots_cert_store()) + .with_no_client_auth(); + let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - config.set_protocols(&protos); + config.alpn_protocols = protos; + + // disable TLS verification config .dangerous() .set_certificate_verifier(Arc::new(danger::NoCertificateVerification)); let client = awc::Client::builder() - .connector(awc::Connector::new().rustls(Arc::new(config)).finish()) + .connector(awc::Connector::new().rustls(Arc::new(config))) .finish(); // req 1 diff --git a/awc/tests/test_ssl_client.rs b/awc/tests/test_ssl_client.rs index 08aa125cd..40c9ab8f0 100644 --- a/awc/tests/test_ssl_client.rs +++ b/awc/tests/test_ssl_client.rs @@ -7,10 +7,10 @@ use std::sync::Arc; use actix_http::HttpService; use actix_http_test::test_server; -use actix_service::{map_config, pipeline_factory, ServiceFactoryExt}; +use actix_service::{fn_service, map_config, ServiceFactoryExt}; +use actix_utils::future::ok; use actix_web::http::Version; use actix_web::{dev::AppConfig, web, App, HttpResponse}; -use futures_util::future::ok; use openssl::{ pkey::PKey, ssl::{SslAcceptor, SslConnector, SslMethod, SslVerifyMode}, @@ -48,7 +48,7 @@ async fn test_connection_reuse_h2() { let srv = test_server(move || { let num2 = num2.clone(); - pipeline_factory(move |io| { + fn_service(move |io| { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) @@ -72,7 +72,7 @@ async fn test_connection_reuse_h2() { .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); let client = awc::Client::builder() - .connector(awc::Connector::new().ssl(builder.build()).finish()) + .connector(awc::Connector::new().openssl(builder.build())) .finish(); // req 1 diff --git a/awc/tests/test_ws.rs b/awc/tests/test_ws.rs index 1b3f780dc..bfc81afbc 100644 --- a/awc/tests/test_ws.rs +++ b/awc/tests/test_ws.rs @@ -3,9 +3,9 @@ use std::io; use actix_codec::Framed; use actix_http::{body::BodySize, h1, ws, Error, HttpService, Request, Response}; use actix_http_test::test_server; +use actix_utils::future::ok; use bytes::Bytes; -use futures_util::future::ok; -use futures_util::{SinkExt, StreamExt}; +use futures_util::{SinkExt as _, StreamExt as _}; async fn ws_service(req: ws::Frame) -> Result { match req { @@ -36,7 +36,7 @@ async fn test_simple() { ws::Dispatcher::with(framed, ws_service).await } }) - .finish(|_| ok::<_, Error>(Response::NotFound())) + .finish(|_| ok::<_, Error>(Response::not_found())) .tcp() }) .await; diff --git a/awc/tests/utils.rs b/awc/tests/utils.rs new file mode 100644 index 000000000..9a3743d8b --- /dev/null +++ b/awc/tests/utils.rs @@ -0,0 +1,76 @@ +// compiling some tests will trigger unused function warnings even though other tests use them +#![allow(dead_code)] + +use std::io::{Read as _, Write as _}; + +pub mod gzip { + use super::*; + use flate2::{read::GzDecoder, write::GzEncoder, Compression}; + + pub fn encode(bytes: impl AsRef<[u8]>) -> Vec { + let mut encoder = GzEncoder::new(Vec::new(), Compression::fast()); + encoder.write_all(bytes.as_ref()).unwrap(); + encoder.finish().unwrap() + } + + pub fn decode(bytes: impl AsRef<[u8]>) -> Vec { + let mut decoder = GzDecoder::new(bytes.as_ref()); + let mut buf = Vec::new(); + decoder.read_to_end(&mut buf).unwrap(); + buf + } +} + +pub mod deflate { + use super::*; + use flate2::{read::ZlibDecoder, write::ZlibEncoder, Compression}; + + pub fn encode(bytes: impl AsRef<[u8]>) -> Vec { + let mut encoder = ZlibEncoder::new(Vec::new(), Compression::fast()); + encoder.write_all(bytes.as_ref()).unwrap(); + encoder.finish().unwrap() + } + + pub fn decode(bytes: impl AsRef<[u8]>) -> Vec { + let mut decoder = ZlibDecoder::new(bytes.as_ref()); + let mut buf = Vec::new(); + decoder.read_to_end(&mut buf).unwrap(); + buf + } +} + +pub mod brotli { + use super::*; + use ::brotli2::{read::BrotliDecoder, write::BrotliEncoder}; + + pub fn encode(bytes: impl AsRef<[u8]>) -> Vec { + let mut encoder = BrotliEncoder::new(Vec::new(), 3); + encoder.write_all(bytes.as_ref()).unwrap(); + encoder.finish().unwrap() + } + + pub fn decode(bytes: impl AsRef<[u8]>) -> Vec { + let mut decoder = BrotliDecoder::new(bytes.as_ref()); + let mut buf = Vec::new(); + decoder.read_to_end(&mut buf).unwrap(); + buf + } +} + +pub mod zstd { + use super::*; + use ::zstd::stream::{read::Decoder, write::Encoder}; + + pub fn encode(bytes: impl AsRef<[u8]>) -> Vec { + let mut encoder = Encoder::new(Vec::new(), 3).unwrap(); + encoder.write_all(bytes.as_ref()).unwrap(); + encoder.finish().unwrap() + } + + pub fn decode(bytes: impl AsRef<[u8]>) -> Vec { + let mut decoder = Decoder::new(bytes.as_ref()).unwrap(); + let mut buf = Vec::new(); + decoder.read_to_end(&mut buf).unwrap(); + buf + } +} diff --git a/benches/responder.rs b/benches/responder.rs index 8cfdbd3ea..20aae3351 100644 --- a/benches/responder.rs +++ b/benches/responder.rs @@ -1,12 +1,12 @@ -use std::future::Future; -use std::time::Instant; +use std::{future::Future, time::Instant}; -use actix_http::Response; -use actix_web::http::StatusCode; -use actix_web::test::TestRequest; -use actix_web::{error, Error, HttpRequest, HttpResponse, Responder}; +use actix_http::body::BoxBody; +use actix_utils::future::{ready, Ready}; +use actix_web::{ + error, http::StatusCode, test::TestRequest, Error, HttpRequest, HttpResponse, Responder, +}; use criterion::{criterion_group, criterion_main, Criterion}; -use futures_util::future::{ready, Either, Ready}; +use futures_util::future::{join_all, Either}; // responder simulate the old responder trait. trait FutureResponder { @@ -24,11 +24,11 @@ struct StringResponder(String); impl FutureResponder for StringResponder { type Error = Error; - type Future = Ready>; + type Future = Ready>; fn future_respond_to(self, _: &HttpRequest) -> Self::Future { // this is default builder for string response in both new and old responder trait. - ready(Ok(Response::build(StatusCode::OK) + ready(Ok(HttpResponse::build(StatusCode::OK) .content_type("text/plain; charset=utf-8") .body(self.0))) } @@ -37,7 +37,7 @@ impl FutureResponder for StringResponder { impl FutureResponder for OptionResponder where T: FutureResponder, - T::Future: Future>, + T::Future: Future>, { type Error = Error; type Future = Either>>; @@ -51,18 +51,22 @@ where } impl Responder for StringResponder { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - Response::build(StatusCode::OK) + type Body = BoxBody; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + HttpResponse::build(StatusCode::OK) .content_type("text/plain; charset=utf-8") .body(self.0) } } impl Responder for OptionResponder { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { + type Body = BoxBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { match self.0 { - Some(t) => t.respond_to(req), - None => Response::from_error(error::ErrorInternalServerError("err")), + Some(t) => t.respond_to(req).map_into_boxed_body(), + None => HttpResponse::from_error(error::ErrorInternalServerError("err")), } } } @@ -79,7 +83,7 @@ fn future_responder(c: &mut Criterion) { .await }); - let futs = futures_util::future::join_all(futs); + let futs = join_all(futs); let start = Instant::now(); diff --git a/benches/server.rs b/benches/server.rs index 9dd540a73..139e24abd 100644 --- a/benches/server.rs +++ b/benches/server.rs @@ -1,4 +1,4 @@ -use actix_web::{test, web, App, HttpResponse}; +use actix_web::{web, App, HttpResponse}; use awc::Client; use criterion::{criterion_group, criterion_main, Criterion}; use futures_util::future::join_all; @@ -32,7 +32,7 @@ fn bench_async_burst(c: &mut Criterion) { let rt = actix_rt::System::new(); let srv = rt.block_on(async { - test::start(|| { + actix_test::start(|| { App::new() .service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) }) diff --git a/benches/service.rs b/benches/service.rs index 0d3264857..87e51f170 100644 --- a/benches/service.rs +++ b/benches/service.rs @@ -9,7 +9,7 @@ use actix_web::test::{init_service, ok_service, TestRequest}; /// Criterion Benchmark for async Service /// Should be used from within criterion group: -/// ```rust,ignore +/// ```ignore /// let mut criterion: ::criterion::Criterion<_> = /// ::criterion::Criterion::default().configure_from_args(); /// bench_async_service(&mut criterion, ok_service(), "async_service_direct"); @@ -51,9 +51,8 @@ where fut.await.unwrap(); } }); - let elapsed = start.elapsed(); // check that at least first request succeeded - elapsed + start.elapsed() }) }); } @@ -93,9 +92,8 @@ fn async_web_service(c: &mut Criterion) { fut.await.unwrap(); } }); - let elapsed = start.elapsed(); // check that at least first request succeeded - elapsed + start.elapsed() }) }); } diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 000000000..ece14b8d2 --- /dev/null +++ b/clippy.toml @@ -0,0 +1 @@ +msrv = "1.54" diff --git a/codecov.yml b/codecov.yml index e6bc40203..d80835c7f 100644 --- a/codecov.yml +++ b/codecov.yml @@ -4,13 +4,12 @@ coverage: status: project: default: - threshold: 10% # make CI green + threshold: 100% # make CI green patch: default: - threshold: 10% # make CI green + threshold: 100% # make CI green ignore: # ignore code coverage on following paths - "**/tests" - - "test-server" - "**/benches" - "**/examples" diff --git a/docs/graphs/net-only.dot b/docs/graphs/net-only.dot index 9488f3fe7..8a58ec2b8 100644 --- a/docs/graphs/net-only.dot +++ b/docs/graphs/net-only.dot @@ -1,21 +1,36 @@ digraph { + rankdir=TB + subgraph cluster_net { - label="actix/actix-net"; - "actix-codec" - "actix-macros" - "actix-rt" - "actix-server" - "actix-service" - "actix-threadpool" - "actix-tls" - "actix-tracing" - "actix-utils" - "actix-router" + label="actix-net" + "actix-codec" "actix-macros" "actix-rt" "actix-server" "actix-service" + "actix-tls" "actix-tracing" "actix-utils" + } + + subgraph cluster_other { + label="other actix owned crates" + { rank=same; "local-channel" "local-waker" "bytestring" } } - "actix-utils" -> { "actix-service" "actix-rt" "actix-codec" } + subgraph cluster_tokio { + label="tokio" + "tokio" "tokio-util" + } + + "actix-codec" -> { "tokio" } + "actix-codec" -> { "tokio-util" }[color=red] + "actix-utils" -> { "local-waker" } "actix-tracing" -> { "actix-service" } "actix-tls" -> { "actix-service" "actix-codec" "actix-utils" "actix-rt" } - "actix-server" -> { "actix-service" "actix-rt" "actix-codec" "actix-utils" } - "actix-rt" -> { "actix-macros" "actix-threadpool" } + "actix-tls" -> { "tokio-util" }[color="#009900"] + "actix-server" -> { "actix-service" "actix-rt" "actix-utils" "tokio" } + "actix-rt" -> { "actix-macros" "tokio" } + + "local-channel" -> { "local-waker" } + + // invisible edges to force nicer layout + edge [style=invis] + "actix-macros" -> "tokio" + "actix-service" -> "bytestring" + "actix-macros" -> "bytestring" } diff --git a/docs/graphs/web-focus.dot b/docs/graphs/web-focus.dot index ec0f7a946..63b3eaa82 100644 --- a/docs/graphs/web-focus.dot +++ b/docs/graphs/web-focus.dot @@ -1,35 +1,45 @@ digraph { subgraph cluster_web { - label="actix/actix-web" + label="actix/web" "awc" - "actix-web" - "actix-files" - "actix-http" - "actix-multipart" - "actix-web-actors" - "actix-web-codegen" - "actix-http-test" + "web" + "files" + "http" + "multipart" + "web-actors" + "web-codegen" + "http-test" + "router" + + { rank=same; "multipart" "web-actors" "http-test" }; + { rank=same; "files" "awc" "web" }; + { rank=same; "web-codegen" "http" }; } - "actix-web" -> { "actix-codec" "actix-service" "actix-utils" "actix-router" "actix-rt" "actix-server" "macros" "threadpool" "actix-tls" "actix-web-codegen" "actix-http" "awc" } - "awc" -> { "actix-codec" "actix-service" "actix-http" "actix-rt" } - "actix-web-actors" -> { "actix" "actix-web" "actix-http" "actix-codec" } - "actix-multipart" -> { "actix-web" "actix-service" "actix-utils" } - "actix-http" -> { "actix-service" "actix-codec" "actix-tls" "actix-utils" "actix-rt" "threadpool" } - "actix-http" -> { "actix-tls" }[color=blue] // optional - "actix-files" -> { "actix-web" } - "actix-http-test" -> { "actix-service" "actix-codec" "actix-tls" "actix-utils" "actix-rt" "actix-server" "awc" } + "web" -> { "codec" "service" "utils" "router" "rt" "server" "macros" "web-codegen" "http" "awc" } + "web" -> { "tls" }[color=blue] // optional + "awc" -> { "codec" "service" "http" "rt" } + "web-actors" -> { "actix" "web" "http" "codec" } + "multipart" -> { "web" "service" "utils" } + "http" -> { "service" "codec" "utils" "rt" } + "http" -> { "tls" }[color=blue] // optional + "files" -> { "web" } + "http-test" -> { "service" "codec" "utils" "rt" "server" "awc" } + "http-test" -> { "tls" }[color=blue] // optional // net - "actix-utils" -> { "actix-service" "actix-rt" "actix-codec" } - "actix-tracing" -> { "actix-service" } - "actix-tls" -> { "actix-service" "actix-codec" "actix-utils" } - "actix-server" -> { "actix-service" "actix-rt" "actix-codec" "actix-utils" } - "actix-rt" -> { "macros" "threadpool" } + "utils" -> { "service" "rt" "codec" } + "tracing" -> { "service" } + "tls" -> { "service" "codec" "utils" } + "server" -> { "service" "rt" "codec" "utils" } + "rt" -> { "macros" } + + { rank=same; "utils" "codec" }; + { rank=same; "rt" "macros" "service" }; // actix - "actix" -> { "actix-rt" } + "actix" -> { "rt" } } diff --git a/docs/graphs/web-only.dot b/docs/graphs/web-only.dot index 6f8292a3a..b27dd0943 100644 --- a/docs/graphs/web-only.dot +++ b/docs/graphs/web-only.dot @@ -9,12 +9,16 @@ digraph { "actix-web-actors" "actix-web-codegen" "actix-http-test" + "actix-test" + "actix-router" } - "actix-web" -> { "actix-web-codegen" "actix-http" "awc" } + "actix-web" -> { "actix-web-codegen" "actix-http" "actix-router" } "awc" -> { "actix-http" } + "actix-web-codegen" -> { "actix-router" } "actix-web-actors" -> { "actix" "actix-web" "actix-http" } "actix-multipart" -> { "actix-web" } "actix-files" -> { "actix-web" } "actix-http-test" -> { "awc" } + "actix-test" -> { "actix-web" "awc" "actix-http-test" } } diff --git a/examples/basic.rs b/examples/basic.rs index 99eef3ee1..598d13a40 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -22,20 +22,20 @@ async fn main() -> std::io::Result<()> { HttpServer::new(|| { App::new() - .wrap(middleware::DefaultHeaders::new().header("X-Version", "0.2")) + .wrap(middleware::DefaultHeaders::new().add(("X-Version", "0.2"))) .wrap(middleware::Compress::default()) .wrap(middleware::Logger::default()) .service(index) .service(no_params) .service( web::resource("/resource2/index.html") - .wrap(middleware::DefaultHeaders::new().header("X-Version-R2", "0.3")) - .default_service(web::route().to(|| HttpResponse::MethodNotAllowed())) + .wrap(middleware::DefaultHeaders::new().add(("X-Version-R2", "0.3"))) + .default_service(web::route().to(HttpResponse::MethodNotAllowed)) .route(web::get().to(index_async)), ) .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) }) - .bind("127.0.0.1:8080")? + .bind(("127.0.0.1", 8080))? .workers(1) .run() .await diff --git a/examples/on_connect.rs b/examples/on-connect.rs similarity index 60% rename from examples/on_connect.rs rename to examples/on-connect.rs index ba5a18f3f..d76e9ce56 100644 --- a/examples/on_connect.rs +++ b/examples/on-connect.rs @@ -2,12 +2,16 @@ //! properties and pass them to a handler through request-local data. //! //! For an example of extracting a client TLS certificate, see: -//! +//! use std::{any::Any, io, net::SocketAddr}; -use actix_web::{dev::Extensions, rt::net::TcpStream, web, App, HttpServer}; +use actix_web::{ + dev::Extensions, rt::net::TcpStream, web, App, HttpRequest, HttpResponse, HttpServer, + Responder, +}; +#[allow(dead_code)] #[derive(Debug, Clone)] struct ConnectionInfo { bind: SocketAddr, @@ -15,11 +19,16 @@ struct ConnectionInfo { ttl: Option, } -async fn route_whoami(conn_info: web::ReqData) -> String { - format!( - "Here is some info about your connection:\n\n{:#?}", - conn_info - ) +async fn route_whoami(req: HttpRequest) -> impl Responder { + match req.conn_data::() { + Some(info) => HttpResponse::Ok().body(format!( + "Here is some info about your connection:\n\n{:#?}", + info + )), + None => { + HttpResponse::InternalServerError().body("Missing expected request extension data") + } + } } fn get_conn_info(connection: &dyn Any, data: &mut Extensions) { @@ -38,9 +47,12 @@ fn get_conn_info(connection: &dyn Any, data: &mut Extensions) { async fn main() -> io::Result<()> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); + let bind = ("127.0.0.1", 8080); + log::info!("staring server at http://{}:{}", &bind.0, &bind.1); + HttpServer::new(|| App::new().default_service(web::to(route_whoami))) .on_connect(get_conn_info) - .bind(("127.0.0.1", 8080))? + .bind(bind)? .workers(1) .run() .await diff --git a/examples/uds.rs b/examples/uds.rs index 096781984..cf0ffebde 100644 --- a/examples/uds.rs +++ b/examples/uds.rs @@ -26,15 +26,15 @@ async fn main() -> std::io::Result<()> { HttpServer::new(|| { App::new() - .wrap(middleware::DefaultHeaders::new().header("X-Version", "0.2")) + .wrap(middleware::DefaultHeaders::new().add(("X-Version", "0.2"))) .wrap(middleware::Compress::default()) .wrap(middleware::Logger::default()) .service(index) .service(no_params) .service( web::resource("/resource2/index.html") - .wrap(middleware::DefaultHeaders::new().header("X-Version-R2", "0.3")) - .default_service(web::route().to(|| HttpResponse::MethodNotAllowed())) + .wrap(middleware::DefaultHeaders::new().add(("X-Version-R2", "0.3"))) + .default_service(web::route().to(HttpResponse::MethodNotAllowed)) .route(web::get().to(index_async)), ) .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) diff --git a/scripts/bump b/scripts/bump new file mode 100755 index 000000000..c43b92dc8 --- /dev/null +++ b/scripts/bump @@ -0,0 +1,157 @@ +#!/bin/sh + +# developed on macOS and probably doesn't work on Linux yet due to minor +# differences in flags on sed + +# requires github cli tool for automatic release draft creation + +set -euo pipefail + +DIR=$1 + +LINUX="" +MACOS="" + +if [ "$(uname)" = "Darwin" ]; then + MACOS="1" +fi + +CARGO_MANIFEST=$DIR/Cargo.toml +README_FILE=$DIR/README.md + +# determine changelog file name +if [ -f "$DIR/CHANGES.md" ]; then + CHANGELOG_FILE=$DIR/CHANGES.md +elif [ -f "$DIR/CHANGELOG.md" ]; then + CHANGELOG_FILE=$DIR/CHANGELOG.md +else + echo "No changelog file found" + exit 1 +fi + +# get current version +PACKAGE_NAME="$(sed -nE 's/^name ?= ?"([^"]+)"$/\1/ p' "$CARGO_MANIFEST" | head -n 1)" +CURRENT_VERSION="$(sed -nE 's/^version ?= ?"([^"]+)"$/\1/ p' "$CARGO_MANIFEST")" + +CHANGE_CHUNK_FILE="$(mktemp)" +echo saving changelog to $CHANGE_CHUNK_FILE +echo + +# get changelog chunk and save to temp file +cat "$CHANGELOG_FILE" | + # skip up to unreleased heading + sed '1,/Unreleased/ d' | + # take up to previous version heading + sed "/$CURRENT_VERSION/ q" | + # drop last line + sed '$d' \ + >"$CHANGE_CHUNK_FILE" + +# if word count of changelog chunk is 0 then insert filler changelog chunk +if [ "$(wc -w "$CHANGE_CHUNK_FILE" | awk '{ print $1 }')" = "0" ]; then + echo "- No significant changes since \`$CURRENT_VERSION\`." >"$CHANGE_CHUNK_FILE" + echo >>"$CHANGE_CHUNK_FILE" + echo >>"$CHANGE_CHUNK_FILE" +fi + +if [ -n "${2-}" ]; then + NEW_VERSION="$2" +else + echo + echo "--- Changes since $CURRENT_VERSION ----" + cat "$CHANGE_CHUNK_FILE" + echo + read -p "Update version to: " NEW_VERSION +fi + +# strip leading v from input +if [ "${NEW_VERSION:0:1}" = "v" ]; then + NEW_VERSION="${NEW_VERSION:1}" +fi + +DATE="$(date -u +"%Y-%m-%d")" +echo "updating from $CURRENT_VERSION => $NEW_VERSION ($DATE)" + +# update package.version field +sed -i.bak -E "s/^version ?= ?\"[^\"]+\"$/version = \"$NEW_VERSION\"/" "$CARGO_MANIFEST" + +# update readme +[ -f "$README_FILE" ] && sed -i.bak -E "s#$CURRENT_VERSION([/)])#$NEW_VERSION\1#g" "$README_FILE" + +# update changelog file +( + sed '/Unreleased/ q' "$CHANGELOG_FILE" # up to unreleased heading + echo # blank line + echo # blank line + echo "## $NEW_VERSION - $DATE" # new version heading + cat "$CHANGE_CHUNK_FILE" # previously unreleased changes + sed "/$CURRENT_VERSION/ q" "$CHANGELOG_FILE" | tail -n 1 # the previous version heading + sed "1,/$CURRENT_VERSION/ d" "$CHANGELOG_FILE" # everything after previous version heading +) >"$CHANGELOG_FILE.bak" +mv "$CHANGELOG_FILE.bak" "$CHANGELOG_FILE" + +# done; remove backup files +rm -f $CARGO_MANIFEST.bak +rm -f $CHANGELOG_FILE.bak +rm -f $README_FILE.bak + +echo "manifest, changelog, and readme updated" +echo +echo "check other references:" +rg --glob='**/Cargo.toml' "\ +${PACKAGE_NAME} ?= ?\"[^\"]+\"\ +|${PACKAGE_NAME} ?=.*version ?= ?\"([^\"]+)\"\ +|package ?= ?\"${PACKAGE_NAME}\".*version ?= ?\"([^\"]+)\"\ +|version ?= ?\"([^\"]+)\".*package ?= ?\"${PACKAGE_NAME}\"" || true + +echo +read -p "Update all references: (y/N) " UPDATE_REFERENCES +UPDATE_REFERENCES="${UPDATE_REFERENCES:-n}" + +if [ "$UPDATE_REFERENCES" = 'y' ] || [ "$UPDATE_REFERENCES" = 'Y' ]; then + + for f in $(fd Cargo.toml); do + sed -i.bak -E \ + "s/^(${PACKAGE_NAME} ?= ?\")[^\"]+(\")$/\1${NEW_VERSION}\2/g" $f + sed -i.bak -E \ + "s/^(${PACKAGE_NAME} ?=.*version ?= ?\")[^\"]+(\".*)$/\1${NEW_VERSION}\2/g" $f + sed -i.bak -E \ + "s/^(.*package ?= ?\"${PACKAGE_NAME}\".*version ?= ?\")[^\"]+(\".*)$/\1${NEW_VERSION}\2/g" $f + sed -i.bak -E \ + "s/^(.*version ?= ?\")[^\"]+(\".*package ?= ?\"${PACKAGE_NAME}\".*)$/\1${NEW_VERSION}\2/g" $f + + # remove backup file + rm -f $f.bak + done + +fi + +if [ $MACOS ]; then + printf "prepare $PACKAGE_NAME release $NEW_VERSION" | pbcopy +else + echo + echo "commit message:" + echo "prepare $PACKAGE_NAME release $NEW_VERSION" +fi + +SHORT_PACKAGE_NAME="$(echo $PACKAGE_NAME | sed 's/^actix-web-//' | sed 's/^actix-//')" +GIT_TAG="$(echo $SHORT_PACKAGE_NAME-v$NEW_VERSION)" +RELEASE_TITLE="$(echo $PACKAGE_NAME: v$NEW_VERSION)" + +if [ "$(echo $NEW_VERSION | grep beta)" ] || [ "$(echo $NEW_VERSION | grep rc)" ] || [ "$(echo $NEW_VERSION | grep alpha)" ]; then + PRERELEASE="--prerelease" +fi + +echo +echo "GitHub release command:" +GH_CMD="gh release create \"$GIT_TAG\" --draft --title \"$RELEASE_TITLE\" --notes-file \"$CHANGE_CHUNK_FILE\" ${PRERELEASE:-}" +echo "$GH_CMD" + +read -p "Submit draft GH release: (y/N) " GH_RELEASE +GH_RELEASE="${GH_RELEASE:-n}" + +if [ "$GH_RELEASE" = 'y' ] || [ "$GH_RELEASE" = 'Y' ]; then + eval "$GH_CMD" +fi + +echo diff --git a/scripts/ci-test b/scripts/ci-test new file mode 100755 index 000000000..567012d33 --- /dev/null +++ b/scripts/ci-test @@ -0,0 +1,28 @@ +#!/bin/sh + +# run tests matching what CI does for non-linux feature sets + +set -x + +EXIT=0 + +save_exit_code() { + eval $@ + local CMD_EXIT=$? + [ "$CMD_EXIT" = "0" ] || EXIT=$CMD_EXIT +} + +save_exit_code cargo test --lib --tests -p=actix-router --all-features -- --nocapture +save_exit_code cargo test --lib --tests -p=actix-http --all-features -- --nocapture +save_exit_code cargo test --lib --tests -p=actix-web --features=rustls,openssl -- --nocapture --skip=test_reading_deflate_encoding_large_random_rustls +save_exit_code cargo test --lib --tests -p=actix-web-codegen --all-features -- --nocapture +save_exit_code cargo test --lib --tests -p=awc --all-features -- --nocapture +save_exit_code cargo test --lib --tests -p=actix-http-test --all-features -- --nocapture +save_exit_code cargo test --lib --tests -p=actix-test --all-features -- --nocapture +save_exit_code cargo test --lib --tests -p=actix-files -- --nocapture +save_exit_code cargo test --lib --tests -p=actix-multipart --all-features -- --nocapture +save_exit_code cargo test --lib --tests -p=actix-web-actors --all-features -- --nocapture + +save_exit_code cargo test --workspace --doc + +exit $EXIT diff --git a/scripts/unreleased b/scripts/unreleased new file mode 100755 index 000000000..4dfa2d9ae --- /dev/null +++ b/scripts/unreleased @@ -0,0 +1,41 @@ +#!/bin/sh + +set -euo pipefail + +bold="\033[1m" +reset="\033[0m" + +unreleased_for() { + DIR=$1 + + CARGO_MANIFEST=$DIR/Cargo.toml + CHANGELOG_FILE=$DIR/CHANGES.md + + # get current version + PACKAGE_NAME="$(sed -nE 's/^name ?= ?"([^"]+)"$/\1/ p' "$CARGO_MANIFEST" | head -n 1)" + CURRENT_VERSION="$(sed -nE 's/^version ?= ?"([^"]+)"$/\1/ p' "$CARGO_MANIFEST")" + + CHANGE_CHUNK_FILE="$(mktemp)" + + # get changelog chunk and save to temp file + cat "$CHANGELOG_FILE" | + # skip up to unreleased heading + sed '1,/Unreleased/ d' | + # take up to previous version heading + sed "/$CURRENT_VERSION/ q" | + # drop last line + sed '$d' \ + >"$CHANGE_CHUNK_FILE" + + # if word count of changelog chunk is 0 then exit + if [ "$(wc -w "$CHANGE_CHUNK_FILE" | awk '{ print $1 }')" = "0" ]; then + return 0; + fi + + echo "${bold}# ${PACKAGE_NAME}${reset} since ${bold}v$CURRENT_VERSION${reset}" + cat "$CHANGE_CHUNK_FILE" +} + +for f in $(fd --absolute-path CHANGES.md); do + unreleased_for $(dirname $f) +done diff --git a/src/app.rs b/src/app.rs index 7a26a3a89..3fddc055b 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,110 +1,132 @@ -use std::cell::RefCell; -use std::fmt; -use std::future::Future; -use std::marker::PhantomData; -use std::rc::Rc; +use std::{cell::RefCell, fmt, future::Future, rc::Rc}; -use actix_http::body::{Body, MessageBody}; -use actix_http::{Extensions, Request}; -use actix_service::boxed::{self, BoxServiceFactory}; +use actix_http::{body::MessageBody, Extensions, Request}; use actix_service::{ - apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, ServiceFactoryExt, Transform, + apply, apply_fn_factory, boxed, IntoServiceFactory, ServiceFactory, ServiceFactoryExt, + Transform, }; -use futures_util::future::FutureExt; +use futures_util::future::FutureExt as _; -use crate::app_service::{AppEntry, AppInit, AppRoutingFactory}; -use crate::config::ServiceConfig; -use crate::data::{Data, DataFactory, FnDataFactory}; -use crate::dev::ResourceDef; -use crate::error::Error; -use crate::resource::Resource; -use crate::route::Route; -use crate::service::{ - AppServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, ServiceRequest, - ServiceResponse, +use crate::{ + app_service::{AppEntry, AppInit, AppRoutingFactory}, + config::ServiceConfig, + data::{Data, DataFactory, FnDataFactory}, + dev::ResourceDef, + error::Error, + resource::Resource, + route::Route, + service::{ + AppServiceFactory, BoxedHttpServiceFactory, HttpServiceFactory, ServiceFactoryWrapper, + ServiceRequest, ServiceResponse, + }, }; -type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; - /// Application builder - structure that follows the builder pattern /// for building application instances. -pub struct App { +pub struct App { endpoint: T, services: Vec>, - default: Option>, + default: Option>, factory_ref: Rc>>, data_factories: Vec, external: Vec, extensions: Extensions, - _phantom: PhantomData, } -impl App { +impl App { /// Create application builder. Application can be configured with a builder-like pattern. #[allow(clippy::new_without_default)] pub fn new() -> Self { - let fref = Rc::new(RefCell::new(None)); + let factory_ref = Rc::new(RefCell::new(None)); + App { - endpoint: AppEntry::new(fref.clone()), + endpoint: AppEntry::new(factory_ref.clone()), data_factories: Vec::new(), services: Vec::new(), default: None, - factory_ref: fref, + factory_ref, external: Vec::new(), extensions: Extensions::new(), - _phantom: PhantomData, } } } -impl App -where - B: MessageBody, - T: ServiceFactory< - ServiceRequest, - Config = (), - Response = ServiceResponse, - Error = Error, - InitError = (), - >, -{ - /// Set application data. Application data could be accessed - /// by using `Data` extractor where `T` is data type. +impl App { + /// Set application (root level) data. /// - /// **Note**: HTTP server accepts an application factory rather than - /// an application instance. Http server constructs an application - /// instance for each thread, thus application data must be constructed - /// multiple times. If you want to share data between different - /// threads, a shared object should be used, e.g. `Arc`. Internally `Data` type - /// uses `Arc` so data could be created outside of app factory and clones could - /// be stored via `App::app_data()` method. + /// Application data stored with `App::app_data()` method is available through the + /// [`HttpRequest::app_data`](crate::HttpRequest::app_data) method at runtime. /// - /// ```rust + /// # [`Data`] + /// Any [`Data`] type added here can utilize it's extractor implementation in handlers. + /// Types not wrapped in `Data` cannot use this extractor. See [its docs](Data) for more + /// about its usage and patterns. + /// + /// ``` /// use std::cell::Cell; - /// use actix_web::{web, App, HttpResponse, Responder}; + /// use actix_web::{web, App, HttpRequest, HttpResponse, Responder}; /// /// struct MyData { - /// counter: Cell, + /// count: std::cell::Cell, /// } /// - /// async fn index(data: web::Data) -> impl Responder { - /// data.counter.set(data.counter.get() + 1); - /// HttpResponse::Ok() + /// async fn handler(req: HttpRequest, counter: web::Data) -> impl Responder { + /// // note this cannot use the Data extractor because it was not added with it + /// let incr = *req.app_data::().unwrap(); + /// assert_eq!(incr, 3); + /// + /// // update counter using other value from app data + /// counter.count.set(counter.count.get() + incr); + /// + /// HttpResponse::Ok().body(counter.count.get().to_string()) /// } /// - /// let app = App::new() - /// .data(MyData{ counter: Cell::new(0) }) - /// .service( - /// web::resource("/index.html").route( - /// web::get().to(index))); + /// let app = App::new().service( + /// web::resource("/") + /// .app_data(3usize) + /// .app_data(web::Data::new(MyData { count: Default::default() })) + /// .route(web::get().to(handler)) + /// ); /// ``` + /// + /// # Shared Mutable State + /// [`HttpServer::new`](crate::HttpServer::new) accepts an application factory rather than an + /// application instance; the factory closure is called on each worker thread independently. + /// Therefore, if you want to share a data object between different workers, a shareable object + /// needs to be created first, outside the `HttpServer::new` closure and cloned into it. + /// [`Data`] is an example of such a sharable object. + /// + /// ```ignore + /// let counter = web::Data::new(AppStateWithCounter { + /// counter: Mutex::new(0), + /// }); + /// + /// HttpServer::new(move || { + /// // move counter object into the closure and clone for each worker + /// + /// App::new() + /// .app_data(counter.clone()) + /// .route("/", web::get().to(handler)) + /// }) + /// ``` + #[doc(alias = "manage")] + pub fn app_data(mut self, ext: U) -> Self { + self.extensions.insert(ext); + self + } + + /// Add application (root) data after wrapping in `Data`. + /// + /// Deprecated in favor of [`app_data`](Self::app_data). + #[deprecated(since = "4.0.0", note = "Use `.app_data(Data::new(val))` instead.")] pub fn data(self, data: U) -> Self { self.app_data(Data::new(data)) } - /// Set application data factory. This function is - /// similar to `.data()` but it accepts data factory. Data object get - /// constructed asynchronously during application initialization. + /// Add application data factory that resolves asynchronously. + /// + /// Data items are constructed during application initialization, before the server starts + /// accepting requests. pub fn data_factory(mut self, data: F) -> Self where F: Fn() -> Out + 'static, @@ -130,18 +152,7 @@ where } .boxed_local() })); - self - } - /// Set application level arbitrary data item. - /// - /// Application data stored with `App::app_data()` method is available - /// via `HttpRequest::app_data()` method at runtime. - /// - /// This method could be used for storing `Data` as well, in that case - /// data could be accessed by using `Data` extractor. - pub fn app_data(mut self, ext: U) -> Self { - self.extensions.insert(ext); self } @@ -152,7 +163,7 @@ where /// different module or even library. For example, /// some of the resource's configuration could be moved to different module. /// - /// ```rust + /// ``` /// use actix_web::{web, App, HttpResponse}; /// /// // this function could be located in different module @@ -185,18 +196,16 @@ where /// This method can be used multiple times with same path, in that case /// multiple resources with one route would be registered for same resource path. /// - /// ```rust + /// ``` /// use actix_web::{web, App, HttpResponse}; /// /// async fn index(data: web::Path<(String, String)>) -> &'static str { /// "Welcome!" /// } /// - /// fn main() { - /// let app = App::new() - /// .route("/test1", web::get().to(index)) - /// .route("/test2", web::post().to(|| HttpResponse::MethodNotAllowed())); - /// } + /// let app = App::new() + /// .route("/test1", web::get().to(index)) + /// .route("/test2", web::post().to(|| HttpResponse::MethodNotAllowed())); /// ``` pub fn route(self, path: &str, mut route: Route) -> Self { self.service( @@ -228,37 +237,33 @@ where /// /// It is possible to use services like `Resource`, `Route`. /// - /// ```rust + /// ``` /// use actix_web::{web, App, HttpResponse}; /// /// async fn index() -> &'static str { /// "Welcome!" /// } /// - /// fn main() { - /// let app = App::new() - /// .service( - /// web::resource("/index.html").route(web::get().to(index))) - /// .default_service( - /// web::route().to(|| HttpResponse::NotFound())); - /// } + /// let app = App::new() + /// .service( + /// web::resource("/index.html").route(web::get().to(index))) + /// .default_service( + /// web::route().to(|| HttpResponse::NotFound())); /// ``` /// /// It is also possible to use static files as default service. /// - /// ```rust + /// ``` /// use actix_web::{web, App, HttpResponse}; /// - /// fn main() { - /// let app = App::new() - /// .service( - /// web::resource("/index.html").to(|| HttpResponse::Ok())) - /// .default_service( - /// web::to(|| HttpResponse::NotFound()) - /// ); - /// } + /// let app = App::new() + /// .service( + /// web::resource("/index.html").to(|| HttpResponse::Ok())) + /// .default_service( + /// web::to(|| HttpResponse::NotFound()) + /// ); /// ``` - pub fn default_service(mut self, f: F) -> Self + pub fn default_service(mut self, svc: F) -> Self where F: IntoServiceFactory, U: ServiceFactory< @@ -269,10 +274,12 @@ where > + 'static, U::InitError: fmt::Debug, { - // create and configure default resource - self.default = Some(Rc::new(boxed::factory(f.into_factory().map_init_err( - |e| log::error!("Can not construct default service: {:?}", e), - )))); + let svc = svc + .into_factory() + .map(|res| res.map_into_boxed_body()) + .map_init_err(|e| log::error!("Can not construct default service: {:?}", e)); + + self.default = Some(Rc::new(boxed::factory(svc))); self } @@ -283,7 +290,7 @@ where /// and are never considered for matching at request time. Calls to /// `HttpRequest::url_for()` will work as expected. /// - /// ```rust + /// ``` /// use actix_web::{web, App, HttpRequest, HttpResponse, Result}; /// /// async fn index(req: HttpRequest) -> Result { @@ -305,7 +312,7 @@ where U: AsRef, { let mut rdef = ResourceDef::new(url.as_ref()); - *rdef.name_mut() = name.as_ref().to_string(); + rdef.set_name(name.as_ref()); self.external.push(rdef); self } @@ -325,10 +332,10 @@ where /// the builder chain. Consequently, the *first* middleware registered /// in the builder chain is the *last* to execute during request processing. /// - /// ```rust + /// ``` /// use actix_service::Service; /// use actix_web::{middleware, web, App}; - /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; + /// use actix_web::http::header::{CONTENT_TYPE, HeaderValue}; /// /// async fn index() -> &'static str { /// "Welcome!" @@ -340,7 +347,7 @@ where /// .route("/index.html", web::get().to(index)); /// } /// ``` - pub fn wrap( + pub fn wrap( self, mw: M, ) -> App< @@ -351,9 +358,16 @@ where Error = Error, InitError = (), >, - B1, > where + T: ServiceFactory< + ServiceRequest, + Response = ServiceResponse, + Error = Error, + Config = (), + InitError = (), + >, + B: MessageBody, M: Transform< T::Service, ServiceRequest, @@ -371,7 +385,6 @@ where factory_ref: self.factory_ref, external: self.external, extensions: self.extensions, - _phantom: PhantomData, } } @@ -382,10 +395,10 @@ where /// /// Use middleware when you need to read or modify *every* request or response in some way. /// - /// ```rust + /// ``` /// use actix_service::Service; /// use actix_web::{web, App}; - /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; + /// use actix_web::http::header::{CONTENT_TYPE, HeaderValue}; /// /// async fn index() -> &'static str { /// "Welcome!" @@ -406,7 +419,7 @@ where /// .route("/index.html", web::get().to(index)); /// } /// ``` - pub fn wrap_fn( + pub fn wrap_fn( self, mw: F, ) -> App< @@ -417,12 +430,19 @@ where Error = Error, InitError = (), >, - B1, > where - B1: MessageBody, + T: ServiceFactory< + ServiceRequest, + Response = ServiceResponse, + Error = Error, + Config = (), + InitError = (), + >, + B: MessageBody, F: Fn(ServiceRequest, &T::Service) -> R + Clone, R: Future, Error>>, + B1: MessageBody, { App { endpoint: apply_fn_factory(self.endpoint, mw), @@ -432,12 +452,11 @@ where factory_ref: self.factory_ref, external: self.external, extensions: self.extensions, - _phantom: PhantomData, } } } -impl IntoServiceFactory, Request> for App +impl IntoServiceFactory, Request> for App where B: MessageBody, T: ServiceFactory< @@ -464,16 +483,21 @@ where #[cfg(test)] mod tests { - use actix_service::Service; + use actix_service::Service as _; + use actix_utils::future::{err, ok}; use bytes::Bytes; - use futures_util::future::{err, ok}; use super::*; - use crate::http::{header, HeaderValue, Method, StatusCode}; - use crate::middleware::DefaultHeaders; - use crate::service::ServiceRequest; - use crate::test::{call_service, init_service, read_body, try_init_service, TestRequest}; - use crate::{web, HttpRequest, HttpResponse}; + use crate::{ + http::{ + header::{self, HeaderValue}, + Method, StatusCode, + }, + middleware::DefaultHeaders, + service::ServiceRequest, + test::{call_service, init_service, read_body, try_init_service, TestRequest}, + web, HttpRequest, HttpResponse, + }; #[actix_rt::test] async fn test_default_resource() { @@ -518,6 +542,8 @@ mod tests { assert_eq!(resp.status(), StatusCode::CREATED); } + // allow deprecated App::data + #[allow(deprecated)] #[actix_rt::test] async fn test_data_factory() { let srv = init_service( @@ -541,6 +567,8 @@ mod tests { assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } + // allow deprecated App::data + #[allow(deprecated)] #[actix_rt::test] async fn test_data_factory_errors() { let srv = try_init_service( @@ -573,7 +601,7 @@ mod tests { App::new() .wrap( DefaultHeaders::new() - .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")), + .add((header::CONTENT_TYPE, HeaderValue::from_static("0001"))), ) .route("/test", web::get().to(HttpResponse::Ok)), ) @@ -594,7 +622,7 @@ mod tests { .route("/test", web::get().to(HttpResponse::Ok)) .wrap( DefaultHeaders::new() - .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")), + .add((header::CONTENT_TYPE, HeaderValue::from_static("0001"))), ), ) .await; @@ -677,4 +705,25 @@ mod tests { let body = read_body(resp).await; assert_eq!(body, Bytes::from_static(b"https://youtube.com/watch/12345")); } + + #[test] + fn can_be_returned_from_fn() { + /// compile-only test for returning app type from function + pub fn my_app() -> App< + impl ServiceFactory< + ServiceRequest, + Response = ServiceResponse, + Config = (), + InitError = (), + Error = Error, + >, + > { + App::new() + // logger can be removed without affecting the return type + .wrap(crate::middleware::Logger::default()) + .route("/", web::to(|| async { "hello" })) + } + + let _ = init_service(my_app()); + } } diff --git a/src/app_service.rs b/src/app_service.rs index 9b4ae3354..56b24f0d8 100644 --- a/src/app_service.rs +++ b/src/app_service.rs @@ -1,27 +1,30 @@ -use std::cell::RefCell; -use std::rc::Rc; -use std::task::Poll; +use std::{cell::RefCell, mem, rc::Rc}; -use actix_http::{Extensions, Request, Response}; +use actix_http::Request; use actix_router::{Path, ResourceDef, Router, Url}; -use actix_service::boxed::{self, BoxService, BoxServiceFactory}; -use actix_service::{fn_service, Service, ServiceFactory}; +use actix_service::{boxed, fn_service, Service, ServiceFactory}; use futures_core::future::LocalBoxFuture; use futures_util::future::join_all; -use crate::config::{AppConfig, AppService}; -use crate::data::FnDataFactory; -use crate::error::Error; -use crate::guard::Guard; -use crate::request::{HttpRequest, HttpRequestPool}; -use crate::rmap::ResourceMap; -use crate::service::{AppServiceFactory, ServiceRequest, ServiceResponse}; +use crate::{ + body::BoxBody, + config::{AppConfig, AppService}, + data::FnDataFactory, + dev::Extensions, + guard::Guard, + request::{HttpRequest, HttpRequestPool}, + rmap::ResourceMap, + service::{ + AppServiceFactory, BoxedHttpService, BoxedHttpServiceFactory, ServiceRequest, + ServiceResponse, + }, + Error, HttpResponse, +}; type Guards = Vec>; -type HttpService = BoxService; -type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; /// Service factory to convert `Request` to a `ServiceRequest`. +/// /// It also executes data factories. pub struct AppInit where @@ -37,7 +40,7 @@ where pub(crate) extensions: RefCell>, pub(crate) async_data_factories: Rc<[FnDataFactory]>, pub(crate) services: Rc>>>, - pub(crate) default: Option>, + pub(crate) default: Option>, pub(crate) factory_ref: Rc>>, pub(crate) external: RefCell>, } @@ -65,7 +68,7 @@ where // if no user defined default service exists. let default = self.default.clone().unwrap_or_else(|| { Rc::new(boxed::factory(fn_service(|req: ServiceRequest| async { - Ok(req.into_response(Response::NotFound().finish())) + Ok(req.into_response(HttpResponse::NotFound())) }))) }); @@ -73,11 +76,11 @@ where let mut config = AppService::new(config, default.clone()); // register services - std::mem::take(&mut *self.services.borrow_mut()) + mem::take(&mut *self.services.borrow_mut()) .into_iter() .for_each(|mut srv| srv.register(&mut config)); - let mut rmap = ResourceMap::new(ResourceDef::new("")); + let mut rmap = ResourceMap::new(ResourceDef::prefix("")); let (config, services) = config.into_services(); @@ -96,13 +99,13 @@ where }); // external resources - for mut rdef in std::mem::take(&mut *self.external.borrow_mut()) { + for mut rdef in mem::take(&mut *self.external.borrow_mut()) { rmap.add(&mut rdef, None); } // complete ResourceMap tree creation let rmap = Rc::new(rmap); - rmap.finish(rmap.clone()); + ResourceMap::finish(&rmap); // construct all async data factory futures let factory_futs = join_all(self.async_data_factories.iter().map(|f| f())); @@ -129,9 +132,9 @@ where let service = endpoint_fut.await?; // populate app data container from (async) data factories. - async_data_factories.iter().for_each(|factory| { + for factory in &async_data_factories { factory.create(&mut app_data); - }); + } Ok(AppInitService { service, @@ -142,7 +145,9 @@ where } } -/// Service that takes a [`Request`] and delegates to a service that take a [`ServiceRequest`]. +/// The [`Service`] that is passed to `actix-http`'s server builder. +/// +/// Wraps a service receiving a [`ServiceRequest`] into one receiving a [`Request`]. pub struct AppInitService where T: Service, Error = Error>, @@ -164,8 +169,7 @@ impl AppInitServiceState { Rc::new(AppInitServiceState { rmap, config, - // TODO: AppConfig can be used to pass user defined HttpRequestPool - // capacity. + // TODO: AppConfig can be used to pass user defined HttpRequestPool capacity. pool: HttpRequestPool::default(), }) } @@ -196,7 +200,9 @@ where actix_service::forward_ready!(service); - fn call(&self, req: Request) -> Self::Future { + fn call(&self, mut req: Request) -> Self::Future { + let req_data = Rc::new(RefCell::new(req.take_req_data())); + let conn_data = req.take_conn_data(); let (head, payload) = req.into_parts(); let req = if let Some(mut req) = self.app_state.pool().pop() { @@ -204,6 +210,8 @@ where inner.path.get_mut().update(&head.uri); inner.path.reset(); inner.head = head; + inner.conn_data = conn_data; + inner.req_data = req_data; req } else { HttpRequest::new( @@ -211,8 +219,11 @@ where head, self.app_state.clone(), self.app_data.clone(), + conn_data, + req_data, ) }; + self.service.call(ServiceRequest::new(req, payload)) } } @@ -227,8 +238,15 @@ where } pub struct AppRoutingFactory { - services: Rc<[(ResourceDef, HttpNewService, RefCell>)]>, - default: Rc, + #[allow(clippy::type_complexity)] + services: Rc< + [( + ResourceDef, + BoxedHttpServiceFactory, + RefCell>, + )], + >, + default: Rc, } impl ServiceFactory for AppRoutingFactory { @@ -274,27 +292,31 @@ impl ServiceFactory for AppRoutingFactory { } } +/// The Actix Web router default entry point. pub struct AppRouting { - router: Router, - default: HttpService, + router: Router, + default: BoxedHttpService, } impl Service for AppRouting { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; type Future = LocalBoxFuture<'static, Result>; actix_service::always_ready!(); fn call(&self, mut req: ServiceRequest) -> Self::Future { - let res = self.router.recognize_checked(&mut req, |req, guards| { + let res = self.router.recognize_fn(&mut req, |req, guards| { if let Some(ref guards) = guards { - for f in guards { - if !f.check(req.head()) { + let guard_ctx = req.guard_ctx(); + + for guard in guards { + if !guard.check(&guard_ctx) { return false; } } } + true }); @@ -348,6 +370,8 @@ mod tests { } } + // allow deprecated App::data + #[allow(deprecated)] #[actix_rt::test] async fn test_drop_data() { let data = Arc::new(AtomicBool::new(false)); diff --git a/src/config.rs b/src/config.rs index 1d8513cdc..6c0344453 100644 --- a/src/config.rs +++ b/src/config.rs @@ -24,6 +24,7 @@ pub struct AppService { config: AppConfig, root: bool, default: Rc, + #[allow(clippy::type_complexity)] services: Vec<( ResourceDef, HttpNewService, @@ -48,6 +49,7 @@ impl AppService { self.root } + #[allow(clippy::type_complexity)] pub(crate) fn into_services( self, ) -> ( @@ -62,6 +64,8 @@ impl AppService { (self.config, self.services) } + /// Clones inner config and default service, returning new `AppService` with empty service list + /// marked as non-root. pub(crate) fn clone_config(&self) -> Self { AppService { config: self.config.clone(), @@ -71,12 +75,12 @@ impl AppService { } } - /// Service configuration + /// Returns reference to configuration. pub fn config(&self) -> &AppConfig { &self.config } - /// Default resource + /// Returns default handler factory. pub fn default_service(&self) -> Rc { self.default.clone() } @@ -92,9 +96,9 @@ impl AppService { F: IntoServiceFactory, S: ServiceFactory< ServiceRequest, - Config = (), Response = ServiceResponse, Error = Error, + Config = (), InitError = (), > + 'static, { @@ -103,8 +107,8 @@ impl AppService { } } -/// Application connection config -#[derive(Clone)] +/// Application connection config. +#[derive(Debug, Clone)] pub struct AppConfig { secure: bool, host: String, @@ -112,13 +116,19 @@ pub struct AppConfig { } impl AppConfig { - pub(crate) fn new(secure: bool, addr: SocketAddr, host: String) -> Self { - AppConfig { secure, addr, host } + pub(crate) fn new(secure: bool, host: String, addr: SocketAddr) -> Self { + AppConfig { secure, host, addr } + } + + /// Needed in actix-test crate. Semver exempt. + #[doc(hidden)] + pub fn __priv_test_new(secure: bool, host: String, addr: SocketAddr) -> Self { + AppConfig::new(secure, host, addr) } /// Server host name. /// - /// Host name is used by application router as a hostname for url generation. + /// Host name is used by application router as a hostname for URL generation. /// Check [ConnectionInfo](super::dev::ConnectionInfo::host()) /// documentation for more information. /// @@ -127,7 +137,7 @@ impl AppConfig { &self.host } - /// Returns true if connection is secure(https) + /// Returns true if connection is secure (i.e., running over `https:`). pub fn secure(&self) -> bool { self.secure } @@ -136,14 +146,19 @@ impl AppConfig { pub fn local_addr(&self) -> SocketAddr { self.addr } + + #[cfg(test)] + pub(crate) fn set_host(&mut self, host: &str) { + self.host = host.to_owned(); + } } impl Default for AppConfig { fn default() -> Self { AppConfig::new( false, - "127.0.0.1:8080".parse().unwrap(), "localhost:8080".to_owned(), + "127.0.0.1:8080".parse().unwrap(), ) } } @@ -186,6 +201,7 @@ impl ServiceConfig { /// Add shared app data item. /// /// Counterpart to [`App::data()`](crate::App::data). + #[deprecated(since = "4.0.0", note = "Use `.app_data(Data::new(val))` instead.")] pub fn data(&mut self, data: U) -> &mut Self { self.app_data(Data::new(data)); self @@ -246,7 +262,7 @@ impl ServiceConfig { U: AsRef, { let mut rdef = ResourceDef::new(url.as_ref()); - *rdef.name_mut() = name.as_ref().to_string(); + rdef.set_name(name.as_ref()); self.external.push(rdef); self } @@ -262,6 +278,8 @@ mod tests { use crate::test::{call_service, init_service, read_body, TestRequest}; use crate::{web, App, HttpRequest, HttpResponse}; + // allow deprecated `ServiceConfig::data` + #[allow(deprecated)] #[actix_rt::test] async fn test_data() { let cfg = |cfg: &mut ServiceConfig| { diff --git a/src/data.rs b/src/data.rs index 133248212..ce7b1fee6 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,14 +1,14 @@ -use std::any::type_name; -use std::ops::Deref; -use std::sync::Arc; +use std::{any::type_name, ops::Deref, sync::Arc}; -use actix_http::error::{Error, ErrorInternalServerError}; use actix_http::Extensions; -use futures_util::future::{err, ok, LocalBoxFuture, Ready}; +use actix_utils::future::{err, ok, Ready}; +use futures_core::future::LocalBoxFuture; +use serde::Serialize; -use crate::dev::Payload; -use crate::extract::FromRequest; -use crate::request::HttpRequest; +use crate::{ + dev::Payload, error::ErrorInternalServerError, extract::FromRequest, request::HttpRequest, + Error, +}; /// Data factory. pub(crate) trait DataFactory { @@ -19,49 +19,76 @@ pub(crate) trait DataFactory { pub(crate) type FnDataFactory = Box LocalBoxFuture<'static, Result, ()>>>; -/// Application data. +/// Application data wrapper and extractor. /// -/// Application level data is a piece of arbitrary data attached to the app, scope, or resource. -/// Application data is available to all routes and can be added during the application -/// configuration process via `App::data()`. +/// # Setting Data +/// Data is set using the `app_data` methods on `App`, `Scope`, and `Resource`. If data is wrapped +/// in this `Data` type for those calls, it can be used as an extractor. /// -/// Application data can be accessed by using `Data` extractor where `T` is data type. +/// Note that `Data` should be constructed _outside_ the `HttpServer::new` closure if shared, +/// potentially mutable state is desired. `Data` is cheap to clone; internally, it uses an `Arc`. /// -/// **Note**: HTTP server accepts an application factory rather than an application instance. HTTP -/// server constructs an application instance for each thread, thus application data must be -/// constructed multiple times. If you want to share data between different threads, a shareable -/// object should be used, e.g. `Send + Sync`. Application data does not need to be `Send` -/// or `Sync`. Internally `Data` uses `Arc`. +/// See also [`App::app_data`](crate::App::app_data), [`Scope::app_data`](crate::Scope::app_data), +/// and [`Resource::app_data`](crate::Resource::app_data). /// -/// If route data is not set for a handler, using `Data` extractor would cause *Internal -/// Server Error* response. +/// # Extracting `Data` +/// Since the Actix Web router layers application data, the returned object will reference the +/// "closest" instance of the type. For example, if an `App` stores a `u32`, a nested `Scope` +/// also stores a `u32`, and the delegated request handler falls within that `Scope`, then +/// extracting a `web::>` for that handler will return the `Scope`'s instance. +/// However, using the same router set up and a request that does not get captured by the `Scope`, +/// `web::>` would return the `App`'s instance. /// -/// ```rust +/// If route data is not set for a handler, using `Data` extractor would cause a `500 Internal +/// Server Error` response. +/// +/// See also [`HttpRequest::app_data`] +/// and [`ServiceRequest::app_data`](crate::dev::ServiceRequest::app_data). +/// +/// # Unsized Data +/// For types that are unsized, most commonly `dyn T`, `Data` can wrap these types by first +/// constructing an `Arc` and using the `From` implementation to convert it. +/// +/// ``` +/// # use std::{fmt::Display, sync::Arc}; +/// # use actix_web::web::Data; +/// let displayable_arc: Arc = Arc::new(42usize); +/// let displayable_data: Data = Data::from(displayable_arc); +/// ``` +/// +/// # Examples +/// ``` /// use std::sync::Mutex; -/// use actix_web::{web, App, HttpResponse, Responder}; +/// use actix_web::{App, HttpRequest, HttpResponse, Responder, web::{self, Data}}; /// /// struct MyData { /// counter: usize, /// } /// /// /// Use the `Data` extractor to access data in a handler. -/// async fn index(data: web::Data>) -> impl Responder { -/// let mut data = data.lock().unwrap(); -/// data.counter += 1; +/// async fn index(data: Data>) -> impl Responder { +/// let mut my_data = data.lock().unwrap(); +/// my_data.counter += 1; /// HttpResponse::Ok() /// } /// -/// fn main() { -/// let data = web::Data::new(Mutex::new(MyData{ counter: 0 })); -/// -/// let app = App::new() -/// // Store `MyData` in application storage. -/// .app_data(data.clone()) -/// .service( -/// web::resource("/index.html").route( -/// web::get().to(index))); +/// /// Alteratively, use the `HttpRequest::app_data` method to access data in a handler. +/// async fn index_alt(req: HttpRequest) -> impl Responder { +/// let data = req.app_data::>>().unwrap(); +/// let mut my_data = data.lock().unwrap(); +/// my_data.counter += 1; +/// HttpResponse::Ok() /// } +/// +/// let data = Data::new(Mutex::new(MyData { counter: 0 })); +/// +/// let app = App::new() +/// // Store `MyData` in application storage. +/// .app_data(Data::clone(&data)) +/// .route("/index.html", web::get().to(index)) +/// .route("/index-alt.html", web::get().to(index_alt)); /// ``` +#[doc(alias = "state")] #[derive(Debug)] pub struct Data(Arc); @@ -70,13 +97,15 @@ impl Data { pub fn new(state: T) -> Data { Data(Arc::new(state)) } +} - /// Get reference to inner app data. +impl Data { + /// Returns reference to inner `T`. pub fn get_ref(&self) -> &T { self.0.as_ref() } - /// Convert to the internal Arc + /// Unwraps to the internal `Arc` pub fn into_inner(self) -> Arc { self.0 } @@ -102,8 +131,19 @@ impl From> for Data { } } +impl Serialize for Data +where + T: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.0.serialize(serializer) + } +} + impl FromRequest for Data { - type Config = (); type Error = Error; type Future = Ready>; @@ -113,13 +153,16 @@ impl FromRequest for Data { ok(st.clone()) } else { log::debug!( - "Failed to construct App-level Data extractor. \ - Request path: {:?} (type: {})", - req.path(), + "Failed to extract `Data<{}>` for `{}` handler. For the Data extractor to work \ + correctly, wrap the data with `Data::new()` and pass it to `App::app_data()`. \ + Ensure that types align in both the set and retrieve calls.", type_name::(), + req.match_name().unwrap_or_else(|| req.path()) ); + err(ErrorInternalServerError( - "App data is not configured, to configure use App::data()", + "Requested application data is not configured correctly. \ + View/enable debug logs for more details.", )) } } @@ -134,14 +177,16 @@ impl DataFactory for Data { #[cfg(test)] mod tests { - use actix_service::Service; - use std::sync::atomic::{AtomicUsize, Ordering}; - use super::*; - use crate::http::StatusCode; - use crate::test::{self, init_service, TestRequest}; - use crate::{web, App, HttpResponse}; + use crate::{ + dev::Service, + http::StatusCode, + test::{init_service, TestRequest}, + web, App, HttpResponse, + }; + // allow deprecated App::data + #[allow(deprecated)] #[actix_rt::test] async fn test_data_extractor() { let srv = init_service(App::new().data("TEST".to_string()).service( @@ -209,6 +254,8 @@ mod tests { assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } + // allow deprecated App::data + #[allow(deprecated)] #[actix_rt::test] async fn test_route_data_extractor() { let srv = init_service( @@ -238,6 +285,8 @@ mod tests { assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } + // allow deprecated App::data + #[allow(deprecated)] #[actix_rt::test] async fn test_override_data() { let srv = @@ -256,54 +305,11 @@ mod tests { assert_eq!(resp.status(), StatusCode::OK); } - #[actix_rt::test] - async fn test_data_drop() { - struct TestData(Arc); - - impl TestData { - fn new(inner: Arc) -> Self { - let _ = inner.fetch_add(1, Ordering::SeqCst); - Self(inner) - } - } - - impl Clone for TestData { - fn clone(&self) -> Self { - let inner = self.0.clone(); - let _ = inner.fetch_add(1, Ordering::SeqCst); - Self(inner) - } - } - - impl Drop for TestData { - fn drop(&mut self) { - let _ = self.0.fetch_sub(1, Ordering::SeqCst); - } - } - - let num = Arc::new(AtomicUsize::new(0)); - let data = TestData::new(num.clone()); - assert_eq!(num.load(Ordering::SeqCst), 1); - - let srv = test::start(move || { - let data = data.clone(); - - App::new() - .data(data) - .service(web::resource("/").to(|_data: Data| async { "ok" })) - }); - - assert!(srv.get("/").send().await.unwrap().status().is_success()); - srv.stop().await; - - assert_eq!(num.load(Ordering::SeqCst), 0); - } - #[actix_rt::test] async fn test_data_from_arc() { let data_new = Data::new(String::from("test-123")); let data_from_arc = Data::from(Arc::new(String::from("test-123"))); - assert_eq!(data_new.0, data_from_arc.0) + assert_eq!(data_new.0, data_from_arc.0); } #[actix_rt::test] @@ -325,4 +331,38 @@ mod tests { let data_arc = Data::from(dyn_arc); assert_eq!(data_arc_box.get_num(), data_arc.get_num()) } + + #[actix_rt::test] + async fn test_dyn_data_into_arc() { + trait TestTrait { + fn get_num(&self) -> i32; + } + struct A {} + impl TestTrait for A { + fn get_num(&self) -> i32 { + 42 + } + } + let dyn_arc: Arc = Arc::new(A {}); + let data_arc = Data::from(dyn_arc); + let arc_from_data = data_arc.clone().into_inner(); + assert_eq!(data_arc.get_num(), arc_from_data.get_num()) + } + + #[actix_rt::test] + async fn test_get_ref_from_dyn_data() { + trait TestTrait { + fn get_num(&self) -> i32; + } + struct A {} + impl TestTrait for A { + fn get_num(&self) -> i32 { + 42 + } + } + let dyn_arc: Arc = Arc::new(A {}); + let data_arc = Data::from(dyn_arc); + let ref_data = data_arc.get_ref(); + assert_eq!(data_arc.get_num(), ref_data.get_num()) + } } diff --git a/src/dev.rs b/src/dev.rs new file mode 100644 index 000000000..def545ec7 --- /dev/null +++ b/src/dev.rs @@ -0,0 +1,44 @@ +//! Lower-level types and re-exports. +//! +//! Most users will not have to interact with the types in this module, but it is useful for those +//! writing extractors, middleware, libraries, or interacting with the service API directly. + +pub use actix_http::{Extensions, Payload, RequestHead, Response, ResponseHead}; +pub use actix_router::{Path, ResourceDef, ResourcePath, Url}; +pub use actix_server::{Server, ServerHandle}; +pub use actix_service::{ + always_ready, fn_factory, fn_service, forward_ready, Service, ServiceFactory, Transform, +}; + +#[cfg(feature = "__compress")] +pub use actix_http::encoding::Decoder as Decompress; + +pub use crate::config::{AppConfig, AppService}; +#[doc(hidden)] +pub use crate::handler::Handler; +pub use crate::info::{ConnectionInfo, PeerAddr}; +pub use crate::rmap::ResourceMap; +pub use crate::service::{HttpServiceFactory, ServiceRequest, ServiceResponse, WebService}; + +pub use crate::types::{JsonBody, Readlines, UrlEncoded}; + +use actix_router::Patterns; + +pub(crate) fn ensure_leading_slash(mut patterns: Patterns) -> Patterns { + match &mut patterns { + Patterns::Single(pat) => { + if !pat.is_empty() && !pat.starts_with('/') { + pat.insert(0, '/'); + }; + } + Patterns::List(pats) => { + for pat in pats { + if !pat.is_empty() && !pat.starts_with('/') { + pat.insert(0, '/'); + }; + } + } + } + + patterns +} diff --git a/src/error/error.rs b/src/error/error.rs new file mode 100644 index 000000000..be17c1962 --- /dev/null +++ b/src/error/error.rs @@ -0,0 +1,76 @@ +use std::{error::Error as StdError, fmt}; + +use actix_http::{body::BoxBody, Response}; + +use crate::{HttpResponse, ResponseError}; + +/// General purpose actix web error. +/// +/// An actix web error is used to carry errors from `std::error` +/// through actix in a convenient way. It can be created through +/// converting errors with `into()`. +/// +/// Whenever it is created from an external object a response error is created +/// for it that can be used to create an HTTP response from it this means that +/// if you have access to an actix `Error` you can always get a +/// `ResponseError` reference from it. +pub struct Error { + cause: Box, +} + +impl Error { + /// Returns the reference to the underlying `ResponseError`. + pub fn as_response_error(&self) -> &dyn ResponseError { + self.cause.as_ref() + } + + /// Similar to `as_response_error` but downcasts. + pub fn as_error(&self) -> Option<&T> { + ::downcast_ref(self.cause.as_ref()) + } + + /// Shortcut for creating an `HttpResponse`. + pub fn error_response(&self) -> HttpResponse { + self.cause.error_response() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.cause, f) + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", &self.cause) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + // TODO: populate if replacement for Box is found + None + } +} + +impl From for Error { + fn from(val: std::convert::Infallible) -> Self { + match val {} + } +} + +/// `Error` for any error that implements `ResponseError` +impl From for Error { + fn from(err: T) -> Error { + Error { + cause: Box::new(err), + } + } +} + +impl From for Response { + fn from(err: Error) -> Response { + err.error_response().into() + } +} diff --git a/src/error/internal.rs b/src/error/internal.rs new file mode 100644 index 000000000..1c6e343e3 --- /dev/null +++ b/src/error/internal.rs @@ -0,0 +1,316 @@ +use std::{cell::RefCell, fmt, io::Write as _}; + +use actix_http::{ + body::BoxBody, + header::{self, TryIntoHeaderValue as _}, + StatusCode, +}; +use bytes::{BufMut as _, BytesMut}; + +use crate::{Error, HttpRequest, HttpResponse, Responder, ResponseError}; + +/// Wraps errors to alter the generated response status code. +/// +/// In following example, the `io::Error` is wrapped into `ErrorBadRequest` which will generate a +/// response with the 400 Bad Request status code instead of the usual status code generated by +/// an `io::Error`. +/// +/// # Examples +/// ``` +/// # use std::io; +/// # use actix_web::{error, HttpRequest}; +/// async fn handler_error() -> Result { +/// let err = io::Error::new(io::ErrorKind::Other, "error"); +/// Err(error::ErrorBadRequest(err)) +/// } +/// ``` +pub struct InternalError { + cause: T, + status: InternalErrorType, +} + +enum InternalErrorType { + Status(StatusCode), + Response(RefCell>), +} + +impl InternalError { + /// Constructs an `InternalError` with given status code. + pub fn new(cause: T, status: StatusCode) -> Self { + InternalError { + cause, + status: InternalErrorType::Status(status), + } + } + + /// Constructs an `InternalError` with pre-defined response. + pub fn from_response(cause: T, response: HttpResponse) -> Self { + InternalError { + cause, + status: InternalErrorType::Response(RefCell::new(Some(response))), + } + } +} + +impl fmt::Debug for InternalError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.cause.fmt(f) + } +} + +impl fmt::Display for InternalError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.cause.fmt(f) + } +} + +impl ResponseError for InternalError +where + T: fmt::Debug + fmt::Display, +{ + fn status_code(&self) -> StatusCode { + match self.status { + InternalErrorType::Status(st) => st, + InternalErrorType::Response(ref resp) => { + if let Some(resp) = resp.borrow().as_ref() { + resp.head().status + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + } + } + } + + fn error_response(&self) -> HttpResponse { + match self.status { + InternalErrorType::Status(status) => { + let mut res = HttpResponse::new(status); + let mut buf = BytesMut::new().writer(); + let _ = write!(buf, "{}", self); + + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + + res.set_body(BoxBody::new(buf.into_inner())) + } + + InternalErrorType::Response(ref resp) => { + if let Some(resp) = resp.borrow_mut().take() { + resp + } else { + HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR) + } + } + } + } +} + +impl Responder for InternalError +where + T: fmt::Debug + fmt::Display + 'static, +{ + type Body = BoxBody; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + HttpResponse::from_error(self) + } +} + +macro_rules! error_helper { + ($name:ident, $status:ident) => { + #[doc = concat!("Helper function that wraps any error and generates a `", stringify!($status), "` response.")] + #[allow(non_snake_case)] + pub fn $name(err: T) -> Error + where + T: fmt::Debug + fmt::Display + 'static, + { + InternalError::new(err, StatusCode::$status).into() + } + }; +} + +error_helper!(ErrorBadRequest, BAD_REQUEST); +error_helper!(ErrorUnauthorized, UNAUTHORIZED); +error_helper!(ErrorPaymentRequired, PAYMENT_REQUIRED); +error_helper!(ErrorForbidden, FORBIDDEN); +error_helper!(ErrorNotFound, NOT_FOUND); +error_helper!(ErrorMethodNotAllowed, METHOD_NOT_ALLOWED); +error_helper!(ErrorNotAcceptable, NOT_ACCEPTABLE); +error_helper!( + ErrorProxyAuthenticationRequired, + PROXY_AUTHENTICATION_REQUIRED +); +error_helper!(ErrorRequestTimeout, REQUEST_TIMEOUT); +error_helper!(ErrorConflict, CONFLICT); +error_helper!(ErrorGone, GONE); +error_helper!(ErrorLengthRequired, LENGTH_REQUIRED); +error_helper!(ErrorPayloadTooLarge, PAYLOAD_TOO_LARGE); +error_helper!(ErrorUriTooLong, URI_TOO_LONG); +error_helper!(ErrorUnsupportedMediaType, UNSUPPORTED_MEDIA_TYPE); +error_helper!(ErrorRangeNotSatisfiable, RANGE_NOT_SATISFIABLE); +error_helper!(ErrorImATeapot, IM_A_TEAPOT); +error_helper!(ErrorMisdirectedRequest, MISDIRECTED_REQUEST); +error_helper!(ErrorUnprocessableEntity, UNPROCESSABLE_ENTITY); +error_helper!(ErrorLocked, LOCKED); +error_helper!(ErrorFailedDependency, FAILED_DEPENDENCY); +error_helper!(ErrorUpgradeRequired, UPGRADE_REQUIRED); +error_helper!(ErrorPreconditionFailed, PRECONDITION_FAILED); +error_helper!(ErrorPreconditionRequired, PRECONDITION_REQUIRED); +error_helper!(ErrorTooManyRequests, TOO_MANY_REQUESTS); +error_helper!( + ErrorRequestHeaderFieldsTooLarge, + REQUEST_HEADER_FIELDS_TOO_LARGE +); +error_helper!( + ErrorUnavailableForLegalReasons, + UNAVAILABLE_FOR_LEGAL_REASONS +); +error_helper!(ErrorExpectationFailed, EXPECTATION_FAILED); +error_helper!(ErrorInternalServerError, INTERNAL_SERVER_ERROR); +error_helper!(ErrorNotImplemented, NOT_IMPLEMENTED); +error_helper!(ErrorBadGateway, BAD_GATEWAY); +error_helper!(ErrorServiceUnavailable, SERVICE_UNAVAILABLE); +error_helper!(ErrorGatewayTimeout, GATEWAY_TIMEOUT); +error_helper!(ErrorHttpVersionNotSupported, HTTP_VERSION_NOT_SUPPORTED); +error_helper!(ErrorVariantAlsoNegotiates, VARIANT_ALSO_NEGOTIATES); +error_helper!(ErrorInsufficientStorage, INSUFFICIENT_STORAGE); +error_helper!(ErrorLoopDetected, LOOP_DETECTED); +error_helper!(ErrorNotExtended, NOT_EXTENDED); +error_helper!( + ErrorNetworkAuthenticationRequired, + NETWORK_AUTHENTICATION_REQUIRED +); + +#[cfg(test)] +mod tests { + use actix_http::error::ParseError; + + use super::*; + + #[test] + fn test_internal_error() { + let err = InternalError::from_response(ParseError::Method, HttpResponse::Ok().finish()); + let resp: HttpResponse = err.error_response(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[test] + fn test_error_helpers() { + let res: HttpResponse = ErrorBadRequest("err").into(); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + + let res: HttpResponse = ErrorUnauthorized("err").into(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + + let res: HttpResponse = ErrorPaymentRequired("err").into(); + assert_eq!(res.status(), StatusCode::PAYMENT_REQUIRED); + + let res: HttpResponse = ErrorForbidden("err").into(); + assert_eq!(res.status(), StatusCode::FORBIDDEN); + + let res: HttpResponse = ErrorNotFound("err").into(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let res: HttpResponse = ErrorMethodNotAllowed("err").into(); + assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); + + let res: HttpResponse = ErrorNotAcceptable("err").into(); + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + + let res: HttpResponse = ErrorProxyAuthenticationRequired("err").into(); + assert_eq!(res.status(), StatusCode::PROXY_AUTHENTICATION_REQUIRED); + + let res: HttpResponse = ErrorRequestTimeout("err").into(); + assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); + + let res: HttpResponse = ErrorConflict("err").into(); + assert_eq!(res.status(), StatusCode::CONFLICT); + + let res: HttpResponse = ErrorGone("err").into(); + assert_eq!(res.status(), StatusCode::GONE); + + let res: HttpResponse = ErrorLengthRequired("err").into(); + assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED); + + let res: HttpResponse = ErrorPreconditionFailed("err").into(); + assert_eq!(res.status(), StatusCode::PRECONDITION_FAILED); + + let res: HttpResponse = ErrorPayloadTooLarge("err").into(); + assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); + + let res: HttpResponse = ErrorUriTooLong("err").into(); + assert_eq!(res.status(), StatusCode::URI_TOO_LONG); + + let res: HttpResponse = ErrorUnsupportedMediaType("err").into(); + assert_eq!(res.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE); + + let res: HttpResponse = ErrorRangeNotSatisfiable("err").into(); + assert_eq!(res.status(), StatusCode::RANGE_NOT_SATISFIABLE); + + let res: HttpResponse = ErrorExpectationFailed("err").into(); + assert_eq!(res.status(), StatusCode::EXPECTATION_FAILED); + + let res: HttpResponse = ErrorImATeapot("err").into(); + assert_eq!(res.status(), StatusCode::IM_A_TEAPOT); + + let res: HttpResponse = ErrorMisdirectedRequest("err").into(); + assert_eq!(res.status(), StatusCode::MISDIRECTED_REQUEST); + + let res: HttpResponse = ErrorUnprocessableEntity("err").into(); + assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); + + let res: HttpResponse = ErrorLocked("err").into(); + assert_eq!(res.status(), StatusCode::LOCKED); + + let res: HttpResponse = ErrorFailedDependency("err").into(); + assert_eq!(res.status(), StatusCode::FAILED_DEPENDENCY); + + let res: HttpResponse = ErrorUpgradeRequired("err").into(); + assert_eq!(res.status(), StatusCode::UPGRADE_REQUIRED); + + let res: HttpResponse = ErrorPreconditionRequired("err").into(); + assert_eq!(res.status(), StatusCode::PRECONDITION_REQUIRED); + + let res: HttpResponse = ErrorTooManyRequests("err").into(); + assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS); + + let res: HttpResponse = ErrorRequestHeaderFieldsTooLarge("err").into(); + assert_eq!(res.status(), StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE); + + let res: HttpResponse = ErrorUnavailableForLegalReasons("err").into(); + assert_eq!(res.status(), StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS); + + let res: HttpResponse = ErrorInternalServerError("err").into(); + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let res: HttpResponse = ErrorNotImplemented("err").into(); + assert_eq!(res.status(), StatusCode::NOT_IMPLEMENTED); + + let res: HttpResponse = ErrorBadGateway("err").into(); + assert_eq!(res.status(), StatusCode::BAD_GATEWAY); + + let res: HttpResponse = ErrorServiceUnavailable("err").into(); + assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE); + + let res: HttpResponse = ErrorGatewayTimeout("err").into(); + assert_eq!(res.status(), StatusCode::GATEWAY_TIMEOUT); + + let res: HttpResponse = ErrorHttpVersionNotSupported("err").into(); + assert_eq!(res.status(), StatusCode::HTTP_VERSION_NOT_SUPPORTED); + + let res: HttpResponse = ErrorVariantAlsoNegotiates("err").into(); + assert_eq!(res.status(), StatusCode::VARIANT_ALSO_NEGOTIATES); + + let res: HttpResponse = ErrorInsufficientStorage("err").into(); + assert_eq!(res.status(), StatusCode::INSUFFICIENT_STORAGE); + + let res: HttpResponse = ErrorLoopDetected("err").into(); + assert_eq!(res.status(), StatusCode::LOOP_DETECTED); + + let res: HttpResponse = ErrorNotExtended("err").into(); + assert_eq!(res.status(), StatusCode::NOT_EXTENDED); + + let res: HttpResponse = ErrorNetworkAuthenticationRequired("err").into(); + assert_eq!(res.status(), StatusCode::NETWORK_AUTHENTICATION_REQUIRED); + } +} diff --git a/actix-http/src/macros.rs b/src/error/macros.rs similarity index 74% rename from actix-http/src/macros.rs rename to src/error/macros.rs index 8973aa39b..78b1ed9f6 100644 --- a/actix-http/src/macros.rs +++ b/src/error/macros.rs @@ -1,4 +1,3 @@ -#[macro_export] macro_rules! downcast_get_type_id { () => { /// A helper method to get the type ID of the type @@ -14,8 +13,13 @@ macro_rules! downcast_get_type_id { /// making it impossible for safe code to construct outside of /// this module. This ensures that safe code cannot violate /// type-safety by implementing this method. + /// + /// We also take `PrivateHelper` as a parameter, to ensure that + /// safe code cannot obtain a `PrivateHelper` instance by + /// delegating to an existing implementation of `__private_get_type_id__` #[doc(hidden)] - fn __private_get_type_id__(&self) -> (std::any::TypeId, PrivateHelper) + #[allow(dead_code)] + fn __private_get_type_id__(&self, _: PrivateHelper) -> (std::any::TypeId, PrivateHelper) where Self: 'static, { @@ -24,20 +28,23 @@ macro_rules! downcast_get_type_id { }; } -//Generate implementation for dyn $name -#[macro_export] -macro_rules! downcast { +// Generate implementation for dyn $name +macro_rules! downcast_dyn { ($name:ident) => { /// A struct with a private constructor, for use with /// `__private_get_type_id__`. Its single field is private, /// ensuring that it can only be constructed from this module #[doc(hidden)] + #[allow(dead_code)] pub struct PrivateHelper(()); impl dyn $name + 'static { /// Downcasts generic body to a specific type. + #[allow(dead_code)] pub fn downcast_ref(&self) -> Option<&T> { - if self.__private_get_type_id__().0 == std::any::TypeId::of::() { + if self.__private_get_type_id__(PrivateHelper(())).0 + == std::any::TypeId::of::() + { // SAFETY: external crates cannot override the default // implementation of `__private_get_type_id__`, since // it requires returning a private type. We can therefore @@ -50,16 +57,17 @@ macro_rules! downcast { } /// Downcasts a generic body to a mutable specific type. + #[allow(dead_code)] pub fn downcast_mut(&mut self) -> Option<&mut T> { - if self.__private_get_type_id__().0 == std::any::TypeId::of::() { + if self.__private_get_type_id__(PrivateHelper(())).0 + == std::any::TypeId::of::() + { // SAFETY: external crates cannot override the default // implementation of `__private_get_type_id__`, since // it requires returning a private type. We can therefore // rely on the returned `TypeId`, which ensures that this // case is correct. - unsafe { - Some(&mut *(self as *const dyn $name as *const T as *mut T)) - } + unsafe { Some(&mut *(self as *const dyn $name as *const T as *mut T)) } } else { None } @@ -68,14 +76,17 @@ macro_rules! downcast { }; } +pub(crate) use {downcast_dyn, downcast_get_type_id}; + #[cfg(test)] mod tests { + #![allow(clippy::upper_case_acronyms)] trait MB { downcast_get_type_id!(); } - downcast!(MB); + downcast_dyn!(MB); impl MB for String {} impl MB for () {} @@ -86,7 +97,7 @@ mod tests { let resp_body: &mut dyn MB = &mut body; let body = resp_body.downcast_ref::().unwrap(); assert_eq!(body, "hello cast"); - let body = &mut resp_body.downcast_mut::().unwrap(); + let body = resp_body.downcast_mut::().unwrap(); body.push('!'); let body = resp_body.downcast_ref::().unwrap(); assert_eq!(body, "hello cast!"); diff --git a/src/error.rs b/src/error/mod.rs similarity index 53% rename from src/error.rs rename to src/error/mod.rs index 0865257d3..64df9f553 100644 --- a/src/error.rs +++ b/src/error/mod.rs @@ -1,35 +1,60 @@ //! Error and Result module -pub use actix_http::error::*; +// This is meant to be a glob import of the whole error module except for `Error`. Rustdoc can't yet +// correctly resolve the conflicting `Error` type defined in this module, so these re-exports are +// expanded manually. +// +// See +pub use actix_http::error::{ + BlockingError, ContentTypeError, DispatchError, HttpError, ParseError, PayloadError, +}; + use derive_more::{Display, Error, From}; use serde_json::error::Error as JsonError; +use serde_urlencoded::de::Error as FormDeError; +use serde_urlencoded::ser::Error as FormError; use url::ParseError as UrlParseError; -use crate::{http::StatusCode, HttpResponse}; +use crate::http::StatusCode; + +#[allow(clippy::module_inception)] +mod error; +mod internal; +mod macros; +mod response_error; + +pub use self::error::Error; +pub use self::internal::*; +pub use self::response_error::ResponseError; +pub(crate) use macros::{downcast_dyn, downcast_get_type_id}; + +/// A convenience [`Result`](std::result::Result) for Actix Web operations. +/// +/// This type alias is generally used to avoid writing out `actix_http::Error` directly. +pub type Result = std::result::Result; /// Errors which can occur when attempting to generate resource uri. -#[derive(Debug, PartialEq, Display, From)] +#[derive(Debug, PartialEq, Display, Error, From)] +#[non_exhaustive] pub enum UrlGenerationError { - /// Resource not found + /// Resource not found. #[display(fmt = "Resource not found")] ResourceNotFound, - /// Not all path pattern covered - #[display(fmt = "Not all path pattern covered")] + /// Not all URL parameters covered. + #[display(fmt = "Not all URL parameters covered")] NotEnoughElements, - /// URL parse error + /// URL parse error. #[display(fmt = "{}", _0)] ParseError(UrlParseError), } -impl std::error::Error for UrlGenerationError {} - -/// `InternalServerError` for `UrlGeneratorError` impl ResponseError for UrlGenerationError {} /// A set of errors that can occur during parsing urlencoded payloads #[derive(Debug, Display, Error, From)] +#[non_exhaustive] pub enum UrlencodedError { /// Can not decode chunked transfer encoding. #[display(fmt = "Can not decode chunked transfer encoding.")] @@ -52,64 +77,96 @@ pub enum UrlencodedError { ContentType, /// Parse error. - #[display(fmt = "Parse error.")] - Parse, + #[display(fmt = "Parse error: {}.", _0)] + Parse(FormDeError), + + /// Encoding error. + #[display(fmt = "Encoding error.")] + Encoding, + + /// Serialize error. + #[display(fmt = "Serialize error: {}.", _0)] + Serialize(FormError), /// Payload error. #[display(fmt = "Error that occur during reading payload: {}.", _0)] Payload(PayloadError), } -/// Return `BadRequest` for `UrlencodedError` impl ResponseError for UrlencodedError { fn status_code(&self) -> StatusCode { - match *self { - UrlencodedError::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE, - UrlencodedError::UnknownLength => StatusCode::LENGTH_REQUIRED, + match self { + Self::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE, + Self::UnknownLength => StatusCode::LENGTH_REQUIRED, + Self::Payload(err) => err.status_code(), _ => StatusCode::BAD_REQUEST, } } } /// A set of errors that can occur during parsing json payloads -#[derive(Debug, Display, From)] +#[derive(Debug, Display, Error)] +#[non_exhaustive] pub enum JsonPayloadError { - /// Payload size is bigger than allowed. (default: 32kB) - #[display(fmt = "Json payload size is bigger than allowed")] - Overflow, + /// Payload size is bigger than allowed & content length header set. (default: 2MB) + #[display( + fmt = "JSON payload ({} bytes) is larger than allowed (limit: {} bytes).", + length, + limit + )] + OverflowKnownLength { length: usize, limit: usize }, + + /// Payload size is bigger than allowed but no content length header set. (default: 2MB) + #[display(fmt = "JSON payload has exceeded limit ({} bytes).", limit)] + Overflow { limit: usize }, + /// Content type error #[display(fmt = "Content type error")] ContentType, + /// Deserialize error #[display(fmt = "Json deserialize error: {}", _0)] Deserialize(JsonError), + + /// Serialize error + #[display(fmt = "Json serialize error: {}", _0)] + Serialize(JsonError), + /// Payload error #[display(fmt = "Error that occur during reading payload: {}", _0)] Payload(PayloadError), } -impl std::error::Error for JsonPayloadError {} +impl From for JsonPayloadError { + fn from(err: PayloadError) -> Self { + Self::Payload(err) + } +} -/// Return `BadRequest` for `JsonPayloadError` impl ResponseError for JsonPayloadError { - fn error_response(&self) -> HttpResponse { - match *self { - JsonPayloadError::Overflow => HttpResponse::new(StatusCode::PAYLOAD_TOO_LARGE), - _ => HttpResponse::new(StatusCode::BAD_REQUEST), + fn status_code(&self) -> StatusCode { + match self { + Self::OverflowKnownLength { + length: _, + limit: _, + } => StatusCode::PAYLOAD_TOO_LARGE, + Self::Overflow { limit: _ } => StatusCode::PAYLOAD_TOO_LARGE, + Self::Serialize(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::Payload(err) => err.status_code(), + _ => StatusCode::BAD_REQUEST, } } } /// A set of errors that can occur during parsing request paths -#[derive(Debug, Display, From)] +#[derive(Debug, Display, Error)] +#[non_exhaustive] pub enum PathError { /// Deserialize error #[display(fmt = "Path deserialize error: {}", _0)] Deserialize(serde::de::value::Error), } -impl std::error::Error for PathError {} - /// Return `BadRequest` for `PathError` impl ResponseError for PathError { fn status_code(&self) -> StatusCode { @@ -119,13 +176,13 @@ impl ResponseError for PathError { /// A set of errors that can occur during parsing query strings. #[derive(Debug, Display, Error, From)] +#[non_exhaustive] pub enum QueryPayloadError { /// Query deserialize error. #[display(fmt = "Query deserialize error: {}", _0)] Deserialize(serde::de::value::Error), } -/// Return `BadRequest` for `QueryPayloadError` impl ResponseError for QueryPayloadError { fn status_code(&self) -> StatusCode { StatusCode::BAD_REQUEST @@ -133,26 +190,26 @@ impl ResponseError for QueryPayloadError { } /// Error type returned when reading body as lines. -#[derive(From, Display, Debug)] +#[derive(Debug, Display, Error, From)] +#[non_exhaustive] pub enum ReadlinesError { - /// Error when decoding a line. #[display(fmt = "Encoding error")] /// Payload size is bigger than allowed. (default: 256kB) EncodingError, + /// Payload error. #[display(fmt = "Error that occur during reading payload: {}", _0)] Payload(PayloadError), + /// Line limit exceeded. #[display(fmt = "Line limit exceeded")] LimitOverflow, + /// ContentType error. #[display(fmt = "Content-type error")] ContentTypeError(ContentTypeError), } -impl std::error::Error for ReadlinesError {} - -/// Return `BadRequest` for `ReadlinesError` impl ResponseError for ReadlinesError { fn status_code(&self) -> StatusCode { match *self { @@ -168,26 +225,31 @@ mod tests { #[test] fn test_urlencoded_error() { - let resp: HttpResponse = - UrlencodedError::Overflow { size: 0, limit: 0 }.error_response(); + let resp = UrlencodedError::Overflow { size: 0, limit: 0 }.error_response(); assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE); - let resp: HttpResponse = UrlencodedError::UnknownLength.error_response(); + let resp = UrlencodedError::UnknownLength.error_response(); assert_eq!(resp.status(), StatusCode::LENGTH_REQUIRED); - let resp: HttpResponse = UrlencodedError::ContentType.error_response(); + let resp = UrlencodedError::ContentType.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } #[test] fn test_json_payload_error() { - let resp: HttpResponse = JsonPayloadError::Overflow.error_response(); + let resp = JsonPayloadError::OverflowKnownLength { + length: 0, + limit: 0, + } + .error_response(); assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE); - let resp: HttpResponse = JsonPayloadError::ContentType.error_response(); + let resp = JsonPayloadError::Overflow { limit: 0 }.error_response(); + assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE); + let resp = JsonPayloadError::ContentType.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } #[test] fn test_query_payload_error() { - let resp: HttpResponse = QueryPayloadError::Deserialize( + let resp = QueryPayloadError::Deserialize( serde_urlencoded::from_str::("bad query").unwrap_err(), ) .error_response(); @@ -196,9 +258,9 @@ mod tests { #[test] fn test_readlines_error() { - let resp: HttpResponse = ReadlinesError::LimitOverflow.error_response(); + let resp = ReadlinesError::LimitOverflow.error_response(); assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE); - let resp: HttpResponse = ReadlinesError::EncodingError.error_response(); + let resp = ReadlinesError::EncodingError.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } } diff --git a/src/error/response_error.rs b/src/error/response_error.rs new file mode 100644 index 000000000..e0b4af44c --- /dev/null +++ b/src/error/response_error.rs @@ -0,0 +1,152 @@ +//! `ResponseError` trait and foreign impls. + +use std::{ + error::Error as StdError, + fmt, + io::{self, Write as _}, +}; + +use actix_http::{ + body::BoxBody, + header::{self, TryIntoHeaderValue}, + Response, StatusCode, +}; +use bytes::BytesMut; + +use crate::{ + error::{downcast_dyn, downcast_get_type_id}, + helpers, HttpResponse, +}; + +/// Errors that can generate responses. +// TODO: add std::error::Error bound when replacement for Box is found +pub trait ResponseError: fmt::Debug + fmt::Display { + /// Returns appropriate status code for error. + /// + /// A 500 Internal Server Error is used by default. If [error_response](Self::error_response) is + /// also implemented and does not call `self.status_code()`, then this will not be used. + fn status_code(&self) -> StatusCode { + StatusCode::INTERNAL_SERVER_ERROR + } + + /// Creates full response for error. + /// + /// By default, the generated response uses a 500 Internal Server Error status code, a + /// `Content-Type` of `text/plain`, and the body is set to `Self`'s `Display` impl. + fn error_response(&self) -> HttpResponse { + let mut res = HttpResponse::new(self.status_code()); + + let mut buf = BytesMut::new(); + let _ = write!(helpers::MutWriter(&mut buf), "{}", self); + + let mime = mime::TEXT_PLAIN_UTF_8.try_into_value().unwrap(); + res.headers_mut().insert(header::CONTENT_TYPE, mime); + + res.set_body(BoxBody::new(buf)) + } + + downcast_get_type_id!(); +} + +downcast_dyn!(ResponseError); + +impl ResponseError for Box {} + +#[cfg(feature = "openssl")] +impl ResponseError for actix_tls::accept::openssl::reexports::Error {} + +impl ResponseError for serde::de::value::Error { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +impl ResponseError for serde_json::Error {} + +impl ResponseError for serde_urlencoded::ser::Error {} + +impl ResponseError for std::str::Utf8Error { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +impl ResponseError for std::io::Error { + fn status_code(&self) -> StatusCode { + // TODO: decide if these errors should consider not found or permission errors + match self.kind() { + io::ErrorKind::NotFound => StatusCode::NOT_FOUND, + io::ErrorKind::PermissionDenied => StatusCode::FORBIDDEN, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +impl ResponseError for actix_http::error::HttpError {} + +impl ResponseError for actix_http::Error { + fn status_code(&self) -> StatusCode { + // TODO: map error kinds to status code better + StatusCode::INTERNAL_SERVER_ERROR + } + + fn error_response(&self) -> HttpResponse { + HttpResponse::with_body(self.status_code(), self.to_string()).map_into_boxed_body() + } +} + +impl ResponseError for actix_http::header::InvalidHeaderValue { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +impl ResponseError for actix_http::error::ParseError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +impl ResponseError for actix_http::error::BlockingError {} + +impl ResponseError for actix_http::error::PayloadError { + fn status_code(&self) -> StatusCode { + match *self { + actix_http::error::PayloadError::Overflow => StatusCode::PAYLOAD_TOO_LARGE, + _ => StatusCode::BAD_REQUEST, + } + } +} + +impl ResponseError for actix_http::ws::ProtocolError {} + +impl ResponseError for actix_http::error::ContentTypeError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +impl ResponseError for actix_http::ws::HandshakeError { + fn error_response(&self) -> HttpResponse { + Response::from(self).map_into_boxed_body().into() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_casting() { + use actix_http::error::{ContentTypeError, PayloadError}; + + let err = PayloadError::Overflow; + let resp_err: &dyn ResponseError = &err; + + let err = resp_err.downcast_ref::().unwrap(); + assert_eq!(err.to_string(), "Payload reached size limit."); + + let not_err = resp_err.downcast_ref::(); + assert!(not_err.is_none()); + } +} diff --git a/src/extract.rs b/src/extract.rs index 7a677bca4..f74a0a54e 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -1,25 +1,65 @@ //! Request extractors use std::{ + convert::Infallible, future::Future, pin::Pin, task::{Context, Poll}, }; -use futures_util::{ - future::{ready, Ready}, - ready, -}; +use actix_http::{Method, Uri}; +use actix_utils::future::{ok, Ready}; +use futures_core::ready; +use pin_project_lite::pin_project; use crate::{dev::Payload, Error, HttpRequest}; -/// Trait implemented by types that can be extracted from request. +/// A type that implements [`FromRequest`] is called an **extractor** and can extract data from +/// the request. Some types that implement this trait are: [`Json`], [`Header`], and [`Path`]. /// -/// Types that implement this trait can be used with `Route` handlers. +/// # Configuration +/// An extractor can be customized by injecting the corresponding configuration with one of: +/// +/// - [`App::app_data()`][crate::App::app_data] +/// - [`Scope::app_data()`][crate::Scope::app_data] +/// - [`Resource::app_data()`][crate::Resource::app_data] +/// +/// Here are some built-in extractors and their corresponding configuration. +/// Please refer to the respective documentation for details. +/// +/// | Extractor | Configuration | +/// |-------------|-------------------| +/// | [`Header`] | _None_ | +/// | [`Path`] | [`PathConfig`] | +/// | [`Json`] | [`JsonConfig`] | +/// | [`Form`] | [`FormConfig`] | +/// | [`Query`] | [`QueryConfig`] | +/// | [`Bytes`] | [`PayloadConfig`] | +/// | [`String`] | [`PayloadConfig`] | +/// | [`Payload`] | [`PayloadConfig`] | +/// +/// # Implementing An Extractor +/// To reduce duplicate code in handlers where extracting certain parts of a request has a common +/// structure, you can implement `FromRequest` for your own types. +/// +/// Note that the request payload can only be consumed by one extractor. +/// +/// [`Header`]: crate::web::Header +/// [`Json`]: crate::web::Json +/// [`JsonConfig`]: crate::web::JsonConfig +/// [`Form`]: crate::web::Form +/// [`FormConfig`]: crate::web::FormConfig +/// [`Path`]: crate::web::Path +/// [`PathConfig`]: crate::web::PathConfig +/// [`Query`]: crate::web::Query +/// [`QueryConfig`]: crate::web::QueryConfig +/// [`Payload`]: crate::web::Payload +/// [`PayloadConfig`]: crate::web::PayloadConfig +/// [`String`]: FromRequest#impl-FromRequest-for-String +/// [`Bytes`]: crate::web::Bytes#impl-FromRequest +/// [`Either`]: crate::web::Either +#[doc(alias = "extract", alias = "extractor")] pub trait FromRequest: Sized { - /// Configuration for this extractor. - type Config: Default + 'static; - /// The associated error which can be returned. type Error: Into; @@ -35,27 +75,18 @@ pub trait FromRequest: Sized { fn extract(req: &HttpRequest) -> Self::Future { Self::from_request(req, &mut Payload::None) } - - /// Create and configure config instance. - fn configure(f: F) -> Self::Config - where - F: FnOnce(Self::Config) -> Self::Config, - { - f(Self::Config::default()) - } } /// Optionally extract a field from the request /// /// If the FromRequest for T fails, return None rather than returning an error response /// -/// ## Example -/// -/// ```rust +/// # Examples +/// ``` /// use actix_web::{web, dev, App, Error, HttpRequest, FromRequest}; /// use actix_web::error::ErrorBadRequest; /// use futures_util::future::{ok, err, Ready}; -/// use serde_derive::Deserialize; +/// use serde::Deserialize; /// use rand; /// /// #[derive(Debug, Deserialize)] @@ -66,7 +97,6 @@ pub trait FromRequest: Sized { /// impl FromRequest for Thing { /// type Error = Error; /// type Future = Ready>; -/// type Config = (); /// /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// if rand::random() { @@ -101,7 +131,6 @@ where { type Error = Error; type Future = FromRequestOptFuture; - type Config = T::Config; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { @@ -111,10 +140,11 @@ where } } -#[pin_project::pin_project] -pub struct FromRequestOptFuture { - #[pin] - fut: Fut, +pin_project! { + pub struct FromRequestOptFuture { + #[pin] + fut: Fut, + } } impl Future for FromRequestOptFuture @@ -141,13 +171,12 @@ where /// /// If the `FromRequest` for T fails, inject Err into handler rather than returning an error response /// -/// ## Example -/// -/// ```rust +/// # Examples +/// ``` /// use actix_web::{web, dev, App, Result, Error, HttpRequest, FromRequest}; /// use actix_web::error::ErrorBadRequest; /// use futures_util::future::{ok, err, Ready}; -/// use serde_derive::Deserialize; +/// use serde::Deserialize; /// use rand; /// /// #[derive(Debug, Deserialize)] @@ -158,7 +187,6 @@ where /// impl FromRequest for Thing { /// type Error = Error; /// type Future = Ready>; -/// type Config = (); /// /// fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { /// if rand::random() { @@ -191,7 +219,6 @@ where { type Error = Error; type Future = FromRequestResFuture; - type Config = T::Config; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { @@ -201,10 +228,11 @@ where } } -#[pin_project::pin_project] -pub struct FromRequestResFuture { - #[pin] - fut: Fut, +pin_project! { + pub struct FromRequestResFuture { + #[pin] + fut: Fut, + } } impl Future for FromRequestResFuture @@ -220,121 +248,163 @@ where } } -#[doc(hidden)] -impl FromRequest for () { - type Error = Error; - type Future = Ready>; - type Config = (); +/// Extract the request's URI. +/// +/// # Examples +/// ``` +/// use actix_web::{http::Uri, web, App, Responder}; +/// +/// async fn handler(uri: Uri) -> impl Responder { +/// format!("Requested path: {}", uri.path()) +/// } +/// +/// let app = App::new().default_service(web::to(handler)); +/// ``` +impl FromRequest for Uri { + type Error = Infallible; + type Future = Ready>; - fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future { - ready(Ok(())) + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + ok(req.uri().clone()) } } -macro_rules! tuple_from_req ({$fut_type:ident, $(($n:tt, $T:ident)),+} => { +/// Extract the request's method. +/// +/// # Examples +/// ``` +/// use actix_web::{http::Method, web, App, Responder}; +/// +/// async fn handler(method: Method) -> impl Responder { +/// format!("Request method: {}", method) +/// } +/// +/// let app = App::new().default_service(web::to(handler)); +/// ``` +impl FromRequest for Method { + type Error = Infallible; + type Future = Ready>; - // This module is a trick to get around the inability of - // `macro_rules!` macros to make new idents. We want to make - // a new `FutWrapper` struct for each distinct invocation of - // this macro. Ideally, we would name it something like - // `FutWrapper_$fut_type`, but this can't be done in a macro_rules - // macro. - // - // Instead, we put everything in a module named `$fut_type`, thus allowing - // us to use the name `FutWrapper` without worrying about conflicts. - // This macro only exists to generate trait impls for tuples - these - // are inherently global, so users don't have to care about this - // weird trick. - #[allow(non_snake_case)] - mod $fut_type { + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + ok(req.method().clone()) + } +} - // Bring everything into scope, so we don't need - // redundant imports - use super::*; +#[doc(hidden)] +impl FromRequest for () { + type Error = Infallible; + type Future = Ready>; - /// A helper struct to allow us to pin-project through - /// to individual fields - #[pin_project::pin_project] - struct FutWrapper<$($T: FromRequest),+>($(#[pin] $T::Future),+); + fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future { + ok(()) + } +} - /// FromRequest implementation for tuple - #[doc(hidden)] - #[allow(unused_parens)] - impl<$($T: FromRequest + 'static),+> FromRequest for ($($T,)+) - { - type Error = Error; - type Future = $fut_type<$($T),+>; - type Config = ($($T::Config),+); +#[doc(hidden)] +#[allow(non_snake_case)] +mod tuple_from_req { + use super::*; - fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - $fut_type { - items: <($(Option<$T>,)+)>::default(), - futs: FutWrapper($($T::from_request(req, payload),)+), + macro_rules! tuple_from_req { + ($fut: ident; $($T: ident),*) => { + /// FromRequest implementation for tuple + #[allow(unused_parens)] + impl<$($T: FromRequest + 'static),+> FromRequest for ($($T,)+) + { + type Error = Error; + type Future = $fut<$($T),+>; + + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { + $fut { + $( + $T: ExtractFuture::Future { + fut: $T::from_request(req, payload) + }, + )+ + } } } - } - #[doc(hidden)] - #[pin_project::pin_project] - pub struct $fut_type<$($T: FromRequest),+> { - items: ($(Option<$T>,)+), - #[pin] - futs: FutWrapper<$($T,)+>, - } + pin_project! { + pub struct $fut<$($T: FromRequest),+> { + $( + #[pin] + $T: ExtractFuture<$T::Future, $T>, + )+ + } + } - impl<$($T: FromRequest),+> Future for $fut_type<$($T),+> - { - type Output = Result<($($T,)+), Error>; + impl<$($T: FromRequest),+> Future for $fut<$($T),+> + { + type Output = Result<($($T,)+), Error>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); - let mut ready = true; - $( - if this.items.$n.is_none() { - match this.futs.as_mut().project().$n.poll(cx) { - Poll::Ready(Ok(item)) => { - this.items.$n = Some(item); - } - Poll::Pending => ready = false, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), + let mut ready = true; + $( + match this.$T.as_mut().project() { + ExtractProj::Future { fut } => match fut.poll(cx) { + Poll::Ready(Ok(output)) => { + let _ = this.$T.as_mut().project_replace(ExtractFuture::Done { output }); + }, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), + Poll::Pending => ready = false, + }, + ExtractProj::Done { .. } => {}, + ExtractProj::Empty => unreachable!("FromRequest polled after finished"), } - } - )+ + )+ if ready { Poll::Ready(Ok( - ($(this.items.$n.take().unwrap(),)+) + ($( + match this.$T.project_replace(ExtractFuture::Empty) { + ExtractReplaceProj::Done { output } => output, + _ => unreachable!("FromRequest polled after finished"), + }, + )+) )) } else { Poll::Pending } + } } + }; + } + + pin_project! { + #[project = ExtractProj] + #[project_replace = ExtractReplaceProj] + enum ExtractFuture { + Future { + #[pin] + fut: Fut + }, + Done { + output: Res, + }, + Empty } } -}); -#[rustfmt::skip] -mod m { - use super::*; - -tuple_from_req!(TupleFromRequest1, (0, A)); -tuple_from_req!(TupleFromRequest2, (0, A), (1, B)); -tuple_from_req!(TupleFromRequest3, (0, A), (1, B), (2, C)); -tuple_from_req!(TupleFromRequest4, (0, A), (1, B), (2, C), (3, D)); -tuple_from_req!(TupleFromRequest5, (0, A), (1, B), (2, C), (3, D), (4, E)); -tuple_from_req!(TupleFromRequest6, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F)); -tuple_from_req!(TupleFromRequest7, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G)); -tuple_from_req!(TupleFromRequest8, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H)); -tuple_from_req!(TupleFromRequest9, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I)); -tuple_from_req!(TupleFromRequest10, (0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I), (9, J)); + tuple_from_req! { TupleFromRequest1; A } + tuple_from_req! { TupleFromRequest2; A, B } + tuple_from_req! { TupleFromRequest3; A, B, C } + tuple_from_req! { TupleFromRequest4; A, B, C, D } + tuple_from_req! { TupleFromRequest5; A, B, C, D, E } + tuple_from_req! { TupleFromRequest6; A, B, C, D, E, F } + tuple_from_req! { TupleFromRequest7; A, B, C, D, E, F, G } + tuple_from_req! { TupleFromRequest8; A, B, C, D, E, F, G, H } + tuple_from_req! { TupleFromRequest9; A, B, C, D, E, F, G, H, I } + tuple_from_req! { TupleFromRequest10; A, B, C, D, E, F, G, H, I, J } } #[cfg(test)] mod tests { - use actix_http::http::header; + use actix_http::header; use bytes::Bytes; - use serde_derive::Deserialize; + use serde::Deserialize; use super::*; use crate::test::TestRequest; @@ -415,4 +485,40 @@ mod tests { .unwrap(); assert!(r.is_err()); } + + #[actix_rt::test] + async fn test_uri() { + let req = TestRequest::default().uri("/foo/bar").to_http_request(); + let uri = Uri::extract(&req).await.unwrap(); + assert_eq!(uri.path(), "/foo/bar"); + } + + #[actix_rt::test] + async fn test_method() { + let req = TestRequest::default().method(Method::GET).to_http_request(); + let method = Method::extract(&req).await.unwrap(); + assert_eq!(method, Method::GET); + } + + #[actix_rt::test] + async fn test_concurrent() { + let (req, mut pl) = TestRequest::default() + .uri("/foo/bar") + .method(Method::GET) + .insert_header((header::CONTENT_TYPE, "application/x-www-form-urlencoded")) + .insert_header((header::CONTENT_LENGTH, "11")) + .set_payload(Bytes::from_static(b"hello=world")) + .to_http_parts(); + let (method, uri, form) = <(Method, Uri, Form)>::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(method, Method::GET); + assert_eq!(uri.path(), "/foo/bar"); + assert_eq!( + form, + Form(Info { + hello: "world".into() + }) + ); + } } diff --git a/src/guard.rs b/src/guard.rs index 5d0de58c2..f4200a382 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -1,160 +1,252 @@ -//! Route match guards. +//! Route guards. //! -//! Guards are one of the ways how actix-web router chooses a -//! handler service. In essence it is just a function that accepts a -//! reference to a `RequestHead` instance and returns a boolean. -//! It is possible to add guards to *scopes*, *resources* -//! and *routes*. Actix provide several guards by default, like various -//! http methods, header, etc. To become a guard, type must implement `Guard` -//! trait. Simple functions could be guards as well. +//! Guards are used during routing to help select a matching service or handler using some aspect of +//! the request; though guards should not be used for path matching since it is a built-in function +//! of the Actix Web router. //! -//! Guards can not modify the request object. But it is possible -//! to store extra attributes on a request by using the `Extensions` container. -//! Extensions containers are available via the `RequestHead::extensions()` method. +//! Guards can be used on [`Scope`]s, [`Resource`]s, [`Route`]s, and other custom services. //! -//! ```rust -//! use actix_web::{web, http, dev, guard, App, HttpResponse}; +//! Fundamentally, a guard is a predicate function that receives a reference to a request context +//! object and returns a boolean; true if the request _should_ be handled by the guarded service +//! or handler. This interface is defined by the [`Guard`] trait. //! -//! fn main() { -//! App::new().service(web::resource("/index.html").route( -//! web::route() -//! .guard(guard::Post()) -//! .guard(guard::fn_guard(|head| head.method == http::Method::GET)) -//! .to(|| HttpResponse::MethodNotAllowed())) -//! ); -//! } +//! Commonly-used guards are provided in this module as well as a way of creating a guard from a +//! closure ([`fn_guard`]). The [`Not`], [`Any`], and [`All`] guards are noteworthy, as they can be +//! used to compose other guards in a more flexible and semantic way than calling `.guard(...)` on +//! services multiple times (which might have different combining behavior than you want). +//! +//! There are shortcuts for routes with method guards in the [`web`](crate::web) module: +//! [`web::get()`](crate::web::get), [`web::post()`](crate::web::post), etc. The routes created by +//! the following calls are equivalent: +//! - `web::get()` (recommended form) +//! - `web::route().guard(guard::Get())` +//! +//! Guards can not modify anything about the request. However, it is possible to store extra +//! attributes in the request-local data container obtained with [`GuardContext::req_data_mut`]. +//! +//! Guards can prevent resource definitions from overlapping which, when only considering paths, +//! would result in inaccessible routes. See the [`Host`] guard for an example of virtual hosting. +//! +//! # Examples +//! In the following code, the `/guarded` resource has one defined route whose handler will only be +//! called if the request method is `POST` and there is a request header with name and value equal +//! to `x-guarded` and `secret`, respectively. //! ``` -#![allow(non_snake_case)] -use std::convert::TryFrom; +//! use actix_web::{web, http::Method, guard, HttpResponse}; +//! +//! web::resource("/guarded").route( +//! web::route() +//! .guard(guard::Any(guard::Get()).or(guard::Post())) +//! .guard(guard::Header("x-guarded", "secret")) +//! .to(|| HttpResponse::Ok()) +//! ); +//! ``` +//! +//! [`Scope`]: crate::Scope::guard() +//! [`Resource`]: crate::Resource::guard() +//! [`Route`]: crate::Route::guard() -use actix_http::http::{self, header, uri::Uri}; -use actix_http::RequestHead; +use std::{ + cell::{Ref, RefMut}, + convert::TryFrom, + rc::Rc, +}; -/// Trait defines resource guards. Guards are used for route selection. -/// -/// Guards can not modify the request object. But it is possible -/// to store extra attributes on a request by using the `Extensions` container. -/// Extensions containers are available via the `RequestHead::extensions()` method. -pub trait Guard { - /// Check if request matches predicate - fn check(&self, request: &RequestHead) -> bool; +use actix_http::{header, uri::Uri, Extensions, Method as HttpMethod, RequestHead}; + +use crate::{http::header::Header, service::ServiceRequest}; + +/// Provides access to request parts that are useful during routing. +#[derive(Debug)] +pub struct GuardContext<'a> { + pub(crate) req: &'a ServiceRequest, } -/// Create guard object for supplied function. +impl<'a> GuardContext<'a> { + /// Returns reference to the request head. + #[inline] + pub fn head(&self) -> &RequestHead { + self.req.head() + } + + /// Returns reference to the request-local data container. + #[inline] + pub fn req_data(&self) -> Ref<'a, Extensions> { + self.req.req_data() + } + + /// Returns mutable reference to the request-local data container. + #[inline] + pub fn req_data_mut(&self) -> RefMut<'a, Extensions> { + self.req.req_data_mut() + } + + /// Extracts a typed header from the request. + /// + /// Returns `None` if parsing `H` fails. + /// + /// # Examples + /// ``` + /// use actix_web::{guard::fn_guard, http::header}; + /// + /// let image_accept_guard = fn_guard(|ctx| { + /// match ctx.header::() { + /// Some(hdr) => hdr.preference() == "image/*", + /// None => false, + /// } + /// }); + /// ``` + #[inline] + pub fn header(&self) -> Option { + H::parse(self.req).ok() + } +} + +/// Interface for routing guards. /// -/// ```rust -/// use actix_web::{guard, web, App, HttpResponse}; +/// See [module level documentation](self) for more. +pub trait Guard { + /// Returns true if predicate condition is met for a given request. + fn check(&self, ctx: &GuardContext<'_>) -> bool; +} + +impl Guard for Rc { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + (**self).check(ctx) + } +} + +/// Creates a guard using the given function. /// -/// fn main() { -/// App::new().service(web::resource("/index.html").route( -/// web::route() -/// .guard( -/// guard::fn_guard( -/// |req| req.headers() -/// .contains_key("content-type"))) -/// .to(|| HttpResponse::MethodNotAllowed())) -/// ); -/// } +/// # Examples +/// ``` +/// use actix_web::{guard, web, HttpResponse}; +/// +/// web::route() +/// .guard(guard::fn_guard(|ctx| { +/// ctx.head().headers().contains_key("content-type") +/// })) +/// .to(|| HttpResponse::Ok()); /// ``` pub fn fn_guard(f: F) -> impl Guard where - F: Fn(&RequestHead) -> bool, + F: Fn(&GuardContext<'_>) -> bool, { FnGuard(f) } -struct FnGuard bool>(F); +struct FnGuard) -> bool>(F); impl Guard for FnGuard where - F: Fn(&RequestHead) -> bool, + F: Fn(&GuardContext<'_>) -> bool, { - fn check(&self, head: &RequestHead) -> bool { - (self.0)(head) + fn check(&self, ctx: &GuardContext<'_>) -> bool { + (self.0)(ctx) } } impl Guard for F where - F: Fn(&RequestHead) -> bool, + F: Fn(&GuardContext<'_>) -> bool, { - fn check(&self, head: &RequestHead) -> bool { - (self)(head) + fn check(&self, ctx: &GuardContext<'_>) -> bool { + (self)(ctx) } } -/// Return guard that matches if any of supplied guards. +/// Creates a guard that matches if any added guards match. /// -/// ```rust -/// use actix_web::{web, guard, App, HttpResponse}; -/// -/// fn main() { -/// App::new().service(web::resource("/index.html").route( -/// web::route() -/// .guard(guard::Any(guard::Get()).or(guard::Post())) -/// .to(|| HttpResponse::MethodNotAllowed())) -/// ); -/// } +/// # Examples +/// The handler below will be called for either request method `GET` or `POST`. /// ``` +/// use actix_web::{web, guard, HttpResponse}; +/// +/// web::route() +/// .guard( +/// guard::Any(guard::Get()) +/// .or(guard::Post())) +/// .to(|| HttpResponse::Ok()); +/// ``` +#[allow(non_snake_case)] pub fn Any(guard: F) -> AnyGuard { - AnyGuard(vec![Box::new(guard)]) + AnyGuard { + guards: vec![Box::new(guard)], + } } -/// Matches any of supplied guards. -pub struct AnyGuard(Vec>); +/// A collection of guards that match if the disjunction of their `check` outcomes is true. +/// +/// That is, only one contained guard needs to match in order for the aggregate guard to match. +/// +/// Construct an `AnyGuard` using [`Any`]. +pub struct AnyGuard { + guards: Vec>, +} impl AnyGuard { - /// Add guard to a list of guards to check + /// Adds new guard to the collection of guards to check. pub fn or(mut self, guard: F) -> Self { - self.0.push(Box::new(guard)); + self.guards.push(Box::new(guard)); self } } impl Guard for AnyGuard { - fn check(&self, req: &RequestHead) -> bool { - for p in &self.0 { - if p.check(req) { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + for guard in &self.guards { + if guard.check(ctx) { return true; } } + false } } -/// Return guard that matches if all of the supplied guards. +/// Creates a guard that matches if all added guards match. /// -/// ```rust -/// use actix_web::{guard, web, App, HttpResponse}; -/// -/// fn main() { -/// App::new().service(web::resource("/index.html").route( -/// web::route() -/// .guard( -/// guard::All(guard::Get()).and(guard::Header("content-type", "text/plain"))) -/// .to(|| HttpResponse::MethodNotAllowed())) -/// ); -/// } +/// # Examples +/// The handler below will only be called if the request method is `GET` **and** the specified +/// header name and value match exactly. /// ``` +/// use actix_web::{guard, web, HttpResponse}; +/// +/// web::route() +/// .guard( +/// guard::All(guard::Get()) +/// .and(guard::Header("accept", "text/plain")) +/// ) +/// .to(|| HttpResponse::Ok()); +/// ``` +#[allow(non_snake_case)] pub fn All(guard: F) -> AllGuard { - AllGuard(vec![Box::new(guard)]) + AllGuard { + guards: vec![Box::new(guard)], + } } -/// Matches if all of supplied guards. -pub struct AllGuard(Vec>); +/// A collection of guards that match if the conjunction of their `check` outcomes is true. +/// +/// That is, **all** contained guard needs to match in order for the aggregate guard to match. +/// +/// Construct an `AllGuard` using [`All`]. +pub struct AllGuard { + guards: Vec>, +} impl AllGuard { - /// Add new guard to the list of guards to check + /// Adds new guard to the collection of guards to check. pub fn and(mut self, guard: F) -> Self { - self.0.push(Box::new(guard)); + self.guards.push(Box::new(guard)); self } } impl Guard for AllGuard { - fn check(&self, request: &RequestHead) -> bool { - for p in &self.0 { - if !p.check(request) { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + for guard in &self.guards { + if !guard.check(ctx) { return false; } } @@ -162,338 +254,430 @@ impl Guard for AllGuard { } } -/// Return guard that matches if supplied guard does not match. -pub fn Not(guard: F) -> NotGuard { - NotGuard(Box::new(guard)) -} +/// Wraps a guard and inverts the outcome of it's `Guard` implementation. +/// +/// # Examples +/// The handler below will be called for any request method apart from `GET`. +/// ``` +/// use actix_web::{guard, web, HttpResponse}; +/// +/// web::route() +/// .guard(guard::Not(guard::Get())) +/// .to(|| HttpResponse::Ok()); +/// ``` +pub struct Not(pub G); -#[doc(hidden)] -pub struct NotGuard(Box); - -impl Guard for NotGuard { - fn check(&self, request: &RequestHead) -> bool { - !self.0.check(request) +impl Guard for Not { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + !self.0.check(ctx) } } -/// HTTP method guard. -#[doc(hidden)] -pub struct MethodGuard(http::Method); - -impl Guard for MethodGuard { - fn check(&self, request: &RequestHead) -> bool { - request.method == self.0 - } -} - -/// Guard to match *GET* HTTP method. -pub fn Get() -> MethodGuard { - MethodGuard(http::Method::GET) -} - -/// Predicate to match *POST* HTTP method. -pub fn Post() -> MethodGuard { - MethodGuard(http::Method::POST) -} - -/// Predicate to match *PUT* HTTP method. -pub fn Put() -> MethodGuard { - MethodGuard(http::Method::PUT) -} - -/// Predicate to match *DELETE* HTTP method. -pub fn Delete() -> MethodGuard { - MethodGuard(http::Method::DELETE) -} - -/// Predicate to match *HEAD* HTTP method. -pub fn Head() -> MethodGuard { - MethodGuard(http::Method::HEAD) -} - -/// Predicate to match *OPTIONS* HTTP method. -pub fn Options() -> MethodGuard { - MethodGuard(http::Method::OPTIONS) -} - -/// Predicate to match *CONNECT* HTTP method. -pub fn Connect() -> MethodGuard { - MethodGuard(http::Method::CONNECT) -} - -/// Predicate to match *PATCH* HTTP method. -pub fn Patch() -> MethodGuard { - MethodGuard(http::Method::PATCH) -} - -/// Predicate to match *TRACE* HTTP method. -pub fn Trace() -> MethodGuard { - MethodGuard(http::Method::TRACE) -} - -/// Predicate to match specified HTTP method. -pub fn Method(method: http::Method) -> MethodGuard { +/// Creates a guard that matches a specified HTTP method. +#[allow(non_snake_case)] +pub fn Method(method: HttpMethod) -> impl Guard { MethodGuard(method) } -/// Return predicate that matches if request contains specified header and -/// value. -pub fn Header(name: &'static str, value: &'static str) -> HeaderGuard { +/// HTTP method guard. +struct MethodGuard(HttpMethod); + +impl Guard for MethodGuard { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + ctx.head().method == self.0 + } +} + +macro_rules! method_guard { + ($method_fn:ident, $method_const:ident) => { + #[doc = concat!("Creates a guard that matches the `", stringify!($method_const), "` request method.")] + /// + /// # Examples + #[doc = concat!("The route in this example will only respond to `", stringify!($method_const), "` requests.")] + /// ``` + /// use actix_web::{guard, web, HttpResponse}; + /// + /// web::route() + #[doc = concat!(" .guard(guard::", stringify!($method_fn), "())")] + /// .to(|| HttpResponse::Ok()); + /// ``` + #[allow(non_snake_case)] + pub fn $method_fn() -> impl Guard { + MethodGuard(HttpMethod::$method_const) + } + }; +} + +method_guard!(Get, GET); +method_guard!(Post, POST); +method_guard!(Put, PUT); +method_guard!(Delete, DELETE); +method_guard!(Head, HEAD); +method_guard!(Options, OPTIONS); +method_guard!(Connect, CONNECT); +method_guard!(Patch, PATCH); +method_guard!(Trace, TRACE); + +/// Creates a guard that matches if request contains given header name and value. +/// +/// # Examples +/// The handler below will be called when the request contains an `x-guarded` header with value +/// equal to `secret`. +/// ``` +/// use actix_web::{guard, web, HttpResponse}; +/// +/// web::route() +/// .guard(guard::Header("x-guarded", "secret")) +/// .to(|| HttpResponse::Ok()); +/// ``` +#[allow(non_snake_case)] +pub fn Header(name: &'static str, value: &'static str) -> impl Guard { HeaderGuard( header::HeaderName::try_from(name).unwrap(), header::HeaderValue::from_static(value), ) } -#[doc(hidden)] -pub struct HeaderGuard(header::HeaderName, header::HeaderValue); +struct HeaderGuard(header::HeaderName, header::HeaderValue); impl Guard for HeaderGuard { - fn check(&self, req: &RequestHead) -> bool { - if let Some(val) = req.headers.get(&self.0) { + fn check(&self, ctx: &GuardContext<'_>) -> bool { + if let Some(val) = ctx.head().headers.get(&self.0) { return val == self.1; } + false } } -/// Return predicate that matches if request contains specified Host name. +/// Creates a guard that matches requests targetting a specific host. /// -/// ```rust -/// use actix_web::{web, guard::Host, App, HttpResponse}; +/// # Matching Host +/// This guard will: +/// - match against the `Host` header, if present; +/// - fall-back to matching against the request target's host, if present; +/// - return false if host cannot be determined; /// -/// fn main() { -/// App::new().service( -/// web::resource("/index.html") -/// .guard(Host("www.rust-lang.org")) -/// .to(|| HttpResponse::MethodNotAllowed()) -/// ); -/// } +/// # Matching Scheme +/// Optionally, this guard can match against the host's scheme. Set the scheme for matching using +/// `Host(host).scheme(protocol)`. If the request's scheme cannot be determined, it will not prevent +/// the guard from matching successfully. +/// +/// # Examples +/// The [module-level documentation](self) has an example of virtual hosting using `Host` guards. +/// +/// The example below additionally guards on the host URI's scheme. This could allow routing to +/// different handlers for `http:` vs `https:` visitors; to redirect, for example. /// ``` -pub fn Host>(host: H) -> HostGuard { - HostGuard(host.as_ref().to_string(), None) +/// use actix_web::{web, guard::Host, HttpResponse}; +/// +/// web::scope("/admin") +/// .guard(Host("admin.rust-lang.org").scheme("https")) +/// .default_service(web::to(|| HttpResponse::Ok().body("admin connection is secure"))); +/// ``` +/// +/// The `Host` guard can be used to set up some form of [virtual hosting] within a single app. +/// Overlapping scope prefixes are usually discouraged, but when combined with non-overlapping guard +/// definitions they become safe to use in this way. Without these host guards, only routes under +/// the first-to-be-defined scope would be accessible. You can test this locally using `127.0.0.1` +/// and `localhost` as the `Host` guards. +/// ``` +/// use actix_web::{web, http::Method, guard, App, HttpResponse}; +/// +/// App::new() +/// .service( +/// web::scope("") +/// .guard(guard::Host("www.rust-lang.org")) +/// .default_service(web::to(|| HttpResponse::Ok().body("marketing site"))), +/// ) +/// .service( +/// web::scope("") +/// .guard(guard::Host("play.rust-lang.org")) +/// .default_service(web::to(|| HttpResponse::Ok().body("playground frontend"))), +/// ); +/// ``` +/// +/// [virtual hosting]: https://en.wikipedia.org/wiki/Virtual_hosting +#[allow(non_snake_case)] +pub fn Host(host: impl AsRef) -> HostGuard { + HostGuard { + host: host.as_ref().to_string(), + scheme: None, + } } fn get_host_uri(req: &RequestHead) -> Option { - use core::str::FromStr; req.headers .get(header::HOST) .and_then(|host_value| host_value.to_str().ok()) .or_else(|| req.uri.host()) - .map(|host: &str| Uri::from_str(host).ok()) - .and_then(|host_success| host_success) + .and_then(|host| host.parse().ok()) } #[doc(hidden)] -pub struct HostGuard(String, Option); +pub struct HostGuard { + host: String, + scheme: Option, +} impl HostGuard { /// Set request scheme to match pub fn scheme>(mut self, scheme: H) -> HostGuard { - self.1 = Some(scheme.as_ref().to_string()); + self.scheme = Some(scheme.as_ref().to_string()); self } } impl Guard for HostGuard { - fn check(&self, req: &RequestHead) -> bool { - let req_host_uri = if let Some(uri) = get_host_uri(req) { - uri - } else { - return false; + fn check(&self, ctx: &GuardContext<'_>) -> bool { + // parse host URI from header or request target + let req_host_uri = match get_host_uri(ctx.head()) { + Some(uri) => uri, + + // no match if host cannot be determined + None => return false, }; - if let Some(uri_host) = req_host_uri.host() { - if self.0 != uri_host { - return false; - } - } else { - return false; + match req_host_uri.host() { + // fall through to scheme checks + Some(uri_host) if self.host == uri_host => {} + + // Either: + // - request's host does not match guard's host; + // - It was possible that the parsed URI from request target did not contain a host. + _ => return false, } - if let Some(ref scheme) = self.1 { + if let Some(ref scheme) = self.scheme { if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() { return scheme == req_host_uri_scheme; } + + // TODO: is the the correct behavior? + // falls through if scheme cannot be determined } + // all conditions passed true } } #[cfg(test)] mod tests { - use actix_http::http::{header, Method}; + use actix_http::{header, Method}; use super::*; use crate::test::TestRequest; #[test] - fn test_header() { + fn header_match() { let req = TestRequest::default() .insert_header((header::TRANSFER_ENCODING, "chunked")) - .to_http_request(); + .to_srv_request(); - let pred = Header("transfer-encoding", "chunked"); - assert!(pred.check(req.head())); + let hdr = Header("transfer-encoding", "chunked"); + assert!(hdr.check(&req.guard_ctx())); - let pred = Header("transfer-encoding", "other"); - assert!(!pred.check(req.head())); + let hdr = Header("transfer-encoding", "other"); + assert!(!hdr.check(&req.guard_ctx())); - let pred = Header("content-type", "other"); - assert!(!pred.check(req.head())); + let hdr = Header("content-type", "chunked"); + assert!(!hdr.check(&req.guard_ctx())); + + let hdr = Header("content-type", "other"); + assert!(!hdr.check(&req.guard_ctx())); } #[test] - fn test_host() { + fn host_from_header() { let req = TestRequest::default() .insert_header(( header::HOST, header::HeaderValue::from_static("www.rust-lang.org"), )) - .to_http_request(); + .to_srv_request(); - let pred = Host("www.rust-lang.org"); - assert!(pred.check(req.head())); + let host = Host("www.rust-lang.org"); + assert!(host.check(&req.guard_ctx())); - let pred = Host("www.rust-lang.org").scheme("https"); - assert!(pred.check(req.head())); + let host = Host("www.rust-lang.org").scheme("https"); + assert!(host.check(&req.guard_ctx())); - let pred = Host("blog.rust-lang.org"); - assert!(!pred.check(req.head())); + let host = Host("blog.rust-lang.org"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("blog.rust-lang.org").scheme("https"); - assert!(!pred.check(req.head())); + let host = Host("blog.rust-lang.org").scheme("https"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("crates.io"); - assert!(!pred.check(req.head())); + let host = Host("crates.io"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("localhost"); - assert!(!pred.check(req.head())); + let host = Host("localhost"); + assert!(!host.check(&req.guard_ctx())); } #[test] - fn test_host_scheme() { + fn host_without_header() { + let req = TestRequest::default() + .uri("www.rust-lang.org") + .to_srv_request(); + + let host = Host("www.rust-lang.org"); + assert!(host.check(&req.guard_ctx())); + + let host = Host("www.rust-lang.org").scheme("https"); + assert!(host.check(&req.guard_ctx())); + + let host = Host("blog.rust-lang.org"); + assert!(!host.check(&req.guard_ctx())); + + let host = Host("blog.rust-lang.org").scheme("https"); + assert!(!host.check(&req.guard_ctx())); + + let host = Host("crates.io"); + assert!(!host.check(&req.guard_ctx())); + + let host = Host("localhost"); + assert!(!host.check(&req.guard_ctx())); + } + + #[test] + fn host_scheme() { let req = TestRequest::default() .insert_header(( header::HOST, header::HeaderValue::from_static("https://www.rust-lang.org"), )) - .to_http_request(); + .to_srv_request(); - let pred = Host("www.rust-lang.org").scheme("https"); - assert!(pred.check(req.head())); + let host = Host("www.rust-lang.org").scheme("https"); + assert!(host.check(&req.guard_ctx())); - let pred = Host("www.rust-lang.org"); - assert!(pred.check(req.head())); + let host = Host("www.rust-lang.org"); + assert!(host.check(&req.guard_ctx())); - let pred = Host("www.rust-lang.org").scheme("http"); - assert!(!pred.check(req.head())); + let host = Host("www.rust-lang.org").scheme("http"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("blog.rust-lang.org"); - assert!(!pred.check(req.head())); + let host = Host("blog.rust-lang.org"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("blog.rust-lang.org").scheme("https"); - assert!(!pred.check(req.head())); + let host = Host("blog.rust-lang.org").scheme("https"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("crates.io").scheme("https"); - assert!(!pred.check(req.head())); + let host = Host("crates.io").scheme("https"); + assert!(!host.check(&req.guard_ctx())); - let pred = Host("localhost"); - assert!(!pred.check(req.head())); + let host = Host("localhost"); + assert!(!host.check(&req.guard_ctx())); } #[test] - fn test_host_without_header() { + fn method_guards() { + let get_req = TestRequest::get().to_srv_request(); + let post_req = TestRequest::post().to_srv_request(); + + assert!(Get().check(&get_req.guard_ctx())); + assert!(!Get().check(&post_req.guard_ctx())); + + assert!(Post().check(&post_req.guard_ctx())); + assert!(!Post().check(&get_req.guard_ctx())); + + let req = TestRequest::put().to_srv_request(); + assert!(Put().check(&req.guard_ctx())); + assert!(!Put().check(&get_req.guard_ctx())); + + let req = TestRequest::patch().to_srv_request(); + assert!(Patch().check(&req.guard_ctx())); + assert!(!Patch().check(&get_req.guard_ctx())); + + let r = TestRequest::delete().to_srv_request(); + assert!(Delete().check(&r.guard_ctx())); + assert!(!Delete().check(&get_req.guard_ctx())); + + let req = TestRequest::default().method(Method::HEAD).to_srv_request(); + assert!(Head().check(&req.guard_ctx())); + assert!(!Head().check(&get_req.guard_ctx())); + let req = TestRequest::default() - .uri("www.rust-lang.org") - .to_http_request(); - - let pred = Host("www.rust-lang.org"); - assert!(pred.check(req.head())); - - let pred = Host("www.rust-lang.org").scheme("https"); - assert!(pred.check(req.head())); - - let pred = Host("blog.rust-lang.org"); - assert!(!pred.check(req.head())); - - let pred = Host("blog.rust-lang.org").scheme("https"); - assert!(!pred.check(req.head())); - - let pred = Host("crates.io"); - assert!(!pred.check(req.head())); - - let pred = Host("localhost"); - assert!(!pred.check(req.head())); - } - - #[test] - fn test_methods() { - let req = TestRequest::default().to_http_request(); - let req2 = TestRequest::default() - .method(Method::POST) - .to_http_request(); - - assert!(Get().check(req.head())); - assert!(!Get().check(req2.head())); - assert!(Post().check(req2.head())); - assert!(!Post().check(req.head())); - - let r = TestRequest::default().method(Method::PUT).to_http_request(); - assert!(Put().check(r.head())); - assert!(!Put().check(req.head())); - - let r = TestRequest::default() - .method(Method::DELETE) - .to_http_request(); - assert!(Delete().check(r.head())); - assert!(!Delete().check(req.head())); - - let r = TestRequest::default() - .method(Method::HEAD) - .to_http_request(); - assert!(Head().check(r.head())); - assert!(!Head().check(req.head())); - - let r = TestRequest::default() .method(Method::OPTIONS) - .to_http_request(); - assert!(Options().check(r.head())); - assert!(!Options().check(req.head())); + .to_srv_request(); + assert!(Options().check(&req.guard_ctx())); + assert!(!Options().check(&get_req.guard_ctx())); - let r = TestRequest::default() + let req = TestRequest::default() .method(Method::CONNECT) - .to_http_request(); - assert!(Connect().check(r.head())); - assert!(!Connect().check(req.head())); + .to_srv_request(); + assert!(Connect().check(&req.guard_ctx())); + assert!(!Connect().check(&get_req.guard_ctx())); - let r = TestRequest::default() - .method(Method::PATCH) - .to_http_request(); - assert!(Patch().check(r.head())); - assert!(!Patch().check(req.head())); - - let r = TestRequest::default() + let req = TestRequest::default() .method(Method::TRACE) - .to_http_request(); - assert!(Trace().check(r.head())); - assert!(!Trace().check(req.head())); + .to_srv_request(); + assert!(Trace().check(&req.guard_ctx())); + assert!(!Trace().check(&get_req.guard_ctx())); } #[test] - fn test_preds() { - let r = TestRequest::default() + fn aggregate_any() { + let req = TestRequest::default() .method(Method::TRACE) - .to_http_request(); + .to_srv_request(); - assert!(Not(Get()).check(r.head())); - assert!(!Not(Trace()).check(r.head())); + assert!(Any(Trace()).check(&req.guard_ctx())); + assert!(Any(Trace()).or(Get()).check(&req.guard_ctx())); + assert!(!Any(Get()).or(Get()).check(&req.guard_ctx())); + } - assert!(All(Trace()).and(Trace()).check(r.head())); - assert!(!All(Get()).and(Trace()).check(r.head())); + #[test] + fn aggregate_all() { + let req = TestRequest::default() + .method(Method::TRACE) + .to_srv_request(); - assert!(Any(Get()).or(Trace()).check(r.head())); - assert!(!Any(Get()).or(Get()).check(r.head())); + assert!(All(Trace()).check(&req.guard_ctx())); + assert!(All(Trace()).and(Trace()).check(&req.guard_ctx())); + assert!(!All(Trace()).and(Get()).check(&req.guard_ctx())); + } + + #[test] + fn nested_not() { + let req = TestRequest::default().to_srv_request(); + + let get = Get(); + assert!(get.check(&req.guard_ctx())); + + let not_get = Not(get); + assert!(!not_get.check(&req.guard_ctx())); + + let not_not_get = Not(not_get); + assert!(not_not_get.check(&req.guard_ctx())); + } + + #[test] + fn function_guard() { + let domain = "rust-lang.org".to_owned(); + let guard = fn_guard(|ctx| ctx.head().uri.host().unwrap().ends_with(&domain)); + + let req = TestRequest::default() + .uri("blog.rust-lang.org") + .to_srv_request(); + assert!(guard.check(&req.guard_ctx())); + + let req = TestRequest::default().uri("crates.io").to_srv_request(); + assert!(!guard.check(&req.guard_ctx())); + } + + #[test] + fn mega_nesting() { + let guard = fn_guard(|ctx| All(Not(Any(Not(Trace())))).check(ctx)); + + let req = TestRequest::default().to_srv_request(); + assert!(!guard.check(&req.guard_ctx())); + + let req = TestRequest::default() + .method(Method::TRACE) + .to_srv_request(); + assert!(guard.check(&req.guard_ctx())); } } diff --git a/src/handler.rs b/src/handler.rs index 0016b741e..d458e22e1 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,206 +1,150 @@ use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{Context, Poll}; -use actix_http::{Error, Response}; -use actix_service::{Service, ServiceFactory}; -use futures_util::future::{ready, Ready}; -use futures_util::ready; -use pin_project::pin_project; +use actix_service::{boxed, fn_service}; -use crate::extract::FromRequest; -use crate::request::HttpRequest; -use crate::responder::Responder; -use crate::service::{ServiceRequest, ServiceResponse}; +use crate::{ + service::{BoxedHttpServiceFactory, ServiceRequest, ServiceResponse}, + FromRequest, HttpResponse, Responder, +}; -/// A request handler is an async function that accepts zero or more parameters that can be -/// extracted from a request (ie, [`impl FromRequest`](crate::FromRequest)) and returns a type that can be converted into -/// an [`HttpResponse`](crate::HttpResponse) (ie, [`impl Responder`](crate::Responder)). +/// The interface for request handlers. /// -/// If you got the error `the trait Handler<_, _, _> is not implemented`, then your function is not -/// a valid handler. See [Request Handlers](https://actix.rs/docs/handlers/) for more information. -pub trait Handler: Clone + 'static -where - R: Future, - R::Output: Responder, -{ - fn call(&self, param: T) -> R; +/// # What Is A Request Handler +/// A request handler has three requirements: +/// 1. It is an async function (or a function/closure that returns an appropriate future); +/// 1. The function accepts zero or more parameters that implement [`FromRequest`]; +/// 1. The async function (or future) resolves to a type that can be converted into an +/// [`HttpResponse`] (i.e., it implements the [`Responder`] trait). +/// +/// # Compiler Errors +/// If you get the error `the trait Handler<_> is not implemented`, then your handler does not +/// fulfill one or more of the above requirements. +/// +/// Unfortunately we cannot provide a better compile error message (while keeping the trait's +/// flexibility) unless a stable alternative to [`#[rustc_on_unimplemented]`][on_unimpl] is added +/// to Rust. +/// +/// # How Do Handlers Receive Variable Numbers Of Arguments +/// Rest assured there is no macro magic here; it's just traits. +/// +/// The first thing to note is that [`FromRequest`] is implemented for tuples (up to 12 in length). +/// +/// Secondly, the `Handler` trait is implemented for functions (up to an [arity] of 12) in a way +/// that aligns their parameter positions with a corresponding tuple of types (becoming the `Args` +/// type parameter for this trait). +/// +/// Thanks to Rust's type system, Actix Web can infer the function parameter types. During the +/// extraction step, the parameter types are described as a tuple type, [`from_request`] is run on +/// that tuple, and the `Handler::call` implementation for that particular function arity +/// destructures the tuple into it's component types and calls your handler function with them. +/// +/// In pseudo-code the process looks something like this: +/// ```ignore +/// async fn my_handler(body: String, state: web::Data) -> impl Responder { +/// ... +/// } +/// +/// // the function params above described as a tuple, names do not matter, only position +/// type InferredMyHandlerArgs = (String, web::Data); +/// +/// // create tuple of arguments to be passed to handler +/// let args = InferredMyHandlerArgs::from_request(&request, &payload).await; +/// +/// // call handler with argument tuple +/// let response = Handler::call(&my_handler, args).await; +/// +/// // which is effectively... +/// +/// let (body, state) = args; +/// let response = my_handler(body, state).await; +/// ``` +/// +/// This is the source code for the 2-parameter implementation of `Handler` to help illustrate the +/// bounds of the handler call after argument extraction: +/// ```ignore +/// impl Handler<(Arg1, Arg2), R> for Func +/// where +/// Func: Fn(Arg1, Arg2) -> R + Clone + 'static, +/// R: Future, +/// R::Output: Responder, +/// { +/// fn call(&self, (arg1, arg2): (Arg1, Arg2)) -> R { +/// (self)(arg1, arg2) +/// } +/// } +/// ``` +/// +/// [arity]: https://en.wikipedia.org/wiki/Arity +/// [`from_request`]: FromRequest::from_request +/// [on_unimpl]: https://github.com/rust-lang/rust/issues/29628 +pub trait Handler: Clone + 'static { + type Output; + type Future: Future; + + fn call(&self, args: Args) -> Self::Future; } -impl Handler<(), R> for F +pub(crate) fn handler_service(handler: F) -> BoxedHttpServiceFactory where - F: Fn() -> R + Clone + 'static, - R: Future, - R::Output: Responder, + F: Handler, + Args: FromRequest, + F::Output: Responder, { - fn call(&self, _: ()) -> R { - (self)() - } -} + boxed::factory(fn_service(move |req: ServiceRequest| { + let handler = handler.clone(); -#[doc(hidden)] -/// Extract arguments from request, run factory function and make response. -pub struct HandlerService -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - hnd: F, - _phantom: PhantomData<(T, R)>, -} + async move { + let (req, mut payload) = req.into_parts(); -impl HandlerService -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - pub fn new(hnd: F) -> Self { - Self { - hnd, - _phantom: PhantomData, + let res = match Args::from_request(&req, &mut payload).await { + Err(err) => HttpResponse::from_error(err), + + Ok(data) => handler + .call(data) + .await + .respond_to(&req) + .map_into_boxed_body(), + }; + + Ok(ServiceResponse::new(req, res)) } - } + })) } -impl Clone for HandlerService -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - fn clone(&self) -> Self { - Self { - hnd: self.hnd.clone(), - _phantom: PhantomData, - } - } -} - -impl ServiceFactory for HandlerService -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - type Response = ServiceResponse; - type Error = Error; - type Config = (); - type Service = Self; - type InitError = (); - type Future = Ready>; - - fn new_service(&self, _: ()) -> Self::Future { - ready(Ok(self.clone())) - } -} - -/// HandlerService is both it's ServiceFactory and Service Type. -impl Service for HandlerService -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - type Response = ServiceResponse; - type Error = Error; - type Future = HandlerServiceFuture; - - fn poll_ready(&self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&self, req: ServiceRequest) -> Self::Future { - let (req, mut payload) = req.into_parts(); - let fut = T::from_request(&req, &mut payload); - HandlerServiceFuture::Extract(fut, Some(req), self.hnd.clone()) - } -} - -#[doc(hidden)] -#[pin_project(project = HandlerProj)] -pub enum HandlerServiceFuture -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - Extract(#[pin] T::Future, Option, F), - Handle(#[pin] R, Option), -} - -impl Future for HandlerServiceFuture -where - F: Handler, - T: FromRequest, - R: Future, - R::Output: Responder, -{ - // Error type in this future is a placeholder type. - // all instances of error must be converted to ServiceResponse and return in Ok. - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - match self.as_mut().project() { - HandlerProj::Extract(fut, req, handle) => { - match ready!(fut.poll(cx)) { - Ok(item) => { - let fut = handle.call(item); - let state = HandlerServiceFuture::Handle(fut, req.take()); - self.as_mut().set(state); - } - Err(e) => { - let res: Response = e.into().into(); - let req = req.take().unwrap(); - return Poll::Ready(Ok(ServiceResponse::new(req, res))); - } - }; - } - HandlerProj::Handle(fut, req) => { - let res = ready!(fut.poll(cx)); - let req = req.take().unwrap(); - let res = res.respond_to(&req); - return Poll::Ready(Ok(ServiceResponse::new(req, res))); - } - } - } - } -} - -/// FromRequest trait impl for tuples -macro_rules! factory_tuple ({ $(($n:tt, $T:ident)),+} => { - impl Handler<($($T,)+), Res> for Func - where Func: Fn($($T,)+) -> Res + Clone + 'static, - Res: Future, - Res::Output: Responder, +/// Generates a [`Handler`] trait impl for N-ary functions where N is specified with a sequence of +/// space separated type parameters. +/// +/// # Examples +/// ```ignore +/// factory_tuple! {} // implements Handler for types: fn() -> R +/// factory_tuple! { A B C } // implements Handler for types: fn(A, B, C) -> R +/// ``` +macro_rules! factory_tuple ({ $($param:ident)* } => { + impl Handler<($($param,)*)> for Func + where Func: Fn($($param),*) -> Fut + Clone + 'static, + Fut: Future, { - fn call(&self, param: ($($T,)+)) -> Res { - (self)($(param.$n,)+) + type Output = Fut::Output; + type Future = Fut; + + #[inline] + #[allow(non_snake_case)] + fn call(&self, ($($param,)*): ($($param,)*)) -> Self::Future { + (self)($($param,)*) } } }); -#[rustfmt::skip] -mod m { - use super::*; - - factory_tuple!((0, A)); - factory_tuple!((0, A), (1, B)); - factory_tuple!((0, A), (1, B), (2, C)); - factory_tuple!((0, A), (1, B), (2, C), (3, D)); - factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E)); - factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F)); - factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G)); - factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H)); - factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I)); - factory_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I), (9, J)); -} +factory_tuple! {} +factory_tuple! { A } +factory_tuple! { A B } +factory_tuple! { A B C } +factory_tuple! { A B C D } +factory_tuple! { A B C D E } +factory_tuple! { A B C D E F } +factory_tuple! { A B C D E F G } +factory_tuple! { A B C D E F G H } +factory_tuple! { A B C D E F G H I } +factory_tuple! { A B C D E F G H I J } +factory_tuple! { A B C D E F G H I J K } +factory_tuple! { A B C D E F G H I J K L } diff --git a/src/helpers.rs b/src/helpers.rs new file mode 100644 index 000000000..1d2679fce --- /dev/null +++ b/src/helpers.rs @@ -0,0 +1,25 @@ +use std::io; + +use bytes::BufMut; + +/// An `io::Write`r that only requires mutable reference and assumes that there is space available +/// in the buffer for every write operation or that it can be extended implicitly (like +/// `bytes::BytesMut`, for example). +/// +/// This is slightly faster (~10%) than `bytes::buf::Writer` in such cases because it does not +/// perform a remaining length check before writing. +pub(crate) struct MutWriter<'a, B>(pub(crate) &'a mut B); + +impl<'a, B> io::Write for MutWriter<'a, B> +where + B: BufMut, +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.put_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} diff --git a/actix-http/src/header/common/accept.rs b/src/http/header/accept.rs similarity index 56% rename from actix-http/src/header/common/accept.rs rename to src/http/header/accept.rs index 775da3394..744c9b6e8 100644 --- a/actix-http/src/header/common/accept.rs +++ b/src/http/header/accept.rs @@ -2,11 +2,12 @@ use std::cmp::Ordering; use mime::Mime; -use crate::header::{qitem, QualityItem}; +use super::{common_header, QualityItem}; use crate::http::header; -header! { - /// `Accept` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.2) +common_header! { + /// `Accept` header, defined + /// in [RFC 7231 §5.3.2](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2) /// /// The `Accept` header field can be used by user agents to specify /// response media types that are acceptable. Accept header fields can @@ -15,8 +16,7 @@ header! { /// in-line image /// /// # ABNF - /// - /// ```text + /// ```plain /// Accept = #( media-range [ accept-params ] ) /// /// media-range = ( "*/*" @@ -27,97 +27,94 @@ header! { /// accept-ext = OWS ";" OWS token [ "=" ( token / quoted-string ) ] /// ``` /// - /// # Example values + /// # Example Values /// * `audio/*; q=0.2, audio/basic` /// * `text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c` /// /// # Examples /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::{Accept, qitem}; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{Accept, QualityItem}; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// Accept(vec![ - /// qitem(mime::TEXT_HTML), + /// QualityItem::max(mime::TEXT_HTML), /// ]) /// ); /// ``` /// /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::{Accept, qitem}; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{Accept, QualityItem}; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// Accept(vec![ - /// qitem(mime::APPLICATION_JSON), + /// QualityItem::max(mime::APPLICATION_JSON), /// ]) /// ); /// ``` /// /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::{Accept, QualityItem, q, qitem}; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{Accept, QualityItem, q}; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// Accept(vec![ - /// qitem(mime::TEXT_HTML), - /// qitem("application/xhtml+xml".parse().unwrap()), - /// QualityItem::new( - /// mime::TEXT_XML, - /// q(900) - /// ), - /// qitem("image/webp".parse().unwrap()), - /// QualityItem::new( - /// mime::STAR_STAR, - /// q(800) - /// ), + /// QualityItem::max(mime::TEXT_HTML), + /// QualityItem::max("application/xhtml+xml".parse().unwrap()), + /// QualityItem::new(mime::TEXT_XML, q(0.9)), + /// QualityItem::max("image/webp".parse().unwrap()), + /// QualityItem::new(mime::STAR_STAR, q(0.8)), /// ]) /// ); /// ``` - (Accept, header::ACCEPT) => (QualityItem)+ + (Accept, header::ACCEPT) => (QualityItem)* - test_accept { + test_parse_and_format { // Tests from the RFC - test_header!( + crate::http::header::common_header_test!( test1, vec![b"audio/*; q=0.2, audio/basic"], Some(Accept(vec![ - QualityItem::new("audio/*".parse().unwrap(), q(200)), - qitem("audio/basic".parse().unwrap()), + QualityItem::new("audio/*".parse().unwrap(), q(0.2)), + QualityItem::max("audio/basic".parse().unwrap()), ]))); - test_header!( + + crate::http::header::common_header_test!( test2, vec![b"text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c"], Some(Accept(vec![ - QualityItem::new(mime::TEXT_PLAIN, q(500)), - qitem(mime::TEXT_HTML), + QualityItem::new(mime::TEXT_PLAIN, q(0.5)), + QualityItem::max(mime::TEXT_HTML), QualityItem::new( "text/x-dvi".parse().unwrap(), - q(800)), - qitem("text/x-c".parse().unwrap()), + q(0.8)), + QualityItem::max("text/x-c".parse().unwrap()), ]))); + // Custom tests - test_header!( + crate::http::header::common_header_test!( test3, vec![b"text/plain; charset=utf-8"], Some(Accept(vec![ - qitem(mime::TEXT_PLAIN_UTF_8), + QualityItem::max(mime::TEXT_PLAIN_UTF_8), ]))); - test_header!( + crate::http::header::common_header_test!( test4, vec![b"text/plain; charset=utf-8; q=0.5"], Some(Accept(vec![ QualityItem::new(mime::TEXT_PLAIN_UTF_8, - q(500)), + q(0.5)), ]))); #[test] fn test_fuzzing1() { - use crate::test::TestRequest; - let req = TestRequest::default().insert_header((crate::header::ACCEPT, "chunk#;e")).finish(); + let req = test::TestRequest::default() + .insert_header((header::ACCEPT, "chunk#;e")) + .finish(); let header = Accept::parse(&req); assert!(header.is_ok()); } @@ -127,34 +124,71 @@ header! { impl Accept { /// Construct `Accept: */*`. pub fn star() -> Accept { - Accept(vec![qitem(mime::STAR_STAR)]) + Accept(vec![QualityItem::max(mime::STAR_STAR)]) } /// Construct `Accept: application/json`. pub fn json() -> Accept { - Accept(vec![qitem(mime::APPLICATION_JSON)]) + Accept(vec![QualityItem::max(mime::APPLICATION_JSON)]) } /// Construct `Accept: text/*`. pub fn text() -> Accept { - Accept(vec![qitem(mime::TEXT_STAR)]) + Accept(vec![QualityItem::max(mime::TEXT_STAR)]) } /// Construct `Accept: image/*`. pub fn image() -> Accept { - Accept(vec![qitem(mime::IMAGE_STAR)]) + Accept(vec![QualityItem::max(mime::IMAGE_STAR)]) } /// Construct `Accept: text/html`. pub fn html() -> Accept { - Accept(vec![qitem(mime::TEXT_HTML)]) + Accept(vec![QualityItem::max(mime::TEXT_HTML)]) + } + + // TODO: method for getting best content encoding based on q-factors, available from server side + // and if none are acceptable return None + + /// Extracts the most preferable mime type, accounting for [q-factor weighting]. + /// + /// If no q-factors are provided, the first mime type is chosen. Note that items without + /// q-factors are given the maximum preference value. + /// + /// As per the spec, will return [`mime::STAR_STAR`] (indicating no preference) if the contained + /// list is empty. + /// + /// [q-factor weighting]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + pub fn preference(&self) -> Mime { + use actix_http::header::Quality; + + let mut max_item = None; + let mut max_pref = Quality::ZERO; + + // uses manual max lookup loop since we want the first occurrence in the case of same + // preference but `Iterator::max_by_key` would give us the last occurrence + + for pref in &self.0 { + // only change if strictly greater + // equal items, even while unsorted, still have higher preference if they appear first + if pref.quality > max_pref { + max_pref = pref.quality; + max_item = Some(pref.item.clone()); + } + } + + max_item.unwrap_or(mime::STAR_STAR) } /// Returns a sorted list of mime types from highest to lowest preference, accounting for /// [q-factor weighting] and specificity. /// - /// [q-factor weighting]: https://tools.ietf.org/html/rfc7231#section-5.3.2 - pub fn mime_precedence(&self) -> Vec { + /// [q-factor weighting]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + pub fn ranked(&self) -> Vec { + if self.is_empty() { + return vec![]; + } + let mut types = self.0.clone(); // use stable sort so items with equal q-factor and specificity retain listed order @@ -195,42 +229,29 @@ impl Accept { types.into_iter().map(|qitem| qitem.item).collect() } - - /// Extracts the most preferable mime type, accounting for [q-factor weighting]. - /// - /// If no q-factors are provided, the first mime type is chosen. Note that items without - /// q-factors are given the maximum preference value. - /// - /// Returns `None` if contained list is empty. - /// - /// [q-factor weighting]: https://tools.ietf.org/html/rfc7231#section-5.3.2 - pub fn mime_preference(&self) -> Option { - let types = self.mime_precedence(); - types.first().cloned() - } } #[cfg(test)] mod tests { use super::*; - use crate::header::q; + use crate::http::header::q; #[test] - fn test_mime_precedence() { + fn ranking_precedence() { let test = Accept(vec![]); - assert!(test.mime_precedence().is_empty()); + assert!(test.ranked().is_empty()); - let test = Accept(vec![qitem(mime::APPLICATION_JSON)]); - assert_eq!(test.mime_precedence(), vec!(mime::APPLICATION_JSON)); + let test = Accept(vec![QualityItem::max(mime::APPLICATION_JSON)]); + assert_eq!(test.ranked(), vec![mime::APPLICATION_JSON]); let test = Accept(vec![ - qitem(mime::TEXT_HTML), + QualityItem::max(mime::TEXT_HTML), "application/xhtml+xml".parse().unwrap(), QualityItem::new("application/xml".parse().unwrap(), q(0.9)), QualityItem::new(mime::STAR_STAR, q(0.8)), ]); assert_eq!( - test.mime_precedence(), + test.ranked(), vec![ mime::TEXT_HTML, "application/xhtml+xml".parse().unwrap(), @@ -240,33 +261,33 @@ mod tests { ); let test = Accept(vec![ - qitem(mime::STAR_STAR), - qitem(mime::IMAGE_STAR), - qitem(mime::IMAGE_PNG), + QualityItem::max(mime::STAR_STAR), + QualityItem::max(mime::IMAGE_STAR), + QualityItem::max(mime::IMAGE_PNG), ]); assert_eq!( - test.mime_precedence(), + test.ranked(), vec![mime::IMAGE_PNG, mime::IMAGE_STAR, mime::STAR_STAR] ); } #[test] - fn test_mime_preference() { + fn preference_selection() { let test = Accept(vec![ - qitem(mime::TEXT_HTML), + QualityItem::max(mime::TEXT_HTML), "application/xhtml+xml".parse().unwrap(), QualityItem::new("application/xml".parse().unwrap(), q(0.9)), QualityItem::new(mime::STAR_STAR, q(0.8)), ]); - assert_eq!(test.mime_preference(), Some(mime::TEXT_HTML)); + assert_eq!(test.preference(), mime::TEXT_HTML); let test = Accept(vec![ QualityItem::new("video/*".parse().unwrap(), q(0.8)), - qitem(mime::IMAGE_PNG), + QualityItem::max(mime::IMAGE_PNG), QualityItem::new(mime::STAR_STAR, q(0.5)), - qitem(mime::IMAGE_SVG), + QualityItem::max(mime::IMAGE_SVG), QualityItem::new(mime::IMAGE_STAR, q(0.8)), ]); - assert_eq!(test.mime_preference(), Some(mime::IMAGE_PNG)); + assert_eq!(test.preference(), mime::IMAGE_PNG); } } diff --git a/src/http/header/accept_charset.rs b/src/http/header/accept_charset.rs new file mode 100644 index 000000000..c7f7e1a68 --- /dev/null +++ b/src/http/header/accept_charset.rs @@ -0,0 +1,62 @@ +use super::{common_header, Charset, QualityItem, ACCEPT_CHARSET}; + +common_header! { + /// `Accept-Charset` header, defined in [RFC 7231 §5.3.3]. + /// + /// The `Accept-Charset` header field can be sent by a user agent to + /// indicate what charsets are acceptable in textual response content. + /// This field allows user agents capable of understanding more + /// comprehensive or special-purpose charsets to signal that capability + /// to an origin server that is capable of representing information in + /// those charsets. + /// + /// # ABNF + /// ```plain + /// Accept-Charset = 1#( ( charset / "*" ) [ weight ] ) + /// ``` + /// + /// # Example Values + /// * `iso-8859-5, unicode-1-1;q=0.8` + /// + /// # Examples + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{AcceptCharset, Charset, QualityItem}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header( + /// AcceptCharset(vec![QualityItem::max(Charset::Us_Ascii)]) + /// ); + /// ``` + /// + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{AcceptCharset, Charset, q, QualityItem}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header( + /// AcceptCharset(vec![ + /// QualityItem::new(Charset::Us_Ascii, q(0.9)), + /// QualityItem::new(Charset::Iso_8859_10, q(0.2)), + /// ]) + /// ); + /// ``` + /// + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{AcceptCharset, Charset, QualityItem}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header( + /// AcceptCharset(vec![QualityItem::max(Charset::Ext("utf-8".to_owned()))]) + /// ); + /// ``` + /// + /// [RFC 7231 §5.3.3]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.3 + (AcceptCharset, ACCEPT_CHARSET) => (QualityItem)* + + test_parse_and_format { + // Test case from RFC + common_header_test!(test1, vec![b"iso-8859-5, unicode-1-1;q=0.8"]); + } +} diff --git a/src/http/header/accept_encoding.rs b/src/http/header/accept_encoding.rs new file mode 100644 index 000000000..8c35179b6 --- /dev/null +++ b/src/http/header/accept_encoding.rs @@ -0,0 +1,429 @@ +use std::collections::HashSet; + +use super::{common_header, ContentEncoding, Encoding, Preference, Quality, QualityItem}; +use crate::http::header; + +common_header! { + /// `Accept-Encoding` header, defined + /// in [RFC 7231](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.4) + /// + /// The `Accept-Encoding` header field can be used by user agents to indicate what response + /// content-codings are acceptable in the response. An `identity` token is used as a synonym + /// for "no encoding" in order to communicate when no encoding is preferred. + /// + /// # ABNF + /// ```plain + /// Accept-Encoding = #( codings [ weight ] ) + /// codings = content-coding / "identity" / "*" + /// ``` + /// + /// # Example Values + /// * `compress, gzip` + /// * `` + /// * `*` + /// * `compress;q=0.5, gzip;q=1` + /// * `gzip;q=1.0, identity; q=0.5, *;q=0` + /// + /// # Examples + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{AcceptEncoding, Encoding, Preference, QualityItem}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header( + /// AcceptEncoding(vec![QualityItem::max(Preference::Specific(Encoding::gzip()))]) + /// ); + /// ``` + /// + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{AcceptEncoding, Encoding, QualityItem}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header( + /// AcceptEncoding(vec![ + /// "gzip".parse().unwrap(), + /// "br".parse().unwrap(), + /// ]) + /// ); + /// ``` + (AcceptEncoding, header::ACCEPT_ENCODING) => (QualityItem>)* + + test_parse_and_format { + common_header_test!(no_headers, vec![b""; 0], Some(AcceptEncoding(vec![]))); + common_header_test!(empty_header, vec![b""; 1], Some(AcceptEncoding(vec![]))); + + common_header_test!( + order_of_appearance, + vec![b"br, gzip"], + Some(AcceptEncoding(vec![ + QualityItem::max(Preference::Specific(Encoding::brotli())), + QualityItem::max(Preference::Specific(Encoding::gzip())), + ])) + ); + + common_header_test!(any, vec![b"*"], Some(AcceptEncoding(vec![ + QualityItem::max(Preference::Any), + ]))); + + // Note: Removed quality 1 from gzip + common_header_test!(implicit_quality, vec![b"gzip, identity; q=0.5, *;q=0"]); + + // Note: Removed quality 1 from gzip + common_header_test!(implicit_quality_out_of_order, vec![b"compress;q=0.5, gzip"]); + + common_header_test!( + only_gzip_no_identity, + vec![b"gzip, *; q=0"], + Some(AcceptEncoding(vec![ + QualityItem::max(Preference::Specific(Encoding::gzip())), + QualityItem::zero(Preference::Any), + ])) + ); + } +} + +impl AcceptEncoding { + /// Selects the most acceptable encoding according to client preference and supported types. + /// + /// The "identity" encoding is not assumed and should be included in the `supported` iterator + /// if a non-encoded representation can be selected. + /// + /// If `None` is returned, this indicates that none of the supported encodings are acceptable to + /// the client. The caller should generate a 406 Not Acceptable response (unencoded) that + /// includes the server's supported encodings in the body plus a [`Vary`] header. + /// + /// [`Vary`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary + pub fn negotiate<'a>( + &self, + supported: impl Iterator, + ) -> Option { + // 1. If no Accept-Encoding field is in the request, any content-coding is considered + // acceptable by the user agent. + + let supported_set = supported.collect::>(); + + if supported_set.is_empty() { + return None; + } + + if self.0.is_empty() { + // though it is not recommended to encode in this case, return identity encoding + return Some(Encoding::identity()); + } + + // 2. If the representation has no content-coding, then it is acceptable by default unless + // specifically excluded by the Accept-Encoding field stating either "identity;q=0" or + // "*;q=0" without a more specific entry for "identity". + + let acceptable_items = self.ranked_items().collect::>(); + + let identity_acceptable = is_identity_acceptable(&acceptable_items); + let identity_supported = supported_set.contains(&Encoding::identity()); + + if identity_acceptable && identity_supported && supported_set.len() == 1 { + return Some(Encoding::identity()); + } + + // 3. If the representation's content-coding is one of the content-codings listed in the + // Accept-Encoding field, then it is acceptable unless it is accompanied by a qvalue of 0. + + // 4. If multiple content-codings are acceptable, then the acceptable content-coding with + // the highest non-zero qvalue is preferred. + + let matched = acceptable_items + .into_iter() + .filter(|q| q.quality > Quality::ZERO) + // search relies on item list being in descending order of quality + .find(|q| { + let enc = &q.item; + matches!(enc, Preference::Specific(enc) if supported_set.contains(enc)) + }) + .map(|q| q.item); + + match matched { + Some(Preference::Specific(enc)) => Some(enc), + + _ if identity_acceptable => Some(Encoding::identity()), + + _ => None, + } + } + + /// Extracts the most preferable encoding, accounting for [q-factor weighting]. + /// + /// If no q-factors are provided, the first encoding is chosen. Note that items without + /// q-factors are given the maximum preference value. + /// + /// As per the spec, returns [`Preference::Any`] if acceptable list is empty. Though, if this is + /// returned, it is recommended to use an un-encoded representation. + /// + /// If `None` is returned, it means that the client has signalled that no representations + /// are acceptable. This should never occur for a well behaved user-agent. + /// + /// [q-factor weighting]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + pub fn preference(&self) -> Option> { + // empty header indicates no preference + if self.0.is_empty() { + return Some(Preference::Any); + } + + let mut max_item = None; + let mut max_pref = Quality::ZERO; + + // uses manual max lookup loop since we want the first occurrence in the case of same + // preference but `Iterator::max_by_key` would give us the last occurrence + + for pref in &self.0 { + // only change if strictly greater + // equal items, even while unsorted, still have higher preference if they appear first + if pref.quality > max_pref { + max_pref = pref.quality; + max_item = Some(pref.item.clone()); + } + } + + // Return max_item if any items were above 0 quality... + max_item.or_else(|| { + // ...or else check for "*" or "identity". We can elide quality checks since + // entering this block means all items had "q=0". + match self.0.iter().find(|pref| { + matches!( + pref.item, + Preference::Any + | Preference::Specific(Encoding::Known(ContentEncoding::Identity)) + ) + }) { + // "identity" or "*" found so no representation is acceptable + Some(_) => None, + + // implicit "identity" is acceptable + None => Some(Preference::Specific(Encoding::identity())), + } + }) + } + + /// Returns a sorted list of encodings from highest to lowest precedence, accounting + /// for [q-factor weighting]. + /// + /// [q-factor weighting]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + pub fn ranked(&self) -> Vec> { + self.ranked_items().map(|q| q.item).collect() + } + + fn ranked_items(&self) -> impl Iterator>> { + if self.0.is_empty() { + return vec![].into_iter(); + } + + let mut types = self.0.clone(); + + // use stable sort so items with equal q-factor retain listed order + types.sort_by(|a, b| { + // sort by q-factor descending + b.quality.cmp(&a.quality) + }); + + types.into_iter() + } +} + +/// Returns true if "identity" is an acceptable encoding. +/// +/// Internal algorithm relies on item list being in descending order of quality. +fn is_identity_acceptable(items: &'_ [QualityItem>]) -> bool { + if items.is_empty() { + return true; + } + + // Loop algorithm depends on items being sorted in descending order of quality. As such, it + // is sufficient to return (q > 0) when reaching either an "identity" or "*" item. + for q in items { + match (q.quality, &q.item) { + // occurrence of "identity;q=n"; return true if quality is non-zero + (q, Preference::Specific(Encoding::Known(ContentEncoding::Identity))) => { + return q > Quality::ZERO + } + + // occurrence of "*;q=n"; return true if quality is non-zero + (q, Preference::Any) => return q > Quality::ZERO, + + _ => {} + } + } + + // implicit acceptable identity + true +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::http::header::*; + + macro_rules! accept_encoding { + () => { AcceptEncoding(vec![]) }; + ($($q:expr),+ $(,)?) => { AcceptEncoding(vec![$($q.parse().unwrap()),+]) }; + } + + /// Parses an encoding string. + fn enc(enc: &str) -> Preference { + enc.parse().unwrap() + } + + #[test] + fn detect_identity_acceptable() { + macro_rules! accept_encoding_ranked { + () => { accept_encoding!().ranked_items().collect::>() }; + ($($q:expr),+ $(,)?) => { accept_encoding!($($q),+).ranked_items().collect::>() }; + } + + let test = accept_encoding_ranked!(); + assert!(is_identity_acceptable(&test)); + let test = accept_encoding_ranked!("gzip"); + assert!(is_identity_acceptable(&test)); + let test = accept_encoding_ranked!("gzip", "br"); + assert!(is_identity_acceptable(&test)); + let test = accept_encoding_ranked!("gzip", "*;q=0.1"); + assert!(is_identity_acceptable(&test)); + let test = accept_encoding_ranked!("gzip", "identity;q=0.1"); + assert!(is_identity_acceptable(&test)); + let test = accept_encoding_ranked!("gzip", "identity;q=0.1", "*;q=0"); + assert!(is_identity_acceptable(&test)); + let test = accept_encoding_ranked!("gzip", "*;q=0", "identity;q=0.1"); + assert!(is_identity_acceptable(&test)); + + let test = accept_encoding_ranked!("gzip", "*;q=0"); + assert!(!is_identity_acceptable(&test)); + let test = accept_encoding_ranked!("gzip", "identity;q=0"); + assert!(!is_identity_acceptable(&test)); + let test = accept_encoding_ranked!("gzip", "identity;q=0", "*;q=0"); + assert!(!is_identity_acceptable(&test)); + let test = accept_encoding_ranked!("gzip", "*;q=0", "identity;q=0"); + assert!(!is_identity_acceptable(&test)); + } + + #[test] + fn encoding_negotiation() { + // no preference + let test = accept_encoding!(); + assert_eq!(test.negotiate([].iter()), None); + + let test = accept_encoding!(); + assert_eq!( + test.negotiate([Encoding::identity()].iter()), + Some(Encoding::identity()), + ); + + let test = accept_encoding!("identity;q=0"); + assert_eq!(test.negotiate([Encoding::identity()].iter()), None); + + let test = accept_encoding!("*;q=0"); + assert_eq!(test.negotiate([Encoding::identity()].iter()), None); + + let test = accept_encoding!(); + assert_eq!( + test.negotiate([Encoding::gzip(), Encoding::identity()].iter()), + Some(Encoding::identity()), + ); + + let test = accept_encoding!("gzip"); + assert_eq!( + test.negotiate([Encoding::gzip(), Encoding::identity()].iter()), + Some(Encoding::gzip()), + ); + assert_eq!( + test.negotiate([Encoding::brotli(), Encoding::identity()].iter()), + Some(Encoding::identity()), + ); + assert_eq!( + test.negotiate([Encoding::brotli(), Encoding::gzip(), Encoding::identity()].iter()), + Some(Encoding::gzip()), + ); + + let test = accept_encoding!("gzip", "identity;q=0"); + assert_eq!( + test.negotiate([Encoding::gzip(), Encoding::identity()].iter()), + Some(Encoding::gzip()), + ); + assert_eq!( + test.negotiate([Encoding::brotli(), Encoding::identity()].iter()), + None + ); + + let test = accept_encoding!("gzip", "*;q=0"); + assert_eq!( + test.negotiate([Encoding::gzip(), Encoding::identity()].iter()), + Some(Encoding::gzip()), + ); + assert_eq!( + test.negotiate([Encoding::brotli(), Encoding::identity()].iter()), + None + ); + + let test = accept_encoding!("gzip", "deflate", "br"); + assert_eq!( + test.negotiate([Encoding::gzip(), Encoding::identity()].iter()), + Some(Encoding::gzip()), + ); + assert_eq!( + test.negotiate([Encoding::brotli(), Encoding::identity()].iter()), + Some(Encoding::brotli()) + ); + assert_eq!( + test.negotiate([Encoding::deflate(), Encoding::identity()].iter()), + Some(Encoding::deflate()) + ); + assert_eq!( + test.negotiate( + [Encoding::gzip(), Encoding::deflate(), Encoding::identity()].iter() + ), + Some(Encoding::gzip()) + ); + assert_eq!( + test.negotiate([Encoding::gzip(), Encoding::brotli(), Encoding::identity()].iter()), + Some(Encoding::gzip()) + ); + assert_eq!( + test.negotiate([Encoding::brotli(), Encoding::gzip(), Encoding::identity()].iter()), + Some(Encoding::gzip()) + ); + } + + #[test] + fn ranking_precedence() { + let test = accept_encoding!(); + assert!(test.ranked().is_empty()); + + let test = accept_encoding!("gzip"); + assert_eq!(test.ranked(), vec![enc("gzip")]); + + let test = accept_encoding!("gzip;q=0.900", "*;q=0.700", "br;q=1.0"); + assert_eq!(test.ranked(), vec![enc("br"), enc("gzip"), enc("*")]); + + let test = accept_encoding!("br", "gzip", "*"); + assert_eq!(test.ranked(), vec![enc("br"), enc("gzip"), enc("*")]); + } + + #[test] + fn preference_selection() { + assert_eq!(accept_encoding!().preference(), Some(Preference::Any)); + + assert_eq!(accept_encoding!("identity;q=0").preference(), None); + assert_eq!(accept_encoding!("*;q=0").preference(), None); + assert_eq!(accept_encoding!("compress;q=0", "*;q=0").preference(), None); + assert_eq!(accept_encoding!("identity;q=0", "*;q=0").preference(), None); + + let test = accept_encoding!("*;q=0.5"); + assert_eq!(test.preference().unwrap(), enc("*")); + + let test = accept_encoding!("br;q=0"); + assert_eq!(test.preference().unwrap(), enc("identity")); + + let test = accept_encoding!("br;q=0.900", "gzip;q=1.0", "*;q=0.500"); + assert_eq!(test.preference().unwrap(), enc("gzip")); + + let test = accept_encoding!("br", "gzip", "*"); + assert_eq!(test.preference().unwrap(), enc("br")); + } +} diff --git a/src/http/header/accept_language.rs b/src/http/header/accept_language.rs new file mode 100644 index 000000000..9943e121f --- /dev/null +++ b/src/http/header/accept_language.rs @@ -0,0 +1,223 @@ +use language_tags::LanguageTag; + +use super::{common_header, Preference, Quality, QualityItem}; +use crate::http::header; + +common_header! { + /// `Accept-Language` header, defined + /// in [RFC 7231 §5.3.5](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.5) + /// + /// The `Accept-Language` header field can be used by user agents to indicate the set of natural + /// languages that are preferred in the response. + /// + /// The `Accept-Language` header is defined in + /// [RFC 7231 §5.3.5](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.5) using language + /// ranges defined in [RFC 4647 §2.1](https://datatracker.ietf.org/doc/html/rfc4647#section-2.1). + /// + /// # ABNF + /// ```plain + /// Accept-Language = 1#( language-range [ weight ] ) + /// language-range = (1*8ALPHA *("-" 1*8alphanum)) / "*" + /// alphanum = ALPHA / DIGIT + /// weight = OWS ";" OWS "q=" qvalue + /// qvalue = ( "0" [ "." 0*3DIGIT ] ) + /// / ( "1" [ "." 0*3("0") ] ) + /// ``` + /// + /// # Example Values + /// - `da, en-gb;q=0.8, en;q=0.7` + /// - `en-us;q=1.0, en;q=0.5, fr` + /// - `fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5` + /// + /// # Examples + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{AcceptLanguage, QualityItem}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header( + /// AcceptLanguage(vec![ + /// "en-US".parse().unwrap(), + /// ]) + /// ); + /// ``` + /// + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{AcceptLanguage, QualityItem, q}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header( + /// AcceptLanguage(vec![ + /// "da".parse().unwrap(), + /// "en-GB;q=0.8".parse().unwrap(), + /// "en;q=0.7".parse().unwrap(), + /// ]) + /// ); + /// ``` + (AcceptLanguage, header::ACCEPT_LANGUAGE) => (QualityItem>)* + + test_parse_and_format { + common_header_test!(no_headers, vec![b""; 0], Some(AcceptLanguage(vec![]))); + + common_header_test!(empty_header, vec![b""; 1], Some(AcceptLanguage(vec![]))); + + common_header_test!( + example_from_rfc, + vec![b"da, en-gb;q=0.8, en;q=0.7"] + ); + + + common_header_test!( + not_ordered_by_weight, + vec![b"en-US, en; q=0.5, fr"], + Some(AcceptLanguage(vec![ + QualityItem::max("en-US".parse().unwrap()), + QualityItem::new("en".parse().unwrap(), q(0.5)), + QualityItem::max("fr".parse().unwrap()), + ])) + ); + + common_header_test!( + has_wildcard, + vec![b"fr-CH, fr; q=0.9, en; q=0.8, de; q=0.7, *; q=0.5"], + Some(AcceptLanguage(vec![ + QualityItem::max("fr-CH".parse().unwrap()), + QualityItem::new("fr".parse().unwrap(), q(0.9)), + QualityItem::new("en".parse().unwrap(), q(0.8)), + QualityItem::new("de".parse().unwrap(), q(0.7)), + QualityItem::new("*".parse().unwrap(), q(0.5)), + ])) + ); + } +} + +impl AcceptLanguage { + /// Extracts the most preferable language, accounting for [q-factor weighting]. + /// + /// If no q-factors are provided, the first language is chosen. Note that items without + /// q-factors are given the maximum preference value. + /// + /// As per the spec, returns [`Preference::Any`] if contained list is empty. + /// + /// [q-factor weighting]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + pub fn preference(&self) -> Preference { + let mut max_item = None; + let mut max_pref = Quality::ZERO; + + // uses manual max lookup loop since we want the first occurrence in the case of same + // preference but `Iterator::max_by_key` would give us the last occurrence + + for pref in &self.0 { + // only change if strictly greater + // equal items, even while unsorted, still have higher preference if they appear first + if pref.quality > max_pref { + max_pref = pref.quality; + max_item = Some(pref.item.clone()); + } + } + + max_item.unwrap_or(Preference::Any) + } + + /// Returns a sorted list of languages from highest to lowest precedence, accounting + /// for [q-factor weighting]. + /// + /// [q-factor weighting]: https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + pub fn ranked(&self) -> Vec> { + if self.0.is_empty() { + return vec![]; + } + + let mut types = self.0.clone(); + + // use stable sort so items with equal q-factor retain listed order + types.sort_by(|a, b| { + // sort by q-factor descending + b.quality.cmp(&a.quality) + }); + + types.into_iter().map(|qitem| qitem.item).collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::http::header::*; + + #[test] + fn ranking_precedence() { + let test = AcceptLanguage(vec![]); + assert!(test.ranked().is_empty()); + + let test = AcceptLanguage(vec![QualityItem::max("fr-CH".parse().unwrap())]); + assert_eq!(test.ranked(), vec!["fr-CH".parse().unwrap()]); + + let test = AcceptLanguage(vec![ + QualityItem::new("fr".parse().unwrap(), q(0.900)), + QualityItem::new("fr-CH".parse().unwrap(), q(1.0)), + QualityItem::new("en".parse().unwrap(), q(0.800)), + QualityItem::new("*".parse().unwrap(), q(0.500)), + QualityItem::new("de".parse().unwrap(), q(0.700)), + ]); + assert_eq!( + test.ranked(), + vec![ + "fr-CH".parse().unwrap(), + "fr".parse().unwrap(), + "en".parse().unwrap(), + "de".parse().unwrap(), + "*".parse().unwrap(), + ] + ); + + let test = AcceptLanguage(vec![ + QualityItem::max("fr".parse().unwrap()), + QualityItem::max("fr-CH".parse().unwrap()), + QualityItem::max("en".parse().unwrap()), + QualityItem::max("*".parse().unwrap()), + QualityItem::max("de".parse().unwrap()), + ]); + assert_eq!( + test.ranked(), + vec![ + "fr".parse().unwrap(), + "fr-CH".parse().unwrap(), + "en".parse().unwrap(), + "*".parse().unwrap(), + "de".parse().unwrap(), + ] + ); + } + + #[test] + fn preference_selection() { + let test = AcceptLanguage(vec![ + QualityItem::new("fr".parse().unwrap(), q(0.900)), + QualityItem::new("fr-CH".parse().unwrap(), q(1.0)), + QualityItem::new("en".parse().unwrap(), q(0.800)), + QualityItem::new("*".parse().unwrap(), q(0.500)), + QualityItem::new("de".parse().unwrap(), q(0.700)), + ]); + assert_eq!( + test.preference(), + Preference::Specific("fr-CH".parse().unwrap()) + ); + + let test = AcceptLanguage(vec![ + QualityItem::max("fr".parse().unwrap()), + QualityItem::max("fr-CH".parse().unwrap()), + QualityItem::max("en".parse().unwrap()), + QualityItem::max("*".parse().unwrap()), + QualityItem::max("de".parse().unwrap()), + ]); + assert_eq!( + test.preference(), + Preference::Specific("fr".parse().unwrap()) + ); + + let test = AcceptLanguage(vec![]); + assert_eq!(test.preference(), Preference::Any); + } +} diff --git a/actix-http/src/header/common/allow.rs b/src/http/header/allow.rs similarity index 65% rename from actix-http/src/header/common/allow.rs rename to src/http/header/allow.rs index 06b1efedc..d0ef96486 100644 --- a/actix-http/src/header/common/allow.rs +++ b/src/http/header/allow.rs @@ -1,42 +1,42 @@ -use http::header; -use http::Method; +use actix_http::Method; -header! { - /// `Allow` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-7.4.1) +use crate::http::header; + +crate::http::header::common_header! { + /// `Allow` header, defined + /// in [RFC 7231 §7.4.1](https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1) /// /// The `Allow` header field lists the set of methods advertised as - /// supported by the target resource. The purpose of this field is + /// supported by the target resource. The purpose of this field is /// strictly to inform the recipient of valid request methods associated /// with the resource. /// /// # ABNF - /// - /// ```text + /// ```plain /// Allow = #method /// ``` /// - /// # Example values + /// # Example Values /// * `GET, HEAD, PUT` /// * `OPTIONS, GET, PUT, POST, DELETE, HEAD, TRACE, CONNECT, PATCH, fOObAr` /// * `` /// /// # Examples - /// /// ``` - /// use actix_http::Response; - /// use actix_http::http::{header::Allow, Method}; + /// use actix_web::HttpResponse; + /// use actix_web::http::{header::Allow, Method}; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// Allow(vec![Method::GET]) /// ); /// ``` /// /// ``` - /// use actix_http::Response; - /// use actix_http::http::{header::Allow, Method}; + /// use actix_web::HttpResponse; + /// use actix_web::http::{header::Allow, Method}; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// Allow(vec![ /// Method::GET, @@ -47,14 +47,14 @@ header! { /// ``` (Allow, header::ALLOW) => (Method)* - test_allow { + test_parse_and_format { // From the RFC - test_header!( + crate::http::header::common_header_test!( test1, vec![b"GET, HEAD, PUT"], Some(HeaderField(vec![Method::GET, Method::HEAD, Method::PUT]))); // Own tests - test_header!( + crate::http::header::common_header_test!( test2, vec![b"OPTIONS, GET, PUT, POST, DELETE, HEAD, TRACE, CONNECT, PATCH"], Some(HeaderField(vec![ @@ -67,7 +67,7 @@ header! { Method::TRACE, Method::CONNECT, Method::PATCH]))); - test_header!( + crate::http::header::common_header_test!( test3, vec![b""], Some(HeaderField(Vec::::new()))); diff --git a/src/http/header/any_or_some.rs b/src/http/header/any_or_some.rs new file mode 100644 index 000000000..e5a37e495 --- /dev/null +++ b/src/http/header/any_or_some.rs @@ -0,0 +1,70 @@ +use std::{ + fmt::{self, Write as _}, + str, +}; + +/// A wrapper for types used in header values where wildcard (`*`) items are allowed but the +/// underlying type does not support them. +/// +/// For example, we use the `language-tags` crate for the [`AcceptLanguage`](super::AcceptLanguage) +/// typed header but it does parse `*` successfully. On the other hand, the `mime` crate, used for +/// [`Accept`](super::Accept), has first-party support for wildcard items so this wrapper is not +/// used in those header types. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Hash)] +pub enum AnyOrSome { + /// A wildcard value. + Any, + + /// A valid `T`. + Item(T), +} + +impl AnyOrSome { + /// Returns true if item is wildcard (`*`) variant. + pub fn is_any(&self) -> bool { + matches!(self, Self::Any) + } + + /// Returns true if item is a valid item (`T`) variant. + pub fn is_item(&self) -> bool { + matches!(self, Self::Item(_)) + } + + /// Returns reference to value in `Item` variant, if it is set. + pub fn item(&self) -> Option<&T> { + match self { + AnyOrSome::Item(ref item) => Some(item), + AnyOrSome::Any => None, + } + } + + /// Consumes the container, returning the value in the `Item` variant, if it is set. + pub fn into_item(self) -> Option { + match self { + AnyOrSome::Item(item) => Some(item), + AnyOrSome::Any => None, + } + } +} + +impl fmt::Display for AnyOrSome { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AnyOrSome::Any => f.write_char('*'), + AnyOrSome::Item(item) => fmt::Display::fmt(item, f), + } + } +} + +impl str::FromStr for AnyOrSome { + type Err = T::Err; + + #[inline] + fn from_str(s: &str) -> Result { + match s.trim() { + "*" => Ok(Self::Any), + other => other.parse().map(AnyOrSome::Item), + } + } +} diff --git a/src/http/header/cache_control.rs b/src/http/header/cache_control.rs new file mode 100644 index 000000000..490d36558 --- /dev/null +++ b/src/http/header/cache_control.rs @@ -0,0 +1,192 @@ +use std::{fmt, str}; + +use super::common_header; +use crate::http::header; + +common_header! { + /// `Cache-Control` header, defined + /// in [RFC 7234 §5.2](https://datatracker.ietf.org/doc/html/rfc7234#section-5.2). + /// + /// The `Cache-Control` header field is used to specify directives for + /// caches along the request/response chain. Such cache directives are + /// unidirectional in that the presence of a directive in a request does + /// not imply that the same directive is to be given in the response. + /// + /// # ABNF + /// ```text + /// Cache-Control = 1#cache-directive + /// cache-directive = token [ "=" ( token / quoted-string ) ] + /// ``` + /// + /// # Example Values + /// * `no-cache` + /// * `private, community="UCI"` + /// * `max-age=30` + /// + /// # Examples + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{CacheControl, CacheDirective}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header(CacheControl(vec![CacheDirective::MaxAge(86400u32)])); + /// ``` + /// + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{CacheControl, CacheDirective}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header(CacheControl(vec![ + /// CacheDirective::NoCache, + /// CacheDirective::Private, + /// CacheDirective::MaxAge(360u32), + /// CacheDirective::Extension("foo".to_owned(), Some("bar".to_owned())), + /// ])); + /// ``` + (CacheControl, header::CACHE_CONTROL) => (CacheDirective)+ + + test_parse_and_format { + common_header_test!(no_headers, vec![b""; 0], None); + common_header_test!(empty_header, vec![b""; 1], None); + common_header_test!(bad_syntax, vec![b"foo="], None); + + common_header_test!( + multiple_headers, + vec![&b"no-cache"[..], &b"private"[..]], + Some(CacheControl(vec![ + CacheDirective::NoCache, + CacheDirective::Private, + ])) + ); + + common_header_test!( + argument, + vec![b"max-age=100, private"], + Some(CacheControl(vec![ + CacheDirective::MaxAge(100), + CacheDirective::Private, + ])) + ); + + common_header_test!( + extension, + vec![b"foo, bar=baz"], + Some(CacheControl(vec![ + CacheDirective::Extension("foo".to_owned(), None), + CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned())), + ])) + ); + + #[test] + fn parse_quote_form() { + let req = test::TestRequest::default() + .insert_header((header::CACHE_CONTROL, "max-age=\"200\"")) + .finish(); + + assert_eq!( + Header::parse(&req).ok(), + Some(CacheControl(vec![CacheDirective::MaxAge(200)])) + ) + } + } +} + +/// `CacheControl` contains a list of these directives. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CacheDirective { + /// "no-cache" + NoCache, + /// "no-store" + NoStore, + /// "no-transform" + NoTransform, + /// "only-if-cached" + OnlyIfCached, + + // request directives + /// "max-age=delta" + MaxAge(u32), + /// "max-stale=delta" + MaxStale(u32), + /// "min-fresh=delta" + MinFresh(u32), + + // response directives + /// "must-revalidate" + MustRevalidate, + /// "public" + Public, + /// "private" + Private, + /// "proxy-revalidate" + ProxyRevalidate, + /// "s-maxage=delta" + SMaxAge(u32), + + /// Extension directives. Optionally include an argument. + Extension(String, Option), +} + +impl fmt::Display for CacheDirective { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use self::CacheDirective::*; + + let dir_str = match self { + NoCache => "no-cache", + NoStore => "no-store", + NoTransform => "no-transform", + OnlyIfCached => "only-if-cached", + + MaxAge(secs) => return write!(f, "max-age={}", secs), + MaxStale(secs) => return write!(f, "max-stale={}", secs), + MinFresh(secs) => return write!(f, "min-fresh={}", secs), + + MustRevalidate => "must-revalidate", + Public => "public", + Private => "private", + ProxyRevalidate => "proxy-revalidate", + SMaxAge(secs) => return write!(f, "s-maxage={}", secs), + + Extension(name, None) => name.as_str(), + Extension(name, Some(arg)) => return write!(f, "{}={}", name, arg), + }; + + f.write_str(dir_str) + } +} + +impl str::FromStr for CacheDirective { + type Err = Option<::Err>; + + fn from_str(s: &str) -> Result { + use self::CacheDirective::*; + + match s { + "" => Err(None), + + "no-cache" => Ok(NoCache), + "no-store" => Ok(NoStore), + "no-transform" => Ok(NoTransform), + "only-if-cached" => Ok(OnlyIfCached), + "must-revalidate" => Ok(MustRevalidate), + "public" => Ok(Public), + "private" => Ok(Private), + "proxy-revalidate" => Ok(ProxyRevalidate), + + _ => match s.find('=') { + Some(idx) if idx + 1 < s.len() => { + match (&s[..idx], (&s[idx + 1..]).trim_matches('"')) { + ("max-age", secs) => secs.parse().map(MaxAge).map_err(Some), + ("max-stale", secs) => secs.parse().map(MaxStale).map_err(Some), + ("min-fresh", secs) => secs.parse().map(MinFresh).map_err(Some), + ("s-maxage", secs) => secs.parse().map(SMaxAge).map_err(Some), + (left, right) => Ok(Extension(left.to_owned(), Some(right.to_owned()))), + } + } + Some(_) => Err(None), + None => Ok(Extension(s.to_owned(), None)), + }, + } + } +} diff --git a/actix-http/src/header/common/content_disposition.rs b/src/http/header/content_disposition.rs similarity index 86% rename from actix-http/src/header/common/content_disposition.rs rename to src/http/header/content_disposition.rs index ecc59aba3..8b7101aa1 100644 --- a/actix-http/src/header/common/content_disposition.rs +++ b/src/http/header/content_disposition.rs @@ -1,16 +1,21 @@ -//! # References +//! The `Content-Disposition` header and associated types. //! -//! "The Content-Disposition Header Field" https://www.ietf.org/rfc/rfc2183.txt -//! "The Content-Disposition Header Field in the Hypertext Transfer Protocol (HTTP)" https://www.ietf.org/rfc/rfc6266.txt -//! "Returning Values from Forms: multipart/form-data" https://www.ietf.org/rfc/rfc7578.txt -//! Browser conformance tests at: http://greenbytes.de/tech/tc2231/ -//! IANA assignment: http://www.iana.org/assignments/cont-disp/cont-disp.xhtml +//! # References +//! - "The Content-Disposition Header Field": +//! +//! - "The Content-Disposition Header Field in the Hypertext Transfer Protocol (HTTP)": +//! +//! - "Returning Values from Forms: multipart/form-data": +//! +//! - Browser conformance tests at: +//! - IANA assignment: -use lazy_static::lazy_static; +use once_cell::sync::Lazy; use regex::Regex; use std::fmt::{self, Write}; -use crate::header::{self, ExtendedValue, Header, IntoHeaderValue, Writer}; +use super::{ExtendedValue, Header, TryIntoHeaderValue, Writer}; +use crate::http::header; /// Split at the index of the first `needle` if it exists or at the end. fn split_once(haystack: &str, needle: char) -> (&str, &str) { @@ -33,15 +38,19 @@ fn split_once_and_trim(haystack: &str, needle: char) -> (&str, &str) { /// The implied disposition of the content of the HTTP body. #[derive(Clone, Debug, PartialEq)] pub enum DispositionType { - /// Inline implies default processing + /// Inline implies default processing. Inline, + /// Attachment implies that the recipient should prompt the user to save the response locally, /// rather than process it normally (as per its media type). Attachment, + /// Used in *multipart/form-data* as defined in - /// [RFC7578](https://tools.ietf.org/html/rfc7578) to carry the field name and the file name. + /// [RFC 7578](https://datatracker.ietf.org/doc/html/rfc7578) to carry the field name and + /// optional filename. FormData, - /// Extension type. Should be handled by recipients the same way as Attachment + + /// Extension type. Should be handled by recipients the same way as Attachment. Ext(String), } @@ -63,7 +72,7 @@ impl<'a> From<&'a str> for DispositionType { /// /// # Examples /// ``` -/// use actix_http::http::header::DispositionParam; +/// use actix_web::http::header::DispositionParam; /// /// let param = DispositionParam::Filename(String::from("sample.txt")); /// assert!(param.is_filename()); @@ -75,25 +84,32 @@ pub enum DispositionParam { /// For [`DispositionType::FormData`] (i.e. *multipart/form-data*), the name of an field from /// the form. Name(String), + /// A plain file name. /// - /// It is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any - /// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where + /// It is [not supposed](https://datatracker.ietf.org/doc/html/rfc6266#appendix-D) to contain + /// any non-ASCII characters when used in a *Content-Disposition* HTTP response header, where /// [`FilenameExt`](DispositionParam::FilenameExt) with charset UTF-8 may be used instead /// in case there are Unicode characters in file names. Filename(String), + /// An extended file name. It must not exist for `ContentType::Formdata` according to - /// [RFC7578 Section 4.2](https://tools.ietf.org/html/rfc7578#section-4.2). + /// [RFC 7578 §4.2](https://datatracker.ietf.org/doc/html/rfc7578#section-4.2). FilenameExt(ExtendedValue), + /// An unrecognized regular parameter as defined in - /// [RFC5987](https://tools.ietf.org/html/rfc5987) as *reg-parameter*, in - /// [RFC6266](https://tools.ietf.org/html/rfc6266) as *token "=" value*. Recipients should - /// ignore unrecognizable parameters. + /// [RFC 5987 §3.2.1](https://datatracker.ietf.org/doc/html/rfc5987#section-3.2.1) as + /// `reg-parameter`, in + /// [RFC 6266 §4.1](https://datatracker.ietf.org/doc/html/rfc6266#section-4.1) as + /// `token "=" value`. Recipients should ignore unrecognizable parameters. Unknown(String, String), + /// An unrecognized extended parameter as defined in - /// [RFC5987](https://tools.ietf.org/html/rfc5987) as *ext-parameter*, in - /// [RFC6266](https://tools.ietf.org/html/rfc6266) as *ext-token "=" ext-value*. The single - /// trailing asterisk is not included. Recipients should ignore unrecognizable parameters. + /// [RFC 5987 §3.2.1](https://datatracker.ietf.org/doc/html/rfc5987#section-3.2.1) as + /// `ext-parameter`, in + /// [RFC 6266 §4.1](https://datatracker.ietf.org/doc/html/rfc6266#section-4.1) as + /// `ext-token "=" ext-value`. The single trailing asterisk is not included. Recipients should + /// ignore unrecognizable parameters. UnknownExt(String, ExtendedValue), } @@ -187,10 +203,10 @@ impl DispositionParam { } /// A *Content-Disposition* header. It is compatible to be used either as -/// [a response header for the main body](https://mdn.io/Content-Disposition#As_a_response_header_for_the_main_body) -/// as (re)defined in [RFC6266](https://tools.ietf.org/html/rfc6266), or as +/// [a response header for the main body](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition#as_a_response_header_for_the_main_body) +/// as (re)defined in [RFC 6266](https://datatracker.ietf.org/doc/html/rfc6266), or as /// [a header for a multipart body](https://mdn.io/Content-Disposition#As_a_header_for_a_multipart_body) -/// as (re)defined in [RFC7587](https://tools.ietf.org/html/rfc7578). +/// as (re)defined in [RFC 7587](https://datatracker.ietf.org/doc/html/rfc7578). /// /// In a regular HTTP response, the *Content-Disposition* response header is a header indicating if /// the content is expected to be displayed *inline* in the browser, that is, as a Web page or as @@ -204,8 +220,7 @@ impl DispositionParam { /// itself, *Content-Disposition* has no effect. /// /// # ABNF - -/// ```text +/// ```plain /// content-disposition = "Content-Disposition" ":" /// disposition-type *( ";" disposition-parm ) /// @@ -226,21 +241,19 @@ impl DispositionParam { /// ``` /// /// # Note +/// *filename* is [not supposed](https://datatracker.ietf.org/doc/html/rfc6266#appendix-D) to +/// contain any non-ASCII characters when used in a *Content-Disposition* HTTP response header, +/// where filename* with charset UTF-8 may be used instead in case there are Unicode characters in +/// file names. Filename is [acceptable](https://datatracker.ietf.org/doc/html/rfc7578#section-4.2) +/// to be UTF-8 encoded directly in a *Content-Disposition* header for +/// *multipart/form-data*, though. /// -/// filename is [not supposed](https://tools.ietf.org/html/rfc6266#appendix-D) to contain any -/// non-ASCII characters when used in a *Content-Disposition* HTTP response header, where -/// filename* with charset UTF-8 may be used instead in case there are Unicode characters in file -/// names. -/// filename is [acceptable](https://tools.ietf.org/html/rfc7578#section-4.2) to be UTF-8 encoded -/// directly in a *Content-Disposition* header for *multipart/form-data*, though. -/// -/// filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within +/// *filename* [must not](https://datatracker.ietf.org/doc/html/rfc7578#section-4.2) be used within /// *multipart/form-data*. /// -/// # Example -/// +/// # Examples /// ``` -/// use actix_http::http::header::{ +/// use actix_web::http::header::{ /// Charset, ContentDisposition, DispositionParam, DispositionType, /// ExtendedValue, /// }; @@ -284,14 +297,15 @@ impl DispositionParam { /// ``` /// /// # Security Note -/// /// If "filename" parameter is supplied, do not use the file name blindly, check and possibly /// change to match local file system conventions if applicable, and do not use directory path -/// information that may be present. See [RFC2183](https://tools.ietf.org/html/rfc2183#section-2.3). +/// information that may be present. +/// See [RFC 2183 §2.3](https://datatracker.ietf.org/doc/html/rfc2183#section-2.3). #[derive(Clone, Debug, PartialEq)] pub struct ContentDisposition { /// The disposition type pub disposition: DispositionType, + /// Disposition parameters pub parameters: Vec, } @@ -333,7 +347,7 @@ impl ContentDisposition { } else { // regular parameters let value = if left.starts_with('\"') { - // quoted-string: defined in RFC6266 -> RFC2616 Section 3.6 + // quoted-string: defined in RFC 6266 -> RFC 2616 Section 3.6 let mut escaping = false; let mut quoted_string = vec![]; let mut end = None; @@ -384,74 +398,62 @@ impl ContentDisposition { Ok(cd) } - /// Returns `true` if it is [`Inline`](DispositionType::Inline). + /// Returns `true` if type is [`Inline`](DispositionType::Inline). pub fn is_inline(&self) -> bool { matches!(self.disposition, DispositionType::Inline) } - /// Returns `true` if it is [`Attachment`](DispositionType::Attachment). + /// Returns `true` if type is [`Attachment`](DispositionType::Attachment). pub fn is_attachment(&self) -> bool { matches!(self.disposition, DispositionType::Attachment) } - /// Returns `true` if it is [`FormData`](DispositionType::FormData). + /// Returns `true` if type is [`FormData`](DispositionType::FormData). pub fn is_form_data(&self) -> bool { matches!(self.disposition, DispositionType::FormData) } - /// Returns `true` if it is [`Ext`](DispositionType::Ext) and the `disp_type` matches. - pub fn is_ext>(&self, disp_type: T) -> bool { - match self.disposition { - DispositionType::Ext(ref t) - if t.eq_ignore_ascii_case(disp_type.as_ref()) => - { - true - } - _ => false, - } + /// Returns `true` if type is [`Ext`](DispositionType::Ext) and the `disp_type` matches. + pub fn is_ext(&self, disp_type: impl AsRef) -> bool { + matches!( + self.disposition, + DispositionType::Ext(ref t) if t.eq_ignore_ascii_case(disp_type.as_ref()) + ) } /// Return the value of *name* if exists. pub fn get_name(&self) -> Option<&str> { - self.parameters.iter().filter_map(|p| p.as_name()).next() + self.parameters.iter().find_map(DispositionParam::as_name) } /// Return the value of *filename* if exists. pub fn get_filename(&self) -> Option<&str> { self.parameters .iter() - .filter_map(|p| p.as_filename()) - .next() + .find_map(DispositionParam::as_filename) } /// Return the value of *filename\** if exists. pub fn get_filename_ext(&self) -> Option<&ExtendedValue> { self.parameters .iter() - .filter_map(|p| p.as_filename_ext()) - .next() + .find_map(DispositionParam::as_filename_ext) } /// Return the value of the parameter which the `name` matches. - pub fn get_unknown>(&self, name: T) -> Option<&str> { + pub fn get_unknown(&self, name: impl AsRef) -> Option<&str> { let name = name.as_ref(); - self.parameters - .iter() - .filter_map(|p| p.as_unknown(name)) - .next() + self.parameters.iter().find_map(|p| p.as_unknown(name)) } /// Return the value of the extended parameter which the `name` matches. - pub fn get_unknown_ext>(&self, name: T) -> Option<&ExtendedValue> { + pub fn get_unknown_ext(&self, name: impl AsRef) -> Option<&ExtendedValue> { let name = name.as_ref(); - self.parameters - .iter() - .filter_map(|p| p.as_unknown_ext(name)) - .next() + self.parameters.iter().find_map(|p| p.as_unknown_ext(name)) } } -impl IntoHeaderValue for ContentDisposition { +impl TryIntoHeaderValue for ContentDisposition { type Error = header::InvalidHeaderValue; fn try_into_value(self) -> Result { @@ -468,7 +470,7 @@ impl Header for ContentDisposition { fn parse(msg: &T) -> Result { if let Some(h) = msg.headers().get(&Self::name()) { - Self::from_raw(&h) + Self::from_raw(h) } else { Err(crate::error::ParseError::Header) } @@ -490,7 +492,9 @@ impl fmt::Display for DispositionParam { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // All ASCII control characters (0-30, 127) including horizontal tab, double quote, and // backslash should be escaped in quoted-string (i.e. "foobar"). - // Ref: RFC6266 S4.1 -> RFC2616 S3.6 + // + // Ref: RFC 6266 §4.1 -> RFC 2616 §3.6 + // // filename-parm = "filename" "=" value // value = token | quoted-string // quoted-string = ( <"> *(qdtext | quoted-pair ) <"> ) @@ -504,7 +508,7 @@ impl fmt::Display for DispositionParam { // CTL = // - // Ref: RFC7578 S4.2 -> RFC2183 S2 -> RFC2045 S5.1 + // Ref: RFC 7578 S4.2 -> RFC 2183 S2 -> RFC 2045 S5.1 // parameter := attribute "=" value // attribute := token // ; Matching of attributes @@ -520,23 +524,28 @@ impl fmt::Display for DispositionParam { // // // See also comments in test_from_raw_unnecessary_percent_decode. - lazy_static! { - static ref RE: Regex = Regex::new("[\x00-\x08\x10-\x1F\x7F\"\\\\]").unwrap(); - } + + static RE: Lazy = + Lazy::new(|| Regex::new("[\x00-\x08\x10-\x1F\x7F\"\\\\]").unwrap()); + match self { DispositionParam::Name(ref value) => write!(f, "name={}", value), + DispositionParam::Filename(ref value) => { write!(f, "filename=\"{}\"", RE.replace_all(value, "\\$0").as_ref()) } + DispositionParam::Unknown(ref name, ref value) => write!( f, "{}=\"{}\"", name, &RE.replace_all(value, "\\$0").as_ref() ), + DispositionParam::FilenameExt(ref ext_value) => { write!(f, "filename*={}", ext_value) } + DispositionParam::UnknownExt(ref name, ref ext_value) => { write!(f, "{}*={}", name, ext_value) } @@ -556,8 +565,7 @@ impl fmt::Display for ContentDisposition { #[cfg(test)] mod tests { use super::{ContentDisposition, DispositionParam, DispositionType}; - use crate::header::shared::Charset; - use crate::header::{ExtendedValue, HeaderValue}; + use crate::http::header::{Charset, ExtendedValue, HeaderValue}; #[test] fn test_from_raw_basic() { @@ -619,8 +627,8 @@ mod tests { charset: Charset::Ext(String::from("UTF-8")), language_tag: None, value: vec![ - 0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, 0xe2, 0x82, 0xac, 0x20, - b'r', b'a', b't', b'e', b's', + 0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, 0xe2, 0x82, 0xac, 0x20, b'r', + b'a', b't', b'e', b's', ], })], }; @@ -636,8 +644,8 @@ mod tests { charset: Charset::Ext(String::from("UTF-8")), language_tag: None, value: vec![ - 0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, 0xe2, 0x82, 0xac, 0x20, - b'r', b'a', b't', b'e', b's', + 0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, 0xe2, 0x82, 0xac, 0x20, b'r', + b'a', b't', b'e', b's', ], })], }; @@ -699,26 +707,22 @@ mod tests { #[test] fn test_from_raw_only_disp() { - let a = ContentDisposition::from_raw(&HeaderValue::from_static("attachment")) - .unwrap(); + let a = ContentDisposition::from_raw(&HeaderValue::from_static("attachment")).unwrap(); let b = ContentDisposition { disposition: DispositionType::Attachment, parameters: vec![], }; assert_eq!(a, b); - let a = - ContentDisposition::from_raw(&HeaderValue::from_static("inline ;")).unwrap(); + let a = ContentDisposition::from_raw(&HeaderValue::from_static("inline ;")).unwrap(); let b = ContentDisposition { disposition: DispositionType::Inline, parameters: vec![], }; assert_eq!(a, b); - let a = ContentDisposition::from_raw(&HeaderValue::from_static( - "unknown-disp-param", - )) - .unwrap(); + let a = ContentDisposition::from_raw(&HeaderValue::from_static("unknown-disp-param")) + .unwrap(); let b = ContentDisposition { disposition: DispositionType::Ext(String::from("unknown-disp-param")), parameters: vec![], @@ -749,7 +753,7 @@ mod tests { #[test] fn from_raw_with_unicode() { - /* RFC7578 Section 4.2: + /* RFC 7578 Section 4.2: Some commonly deployed systems use multipart/form-data with file names directly encoded including octets outside the US-ASCII range. The encoding used for the file names is typically UTF-8, although HTML forms will use the charset associated with the form. @@ -757,8 +761,8 @@ mod tests { Mainstream browsers like Firefox (gecko) and Chrome use UTF-8 directly as above. (And now, only UTF-8 is handled by this implementation.) */ - let a = HeaderValue::from_str("form-data; name=upload; filename=\"文件.webp\"") - .unwrap(); + let a = + HeaderValue::from_str("form-data; name=upload; filename=\"文件.webp\"").unwrap(); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let b = ContentDisposition { disposition: DispositionType::FormData, @@ -809,8 +813,7 @@ mod tests { #[test] fn test_from_raw_semicolon() { - let a = - HeaderValue::from_static("form-data; filename=\"A semicolon here;.pdf\""); + let a = HeaderValue::from_static("form-data; filename=\"A semicolon here;.pdf\""); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let b = ContentDisposition { disposition: DispositionType::FormData, @@ -823,9 +826,9 @@ mod tests { #[test] fn test_from_raw_unnecessary_percent_decode() { - // In fact, RFC7578 (multipart/form-data) Section 2 and 4.2 suggests that filename with + // In fact, RFC 7578 (multipart/form-data) Section 2 and 4.2 suggests that filename with // non-ASCII characters MAY be percent-encoded. - // On the contrary, RFC6266 or other RFCs related to Content-Disposition response header + // On the contrary, RFC 6266 or other RFCs related to Content-Disposition response header // do not mention such percent-encoding. // So, it appears to be undecidable whether to percent-decode or not without // knowing the usage scenario (multipart/form-data v.s. HTTP response header) and @@ -846,9 +849,8 @@ mod tests { }; assert_eq!(a, b); - let a = HeaderValue::from_static( - "form-data; name=photo; filename=\"%74%65%73%74.png\"", - ); + let a = + HeaderValue::from_static("form-data; name=photo; filename=\"%74%65%73%74.png\""); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let b = ContentDisposition { disposition: DispositionType::FormData, @@ -893,8 +895,7 @@ mod tests { #[test] fn test_display_extended() { - let as_string = - "attachment; filename*=UTF-8'en'%C2%A3%20and%20%E2%82%AC%20rates"; + let as_string = "attachment; filename*=UTF-8'en'%C2%A3%20and%20%E2%82%AC%20rates"; let a = HeaderValue::from_static(as_string); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let display_rendered = format!("{}", a); diff --git a/src/http/header/content_language.rs b/src/http/header/content_language.rs new file mode 100644 index 000000000..ff317e1de --- /dev/null +++ b/src/http/header/content_language.rs @@ -0,0 +1,54 @@ +use language_tags::LanguageTag; + +use super::{common_header, QualityItem, CONTENT_LANGUAGE}; + +common_header! { + /// `Content-Language` header, defined + /// in [RFC 7231 §3.1.3.2](https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.3.2) + /// + /// The `Content-Language` header field describes the natural language(s) + /// of the intended audience for the representation. Note that this + /// might not be equivalent to all the languages used within the + /// representation. + /// + /// # ABNF + /// ```plain + /// Content-Language = 1#language-tag + /// ``` + /// + /// # Example Values + /// * `da` + /// * `mi, en` + /// + /// # Examples + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{ContentLanguage, LanguageTag, QualityItem}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header( + /// ContentLanguage(vec![ + /// QualityItem::max(LanguageTag::parse("en").unwrap()), + /// ]) + /// ); + /// ``` + /// + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{ContentLanguage, LanguageTag, QualityItem}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header( + /// ContentLanguage(vec![ + /// QualityItem::max(LanguageTag::parse("da").unwrap()), + /// QualityItem::max(LanguageTag::parse("en-GB").unwrap()), + /// ]) + /// ); + /// ``` + (ContentLanguage, CONTENT_LANGUAGE) => (QualityItem)+ + + test_parse_and_format { + crate::http::header::common_header_test!(test1, vec![b"da"]); + crate::http::header::common_header_test!(test2, vec![b"mi, en"]); + } +} diff --git a/actix-http/src/header/common/content_range.rs b/src/http/header/content_range.rs similarity index 78% rename from actix-http/src/header/common/content_range.rs rename to src/http/header/content_range.rs index 8b7552377..bcbe77e66 100644 --- a/actix-http/src/header/common/content_range.rs +++ b/src/http/header/content_range.rs @@ -1,81 +1,81 @@ -use std::fmt::{self, Display, Write}; -use std::str::FromStr; - -use crate::error::ParseError; -use crate::header::{ - HeaderValue, IntoHeaderValue, InvalidHeaderValue, Writer, CONTENT_RANGE, +use std::{ + fmt::{self, Display, Write}, + str::FromStr, }; -header! { - /// `Content-Range` header, defined in - /// [RFC7233](http://tools.ietf.org/html/rfc7233#section-4.2) +use super::{HeaderValue, InvalidHeaderValue, TryIntoHeaderValue, Writer, CONTENT_RANGE}; +use crate::error::ParseError; + +crate::http::header::common_header! { + /// `Content-Range` header, defined + /// in [RFC 7233 §4.2](https://datatracker.ietf.org/doc/html/rfc7233#section-4.2) (ContentRange, CONTENT_RANGE) => [ContentRangeSpec] - test_content_range { - test_header!(test_bytes, + test_parse_and_format { + crate::http::header::common_header_test!(test_bytes, vec![b"bytes 0-499/500"], Some(ContentRange(ContentRangeSpec::Bytes { range: Some((0, 499)), instance_length: Some(500) }))); - test_header!(test_bytes_unknown_len, + crate::http::header::common_header_test!(test_bytes_unknown_len, vec![b"bytes 0-499/*"], Some(ContentRange(ContentRangeSpec::Bytes { range: Some((0, 499)), instance_length: None }))); - test_header!(test_bytes_unknown_range, + crate::http::header::common_header_test!(test_bytes_unknown_range, vec![b"bytes */500"], Some(ContentRange(ContentRangeSpec::Bytes { range: None, instance_length: Some(500) }))); - test_header!(test_unregistered, + crate::http::header::common_header_test!(test_unregistered, vec![b"seconds 1-2"], Some(ContentRange(ContentRangeSpec::Unregistered { unit: "seconds".to_owned(), resp: "1-2".to_owned() }))); - test_header!(test_no_len, + crate::http::header::common_header_test!(test_no_len, vec![b"bytes 0-499"], None::); - test_header!(test_only_unit, + crate::http::header::common_header_test!(test_only_unit, vec![b"bytes"], None::); - test_header!(test_end_less_than_start, + crate::http::header::common_header_test!(test_end_less_than_start, vec![b"bytes 499-0/500"], None::); - test_header!(test_blank, + crate::http::header::common_header_test!(test_blank, vec![b""], None::); - test_header!(test_bytes_many_spaces, + crate::http::header::common_header_test!(test_bytes_many_spaces, vec![b"bytes 1-2/500 3"], None::); - test_header!(test_bytes_many_slashes, + crate::http::header::common_header_test!(test_bytes_many_slashes, vec![b"bytes 1-2/500/600"], None::); - test_header!(test_bytes_many_dashes, + crate::http::header::common_header_test!(test_bytes_many_dashes, vec![b"bytes 1-2-3/500"], None::); } } -/// Content-Range, described in [RFC7233](https://tools.ietf.org/html/rfc7233#section-4.2) +/// Content-Range header, defined +/// in [RFC 7233 §4.2](https://datatracker.ietf.org/doc/html/rfc7233#section-4.2) /// /// # ABNF -/// -/// ```text +/// ```plain /// Content-Range = byte-content-range /// / other-content-range /// @@ -91,7 +91,7 @@ header! { /// other-content-range = other-range-unit SP other-range-resp /// other-range-resp = *CHAR /// ``` -#[derive(PartialEq, Clone, Debug)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum ContentRangeSpec { /// Byte range Bytes { @@ -141,8 +141,7 @@ impl FromStr for ContentRangeSpec { } else { let (first_byte, last_byte) = split_in_two(range, '-').ok_or(ParseError::Header)?; - let first_byte = - first_byte.parse().map_err(|_| ParseError::Header)?; + let first_byte = first_byte.parse().map_err(|_| ParseError::Header)?; let last_byte = last_byte.parse().map_err(|_| ParseError::Header)?; if last_byte < first_byte { return Err(ParseError::Header); @@ -197,7 +196,7 @@ impl Display for ContentRangeSpec { } } -impl IntoHeaderValue for ContentRangeSpec { +impl TryIntoHeaderValue for ContentRangeSpec { type Error = InvalidHeaderValue; fn try_into_value(self) -> Result { diff --git a/actix-http/src/header/common/content_type.rs b/src/http/header/content_type.rs similarity index 65% rename from actix-http/src/header/common/content_type.rs rename to src/http/header/content_type.rs index ac5c7e5b8..1fc75d0e2 100644 --- a/actix-http/src/header/common/content_type.rs +++ b/src/http/header/content_type.rs @@ -1,9 +1,9 @@ -use crate::header::CONTENT_TYPE; +use super::CONTENT_TYPE; use mime::Mime; -header! { - /// `Content-Type` header, defined in - /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-3.1.1.5) +crate::http::header::common_header! { + /// `Content-Type` header, defined + /// in [RFC 7231 §3.1.1.5](https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.5) /// /// The `Content-Type` header field indicates the media type of the /// associated representation: either the representation enclosed in the @@ -18,41 +18,38 @@ header! { /// this is an issue, it's possible to implement `Header` on a custom struct. /// /// # ABNF - /// - /// ```text + /// ```plain /// Content-Type = media-type /// ``` /// - /// # Example values - /// + /// # Example Values /// * `text/html; charset=utf-8` /// * `application/json` /// /// # Examples - /// /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::ContentType; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::ContentType; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// ContentType::json() /// ); /// ``` /// /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::ContentType; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::ContentType; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// ContentType(mime::TEXT_HTML) /// ); /// ``` (ContentType, CONTENT_TYPE) => [Mime] - test_content_type { - test_header!( + test_parse_and_format { + crate::http::header::common_header_test!( test1, vec![b"text/html"], Some(HeaderField(mime::TEXT_HTML))); @@ -60,57 +57,56 @@ header! { } impl ContentType { - /// A constructor to easily create a `Content-Type: application/json` + /// A constructor to easily create a `Content-Type: application/json` /// header. #[inline] pub fn json() -> ContentType { ContentType(mime::APPLICATION_JSON) } - /// A constructor to easily create a `Content-Type: text/plain; + /// A constructor to easily create a `Content-Type: text/plain; /// charset=utf-8` header. #[inline] pub fn plaintext() -> ContentType { ContentType(mime::TEXT_PLAIN_UTF_8) } - /// A constructor to easily create a `Content-Type: text/html` header. + /// A constructor to easily create a `Content-Type: text/html; charset=utf-8` + /// header. #[inline] pub fn html() -> ContentType { - ContentType(mime::TEXT_HTML) + ContentType(mime::TEXT_HTML_UTF_8) } - /// A constructor to easily create a `Content-Type: text/xml` header. + /// A constructor to easily create a `Content-Type: text/xml` header. #[inline] pub fn xml() -> ContentType { ContentType(mime::TEXT_XML) } - /// A constructor to easily create a `Content-Type: + /// A constructor to easily create a `Content-Type: /// application/www-form-url-encoded` header. #[inline] pub fn form_url_encoded() -> ContentType { ContentType(mime::APPLICATION_WWW_FORM_URLENCODED) } - /// A constructor to easily create a `Content-Type: image/jpeg` header. + /// A constructor to easily create a `Content-Type: image/jpeg` header. #[inline] pub fn jpeg() -> ContentType { ContentType(mime::IMAGE_JPEG) } - /// A constructor to easily create a `Content-Type: image/png` header. + /// A constructor to easily create a `Content-Type: image/png` header. #[inline] pub fn png() -> ContentType { ContentType(mime::IMAGE_PNG) } - /// A constructor to easily create a `Content-Type: + /// A constructor to easily create a `Content-Type: /// application/octet-stream` header. #[inline] pub fn octet_stream() -> ContentType { ContentType(mime::APPLICATION_OCTET_STREAM) } } - -impl Eq for ContentType {} diff --git a/actix-http/src/header/common/date.rs b/src/http/header/date.rs similarity index 56% rename from actix-http/src/header/common/date.rs rename to src/http/header/date.rs index e5ace95e6..4063deab1 100644 --- a/actix-http/src/header/common/date.rs +++ b/src/http/header/date.rs @@ -1,38 +1,37 @@ -use crate::header::{HttpDate, DATE}; +use super::{HttpDate, DATE}; use std::time::SystemTime; -header! { - /// `Date` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-7.1.1.2) +crate::http::header::common_header! { + /// `Date` header, defined + /// in [RFC 7231 §7.1.1.2](https://datatracker.ietf.org/doc/html/rfc7231#section-7.1.1.2) /// /// The `Date` header field represents the date and time at which the /// message was originated. /// /// # ABNF - /// - /// ```text + /// ```plain /// Date = HTTP-date /// ``` /// - /// # Example values - /// + /// # Example Values /// * `Tue, 15 Nov 1994 08:12:31 GMT` /// /// # Example /// /// ``` /// use std::time::SystemTime; - /// use actix_http::Response; - /// use actix_http::http::header::Date; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::Date; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// Date(SystemTime::now().into()) /// ); /// ``` (Date, DATE) => [HttpDate] - test_date { - test_header!(test1, vec![b"Tue, 15 Nov 1994 08:12:31 GMT"]); + test_parse_and_format { + crate::http::header::common_header_test!(test1, vec![b"Tue, 15 Nov 1994 08:12:31 GMT"]); } } diff --git a/src/http/header/encoding.rs b/src/http/header/encoding.rs new file mode 100644 index 000000000..ab11f50be --- /dev/null +++ b/src/http/header/encoding.rs @@ -0,0 +1,55 @@ +use std::{fmt, str}; + +use actix_http::ContentEncoding; + +/// A value to represent an encoding used in the `Accept-Encoding` and `Content-Encoding` header. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Encoding { + /// A supported content encoding. See [`ContentEncoding`] for variants. + Known(ContentEncoding), + + /// Some other encoding that is less common, can be any string. + Unknown(String), +} + +impl Encoding { + pub const fn identity() -> Self { + Self::Known(ContentEncoding::Identity) + } + + pub const fn brotli() -> Self { + Self::Known(ContentEncoding::Brotli) + } + + pub const fn deflate() -> Self { + Self::Known(ContentEncoding::Deflate) + } + + pub const fn gzip() -> Self { + Self::Known(ContentEncoding::Gzip) + } + + pub const fn zstd() -> Self { + Self::Known(ContentEncoding::Zstd) + } +} + +impl fmt::Display for Encoding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Encoding::Known(enc) => enc.as_str(), + Encoding::Unknown(enc) => enc.as_str(), + }) + } +} + +impl str::FromStr for Encoding { + type Err = crate::error::ParseError; + + fn from_str(enc: &str) -> Result { + match enc.parse::() { + Ok(enc) => Ok(Self::Known(enc)), + Err(_) => Ok(Self::Unknown(enc.to_owned())), + } + } +} diff --git a/actix-http/src/header/shared/entity.rs b/src/http/header/entity.rs similarity index 75% rename from actix-http/src/header/shared/entity.rs rename to src/http/header/entity.rs index eb383cd6f..0eaa12b5d 100644 --- a/actix-http/src/header/shared/entity.rs +++ b/src/http/header/entity.rs @@ -1,7 +1,9 @@ -use std::fmt::{self, Display, Write}; -use std::str::FromStr; +use std::{ + fmt::{self, Display, Write}, + str::FromStr, +}; -use crate::header::{HeaderValue, IntoHeaderValue, InvalidHeaderValue, Writer}; +use super::{HeaderValue, InvalidHeaderValue, TryIntoHeaderValue, Writer}; /// check that each char in the slice is either: /// 1. `%x21`, or @@ -15,7 +17,7 @@ fn check_slice_validity(slice: &str) -> bool { slice.bytes().all(entity_validate_char) } -/// An entity tag, defined in [RFC7232](https://tools.ietf.org/html/rfc7232#section-2.3) +/// An entity tag, defined in [RFC 7232 §2.3]. /// /// An entity tag consists of a string enclosed by two literal double quotes. /// Preceding the first double quote is an optional weakness indicator, @@ -23,8 +25,7 @@ fn check_slice_validity(slice: &str) -> bool { /// `W/"xyzzy"`. /// /// # ABNF -/// -/// ```text +/// ```plain /// entity-tag = [ weak ] opaque-tag /// weak = %x57.2F ; "W/", case-sensitive /// opaque-tag = DQUOTE *etagc DQUOTE @@ -46,16 +47,20 @@ fn check_slice_validity(slice: &str) -> bool { /// | `W/"1"` | `W/"2"` | no match | no match | /// | `W/"1"` | `"1"` | no match | match | /// | `"1"` | `"1"` | match | match | -#[derive(Clone, Debug, Eq, PartialEq)] +/// +/// [RFC 7232 §2.3](https://datatracker.ietf.org/doc/html/rfc7232#section-2.3) +#[derive(Debug, Clone, PartialEq, Eq)] pub struct EntityTag { /// Weakness indicator for the tag pub weak: bool, + /// The opaque string in between the DQUOTEs tag: String, } impl EntityTag { - /// Constructs a new EntityTag. + /// Constructs a new `EntityTag`. + /// /// # Panics /// If the tag contains invalid characters. pub fn new(weak: bool, tag: String) -> EntityTag { @@ -64,51 +69,64 @@ impl EntityTag { } /// Constructs a new weak EntityTag. + /// /// # Panics /// If the tag contains invalid characters. - pub fn weak(tag: String) -> EntityTag { + pub fn new_weak(tag: String) -> EntityTag { EntityTag::new(true, tag) } + #[deprecated(since = "3.0.0", note = "Renamed to `new_weak`.")] + pub fn weak(tag: String) -> EntityTag { + Self::new_weak(tag) + } + /// Constructs a new strong EntityTag. + /// /// # Panics /// If the tag contains invalid characters. - pub fn strong(tag: String) -> EntityTag { + pub fn new_strong(tag: String) -> EntityTag { EntityTag::new(false, tag) } - /// Get the tag. + #[deprecated(since = "3.0.0", note = "Renamed to `new_strong`.")] + pub fn strong(tag: String) -> EntityTag { + Self::new_strong(tag) + } + + /// Returns tag. pub fn tag(&self) -> &str { self.tag.as_ref() } - /// Set the tag. + /// Sets tag. + /// /// # Panics /// If the tag contains invalid characters. - pub fn set_tag(&mut self, tag: String) { + pub fn set_tag(&mut self, tag: impl Into) { + let tag = tag.into(); assert!(check_slice_validity(&tag), "Invalid tag: {:?}", tag); self.tag = tag } - /// For strong comparison two entity-tags are equivalent if both are not - /// weak and their opaque-tags match character-by-character. + /// For strong comparison two entity-tags are equivalent if both are not weak and their + /// opaque-tags match character-by-character. pub fn strong_eq(&self, other: &EntityTag) -> bool { !self.weak && !other.weak && self.tag == other.tag } - /// For weak comparison two entity-tags are equivalent if their - /// opaque-tags match character-by-character, regardless of either or - /// both being tagged as "weak". + /// For weak comparison two entity-tags are equivalent if their opaque-tags match + /// character-by-character, regardless of either or both being tagged as "weak". pub fn weak_eq(&self, other: &EntityTag) -> bool { self.tag == other.tag } - /// The inverse of `EntityTag.strong_eq()`. + /// Returns the inverse of `strong_eq()`. pub fn strong_ne(&self, other: &EntityTag) -> bool { !self.strong_eq(other) } - /// The inverse of `EntityTag.weak_eq()`. + /// Returns inverse of `weak_eq()`. pub fn weak_ne(&self, other: &EntityTag) -> bool { !self.weak_eq(other) } @@ -127,9 +145,8 @@ impl Display for EntityTag { impl FromStr for EntityTag { type Err = crate::error::ParseError; - fn from_str(s: &str) -> Result { - let length: usize = s.len(); - let slice = &s[..]; + fn from_str(slice: &str) -> Result { + let length = slice.len(); // Early exits if it doesn't terminate in a DQUOTE. if !slice.ends_with('"') || slice.len() < 2 { return Err(crate::error::ParseError::Header); @@ -158,7 +175,7 @@ impl FromStr for EntityTag { } } -impl IntoHeaderValue for EntityTag { +impl TryIntoHeaderValue for EntityTag { type Error = InvalidHeaderValue; fn try_into_value(self) -> Result { @@ -177,23 +194,23 @@ mod tests { // Expected success assert_eq!( "\"foobar\"".parse::().unwrap(), - EntityTag::strong("foobar".to_owned()) + EntityTag::new_strong("foobar".to_owned()) ); assert_eq!( "\"\"".parse::().unwrap(), - EntityTag::strong("".to_owned()) + EntityTag::new_strong("".to_owned()) ); assert_eq!( "W/\"weaktag\"".parse::().unwrap(), - EntityTag::weak("weaktag".to_owned()) + EntityTag::new_weak("weaktag".to_owned()) ); assert_eq!( "W/\"\x65\x62\"".parse::().unwrap(), - EntityTag::weak("\x65\x62".to_owned()) + EntityTag::new_weak("\x65\x62".to_owned()) ); assert_eq!( "W/\"\"".parse::().unwrap(), - EntityTag::weak("".to_owned()) + EntityTag::new_weak("".to_owned()) ); } @@ -213,19 +230,19 @@ mod tests { #[test] fn test_etag_fmt() { assert_eq!( - format!("{}", EntityTag::strong("foobar".to_owned())), + format!("{}", EntityTag::new_strong("foobar".to_owned())), "\"foobar\"" ); - assert_eq!(format!("{}", EntityTag::strong("".to_owned())), "\"\""); + assert_eq!(format!("{}", EntityTag::new_strong("".to_owned())), "\"\""); assert_eq!( - format!("{}", EntityTag::weak("weak-etag".to_owned())), + format!("{}", EntityTag::new_weak("weak-etag".to_owned())), "W/\"weak-etag\"" ); assert_eq!( - format!("{}", EntityTag::weak("\u{0065}".to_owned())), + format!("{}", EntityTag::new_weak("\u{0065}".to_owned())), "W/\"\x65\"" ); - assert_eq!(format!("{}", EntityTag::weak("".to_owned())), "W/\"\""); + assert_eq!(format!("{}", EntityTag::new_weak("".to_owned())), "W/\"\""); } #[test] @@ -236,29 +253,29 @@ mod tests { // | `W/"1"` | `W/"2"` | no match | no match | // | `W/"1"` | `"1"` | no match | match | // | `"1"` | `"1"` | match | match | - let mut etag1 = EntityTag::weak("1".to_owned()); - let mut etag2 = EntityTag::weak("1".to_owned()); + let mut etag1 = EntityTag::new_weak("1".to_owned()); + let mut etag2 = EntityTag::new_weak("1".to_owned()); assert!(!etag1.strong_eq(&etag2)); assert!(etag1.weak_eq(&etag2)); assert!(etag1.strong_ne(&etag2)); assert!(!etag1.weak_ne(&etag2)); - etag1 = EntityTag::weak("1".to_owned()); - etag2 = EntityTag::weak("2".to_owned()); + etag1 = EntityTag::new_weak("1".to_owned()); + etag2 = EntityTag::new_weak("2".to_owned()); assert!(!etag1.strong_eq(&etag2)); assert!(!etag1.weak_eq(&etag2)); assert!(etag1.strong_ne(&etag2)); assert!(etag1.weak_ne(&etag2)); - etag1 = EntityTag::weak("1".to_owned()); - etag2 = EntityTag::strong("1".to_owned()); + etag1 = EntityTag::new_weak("1".to_owned()); + etag2 = EntityTag::new_strong("1".to_owned()); assert!(!etag1.strong_eq(&etag2)); assert!(etag1.weak_eq(&etag2)); assert!(etag1.strong_ne(&etag2)); assert!(!etag1.weak_ne(&etag2)); - etag1 = EntityTag::strong("1".to_owned()); - etag2 = EntityTag::strong("1".to_owned()); + etag1 = EntityTag::new_strong("1".to_owned()); + etag2 = EntityTag::new_strong("1".to_owned()); assert!(etag1.strong_eq(&etag2)); assert!(etag1.weak_eq(&etag2)); assert!(!etag1.strong_ne(&etag2)); diff --git a/src/http/header/etag.rs b/src/http/header/etag.rs new file mode 100644 index 000000000..78f5447b3 --- /dev/null +++ b/src/http/header/etag.rs @@ -0,0 +1,98 @@ +use super::{EntityTag, ETAG}; + +crate::http::header::common_header! { + /// `ETag` header, defined in + /// [RFC 7232 §2.3](https://datatracker.ietf.org/doc/html/rfc7232#section-2.3) + /// + /// The `ETag` header field in a response provides the current entity-tag + /// for the selected representation, as determined at the conclusion of + /// handling the request. An entity-tag is an opaque validator for + /// differentiating between multiple representations of the same + /// resource, regardless of whether those multiple representations are + /// due to resource state changes over time, content negotiation + /// resulting in multiple representations being valid at the same time, + /// or both. An entity-tag consists of an opaque quoted string, possibly + /// prefixed by a weakness indicator. + /// + /// # ABNF + /// ```plain + /// ETag = entity-tag + /// ``` + /// + /// # Example Values + /// * `"xyzzy"` + /// * `W/"xyzzy"` + /// * `""` + /// + /// # Examples + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{ETag, EntityTag}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header( + /// ETag(EntityTag::new_strong("xyzzy".to_owned())) + /// ); + /// ``` + /// + /// ``` + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{ETag, EntityTag}; + /// + /// let mut builder = HttpResponse::Ok(); + /// builder.insert_header( + /// ETag(EntityTag::new_weak("xyzzy".to_owned())) + /// ); + /// ``` + (ETag, ETAG) => [EntityTag] + + test_parse_and_format { + // From the RFC + crate::http::header::common_header_test!(test1, + vec![b"\"xyzzy\""], + Some(ETag(EntityTag::new_strong("xyzzy".to_owned())))); + crate::http::header::common_header_test!(test2, + vec![b"W/\"xyzzy\""], + Some(ETag(EntityTag::new_weak("xyzzy".to_owned())))); + crate::http::header::common_header_test!(test3, + vec![b"\"\""], + Some(ETag(EntityTag::new_strong("".to_owned())))); + // Own tests + crate::http::header::common_header_test!(test4, + vec![b"\"foobar\""], + Some(ETag(EntityTag::new_strong("foobar".to_owned())))); + crate::http::header::common_header_test!(test5, + vec![b"\"\""], + Some(ETag(EntityTag::new_strong("".to_owned())))); + crate::http::header::common_header_test!(test6, + vec![b"W/\"weak-etag\""], + Some(ETag(EntityTag::new_weak("weak-etag".to_owned())))); + crate::http::header::common_header_test!(test7, + vec![b"W/\"\x65\x62\""], + Some(ETag(EntityTag::new_weak("\u{0065}\u{0062}".to_owned())))); + crate::http::header::common_header_test!(test8, + vec![b"W/\"\""], + Some(ETag(EntityTag::new_weak("".to_owned())))); + crate::http::header::common_header_test!(test9, + vec![b"no-dquotes"], + None::); + crate::http::header::common_header_test!(test10, + vec![b"w/\"the-first-w-is-case-sensitive\""], + None::); + crate::http::header::common_header_test!(test11, + vec![b""], + None::); + crate::http::header::common_header_test!(test12, + vec![b"\"unmatched-dquotes1"], + None::); + crate::http::header::common_header_test!(test13, + vec![b"unmatched-dquotes2\""], + None::); + crate::http::header::common_header_test!(test14, + vec![b"matched-\"dquotes\""], + None::); + crate::http::header::common_header_test!(test15, + vec![b"\""], + None::); + } +} diff --git a/actix-http/src/header/common/expires.rs b/src/http/header/expires.rs similarity index 60% rename from actix-http/src/header/common/expires.rs rename to src/http/header/expires.rs index 79563955d..5b6c65c53 100644 --- a/actix-http/src/header/common/expires.rs +++ b/src/http/header/expires.rs @@ -1,7 +1,8 @@ -use crate::header::{HttpDate, EXPIRES}; +use super::{HttpDate, EXPIRES}; -header! { - /// `Expires` header, defined in [RFC7234](http://tools.ietf.org/html/rfc7234#section-5.3) +crate::http::header::common_header! { + /// `Expires` header, defined + /// in [RFC 7234 §5.3](https://datatracker.ietf.org/doc/html/rfc7234#section-5.3) /// /// The `Expires` header field gives the date/time after which the /// response is considered stale. @@ -11,22 +12,21 @@ header! { /// time. /// /// # ABNF - /// - /// ```text + /// ```plain /// Expires = HTTP-date /// ``` /// - /// # Example values + /// # Example Values /// * `Thu, 01 Dec 1994 16:00:00 GMT` /// /// # Example /// /// ``` /// use std::time::{SystemTime, Duration}; - /// use actix_http::Response; - /// use actix_http::http::header::Expires; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::Expires; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// let expiration = SystemTime::now() + Duration::from_secs(60 * 60 * 24); /// builder.insert_header( /// Expires(expiration.into()) @@ -34,8 +34,8 @@ header! { /// ``` (Expires, EXPIRES) => [HttpDate] - test_expires { + test_parse_and_format { // Test case from RFC - test_header!(test1, vec![b"Thu, 01 Dec 1994 16:00:00 GMT"]); + crate::http::header::common_header_test!(test1, vec![b"Thu, 01 Dec 1994 16:00:00 GMT"]); } } diff --git a/actix-http/src/header/common/if_match.rs b/src/http/header/if_match.rs similarity index 61% rename from actix-http/src/header/common/if_match.rs rename to src/http/header/if_match.rs index db255e91a..e299d30fe 100644 --- a/actix-http/src/header/common/if_match.rs +++ b/src/http/header/if_match.rs @@ -1,8 +1,8 @@ -use crate::header::{EntityTag, IF_MATCH}; +use super::{common_header, EntityTag, IF_MATCH}; -header! { - /// `If-Match` header, defined in - /// [RFC7232](https://tools.ietf.org/html/rfc7232#section-3.1) +common_header! { + /// `If-Match` header, defined + /// in [RFC 7232 §3.1](https://datatracker.ietf.org/doc/html/rfc7232#section-3.1) /// /// The `If-Match` header field makes the request method conditional on /// the recipient origin server either having at least one current @@ -17,31 +17,28 @@ header! { /// there have been any changes to the representation data. /// /// # ABNF - /// - /// ```text + /// ```plain /// If-Match = "*" / 1#entity-tag /// ``` /// - /// # Example values - /// + /// # Example Values /// * `"xyzzy"` /// * "xyzzy", "r2d2xxxx", "c3piozzzz" /// /// # Examples - /// /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::IfMatch; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::IfMatch; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header(IfMatch::Any); /// ``` /// /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::{IfMatch, EntityTag}; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{IfMatch, EntityTag}; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// IfMatch::Items(vec![ /// EntityTag::new(false, "xyzzy".to_owned()), @@ -52,19 +49,20 @@ header! { /// ``` (IfMatch, IF_MATCH) => {Any / (EntityTag)+} - test_if_match { - test_header!( + test_parse_and_format { + crate::http::header::common_header_test!( test1, vec![b"\"xyzzy\""], Some(HeaderField::Items( - vec![EntityTag::new(false, "xyzzy".to_owned())]))); - test_header!( + vec![EntityTag::new_strong("xyzzy".to_owned())]))); + + crate::http::header::common_header_test!( test2, vec![b"\"xyzzy\", \"r2d2xxxx\", \"c3piozzzz\""], Some(HeaderField::Items( - vec![EntityTag::new(false, "xyzzy".to_owned()), - EntityTag::new(false, "r2d2xxxx".to_owned()), - EntityTag::new(false, "c3piozzzz".to_owned())]))); - test_header!(test3, vec![b"*"], Some(IfMatch::Any)); + vec![EntityTag::new_strong("xyzzy".to_owned()), + EntityTag::new_strong("r2d2xxxx".to_owned()), + EntityTag::new_strong("c3piozzzz".to_owned())]))); + crate::http::header::common_header_test!(test3, vec![b"*"], Some(IfMatch::Any)); } } diff --git a/actix-http/src/header/common/if_modified_since.rs b/src/http/header/if_modified_since.rs similarity index 61% rename from actix-http/src/header/common/if_modified_since.rs rename to src/http/header/if_modified_since.rs index 99c7e441d..14d6c3553 100644 --- a/actix-http/src/header/common/if_modified_since.rs +++ b/src/http/header/if_modified_since.rs @@ -1,8 +1,8 @@ -use crate::header::{HttpDate, IF_MODIFIED_SINCE}; +use super::{HttpDate, IF_MODIFIED_SINCE}; -header! { - /// `If-Modified-Since` header, defined in - /// [RFC7232](http://tools.ietf.org/html/rfc7232#section-3.3) +crate::http::header::common_header! { + /// `If-Modified-Since` header, defined + /// in [RFC 7232 §3.3](https://datatracker.ietf.org/doc/html/rfc7232#section-3.3) /// /// The `If-Modified-Since` header field makes a GET or HEAD request /// method conditional on the selected representation's modification date @@ -11,22 +11,21 @@ header! { /// data has not changed. /// /// # ABNF - /// - /// ```text + /// ```plain /// If-Unmodified-Since = HTTP-date /// ``` /// - /// # Example values + /// # Example Values /// * `Sat, 29 Oct 1994 19:43:31 GMT` /// /// # Example /// /// ``` /// use std::time::{SystemTime, Duration}; - /// use actix_http::Response; - /// use actix_http::http::header::IfModifiedSince; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::IfModifiedSince; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); /// builder.insert_header( /// IfModifiedSince(modified.into()) @@ -34,8 +33,8 @@ header! { /// ``` (IfModifiedSince, IF_MODIFIED_SINCE) => [HttpDate] - test_if_modified_since { + test_parse_and_format { // Test case from RFC - test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); + crate::http::header::common_header_test!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); } } diff --git a/actix-http/src/header/common/if_none_match.rs b/src/http/header/if_none_match.rs similarity index 63% rename from actix-http/src/header/common/if_none_match.rs rename to src/http/header/if_none_match.rs index 464caf1ae..863be70cf 100644 --- a/actix-http/src/header/common/if_none_match.rs +++ b/src/http/header/if_none_match.rs @@ -1,8 +1,8 @@ -use crate::header::{EntityTag, IF_NONE_MATCH}; +use super::{EntityTag, IF_NONE_MATCH}; -header! { - /// `If-None-Match` header, defined in - /// [RFC7232](https://tools.ietf.org/html/rfc7232#section-3.2) +crate::http::header::common_header! { + /// `If-None-Match` header, defined + /// in [RFC 7232 §3.2](https://datatracker.ietf.org/doc/html/rfc7232#section-3.2) /// /// The `If-None-Match` header field makes the request method conditional /// on a recipient cache or origin server either not having any current @@ -16,13 +16,11 @@ header! { /// the representation data. /// /// # ABNF - /// - /// ```text + /// ```plain /// If-None-Match = "*" / 1#entity-tag /// ``` /// - /// # Example values - /// + /// # Example Values /// * `"xyzzy"` /// * `W/"xyzzy"` /// * `"xyzzy", "r2d2xxxx", "c3piozzzz"` @@ -30,20 +28,19 @@ header! { /// * `*` /// /// # Examples - /// /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::IfNoneMatch; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::IfNoneMatch; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header(IfNoneMatch::Any); /// ``` /// /// ``` - /// use actix_http::Response; - /// use actix_http::http::header::{IfNoneMatch, EntityTag}; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::{IfNoneMatch, EntityTag}; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// IfNoneMatch::Items(vec![ /// EntityTag::new(false, "xyzzy".to_owned()), @@ -54,20 +51,20 @@ header! { /// ``` (IfNoneMatch, IF_NONE_MATCH) => {Any / (EntityTag)+} - test_if_none_match { - test_header!(test1, vec![b"\"xyzzy\""]); - test_header!(test2, vec![b"W/\"xyzzy\""]); - test_header!(test3, vec![b"\"xyzzy\", \"r2d2xxxx\", \"c3piozzzz\""]); - test_header!(test4, vec![b"W/\"xyzzy\", W/\"r2d2xxxx\", W/\"c3piozzzz\""]); - test_header!(test5, vec![b"*"]); + test_parse_and_format { + crate::http::header::common_header_test!(test1, vec![b"\"xyzzy\""]); + crate::http::header::common_header_test!(test2, vec![b"W/\"xyzzy\""]); + crate::http::header::common_header_test!(test3, vec![b"\"xyzzy\", \"r2d2xxxx\", \"c3piozzzz\""]); + crate::http::header::common_header_test!(test4, vec![b"W/\"xyzzy\", W/\"r2d2xxxx\", W/\"c3piozzzz\""]); + crate::http::header::common_header_test!(test5, vec![b"*"]); } } #[cfg(test)] mod tests { use super::IfNoneMatch; - use crate::header::{EntityTag, Header, IF_NONE_MATCH}; - use crate::test::TestRequest; + use crate::http::header::{EntityTag, Header, IF_NONE_MATCH}; + use actix_http::test::TestRequest; #[test] fn test_if_none_match() { @@ -85,8 +82,8 @@ mod tests { if_none_match = Header::parse(&req); let mut entities: Vec = Vec::new(); - let foobar_etag = EntityTag::new(false, "foobar".to_owned()); - let weak_etag = EntityTag::new(true, "weak-etag".to_owned()); + let foobar_etag = EntityTag::new_strong("foobar".to_owned()); + let weak_etag = EntityTag::new_weak("weak-etag".to_owned()); entities.push(foobar_etag); entities.push(weak_etag); assert_eq!(if_none_match.ok(), Some(IfNoneMatch::Items(entities))); diff --git a/actix-http/src/header/common/if_range.rs b/src/http/header/if_range.rs similarity index 69% rename from actix-http/src/header/common/if_range.rs rename to src/http/header/if_range.rs index 0a5749505..b845fb3bf 100644 --- a/actix-http/src/header/common/if_range.rs +++ b/src/http/header/if_range.rs @@ -1,13 +1,15 @@ use std::fmt::{self, Display, Write}; -use crate::error::ParseError; -use crate::header::{ - self, from_one_raw_str, EntityTag, Header, HeaderName, HeaderValue, HttpDate, - IntoHeaderValue, InvalidHeaderValue, Writer, +use super::{ + from_one_raw_str, EntityTag, Header, HeaderName, HeaderValue, HttpDate, InvalidHeaderValue, + TryIntoHeaderValue, Writer, }; +use crate::error::ParseError; +use crate::http::header; use crate::HttpMessage; -/// `If-Range` header, defined in [RFC7233](http://tools.ietf.org/html/rfc7233#section-3.2) +/// `If-Range` header, defined +/// in [RFC 7233 §3.2](https://datatracker.ietf.org/doc/html/rfc7233#section-3.2) /// /// If a client has a partial copy of a representation and wishes to have /// an up-to-date copy of the entire representation, it could use the @@ -23,23 +25,21 @@ use crate::HttpMessage; /// in Range; otherwise, send me the entire representation. /// /// # ABNF -/// -/// ```text +/// ```plain /// If-Range = entity-tag / HTTP-date /// ``` /// -/// # Example values +/// # Example Values /// /// * `Sat, 29 Oct 1994 19:43:31 GMT` /// * `\"xyzzy\"` /// /// # Examples -/// /// ``` -/// use actix_http::Response; -/// use actix_http::http::header::{EntityTag, IfRange}; +/// use actix_web::HttpResponse; +/// use actix_web::http::header::{EntityTag, IfRange}; /// -/// let mut builder = Response::Ok(); +/// let mut builder = HttpResponse::Ok(); /// builder.insert_header( /// IfRange::EntityTag( /// EntityTag::new(false, "abc".to_owned()) @@ -49,9 +49,9 @@ use crate::HttpMessage; /// /// ``` /// use std::time::{Duration, SystemTime}; -/// use actix_http::{http::header::IfRange, Response}; +/// use actix_web::{http::header::IfRange, HttpResponse}; /// -/// let mut builder = Response::Ok(); +/// let mut builder = HttpResponse::Ok(); /// let fetched = SystemTime::now() - Duration::from_secs(60 * 60 * 24); /// builder.insert_header( /// IfRange::Date(fetched.into()) @@ -75,13 +75,11 @@ impl Header for IfRange { where T: HttpMessage, { - let etag: Result = - from_one_raw_str(msg.headers().get(&header::IF_RANGE)); + let etag: Result = from_one_raw_str(msg.headers().get(&header::IF_RANGE)); if let Ok(etag) = etag { return Ok(IfRange::EntityTag(etag)); } - let date: Result = - from_one_raw_str(msg.headers().get(&header::IF_RANGE)); + let date: Result = from_one_raw_str(msg.headers().get(&header::IF_RANGE)); if let Ok(date) = date { return Ok(IfRange::Date(date)); } @@ -98,7 +96,7 @@ impl Display for IfRange { } } -impl IntoHeaderValue for IfRange { +impl TryIntoHeaderValue for IfRange { type Error = InvalidHeaderValue; fn try_into_value(self) -> Result { @@ -109,12 +107,13 @@ impl IntoHeaderValue for IfRange { } #[cfg(test)] -mod test_if_range { - use super::IfRange as HeaderField; - use crate::header::*; +mod test_parse_and_format { use std::str; - test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); - test_header!(test2, vec![b"\"abc\""]); - test_header!(test3, vec![b"this-is-invalid"], None::); + use super::IfRange as HeaderField; + use crate::http::header::*; + + crate::http::header::common_header_test!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); + crate::http::header::common_header_test!(test2, vec![b"\"abc\""]); + crate::http::header::common_header_test!(test3, vec![b"this-is-invalid"], None::); } diff --git a/actix-http/src/header/common/if_unmodified_since.rs b/src/http/header/if_unmodified_since.rs similarity index 63% rename from actix-http/src/header/common/if_unmodified_since.rs rename to src/http/header/if_unmodified_since.rs index 1c2b4af78..0df6d7ba0 100644 --- a/actix-http/src/header/common/if_unmodified_since.rs +++ b/src/http/header/if_unmodified_since.rs @@ -1,8 +1,8 @@ -use crate::header::{HttpDate, IF_UNMODIFIED_SINCE}; +use super::{HttpDate, IF_UNMODIFIED_SINCE}; -header! { - /// `If-Unmodified-Since` header, defined in - /// [RFC7232](http://tools.ietf.org/html/rfc7232#section-3.4) +crate::http::header::common_header! { + /// `If-Unmodified-Since` header, defined + /// in [RFC 7232 §3.4](https://datatracker.ietf.org/doc/html/rfc7232#section-3.4) /// /// The `If-Unmodified-Since` header field makes the request method /// conditional on the selected representation's last modification date @@ -11,23 +11,21 @@ header! { /// the user agent does not have an entity-tag for the representation. /// /// # ABNF - /// - /// ```text + /// ```plain /// If-Unmodified-Since = HTTP-date /// ``` /// - /// # Example values - /// + /// # Example Values /// * `Sat, 29 Oct 1994 19:43:31 GMT` /// /// # Example /// /// ``` /// use std::time::{SystemTime, Duration}; - /// use actix_http::Response; - /// use actix_http::http::header::IfUnmodifiedSince; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::IfUnmodifiedSince; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); /// builder.insert_header( /// IfUnmodifiedSince(modified.into()) @@ -35,8 +33,8 @@ header! { /// ``` (IfUnmodifiedSince, IF_UNMODIFIED_SINCE) => [HttpDate] - test_if_unmodified_since { + test_parse_and_format { // Test case from RFC - test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); + crate::http::header::common_header_test!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); } } diff --git a/actix-http/src/header/common/last_modified.rs b/src/http/header/last_modified.rs similarity index 57% rename from actix-http/src/header/common/last_modified.rs rename to src/http/header/last_modified.rs index 65608d846..e15443ed1 100644 --- a/actix-http/src/header/common/last_modified.rs +++ b/src/http/header/last_modified.rs @@ -1,8 +1,8 @@ -use crate::header::{HttpDate, LAST_MODIFIED}; +use super::{HttpDate, LAST_MODIFIED}; -header! { - /// `Last-Modified` header, defined in - /// [RFC7232](http://tools.ietf.org/html/rfc7232#section-2.2) +crate::http::header::common_header! { + /// `Last-Modified` header, defined + /// in [RFC 7232 §2.2](https://datatracker.ietf.org/doc/html/rfc7232#section-2.2) /// /// The `Last-Modified` header field in a response provides a timestamp /// indicating the date and time at which the origin server believes the @@ -10,23 +10,21 @@ header! { /// conclusion of handling the request. /// /// # ABNF - /// - /// ```text + /// ```plain /// Expires = HTTP-date /// ``` /// - /// # Example values - /// + /// # Example Values /// * `Sat, 29 Oct 1994 19:43:31 GMT` /// /// # Example /// /// ``` /// use std::time::{SystemTime, Duration}; - /// use actix_http::Response; - /// use actix_http::http::header::LastModified; + /// use actix_web::HttpResponse; + /// use actix_web::http::header::LastModified; /// - /// let mut builder = Response::Ok(); + /// let mut builder = HttpResponse::Ok(); /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); /// builder.insert_header( /// LastModified(modified.into()) @@ -34,7 +32,8 @@ header! { /// ``` (LastModified, LAST_MODIFIED) => [HttpDate] - test_last_modified { - // Test case from RFC - test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]);} + test_parse_and_format { + // Test case from RFC + crate::http::header::common_header_test!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); + } } diff --git a/src/http/header/macros.rs b/src/http/header/macros.rs new file mode 100644 index 000000000..25f40a52b --- /dev/null +++ b/src/http/header/macros.rs @@ -0,0 +1,319 @@ +macro_rules! common_header_test_module { + ($id:ident, $tm:ident{$($tf:item)*}) => { + #[cfg(test)] + mod $tm { + #![allow(unused_imports)] + + use ::core::str; + + use ::actix_http::{Method, test}; + use ::mime::*; + + use $crate::http::header::{self, *}; + use super::{$id as HeaderField, *}; + + $($tf)* + } + } +} + +#[cfg(test)] +macro_rules! common_header_test { + ($id:ident, $raw:expr) => { + #[test] + fn $id() { + use ::actix_http::test; + + let raw = $raw; + let headers = raw.iter().map(|x| x.to_vec()).collect::>(); + + let mut req = test::TestRequest::default(); + + for item in headers { + req = req.append_header((HeaderField::name(), item)).take(); + } + + let req = req.finish(); + let value = HeaderField::parse(&req); + + let result = format!("{}", value.unwrap()); + let expected = ::std::string::String::from_utf8(raw[0].to_vec()).unwrap(); + + let result_cmp: Vec = result + .to_ascii_lowercase() + .split(' ') + .map(|x| x.to_owned()) + .collect(); + let expected_cmp: Vec = expected + .to_ascii_lowercase() + .split(' ') + .map(|x| x.to_owned()) + .collect(); + + assert_eq!(result_cmp.concat(), expected_cmp.concat()); + } + }; + + ($id:ident, $raw:expr, $exp:expr) => { + #[test] + fn $id() { + use actix_http::test; + + let headers = $raw.iter().map(|x| x.to_vec()).collect::>(); + let mut req = test::TestRequest::default(); + + for item in headers { + req.append_header((HeaderField::name(), item)); + } + + let req = req.finish(); + let val = HeaderField::parse(&req); + + let exp: ::core::option::Option = $exp; + + // test parsing + assert_eq!(val.ok(), exp); + + // test formatting + if let Some(exp) = exp { + let raw = &($raw)[..]; + let mut iter = raw.iter().map(|b| str::from_utf8(&b[..]).unwrap()); + let mut joined = String::new(); + if let Some(s) = iter.next() { + joined.push_str(s); + for s in iter { + joined.push_str(", "); + joined.push_str(s); + } + } + assert_eq!(format!("{}", exp), joined); + } + } + }; +} + +macro_rules! common_header { + // TODO: these docs are wrong, there's no $n or $nn + // $attrs:meta: Attributes associated with the header item (usually docs) + // $id:ident: Identifier of the header + // $n:expr: Lowercase name of the header + // $nn:expr: Nice name of the header + + // List header, zero or more items + ($(#[$attrs:meta])*($id:ident, $name:expr) => ($item:ty)*) => { + $(#[$attrs])* + #[derive(Debug, Clone, PartialEq, Eq, ::derive_more::Deref, ::derive_more::DerefMut)] + pub struct $id(pub Vec<$item>); + + impl $crate::http::header::Header for $id { + #[inline] + fn name() -> $crate::http::header::HeaderName { + $name + } + + #[inline] + fn parse(msg: &M) -> Result { + let headers = msg.headers().get_all(Self::name()); + $crate::http::header::from_comma_delimited(headers).map($id) + } + } + + impl ::core::fmt::Display for $id { + #[inline] + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + $crate::http::header::fmt_comma_delimited(f, &self.0[..]) + } + } + + impl $crate::http::header::TryIntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + use ::core::fmt::Write; + let mut writer = $crate::http::header::Writer::new(); + let _ = write!(&mut writer, "{}", self); + $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) + } + } + }; + + // List header, one or more items + ($(#[$attrs:meta])*($id:ident, $name:expr) => ($item:ty)+) => { + $(#[$attrs])* + #[derive(Debug, Clone, PartialEq, Eq, ::derive_more::Deref, ::derive_more::DerefMut)] + pub struct $id(pub Vec<$item>); + + impl $crate::http::header::Header for $id { + #[inline] + fn name() -> $crate::http::header::HeaderName { + $name + } + + #[inline] + fn parse(msg: &M) -> Result{ + let headers = msg.headers().get_all(Self::name()); + + $crate::http::header::from_comma_delimited(headers) + .and_then(|items| { + if items.is_empty() { + Err($crate::error::ParseError::Header) + } else { + Ok($id(items)) + } + }) + } + } + + impl ::core::fmt::Display for $id { + #[inline] + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + $crate::http::header::fmt_comma_delimited(f, &self.0[..]) + } + } + + impl $crate::http::header::TryIntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + use ::core::fmt::Write; + let mut writer = $crate::http::header::Writer::new(); + let _ = write!(&mut writer, "{}", self); + $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) + } + } + }; + + // Single value header + ($(#[$attrs:meta])*($id:ident, $name:expr) => [$value:ty]) => { + $(#[$attrs])* + #[derive(Debug, Clone, PartialEq, Eq, ::derive_more::Deref, ::derive_more::DerefMut)] + pub struct $id(pub $value); + + impl $crate::http::header::Header for $id { + #[inline] + fn name() -> $crate::http::header::HeaderName { + $name + } + + #[inline] + fn parse(msg: &M) -> Result { + let header = msg.headers().get(Self::name()); + $crate::http::header::from_one_raw_str(header).map($id) + } + } + + impl ::core::fmt::Display for $id { + #[inline] + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + ::core::fmt::Display::fmt(&self.0, f) + } + } + + impl $crate::http::header::TryIntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + self.0.try_into_value() + } + } + }; + + // List header, one or more items with "*" option + ($(#[$attrs:meta])*($id:ident, $name:expr) => {Any / ($item:ty)+}) => { + $(#[$attrs])* + #[derive(Clone, Debug, PartialEq)] + pub enum $id { + /// Any value is a match + Any, + /// Only the listed items are a match + Items(Vec<$item>), + } + + impl $crate::http::header::Header for $id { + #[inline] + fn name() -> $crate::http::header::HeaderName { + $name + } + + #[inline] + fn parse(msg: &M) -> Result { + let is_any = msg + .headers() + .get(Self::name()) + .and_then(|hdr| hdr.to_str().ok()) + .map(|hdr| hdr.trim() == "*"); + + if let Some(true) = is_any { + Ok($id::Any) + } else { + let headers = msg.headers().get_all(Self::name()); + Ok($id::Items($crate::http::header::from_comma_delimited(headers)?)) + } + } + } + + impl ::core::fmt::Display for $id { + #[inline] + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + match *self { + $id::Any => f.write_str("*"), + $id::Items(ref fields) => + $crate::http::header::fmt_comma_delimited(f, &fields[..]) + } + } + } + + impl $crate::http::header::TryIntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValue; + + #[inline] + fn try_into_value(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + use ::core::fmt::Write; + let mut writer = $crate::http::header::Writer::new(); + let _ = write!(&mut writer, "{}", self); + $crate::http::header::HeaderValue::from_maybe_shared(writer.take()) + } + } + }; + + // optional test module + ($(#[$attrs:meta])*($id:ident, $name:expr) => ($item:ty)* $tm:ident{$($tf:item)*}) => { + crate::http::header::common_header! { + $(#[$attrs])* + ($id, $name) => ($item)* + } + + crate::http::header::common_header_test_module! { $id, $tm { $($tf)* }} + }; + ($(#[$attrs:meta])*($id:ident, $n:expr) => ($item:ty)+ $tm:ident{$($tf:item)*}) => { + crate::http::header::common_header! { + $(#[$attrs])* + ($id, $n) => ($item)+ + } + + crate::http::header::common_header_test_module! { $id, $tm { $($tf)* }} + }; + ($(#[$attrs:meta])*($id:ident, $name:expr) => [$item:ty] $tm:ident{$($tf:item)*}) => { + crate::http::header::common_header! { + $(#[$attrs])* ($id, $name) => [$item] + } + + crate::http::header::common_header_test_module! { $id, $tm { $($tf)* }} + }; + ($(#[$attrs:meta])*($id:ident, $name:expr) => {Any / ($item:ty)+} $tm:ident{$($tf:item)*}) => { + crate::http::header::common_header! { + $(#[$attrs])* + ($id, $name) => {Any / ($item)+} + } + + crate::http::header::common_header_test_module! { $id, $tm { $($tf)* }} + }; +} + +pub(crate) use {common_header, common_header_test_module}; + +#[cfg(test)] +pub(crate) use common_header_test; diff --git a/src/http/header/mod.rs b/src/http/header/mod.rs new file mode 100644 index 000000000..9807d5f5e --- /dev/null +++ b/src/http/header/mod.rs @@ -0,0 +1,102 @@ +//! A Collection of Header implementations for common HTTP Headers. +//! +//! ## Mime Types +//! Several header fields use MIME values for their contents. Keeping with the strongly-typed theme, +//! the [mime] crate is used in such headers as [`ContentType`] and [`Accept`]. + +use std::fmt; + +use bytes::{Bytes, BytesMut}; + +// re-export from actix-http +// - header name / value types +// - relevant traits for converting to header name / value +// - all const header names +// - header map +// - the few typed headers from actix-http +// - header parsing utils +pub use actix_http::header::*; + +mod accept; +mod accept_charset; +mod accept_encoding; +mod accept_language; +mod allow; +mod cache_control; +mod content_disposition; +mod content_language; +mod content_range; +mod content_type; +mod date; +mod encoding; +mod entity; +mod etag; +mod expires; +mod if_match; +mod if_modified_since; +mod if_none_match; +mod if_range; +mod if_unmodified_since; +mod last_modified; +mod macros; +mod preference; +mod range; + +#[cfg(test)] +pub(crate) use macros::common_header_test; +pub(crate) use macros::{common_header, common_header_test_module}; + +pub use self::accept::Accept; +pub use self::accept_charset::AcceptCharset; +pub use self::accept_encoding::AcceptEncoding; +pub use self::accept_language::AcceptLanguage; +pub use self::allow::Allow; +pub use self::cache_control::{CacheControl, CacheDirective}; +pub use self::content_disposition::{ContentDisposition, DispositionParam, DispositionType}; +pub use self::content_language::ContentLanguage; +pub use self::content_range::{ContentRange, ContentRangeSpec}; +pub use self::content_type::ContentType; +pub use self::date::Date; +pub use self::encoding::Encoding; +pub use self::entity::EntityTag; +pub use self::etag::ETag; +pub use self::expires::Expires; +pub use self::if_match::IfMatch; +pub use self::if_modified_since::IfModifiedSince; +pub use self::if_none_match::IfNoneMatch; +pub use self::if_range::IfRange; +pub use self::if_unmodified_since::IfUnmodifiedSince; +pub use self::last_modified::LastModified; +pub use self::preference::Preference; +pub use self::range::{ByteRangeSpec, Range}; + +/// Format writer ([`fmt::Write`]) for a [`BytesMut`]. +#[derive(Debug, Default)] +struct Writer { + buf: BytesMut, +} + +impl Writer { + /// Constructs new bytes writer. + pub fn new() -> Writer { + Writer::default() + } + + /// Splits bytes out of writer, leaving writer buffer empty. + pub fn take(&mut self) -> Bytes { + self.buf.split().freeze() + } +} + +impl fmt::Write for Writer { + #[inline] + fn write_str(&mut self, s: &str) -> fmt::Result { + self.buf.extend_from_slice(s.as_bytes()); + Ok(()) + } + + #[inline] + fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result { + fmt::write(self, args) + } +} diff --git a/src/http/header/preference.rs b/src/http/header/preference.rs new file mode 100644 index 000000000..979fc7720 --- /dev/null +++ b/src/http/header/preference.rs @@ -0,0 +1,70 @@ +use std::{ + fmt::{self, Write as _}, + str, +}; + +/// A wrapper for types used in header values where wildcard (`*`) items are allowed but the +/// underlying type does not support them. +/// +/// For example, we use the `language-tags` crate for the [`AcceptLanguage`](super::AcceptLanguage) +/// typed header but it does not parse `*` successfully. On the other hand, the `mime` crate, used +/// for [`Accept`](super::Accept), has first-party support for wildcard items so this wrapper is not +/// used in those header types. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Hash)] +pub enum Preference { + /// A wildcard value. + Any, + + /// A valid `T`. + Specific(T), +} + +impl Preference { + /// Returns true if preference is the any/wildcard (`*`) value. + pub fn is_any(&self) -> bool { + matches!(self, Self::Any) + } + + /// Returns true if preference is the specific item (`T`) variant. + pub fn is_specific(&self) -> bool { + matches!(self, Self::Specific(_)) + } + + /// Returns reference to value in `Specific` variant, if it is set. + pub fn item(&self) -> Option<&T> { + match self { + Preference::Specific(ref item) => Some(item), + Preference::Any => None, + } + } + + /// Consumes the container, returning the value in the `Specific` variant, if it is set. + pub fn into_item(self) -> Option { + match self { + Preference::Specific(item) => Some(item), + Preference::Any => None, + } + } +} + +impl fmt::Display for Preference { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Preference::Any => f.write_char('*'), + Preference::Specific(item) => fmt::Display::fmt(item, f), + } + } +} + +impl str::FromStr for Preference { + type Err = T::Err; + + #[inline] + fn from_str(s: &str) -> Result { + match s.trim() { + "*" => Ok(Self::Any), + other => other.parse().map(Preference::Specific), + } + } +} diff --git a/src/http/header/range.rs b/src/http/header/range.rs new file mode 100644 index 000000000..68028f53a --- /dev/null +++ b/src/http/header/range.rs @@ -0,0 +1,427 @@ +use std::{ + cmp, + fmt::{self, Display, Write}, + str::FromStr, +}; + +use actix_http::{error::ParseError, header, HttpMessage}; + +use super::{Header, HeaderName, HeaderValue, InvalidHeaderValue, TryIntoHeaderValue, Writer}; + +/// `Range` header, defined +/// in [RFC 7233 §3.1](https://datatracker.ietf.org/doc/html/rfc7233#section-3.1) +/// +/// The "Range" header field on a GET request modifies the method semantics to request transfer of +/// only one or more sub-ranges of the selected representation data, rather than the entire selected +/// representation data. +/// +/// # ABNF +/// ```plain +/// Range = byte-ranges-specifier / other-ranges-specifier +/// other-ranges-specifier = other-range-unit "=" other-range-set +/// other-range-set = 1*VCHAR +/// +/// bytes-unit = "bytes" +/// +/// byte-ranges-specifier = bytes-unit "=" byte-range-set +/// byte-range-set = 1#(byte-range-spec / suffix-byte-range-spec) +/// byte-range-spec = first-byte-pos "-" [last-byte-pos] +/// suffix-byte-range-spec = "-" suffix-length +/// suffix-length = 1*DIGIT +/// first-byte-pos = 1*DIGIT +/// last-byte-pos = 1*DIGIT +/// ``` +/// +/// # Example Values +/// * `bytes=1000-` +/// * `bytes=-50` +/// * `bytes=0-1,30-40` +/// * `bytes=0-10,20-90,-100` +/// * `custom_unit=0-123` +/// * `custom_unit=xxx-yyy` +/// +/// # Examples +/// ``` +/// use actix_web::http::header::{Range, ByteRangeSpec}; +/// use actix_web::HttpResponse; +/// +/// let mut builder = HttpResponse::Ok(); +/// builder.insert_header(Range::Bytes( +/// vec![ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::From(200)] +/// )); +/// builder.insert_header(Range::Unregistered("letters".to_owned(), "a-f".to_owned())); +/// builder.insert_header(Range::bytes(1, 100)); +/// builder.insert_header(Range::bytes_multi(vec![(1, 100), (200, 300)])); +/// ``` +#[derive(PartialEq, Clone, Debug)] +pub enum Range { + /// Byte range. + Bytes(Vec), + + /// Custom range, with unit not registered at IANA. + /// + /// (`other-range-unit`: String , `other-range-set`: String) + Unregistered(String, String), +} + +/// A range of bytes to fetch. +/// +/// Each [`Range::Bytes`] header can contain one or more `ByteRangeSpec`s. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ByteRangeSpec { + /// All bytes from `x` to `y`, inclusive. + /// + /// Serialized as `x-y`. + /// + /// Example: `bytes=500-999` would represent the second 500 bytes. + FromTo(u64, u64), + + /// All bytes starting from `x`, inclusive. + /// + /// Serialized as `x-`. + /// + /// Example: For a file of 1000 bytes, `bytes=950-` would represent bytes 950-999, inclusive. + From(u64), + + /// The last `y` bytes, inclusive. + /// + /// Using the spec terminology, this is `suffix-byte-range-spec`. Serialized as `-y`. + /// + /// Example: For a file of 1000 bytes, `bytes=-50` is equivalent to `bytes=950-`. + Last(u64), +} + +impl ByteRangeSpec { + /// Given the full length of the entity, attempt to normalize the byte range into an satisfiable + /// end-inclusive `(from, to)` range. + /// + /// The resulting range is guaranteed to be a satisfiable range within the bounds + /// of `0 <= from <= to < full_length`. + /// + /// If the byte range is deemed unsatisfiable, `None` is returned. An unsatisfiable range is + /// generally cause for a server to either reject the client request with a + /// `416 Range Not Satisfiable` status code, or to simply ignore the range header and serve the + /// full entity using a `200 OK` status code. + /// + /// This function closely follows [RFC 7233 §2.1]. As such, it considers ranges to be + /// satisfiable if they meet the following conditions: + /// + /// > If a valid byte-range-set includes at least one byte-range-spec with a first-byte-pos that + /// is less than the current length of the representation, or at least one + /// suffix-byte-range-spec with a non-zero suffix-length, then the byte-range-set + /// is satisfiable. Otherwise, the byte-range-set is unsatisfiable. + /// + /// The function also computes remainder ranges based on the RFC: + /// + /// > If the last-byte-pos value is absent, or if the value is greater than or equal to the + /// current length of the representation data, the byte range is interpreted as the remainder + /// of the representation (i.e., the server replaces the value of last-byte-pos with a value + /// that is one less than the current length of the selected representation). + /// + /// [RFC 7233 §2.1]: https://datatracker.ietf.org/doc/html/rfc7233 + pub fn to_satisfiable_range(&self, full_length: u64) -> Option<(u64, u64)> { + // If the full length is zero, there is no satisfiable end-inclusive range. + if full_length == 0 { + return None; + } + + match *self { + ByteRangeSpec::FromTo(from, to) => { + if from < full_length && from <= to { + Some((from, cmp::min(to, full_length - 1))) + } else { + None + } + } + + ByteRangeSpec::From(from) => { + if from < full_length { + Some((from, full_length - 1)) + } else { + None + } + } + + ByteRangeSpec::Last(last) => { + if last > 0 { + // From the RFC: If the selected representation is shorter than the specified + // suffix-length, the entire representation is used. + if last > full_length { + Some((0, full_length - 1)) + } else { + Some((full_length - last, full_length - 1)) + } + } else { + None + } + } + } + } +} + +impl Range { + /// Constructs a common byte range header. + /// + /// Eg: `bytes=from-to` + pub fn bytes(from: u64, to: u64) -> Range { + Range::Bytes(vec![ByteRangeSpec::FromTo(from, to)]) + } + + /// Constructs a byte range header with multiple subranges. + /// + /// Eg: `bytes=from1-to1,from2-to2,fromX-toX` + pub fn bytes_multi(ranges: Vec<(u64, u64)>) -> Range { + Range::Bytes( + ranges + .into_iter() + .map(|(from, to)| ByteRangeSpec::FromTo(from, to)) + .collect(), + ) + } +} + +impl fmt::Display for ByteRangeSpec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + ByteRangeSpec::FromTo(from, to) => write!(f, "{}-{}", from, to), + ByteRangeSpec::Last(pos) => write!(f, "-{}", pos), + ByteRangeSpec::From(pos) => write!(f, "{}-", pos), + } + } +} + +impl fmt::Display for Range { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Range::Bytes(ranges) => { + write!(f, "bytes=")?; + + for (i, range) in ranges.iter().enumerate() { + if i != 0 { + f.write_str(",")?; + } + + Display::fmt(range, f)?; + } + Ok(()) + } + + Range::Unregistered(unit, range_str) => { + write!(f, "{}={}", unit, range_str) + } + } + } +} + +impl FromStr for Range { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + let (unit, val) = s.split_once('=').ok_or(ParseError::Header)?; + + match (unit, val) { + ("bytes", ranges) => { + let ranges = from_comma_delimited(ranges); + + if ranges.is_empty() { + return Err(ParseError::Header); + } + + Ok(Range::Bytes(ranges)) + } + + (_, "") => Err(ParseError::Header), + ("", _) => Err(ParseError::Header), + + (unit, range_str) => Ok(Range::Unregistered(unit.to_owned(), range_str.to_owned())), + } + } +} + +impl FromStr for ByteRangeSpec { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + let (start, end) = s.split_once('-').ok_or(ParseError::Header)?; + + match (start, end) { + ("", end) => end + .parse() + .or(Err(ParseError::Header)) + .map(ByteRangeSpec::Last), + + (start, "") => start + .parse() + .or(Err(ParseError::Header)) + .map(ByteRangeSpec::From), + + (start, end) => match (start.parse(), end.parse()) { + (Ok(start), Ok(end)) if start <= end => Ok(ByteRangeSpec::FromTo(start, end)), + _ => Err(ParseError::Header), + }, + } + } +} + +impl Header for Range { + fn name() -> HeaderName { + header::RANGE + } + + #[inline] + fn parse(msg: &T) -> Result { + header::from_one_raw_str(msg.headers().get(&Self::name())) + } +} + +impl TryIntoHeaderValue for Range { + type Error = InvalidHeaderValue; + + fn try_into_value(self) -> Result { + let mut wrt = Writer::new(); + let _ = write!(wrt, "{}", self); + HeaderValue::from_maybe_shared(wrt.take()) + } +} + +/// Parses 0 or more items out of a comma delimited string, ignoring invalid items. +fn from_comma_delimited(s: &str) -> Vec { + s.split(',') + .filter_map(|x| match x.trim() { + "" => None, + y => Some(y), + }) + .filter_map(|x| x.parse().ok()) + .collect() +} + +#[cfg(test)] +mod tests { + use actix_http::{test::TestRequest, Request}; + + use super::*; + + fn req(s: &str) -> Request { + TestRequest::default() + .insert_header((header::RANGE, s)) + .finish() + } + + #[test] + fn test_parse_bytes_range_valid() { + let r: Range = Header::parse(&req("bytes=1-100")).unwrap(); + let r2: Range = Header::parse(&req("bytes=1-100,-")).unwrap(); + let r3 = Range::bytes(1, 100); + assert_eq!(r, r2); + assert_eq!(r2, r3); + + let r: Range = Header::parse(&req("bytes=1-100,200-")).unwrap(); + let r2: Range = Header::parse(&req("bytes= 1-100 , 101-xxx, 200- ")).unwrap(); + let r3 = Range::Bytes(vec![ + ByteRangeSpec::FromTo(1, 100), + ByteRangeSpec::From(200), + ]); + assert_eq!(r, r2); + assert_eq!(r2, r3); + + let r: Range = Header::parse(&req("bytes=1-100,-100")).unwrap(); + let r2: Range = Header::parse(&req("bytes=1-100, ,,-100")).unwrap(); + let r3 = Range::Bytes(vec![ + ByteRangeSpec::FromTo(1, 100), + ByteRangeSpec::Last(100), + ]); + assert_eq!(r, r2); + assert_eq!(r2, r3); + + let r: Range = Header::parse(&req("custom=1-100,-100")).unwrap(); + let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); + assert_eq!(r, r2); + } + + #[test] + fn test_parse_unregistered_range_valid() { + let r: Range = Header::parse(&req("custom=1-100,-100")).unwrap(); + let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); + assert_eq!(r, r2); + + let r: Range = Header::parse(&req("custom=abcd")).unwrap(); + let r2 = Range::Unregistered("custom".to_owned(), "abcd".to_owned()); + assert_eq!(r, r2); + + let r: Range = Header::parse(&req("custom=xxx-yyy")).unwrap(); + let r2 = Range::Unregistered("custom".to_owned(), "xxx-yyy".to_owned()); + assert_eq!(r, r2); + } + + #[test] + fn test_parse_invalid() { + let r: Result = Header::parse(&req("bytes=1-a,-")); + assert_eq!(r.ok(), None); + + let r: Result = Header::parse(&req("bytes=1-2-3")); + assert_eq!(r.ok(), None); + + let r: Result = Header::parse(&req("abc")); + assert_eq!(r.ok(), None); + + let r: Result = Header::parse(&req("bytes=1-100=")); + assert_eq!(r.ok(), None); + + let r: Result = Header::parse(&req("bytes=")); + assert_eq!(r.ok(), None); + + let r: Result = Header::parse(&req("custom=")); + assert_eq!(r.ok(), None); + + let r: Result = Header::parse(&req("=1-100")); + assert_eq!(r.ok(), None); + } + + #[test] + fn test_fmt() { + let range = Range::Bytes(vec![ + ByteRangeSpec::FromTo(0, 1000), + ByteRangeSpec::From(2000), + ]); + assert_eq!(&range.to_string(), "bytes=0-1000,2000-"); + + let range = Range::Bytes(vec![]); + + assert_eq!(&range.to_string(), "bytes="); + + let range = Range::Unregistered("custom".to_owned(), "1-xxx".to_owned()); + + assert_eq!(&range.to_string(), "custom=1-xxx"); + } + + #[test] + fn test_byte_range_spec_to_satisfiable_range() { + assert_eq!( + Some((0, 0)), + ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(3) + ); + assert_eq!( + Some((1, 2)), + ByteRangeSpec::FromTo(1, 2).to_satisfiable_range(3) + ); + assert_eq!( + Some((1, 2)), + ByteRangeSpec::FromTo(1, 5).to_satisfiable_range(3) + ); + assert_eq!(None, ByteRangeSpec::FromTo(3, 3).to_satisfiable_range(3)); + assert_eq!(None, ByteRangeSpec::FromTo(2, 1).to_satisfiable_range(3)); + assert_eq!(None, ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(0)); + + assert_eq!(Some((0, 2)), ByteRangeSpec::From(0).to_satisfiable_range(3)); + assert_eq!(Some((2, 2)), ByteRangeSpec::From(2).to_satisfiable_range(3)); + assert_eq!(None, ByteRangeSpec::From(3).to_satisfiable_range(3)); + assert_eq!(None, ByteRangeSpec::From(5).to_satisfiable_range(3)); + assert_eq!(None, ByteRangeSpec::From(0).to_satisfiable_range(0)); + + assert_eq!(Some((1, 2)), ByteRangeSpec::Last(2).to_satisfiable_range(3)); + assert_eq!(Some((2, 2)), ByteRangeSpec::Last(1).to_satisfiable_range(3)); + assert_eq!(Some((0, 2)), ByteRangeSpec::Last(5).to_satisfiable_range(3)); + assert_eq!(None, ByteRangeSpec::Last(0).to_satisfiable_range(3)); + assert_eq!(None, ByteRangeSpec::Last(2).to_satisfiable_range(0)); + } +} diff --git a/src/http/mod.rs b/src/http/mod.rs new file mode 100644 index 000000000..2581532cd --- /dev/null +++ b/src/http/mod.rs @@ -0,0 +1,6 @@ +//! Various HTTP related types. + +pub mod header; + +// TODO: figure out how best to expose http::Error vs actix_http::Error +pub use actix_http::{uri, ConnectionType, Error, Method, StatusCode, Uri, Version}; diff --git a/src/info.rs b/src/info.rs index c9ddf6ec4..ce1ef97c6 100644 --- a/src/info.rs +++ b/src/info.rs @@ -1,188 +1,259 @@ -use std::cell::Ref; +use std::{convert::Infallible, net::SocketAddr}; -use crate::dev::{AppConfig, RequestHead}; -use crate::http::header::{self, HeaderName}; +use actix_utils::future::{err, ok, Ready}; +use derive_more::{Display, Error}; +use once_cell::sync::Lazy; -const X_FORWARDED_FOR: &[u8] = b"x-forwarded-for"; -const X_FORWARDED_HOST: &[u8] = b"x-forwarded-host"; -const X_FORWARDED_PROTO: &[u8] = b"x-forwarded-proto"; +use crate::{ + dev::{AppConfig, Payload, RequestHead}, + http::{ + header::{self, HeaderName}, + uri::{Authority, Scheme}, + }, + FromRequest, HttpRequest, ResponseError, +}; -/// `HttpRequest` connection information +static X_FORWARDED_FOR: Lazy = + Lazy::new(|| HeaderName::from_static("x-forwarded-for")); +static X_FORWARDED_HOST: Lazy = + Lazy::new(|| HeaderName::from_static("x-forwarded-host")); +static X_FORWARDED_PROTO: Lazy = + Lazy::new(|| HeaderName::from_static("x-forwarded-proto")); + +/// Trim whitespace then any quote marks. +fn unquote(val: &str) -> &str { + val.trim().trim_start_matches('"').trim_end_matches('"') +} + +/// Extracts and trims first value for given header name. +fn first_header_value<'a>(req: &'a RequestHead, name: &'_ HeaderName) -> Option<&'a str> { + let hdr = req.headers.get(name)?.to_str().ok()?; + let val = hdr.split(',').next()?.trim(); + Some(val) +} + +/// HTTP connection information. +/// +/// `ConnectionInfo` implements `FromRequest` and can be extracted in handlers. +/// +/// # Examples +/// ``` +/// # use actix_web::{HttpResponse, Responder}; +/// use actix_web::dev::ConnectionInfo; +/// +/// async fn handler(conn: ConnectionInfo) -> impl Responder { +/// match conn.host() { +/// "actix.rs" => HttpResponse::Ok().body("Welcome!"), +/// "admin.actix.rs" => HttpResponse::Ok().body("Admin portal."), +/// _ => HttpResponse::NotFound().finish() +/// } +/// } +/// # let _svc = actix_web::web::to(handler); +/// ``` +/// +/// # Implementation Notes +/// Parses `Forwarded` header information according to [RFC 7239][rfc7239] but does not try to +/// interpret the values for each property. As such, the getter methods on `ConnectionInfo` return +/// strings instead of IP addresses or other types to acknowledge that they may be +/// [obfuscated][rfc7239-63] or [unknown][rfc7239-62]. +/// +/// If the older, related headers are also present (eg. `X-Forwarded-For`), then `Forwarded` +/// is preferred. +/// +/// [rfc7239]: https://datatracker.ietf.org/doc/html/rfc7239 +/// [rfc7239-62]: https://datatracker.ietf.org/doc/html/rfc7239#section-6.2 +/// [rfc7239-63]: https://datatracker.ietf.org/doc/html/rfc7239#section-6.3 #[derive(Debug, Clone, Default)] pub struct ConnectionInfo { - scheme: String, host: String, + scheme: String, + peer_addr: Option, realip_remote_addr: Option, - remote_addr: Option, } impl ConnectionInfo { - /// Create *ConnectionInfo* instance for a request. - pub fn get<'a>(req: &'a RequestHead, cfg: &AppConfig) -> Ref<'a, Self> { - if !req.extensions().contains::() { - req.extensions_mut().insert(ConnectionInfo::new(req, cfg)); - } - Ref::map(req.extensions(), |e| e.get().unwrap()) - } - - #[allow(clippy::cognitive_complexity, clippy::borrow_interior_mutable_const)] - fn new(req: &RequestHead, cfg: &AppConfig) -> ConnectionInfo { + pub(crate) fn new(req: &RequestHead, cfg: &AppConfig) -> ConnectionInfo { let mut host = None; let mut scheme = None; let mut realip_remote_addr = None; - // load forwarded header - for hdr in req.headers.get_all(&header::FORWARDED) { - if let Ok(val) = hdr.to_str() { - for pair in val.split(';') { - for el in pair.split(',') { - let mut items = el.trim().splitn(2, '='); - if let Some(name) = items.next() { - if let Some(val) = items.next() { - match &name.to_lowercase() as &str { - "for" => { - if realip_remote_addr.is_none() { - realip_remote_addr = Some(val.trim()); - } - } - "proto" => { - if scheme.is_none() { - scheme = Some(val.trim()); - } - } - "host" => { - if host.is_none() { - host = Some(val.trim()); - } - } - _ => {} - } - } - } - } + for (name, val) in req + .headers + .get_all(&header::FORWARDED) + .into_iter() + .filter_map(|hdr| hdr.to_str().ok()) + // "for=1.2.3.4, for=5.6.7.8; scheme=https" + .flat_map(|val| val.split(';')) + // ["for=1.2.3.4, for=5.6.7.8", " scheme=https"] + .flat_map(|vals| vals.split(',')) + // ["for=1.2.3.4", " for=5.6.7.8", " scheme=https"] + .flat_map(|pair| { + let mut items = pair.trim().splitn(2, '='); + Some((items.next()?, items.next()?)) + }) + { + // [(name , val ), ... ] + // [("for", "1.2.3.4"), ("for", "5.6.7.8"), ("scheme", "https")] + + // taking the first value for each property is correct because spec states that first + // "for" value is client and rest are proxies; multiple values other properties have + // no defined semantics + // + // > In a chain of proxy servers where this is fully utilized, the first + // > "for" parameter will disclose the client where the request was first + // > made, followed by any subsequent proxy identifiers. + // --- https://datatracker.ietf.org/doc/html/rfc7239#section-5.2 + + match name.trim().to_lowercase().as_str() { + "for" => realip_remote_addr.get_or_insert_with(|| unquote(val)), + "proto" => scheme.get_or_insert_with(|| unquote(val)), + "host" => host.get_or_insert_with(|| unquote(val)), + "by" => { + // TODO: implement https://datatracker.ietf.org/doc/html/rfc7239#section-5.1 + continue; } - } + _ => continue, + }; } - // scheme - if scheme.is_none() { - if let Some(h) = req - .headers - .get(&HeaderName::from_lowercase(X_FORWARDED_PROTO).unwrap()) - { - if let Ok(h) = h.to_str() { - scheme = h.split(',').next().map(|v| v.trim()); - } - } - if scheme.is_none() { - scheme = req.uri.scheme().map(|a| a.as_str()); - if scheme.is_none() && cfg.secure() { - scheme = Some("https") - } - } - } + let scheme = scheme + .or_else(|| first_header_value(req, &*X_FORWARDED_PROTO)) + .or_else(|| req.uri.scheme().map(Scheme::as_str)) + .or_else(|| Some("https").filter(|_| cfg.secure())) + .unwrap_or("http") + .to_owned(); - // host - if host.is_none() { - if let Some(h) = req - .headers - .get(&HeaderName::from_lowercase(X_FORWARDED_HOST).unwrap()) - { - if let Ok(h) = h.to_str() { - host = h.split(',').next().map(|v| v.trim()); - } - } - if host.is_none() { - if let Some(h) = req.headers.get(&header::HOST) { - host = h.to_str().ok(); - } - if host.is_none() { - host = req.uri.authority().map(|a| a.as_str()); - if host.is_none() { - host = Some(cfg.host()); - } - } - } - } + let host = host + .or_else(|| first_header_value(req, &*X_FORWARDED_HOST)) + .or_else(|| req.headers.get(&header::HOST)?.to_str().ok()) + .or_else(|| req.uri.authority().map(Authority::as_str)) + .unwrap_or_else(|| cfg.host()) + .to_owned(); - // get remote_addraddr from socketaddr - let remote_addr = req.peer_addr.map(|addr| format!("{}", addr)); + let realip_remote_addr = realip_remote_addr + .or_else(|| first_header_value(req, &*X_FORWARDED_FOR)) + .map(str::to_owned); - if realip_remote_addr.is_none() { - if let Some(h) = req - .headers - .get(&HeaderName::from_lowercase(X_FORWARDED_FOR).unwrap()) - { - if let Ok(h) = h.to_str() { - realip_remote_addr = h.split(',').next().map(|v| v.trim()); - } - } - } + let peer_addr = req.peer_addr.map(|addr| addr.ip().to_string()); ConnectionInfo { - remote_addr, - scheme: scheme.unwrap_or("http").to_owned(), - host: host.unwrap_or("localhost").to_owned(), - realip_remote_addr: realip_remote_addr.map(|s| s.to_owned()), + host, + scheme, + peer_addr, + realip_remote_addr, } } + /// Real IP (remote address) of client that initiated request. + /// + /// The address is resolved through the following, in order: + /// - `Forwarded` header + /// - `X-Forwarded-For` header + /// - peer address of opened socket (same as [`remote_addr`](Self::remote_addr)) + /// + /// # Security + /// Do not use this function for security purposes unless you can be sure that the `Forwarded` + /// and `X-Forwarded-For` headers cannot be spoofed by the client. If you are running without a + /// proxy then [obtaining the peer address](Self::peer_addr) would be more appropriate. + #[inline] + pub fn realip_remote_addr(&self) -> Option<&str> { + self.realip_remote_addr + .as_deref() + .or_else(|| self.peer_addr.as_deref()) + } + + /// Returns serialized IP address of the peer connection. + /// + /// See [`HttpRequest::peer_addr`] for more details. + #[inline] + pub fn peer_addr(&self) -> Option<&str> { + self.peer_addr.as_deref() + } + + /// Hostname of the request. + /// + /// Hostname is resolved through the following, in order: + /// - `Forwarded` header + /// - `X-Forwarded-Host` header + /// - `Host` header + /// - request target / URI + /// - configured server hostname + #[inline] + pub fn host(&self) -> &str { + &self.host + } + /// Scheme of the request. /// - /// Scheme is resolved through the following headers, in this order: - /// - /// - Forwarded - /// - X-Forwarded-Proto - /// - Uri + /// Scheme is resolved through the following, in order: + /// - `Forwarded` header + /// - `X-Forwarded-Proto` header + /// - request target / URI #[inline] pub fn scheme(&self) -> &str { &self.scheme } - /// Hostname of the request. - /// - /// Hostname is resolved through the following headers, in this order: - /// - /// - Forwarded - /// - X-Forwarded-Host - /// - Host - /// - Uri - /// - Server hostname - pub fn host(&self) -> &str { - &self.host - } - - /// remote_addr address of the request. - /// - /// Get remote_addr address from socket address + #[doc(hidden)] + #[deprecated(since = "4.0.0", note = "Renamed to `peer_addr`.")] pub fn remote_addr(&self) -> Option<&str> { - if let Some(ref remote_addr) = self.remote_addr { - Some(remote_addr) - } else { - None - } + self.peer_addr() } - /// Real ip remote addr of client initiated HTTP request. - /// - /// The addr is resolved through the following headers, in this order: - /// - /// - Forwarded - /// - X-Forwarded-For - /// - remote_addr name of opened socket - /// - /// # Security - /// Do not use this function for security purposes, unless you can ensure the Forwarded and - /// X-Forwarded-For headers cannot be spoofed by the client. If you want the client's socket - /// address explicitly, use - /// [`HttpRequest::peer_addr()`](super::web::HttpRequest::peer_addr()) instead. - #[inline] - pub fn realip_remote_addr(&self) -> Option<&str> { - if let Some(ref r) = self.realip_remote_addr { - Some(r) - } else if let Some(ref remote_addr) = self.remote_addr { - Some(remote_addr) - } else { - None +} + +impl FromRequest for ConnectionInfo { + type Error = Infallible; + type Future = Ready>; + + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + ok(req.connection_info().clone()) + } +} + +/// Extractor for peer's socket address. +/// +/// Also see [`HttpRequest::peer_addr`] and [`ConnectionInfo::peer_addr`]. +/// +/// # Examples +/// ``` +/// # use actix_web::Responder; +/// use actix_web::dev::PeerAddr; +/// +/// async fn handler(peer_addr: PeerAddr) -> impl Responder { +/// let socket_addr = peer_addr.0; +/// socket_addr.to_string() +/// } +/// # let _svc = actix_web::web::to(handler); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Display)] +#[display(fmt = "{}", _0)] +pub struct PeerAddr(pub SocketAddr); + +impl PeerAddr { + /// Unwrap into inner `SocketAddr` value. + pub fn into_inner(self) -> SocketAddr { + self.0 + } +} + +#[derive(Debug, Display, Error)] +#[non_exhaustive] +#[display(fmt = "Missing peer address")] +pub struct MissingPeerAddr; + +impl ResponseError for MissingPeerAddr {} + +impl FromRequest for PeerAddr { + type Error = MissingPeerAddr; + type Future = Ready>; + + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + match req.peer_addr() { + Some(addr) => ok(PeerAddr(addr)), + None => { + log::error!("Missing peer address."); + err(MissingPeerAddr) + } } } } @@ -192,13 +263,60 @@ mod tests { use super::*; use crate::test::TestRequest; + const X_FORWARDED_FOR: &str = "x-forwarded-for"; + const X_FORWARDED_HOST: &str = "x-forwarded-host"; + const X_FORWARDED_PROTO: &str = "x-forwarded-proto"; + #[test] - fn test_forwarded() { + fn info_default() { let req = TestRequest::default().to_http_request(); let info = req.connection_info(); assert_eq!(info.scheme(), "http"); assert_eq!(info.host(), "localhost:8080"); + } + #[test] + fn host_header() { + let req = TestRequest::default() + .insert_header((header::HOST, "rust-lang.org")) + .to_http_request(); + + let info = req.connection_info(); + assert_eq!(info.scheme(), "http"); + assert_eq!(info.host(), "rust-lang.org"); + assert_eq!(info.realip_remote_addr(), None); + } + + #[test] + fn x_forwarded_for_header() { + let req = TestRequest::default() + .insert_header((X_FORWARDED_FOR, "192.0.2.60")) + .to_http_request(); + let info = req.connection_info(); + assert_eq!(info.realip_remote_addr(), Some("192.0.2.60")); + } + + #[test] + fn x_forwarded_host_header() { + let req = TestRequest::default() + .insert_header((X_FORWARDED_HOST, "192.0.2.60")) + .to_http_request(); + let info = req.connection_info(); + assert_eq!(info.host(), "192.0.2.60"); + assert_eq!(info.realip_remote_addr(), None); + } + + #[test] + fn x_forwarded_proto_header() { + let req = TestRequest::default() + .insert_header((X_FORWARDED_PROTO, "https")) + .to_http_request(); + let info = req.connection_info(); + assert_eq!(info.scheme(), "https"); + } + + #[test] + fn forwarded_header() { let req = TestRequest::default() .insert_header(( header::FORWARDED, @@ -212,31 +330,142 @@ mod tests { assert_eq!(info.realip_remote_addr(), Some("192.0.2.60")); let req = TestRequest::default() - .insert_header((header::HOST, "rust-lang.org")) + .insert_header(( + header::FORWARDED, + "for=192.0.2.60; proto=https; by=203.0.113.43; host=rust-lang.org", + )) .to_http_request(); let info = req.connection_info(); - assert_eq!(info.scheme(), "http"); + assert_eq!(info.scheme(), "https"); assert_eq!(info.host(), "rust-lang.org"); - assert_eq!(info.realip_remote_addr(), None); + assert_eq!(info.realip_remote_addr(), Some("192.0.2.60")); + } + #[test] + fn forwarded_case_sensitivity() { let req = TestRequest::default() - .insert_header((X_FORWARDED_FOR, "192.0.2.60")) + .insert_header((header::FORWARDED, "For=192.0.2.60")) .to_http_request(); let info = req.connection_info(); assert_eq!(info.realip_remote_addr(), Some("192.0.2.60")); + } + #[test] + fn forwarded_weird_whitespace() { let req = TestRequest::default() - .insert_header((X_FORWARDED_HOST, "192.0.2.60")) + .insert_header((header::FORWARDED, "for= 1.2.3.4; proto= https")) .to_http_request(); let info = req.connection_info(); - assert_eq!(info.host(), "192.0.2.60"); - assert_eq!(info.realip_remote_addr(), None); + assert_eq!(info.realip_remote_addr(), Some("1.2.3.4")); + assert_eq!(info.scheme(), "https"); let req = TestRequest::default() - .insert_header((X_FORWARDED_PROTO, "https")) + .insert_header((header::FORWARDED, " for = 1.2.3.4 ")) + .to_http_request(); + let info = req.connection_info(); + assert_eq!(info.realip_remote_addr(), Some("1.2.3.4")); + } + + #[test] + fn forwarded_for_quoted() { + let req = TestRequest::default() + .insert_header((header::FORWARDED, r#"for="192.0.2.60:8080""#)) + .to_http_request(); + let info = req.connection_info(); + assert_eq!(info.realip_remote_addr(), Some("192.0.2.60:8080")); + } + + #[test] + fn forwarded_for_ipv6() { + let req = TestRequest::default() + .insert_header((header::FORWARDED, r#"for="[2001:db8:cafe::17]:4711""#)) + .to_http_request(); + let info = req.connection_info(); + assert_eq!(info.realip_remote_addr(), Some("[2001:db8:cafe::17]:4711")); + } + + #[test] + fn forwarded_for_multiple() { + let req = TestRequest::default() + .insert_header((header::FORWARDED, "for=192.0.2.60, for=198.51.100.17")) + .to_http_request(); + let info = req.connection_info(); + // takes the first value + assert_eq!(info.realip_remote_addr(), Some("192.0.2.60")); + } + + #[test] + fn scheme_from_uri() { + let req = TestRequest::get() + .uri("https://actix.rs/test") .to_http_request(); let info = req.connection_info(); assert_eq!(info.scheme(), "https"); } + + #[test] + fn host_from_uri() { + let req = TestRequest::get() + .uri("https://actix.rs/test") + .to_http_request(); + let info = req.connection_info(); + assert_eq!(info.host(), "actix.rs"); + } + + #[test] + fn host_from_server_hostname() { + let mut req = TestRequest::get(); + req.set_server_hostname("actix.rs"); + let req = req.to_http_request(); + + let info = req.connection_info(); + assert_eq!(info.host(), "actix.rs"); + } + + #[actix_rt::test] + async fn conn_info_extract() { + let req = TestRequest::default() + .uri("https://actix.rs/test") + .to_http_request(); + let conn_info = ConnectionInfo::extract(&req).await.unwrap(); + assert_eq!(conn_info.scheme(), "https"); + assert_eq!(conn_info.host(), "actix.rs"); + } + + #[actix_rt::test] + async fn peer_addr_extract() { + let req = TestRequest::default().to_http_request(); + let res = PeerAddr::extract(&req).await; + assert!(res.is_err()); + + let addr = "127.0.0.1:8080".parse().unwrap(); + let req = TestRequest::default().peer_addr(addr).to_http_request(); + let peer_addr = PeerAddr::extract(&req).await.unwrap(); + assert_eq!(peer_addr, PeerAddr(addr)); + } + + #[actix_rt::test] + async fn remote_address() { + let req = TestRequest::default().to_http_request(); + let res = ConnectionInfo::extract(&req).await.unwrap(); + assert!(res.peer_addr().is_none()); + + let addr = "127.0.0.1:8080".parse().unwrap(); + let req = TestRequest::default().peer_addr(addr).to_http_request(); + let conn_info = ConnectionInfo::extract(&req).await.unwrap(); + assert_eq!(conn_info.peer_addr().unwrap(), "127.0.0.1"); + } + + #[actix_rt::test] + async fn real_ip_from_socket_addr() { + let req = TestRequest::default().to_http_request(); + let res = ConnectionInfo::extract(&req).await.unwrap(); + assert!(res.realip_remote_addr().is_none()); + + let addr = "127.0.0.1:8080".parse().unwrap(); + let req = TestRequest::default().peer_addr(addr).to_http_request(); + let conn_info = ConnectionInfo::extract(&req).await.unwrap(); + assert_eq!(conn_info.realip_remote_addr().unwrap(), "127.0.0.1"); + } } diff --git a/src/lib.rs b/src/lib.rs index 16b2ab186..18f0d581d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,7 @@ //! Actix Web is a powerful, pragmatic, and extremely fast web framework for Rust. //! -//! ## Example -//! -//! ```rust,no_run +//! # Examples +//! ```no_run //! use actix_web::{get, web, App, HttpServer, Responder}; //! //! #[get("/{id}/{name}/index.html")] @@ -20,76 +19,73 @@ //! } //! ``` //! -//! ## Documentation & Community Resources -//! +//! # Documentation & Community Resources //! In addition to this API documentation, several other resources are available: //! //! * [Website & User Guide](https://actix.rs/) //! * [Examples Repository](https://github.com/actix/examples) -//! * [Community Chat on Gitter](https://gitter.im/actix/actix-web) +//! * [Community Chat on Discord](https://discord.gg/NWpN5mmg3x) //! //! To get started navigating the API docs, you may consider looking at the following pages first: //! -//! * [App]: This struct represents an Actix Web application and is used to +//! * [`App`]: This struct represents an Actix Web application and is used to //! configure routes and other common application settings. //! -//! * [HttpServer]: This struct represents an HTTP server instance and is +//! * [`HttpServer`]: This struct represents an HTTP server instance and is //! used to instantiate and configure servers. //! -//! * [web]: This module provides essential types for route registration as well as +//! * [`web`]: This module provides essential types for route registration as well as //! common utilities for request handlers. //! -//! * [HttpRequest] and [HttpResponse]: These +//! * [`HttpRequest`] and [`HttpResponse`]: These //! structs represent HTTP requests and responses and expose methods for creating, inspecting, //! and otherwise utilizing them. //! -//! ## Features -//! +//! # Features //! * Supports *HTTP/1.x* and *HTTP/2* //! * Streaming and pipelining //! * Keep-alive and slow requests handling //! * Client/server [WebSockets](https://actix.rs/docs/websockets/) support -//! * Transparent content compression/decompression (br, gzip, deflate) +//! * Transparent content compression/decompression (br, gzip, deflate, zstd) //! * Powerful [request routing](https://actix.rs/docs/url-dispatch/) //! * Multipart streams //! * Static assets //! * SSL support using OpenSSL or Rustls //! * Middlewares ([Logger, Session, CORS, etc](https://actix.rs/docs/middleware/)) -//! * Includes an async [HTTP client](https://actix.rs/actix-web/actix_web/client/index.html) -//! * Runs on stable Rust 1.46+ +//! * Includes an async [HTTP client](https://docs.rs/awc/) +//! * Runs on stable Rust 1.54+ //! -//! ## Crate Features -//! -//! * `compress` - content encoding compression support (enabled by default) +//! # Crate Features //! * `cookies` - cookies support (enabled by default) +//! * `compress-brotli` - brotli content encoding compression support (enabled by default) +//! * `compress-gzip` - gzip and deflate content encoding compression support (enabled by default) +//! * `compress-zstd` - zstd content encoding compression support (enabled by default) //! * `openssl` - HTTPS support via `openssl` crate, supports `HTTP/2` //! * `rustls` - HTTPS support via `rustls` crate, supports `HTTP/2` //! * `secure-cookies` - secure cookies support #![deny(rust_2018_idioms, nonstandard_style)] -#![allow(clippy::needless_doctest_main, clippy::type_complexity)] +#![warn(future_incompatible)] #![doc(html_logo_url = "https://actix.rs/img/logo.png")] #![doc(html_favicon_url = "https://actix.rs/favicon.ico")] -#[cfg(feature = "openssl")] -extern crate tls_openssl as openssl; -#[cfg(feature = "rustls")] -extern crate tls_rustls as rustls; - mod app; mod app_service; mod config; mod data; +pub mod dev; pub mod error; mod extract; pub mod guard; mod handler; +mod helpers; +pub mod http; mod info; pub mod middleware; mod request; mod request_data; mod resource; -mod responder; +mod response; mod rmap; mod route; mod scope; @@ -99,105 +95,23 @@ pub mod test; pub(crate) mod types; pub mod web; -#[cfg(feature = "cookies")] -pub use actix_http::cookie; -pub use actix_http::Response as HttpResponse; -pub use actix_http::{body, http, Error, HttpMessage, ResponseError, Result}; +pub use actix_http::{body, HttpMessage}; +#[doc(inline)] pub use actix_rt as rt; pub use actix_web_codegen::*; +#[cfg(feature = "cookies")] +pub use cookie; pub use crate::app::App; +pub use crate::error::{Error, ResponseError, Result}; pub use crate::extract::FromRequest; +pub use crate::handler::Handler; pub use crate::request::HttpRequest; pub use crate::resource::Resource; -pub use crate::responder::Responder; +pub use crate::response::{CustomizeResponder, HttpResponse, HttpResponseBuilder, Responder}; pub use crate::route::Route; pub use crate::scope::Scope; pub use crate::server::HttpServer; -// TODO: is exposing the error directly really needed -pub use crate::types::{Either, EitherExtractError}; +pub use crate::types::Either; -pub mod dev { - //! The `actix-web` prelude for library developers - //! - //! The purpose of this module is to alleviate imports of many common actix - //! traits by adding a glob import to the top of actix heavy modules: - //! - //! ``` - //! # #![allow(unused_imports)] - //! use actix_web::dev::*; - //! ``` - - pub use crate::config::{AppConfig, AppService}; - #[doc(hidden)] - pub use crate::handler::Handler; - pub use crate::info::ConnectionInfo; - pub use crate::rmap::ResourceMap; - pub use crate::service::{HttpServiceFactory, ServiceRequest, ServiceResponse, WebService}; - - pub use crate::types::form::UrlEncoded; - pub use crate::types::json::JsonBody; - pub use crate::types::readlines::Readlines; - - pub use actix_http::body::{Body, BodySize, MessageBody, ResponseBody, SizedStream}; - #[cfg(feature = "compress")] - pub use actix_http::encoding::Decoder as Decompress; - pub use actix_http::ResponseBuilder as HttpResponseBuilder; - pub use actix_http::{Extensions, Payload, PayloadStream, RequestHead, ResponseHead}; - pub use actix_router::{Path, ResourceDef, ResourcePath, Url}; - pub use actix_server::Server; - pub use actix_service::{Service, Transform}; - - pub(crate) fn insert_slash(mut patterns: Vec) -> Vec { - for path in &mut patterns { - if !path.is_empty() && !path.starts_with('/') { - path.insert(0, '/'); - }; - } - patterns - } - - use crate::http::header::ContentEncoding; - use actix_http::{Response, ResponseBuilder}; - - struct Enc(ContentEncoding); - - /// Helper trait that allows to set specific encoding for response. - pub trait BodyEncoding { - /// Get content encoding - fn get_encoding(&self) -> Option; - - /// Set content encoding - fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self; - } - - impl BodyEncoding for ResponseBuilder { - fn get_encoding(&self) -> Option { - if let Some(ref enc) = self.extensions().get::() { - Some(enc.0) - } else { - None - } - } - - fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self { - self.extensions_mut().insert(Enc(encoding)); - self - } - } - - impl BodyEncoding for Response { - fn get_encoding(&self) -> Option { - if let Some(ref enc) = self.extensions().get::() { - Some(enc.0) - } else { - None - } - } - - fn encoding(&mut self, encoding: ContentEncoding) -> &mut Self { - self.extensions_mut().insert(Enc(encoding)); - self - } - } -} +pub(crate) type BoxError = Box; diff --git a/src/middleware/compat.rs b/src/middleware/compat.rs index 6f60264b1..18c9ff6a7 100644 --- a/src/middleware/compat.rs +++ b/src/middleware/compat.rs @@ -6,17 +6,21 @@ use std::{ task::{Context, Poll}, }; -use actix_http::body::{Body, MessageBody, ResponseBody}; -use actix_service::{Service, Transform}; use futures_core::{future::LocalBoxFuture, ready}; +use pin_project_lite::pin_project; -use crate::{error::Error, service::ServiceResponse}; +use crate::{ + body::{BoxBody, MessageBody}, + dev::{Service, Transform}, + error::Error, + service::ServiceResponse, +}; /// Middleware for enabling any middleware to be used in [`Resource::wrap`](crate::Resource::wrap), -/// [`Scope::wrap`](crate::Scope::wrap) and [`Condition`](super::Condition). +/// and [`Condition`](super::Condition). /// /// # Examples -/// ```rust +/// ``` /// use actix_web::middleware::{Logger, Compat}; /// use actix_web::{App, web}; /// @@ -34,6 +38,15 @@ pub struct Compat { transform: T, } +#[cfg(test)] +impl Compat { + pub(crate) fn noop() -> Self { + Self { + transform: super::Noop, + } + } +} + impl Compat { /// Wrap a middleware to give it broader compatibility. pub fn new(middleware: T) -> Self { @@ -49,9 +62,9 @@ where T: Transform, T::Future: 'static, T::Response: MapServiceResponseBody, - Error: From, + T::Error: Into, { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; type Transform = CompatMiddleware; type InitError = T::InitError; @@ -74,15 +87,13 @@ impl Service for CompatMiddleware where S: Service, S::Response: MapServiceResponseBody, - Error: From, + S::Error: Into, { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; type Future = CompatMiddlewareFuture; - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx).map_err(From::from) - } + actix_service::forward_ready!(service); fn call(&self, req: Req) -> Self::Future { let fut = self.service.call(req); @@ -90,34 +101,43 @@ where } } -#[pin_project::pin_project] -pub struct CompatMiddlewareFuture { - #[pin] - fut: Fut, +pin_project! { + pub struct CompatMiddlewareFuture { + #[pin] + fut: Fut, + } } impl Future for CompatMiddlewareFuture where Fut: Future>, T: MapServiceResponseBody, - Error: From, + E: Into, { - type Output = Result; + type Output = Result, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let res = ready!(self.project().fut.poll(cx))?; + let res = match ready!(self.project().fut.poll(cx)) { + Ok(res) => res, + Err(err) => return Poll::Ready(Err(err.into())), + }; + Poll::Ready(Ok(res.map_body())) } } /// Convert `ServiceResponse`'s `ResponseBody` generic type to `ResponseBody`. pub trait MapServiceResponseBody { - fn map_body(self) -> ServiceResponse; + fn map_body(self) -> ServiceResponse; } -impl MapServiceResponseBody for ServiceResponse { - fn map_body(self) -> ServiceResponse { - self.map_body(|_, body| ResponseBody::Other(Body::from_message(body))) +impl MapServiceResponseBody for ServiceResponse +where + B: MessageBody + 'static, +{ + #[inline] + fn map_body(self) -> ServiceResponse { + self.map_into_boxed_body() } } @@ -137,7 +157,7 @@ mod tests { use crate::{web, App, HttpResponse}; #[actix_rt::test] - #[cfg(feature = "cookies")] + #[cfg(all(feature = "cookies", feature = "__compress"))] async fn test_scope_middleware() { use crate::middleware::Compress; @@ -147,7 +167,7 @@ mod tests { let srv = init_service( App::new().service( web::scope("app") - .wrap(Compat::new(logger)) + .wrap(logger) .wrap(Compat::new(compress)) .service(web::resource("/test").route(web::get().to(HttpResponse::Ok))), ), @@ -160,7 +180,7 @@ mod tests { } #[actix_rt::test] - #[cfg(feature = "cookies")] + #[cfg(all(feature = "cookies", feature = "__compress"))] async fn test_resource_scope_middleware() { use crate::middleware::Compress; diff --git a/src/middleware/compress.rs b/src/middleware/compress.rs index 698ba768e..16af4c2cd 100644 --- a/src/middleware/compress.rs +++ b/src/middleware/compress.rs @@ -1,125 +1,166 @@ //! For middleware documentation, see [`Compress`]. use std::{ - cmp, future::Future, marker::PhantomData, pin::Pin, - str::FromStr, task::{Context, Poll}, }; -use actix_http::{ - body::MessageBody, - encoding::Encoder, - http::header::{ContentEncoding, ACCEPT_ENCODING}, - Error, -}; +use actix_http::encoding::Encoder; use actix_service::{Service, Transform}; +use actix_utils::future::{ok, Either, Ready}; use futures_core::ready; -use futures_util::future::{ok, Ready}; -use pin_project::pin_project; +use once_cell::sync::Lazy; +use pin_project_lite::pin_project; use crate::{ - dev::BodyEncoding, + body::{EitherBody, MessageBody}, + http::{ + header::{self, AcceptEncoding, Encoding, HeaderValue}, + StatusCode, + }, service::{ServiceRequest, ServiceResponse}, + Error, HttpMessage, HttpResponse, }; /// Middleware for compressing response payloads. /// -/// Use `BodyEncoding` trait for overriding response compression. To disable compression set -/// encoding to `ContentEncoding::Identity`. +/// # Encoding Negotiation +/// `Compress` will read the `Accept-Encoding` header to negotiate which compression codec to use. +/// Payloads are not compressed if the header is not sent. The `compress-*` [feature flags] are also +/// considered in this selection process. +/// +/// # Pre-compressed Payload +/// If you are serving some data is already using a compressed representation (e.g., a gzip +/// compressed HTML file from disk) you can signal this to `Compress` by setting an appropriate +/// `Content-Encoding` header. In addition to preventing double compressing the payload, this header +/// is required by the spec when using compressed representations and will inform the client that +/// the content should be uncompressed. +/// +/// However, it is not advised to unconditionally serve encoded representations of content because +/// the client may not support it. The [`AcceptEncoding`] typed header has some utilities to help +/// perform manual encoding negotiation, if required. When negotiating content encoding, it is also +/// required by the spec to send a `Vary: Accept-Encoding` header. +/// +/// A (naïve) example serving an pre-compressed Gzip file is included below. /// /// # Examples -/// ```rust -/// use actix_web::{web, middleware, App, HttpResponse}; +/// To enable automatic payload compression just include `Compress` as a top-level middleware: +/// ``` +/// use actix_web::{middleware, web, App, HttpResponse}; /// /// let app = App::new() /// .wrap(middleware::Compress::default()) -/// .default_service(web::to(|| HttpResponse::NotFound())); +/// .default_service(web::to(|| HttpResponse::Ok().body("hello world"))); /// ``` -#[derive(Debug, Clone)] -pub struct Compress(ContentEncoding); - -impl Compress { - /// Create new `Compress` middleware with the specified encoding. - pub fn new(encoding: ContentEncoding) -> Self { - Compress(encoding) - } -} - -impl Default for Compress { - fn default() -> Self { - Compress::new(ContentEncoding::Auto) - } -} +/// +/// Pre-compressed Gzip file being served from disk with correct headers added to bypass middleware: +/// ```no_run +/// use actix_web::{middleware, http::header, web, App, HttpResponse, Responder}; +/// +/// async fn index_handler() -> actix_web::Result { +/// Ok(actix_files::NamedFile::open_async("./assets/index.html.gz").await? +/// .customize() +/// .insert_header(header::ContentEncoding::Gzip)) +/// } +/// +/// let app = App::new() +/// .wrap(middleware::Compress::default()) +/// .default_service(web::to(index_handler)); +/// ``` +/// +/// [feature flags]: ../index.html#crate-features +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct Compress; impl Transform for Compress where B: MessageBody, S: Service, Error = Error>, { - type Response = ServiceResponse>; + type Response = ServiceResponse>>; type Error = Error; type Transform = CompressMiddleware; type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - ok(CompressMiddleware { - service, - encoding: self.0, - }) + ok(CompressMiddleware { service }) } } pub struct CompressMiddleware { service: S, - encoding: ContentEncoding, } impl Service for CompressMiddleware where - B: MessageBody, S: Service, Error = Error>, + B: MessageBody, { - type Response = ServiceResponse>; + type Response = ServiceResponse>>; type Error = Error; - type Future = CompressResponse; + #[allow(clippy::type_complexity)] + type Future = Either, Ready>>; actix_service::forward_ready!(service); #[allow(clippy::borrow_interior_mutable_const)] fn call(&self, req: ServiceRequest) -> Self::Future { // negotiate content-encoding - let encoding = if let Some(val) = req.headers().get(&ACCEPT_ENCODING) { - if let Ok(enc) = val.to_str() { - AcceptEncoding::parse(enc, self.encoding) - } else { - ContentEncoding::Identity + let accept_encoding = req.get_header::(); + + let accept_encoding = match accept_encoding { + // missing header; fallback to identity + None => { + return Either::left(CompressResponse { + encoding: Encoding::identity(), + fut: self.service.call(req), + _phantom: PhantomData, + }) } - } else { - ContentEncoding::Identity + + // valid accept-encoding header + Some(accept_encoding) => accept_encoding, }; - CompressResponse { - encoding, - fut: self.service.call(req), - _phantom: PhantomData, + match accept_encoding.negotiate(SUPPORTED_ENCODINGS.iter()) { + None => { + let mut res = HttpResponse::with_body( + StatusCode::NOT_ACCEPTABLE, + SUPPORTED_ENCODINGS_STRING.as_str(), + ); + + res.headers_mut() + .insert(header::VARY, HeaderValue::from_static("Accept-Encoding")); + + Either::right(ok(req + .into_response(res) + .map_into_boxed_body() + .map_into_right_body())) + } + + Some(encoding) => Either::left(CompressResponse { + fut: self.service.call(req), + encoding, + _phantom: PhantomData, + }), } } } -#[pin_project] -pub struct CompressResponse -where - S: Service, - B: MessageBody, -{ - #[pin] - fut: S::Future, - encoding: ContentEncoding, - _phantom: PhantomData, +pin_project! { + pub struct CompressResponse + where + S: Service, + { + #[pin] + fut: S::Future, + encoding: Encoding, + _phantom: PhantomData, + } } impl Future for CompressResponse @@ -127,92 +168,141 @@ where B: MessageBody, S: Service, Error = Error>, { - type Output = Result>, Error>; + type Output = Result>>, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); match ready!(this.fut.poll(cx)) { Ok(resp) => { - let enc = if let Some(enc) = resp.response().get_encoding() { - enc - } else { - *this.encoding + let enc = match this.encoding { + Encoding::Known(enc) => *enc, + Encoding::Unknown(enc) => { + unimplemented!("encoding {} should not be here", enc); + } }; - Poll::Ready(Ok( - resp.map_body(move |head, body| Encoder::response(enc, head, body)) - )) + Poll::Ready(Ok(resp.map_body(move |head, body| { + EitherBody::left(Encoder::response(enc, head, body)) + }))) } - Err(e) => Poll::Ready(Err(e)), + + Err(err) => Poll::Ready(Err(err)), } } } -struct AcceptEncoding { - encoding: ContentEncoding, - quality: f64, -} +static SUPPORTED_ENCODINGS_STRING: Lazy = Lazy::new(|| { + #[allow(unused_mut)] // only unused when no compress features enabled + let mut encoding: Vec<&str> = vec![]; -impl Eq for AcceptEncoding {} - -impl Ord for AcceptEncoding { - #[allow(clippy::comparison_chain)] - fn cmp(&self, other: &AcceptEncoding) -> cmp::Ordering { - if self.quality > other.quality { - cmp::Ordering::Less - } else if self.quality < other.quality { - cmp::Ordering::Greater - } else { - cmp::Ordering::Equal - } - } -} - -impl PartialOrd for AcceptEncoding { - fn partial_cmp(&self, other: &AcceptEncoding) -> Option { - Some(self.cmp(other)) - } -} - -impl PartialEq for AcceptEncoding { - fn eq(&self, other: &AcceptEncoding) -> bool { - self.quality == other.quality - } -} - -impl AcceptEncoding { - fn new(tag: &str) -> Option { - let parts: Vec<&str> = tag.split(';').collect(); - let encoding = match parts.len() { - 0 => return None, - _ => ContentEncoding::from(parts[0]), - }; - let quality = match parts.len() { - 1 => encoding.quality(), - _ => f64::from_str(parts[1]).unwrap_or(0.0), - }; - Some(AcceptEncoding { encoding, quality }) + #[cfg(feature = "compress-brotli")] + { + encoding.push("br"); } - /// Parse a raw Accept-Encoding header value into an ordered list. - pub fn parse(raw: &str, encoding: ContentEncoding) -> ContentEncoding { - let mut encodings: Vec<_> = raw - .replace(' ', "") - .split(',') - .map(|l| AcceptEncoding::new(l)) - .collect(); - encodings.sort(); + #[cfg(feature = "compress-gzip")] + { + encoding.push("gzip"); + encoding.push("deflate"); + } - for enc in encodings { - if let Some(enc) = enc { - if encoding == ContentEncoding::Auto { - return enc.encoding; - } else if encoding == enc.encoding { - return encoding; - } - } - } - ContentEncoding::Identity + #[cfg(feature = "compress-zstd")] + { + encoding.push("zstd"); + } + + assert!( + !encoding.is_empty(), + "encoding can not be empty unless __compress feature has been explicitly enabled by itself" + ); + + encoding.join(", ") +}); + +static SUPPORTED_ENCODINGS: Lazy> = Lazy::new(|| { + let mut encodings = vec![Encoding::identity()]; + + #[cfg(feature = "compress-brotli")] + { + encodings.push(Encoding::brotli()); + } + + #[cfg(feature = "compress-gzip")] + { + encodings.push(Encoding::gzip()); + encodings.push(Encoding::deflate()); + } + + #[cfg(feature = "compress-zstd")] + { + encodings.push(Encoding::zstd()); + } + + assert!( + !encodings.is_empty(), + "encodings can not be empty unless __compress feature has been explicitly enabled by itself" + ); + + encodings +}); + +// move cfg(feature) to prevents_double_compressing if more tests are added +#[cfg(feature = "compress-gzip")] +#[cfg(test)] +mod tests { + use super::*; + use crate::{middleware::DefaultHeaders, test, web, App}; + + pub fn gzip_decode(bytes: impl AsRef<[u8]>) -> Vec { + use std::io::Read as _; + let mut decoder = flate2::read::GzDecoder::new(bytes.as_ref()); + let mut buf = Vec::new(); + decoder.read_to_end(&mut buf).unwrap(); + buf + } + + #[actix_rt::test] + async fn prevents_double_compressing() { + const D: &str = "hello world "; + const DATA: &str = const_str::repeat!(D, 100); + + let app = test::init_service({ + App::new() + .wrap(Compress::default()) + .route( + "/single", + web::get().to(move || HttpResponse::Ok().body(DATA)), + ) + .service( + web::resource("/double") + .wrap(Compress::default()) + .wrap(DefaultHeaders::new().add(("x-double", "true"))) + .route(web::get().to(move || HttpResponse::Ok().body(DATA))), + ) + }) + .await; + + let req = test::TestRequest::default() + .uri("/single") + .insert_header((header::ACCEPT_ENCODING, "gzip")) + .to_request(); + let res = test::call_service(&app, req).await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get("x-double"), None); + assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "gzip"); + let bytes = test::read_body(res).await; + assert_eq!(gzip_decode(bytes), DATA.as_bytes()); + + let req = test::TestRequest::default() + .uri("/double") + .insert_header((header::ACCEPT_ENCODING, "gzip")) + .to_request(); + let res = test::call_service(&app, req).await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get("x-double").unwrap(), "true"); + assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "gzip"); + let bytes = test::read_body(res).await; + assert_eq!(gzip_decode(bytes), DATA.as_bytes()); } } diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs index f0c344062..659f88bc9 100644 --- a/src/middleware/condition.rs +++ b/src/middleware/condition.rs @@ -3,7 +3,9 @@ use std::task::{Context, Poll}; use actix_service::{Service, Transform}; -use futures_util::future::{Either, FutureExt, LocalBoxFuture}; +use actix_utils::future::Either; +use futures_core::future::LocalBoxFuture; +use futures_util::future::FutureExt as _; /// Middleware for conditionally enabling other middleware. /// @@ -12,7 +14,7 @@ use futures_util::future::{Either, FutureExt, LocalBoxFuture}; /// middleware for a workaround. /// /// # Examples -/// ```rust +/// ``` /// use actix_web::middleware::{Condition, NormalizePath}; /// use actix_web::App; /// @@ -85,8 +87,8 @@ where fn call(&self, req: Req) -> Self::Future { match self { - ConditionMiddleware::Enable(service) => Either::Left(service.call(req)), - ConditionMiddleware::Disable(service) => Either::Right(service.call(req)), + ConditionMiddleware::Enable(service) => Either::left(service.call(req)), + ConditionMiddleware::Disable(service) => Either::right(service.call(req)), } } } @@ -94,23 +96,28 @@ where #[cfg(test)] mod tests { use actix_service::IntoService; - use futures_util::future::ok; + use actix_utils::future::ok; use super::*; use crate::{ dev::{ServiceRequest, ServiceResponse}, error::Result, - http::{header::CONTENT_TYPE, HeaderValue, StatusCode}, - middleware::err_handlers::*, + http::{ + header::{HeaderValue, CONTENT_TYPE}, + StatusCode, + }, + middleware::{err_handlers::*, Compat}, test::{self, TestRequest}, HttpResponse, }; + #[allow(clippy::unnecessary_wraps)] fn render_500(mut res: ServiceResponse) -> Result> { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); - Ok(ErrorHandlerResponse::Response(res)) + + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) } #[actix_rt::test] @@ -119,7 +126,9 @@ mod tests { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; - let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + let mw = Compat::new( + ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500), + ); let mw = Condition::new(true, mw) .new_transform(srv.into_service()) @@ -135,7 +144,9 @@ mod tests { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; - let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); + let mw = Compat::new( + ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500), + ); let mw = Condition::new(false, mw) .new_transform(srv.into_service()) diff --git a/src/middleware/default_headers.rs b/src/middleware/default_headers.rs index a36cc2f29..003abd40d 100644 --- a/src/middleware/default_headers.rs +++ b/src/middleware/default_headers.rs @@ -9,17 +9,14 @@ use std::{ task::{Context, Poll}, }; -use futures_util::{ - future::{ready, Ready}, - ready, -}; +use actix_http::error::HttpError; +use actix_utils::future::{ready, Ready}; +use futures_core::ready; +use pin_project_lite::pin_project; use crate::{ dev::{Service, Transform}, - http::{ - header::{HeaderName, HeaderValue, CONTENT_TYPE}, - Error as HttpError, HeaderMap, - }, + http::header::{HeaderMap, HeaderName, HeaderValue, TryIntoHeaderPair, CONTENT_TYPE}, service::{ServiceRequest, ServiceResponse}, Error, }; @@ -29,82 +26,83 @@ use crate::{ /// Headers with the same key that are already set in a response will *not* be overwritten. /// /// # Examples -/// ```rust +/// ``` /// use actix_web::{web, http, middleware, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new() -/// .wrap(middleware::DefaultHeaders::new().header("X-Version", "0.2")) -/// .service( -/// web::resource("/test") -/// .route(web::get().to(|| HttpResponse::Ok())) -/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) -/// ); -/// } +/// let app = App::new() +/// .wrap(middleware::DefaultHeaders::new().add(("X-Version", "0.2"))) +/// .service( +/// web::resource("/test") +/// .route(web::get().to(|| HttpResponse::Ok())) +/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) +/// ); /// ``` -#[derive(Clone)] +#[derive(Debug, Clone, Default)] pub struct DefaultHeaders { inner: Rc, } +#[derive(Debug, Default)] struct Inner { headers: HeaderMap, } -impl Default for DefaultHeaders { - fn default() -> Self { - DefaultHeaders { - inner: Rc::new(Inner { - headers: HeaderMap::new(), - }), - } - } -} - impl DefaultHeaders { /// Constructs an empty `DefaultHeaders` middleware. + #[inline] pub fn new() -> DefaultHeaders { DefaultHeaders::default() } /// Adds a header to the default set. - #[inline] - pub fn header(mut self, key: K, value: V) -> Self + /// + /// # Panics + /// Panics when resolved header name or value is invalid. + #[allow(clippy::should_implement_trait)] + pub fn add(mut self, header: impl TryIntoHeaderPair) -> Self { + // standard header terminology `insert` or `append` for this method would make the behavior + // of this middleware less obvious since it only adds the headers if they are not present + + match header.try_into_pair() { + Ok((key, value)) => Rc::get_mut(&mut self.inner) + .expect("All default headers must be added before cloning.") + .headers + .append(key, value), + Err(err) => panic!("Invalid header: {}", err.into()), + } + + self + } + + #[doc(hidden)] + #[deprecated( + since = "4.0.0", + note = "Prefer `.add((key, value))`. Will be removed in v5." + )] + pub fn header(self, key: K, value: V) -> Self where HeaderName: TryFrom, >::Error: Into, HeaderValue: TryFrom, >::Error: Into, { - #[allow(clippy::match_wild_err_arm)] - match HeaderName::try_from(key) { - Ok(key) => match HeaderValue::try_from(value) { - Ok(value) => { - Rc::get_mut(&mut self.inner) - .expect("Multiple copies exist") - .headers - .append(key, value); - } - Err(_) => panic!("Can not create header value"), - }, - Err(_) => panic!("Can not create header name"), - } - self + self.add(( + HeaderName::try_from(key) + .map_err(Into::into) + .expect("Invalid header name"), + HeaderValue::try_from(value) + .map_err(Into::into) + .expect("Invalid header value"), + )) } /// Adds a default *Content-Type* header if response does not contain one. /// /// Default is `application/octet-stream`. - pub fn add_content_type(mut self) -> Self { - Rc::get_mut(&mut self.inner) - .expect("Multiple `Inner` copies exist.") - .headers - .insert( - CONTENT_TYPE, - HeaderValue::from_static("application/octet-stream"), - ); - - self + pub fn add_content_type(self) -> Self { + #[allow(clippy::declare_interior_mutable_const)] + const HV_MIME: HeaderValue = HeaderValue::from_static("application/octet-stream"); + self.add((CONTENT_TYPE, HV_MIME)) } } @@ -122,7 +120,7 @@ where fn new_transform(&self, service: S) -> Self::Future { ready(Ok(DefaultHeadersMiddleware { service, - inner: self.inner.clone(), + inner: Rc::clone(&self.inner), })) } } @@ -155,12 +153,13 @@ where } } -#[pin_project::pin_project] -pub struct DefaultHeaderFuture, B> { - #[pin] - fut: S::Future, - inner: Rc, - _body: PhantomData, +pin_project! { + pub struct DefaultHeaderFuture, B> { + #[pin] + fut: S::Future, + inner: Rc, + _body: PhantomData, + } } impl Future for DefaultHeaderFuture @@ -188,28 +187,33 @@ where #[cfg(test)] mod tests { use actix_service::IntoService; - use futures_util::future::ok; + use actix_utils::future::ok; use super::*; use crate::{ dev::ServiceRequest, http::header::CONTENT_TYPE, - test::{ok_service, TestRequest}, + test::{self, TestRequest}, HttpResponse, }; #[actix_rt::test] - async fn test_default_headers() { + async fn adding_default_headers() { let mw = DefaultHeaders::new() - .header(CONTENT_TYPE, "0001") - .new_transform(ok_service()) + .add(("X-TEST", "0001")) + .add(("X-TEST-TWO", HeaderValue::from_static("123"))) + .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::default().to_srv_request(); - let resp = mw.call(req).await.unwrap(); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + let res = mw.call(req).await.unwrap(); + assert_eq!(res.headers().get("x-test").unwrap(), "0001"); + assert_eq!(res.headers().get("x-test-two").unwrap(), "123"); + } + #[actix_rt::test] + async fn no_override_existing() { let req = TestRequest::default().to_srv_request(); let srv = |req: ServiceRequest| { ok(req.into_response( @@ -219,7 +223,7 @@ mod tests { )) }; let mw = DefaultHeaders::new() - .header(CONTENT_TYPE, "0001") + .add((CONTENT_TYPE, "0001")) .new_transform(srv.into_service()) .await .unwrap(); @@ -228,11 +232,10 @@ mod tests { } #[actix_rt::test] - async fn test_content_type() { - let srv = |req: ServiceRequest| ok(req.into_response(HttpResponse::Ok().finish())); + async fn adding_content_type() { let mw = DefaultHeaders::new() .add_content_type() - .new_transform(srv.into_service()) + .new_transform(test::ok_service()) .await .unwrap(); @@ -243,4 +246,16 @@ mod tests { "application/octet-stream" ); } + + #[test] + #[should_panic] + fn invalid_header_name() { + DefaultHeaders::new().add((":", "hello")); + } + + #[test] + #[should_panic] + fn invalid_header_value() { + DefaultHeaders::new().add(("x-test", "\n")); + } } diff --git a/src/middleware/err_handlers.rs b/src/middleware/err_handlers.rs index 4673ed4ce..bde054330 100644 --- a/src/middleware/err_handlers.rs +++ b/src/middleware/err_handlers.rs @@ -10,20 +10,22 @@ use std::{ use actix_service::{Service, Transform}; use ahash::AHashMap; use futures_core::{future::LocalBoxFuture, ready}; +use pin_project_lite::pin_project; use crate::{ + body::EitherBody, dev::{ServiceRequest, ServiceResponse}, - error::{Error, Result}, http::StatusCode, + Error, Result, }; /// Return type for [`ErrorHandlers`] custom handlers. pub enum ErrorHandlerResponse { /// Immediate HTTP response. - Response(ServiceResponse), + Response(ServiceResponse>), /// A future that resolves to an HTTP response. - Future(LocalBoxFuture<'static, Result, Error>>), + Future(LocalBoxFuture<'static, Result>, Error>>), } type ErrorHandler = dyn Fn(ServiceResponse) -> Result>; @@ -34,26 +36,22 @@ type ErrorHandler = dyn Fn(ServiceResponse) -> Result(mut res: dev::ServiceResponse) -> Result> { -/// res.response_mut() -/// .headers_mut() -/// .insert(http::header::CONTENT_TYPE, http::HeaderValue::from_static("Error")); -/// Ok(ErrorHandlerResponse::Response(res)) +/// fn add_error_header(mut res: dev::ServiceResponse) -> Result> { +/// res.response_mut().headers_mut().insert( +/// header::CONTENT_TYPE, +/// header::HeaderValue::from_static("Error"), +/// ); +/// Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) /// } /// /// let app = App::new() -/// .wrap( -/// ErrorHandlers::new() -/// .handler(http::StatusCode::INTERNAL_SERVER_ERROR, render_500), -/// ) -/// .service(web::resource("/test") -/// .route(web::get().to(|| HttpResponse::Ok())) -/// .route(web::head().to(|| HttpResponse::MethodNotAllowed()) -/// )); +/// .wrap(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, add_error_header)) +/// .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError))); /// ``` pub struct ErrorHandlers { handlers: Handlers, @@ -64,7 +62,7 @@ type Handlers = Rc>>>; impl Default for ErrorHandlers { fn default() -> Self { ErrorHandlers { - handlers: Rc::new(AHashMap::default()), + handlers: Default::default(), } } } @@ -93,7 +91,7 @@ where S::Future: 'static, B: 'static, { - type Response = ServiceResponse; + type Response = ServiceResponse>; type Error = Error; type Transform = ErrorHandlersMiddleware; type InitError = (); @@ -117,7 +115,7 @@ where S::Future: 'static, B: 'static, { - type Response = ServiceResponse; + type Response = ServiceResponse>; type Error = Error; type Future = ErrorHandlersFuture; @@ -130,44 +128,50 @@ where } } -#[pin_project::pin_project(project = ErrorHandlersProj)] -pub enum ErrorHandlersFuture -where - Fut: Future, -{ - ServiceFuture { - #[pin] - fut: Fut, - handlers: Handlers, - }, - HandlerFuture { - fut: LocalBoxFuture<'static, Fut::Output>, - }, +pin_project! { + #[project = ErrorHandlersProj] + pub enum ErrorHandlersFuture + where + Fut: Future, + { + ServiceFuture { + #[pin] + fut: Fut, + handlers: Handlers, + }, + ErrorHandlerFuture { + fut: LocalBoxFuture<'static, Result>, Error>>, + }, + } } impl Future for ErrorHandlersFuture where Fut: Future, Error>>, { - type Output = Fut::Output; + type Output = Result>, Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.as_mut().project() { ErrorHandlersProj::ServiceFuture { fut, handlers } => { let res = ready!(fut.poll(cx))?; + match handlers.get(&res.status()) { Some(handler) => match handler(res)? { ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)), ErrorHandlerResponse::Future(fut) => { self.as_mut() - .set(ErrorHandlersFuture::HandlerFuture { fut }); + .set(ErrorHandlersFuture::ErrorHandlerFuture { fut }); + self.poll(cx) } }, - None => Poll::Ready(Ok(res)), + + None => Poll::Ready(Ok(res.map_into_left_body())), } } - ErrorHandlersProj::HandlerFuture { fut } => fut.as_mut().poll(cx), + + ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx), } } } @@ -175,28 +179,34 @@ where #[cfg(test)] mod tests { use actix_service::IntoService; - use futures_util::future::{ok, FutureExt}; + use actix_utils::future::ok; + use bytes::Bytes; + use futures_util::future::FutureExt as _; use super::*; - use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; - use crate::test::{self, TestRequest}; - use crate::HttpResponse; - - fn render_500(mut res: ServiceResponse) -> Result> { - res.response_mut() - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); - Ok(ErrorHandlerResponse::Response(res)) - } + use crate::{ + http::{ + header::{HeaderValue, CONTENT_TYPE}, + StatusCode, + }, + test::{self, TestRequest}, + }; #[actix_rt::test] - async fn test_handler() { - let srv = |req: ServiceRequest| { - ok(req.into_response(HttpResponse::InternalServerError().finish())) - }; + async fn add_header_error_handler() { + #[allow(clippy::unnecessary_wraps)] + fn error_handler(mut res: ServiceResponse) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); + + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) + } + + let srv = test::simple_service(StatusCode::INTERNAL_SERVER_ERROR); let mw = ErrorHandlers::new() - .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500) + .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler) .new_transform(srv.into_service()) .await .unwrap(); @@ -205,23 +215,25 @@ mod tests { assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } - fn render_500_async( - mut res: ServiceResponse, - ) -> Result> { - res.response_mut() - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); - Ok(ErrorHandlerResponse::Future(ok(res).boxed_local())) - } - #[actix_rt::test] - async fn test_handler_async() { - let srv = |req: ServiceRequest| { - ok(req.into_response(HttpResponse::InternalServerError().finish())) - }; + async fn add_header_error_handler_async() { + #[allow(clippy::unnecessary_wraps)] + fn error_handler( + mut res: ServiceResponse, + ) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); + + Ok(ErrorHandlerResponse::Future( + ok(res.map_into_left_body()).boxed_local(), + )) + } + + let srv = test::simple_service(StatusCode::INTERNAL_SERVER_ERROR); let mw = ErrorHandlers::new() - .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async) + .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler) .new_transform(srv.into_service()) .await .unwrap(); @@ -229,4 +241,34 @@ mod tests { let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } + + #[actix_rt::test] + async fn changes_body_type() { + #[allow(clippy::unnecessary_wraps)] + fn error_handler( + res: ServiceResponse, + ) -> Result> { + let (req, res) = res.into_parts(); + let res = res.set_body(Bytes::from("sorry, that's no bueno")); + + let res = ServiceResponse::new(req, res) + .map_into_boxed_body() + .map_into_right_body(); + + Ok(ErrorHandlerResponse::Response(res)) + } + + let srv = test::simple_service(StatusCode::INTERNAL_SERVER_ERROR); + + let mw = ErrorHandlers::new() + .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler) + .new_transform(srv.into_service()) + .await + .unwrap(); + + let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await; + assert_eq!(test::read_body(res).await, "sorry, that's no bueno"); + } + + // TODO: test where error is thrown } diff --git a/src/middleware/logger.rs b/src/middleware/logger.rs index 5b5b5577c..969cb0c10 100644 --- a/src/middleware/logger.rs +++ b/src/middleware/logger.rs @@ -13,18 +13,19 @@ use std::{ }; use actix_service::{Service, Transform}; +use actix_utils::future::{ready, Ready}; use bytes::Bytes; -use futures_util::future::{ok, Ready}; +use futures_core::ready; use log::{debug, warn}; +use pin_project_lite::pin_project; use regex::{Regex, RegexSet}; -use time::OffsetDateTime; +use time::{format_description::well_known::Rfc3339, OffsetDateTime}; use crate::{ - dev::{BodySize, MessageBody, ResponseBody}, - error::{Error, Result}, - http::{HeaderName, StatusCode}, + body::{BodySize, MessageBody}, + http::header::HeaderName, service::{ServiceRequest, ServiceResponse}, - HttpResponse, + Error, HttpResponse, Result, }; /// Middleware for logging request and response summaries to the terminal. @@ -43,7 +44,7 @@ use crate::{ /// ``` /// /// # Examples -/// ```rust +/// ``` /// use actix_web::{middleware::Logger, App}; /// /// // access logs are printed with the INFO level so ensure it is enabled by default @@ -124,8 +125,9 @@ impl Logger { /// It is convention to print "-" to indicate no output instead of an empty string. /// /// # Example - /// ```rust - /// # use actix_web::{http::HeaderValue, middleware::Logger}; + /// ``` + /// # use actix_web::http::{header::HeaderValue}; + /// # use actix_web::middleware::Logger; /// # fn parse_jwt_id (_req: Option<&HeaderValue>) -> String { "jwt_uid".to_owned() } /// Logger::new("example %{JWT_ID}xi") /// .custom_request_replace("JWT_ID", |req| parse_jwt_id(req.headers().get("Authorization"))); @@ -180,8 +182,8 @@ where { type Response = ServiceResponse>; type Error = Error; - type InitError = (); type Transform = LoggerMiddleware; + type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { @@ -195,10 +197,10 @@ where } } - ok(LoggerMiddleware { + ready(Ok(LoggerMiddleware { service, inner: self.0.clone(), - }) + })) } } @@ -246,17 +248,18 @@ where } } -#[pin_project::pin_project] -pub struct LoggerResponse -where - B: MessageBody, - S: Service, -{ - #[pin] - fut: S::Future, - time: OffsetDateTime, - format: Option, - _phantom: PhantomData, +pin_project! { + pub struct LoggerResponse + where + B: MessageBody, + S: Service, + { + #[pin] + fut: S::Future, + time: OffsetDateTime, + format: Option, + _phantom: PhantomData, + } } impl Future for LoggerResponse @@ -269,15 +272,13 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - let res = match futures_util::ready!(this.fut.poll(cx)) { + let res = match ready!(this.fut.poll(cx)) { Ok(res) => res, Err(e) => return Poll::Ready(Err(e)), }; if let Some(error) = res.response().error() { - if res.response().head().status != StatusCode::INTERNAL_SERVER_ERROR { - debug!("Error in response: {:?}", error); - } + debug!("Error in response: {:?}", error); } if let Some(ref mut format) = this.format { @@ -289,44 +290,42 @@ where let time = *this.time; let format = this.format.take(); - Poll::Ready(Ok(res.map_body(move |_, body| { - ResponseBody::Body(StreamLog { - body, - time, - format, - size: 0, - }) + Poll::Ready(Ok(res.map_body(move |_, body| StreamLog { + body, + time, + format, + size: 0, }))) } } -use pin_project::{pin_project, pinned_drop}; - -#[pin_project(PinnedDrop)] -pub struct StreamLog { - #[pin] - body: ResponseBody, - format: Option, - size: usize, - time: OffsetDateTime, -} - -#[pinned_drop] -impl PinnedDrop for StreamLog { - fn drop(self: Pin<&mut Self>) { - if let Some(ref format) = self.format { - let render = |fmt: &mut fmt::Formatter<'_>| { - for unit in &format.0 { - unit.render(fmt, self.size, self.time)?; - } - Ok(()) - }; - log::info!("{}", FormatDisplay(&render)); +pin_project! { + pub struct StreamLog { + #[pin] + body: B, + format: Option, + size: usize, + time: OffsetDateTime, + } + impl PinnedDrop for StreamLog { + fn drop(this: Pin<&mut Self>) { + if let Some(ref format) = this.format { + let render = |fmt: &mut fmt::Formatter<'_>| { + for unit in &format.0 { + unit.render(fmt, this.size, this.time)?; + } + Ok(()) + }; + log::info!("{}", FormatDisplay(&render)); + } } } } impl MessageBody for StreamLog { + type Error = B::Error; + + #[inline] fn size(&self) -> BodySize { self.body.size() } @@ -334,14 +333,16 @@ impl MessageBody for StreamLog { fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { let this = self.project(); - match this.body.poll_next(cx) { - Poll::Ready(Some(Ok(chunk))) => { + + match ready!(this.body.poll_next(cx)) { + Some(Ok(chunk)) => { *this.size += chunk.len(); Poll::Ready(Some(Ok(chunk))) } - val => val, + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), } } } @@ -363,7 +364,7 @@ impl Format { /// Returns `None` if the format string syntax is incorrect. pub fn new(s: &str) -> Format { log::trace!("Access log format: {}", s); - let fmt = Regex::new(r"%(\{([A-Za-z0-9\-_]+)\}([aioe]|xi)|[atPrUsbTD]?)").unwrap(); + let fmt = Regex::new(r"%(\{([A-Za-z0-9\-_]+)\}([aioe]|xi)|[%atPrUsbTD]?)").unwrap(); let mut idx = 0; let mut results = Vec::new(); @@ -379,7 +380,7 @@ impl Format { results.push(match cap.get(3).unwrap().as_str() { "a" => { if key.as_str() == "r" { - FormatText::RealIPRemoteAddr + FormatText::RealIpRemoteAddr } else { unreachable!() } @@ -433,7 +434,7 @@ enum FormatText { Time, TimeMillis, RemoteAddr, - RealIPRemoteAddr, + RealIpRemoteAddr, UrlPath, RequestHeader(HeaderName), ResponseHeader(HeaderName), @@ -532,7 +533,7 @@ impl FormatText { }; } FormatText::UrlPath => *self = FormatText::Str(req.path().to_string()), - FormatText::RequestTime => *self = FormatText::Str(now.format("%Y-%m-%dT%H:%M:%S")), + FormatText::RequestTime => *self = FormatText::Str(now.format(&Rfc3339).unwrap()), FormatText::RequestHeader(ref name) => { let s = if let Some(val) = req.headers().get(name) { if let Ok(s) = val.to_str() { @@ -546,14 +547,14 @@ impl FormatText { *self = FormatText::Str(s.to_string()); } FormatText::RemoteAddr => { - let s = if let Some(ref peer) = req.connection_info().remote_addr() { + let s = if let Some(peer) = req.connection_info().peer_addr() { FormatText::Str((*peer).to_string()) } else { FormatText::Str("-".to_string()) }; *self = s; } - FormatText::RealIPRemoteAddr => { + FormatText::RealIpRemoteAddr => { let s = if let Some(remote) = req.connection_info().realip_remote_addr() { FormatText::Str(remote.to_string()) } else { @@ -588,7 +589,7 @@ impl<'a> fmt::Display for FormatDisplay<'a> { #[cfg(test)] mod tests { use actix_service::{IntoService, Service, Transform}; - use futures_util::future::ok; + use actix_utils::future::ok; use super::*; use crate::http::{header, StatusCode}; @@ -639,6 +640,38 @@ mod tests { let _res = srv.call(req).await.unwrap(); } + #[actix_rt::test] + async fn test_escape_percent() { + let mut format = Format::new("%%{r}a"); + + let req = TestRequest::default() + .insert_header(( + header::FORWARDED, + header::HeaderValue::from_static("for=192.0.2.60;proto=http;by=203.0.113.43"), + )) + .to_srv_request(); + + let now = OffsetDateTime::now_utc(); + for unit in &mut format.0 { + unit.render_request(now, &req); + } + + let resp = HttpResponse::build(StatusCode::OK).force_close().finish(); + for unit in &mut format.0 { + unit.render_response(&resp); + } + + let entry_time = OffsetDateTime::now_utc(); + let render = |fmt: &mut fmt::Formatter<'_>| { + for unit in &format.0 { + unit.render(fmt, 1024, entry_time)?; + } + Ok(()) + }; + let s = format!("{}", FormatDisplay(&render)); + assert_eq!(s, "%{r}a"); + } + #[actix_rt::test] async fn test_url_path() { let mut format = Format::new("%T %U"); @@ -729,7 +762,7 @@ mod tests { Ok(()) }; let s = format!("{}", FormatDisplay(&render)); - assert!(s.contains(&now.format("%Y-%m-%dT%H:%M:%S"))); + assert!(s.contains(&now.format(&Rfc3339).unwrap())); } #[actix_rt::test] diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index e24782f07..a781052a6 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -5,6 +5,8 @@ mod condition; mod default_headers; mod err_handlers; mod logger; +#[cfg(test)] +mod noop; mod normalize; pub use self::compat::Compat; @@ -12,9 +14,52 @@ pub use self::condition::Condition; pub use self::default_headers::DefaultHeaders; pub use self::err_handlers::{ErrorHandlerResponse, ErrorHandlers}; pub use self::logger::Logger; +#[cfg(test)] +pub(crate) use self::noop::Noop; pub use self::normalize::{NormalizePath, TrailingSlash}; -#[cfg(feature = "compress")] +#[cfg(feature = "__compress")] mod compress; -#[cfg(feature = "compress")] + +#[cfg(feature = "__compress")] pub use self::compress::Compress; + +#[cfg(test)] +mod tests { + use crate::{http::StatusCode, App}; + + use super::*; + + #[test] + fn common_combinations() { + // ensure there's no reason that the built-in middleware cannot compose + + let _ = App::new() + .wrap(Compat::new(Logger::default())) + .wrap(Condition::new(true, DefaultHeaders::new())) + .wrap(DefaultHeaders::new().add(("X-Test2", "X-Value2"))) + .wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| { + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) + })) + .wrap(Logger::default()) + .wrap(NormalizePath::new(TrailingSlash::Trim)); + + let _ = App::new() + .wrap(NormalizePath::new(TrailingSlash::Trim)) + .wrap(Logger::default()) + .wrap(ErrorHandlers::new().handler(StatusCode::FORBIDDEN, |res| { + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) + })) + .wrap(DefaultHeaders::new().add(("X-Test2", "X-Value2"))) + .wrap(Condition::new(true, DefaultHeaders::new())) + .wrap(Compat::new(Logger::default())); + + #[cfg(feature = "__compress")] + { + let _ = App::new().wrap(Compress::default()).wrap(Logger::default()); + let _ = App::new().wrap(Logger::default()).wrap(Compress::default()); + let _ = App::new().wrap(Compat::new(Compress::default())); + let _ = App::new().wrap(Condition::new(true, Compat::new(Compress::default()))); + } + } +} diff --git a/src/middleware/noop.rs b/src/middleware/noop.rs new file mode 100644 index 000000000..ae7da1d81 --- /dev/null +++ b/src/middleware/noop.rs @@ -0,0 +1,37 @@ +//! A no-op middleware. See [Noop] for docs. + +use actix_utils::future::{ready, Ready}; + +use crate::dev::{Service, Transform}; + +/// A no-op middleware that passes through request and response untouched. +pub(crate) struct Noop; + +impl, Req> Transform for Noop { + type Response = S::Response; + type Error = S::Error; + type Transform = NoopService; + type InitError = (); + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(NoopService { service })) + } +} + +#[doc(hidden)] +pub(crate) struct NoopService { + service: S, +} + +impl, Req> Service for NoopService { + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + crate::dev::forward_ready!(service); + + fn call(&self, req: Req) -> Self::Future { + self.service.call(req) + } +} diff --git a/src/middleware/normalize.rs b/src/middleware/normalize.rs index ea21a7215..3ab908481 100644 --- a/src/middleware/normalize.rs +++ b/src/middleware/normalize.rs @@ -1,9 +1,9 @@ //! For middleware documentation, see [`NormalizePath`]. -use actix_http::http::{PathAndQuery, Uri}; +use actix_http::uri::{PathAndQuery, Uri}; use actix_service::{Service, Transform}; +use actix_utils::future::{ready, Ready}; use bytes::Bytes; -use futures_util::future::{ready, Ready}; use regex::Regex; use crate::{ @@ -54,12 +54,12 @@ impl Default for TrailingSlash { /// `TrailingSlash::Always` behavior), as shown in the example tests below. /// /// # Examples -/// ```rust +/// ``` /// use actix_web::{web, middleware, App}; /// /// # actix_web::rt::System::new().block_on(async { /// let app = App::new() -/// .wrap(middleware::NormalizePath::default()) +/// .wrap(middleware::NormalizePath::trim()) /// .route("/test", web::get().to(|| async { "test" })) /// .route("/unmatchable/", web::get().to(|| async { "unmatchable" })); /// @@ -85,13 +85,31 @@ impl Default for TrailingSlash { /// assert_eq!(res.status(), StatusCode::NOT_FOUND); /// # }) /// ``` -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Copy)] pub struct NormalizePath(TrailingSlash); +impl Default for NormalizePath { + fn default() -> Self { + log::warn!( + "`NormalizePath::default()` is deprecated. The default trailing slash behavior changed \ + in v4 from `Always` to `Trim`. Update your call to `NormalizePath::new(...)`." + ); + + Self(TrailingSlash::Trim) + } +} + impl NormalizePath { /// Create new `NormalizePath` middleware with the specified trailing slash style. pub fn new(trailing_slash_style: TrailingSlash) -> Self { - NormalizePath(trailing_slash_style) + Self(trailing_slash_style) + } + + /// Constructs a new `NormalizePath` middleware with [trim](TrailingSlash::Trim) semantics. + /// + /// Use this instead of `NormalizePath::default()` to avoid deprecation warning. + pub fn trim() -> Self { + Self::new(TrailingSlash::Trim) } } @@ -137,58 +155,63 @@ where let original_path = head.uri.path(); - // Either adds a string to the end (duplicates will be removed anyways) or trims all slashes from the end - let path = match self.trailing_slash_behavior { - TrailingSlash::Always => original_path.to_string() + "/", - TrailingSlash::MergeOnly => original_path.to_string(), - TrailingSlash::Trim => original_path.trim_end_matches('/').to_string(), - }; - - // normalize multiple /'s to one / - let path = self.merge_slash.replace_all(&path, "/"); - - // Ensure root paths are still resolvable. If resulting path is blank after previous step - // it means the path was one or more slashes. Reduce to single slash. - let path = if path.is_empty() { "/" } else { path.as_ref() }; - - // Check whether the path has been changed - // - // This check was previously implemented as string length comparison - // - // That approach fails when a trailing slash is added, - // and a duplicate slash is removed, - // since the length of the strings remains the same - // - // For example, the path "/v1//s" will be normalized to "/v1/s/" - // Both of the paths have the same length, - // so the change can not be deduced from the length comparison - if path != original_path { - let mut parts = head.uri.clone().into_parts(); - let query = parts.path_and_query.as_ref().and_then(|pq| pq.query()); - - let path = if let Some(q) = query { - Bytes::from(format!("{}?{}", path, q)) - } else { - Bytes::copy_from_slice(path.as_bytes()) + // An empty path here means that the URI has no valid path. We skip normalization in this + // case, because adding a path can make the URI invalid + if !original_path.is_empty() { + // Either adds a string to the end (duplicates will be removed anyways) or trims all + // slashes from the end + let path = match self.trailing_slash_behavior { + TrailingSlash::Always => format!("{}/", original_path), + TrailingSlash::MergeOnly => original_path.to_string(), + TrailingSlash::Trim => original_path.trim_end_matches('/').to_string(), }; - parts.path_and_query = Some(PathAndQuery::from_maybe_shared(path).unwrap()); - let uri = Uri::from_parts(parts).unwrap(); - req.match_info_mut().get_mut().update(&uri); - req.head_mut().uri = uri; + // normalize multiple /'s to one / + let path = self.merge_slash.replace_all(&path, "/"); + + // Ensure root paths are still resolvable. If resulting path is blank after previous + // step it means the path was one or more slashes. Reduce to single slash. + let path = if path.is_empty() { "/" } else { path.as_ref() }; + + // Check whether the path has been changed + // + // This check was previously implemented as string length comparison + // + // That approach fails when a trailing slash is added, + // and a duplicate slash is removed, + // since the length of the strings remains the same + // + // For example, the path "/v1//s" will be normalized to "/v1/s/" + // Both of the paths have the same length, + // so the change can not be deduced from the length comparison + if path != original_path { + let mut parts = head.uri.clone().into_parts(); + let query = parts.path_and_query.as_ref().and_then(|pq| pq.query()); + + let path = match query { + Some(q) => Bytes::from(format!("{}?{}", path, q)), + None => Bytes::copy_from_slice(path.as_bytes()), + }; + parts.path_and_query = Some(PathAndQuery::from_maybe_shared(path).unwrap()); + + let uri = Uri::from_parts(parts).unwrap(); + req.match_info_mut().get_mut().update(&uri); + req.head_mut().uri = uri; + } } - self.service.call(req) } } #[cfg(test)] mod tests { + use actix_http::StatusCode; use actix_service::IntoService; use super::*; use crate::{ dev::ServiceRequest, + guard::fn_guard, test::{call_service, init_service, TestRequest}, web, App, HttpResponse, }; @@ -199,37 +222,34 @@ mod tests { App::new() .wrap(NormalizePath::default()) .service(web::resource("/").to(HttpResponse::Ok)) - .service(web::resource("/v1/something").to(HttpResponse::Ok)), + .service(web::resource("/v1/something").to(HttpResponse::Ok)) + .service( + web::resource("/v2/something") + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) + .to(HttpResponse::Ok), + ), ) .await; - let req = TestRequest::with_uri("/").to_request(); - let res = call_service(&app, req).await; - assert!(res.status().is_success()); + let test_uris = vec![ + "/", + "/?query=test", + "///", + "/v1//something", + "/v1//something////", + "//v1/something", + "//v1//////something", + "/v2//something?query=test", + "/v2//something////?query=test", + "//v2/something?query=test", + "//v2//////something?query=test", + ]; - let req = TestRequest::with_uri("/?query=test").to_request(); - let res = call_service(&app, req).await; - assert!(res.status().is_success()); - - let req = TestRequest::with_uri("///").to_request(); - let res = call_service(&app, req).await; - assert!(res.status().is_success()); - - let req = TestRequest::with_uri("/v1//something////").to_request(); - let res = call_service(&app, req).await; - assert!(res.status().is_success()); - - let req2 = TestRequest::with_uri("//v1/something").to_request(); - let res2 = call_service(&app, req2).await; - assert!(res2.status().is_success()); - - let req3 = TestRequest::with_uri("//v1//////something").to_request(); - let res3 = call_service(&app, req3).await; - assert!(res3.status().is_success()); - - let req4 = TestRequest::with_uri("/v1//something").to_request(); - let res4 = call_service(&app, req4).await; - assert!(res4.status().is_success()); + for uri in test_uris { + let req = TestRequest::with_uri(uri).to_request(); + let res = call_service(&app, req).await; + assert!(res.status().is_success(), "Failed uri: {}", uri); + } } #[actix_rt::test] @@ -238,38 +258,114 @@ mod tests { App::new() .wrap(NormalizePath(TrailingSlash::Trim)) .service(web::resource("/").to(HttpResponse::Ok)) - .service(web::resource("/v1/something").to(HttpResponse::Ok)), + .service(web::resource("/v1/something").to(HttpResponse::Ok)) + .service( + web::resource("/v2/something") + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) + .to(HttpResponse::Ok), + ), ) .await; - // root paths should still work - let req = TestRequest::with_uri("/").to_request(); - let res = call_service(&app, req).await; - assert!(res.status().is_success()); + let test_uris = vec![ + "/", + "///", + "/v1/something", + "/v1/something/", + "/v1/something////", + "//v1//something", + "//v1//something//", + "/v2/something?query=test", + "/v2/something/?query=test", + "/v2/something////?query=test", + "//v2//something?query=test", + "//v2//something//?query=test", + ]; - let req = TestRequest::with_uri("/?query=test").to_request(); - let res = call_service(&app, req).await; - assert!(res.status().is_success()); + for uri in test_uris { + let req = TestRequest::with_uri(uri).to_request(); + let res = call_service(&app, req).await; + assert!(res.status().is_success(), "Failed uri: {}", uri); + } + } - let req = TestRequest::with_uri("///").to_request(); - let res = call_service(&app, req).await; - assert!(res.status().is_success()); + #[actix_rt::test] + async fn trim_root_trailing_slashes_with_query() { + let app = init_service( + App::new().wrap(NormalizePath(TrailingSlash::Trim)).service( + web::resource("/") + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) + .to(HttpResponse::Ok), + ), + ) + .await; - let req = TestRequest::with_uri("/v1/something////").to_request(); - let res = call_service(&app, req).await; - assert!(res.status().is_success()); + let test_uris = vec!["/?query=test", "//?query=test", "///?query=test"]; - let req2 = TestRequest::with_uri("/v1/something/").to_request(); - let res2 = call_service(&app, req2).await; - assert!(res2.status().is_success()); + for uri in test_uris { + let req = TestRequest::with_uri(uri).to_request(); + let res = call_service(&app, req).await; + assert!(res.status().is_success(), "Failed uri: {}", uri); + } + } - let req3 = TestRequest::with_uri("//v1//something//").to_request(); - let res3 = call_service(&app, req3).await; - assert!(res3.status().is_success()); + #[actix_rt::test] + async fn ensure_trailing_slash() { + let app = init_service( + App::new() + .wrap(NormalizePath(TrailingSlash::Always)) + .service(web::resource("/").to(HttpResponse::Ok)) + .service(web::resource("/v1/something/").to(HttpResponse::Ok)) + .service( + web::resource("/v2/something/") + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) + .to(HttpResponse::Ok), + ), + ) + .await; - let req4 = TestRequest::with_uri("//v1//something").to_request(); - let res4 = call_service(&app, req4).await; - assert!(res4.status().is_success()); + let test_uris = vec![ + "/", + "///", + "/v1/something", + "/v1/something/", + "/v1/something////", + "//v1//something", + "//v1//something//", + "/v2/something?query=test", + "/v2/something/?query=test", + "/v2/something////?query=test", + "//v2//something?query=test", + "//v2//something//?query=test", + ]; + + for uri in test_uris { + let req = TestRequest::with_uri(uri).to_request(); + let res = call_service(&app, req).await; + assert!(res.status().is_success(), "Failed uri: {}", uri); + } + } + + #[actix_rt::test] + async fn ensure_root_trailing_slash_with_query() { + let app = init_service( + App::new() + .wrap(NormalizePath(TrailingSlash::Always)) + .service( + web::resource("/") + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) + .to(HttpResponse::Ok), + ), + ) + .await; + + let test_uris = vec!["/?query=test", "//?query=test", "///?query=test"]; + + for uri in test_uris { + let req = TestRequest::with_uri(uri).to_request(); + let res = call_service(&app, req).await; + assert!(res.status().is_success(), "Failed uri: {}", uri); + } } #[actix_rt::test] @@ -279,7 +375,12 @@ mod tests { .wrap(NormalizePath(TrailingSlash::MergeOnly)) .service(web::resource("/").to(HttpResponse::Ok)) .service(web::resource("/v1/something").to(HttpResponse::Ok)) - .service(web::resource("/v1/").to(HttpResponse::Ok)), + .service(web::resource("/v1/").to(HttpResponse::Ok)) + .service( + web::resource("/v2/something") + .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test"))) + .to(HttpResponse::Ok), + ), ) .await; @@ -295,15 +396,35 @@ mod tests { ("/v1////", true), ("//v1//", true), ("///v1", false), + ("/v2/something?query=test", true), + ("/v2/something/?query=test", false), + ("/v2/something//?query=test", false), + ("//v2//something?query=test", true), ]; - for (path, success) in tests { - let req = TestRequest::with_uri(path).to_request(); + for (uri, success) in tests { + let req = TestRequest::with_uri(uri).to_request(); let res = call_service(&app, req).await; - assert_eq!(res.status().is_success(), success); + assert_eq!(res.status().is_success(), success, "Failed uri: {}", uri); } } + #[actix_rt::test] + async fn no_path() { + let app = init_service( + App::new() + .wrap(NormalizePath::default()) + .service(web::resource("/").to(HttpResponse::Ok)), + ) + .await; + + // This URI will be interpreted as an authority form, i.e. there is no path nor scheme + // (https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3) + let req = TestRequest::with_uri("eh").to_request(); + let res = call_service(&app, req).await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } + #[actix_rt::test] async fn test_in_place_normalization() { let srv = |req: ServiceRequest| { @@ -316,21 +437,18 @@ mod tests { .await .unwrap(); - let req = TestRequest::with_uri("/v1//something////").to_srv_request(); - let res = normalize.call(req).await.unwrap(); - assert!(res.status().is_success()); + let test_uris = vec![ + "/v1//something////", + "///v1/something", + "//v1///something", + "/v1//something", + ]; - let req2 = TestRequest::with_uri("///v1/something").to_srv_request(); - let res2 = normalize.call(req2).await.unwrap(); - assert!(res2.status().is_success()); - - let req3 = TestRequest::with_uri("//v1///something").to_srv_request(); - let res3 = normalize.call(req3).await.unwrap(); - assert!(res3.status().is_success()); - - let req4 = TestRequest::with_uri("/v1//something").to_srv_request(); - let res4 = normalize.call(req4).await.unwrap(); - assert!(res4.status().is_success()); + for uri in test_uris { + let req = TestRequest::with_uri(uri).to_srv_request(); + let res = normalize.call(req).await.unwrap(); + assert!(res.status().is_success(), "Failed uri: {}", uri); + } } #[actix_rt::test] diff --git a/src/request.rs b/src/request.rs index 514b7466e..e876c3b4d 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,24 +1,32 @@ -use std::cell::{Ref, RefCell, RefMut}; -use std::rc::Rc; -use std::{fmt, net}; +use std::{ + cell::{Ref, RefCell, RefMut}, + fmt, net, + rc::Rc, + str, +}; -use actix_http::http::{HeaderMap, Method, Uri, Version}; -use actix_http::{Error, Extensions, HttpMessage, Message, Payload, RequestHead}; +use actix_http::{ + header::HeaderMap, Extensions, HttpMessage, Message, Method, Payload, RequestHead, Uri, + Version, +}; use actix_router::{Path, Url}; -use futures_util::future::{ok, Ready}; +use actix_utils::future::{ok, Ready}; +#[cfg(feature = "cookies")] +use cookie::{Cookie, ParseError as CookieParseError}; use smallvec::SmallVec; -use crate::app_service::AppInitServiceState; -use crate::config::AppConfig; -use crate::error::UrlGenerationError; -use crate::extract::FromRequest; -use crate::info::ConnectionInfo; -use crate::rmap::ResourceMap; +use crate::{ + app_service::AppInitServiceState, config::AppConfig, error::UrlGenerationError, + info::ConnectionInfo, rmap::ResourceMap, Error, FromRequest, +}; +#[cfg(feature = "cookies")] +struct Cookies(Vec>); + +/// An incoming request. #[derive(Clone)] -/// An HTTP Request pub struct HttpRequest { - /// # Panics + /// # Invariant /// `Rc` is used exclusively and NO `Weak` /// is allowed anywhere in the code. Weak pointer is purposely ignored when /// doing `Rc`'s ref counter check. Expect panics if this invariant is violated. @@ -29,6 +37,8 @@ pub(crate) struct HttpRequestInner { pub(crate) head: Message, pub(crate) path: Path, pub(crate) app_data: SmallVec<[Rc; 4]>, + pub(crate) conn_data: Option>, + pub(crate) req_data: Rc>, app_state: Rc, } @@ -39,6 +49,8 @@ impl HttpRequest { head: Message, app_state: Rc, app_data: Rc, + conn_data: Option>, + req_data: Rc>, ) -> HttpRequest { let mut data = SmallVec::<[Rc; 4]>::new(); data.push(app_data); @@ -49,6 +61,8 @@ impl HttpRequest { path, app_state, app_data: data, + conn_data, + req_data, }), } } @@ -92,7 +106,7 @@ impl HttpRequest { &self.head().headers } - /// The target path of this Request. + /// The target path of this request. #[inline] pub fn path(&self) -> &str { self.head().uri.path() @@ -100,22 +114,22 @@ impl HttpRequest { /// The query string in the URL. /// - /// E.g., id=10 + /// Example: `id=10` #[inline] pub fn query_string(&self) -> &str { - if let Some(query) = self.uri().query().as_ref() { - query - } else { - "" - } + self.uri().query().unwrap_or_default() } - /// Get a reference to the Path parameters. + /// Returns a reference to the URL parameters container. /// - /// Params is a container for url parameters. - /// A variable segment is specified in the form `{identifier}`, - /// where the identifier can be used later in a request handler to - /// access the matched value for that segment. + /// A URL parameter is specified in the form `{identifier}`, where the identifier can be used + /// later in a request handler to access the matched value for that parameter. + /// + /// # Percent Encoding and URL Parameters + /// Because each URL parameter is able to capture multiple path segments, both `["%2F", "%25"]` + /// found in the request URI are not decoded into `["/", "%"]` in order to preserve path + /// segment boundaries. If a url parameter is expected to contain these characters, then it is + /// on the user to decode them. #[inline] pub fn match_info(&self) -> &Path { &self.inner.path @@ -145,42 +159,58 @@ impl HttpRequest { self.resource_map().match_name(self.path()) } - /// Request extensions - #[inline] - pub fn extensions(&self) -> Ref<'_, Extensions> { - self.head().extensions() + pub fn req_data(&self) -> Ref<'_, Extensions> { + self.inner.req_data.borrow() } - /// Mutable reference to a the request's extensions - #[inline] - pub fn extensions_mut(&self) -> RefMut<'_, Extensions> { - self.head().extensions_mut() + pub fn req_data_mut(&self) -> RefMut<'_, Extensions> { + self.inner.req_data.borrow_mut() } - /// Generate url for named resource + /// Returns a reference a piece of connection data set in an [on-connect] callback. /// - /// ```rust + /// ```ignore + /// let opt_t = req.conn_data::(); + /// ``` + /// + /// [on-connect]: crate::HttpServer::on_connect + pub fn conn_data(&self) -> Option<&T> { + self.inner + .conn_data + .as_deref() + .and_then(|container| container.get::()) + } + + /// Generates URL for a named resource. + /// + /// This substitutes in sequence all URL parameters that appear in the resource itself and in + /// parent [scopes](crate::web::scope), if any. + /// + /// It is worth noting that the characters `['/', '%']` are not escaped and therefore a single + /// URL parameter may expand into multiple path segments and `elements` can be percent-encoded + /// beforehand without worrying about double encoding. Any other character that is not valid in + /// a URL path context is escaped using percent-encoding. + /// + /// # Examples + /// ``` /// # use actix_web::{web, App, HttpRequest, HttpResponse}; - /// # /// fn index(req: HttpRequest) -> HttpResponse { - /// let url = req.url_for("foo", &["1", "2", "3"]); // <- generate url for "foo" resource + /// let url = req.url_for("foo", &["1", "2", "3"]); // <- generate URL for "foo" resource /// HttpResponse::Ok().into() /// } /// - /// fn main() { - /// let app = App::new() - /// .service(web::resource("/test/{one}/{two}/{three}") - /// .name("foo") // <- set resource name, then it could be used in `url_for` - /// .route(web::get().to(|| HttpResponse::Ok())) - /// ); - /// } + /// let app = App::new() + /// .service(web::resource("/test/{one}/{two}/{three}") + /// .name("foo") // <- set resource name so it can be used in `url_for` + /// .route(web::get().to(|| HttpResponse::Ok())) + /// ); /// ``` pub fn url_for(&self, name: &str, elements: U) -> Result where U: IntoIterator, I: AsRef, { - self.resource_map().url_for(&self, name, elements) + self.resource_map().url_for(self, name, elements) } /// Generate url for named resource @@ -192,32 +222,42 @@ impl HttpRequest { self.url_for(name, &NO_PARAMS) } - #[inline] /// Get a reference to a `ResourceMap` of current application. + #[inline] pub fn resource_map(&self) -> &ResourceMap { - &self.app_state().rmap() + self.app_state().rmap() } - /// Peer socket address. + /// Returns peer socket address. /// /// Peer address is the directly connected peer's socket address. If a proxy is used in front of /// the Actix Web server, then it would be address of this proxy. /// - /// To get client connection information `.connection_info()` should be used. + /// For expanded client connection information, use [`connection_info`] instead. /// - /// Will only return None when called in unit tests. + /// Will only return None when called in unit tests unless [`TestRequest::peer_addr`] is used. + /// + /// [`TestRequest::peer_addr`]: crate::test::TestRequest::peer_addr + /// [`connection_info`]: Self::connection_info #[inline] pub fn peer_addr(&self) -> Option { self.head().peer_addr } - /// Get *ConnectionInfo* for the current request. + /// Returns connection info for the current request. /// - /// This method panics if request's extensions container is already - /// borrowed. + /// The return type, [`ConnectionInfo`], can also be used as an extractor. + /// + /// # Panics + /// Panics if request's extensions container is already borrowed. #[inline] pub fn connection_info(&self) -> Ref<'_, ConnectionInfo> { - ConnectionInfo::get(self.head(), self.app_config()) + if !self.extensions().contains::() { + let info = ConnectionInfo::new(self.head(), &*self.app_config()); + self.extensions_mut().insert(info); + } + + Ref::map(self.extensions(), |data| data.get().unwrap()) } /// App config @@ -226,14 +266,34 @@ impl HttpRequest { self.app_state().config() } - /// Get an application data object stored with `App::data` or `App::app_data` - /// methods during application configuration. + /// Retrieves a piece of application state. /// - /// If `App::data` was used to store object, use `Data`: + /// Extracts any object stored with [`App::app_data()`](crate::App::app_data) (or the + /// counterpart methods on [`Scope`](crate::Scope::app_data) and + /// [`Resource`](crate::Resource::app_data)) during application configuration. /// - /// ```rust,ignore - /// let opt_t = req.app_data::>(); + /// Since the Actix Web router layers application data, the returned object will reference the + /// "closest" instance of the type. For example, if an `App` stores a `u32`, a nested `Scope` + /// also stores a `u32`, and the delegated request handler falls within that `Scope`, then + /// calling `.app_data::()` on an `HttpRequest` within that handler will return the + /// `Scope`'s instance. However, using the same router set up and a request that does not get + /// captured by the `Scope`, `.app_data::()` would return the `App`'s instance. + /// + /// If the state was stored using the [`Data`] wrapper, then it must also be retrieved using + /// this same type. + /// + /// See also the [`Data`] extractor. + /// + /// # Examples + /// ```no_run + /// # use actix_web::{test::TestRequest, web::Data}; + /// # let req = TestRequest::default().to_http_request(); + /// # type T = u32; + /// let opt_t: Option<&Data> = req.app_data::>(); /// ``` + /// + /// [`Data`]: crate::web::Data + #[doc(alias = "state")] pub fn app_data(&self) -> Option<&T> { for container in self.inner.app_data.iter().rev() { if let Some(data) = container.get::() { @@ -248,27 +308,60 @@ impl HttpRequest { fn app_state(&self) -> &AppInitServiceState { &*self.inner.app_state } + + /// Load request cookies. + #[cfg(feature = "cookies")] + pub fn cookies(&self) -> Result>>, CookieParseError> { + use actix_http::header::COOKIE; + + if self.extensions().get::().is_none() { + let mut cookies = Vec::new(); + for hdr in self.headers().get_all(COOKIE) { + let s = str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; + for cookie_str in s.split(';').map(|s| s.trim()) { + if !cookie_str.is_empty() { + cookies.push(Cookie::parse_encoded(cookie_str)?.into_owned()); + } + } + } + self.extensions_mut().insert(Cookies(cookies)); + } + + Ok(Ref::map(self.extensions(), |ext| { + &ext.get::().unwrap().0 + })) + } + + /// Return request cookie. + #[cfg(feature = "cookies")] + pub fn cookie(&self, name: &str) -> Option> { + if let Ok(cookies) = self.cookies() { + for cookie in cookies.iter() { + if cookie.name() == name { + return Some(cookie.to_owned()); + } + } + } + None + } } impl HttpMessage for HttpRequest { type Stream = (); #[inline] - /// Returns Request's headers. fn headers(&self) -> &HeaderMap { &self.head().headers } - /// Request extensions #[inline] fn extensions(&self) -> Ref<'_, Extensions> { - self.inner.head.extensions() + self.req_data() } - /// Mutable reference to a the request's extensions #[inline] fn extensions_mut(&self) -> RefMut<'_, Extensions> { - self.inner.head.extensions_mut() + self.req_data_mut() } #[inline] @@ -281,17 +374,18 @@ impl Drop for HttpRequest { fn drop(&mut self) { // if possible, contribute to current worker's HttpRequest allocation pool - // This relies on no Weak exists anywhere.(There is none) + // This relies on no weak references to inner existing anywhere within the codebase. if let Some(inner) = Rc::get_mut(&mut self.inner) { if inner.app_state.pool().is_available() { // clear additional app_data and keep the root one for reuse. inner.app_data.truncate(1); - // inner is borrowed mut here. get head's Extension mutably - // to reduce borrow check - inner.head.extensions.get_mut().clear(); + + // Inner is borrowed mut here and; get req data mutably to reduce borrow check. Also + // we know the req_data Rc will not have any cloned at this point to unwrap is okay. + Rc::get_mut(&mut inner.req_data).unwrap().get_mut().clear(); // a re-borrow of pool is necessary here. - let req = self.inner.clone(); + let req = Rc::clone(&self.inner); self.app_state().pool().push(req); } } @@ -300,11 +394,10 @@ impl Drop for HttpRequest { /// It is possible to get `HttpRequest` as an extractor handler parameter /// -/// ## Example -/// -/// ```rust +/// # Examples +/// ``` /// use actix_web::{web, App, HttpRequest}; -/// use serde_derive::Deserialize; +/// use serde::Deserialize; /// /// /// extract `Thing` from request /// async fn index(req: HttpRequest) -> String { @@ -319,7 +412,6 @@ impl Drop for HttpRequest { /// } /// ``` impl FromRequest for HttpRequest { - type Config = (); type Error = Error; type Future = Ready>; @@ -470,9 +562,9 @@ mod tests { #[test] fn test_url_for() { let mut res = ResourceDef::new("/user/{name}.{ext}"); - *res.name_mut() = "index".to_string(); + res.set_name("index"); - let mut rmap = ResourceMap::new(ResourceDef::new("")); + let mut rmap = ResourceMap::new(ResourceDef::prefix("")); rmap.add(&mut res, None); assert!(rmap.has_resource("/user/test.html")); assert!(!rmap.has_resource("/test/unknown")); @@ -500,9 +592,9 @@ mod tests { #[test] fn test_url_for_static() { let mut rdef = ResourceDef::new("/index.html"); - *rdef.name_mut() = "index".to_string(); + rdef.set_name("index"); - let mut rmap = ResourceMap::new(ResourceDef::new("")); + let mut rmap = ResourceMap::new(ResourceDef::prefix("")); rmap.add(&mut rdef, None); assert!(rmap.has_resource("/index.html")); @@ -521,9 +613,9 @@ mod tests { #[test] fn test_match_name() { let mut rdef = ResourceDef::new("/index.html"); - *rdef.name_mut() = "index".to_string(); + rdef.set_name("index"); - let mut rmap = ResourceMap::new(ResourceDef::new("")); + let mut rmap = ResourceMap::new(ResourceDef::prefix("")); rmap.add(&mut rdef, None); assert!(rmap.has_resource("/index.html")); @@ -540,11 +632,10 @@ mod tests { fn test_url_for_external() { let mut rdef = ResourceDef::new("https://youtube.com/watch/{video_id}"); - *rdef.name_mut() = "youtube".to_string(); + rdef.set_name("youtube"); - let mut rmap = ResourceMap::new(ResourceDef::new("")); + let mut rmap = ResourceMap::new(ResourceDef::prefix("")); rmap.add(&mut rdef, None); - assert!(rmap.has_resource("https://youtube.com/watch/unknown")); let req = TestRequest::default().rmap(rmap).to_http_request(); let url = req.url_for("youtube", &["oHg5SJYRHA0"]); @@ -668,6 +759,8 @@ mod tests { assert_eq!(body, Bytes::from_static(b"1")); } + // allow deprecated App::data + #[allow(deprecated)] #[actix_rt::test] async fn test_extensions_dropped() { struct Tracker { diff --git a/src/request_data.rs b/src/request_data.rs index beee8ac12..b685fd0d6 100644 --- a/src/request_data.rs +++ b/src/request_data.rs @@ -1,9 +1,8 @@ use std::{any::type_name, ops::Deref}; -use actix_http::error::{Error, ErrorInternalServerError}; -use futures_util::future; +use actix_utils::future::{err, ok, Ready}; -use crate::{dev::Payload, FromRequest, HttpRequest}; +use crate::{dev::Payload, error::ErrorInternalServerError, Error, FromRequest, HttpRequest}; /// Request-local data extractor. /// @@ -18,12 +17,12 @@ use crate::{dev::Payload, FromRequest, HttpRequest}; /// # Mutating Request Data /// Note that since extractors must output owned data, only types that `impl Clone` can use this /// extractor. A clone is taken of the required request data and can, therefore, not be directly -/// mutated in-place. To mutate request data, continue to use [`HttpRequest::extensions_mut`] or +/// mutated in-place. To mutate request data, continue to use [`HttpRequest::req_data_mut`] or /// re-insert the cloned data back into the extensions map. A `DerefMut` impl is intentionally not /// provided to make this potential foot-gun more obvious. /// /// # Example -/// ```rust,no_run +/// ```no_run /// # use actix_web::{web, HttpResponse, HttpRequest, Responder}; /// /// #[derive(Debug, Clone, PartialEq)] @@ -34,12 +33,11 @@ use crate::{dev::Payload, FromRequest, HttpRequest}; /// req: HttpRequest, /// opt_flag: Option>, /// ) -> impl Responder { -/// // use an optional extractor if the middleware is -/// // not guaranteed to add this type of requests data +/// // use an option extractor if middleware is not guaranteed to add this type of req data /// if let Some(flag) = opt_flag { -/// assert_eq!(&flag.into_inner(), req.extensions().get::().unwrap()); +/// assert_eq!(&flag.into_inner(), req.req_data().get::().unwrap()); /// } -/// +/// /// HttpResponse::Ok() /// } /// ``` @@ -65,13 +63,12 @@ impl Deref for ReqData { } impl FromRequest for ReqData { - type Config = (); type Error = Error; - type Future = future::Ready>; + type Future = Ready>; fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { - if let Some(st) = req.extensions().get::() { - future::ok(ReqData(st.clone())) + if let Some(st) = req.req_data().get::() { + ok(ReqData(st.clone())) } else { log::debug!( "Failed to construct App-level ReqData extractor. \ @@ -79,7 +76,7 @@ impl FromRequest for ReqData { req.path(), type_name::(), ); - future::err(ErrorInternalServerError( + err(ErrorInternalServerError( "Missing expected request extension data", )) } diff --git a/src/resource.rs b/src/resource.rs index 944beeefa..8da0a8a85 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -1,66 +1,61 @@ -use std::cell::RefCell; -use std::fmt; -use std::future::Future; -use std::rc::Rc; -use std::task::Poll; +use std::{cell::RefCell, fmt, future::Future, marker::PhantomData, rc::Rc}; -use actix_http::{Error, Extensions, Response}; -use actix_router::IntoPattern; -use actix_service::boxed::{self, BoxService, BoxServiceFactory}; +use actix_http::{body::BoxBody, Extensions}; +use actix_router::{IntoPatterns, Patterns}; use actix_service::{ - apply, apply_fn_factory, fn_service, IntoServiceFactory, Service, ServiceFactory, + apply, apply_fn_factory, boxed, fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt, Transform, }; use futures_core::future::LocalBoxFuture; use futures_util::future::join_all; -use crate::data::Data; -use crate::dev::{insert_slash, AppService, HttpServiceFactory, ResourceDef}; -use crate::extract::FromRequest; -use crate::guard::Guard; -use crate::handler::Handler; -use crate::responder::Responder; -use crate::route::{Route, RouteService}; -use crate::service::{ServiceRequest, ServiceResponse}; +use crate::{ + body::MessageBody, + data::Data, + dev::{ensure_leading_slash, AppService, ResourceDef}, + guard::Guard, + handler::Handler, + route::{Route, RouteService}, + service::{ + BoxedHttpService, BoxedHttpServiceFactory, HttpServiceFactory, ServiceRequest, + ServiceResponse, + }, + Error, FromRequest, HttpResponse, Responder, +}; -type HttpService = BoxService; -type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; - -/// *Resource* is an entry in resources table which corresponds to requested URL. +/// A collection of [`Route`]s that respond to the same path pattern. /// -/// Resource in turn has at least one route. -/// Route consists of an handlers objects and list of guards -/// (objects that implement `Guard` trait). -/// Resources and routes uses builder-like pattern for configuration. -/// During request handling, resource object iterate through all routes -/// and check guards for specific route, if request matches all -/// guards, route considered matched and route handler get called. +/// Resource in turn has at least one route. Route consists of an handlers objects and list of +/// guards (objects that implement `Guard` trait). Resources and routes uses builder-like pattern +/// for configuration. During request handling, resource object iterate through all routes and check +/// guards for specific route, if request matches all guards, route considered matched and route +/// handler get called. /// -/// ```rust +/// # Examples +/// ``` /// use actix_web::{web, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/") -/// .route(web::get().to(|| HttpResponse::Ok()))); -/// } +/// let app = App::new().service( +/// web::resource("/") +/// .route(web::get().to(|| HttpResponse::Ok()))); /// ``` /// -/// If no matching route could be found, *405* response code get returned. -/// Default behavior could be overridden with `default_resource()` method. -pub struct Resource { +/// If no matching route could be found, *405* response code get returned. Default behavior could be +/// overridden with `default_resource()` method. +pub struct Resource { endpoint: T, - rdef: Vec, + rdef: Patterns, name: Option, routes: Vec, app_data: Option, guards: Vec>, - default: HttpNewService, + default: BoxedHttpServiceFactory, factory_ref: Rc>>, + _phantom: PhantomData, } impl Resource { - pub fn new(path: T) -> Resource { + pub fn new(path: T) -> Resource { let fref = Rc::new(RefCell::new(None)); Resource { @@ -72,21 +67,23 @@ impl Resource { guards: Vec::new(), app_data: None, default: boxed::factory(fn_service(|req: ServiceRequest| async { - Ok(req.into_response(Response::MethodNotAllowed().finish())) + Ok(req.into_response(HttpResponse::MethodNotAllowed())) })), + _phantom: PhantomData, } } } -impl Resource +impl Resource where T: ServiceFactory< ServiceRequest, Config = (), - Response = ServiceResponse, + Response = ServiceResponse, Error = Error, InitError = (), >, + B: MessageBody, { /// Set resource name. /// @@ -98,7 +95,7 @@ where /// Add match guard to a resource. /// - /// ```rust + /// ``` /// use actix_web::{web, guard, App, HttpResponse}; /// /// async fn index(data: web::Path<(String, String)>) -> &'static str { @@ -131,24 +128,22 @@ where /// Register a new route. /// - /// ```rust + /// ``` /// use actix_web::{web, guard, App, HttpResponse}; /// - /// fn main() { - /// let app = App::new().service( - /// web::resource("/").route( - /// web::route() - /// .guard(guard::Any(guard::Get()).or(guard::Put())) - /// .guard(guard::Header("Content-Type", "text/plain")) - /// .to(|| HttpResponse::Ok())) - /// ); - /// } + /// let app = App::new().service( + /// web::resource("/").route( + /// web::route() + /// .guard(guard::Any(guard::Get()).or(guard::Put())) + /// .guard(guard::Header("Content-Type", "text/plain")) + /// .to(|| HttpResponse::Ok())) + /// ); /// ``` /// /// Multiple routes could be added to a resource. Resource object uses /// match guards for route selection. /// - /// ```rust + /// ``` /// use actix_web::{web, guard, App}; /// /// fn main() { @@ -168,40 +163,39 @@ where self } - /// Provide resource specific data. This method allows to add extractor - /// configuration or specific state available via `Data` extractor. - /// Provided data is available for all routes registered for the current resource. - /// Resource data overrides data registered by `App::data()` method. - /// - /// ```rust - /// use actix_web::{web, App, FromRequest}; - /// - /// /// extract text data from request - /// async fn index(body: String) -> String { - /// format!("Body {}!", body) - /// } - /// - /// fn main() { - /// let app = App::new().service( - /// web::resource("/index.html") - /// // limit size of the payload - /// .data(String::configure(|cfg| { - /// cfg.limit(4096) - /// })) - /// .route( - /// web::get() - /// // register handler - /// .to(index) - /// )); - /// } - /// ``` - pub fn data(self, data: U) -> Self { - self.app_data(Data::new(data)) - } - /// Add resource data. /// - /// Data of different types from parent contexts will still be accessible. + /// Data of different types from parent contexts will still be accessible. Any `Data` types + /// set here can be extracted in handlers using the `Data` extractor. + /// + /// # Examples + /// ``` + /// use std::cell::Cell; + /// use actix_web::{web, App, HttpRequest, HttpResponse, Responder}; + /// + /// struct MyData { + /// count: std::cell::Cell, + /// } + /// + /// async fn handler(req: HttpRequest, counter: web::Data) -> impl Responder { + /// // note this cannot use the Data extractor because it was not added with it + /// let incr = *req.app_data::().unwrap(); + /// assert_eq!(incr, 3); + /// + /// // update counter using other value from app data + /// counter.count.set(counter.count.get() + incr); + /// + /// HttpResponse::Ok().body(counter.count.get().to_string()) + /// } + /// + /// let app = App::new().service( + /// web::resource("/") + /// .app_data(3usize) + /// .app_data(web::Data::new(MyData { count: Default::default() })) + /// .route(web::get().to(handler)) + /// ); + /// ``` + #[doc(alias = "manage")] pub fn app_data(mut self, data: U) -> Self { self.app_data .get_or_insert_with(Extensions::new) @@ -210,9 +204,17 @@ where self } + /// Add resource data after wrapping in `Data`. + /// + /// Deprecated in favor of [`app_data`](Self::app_data). + #[deprecated(since = "4.0.0", note = "Use `.app_data(Data::new(val))` instead.")] + pub fn data(self, data: U) -> Self { + self.app_data(Data::new(data)) + } + /// Register a new route and add handler. This route matches all requests. /// - /// ```rust + /// ``` /// use actix_web::*; /// /// fn index(req: HttpRequest) -> HttpResponse { @@ -224,18 +226,16 @@ where /// /// This is shortcut for: /// - /// ```rust - /// # extern crate actix_web; + /// ``` /// # use actix_web::*; /// # fn index(req: HttpRequest) -> HttpResponse { unimplemented!() } /// App::new().service(web::resource("/").route(web::route().to(index))); /// ``` - pub fn to(mut self, handler: F) -> Self + pub fn to(mut self, handler: F) -> Self where - F: Handler, - I: FromRequest + 'static, - R: Future + 'static, - R::Output: Responder + 'static, + F: Handler, + Args: FromRequest + 'static, + F::Output: Responder + 'static, { self.routes.push(Route::new().to(handler)); self @@ -248,26 +248,28 @@ where /// type (i.e modify response's body). /// /// **Note**: middlewares get called in opposite order of middlewares registration. - pub fn wrap( + pub fn wrap( self, mw: M, ) -> Resource< impl ServiceFactory< ServiceRequest, Config = (), - Response = ServiceResponse, + Response = ServiceResponse, Error = Error, InitError = (), >, + B1, > where M: Transform< T::Service, ServiceRequest, - Response = ServiceResponse, + Response = ServiceResponse, Error = Error, InitError = (), >, + B1: MessageBody, { Resource { endpoint: apply(mw, self.endpoint), @@ -278,6 +280,7 @@ where default: self.default, app_data: self.app_data, factory_ref: self.factory_ref, + _phantom: PhantomData, } } @@ -290,10 +293,10 @@ where /// Resource level middlewares are not allowed to change response /// type (i.e modify response's body). /// - /// ```rust + /// ``` /// use actix_service::Service; /// use actix_web::{web, App}; - /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; + /// use actix_web::http::header::{CONTENT_TYPE, HeaderValue}; /// /// async fn index() -> &'static str { /// "Welcome!" @@ -315,21 +318,23 @@ where /// .route(web::get().to(index))); /// } /// ``` - pub fn wrap_fn( + pub fn wrap_fn( self, mw: F, ) -> Resource< impl ServiceFactory< ServiceRequest, Config = (), - Response = ServiceResponse, + Response = ServiceResponse, Error = Error, InitError = (), >, + B1, > where F: Fn(ServiceRequest, &T::Service) -> R + Clone, - R: Future>, + R: Future, Error>>, + B1: MessageBody, { Resource { endpoint: apply_fn_factory(self.endpoint, mw), @@ -340,6 +345,7 @@ where default: self.default, app_data: self.app_data, factory_ref: self.factory_ref, + _phantom: PhantomData, } } @@ -367,15 +373,16 @@ where } } -impl HttpServiceFactory for Resource +impl HttpServiceFactory for Resource where T: ServiceFactory< ServiceRequest, Config = (), - Response = ServiceResponse, + Response = ServiceResponse, Error = Error, InitError = (), > + 'static, + B: MessageBody + 'static, { fn register(mut self, config: &mut AppService) { let guards = if self.guards.is_empty() { @@ -385,44 +392,40 @@ where }; let mut rdef = if config.is_root() || !self.rdef.is_empty() { - ResourceDef::new(insert_slash(self.rdef.clone())) + ResourceDef::new(ensure_leading_slash(self.rdef.clone())) } else { ResourceDef::new(self.rdef.clone()) }; if let Some(ref name) = self.name { - *rdef.name_mut() = name.clone(); + rdef.set_name(name); } - config.register_service(rdef, guards, self, None) - } -} - -impl IntoServiceFactory for Resource -where - T: ServiceFactory< - ServiceRequest, - Config = (), - Response = ServiceResponse, - Error = Error, - InitError = (), - >, -{ - fn into_factory(self) -> T { *self.factory_ref.borrow_mut() = Some(ResourceFactory { routes: self.routes, - app_data: self.app_data.map(Rc::new), default: self.default, }); - self.endpoint + let resource_data = self.app_data.map(Rc::new); + + // wraps endpoint service (including middleware) call and injects app data for this scope + let endpoint = apply_fn_factory(self.endpoint, move |mut req: ServiceRequest, srv| { + if let Some(ref data) = resource_data { + req.add_data_container(Rc::clone(data)); + } + + let fut = srv.call(req); + + async { Ok(fut.await?.map_into_boxed_body()) } + }); + + config.register_service(rdef, guards, endpoint, None) } } pub struct ResourceFactory { routes: Vec, - app_data: Option>, - default: HttpNewService, + default: BoxedHttpServiceFactory, } impl ServiceFactory for ResourceFactory { @@ -440,8 +443,6 @@ impl ServiceFactory for ResourceFactory { // construct route service factory futures let factory_fut = join_all(self.routes.iter().map(|route| route.new_service(()))); - let app_data = self.app_data.clone(); - Box::pin(async move { let default = default_fut.await?; let routes = factory_fut @@ -449,19 +450,14 @@ impl ServiceFactory for ResourceFactory { .into_iter() .collect::, _>>()?; - Ok(ResourceService { - app_data, - default, - routes, - }) + Ok(ResourceService { routes, default }) }) } } pub struct ResourceService { routes: Vec, - app_data: Option>, - default: HttpService, + default: BoxedHttpService, } impl Service for ResourceService { @@ -472,20 +468,12 @@ impl Service for ResourceService { actix_service::always_ready!(); fn call(&self, mut req: ServiceRequest) -> Self::Future { - for route in self.routes.iter() { + for route in &self.routes { if route.check(&mut req) { - if let Some(ref app_data) = self.app_data { - req.add_data_container(app_data.clone()); - } - return route.call(req); } } - if let Some(ref app_data) = self.app_data { - req.add_data_container(app_data.clone()); - } - self.default.call(req) } } @@ -520,13 +508,49 @@ mod tests { use actix_rt::time::sleep; use actix_service::Service; - use futures_util::future::ok; + use actix_utils::future::ok; - use crate::http::{header, HeaderValue, Method, StatusCode}; - use crate::middleware::DefaultHeaders; - use crate::service::ServiceRequest; - use crate::test::{call_service, init_service, TestRequest}; - use crate::{guard, web, App, Error, HttpResponse}; + use super::*; + use crate::{ + guard, + http::{ + header::{self, HeaderValue}, + Method, StatusCode, + }, + middleware::{Compat, DefaultHeaders}, + service::{ServiceRequest, ServiceResponse}, + test::{call_service, init_service, TestRequest}, + web, App, Error, HttpMessage, HttpResponse, + }; + + #[test] + fn can_be_returned_from_fn() { + fn my_resource() -> Resource { + web::resource("/test").route(web::get().to(|| async { "hello" })) + } + + fn my_compat_resource() -> Resource< + impl ServiceFactory< + ServiceRequest, + Config = (), + Response = ServiceResponse, + Error = Error, + InitError = (), + >, + > { + web::resource("/test-compat") + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async { Ok(fut.await?.map_into_right_body::<()>()) } + }) + .wrap(Compat::noop()) + .route(web::get().to(|| async { "hello" })) + } + + App::new() + .service(my_resource()) + .service(my_compat_resource()); + } #[actix_rt::test] async fn test_middleware() { @@ -536,7 +560,7 @@ mod tests { .name("test") .wrap( DefaultHeaders::new() - .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")), + .add((header::CONTENT_TYPE, HeaderValue::from_static("0001"))), ) .route(web::get().to(HttpResponse::Ok)), ), @@ -693,6 +717,8 @@ mod tests { assert_eq!(resp.status(), StatusCode::NO_CONTENT); } + // allow deprecated `{App, Resource}::data` + #[allow(deprecated)] #[actix_rt::test] async fn test_data() { let srv = init_service( @@ -725,6 +751,8 @@ mod tests { assert_eq!(resp.status(), StatusCode::OK); } + // allow deprecated `{App, Resource}::data` + #[allow(deprecated)] #[actix_rt::test] async fn test_data_default_service() { let srv = init_service( @@ -743,4 +771,61 @@ mod tests { let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } + + #[actix_rt::test] + async fn test_middleware_app_data() { + let srv = init_service( + App::new().service( + web::resource("test") + .app_data(1usize) + .wrap_fn(|req, srv| { + assert_eq!(req.app_data::(), Some(&1usize)); + req.extensions_mut().insert(1usize); + srv.call(req) + }) + .route(web::get().to(HttpResponse::Ok)) + .default_service(|req: ServiceRequest| async move { + let (req, _) = req.into_parts(); + + assert_eq!(req.extensions().get::(), Some(&1)); + + Ok(ServiceResponse::new( + req, + HttpResponse::BadRequest().finish(), + )) + }), + ), + ) + .await; + + let req = TestRequest::get().uri("/test").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::post().uri("/test").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[actix_rt::test] + async fn test_middleware_body_type() { + let srv = init_service( + App::new().service( + web::resource("/test") + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async { Ok(fut.await?.map_into_right_body::<()>()) } + }) + .route(web::get().to(|| async { "hello" })), + ), + ) + .await; + + // test if `try_into_bytes()` is preserved across scope layer + use actix_http::body::MessageBody as _; + let req = TestRequest::with_uri("/test").to_request(); + let resp = call_service(&srv, req).await; + let body = resp.into_body(); + assert_eq!(body.try_into_bytes().unwrap(), b"hello".as_ref()); + } } diff --git a/src/responder.rs b/src/responder.rs deleted file mode 100644 index 92945cdaa..000000000 --- a/src/responder.rs +++ /dev/null @@ -1,430 +0,0 @@ -use std::fmt; - -use actix_http::{ - error::InternalError, - http::{header::IntoHeaderPair, Error as HttpError, HeaderMap, StatusCode}, - ResponseBuilder, -}; -use bytes::{Bytes, BytesMut}; - -use crate::{Error, HttpRequest, HttpResponse}; - -/// Trait implemented by types that can be converted to an HTTP response. -/// -/// Any types that implement this trait can be used in the return type of a handler. -pub trait Responder { - /// Convert self to `HttpResponse`. - fn respond_to(self, req: &HttpRequest) -> HttpResponse; - - /// Override a status code for a Responder. - /// - /// ```rust - /// use actix_web::{http::StatusCode, HttpRequest, Responder}; - /// - /// fn index(req: HttpRequest) -> impl Responder { - /// "Welcome!".with_status(StatusCode::OK) - /// } - /// ``` - fn with_status(self, status: StatusCode) -> CustomResponder - where - Self: Sized, - { - CustomResponder::new(self).with_status(status) - } - - /// Insert header to the final response. - /// - /// Overrides other headers with the same name. - /// - /// ```rust - /// use actix_web::{web, HttpRequest, Responder}; - /// use serde::Serialize; - /// - /// #[derive(Serialize)] - /// struct MyObj { - /// name: String, - /// } - /// - /// fn index(req: HttpRequest) -> impl Responder { - /// web::Json(MyObj { name: "Name".to_owned() }) - /// .with_header(("x-version", "1.2.3")) - /// } - /// ``` - fn with_header(self, header: H) -> CustomResponder - where - Self: Sized, - H: IntoHeaderPair, - { - CustomResponder::new(self).with_header(header) - } -} - -impl Responder for HttpResponse { - #[inline] - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - self - } -} - -impl Responder for Option { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { - match self { - Some(t) => t.respond_to(req), - None => HttpResponse::build(StatusCode::NOT_FOUND).finish(), - } - } -} - -impl Responder for Result -where - T: Responder, - E: Into, -{ - fn respond_to(self, req: &HttpRequest) -> HttpResponse { - match self { - Ok(val) => val.respond_to(req), - Err(e) => HttpResponse::from_error(e.into()), - } - } -} - -impl Responder for ResponseBuilder { - #[inline] - fn respond_to(mut self, _: &HttpRequest) -> HttpResponse { - self.finish() - } -} - -impl Responder for (T, StatusCode) { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { - let mut res = self.0.respond_to(req); - *res.status_mut() = self.1; - res - } -} - -impl Responder for &'static str { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(self) - } -} - -impl Responder for &'static [u8] { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(self) - } -} - -impl Responder for String { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(self) - } -} - -impl<'a> Responder for &'a String { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(self) - } -} - -impl Responder for Bytes { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(self) - } -} - -impl Responder for BytesMut { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(self) - } -} - -/// Allows overriding status code and headers for a responder. -pub struct CustomResponder { - responder: T, - status: Option, - headers: Option, - error: Option, -} - -impl CustomResponder { - fn new(responder: T) -> Self { - CustomResponder { - responder, - status: None, - headers: None, - error: None, - } - } - - /// Override a status code for the Responder's response. - /// - /// ```rust - /// use actix_web::{HttpRequest, Responder, http::StatusCode}; - /// - /// fn index(req: HttpRequest) -> impl Responder { - /// "Welcome!".with_status(StatusCode::OK) - /// } - /// ``` - pub fn with_status(mut self, status: StatusCode) -> Self { - self.status = Some(status); - self - } - - /// Insert header to the final response. - /// - /// Overrides other headers with the same name. - /// - /// ```rust - /// use actix_web::{web, HttpRequest, Responder}; - /// use serde::Serialize; - /// - /// #[derive(Serialize)] - /// struct MyObj { - /// name: String, - /// } - /// - /// fn index(req: HttpRequest) -> impl Responder { - /// web::Json(MyObj { name: "Name".to_string() }) - /// .with_header(("x-version", "1.2.3")) - /// .with_header(("x-version", "1.2.3")) - /// } - /// ``` - pub fn with_header(mut self, header: H) -> Self - where - H: IntoHeaderPair, - { - if self.headers.is_none() { - self.headers = Some(HeaderMap::new()); - } - - match header.try_into_header_pair() { - Ok((key, value)) => self.headers.as_mut().unwrap().append(key, value), - Err(e) => self.error = Some(e.into()), - }; - - self - } -} - -impl Responder for CustomResponder { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { - let mut res = self.responder.respond_to(req); - - if let Some(status) = self.status { - *res.status_mut() = status; - } - - if let Some(ref headers) = self.headers { - for (k, v) in headers { - // TODO: before v4, decide if this should be append instead - res.headers_mut().insert(k.clone(), v.clone()); - } - } - - res - } -} - -impl Responder for InternalError -where - T: fmt::Debug + fmt::Display + 'static, -{ - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::from_error(self.into()) - } -} - -#[cfg(test)] -pub(crate) mod tests { - use actix_service::Service; - use bytes::{Bytes, BytesMut}; - - use super::*; - use crate::dev::{Body, ResponseBody}; - use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; - use crate::test::{init_service, TestRequest}; - use crate::{error, web, App}; - - #[actix_rt::test] - async fn test_option_responder() { - let srv = init_service( - App::new() - .service(web::resource("/none").to(|| async { Option::<&'static str>::None })) - .service(web::resource("/some").to(|| async { Some("some") })), - ) - .await; - - let req = TestRequest::with_uri("/none").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); - - let req = TestRequest::with_uri("/some").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - match resp.response().body() { - ResponseBody::Body(Body::Bytes(ref b)) => { - let bytes = b.clone(); - assert_eq!(bytes, Bytes::from_static(b"some")); - } - _ => panic!(), - } - } - - pub(crate) trait BodyTest { - fn bin_ref(&self) -> &[u8]; - fn body(&self) -> &Body; - } - - impl BodyTest for ResponseBody { - fn bin_ref(&self) -> &[u8] { - match self { - ResponseBody::Body(ref b) => match b { - Body::Bytes(ref bin) => &bin, - _ => panic!(), - }, - ResponseBody::Other(ref b) => match b { - Body::Bytes(ref bin) => &bin, - _ => panic!(), - }, - } - } - fn body(&self) -> &Body { - match self { - ResponseBody::Body(ref b) => b, - ResponseBody::Other(ref b) => b, - } - } - } - - #[actix_rt::test] - async fn test_responder() { - let req = TestRequest::default().to_http_request(); - - let resp = "test".respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); - - let resp = b"test".respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream") - ); - - let resp = "test".to_string().respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); - - let resp = (&"test".to_string()).respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); - - let resp = Bytes::from_static(b"test").respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream") - ); - - let resp = BytesMut::from(b"test".as_ref()).respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/octet-stream") - ); - - // InternalError - let resp = error::InternalError::new("err", StatusCode::BAD_REQUEST).respond_to(&req); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - } - - #[actix_rt::test] - async fn test_result_responder() { - let req = TestRequest::default().to_http_request(); - - // Result - let resp = Ok::<_, Error>("test".to_string()).respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().bin_ref(), b"test"); - assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("text/plain; charset=utf-8") - ); - - let res = Err::(error::InternalError::new("err", StatusCode::BAD_REQUEST)) - .respond_to(&req); - - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - } - - #[actix_rt::test] - async fn test_custom_responder() { - let req = TestRequest::default().to_http_request(); - let res = "test" - .to_string() - .with_status(StatusCode::BAD_REQUEST) - .respond_to(&req); - - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - assert_eq!(res.body().bin_ref(), b"test"); - - let res = "test" - .to_string() - .with_header(("content-type", "json")) - .respond_to(&req); - - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.body().bin_ref(), b"test"); - assert_eq!( - res.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("json") - ); - } - - #[actix_rt::test] - async fn test_tuple_responder_with_status_code() { - let req = TestRequest::default().to_http_request(); - let res = ("test".to_string(), StatusCode::BAD_REQUEST).respond_to(&req); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - assert_eq!(res.body().bin_ref(), b"test"); - - let req = TestRequest::default().to_http_request(); - let res = ("test".to_string(), StatusCode::OK) - .with_header((CONTENT_TYPE, mime::APPLICATION_JSON)) - .respond_to(&req); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.body().bin_ref(), b"test"); - assert_eq!( - res.headers().get(CONTENT_TYPE).unwrap(), - HeaderValue::from_static("application/json") - ); - } -} diff --git a/src/response/builder.rs b/src/response/builder.rs new file mode 100644 index 000000000..bdb0aaa12 --- /dev/null +++ b/src/response/builder.rs @@ -0,0 +1,567 @@ +use std::{ + cell::{Ref, RefMut}, + convert::TryInto, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_http::{ + body::{BodyStream, BoxBody, MessageBody}, + error::HttpError, + header::{self, HeaderName, TryIntoHeaderPair, TryIntoHeaderValue}, + ConnectionType, Extensions, Response, ResponseHead, StatusCode, +}; +use bytes::Bytes; +use futures_core::Stream; +use serde::Serialize; + +#[cfg(feature = "cookies")] +use actix_http::header::HeaderValue; +#[cfg(feature = "cookies")] +use cookie::{Cookie, CookieJar}; + +use crate::{ + error::{Error, JsonPayloadError}, + BoxError, HttpRequest, HttpResponse, Responder, +}; + +/// An HTTP response builder. +/// +/// This type can be used to construct an instance of `Response` through a builder-like pattern. +pub struct HttpResponseBuilder { + res: Option>, + err: Option, + #[cfg(feature = "cookies")] + cookies: Option, +} + +impl HttpResponseBuilder { + #[inline] + /// Create response builder + pub fn new(status: StatusCode) -> Self { + Self { + res: Some(Response::with_body(status, BoxBody::new(()))), + err: None, + #[cfg(feature = "cookies")] + cookies: None, + } + } + + /// Set HTTP status code of this response. + #[inline] + pub fn status(&mut self, status: StatusCode) -> &mut Self { + if let Some(parts) = self.inner() { + parts.status = status; + } + self + } + + /// Insert a header, replacing any that were set with an equivalent field name. + /// + /// ``` + /// use actix_web::{HttpResponse, http::header}; + /// + /// HttpResponse::Ok() + /// .insert_header(header::ContentType(mime::APPLICATION_JSON)) + /// .insert_header(("X-TEST", "value")) + /// .finish(); + /// ``` + pub fn insert_header(&mut self, header: impl TryIntoHeaderPair) -> &mut Self { + if let Some(parts) = self.inner() { + match header.try_into_pair() { + Ok((key, value)) => { + parts.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + }; + } + + self + } + + /// Append a header, keeping any that were set with an equivalent field name. + /// + /// ``` + /// use actix_web::{HttpResponse, http::header}; + /// + /// HttpResponse::Ok() + /// .append_header(header::ContentType(mime::APPLICATION_JSON)) + /// .append_header(("X-TEST", "value1")) + /// .append_header(("X-TEST", "value2")) + /// .finish(); + /// ``` + pub fn append_header(&mut self, header: impl TryIntoHeaderPair) -> &mut Self { + if let Some(parts) = self.inner() { + match header.try_into_pair() { + Ok((key, value)) => parts.headers.append(key, value), + Err(e) => self.err = Some(e.into()), + }; + } + + self + } + + /// Replaced with [`Self::insert_header()`]. + #[doc(hidden)] + #[deprecated( + since = "4.0.0", + note = "Replaced with `insert_header((key, value))`. Will be removed in v5." + )] + pub fn set_header(&mut self, key: K, value: V) -> &mut Self + where + K: TryInto, + K::Error: Into, + V: TryIntoHeaderValue, + { + if self.err.is_some() { + return self; + } + + match (key.try_into(), value.try_into_value()) { + (Ok(name), Ok(value)) => return self.insert_header((name, value)), + (Err(err), _) => self.err = Some(err.into()), + (_, Err(err)) => self.err = Some(err.into()), + } + + self + } + + /// Replaced with [`Self::append_header()`]. + #[doc(hidden)] + #[deprecated( + since = "4.0.0", + note = "Replaced with `append_header((key, value))`. Will be removed in v5." + )] + pub fn header(&mut self, key: K, value: V) -> &mut Self + where + K: TryInto, + K::Error: Into, + V: TryIntoHeaderValue, + { + if self.err.is_some() { + return self; + } + + match (key.try_into(), value.try_into_value()) { + (Ok(name), Ok(value)) => return self.append_header((name, value)), + (Err(err), _) => self.err = Some(err.into()), + (_, Err(err)) => self.err = Some(err.into()), + } + + self + } + + /// Set the custom reason for the response. + #[inline] + pub fn reason(&mut self, reason: &'static str) -> &mut Self { + if let Some(parts) = self.inner() { + parts.reason = Some(reason); + } + self + } + + /// Set connection type to KeepAlive + #[inline] + pub fn keep_alive(&mut self) -> &mut Self { + if let Some(parts) = self.inner() { + parts.set_connection_type(ConnectionType::KeepAlive); + } + self + } + + /// Set connection type to Upgrade + #[inline] + pub fn upgrade(&mut self, value: V) -> &mut Self + where + V: TryIntoHeaderValue, + { + if let Some(parts) = self.inner() { + parts.set_connection_type(ConnectionType::Upgrade); + } + + if let Ok(value) = value.try_into_value() { + self.insert_header((header::UPGRADE, value)); + } + + self + } + + /// Force close connection, even if it is marked as keep-alive + #[inline] + pub fn force_close(&mut self) -> &mut Self { + if let Some(parts) = self.inner() { + parts.set_connection_type(ConnectionType::Close); + } + self + } + + /// Disable chunked transfer encoding for HTTP/1.1 streaming responses. + #[inline] + pub fn no_chunking(&mut self, len: u64) -> &mut Self { + let mut buf = itoa::Buffer::new(); + self.insert_header((header::CONTENT_LENGTH, buf.format(len))); + + if let Some(parts) = self.inner() { + parts.no_chunking(true); + } + self + } + + /// Set response content type. + #[inline] + pub fn content_type(&mut self, value: V) -> &mut Self + where + V: TryIntoHeaderValue, + { + if let Some(parts) = self.inner() { + match value.try_into_value() { + Ok(value) => { + parts.headers.insert(header::CONTENT_TYPE, value); + } + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set a cookie. + /// + /// ``` + /// use actix_web::{HttpResponse, cookie::Cookie}; + /// + /// HttpResponse::Ok() + /// .cookie( + /// Cookie::build("name", "value") + /// .domain("www.rust-lang.org") + /// .path("/") + /// .secure(true) + /// .http_only(true) + /// .finish(), + /// ) + /// .finish(); + /// ``` + #[cfg(feature = "cookies")] + pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { + if self.cookies.is_none() { + let mut jar = CookieJar::new(); + jar.add(cookie.into_owned()); + self.cookies = Some(jar) + } else { + self.cookies.as_mut().unwrap().add(cookie.into_owned()); + } + self + } + + /// Remove cookie. + /// + /// A `Set-Cookie` header is added that will delete a cookie with the same name from the client. + /// + /// ``` + /// use actix_web::{HttpRequest, HttpResponse, Responder}; + /// + /// async fn handler(req: HttpRequest) -> impl Responder { + /// let mut builder = HttpResponse::Ok(); + /// + /// if let Some(ref cookie) = req.cookie("name") { + /// builder.del_cookie(cookie); + /// } + /// + /// builder.finish() + /// } + /// ``` + #[cfg(feature = "cookies")] + pub fn del_cookie(&mut self, cookie: &Cookie<'_>) -> &mut Self { + if self.cookies.is_none() { + self.cookies = Some(CookieJar::new()) + } + let jar = self.cookies.as_mut().unwrap(); + let cookie = cookie.clone().into_owned(); + jar.add_original(cookie.clone()); + jar.remove(cookie); + self + } + + /// Responses extensions + #[inline] + pub fn extensions(&self) -> Ref<'_, Extensions> { + self.res + .as_ref() + .expect("cannot reuse response builder") + .extensions() + } + + /// Mutable reference to a the response's extensions + pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { + self.res + .as_mut() + .expect("cannot reuse response builder") + .extensions_mut() + } + + /// Set a body and build the `HttpResponse`. + /// + /// `HttpResponseBuilder` can not be used after this call. + pub fn body(&mut self, body: B) -> HttpResponse + where + B: MessageBody + 'static, + { + match self.message_body(body) { + Ok(res) => res.map_into_boxed_body(), + Err(err) => HttpResponse::from_error(err), + } + } + + /// Set a body and build the `HttpResponse`. + /// + /// `HttpResponseBuilder` can not be used after this call. + pub fn message_body(&mut self, body: B) -> Result, Error> { + if let Some(err) = self.err.take() { + return Err(err.into()); + } + + let res = self + .res + .take() + .expect("cannot reuse response builder") + .set_body(body); + + #[allow(unused_mut)] // mut is only unused when cookies are disabled + let mut res = HttpResponse::from(res); + + #[cfg(feature = "cookies")] + if let Some(ref jar) = self.cookies { + for cookie in jar.delta() { + match HeaderValue::from_str(&cookie.to_string()) { + Ok(val) => res.headers_mut().append(header::SET_COOKIE, val), + Err(err) => return Err(err.into()), + }; + } + } + + Ok(res) + } + + /// Set a streaming body and build the `HttpResponse`. + /// + /// `HttpResponseBuilder` can not be used after this call. + #[inline] + pub fn streaming(&mut self, stream: S) -> HttpResponse + where + S: Stream> + 'static, + E: Into + 'static, + { + self.body(BodyStream::new(stream)) + } + + /// Set a JSON body and build the `HttpResponse`. + /// + /// `HttpResponseBuilder` can not be used after this call. + pub fn json(&mut self, value: impl Serialize) -> HttpResponse { + match serde_json::to_string(&value) { + Ok(body) => { + let contains = if let Some(parts) = self.inner() { + parts.headers.contains_key(header::CONTENT_TYPE) + } else { + true + }; + + if !contains { + self.insert_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)); + } + + self.body(body) + } + Err(err) => HttpResponse::from_error(JsonPayloadError::Serialize(err)), + } + } + + /// Set an empty body and build the `HttpResponse`. + /// + /// `HttpResponseBuilder` can not be used after this call. + #[inline] + pub fn finish(&mut self) -> HttpResponse { + self.body(()) + } + + /// This method construct new `HttpResponseBuilder` + pub fn take(&mut self) -> Self { + Self { + res: self.res.take(), + err: self.err.take(), + #[cfg(feature = "cookies")] + cookies: self.cookies.take(), + } + } + + #[inline] + fn inner(&mut self) -> Option<&mut ResponseHead> { + if self.err.is_some() { + return None; + } + + self.res.as_mut().map(Response::head_mut) + } +} + +impl From for HttpResponse { + fn from(mut builder: HttpResponseBuilder) -> Self { + builder.finish() + } +} + +impl From for Response { + fn from(mut builder: HttpResponseBuilder) -> Self { + builder.finish().into() + } +} + +impl Future for HttpResponseBuilder { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + Poll::Ready(Ok(self.finish())) + } +} + +impl Responder for HttpResponseBuilder { + type Body = BoxBody; + + #[inline] + fn respond_to(mut self, _: &HttpRequest) -> HttpResponse { + self.finish() + } +} + +#[cfg(test)] +mod tests { + use actix_http::body; + + use super::*; + use crate::{ + http::{ + header::{self, HeaderValue, CONTENT_TYPE}, + StatusCode, + }, + test::assert_body_eq, + }; + + #[test] + fn test_basic_builder() { + let resp = HttpResponse::Ok() + .insert_header(("X-TEST", "value")) + .finish(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[test] + fn test_upgrade() { + let resp = HttpResponseBuilder::new(StatusCode::OK) + .upgrade("websocket") + .finish(); + assert!(resp.upgrade()); + assert_eq!( + resp.headers().get(header::UPGRADE).unwrap(), + HeaderValue::from_static("websocket") + ); + } + + #[test] + fn test_force_close() { + let resp = HttpResponseBuilder::new(StatusCode::OK) + .force_close() + .finish(); + assert!(!resp.keep_alive()) + } + + #[test] + fn test_content_type() { + let resp = HttpResponseBuilder::new(StatusCode::OK) + .content_type("text/plain") + .body(Bytes::new()); + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain") + } + + #[actix_rt::test] + async fn test_json() { + let res = HttpResponse::Ok().json(vec!["v1", "v2", "v3"]); + let ct = res.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("application/json")); + assert_body_eq!(res, br#"["v1","v2","v3"]"#); + + let res = HttpResponse::Ok().json(&["v1", "v2", "v3"]); + let ct = res.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("application/json")); + assert_body_eq!(res, br#"["v1","v2","v3"]"#); + + // content type override + let res = HttpResponse::Ok() + .insert_header((CONTENT_TYPE, "text/json")) + .json(&vec!["v1", "v2", "v3"]); + let ct = res.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("text/json")); + assert_body_eq!(res, br#"["v1","v2","v3"]"#); + } + + #[actix_rt::test] + async fn test_serde_json_in_body() { + let resp = HttpResponse::Ok().body( + serde_json::to_vec(&serde_json::json!({ "test-key": "test-value" })).unwrap(), + ); + + assert_eq!( + body::to_bytes(resp.into_body()).await.unwrap().as_ref(), + br#"{"test-key":"test-value"}"# + ); + } + + #[test] + fn response_builder_header_insert_kv() { + let mut res = HttpResponse::Ok(); + res.insert_header(("Content-Type", "application/octet-stream")); + let res = res.finish(); + + assert_eq!( + res.headers().get("Content-Type"), + Some(&HeaderValue::from_static("application/octet-stream")) + ); + } + + #[test] + fn response_builder_header_insert_typed() { + let mut res = HttpResponse::Ok(); + res.insert_header((header::CONTENT_TYPE, mime::APPLICATION_OCTET_STREAM)); + let res = res.finish(); + + assert_eq!( + res.headers().get("Content-Type"), + Some(&HeaderValue::from_static("application/octet-stream")) + ); + } + + #[test] + fn response_builder_header_append_kv() { + let mut res = HttpResponse::Ok(); + res.append_header(("Content-Type", "application/octet-stream")); + res.append_header(("Content-Type", "application/json")); + let res = res.finish(); + + let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); + assert_eq!(headers.len(), 2); + assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); + assert!(headers.contains(&HeaderValue::from_static("application/json"))); + } + + #[test] + fn response_builder_header_append_typed() { + let mut res = HttpResponse::Ok(); + res.append_header((header::CONTENT_TYPE, mime::APPLICATION_OCTET_STREAM)); + res.append_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)); + let res = res.finish(); + + let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); + assert_eq!(headers.len(), 2); + assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); + assert!(headers.contains(&HeaderValue::from_static("application/json"))); + } +} diff --git a/src/response/customize_responder.rs b/src/response/customize_responder.rs new file mode 100644 index 000000000..8cb146dda --- /dev/null +++ b/src/response/customize_responder.rs @@ -0,0 +1,241 @@ +use actix_http::{ + body::EitherBody, error::HttpError, header::HeaderMap, header::TryIntoHeaderPair, + StatusCode, +}; + +use crate::{HttpRequest, HttpResponse, Responder}; + +/// Allows overriding status code and headers for a [`Responder`]. +/// +/// Created by the [`Responder::customize`] method. +pub struct CustomizeResponder { + inner: CustomizeResponderInner, + error: Option, +} + +struct CustomizeResponderInner { + responder: R, + status: Option, + override_headers: HeaderMap, + append_headers: HeaderMap, +} + +impl CustomizeResponder { + pub(crate) fn new(responder: R) -> Self { + CustomizeResponder { + inner: CustomizeResponderInner { + responder, + status: None, + override_headers: HeaderMap::new(), + append_headers: HeaderMap::new(), + }, + error: None, + } + } + + /// Override a status code for the Responder's response. + /// + /// # Examples + /// ``` + /// use actix_web::{Responder, http::StatusCode, test::TestRequest}; + /// + /// let responder = "Welcome!".customize().with_status(StatusCode::ACCEPTED); + /// + /// let request = TestRequest::default().to_http_request(); + /// let response = responder.respond_to(&request); + /// assert_eq!(response.status(), StatusCode::ACCEPTED); + /// ``` + pub fn with_status(mut self, status: StatusCode) -> Self { + if let Some(inner) = self.inner() { + inner.status = Some(status); + } + + self + } + + /// Insert (override) header in the final response. + /// + /// Overrides other headers with the same name. + /// See [`HeaderMap::insert`](crate::http::header::HeaderMap::insert). + /// + /// Headers added with this method will be inserted before those added + /// with [`append_header`](Self::append_header). As such, header(s) can be overridden with more + /// than one new header by first calling `insert_header` followed by `append_header`. + /// + /// # Examples + /// ``` + /// use actix_web::{Responder, test::TestRequest}; + /// + /// let responder = "Hello world!" + /// .customize() + /// .insert_header(("x-version", "1.2.3")); + /// + /// let request = TestRequest::default().to_http_request(); + /// let response = responder.respond_to(&request); + /// assert_eq!(response.headers().get("x-version").unwrap(), "1.2.3"); + /// ``` + pub fn insert_header(mut self, header: impl TryIntoHeaderPair) -> Self { + if let Some(inner) = self.inner() { + match header.try_into_pair() { + Ok((key, value)) => { + inner.override_headers.insert(key, value); + } + Err(err) => self.error = Some(err.into()), + }; + } + + self + } + + /// Append header to the final response. + /// + /// Unlike [`insert_header`](Self::insert_header), this will not override existing headers. + /// See [`HeaderMap::append`](crate::http::header::HeaderMap::append). + /// + /// Headers added here are appended _after_ additions/overrides from `insert_header`. + /// + /// # Examples + /// ``` + /// use actix_web::{Responder, test::TestRequest}; + /// + /// let responder = "Hello world!" + /// .customize() + /// .append_header(("x-version", "1.2.3")); + /// + /// let request = TestRequest::default().to_http_request(); + /// let response = responder.respond_to(&request); + /// assert_eq!(response.headers().get("x-version").unwrap(), "1.2.3"); + /// ``` + pub fn append_header(mut self, header: impl TryIntoHeaderPair) -> Self { + if let Some(inner) = self.inner() { + match header.try_into_pair() { + Ok((key, value)) => { + inner.append_headers.append(key, value); + } + Err(err) => self.error = Some(err.into()), + }; + } + + self + } + + #[doc(hidden)] + #[deprecated(since = "4.0.0", note = "Renamed to `insert_header`.")] + pub fn with_header(self, header: impl TryIntoHeaderPair) -> Self + where + Self: Sized, + { + self.insert_header(header) + } + + fn inner(&mut self) -> Option<&mut CustomizeResponderInner> { + if self.error.is_some() { + None + } else { + Some(&mut self.inner) + } + } +} + +impl Responder for CustomizeResponder +where + T: Responder, +{ + type Body = EitherBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { + if let Some(err) = self.error { + return HttpResponse::from_error(err).map_into_right_body(); + } + + let mut res = self.inner.responder.respond_to(req); + + if let Some(status) = self.inner.status { + *res.status_mut() = status; + } + + for (k, v) in self.inner.override_headers { + res.headers_mut().insert(k, v); + } + + for (k, v) in self.inner.append_headers { + res.headers_mut().append(k, v); + } + + res.map_into_left_body() + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use actix_http::body::to_bytes; + + use super::*; + use crate::{ + http::{ + header::{HeaderValue, CONTENT_TYPE}, + StatusCode, + }, + test::TestRequest, + }; + + #[actix_rt::test] + async fn customize_responder() { + let req = TestRequest::default().to_http_request(); + let res = "test" + .to_string() + .customize() + .with_status(StatusCode::BAD_REQUEST) + .respond_to(&req); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = "test" + .to_string() + .customize() + .insert_header(("content-type", "json")) + .respond_to(&req); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("json") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + } + + #[actix_rt::test] + async fn tuple_responder_with_status_code() { + let req = TestRequest::default().to_http_request(); + let res = ("test".to_string(), StatusCode::BAD_REQUEST).respond_to(&req); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let req = TestRequest::default().to_http_request(); + let res = ("test".to_string(), StatusCode::OK) + .customize() + .insert_header((CONTENT_TYPE, mime::APPLICATION_JSON)) + .respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/json") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + } +} diff --git a/actix-http/src/http_codes.rs b/src/response/http_codes.rs similarity index 91% rename from actix-http/src/http_codes.rs rename to src/response/http_codes.rs index 688a08be5..986735346 100644 --- a/actix-http/src/http_codes.rs +++ b/src/response/http_codes.rs @@ -1,21 +1,19 @@ //! Status code based HTTP response builders. -#![allow(non_upper_case_globals)] +use actix_http::StatusCode; -use http::StatusCode; - -use crate::response::{Response, ResponseBuilder}; +use crate::{HttpResponse, HttpResponseBuilder}; macro_rules! static_resp { ($name:ident, $status:expr) => { #[allow(non_snake_case, missing_docs)] - pub fn $name() -> ResponseBuilder { - ResponseBuilder::new($status) + pub fn $name() -> HttpResponseBuilder { + HttpResponseBuilder::new($status) } }; } -impl Response { +impl HttpResponse { static_resp!(Continue, StatusCode::CONTINUE); static_resp!(SwitchingProtocols, StatusCode::SWITCHING_PROTOCOLS); static_resp!(Processing, StatusCode::PROCESSING); @@ -89,13 +87,12 @@ impl Response { #[cfg(test)] mod tests { - use crate::body::Body; - use crate::response::Response; - use http::StatusCode; + use crate::http::StatusCode; + use crate::HttpResponse; #[test] fn test_build() { - let resp = Response::Ok().body(Body::Empty); + let resp = HttpResponse::Ok().finish(); assert_eq!(resp.status(), StatusCode::OK); } } diff --git a/src/response/mod.rs b/src/response/mod.rs new file mode 100644 index 000000000..977147104 --- /dev/null +++ b/src/response/mod.rs @@ -0,0 +1,14 @@ +mod builder; +mod customize_responder; +mod http_codes; +mod responder; +#[allow(clippy::module_inception)] +mod response; + +pub use self::builder::HttpResponseBuilder; +pub use self::customize_responder::CustomizeResponder; +pub use self::responder::Responder; +pub use self::response::HttpResponse; + +#[cfg(feature = "cookies")] +pub use self::response::CookieIter; diff --git a/src/response/responder.rs b/src/response/responder.rs new file mode 100644 index 000000000..d1b9e49e0 --- /dev/null +++ b/src/response/responder.rs @@ -0,0 +1,325 @@ +use std::borrow::Cow; + +use actix_http::{ + body::{BoxBody, EitherBody, MessageBody}, + header::TryIntoHeaderPair, + StatusCode, +}; +use bytes::{Bytes, BytesMut}; + +use crate::{Error, HttpRequest, HttpResponse}; + +use super::CustomizeResponder; + +/// Trait implemented by types that can be converted to an HTTP response. +/// +/// Any types that implement this trait can be used in the return type of a handler. +// # TODO: more about implementation notes and foreign impls +pub trait Responder { + type Body: MessageBody + 'static; + + /// Convert self to `HttpResponse`. + fn respond_to(self, req: &HttpRequest) -> HttpResponse; + + /// Wraps responder to allow alteration of its response. + /// + /// See [`CustomizeResponder`] docs for its capabilities. + /// + /// # Examples + /// ``` + /// use actix_web::{Responder, http::StatusCode, test::TestRequest}; + /// + /// let responder = "Hello world!" + /// .customize() + /// .with_status(StatusCode::BAD_REQUEST) + /// .insert_header(("x-hello", "world")); + /// + /// let request = TestRequest::default().to_http_request(); + /// let response = responder.respond_to(&request); + /// assert_eq!(response.status(), StatusCode::BAD_REQUEST); + /// assert_eq!(response.headers().get("x-hello").unwrap(), "world"); + /// ``` + #[inline] + fn customize(self) -> CustomizeResponder + where + Self: Sized, + { + CustomizeResponder::new(self) + } + + #[doc(hidden)] + #[deprecated(since = "4.0.0", note = "Prefer `.customize().insert_header(header)`.")] + fn with_header(self, header: impl TryIntoHeaderPair) -> CustomizeResponder + where + Self: Sized, + { + self.customize().insert_header(header) + } +} + +impl Responder for actix_http::Response { + type Body = BoxBody; + + #[inline] + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + HttpResponse::from(self) + } +} + +impl Responder for actix_http::ResponseBuilder { + type Body = BoxBody; + + #[inline] + fn respond_to(mut self, req: &HttpRequest) -> HttpResponse { + self.finish().map_into_boxed_body().respond_to(req) + } +} + +impl Responder for Option +where + T: Responder, +{ + type Body = EitherBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { + match self { + Some(val) => val.respond_to(req).map_into_left_body(), + None => HttpResponse::new(StatusCode::NOT_FOUND).map_into_right_body(), + } + } +} + +impl Responder for Result +where + T: Responder, + E: Into, +{ + type Body = EitherBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { + match self { + Ok(val) => val.respond_to(req).map_into_left_body(), + Err(err) => HttpResponse::from_error(err.into()).map_into_right_body(), + } + } +} + +impl Responder for (T, StatusCode) { + type Body = T::Body; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { + let mut res = self.0.respond_to(req); + *res.status_mut() = self.1; + res + } +} + +macro_rules! impl_responder_by_forward_into_base_response { + ($res:ty, $body:ty) => { + impl Responder for $res { + type Body = $body; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + let res: actix_http::Response<_> = self.into(); + res.into() + } + } + }; + + ($res:ty) => { + impl_responder_by_forward_into_base_response!($res, $res); + }; +} + +impl_responder_by_forward_into_base_response!(&'static [u8]); +impl_responder_by_forward_into_base_response!(Bytes); +impl_responder_by_forward_into_base_response!(BytesMut); + +impl_responder_by_forward_into_base_response!(&'static str); +impl_responder_by_forward_into_base_response!(String); + +macro_rules! impl_into_string_responder { + ($res:ty) => { + impl Responder for $res { + type Body = String; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + let string: String = self.into(); + let res: actix_http::Response<_> = string.into(); + res.into() + } + } + }; +} + +impl_into_string_responder!(&'_ String); +impl_into_string_responder!(Cow<'_, str>); + +#[cfg(test)] +pub(crate) mod tests { + use actix_service::Service; + use bytes::{Bytes, BytesMut}; + + use actix_http::body::to_bytes; + + use super::*; + use crate::{ + error, + http::{ + header::{HeaderValue, CONTENT_TYPE}, + StatusCode, + }, + test::{assert_body_eq, init_service, TestRequest}, + web, App, + }; + + #[actix_rt::test] + async fn test_option_responder() { + let srv = init_service( + App::new() + .service(web::resource("/none").to(|| async { Option::<&'static str>::None })) + .service(web::resource("/some").to(|| async { Some("some") })), + ) + .await; + + let req = TestRequest::with_uri("/none").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/some").to_request(); + let resp = srv.call(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_body_eq!(resp, b"some"); + } + + #[actix_rt::test] + async fn test_responder() { + let req = TestRequest::default().to_http_request(); + + let res = "test".respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = b"test".respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = "test".to_string().respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = (&"test".to_string()).respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let s = String::from("test"); + let res = Cow::Borrowed(s.as_str()).respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = Cow::<'_, str>::Owned(s).respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = Cow::Borrowed("test").respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = Bytes::from_static(b"test").respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = BytesMut::from(b"test".as_ref()).respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!( + to_bytes(res.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + // InternalError + let res = error::InternalError::new("err", StatusCode::BAD_REQUEST).respond_to(&req); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + } + + #[actix_rt::test] + async fn test_result_responder() { + let req = TestRequest::default().to_http_request(); + + // Result + let resp = Ok::<_, Error>("test".to_string()).respond_to(&req); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!( + to_bytes(resp.into_body()).await.unwrap(), + Bytes::from_static(b"test"), + ); + + let res = Err::(error::InternalError::new("err", StatusCode::BAD_REQUEST)) + .respond_to(&req); + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + } +} diff --git a/src/response/response.rs b/src/response/response.rs new file mode 100644 index 000000000..f24a75b19 --- /dev/null +++ b/src/response/response.rs @@ -0,0 +1,367 @@ +use std::{ + cell::{Ref, RefMut}, + fmt, + future::Future, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_http::{ + body::{BoxBody, EitherBody, MessageBody}, + header::HeaderMap, + Extensions, Response, ResponseHead, StatusCode, +}; + +#[cfg(feature = "cookies")] +use { + actix_http::{ + error::HttpError, + header::{self, HeaderValue}, + }, + cookie::Cookie, +}; + +use crate::{error::Error, HttpRequest, HttpResponseBuilder, Responder}; + +/// An outgoing response. +pub struct HttpResponse { + res: Response, + pub(crate) error: Option, +} + +impl HttpResponse { + /// Constructs a response. + #[inline] + pub fn new(status: StatusCode) -> Self { + Self { + res: Response::new(status), + error: None, + } + } + + /// Constructs a response builder with specific HTTP status. + #[inline] + pub fn build(status: StatusCode) -> HttpResponseBuilder { + HttpResponseBuilder::new(status) + } + + /// Create an error response. + #[inline] + pub fn from_error(error: impl Into) -> Self { + let error = error.into(); + let mut response = error.as_response_error().error_response(); + response.error = Some(error); + response + } +} + +impl HttpResponse { + /// Constructs a response with body + #[inline] + pub fn with_body(status: StatusCode, body: B) -> Self { + Self { + res: Response::with_body(status, body), + error: None, + } + } + + /// Returns a reference to response head. + #[inline] + pub fn head(&self) -> &ResponseHead { + self.res.head() + } + + /// Returns a mutable reference to response head. + #[inline] + pub fn head_mut(&mut self) -> &mut ResponseHead { + self.res.head_mut() + } + + /// The source `error` for this response + #[inline] + pub fn error(&self) -> Option<&Error> { + self.error.as_ref() + } + + /// Get the response status code + #[inline] + pub fn status(&self) -> StatusCode { + self.res.status() + } + + /// Set the `StatusCode` for this response + #[inline] + pub fn status_mut(&mut self) -> &mut StatusCode { + self.res.status_mut() + } + + /// Get the headers from the response + #[inline] + pub fn headers(&self) -> &HeaderMap { + self.res.headers() + } + + /// Get a mutable reference to the headers + #[inline] + pub fn headers_mut(&mut self) -> &mut HeaderMap { + self.res.headers_mut() + } + + /// Get an iterator for the cookies set by this response. + #[cfg(feature = "cookies")] + pub fn cookies(&self) -> CookieIter<'_> { + CookieIter { + iter: self.headers().get_all(header::SET_COOKIE), + } + } + + /// Add a cookie to this response + #[cfg(feature = "cookies")] + pub fn add_cookie(&mut self, cookie: &Cookie<'_>) -> Result<(), HttpError> { + HeaderValue::from_str(&cookie.to_string()) + .map(|c| { + self.headers_mut().append(header::SET_COOKIE, c); + }) + .map_err(|e| e.into()) + } + + /// Remove all cookies with the given name from this response. Returns + /// the number of cookies removed. + #[cfg(feature = "cookies")] + pub fn del_cookie(&mut self, name: &str) -> usize { + let headers = self.headers_mut(); + + let vals: Vec = headers + .get_all(header::SET_COOKIE) + .map(|v| v.to_owned()) + .collect(); + + headers.remove(header::SET_COOKIE); + + let mut count: usize = 0; + for v in vals { + if let Ok(s) = v.to_str() { + if let Ok(c) = Cookie::parse_encoded(s) { + if c.name() == name { + count += 1; + continue; + } + } + } + + // put set-cookie header head back if it does not validate + headers.append(header::SET_COOKIE, v); + } + + count + } + + /// Connection upgrade status + #[inline] + pub fn upgrade(&self) -> bool { + self.res.upgrade() + } + + /// Keep-alive status for this connection + pub fn keep_alive(&self) -> bool { + self.res.keep_alive() + } + + /// Responses extensions + #[inline] + pub fn extensions(&self) -> Ref<'_, Extensions> { + self.res.extensions() + } + + /// Mutable reference to a the response's extensions + #[inline] + pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { + self.res.extensions_mut() + } + + /// Get body of this response + #[inline] + pub fn body(&self) -> &B { + self.res.body() + } + + /// Set a body + pub fn set_body(self, body: B2) -> HttpResponse { + HttpResponse { + res: self.res.set_body(body), + error: None, + // error: self.error, ?? + } + } + + /// Split response and body + pub fn into_parts(self) -> (HttpResponse<()>, B) { + let (head, body) = self.res.into_parts(); + + ( + HttpResponse { + res: head, + error: None, + }, + body, + ) + } + + /// Drop request's body + pub fn drop_body(self) -> HttpResponse<()> { + HttpResponse { + res: self.res.drop_body(), + error: None, + } + } + + /// Set a body and return previous body value + pub fn map_body(self, f: F) -> HttpResponse + where + F: FnOnce(&mut ResponseHead, B) -> B2, + { + HttpResponse { + res: self.res.map_body(f), + error: self.error, + } + } + + // TODO: docs for the body map methods below + + #[inline] + pub fn map_into_left_body(self) -> HttpResponse> { + self.map_body(|_, body| EitherBody::left(body)) + } + + #[inline] + pub fn map_into_right_body(self) -> HttpResponse> { + self.map_body(|_, body| EitherBody::right(body)) + } + + #[inline] + pub fn map_into_boxed_body(self) -> HttpResponse + where + B: MessageBody + 'static, + { + self.map_body(|_, body| body.boxed()) + } + + /// Extract response body + pub fn into_body(self) -> B { + self.res.into_body() + } +} + +impl fmt::Debug for HttpResponse +where + B: MessageBody, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HttpResponse") + .field("error", &self.error) + .field("res", &self.res) + .finish() + } +} + +impl From> for HttpResponse { + fn from(res: Response) -> Self { + HttpResponse { res, error: None } + } +} + +impl From for HttpResponse { + fn from(err: Error) -> Self { + HttpResponse::from_error(err) + } +} + +impl From> for Response { + fn from(res: HttpResponse) -> Self { + // this impl will always be called as part of dispatcher + + // TODO: expose cause somewhere? + // if let Some(err) = res.error { + // return Response::from_error(err); + // } + + res.res + } +} + +// Future is only implemented for BoxBody payload type because it's the most useful for making +// simple handlers without async blocks. Making it generic over all MessageBody types requires a +// future impl on Response which would cause it's body field to be, undesirably, Option. +// +// This impl is not particularly efficient due to the Response construction and should probably +// not be invoked if performance is important. Prefer an async fn/block in such cases. +impl Future for HttpResponse { + type Output = Result, Error>; + + fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + if let Some(err) = self.error.take() { + return Poll::Ready(Err(err)); + } + + Poll::Ready(Ok(mem::replace( + &mut self.res, + Response::new(StatusCode::default()), + ))) + } +} + +impl Responder for HttpResponse +where + B: MessageBody + 'static, +{ + type Body = B; + + #[inline] + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + self + } +} + +#[cfg(feature = "cookies")] +pub struct CookieIter<'a> { + iter: std::slice::Iter<'a, HeaderValue>, +} + +#[cfg(feature = "cookies")] +impl<'a> Iterator for CookieIter<'a> { + type Item = Cookie<'a>; + + #[inline] + fn next(&mut self) -> Option> { + for v in self.iter.by_ref() { + if let Ok(c) = Cookie::parse_encoded(v.to_str().ok()?) { + return Some(c); + } + } + None + } +} + +#[cfg(test)] +mod tests { + use static_assertions::assert_impl_all; + + use super::*; + use crate::http::header::{HeaderValue, COOKIE}; + + assert_impl_all!(HttpResponse: Responder); + assert_impl_all!(HttpResponse: Responder); + assert_impl_all!(HttpResponse<&'static str>: Responder); + assert_impl_all!(HttpResponse: Responder); + + #[test] + fn test_debug() { + let resp = HttpResponse::Ok() + .append_header((COOKIE, HeaderValue::from_static("cookie1=value1; "))) + .append_header((COOKIE, HeaderValue::from_static("cookie2=value2; "))) + .finish(); + let dbg = format!("{:?}", resp); + assert!(dbg.contains("HttpResponse")); + } +} diff --git a/src/rmap.rs b/src/rmap.rs index 3c8805d57..432eaf83c 100644 --- a/src/rmap.rs +++ b/src/rmap.rs @@ -1,53 +1,86 @@ -use std::cell::RefCell; -use std::rc::{Rc, Weak}; +use std::{ + borrow::Cow, + cell::RefCell, + rc::{Rc, Weak}, +}; use actix_router::ResourceDef; use ahash::AHashMap; use url::Url; -use crate::error::UrlGenerationError; -use crate::request::HttpRequest; +use crate::{error::UrlGenerationError, request::HttpRequest}; #[derive(Clone, Debug)] pub struct ResourceMap { - root: ResourceDef, + pattern: ResourceDef, + + /// Named resources within the tree or, for external resources, + /// it points to isolated nodes outside the tree. + named: AHashMap>, + parent: RefCell>, - named: AHashMap, - patterns: Vec<(ResourceDef, Option>)>, + + /// Must be `None` for "edge" nodes. + nodes: Option>>, } impl ResourceMap { + /// Creates a _container_ node in the `ResourceMap` tree. pub fn new(root: ResourceDef) -> Self { ResourceMap { - root, - parent: RefCell::new(Weak::new()), + pattern: root, named: AHashMap::default(), - patterns: Vec::new(), + parent: RefCell::new(Weak::new()), + nodes: Some(Vec::new()), } } + /// Adds a (possibly nested) resource. + /// + /// To add a non-prefix pattern, `nested` must be `None`. + /// To add external resource, supply a pattern without a leading `/`. + /// The root pattern of `nested`, if present, should match `pattern`. pub fn add(&mut self, pattern: &mut ResourceDef, nested: Option>) { - pattern.set_id(self.patterns.len() as u16); - self.patterns.push((pattern.clone(), nested)); - if !pattern.name().is_empty() { - self.named - .insert(pattern.name().to_string(), pattern.clone()); + pattern.set_id(self.nodes.as_ref().unwrap().len() as u16); + + if let Some(new_node) = nested { + assert_eq!(&new_node.pattern, pattern, "`patern` and `nested` mismatch"); + self.named.extend(new_node.named.clone().into_iter()); + self.nodes.as_mut().unwrap().push(new_node); + } else { + let new_node = Rc::new(ResourceMap { + pattern: pattern.clone(), + named: AHashMap::default(), + parent: RefCell::new(Weak::new()), + nodes: None, + }); + + if let Some(name) = pattern.name() { + self.named.insert(name.to_owned(), Rc::clone(&new_node)); + } + + let is_external = match pattern.pattern() { + Some(p) => !p.is_empty() && !p.starts_with('/'), + None => false, + }; + + // Don't add external resources to the tree + if !is_external { + self.nodes.as_mut().unwrap().push(new_node); + } } } - pub(crate) fn finish(&self, current: Rc) { - for (_, nested) in &self.patterns { - if let Some(ref nested) = nested { - *nested.parent.borrow_mut() = Rc::downgrade(¤t); - nested.finish(nested.clone()); - } + pub(crate) fn finish(self: &Rc) { + for node in self.nodes.iter().flatten() { + node.parent.replace(Rc::downgrade(self)); + ResourceMap::finish(node); } } /// Generate url for named resource /// - /// Check [`HttpRequest::url_for()`](../struct.HttpRequest.html#method. - /// url_for) for detailed information. + /// Check [`HttpRequest::url_for`] for detailed information. pub fn url_for( &self, req: &HttpRequest, @@ -58,192 +91,108 @@ impl ResourceMap { U: IntoIterator, I: AsRef, { - let mut path = String::new(); let mut elements = elements.into_iter(); - if self.patterns_for(name, &mut path, &mut elements)?.is_some() { - if path.starts_with('/') { - let conn = req.connection_info(); - Ok(Url::parse(&format!( - "{}://{}{}", - conn.scheme(), - conn.host(), - path - ))?) - } else { - Ok(Url::parse(&path)?) - } + let path = self + .named + .get(name) + .ok_or(UrlGenerationError::ResourceNotFound)? + .root_rmap_fn(String::with_capacity(24), |mut acc, node| { + node.pattern + .resource_path_from_iter(&mut acc, &mut elements) + .then(|| acc) + }) + .ok_or(UrlGenerationError::NotEnoughElements)?; + + let (base, path): (Cow<'_, _>, _) = if path.starts_with('/') { + // build full URL from connection info parts and resource path + let conn = req.connection_info(); + let base = format!("{}://{}", conn.scheme(), conn.host()); + (Cow::Owned(base), path.as_str()) } else { - Err(UrlGenerationError::ResourceNotFound) - } + // external resource; third slash would be the root slash in the path + let third_slash_index = path + .char_indices() + .filter_map(|(i, c)| (c == '/').then(|| i)) + .nth(2) + .unwrap_or_else(|| path.len()); + + ( + Cow::Borrowed(&path[..third_slash_index]), + &path[third_slash_index..], + ) + }; + + let mut url = Url::parse(&base)?; + url.set_path(path); + Ok(url) } pub fn has_resource(&self, path: &str) -> bool { - let path = if path.is_empty() { "/" } else { path }; - - for (pattern, rmap) in &self.patterns { - if let Some(ref rmap) = rmap { - if let Some(plen) = pattern.is_prefix_match(path) { - return rmap.has_resource(&path[plen..]); - } - } else if pattern.is_match(path) || pattern.pattern() == "" && path == "/" { - return true; - } - } - false + self.find_matching_node(path).is_some() } /// Returns the name of the route that matches the given path or None if no full match - /// is possible. + /// is possible or the matching resource is not named. pub fn match_name(&self, path: &str) -> Option<&str> { - let path = if path.is_empty() { "/" } else { path }; - - for (pattern, rmap) in &self.patterns { - if let Some(ref rmap) = rmap { - if let Some(plen) = pattern.is_prefix_match(path) { - return rmap.match_name(&path[plen..]); - } - } else if pattern.is_match(path) { - return match pattern.name() { - "" => None, - s => Some(s), - }; - } - } - - None + self.find_matching_node(path)?.pattern.name() } /// Returns the full resource pattern matched against a path or None if no full match /// is possible. pub fn match_pattern(&self, path: &str) -> Option { - let path = if path.is_empty() { "/" } else { path }; - - // ensure a full match exists - if !self.has_resource(path) { - return None; - } - - Some(self.traverse_resource_pattern(path)) + self.find_matching_node(path)?.root_rmap_fn( + String::with_capacity(24), + |mut acc, node| { + acc.push_str(node.pattern.pattern()?); + Some(acc) + }, + ) } - /// Takes remaining path and tries to match it up against a resource definition within the - /// current resource map recursively, returning a concatenation of all resource prefixes and - /// patterns matched in the tree. - /// - /// Should only be used after checking the resource exists in the map so that partial match - /// patterns are not returned. - fn traverse_resource_pattern(&self, remaining: &str) -> String { - for (pattern, rmap) in &self.patterns { - if let Some(ref rmap) = rmap { - if let Some(prefix_len) = pattern.is_prefix_match(remaining) { - let prefix = pattern.pattern().to_owned(); - - return [ - prefix, - rmap.traverse_resource_pattern(&remaining[prefix_len..]), - ] - .concat(); - } - } else if pattern.is_match(remaining) { - return pattern.pattern().to_owned(); - } - } - - String::new() + fn find_matching_node(&self, path: &str) -> Option<&ResourceMap> { + self._find_matching_node(path).flatten() } - fn patterns_for( - &self, - name: &str, - path: &mut String, - elements: &mut U, - ) -> Result, UrlGenerationError> + /// Returns `None` if root pattern doesn't match; + /// `Some(None)` if root pattern matches but there is no matching child pattern. + /// Don't search sideways when `Some(none)` is returned. + fn _find_matching_node(&self, path: &str) -> Option> { + let matched_len = self.pattern.find_match(path)?; + let path = &path[matched_len..]; + + Some(match &self.nodes { + // find first sub-node to match remaining path + Some(nodes) => nodes + .iter() + .filter_map(|node| node._find_matching_node(path)) + .next() + .flatten(), + + // only terminate at edge nodes + None => Some(self), + }) + } + + /// Find `self`'s highest ancestor and then run `F`, providing `B`, in that rmap context. + fn root_rmap_fn(&self, init: B, mut f: F) -> Option where - U: Iterator, - I: AsRef, + F: FnMut(B, &ResourceMap) -> Option, { - if self.pattern_for(name, path, elements)?.is_some() { - Ok(Some(())) - } else { - self.parent_pattern_for(name, path, elements) - } + self._root_rmap_fn(init, &mut f) } - fn pattern_for( - &self, - name: &str, - path: &mut String, - elements: &mut U, - ) -> Result, UrlGenerationError> + /// Run `F`, providing `B`, if `self` is top-level resource map, else recurse to parent map. + fn _root_rmap_fn(&self, init: B, f: &mut F) -> Option where - U: Iterator, - I: AsRef, + F: FnMut(B, &ResourceMap) -> Option, { - if let Some(pattern) = self.named.get(name) { - if pattern.pattern().starts_with('/') { - self.fill_root(path, elements)?; - } - if pattern.resource_path(path, elements) { - Ok(Some(())) - } else { - Err(UrlGenerationError::NotEnoughElements) - } - } else { - for (_, rmap) in &self.patterns { - if let Some(ref rmap) = rmap { - if rmap.pattern_for(name, path, elements)?.is_some() { - return Ok(Some(())); - } - } - } - Ok(None) - } - } + let data = match self.parent.borrow().upgrade() { + Some(ref parent) => parent._root_rmap_fn(init, f)?, + None => init, + }; - fn fill_root( - &self, - path: &mut String, - elements: &mut U, - ) -> Result<(), UrlGenerationError> - where - U: Iterator, - I: AsRef, - { - if let Some(ref parent) = self.parent.borrow().upgrade() { - parent.fill_root(path, elements)?; - } - if self.root.resource_path(path, elements) { - Ok(()) - } else { - Err(UrlGenerationError::NotEnoughElements) - } - } - - fn parent_pattern_for( - &self, - name: &str, - path: &mut String, - elements: &mut U, - ) -> Result, UrlGenerationError> - where - U: Iterator, - I: AsRef, - { - if let Some(ref parent) = self.parent.borrow().upgrade() { - if let Some(pattern) = parent.named.get(name) { - self.fill_root(path, elements)?; - if pattern.resource_path(path, elements) { - Ok(Some(())) - } else { - Err(UrlGenerationError::NotEnoughElements) - } - } else { - parent.parent_pattern_for(name, path, elements) - } - } else { - Ok(None) - } + f(data, self) } } @@ -255,7 +204,7 @@ mod tests { fn extract_matched_pattern() { let mut root = ResourceMap::new(ResourceDef::root_prefix("")); - let mut user_map = ResourceMap::new(ResourceDef::root_prefix("")); + let mut user_map = ResourceMap::new(ResourceDef::root_prefix("/user/{id}")); user_map.add(&mut ResourceDef::new("/"), None); user_map.add(&mut ResourceDef::new("/profile"), None); user_map.add(&mut ResourceDef::new("/article/{id}"), None); @@ -271,9 +220,10 @@ mod tests { &mut ResourceDef::root_prefix("/user/{id}"), Some(Rc::new(user_map)), ); + root.add(&mut ResourceDef::new("/info"), None); let root = Rc::new(root); - root.finish(Rc::clone(&root)); + ResourceMap::finish(&root); // sanity check resource map setup @@ -284,7 +234,7 @@ mod tests { assert!(root.has_resource("/v2")); assert!(!root.has_resource("/v33")); - assert!(root.has_resource("/user/22")); + assert!(!root.has_resource("/user/22")); assert!(root.has_resource("/user/22/")); assert!(root.has_resource("/user/22/profile")); @@ -329,15 +279,15 @@ mod tests { let mut root = ResourceMap::new(ResourceDef::root_prefix("")); let mut rdef = ResourceDef::new("/info"); - *rdef.name_mut() = "root_info".to_owned(); + rdef.set_name("root_info"); root.add(&mut rdef, None); - let mut user_map = ResourceMap::new(ResourceDef::root_prefix("")); + let mut user_map = ResourceMap::new(ResourceDef::root_prefix("/user/{id}")); let mut rdef = ResourceDef::new("/"); user_map.add(&mut rdef, None); let mut rdef = ResourceDef::new("/post/{post_id}"); - *rdef.name_mut() = "user_post".to_owned(); + rdef.set_name("user_post"); user_map.add(&mut rdef, None); root.add( @@ -346,14 +296,14 @@ mod tests { ); let root = Rc::new(root); - root.finish(Rc::clone(&root)); + ResourceMap::finish(&root); // sanity check resource map setup assert!(root.has_resource("/info")); assert!(!root.has_resource("/bar")); - assert!(root.has_resource("/user/22")); + assert!(!root.has_resource("/user/22")); assert!(root.has_resource("/user/22/")); assert!(root.has_resource("/user/22/post/55")); @@ -373,7 +323,7 @@ mod tests { // ref: https://github.com/actix/actix-web/issues/1582 let mut root = ResourceMap::new(ResourceDef::root_prefix("")); - let mut user_map = ResourceMap::new(ResourceDef::root_prefix("")); + let mut user_map = ResourceMap::new(ResourceDef::root_prefix("/user/{id}")); user_map.add(&mut ResourceDef::new("/"), None); user_map.add(&mut ResourceDef::new("/profile"), None); user_map.add(&mut ResourceDef::new("/article/{id}"), None); @@ -389,20 +339,155 @@ mod tests { ); let root = Rc::new(root); - root.finish(Rc::clone(&root)); + ResourceMap::finish(&root); // check root has no parent assert!(root.parent.borrow().upgrade().is_none()); // check child has parent reference - assert!(root.patterns[0].1.is_some()); + assert!(root.nodes.as_ref().unwrap()[0] + .parent + .borrow() + .upgrade() + .is_some()); // check child's parent root id matches root's root id - assert_eq!( - root.patterns[0].1.as_ref().unwrap().root.id(), - root.root.id() - ); + assert!(Rc::ptr_eq( + &root.nodes.as_ref().unwrap()[0] + .parent + .borrow() + .upgrade() + .unwrap(), + &root + )); let output = format!("{:?}", root); assert!(output.starts_with("ResourceMap {")); assert!(output.ends_with(" }")); } + + #[test] + fn short_circuit() { + let mut root = ResourceMap::new(ResourceDef::prefix("")); + + let mut user_root = ResourceDef::prefix("/user"); + let mut user_map = ResourceMap::new(user_root.clone()); + user_map.add(&mut ResourceDef::new("/u1"), None); + user_map.add(&mut ResourceDef::new("/u2"), None); + + root.add(&mut ResourceDef::new("/user/u3"), None); + root.add(&mut user_root, Some(Rc::new(user_map))); + root.add(&mut ResourceDef::new("/user/u4"), None); + + let rmap = Rc::new(root); + ResourceMap::finish(&rmap); + + assert!(rmap.has_resource("/user/u1")); + assert!(rmap.has_resource("/user/u2")); + assert!(rmap.has_resource("/user/u3")); + assert!(!rmap.has_resource("/user/u4")); + } + + #[test] + fn url_for() { + let mut root = ResourceMap::new(ResourceDef::prefix("")); + + let mut user_scope_rdef = ResourceDef::prefix("/user"); + let mut user_scope_map = ResourceMap::new(user_scope_rdef.clone()); + + let mut user_rdef = ResourceDef::new("/{user_id}"); + let mut user_map = ResourceMap::new(user_rdef.clone()); + + let mut post_rdef = ResourceDef::new("/post/{sub_id}"); + post_rdef.set_name("post"); + + user_map.add(&mut post_rdef, None); + user_scope_map.add(&mut user_rdef, Some(Rc::new(user_map))); + root.add(&mut user_scope_rdef, Some(Rc::new(user_scope_map))); + + let rmap = Rc::new(root); + ResourceMap::finish(&rmap); + + let mut req = crate::test::TestRequest::default(); + req.set_server_hostname("localhost:8888"); + let req = req.to_http_request(); + + let url = rmap + .url_for(&req, "post", &["u123", "foobar"]) + .unwrap() + .to_string(); + assert_eq!(url, "http://localhost:8888/user/u123/post/foobar"); + + assert!(rmap.url_for(&req, "missing", &["u123"]).is_err()); + } + + #[test] + fn url_for_parser() { + let mut root = ResourceMap::new(ResourceDef::prefix("")); + + let mut rdef_1 = ResourceDef::new("/{var}"); + rdef_1.set_name("internal"); + + let mut rdef_2 = ResourceDef::new("http://host.dom/{var}"); + rdef_2.set_name("external.1"); + + let mut rdef_3 = ResourceDef::new("{var}"); + rdef_3.set_name("external.2"); + + root.add(&mut rdef_1, None); + root.add(&mut rdef_2, None); + root.add(&mut rdef_3, None); + let rmap = Rc::new(root); + ResourceMap::finish(&rmap); + + let mut req = crate::test::TestRequest::default(); + req.set_server_hostname("localhost:8888"); + let req = req.to_http_request(); + + const INPUT: &[&str] = &["a/../quick brown%20fox/%nan?query#frag"]; + const OUTPUT: &str = "/quick%20brown%20fox/%nan%3Fquery%23frag"; + + let url = rmap.url_for(&req, "internal", INPUT).unwrap(); + assert_eq!(url.path(), OUTPUT); + + let url = rmap.url_for(&req, "external.1", INPUT).unwrap(); + assert_eq!(url.path(), OUTPUT); + + assert!(rmap.url_for(&req, "external.2", INPUT).is_err()); + assert!(rmap.url_for(&req, "external.2", &[""]).is_err()); + } + + #[test] + fn external_resource_with_no_name() { + let mut root = ResourceMap::new(ResourceDef::prefix("")); + + let mut rdef = ResourceDef::new("https://duck.com/{query}"); + root.add(&mut rdef, None); + + let rmap = Rc::new(root); + ResourceMap::finish(&rmap); + + assert!(!rmap.has_resource("https://duck.com/abc")); + } + + #[test] + fn external_resource_with_name() { + let mut root = ResourceMap::new(ResourceDef::prefix("")); + + let mut rdef = ResourceDef::new("https://duck.com/{query}"); + rdef.set_name("duck"); + root.add(&mut rdef, None); + + let rmap = Rc::new(root); + ResourceMap::finish(&rmap); + + assert!(!rmap.has_resource("https://duck.com/abc")); + + let mut req = crate::test::TestRequest::default(); + req.set_server_hostname("localhost:8888"); + let req = req.to_http_request(); + + assert_eq!( + rmap.url_for(&req, "duck", &["abcd"]).unwrap().to_string(), + "https://duck.com/abcd" + ); + } } diff --git a/src/route.rs b/src/route.rs index b6b2482cd..0410b99dd 100644 --- a/src/route.rs +++ b/src/route.rs @@ -1,48 +1,25 @@ -#![allow(clippy::rc_buffer)] // inner value is mutated before being shared (`Rc::get_mut`) +use std::{mem, rc::Rc}; -use std::future::Future; -use std::pin::Pin; -use std::rc::Rc; -use std::task::{Context, Poll}; +use actix_http::Method; +use actix_service::{ + boxed::{self, BoxService}, + fn_service, Service, ServiceFactory, ServiceFactoryExt, +}; +use futures_core::future::LocalBoxFuture; -use actix_http::{http::Method, Error}; -use actix_service::{Service, ServiceFactory}; -use futures_util::future::{ready, FutureExt, LocalBoxFuture}; +use crate::{ + guard::{self, Guard}, + handler::{handler_service, Handler}, + service::{BoxedHttpServiceFactory, ServiceRequest, ServiceResponse}, + Error, FromRequest, HttpResponse, Responder, +}; -use crate::extract::FromRequest; -use crate::guard::{self, Guard}; -use crate::handler::{Handler, HandlerService}; -use crate::responder::Responder; -use crate::service::{ServiceRequest, ServiceResponse}; -use crate::HttpResponse; - -type BoxedRouteService = Box< - dyn Service< - ServiceRequest, - Response = ServiceResponse, - Error = Error, - Future = LocalBoxFuture<'static, Result>, - >, ->; - -type BoxedRouteNewService = Box< - dyn ServiceFactory< - ServiceRequest, - Config = (), - Response = ServiceResponse, - Error = Error, - InitError = (), - Service = BoxedRouteService, - Future = LocalBoxFuture<'static, Result>, - >, ->; - -/// Resource route definition +/// A request handler with [guards](guard). /// -/// Route uses builder-like pattern for configuration. -/// If handler is not explicitly set, default *404 Not Found* handler is used. +/// Route uses a builder-like pattern for configuration. If handler is not set, a `404 Not Found` +/// handler is used. pub struct Route { - service: BoxedRouteNewService, + service: BoxedHttpServiceFactory, guards: Rc>>, } @@ -51,64 +28,49 @@ impl Route { #[allow(clippy::new_without_default)] pub fn new() -> Route { Route { - service: Box::new(RouteNewService::new(HandlerService::new(|| { - ready(HttpResponse::NotFound()) - }))), + service: boxed::factory(fn_service(|req: ServiceRequest| async { + Ok(req.into_response(HttpResponse::NotFound())) + })), guards: Rc::new(Vec::new()), } } pub(crate) fn take_guards(&mut self) -> Vec> { - std::mem::take(Rc::get_mut(&mut self.guards).unwrap()) + mem::take(Rc::get_mut(&mut self.guards).unwrap()) } } impl ServiceFactory for Route { - type Config = (); type Response = ServiceResponse; type Error = Error; - type InitError = (); + type Config = (); type Service = RouteService; - type Future = CreateRouteService; + type InitError = (); + type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { - CreateRouteService { - fut: self.service.new_service(()), - guards: self.guards.clone(), - } - } -} + let fut = self.service.new_service(()); + let guards = self.guards.clone(); -pub struct CreateRouteService { - fut: LocalBoxFuture<'static, Result>, - guards: Rc>>, -} - -impl Future for CreateRouteService { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - match this.fut.as_mut().poll(cx)? { - Poll::Ready(service) => Poll::Ready(Ok(RouteService { - service, - guards: this.guards.clone(), - })), - Poll::Pending => Poll::Pending, - } + Box::pin(async move { + let service = fut.await?; + Ok(RouteService { service, guards }) + }) } } pub struct RouteService { - service: BoxedRouteService, + service: BoxService, guards: Rc>>, } impl RouteService { + // TODO: does this need to take &mut ? pub fn check(&self, req: &mut ServiceRequest) -> bool { - for f in self.guards.iter() { - if !f.check(req.head()) { + let guard_ctx = req.guard_ctx(); + + for guard in self.guards.iter() { + if !guard.check(&guard_ctx) { return false; } } @@ -121,9 +83,7 @@ impl Service for RouteService { type Error = Error; type Future = LocalBoxFuture<'static, Result>; - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } + actix_service::forward_ready!(service); fn call(&self, req: ServiceRequest) -> Self::Future { self.service.call(req) @@ -133,7 +93,8 @@ impl Service for RouteService { impl Route { /// Add method guard to the route. /// - /// ```rust + /// # Examples + /// ``` /// # use actix_web::*; /// # fn main() { /// App::new().service(web::resource("/path").route( @@ -153,7 +114,8 @@ impl Route { /// Add guard to the route. /// - /// ```rust + /// # Examples + /// ``` /// # use actix_web::*; /// # fn main() { /// App::new().service(web::resource("/path").route( @@ -171,9 +133,10 @@ impl Route { /// Set handler function, use request extractors for parameters. /// - /// ```rust + /// # Examples + /// ``` /// use actix_web::{web, http, App}; - /// use serde_derive::Deserialize; + /// use serde::Deserialize; /// /// #[derive(Deserialize)] /// struct Info { @@ -185,19 +148,16 @@ impl Route { /// format!("Welcome {}!", info.username) /// } /// - /// fn main() { - /// let app = App::new().service( - /// web::resource("/{username}/index.html") // <- define path parameters - /// .route(web::get().to(index)) // <- register handler - /// ); - /// } + /// let app = App::new().service( + /// web::resource("/{username}/index.html") // <- define path parameters + /// .route(web::get().to(index)) // <- register handler + /// ); /// ``` /// /// It is possible to use multiple extractors for one handler function. - /// - /// ```rust + /// ``` /// # use std::collections::HashMap; - /// # use serde_derive::Deserialize; + /// # use serde::Deserialize; /// use actix_web::{web, App}; /// /// #[derive(Deserialize)] @@ -206,107 +166,89 @@ impl Route { /// } /// /// /// extract path info using serde - /// async fn index(path: web::Path, query: web::Query>, body: web::Json) -> String { + /// async fn index( + /// path: web::Path, + /// query: web::Query>, + /// body: web::Json + /// ) -> String { /// format!("Welcome {}!", path.username) /// } /// - /// fn main() { - /// let app = App::new().service( - /// web::resource("/{username}/index.html") // <- define path parameters - /// .route(web::get().to(index)) - /// ); - /// } + /// let app = App::new().service( + /// web::resource("/{username}/index.html") // <- define path parameters + /// .route(web::get().to(index)) + /// ); /// ``` - pub fn to(mut self, handler: F) -> Self + pub fn to(mut self, handler: F) -> Self where - F: Handler, - T: FromRequest + 'static, - R: Future + 'static, - R::Output: Responder + 'static, + F: Handler, + Args: FromRequest + 'static, + F::Output: Responder + 'static, { - self.service = Box::new(RouteNewService::new(HandlerService::new(handler))); + self.service = handler_service(handler); self } -} -struct RouteNewService -where - T: ServiceFactory, -{ - service: T, -} - -impl RouteNewService -where - T: ServiceFactory, - T::Future: 'static, - T::Service: 'static, - >::Future: 'static, -{ - pub fn new(service: T) -> Self { - RouteNewService { service } - } -} - -impl ServiceFactory for RouteNewService -where - T: ServiceFactory, - T::Future: 'static, - T::Service: 'static, - >::Future: 'static, -{ - type Response = ServiceResponse; - type Error = Error; - type Config = (); - type Service = BoxedRouteService; - type InitError = (); - type Future = LocalBoxFuture<'static, Result>; - - fn new_service(&self, _: ()) -> Self::Future { - self.service - .new_service(()) - .map(|result| match result { - Ok(service) => { - let service = Box::new(RouteServiceWrapper { service }) as _; - Ok(service) - } - Err(_) => Err(()), - }) - .boxed_local() - } -} - -struct RouteServiceWrapper> { - service: T, -} - -impl Service for RouteServiceWrapper -where - T::Future: 'static, - T: Service, -{ - type Response = ServiceResponse; - type Error = Error; - type Future = LocalBoxFuture<'static, Result>; - - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } - - fn call(&self, req: ServiceRequest) -> Self::Future { - Box::pin(self.service.call(req)) + /// Set raw service to be constructed and called as the request handler. + /// + /// # Examples + /// ``` + /// # use std::convert::Infallible; + /// # use futures_util::future::LocalBoxFuture; + /// # use actix_web::{*, dev::*, http::header}; + /// struct HelloWorld; + /// + /// impl Service for HelloWorld { + /// type Response = ServiceResponse; + /// type Error = Infallible; + /// type Future = LocalBoxFuture<'static, Result>; + /// + /// dev::always_ready!(); + /// + /// fn call(&self, req: ServiceRequest) -> Self::Future { + /// let (req, _) = req.into_parts(); + /// + /// let res = HttpResponse::Ok() + /// .insert_header(header::ContentType::plaintext()) + /// .body("Hello world!"); + /// + /// Box::pin(async move { Ok(ServiceResponse::new(req, res)) }) + /// } + /// } + /// + /// App::new().route( + /// "/", + /// web::get().service(fn_factory(|| async { Ok(HelloWorld) })), + /// ); + /// ``` + pub fn service(mut self, service_factory: S) -> Self + where + S: ServiceFactory< + ServiceRequest, + Response = ServiceResponse, + Error = E, + InitError = (), + Config = (), + > + 'static, + E: Into + 'static, + { + self.service = boxed::factory(service_factory.map_err(Into::into)); + self } } #[cfg(test)] mod tests { - use std::time::Duration; + use std::{convert::Infallible, time::Duration}; use actix_rt::time::sleep; use bytes::Bytes; - use serde_derive::Serialize; + use futures_core::future::LocalBoxFuture; + use serde::Serialize; - use crate::http::{Method, StatusCode}; + use crate::dev::{always_ready, fn_factory, fn_service, Service}; + use crate::http::{header, Method, StatusCode}; + use crate::service::{ServiceRequest, ServiceResponse}; use crate::test::{call_service, init_service, read_body, TestRequest}; use crate::{error, web, App, HttpResponse}; @@ -327,7 +269,7 @@ mod tests { })) .route(web::post().to(|| async { sleep(Duration::from_millis(100)).await; - Ok::<_, ()>(HttpResponse::Created()) + Ok::<_, Infallible>(HttpResponse::Created()) })) .route(web::delete().to(|| async { sleep(Duration::from_millis(100)).await; @@ -380,4 +322,65 @@ mod tests { let body = read_body(resp).await; assert_eq!(body, Bytes::from_static(b"{\"name\":\"test\"}")); } + + #[actix_rt::test] + async fn test_service_handler() { + struct HelloWorld; + + impl Service for HelloWorld { + type Response = ServiceResponse; + type Error = crate::Error; + type Future = LocalBoxFuture<'static, Result>; + + always_ready!(); + + fn call(&self, req: ServiceRequest) -> Self::Future { + let (req, _) = req.into_parts(); + + let res = HttpResponse::Ok() + .insert_header(header::ContentType::plaintext()) + .body("Hello world!"); + + Box::pin(async move { Ok(ServiceResponse::new(req, res)) }) + } + } + + let srv = init_service( + App::new() + .route( + "/hello", + web::get().service(fn_factory(|| async { Ok(HelloWorld) })), + ) + .route( + "/bye", + web::get().service(fn_factory(|| async { + Ok::<_, ()>(fn_service(|req: ServiceRequest| async { + let (req, _) = req.into_parts(); + + let res = HttpResponse::Ok() + .insert_header(header::ContentType::plaintext()) + .body("Goodbye, and thanks for all the fish!"); + + Ok::<_, Infallible>(ServiceResponse::new(req, res)) + })) + })), + ), + ) + .await; + + let req = TestRequest::get().uri("/hello").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + let body = read_body(resp).await; + assert_eq!(body, Bytes::from_static(b"Hello world!")); + + let req = TestRequest::get().uri("/bye").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + let body = read_body(resp).await; + assert_eq!( + body, + Bytes::from_static(b"Goodbye, and thanks for all the fish!") + ); + } } diff --git a/src/scope.rs b/src/scope.rs index dd02501b0..fa9807f42 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -1,159 +1,160 @@ -use std::cell::RefCell; -use std::fmt; -use std::future::Future; -use std::rc::Rc; -use std::task::Poll; +use std::{cell::RefCell, fmt, future::Future, marker::PhantomData, mem, rc::Rc}; -use actix_http::Extensions; +use actix_http::{ + body::{BoxBody, MessageBody}, + Extensions, +}; use actix_router::{ResourceDef, Router}; -use actix_service::boxed::{self, BoxService, BoxServiceFactory}; use actix_service::{ - apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt, - Transform, + apply, apply_fn_factory, boxed, IntoServiceFactory, Service, ServiceFactory, + ServiceFactoryExt, Transform, }; use futures_core::future::LocalBoxFuture; use futures_util::future::join_all; -use crate::config::ServiceConfig; -use crate::data::Data; -use crate::dev::{AppService, HttpServiceFactory}; -use crate::error::Error; -use crate::guard::Guard; -use crate::resource::Resource; -use crate::rmap::ResourceMap; -use crate::route::Route; -use crate::service::{ - AppServiceFactory, ServiceFactoryWrapper, ServiceRequest, ServiceResponse, +use crate::{ + config::ServiceConfig, + data::Data, + dev::AppService, + guard::Guard, + rmap::ResourceMap, + service::{ + AppServiceFactory, BoxedHttpService, BoxedHttpServiceFactory, HttpServiceFactory, + ServiceFactoryWrapper, ServiceRequest, ServiceResponse, + }, + Error, Resource, Route, }; type Guards = Vec>; -type HttpService = BoxService; -type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; -/// Resources scope. +/// A collection of [`Route`]s, [`Resource`]s, or other services that share a common path prefix. /// -/// Scope is a set of resources with common root path. -/// Scopes collect multiple paths under a common path prefix. -/// Scope path can contain variable path segments as resources. -/// Scope prefix is always complete path segment, i.e `/app` would -/// be converted to a `/app/` and it would not match `/app` path. +/// The `Scope`'s path can contain [dynamic segments]. The dynamic segments can be extracted from +/// requests using the [`Path`](crate::web::Path) extractor or +/// with [`HttpRequest::match_info()`](crate::HttpRequest::match_info). /// -/// You can get variable path segments from `HttpRequest::match_info()`. -/// `Path` extractor also is able to extract scope level variable segments. +/// # Avoid Trailing Slashes +/// Avoid using trailing slashes in the scope prefix (e.g., `web::scope("/scope/")`). It will almost +/// certainly not have the expected behavior. See the [documentation on resource definitions][pat] +/// to understand why this is the case and how to correctly construct scope/prefix definitions. /// -/// ```rust +/// # Examples +/// ``` /// use actix_web::{web, App, HttpResponse}; /// -/// fn main() { -/// let app = App::new().service( -/// web::scope("/{project_id}/") -/// .service(web::resource("/path1").to(|| async { HttpResponse::Ok() })) -/// .service(web::resource("/path2").route(web::get().to(|| HttpResponse::Ok()))) -/// .service(web::resource("/path3").route(web::head().to(|| HttpResponse::MethodNotAllowed()))) -/// ); -/// } +/// let app = App::new().service( +/// web::scope("/{project_id}/") +/// .service(web::resource("/path1").to(|| async { "OK" })) +/// .service(web::resource("/path2").route(web::get().to(|| HttpResponse::Ok()))) +/// .service(web::resource("/path3").route(web::head().to(HttpResponse::MethodNotAllowed))) +/// ); /// ``` /// /// In the above example three routes get registered: -/// * /{project_id}/path1 - responds to all http method -/// * /{project_id}/path2 - `GET` requests -/// * /{project_id}/path3 - `HEAD` requests -pub struct Scope { +/// - /{project_id}/path1 - responds to all HTTP methods +/// - /{project_id}/path2 - responds to `GET` requests +/// - /{project_id}/path3 - responds to `HEAD` requests +/// +/// [pat]: crate::dev::ResourceDef#prefix-resources +/// [dynamic segments]: crate::dev::ResourceDef#dynamic-segments +pub struct Scope { endpoint: T, rdef: String, app_data: Option, services: Vec>, guards: Vec>, - default: Option>, + default: Option>, external: Vec, factory_ref: Rc>>, + _phantom: PhantomData, } impl Scope { /// Create a new scope pub fn new(path: &str) -> Scope { - let fref = Rc::new(RefCell::new(None)); + let factory_ref = Rc::new(RefCell::new(None)); + Scope { - endpoint: ScopeEndpoint::new(fref.clone()), + endpoint: ScopeEndpoint::new(Rc::clone(&factory_ref)), rdef: path.to_string(), app_data: None, guards: Vec::new(), services: Vec::new(), default: None, external: Vec::new(), - factory_ref: fref, + factory_ref, + _phantom: Default::default(), } } } -impl Scope +impl Scope where T: ServiceFactory< ServiceRequest, Config = (), - Response = ServiceResponse, + Response = ServiceResponse, Error = Error, InitError = (), >, + B: 'static, { /// Add match guard to a scope. /// - /// ```rust + /// ``` /// use actix_web::{web, guard, App, HttpRequest, HttpResponse}; /// /// async fn index(data: web::Path<(String, String)>) -> &'static str { /// "Welcome!" /// } /// - /// fn main() { - /// let app = App::new().service( - /// web::scope("/app") - /// .guard(guard::Header("content-type", "text/plain")) - /// .route("/test1", web::get().to(index)) - /// .route("/test2", web::post().to(|r: HttpRequest| { - /// HttpResponse::MethodNotAllowed() - /// })) - /// ); - /// } + /// let app = App::new().service( + /// web::scope("/app") + /// .guard(guard::Header("content-type", "text/plain")) + /// .route("/test1", web::get().to(index)) + /// .route("/test2", web::post().to(|r: HttpRequest| { + /// HttpResponse::MethodNotAllowed() + /// })) + /// ); /// ``` pub fn guard(mut self, guard: G) -> Self { self.guards.push(Box::new(guard)); self } - /// Set or override application data. Application data could be accessed - /// by using `Data` extractor where `T` is data type. - /// - /// ```rust - /// use std::cell::Cell; - /// use actix_web::{web, App, HttpResponse, Responder}; - /// - /// struct MyData { - /// counter: Cell, - /// } - /// - /// async fn index(data: web::Data) -> impl Responder { - /// data.counter.set(data.counter.get() + 1); - /// HttpResponse::Ok() - /// } - /// - /// fn main() { - /// let app = App::new().service( - /// web::scope("/app") - /// .data(MyData{ counter: Cell::new(0) }) - /// .service( - /// web::resource("/index.html").route( - /// web::get().to(index))) - /// ); - /// } - /// ``` - pub fn data(self, data: U) -> Self { - self.app_data(Data::new(data)) - } - /// Add scope data. /// - /// Data of different types from parent contexts will still be accessible. + /// Data of different types from parent contexts will still be accessible. Any `Data` types + /// set here can be extracted in handlers using the `Data` extractor. + /// + /// # Examples + /// ``` + /// use std::cell::Cell; + /// use actix_web::{web, App, HttpRequest, HttpResponse, Responder}; + /// + /// struct MyData { + /// count: std::cell::Cell, + /// } + /// + /// async fn handler(req: HttpRequest, counter: web::Data) -> impl Responder { + /// // note this cannot use the Data extractor because it was not added with it + /// let incr = *req.app_data::().unwrap(); + /// assert_eq!(incr, 3); + /// + /// // update counter using other value from app data + /// counter.count.set(counter.count.get() + incr); + /// + /// HttpResponse::Ok().body(counter.count.get().to_string()) + /// } + /// + /// let app = App::new().service( + /// web::scope("/app") + /// .app_data(3usize) + /// .app_data(web::Data::new(MyData { count: Default::default() })) + /// .route("/", web::get().to(handler)) + /// ); + /// ``` + #[doc(alias = "manage")] pub fn app_data(mut self, data: U) -> Self { self.app_data .get_or_insert_with(Extensions::new) @@ -162,15 +163,20 @@ where self } - /// Run external configuration as part of the scope building - /// process + /// Add scope data after wrapping in `Data`. /// - /// This function is useful for moving parts of configuration to a - /// different module or even library. For example, - /// some of the resource's configuration could be moved to different module. + /// Deprecated in favor of [`app_data`](Self::app_data). + #[deprecated(since = "4.0.0", note = "Use `.app_data(Data::new(val))` instead.")] + pub fn data(self, data: U) -> Self { + self.app_data(Data::new(data)) + } + + /// Run external configuration as part of the scope building process. /// - /// ```rust - /// # extern crate actix_web; + /// This function is useful for moving parts of configuration to a different module or library. + /// For example, some of the resource's configuration could be moved to different module. + /// + /// ``` /// use actix_web::{web, middleware, App, HttpResponse}; /// /// // this function could be located in different module @@ -181,28 +187,29 @@ where /// ); /// } /// - /// fn main() { - /// let app = App::new() - /// .wrap(middleware::Logger::default()) - /// .service( - /// web::scope("/api") - /// .configure(config) - /// ) - /// .route("/index.html", web::get().to(|| HttpResponse::Ok())); - /// } + /// let app = App::new() + /// .wrap(middleware::Logger::default()) + /// .service( + /// web::scope("/api") + /// .configure(config) + /// ) + /// .route("/index.html", web::get().to(|| HttpResponse::Ok())); /// ``` - pub fn configure(mut self, f: F) -> Self + pub fn configure(mut self, cfg_fn: F) -> Self where F: FnOnce(&mut ServiceConfig), { let mut cfg = ServiceConfig::new(); - f(&mut cfg); + cfg_fn(&mut cfg); + self.services.extend(cfg.services); self.external.extend(cfg.external); + // TODO: add Extensions::is_empty check and conditionally insert data self.app_data .get_or_insert_with(Extensions::new) .extend(cfg.app_data); + self } @@ -216,7 +223,7 @@ where /// * *Scope* is a set of resources with common root path. /// * "StaticFiles" is a service for static files support /// - /// ```rust + /// ``` /// use actix_web::{web, App, HttpRequest}; /// /// struct AppState; @@ -225,13 +232,11 @@ where /// "Welcome!" /// } /// - /// fn main() { - /// let app = App::new().service( - /// web::scope("/app").service( - /// web::scope("/v1") - /// .service(web::resource("/test1").to(index))) - /// ); - /// } + /// let app = App::new().service( + /// web::scope("/app").service( + /// web::scope("/v1") + /// .service(web::resource("/test1").to(index))) + /// ); /// ``` pub fn service(mut self, factory: F) -> Self where @@ -248,20 +253,18 @@ where /// This method can be called multiple times, in that case /// multiple resources with one route would be registered for same resource path. /// - /// ```rust + /// ``` /// use actix_web::{web, App, HttpResponse}; /// /// async fn index(data: web::Path<(String, String)>) -> &'static str { /// "Welcome!" /// } /// - /// fn main() { - /// let app = App::new().service( - /// web::scope("/app") - /// .route("/test1", web::get().to(index)) - /// .route("/test2", web::post().to(|| HttpResponse::MethodNotAllowed())) - /// ); - /// } + /// let app = App::new().service( + /// web::scope("/app") + /// .route("/test1", web::get().to(index)) + /// .route("/test2", web::post().to(|| HttpResponse::MethodNotAllowed())) + /// ); /// ``` pub fn route(self, path: &str, mut route: Route) -> Self { self.service( @@ -293,32 +296,29 @@ where self } - /// Registers middleware, in the form of a middleware component (type), - /// that runs during inbound processing in the request - /// life-cycle (request -> response), modifying request as - /// necessary, across all requests managed by the *Scope*. Scope-level - /// middleware is more limited in what it can modify, relative to Route or - /// Application level middleware, in that Scope-level middleware can not modify - /// ServiceResponse. + /// Registers middleware, in the form of a middleware component (type), that runs during inbound + /// processing in the request life-cycle (request -> response), modifying request as necessary, + /// across all requests managed by the *Scope*. /// /// Use middleware when you need to read or modify *every* request in some way. - pub fn wrap( + pub fn wrap( self, mw: M, ) -> Scope< impl ServiceFactory< ServiceRequest, Config = (), - Response = ServiceResponse, + Response = ServiceResponse, Error = Error, InitError = (), >, + B1, > where M: Transform< T::Service, ServiceRequest, - Response = ServiceResponse, + Response = ServiceResponse, Error = Error, InitError = (), >, @@ -332,56 +332,54 @@ where default: self.default, external: self.external, factory_ref: self.factory_ref, + _phantom: PhantomData, } } - /// Registers middleware, in the form of a closure, that runs during inbound - /// processing in the request life-cycle (request -> response), modifying - /// request as necessary, across all requests managed by the *Scope*. - /// Scope-level middleware is more limited in what it can modify, relative - /// to Route or Application level middleware, in that Scope-level middleware - /// can not modify ServiceResponse. + /// Registers middleware, in the form of a closure, that runs during inbound processing in the + /// request life-cycle (request -> response), modifying request as necessary, across all + /// requests managed by the *Scope*. /// - /// ```rust + /// # Examples + /// ``` /// use actix_service::Service; /// use actix_web::{web, App}; - /// use actix_web::http::{header::CONTENT_TYPE, HeaderValue}; + /// use actix_web::http::header::{CONTENT_TYPE, HeaderValue}; /// /// async fn index() -> &'static str { /// "Welcome!" /// } /// - /// fn main() { - /// let app = App::new().service( - /// web::scope("/app") - /// .wrap_fn(|req, srv| { - /// let fut = srv.call(req); - /// async { - /// let mut res = fut.await?; - /// res.headers_mut().insert( - /// CONTENT_TYPE, HeaderValue::from_static("text/plain"), - /// ); - /// Ok(res) - /// } - /// }) - /// .route("/index.html", web::get().to(index))); - /// } + /// let app = App::new().service( + /// web::scope("/app") + /// .wrap_fn(|req, srv| { + /// let fut = srv.call(req); + /// async { + /// let mut res = fut.await?; + /// res.headers_mut().insert( + /// CONTENT_TYPE, HeaderValue::from_static("text/plain"), + /// ); + /// Ok(res) + /// } + /// }) + /// .route("/index.html", web::get().to(index))); /// ``` - pub fn wrap_fn( + pub fn wrap_fn( self, mw: F, ) -> Scope< impl ServiceFactory< ServiceRequest, Config = (), - Response = ServiceResponse, + Response = ServiceResponse, Error = Error, InitError = (), >, + B1, > where F: Fn(ServiceRequest, &T::Service) -> R + Clone, - R: Future>, + R: Future, Error>>, { Scope { endpoint: apply_fn_factory(self.endpoint, mw), @@ -392,19 +390,21 @@ where default: self.default, external: self.external, factory_ref: self.factory_ref, + _phantom: PhantomData, } } } -impl HttpServiceFactory for Scope +impl HttpServiceFactory for Scope where T: ServiceFactory< ServiceRequest, Config = (), - Response = ServiceResponse, + Response = ServiceResponse, Error = Error, InitError = (), > + 'static, + B: MessageBody + 'static, { fn register(mut self, config: &mut AppService) { // update default resource if needed @@ -419,13 +419,12 @@ where let mut rmap = ResourceMap::new(ResourceDef::root_prefix(&self.rdef)); // external resources - for mut rdef in std::mem::take(&mut self.external) { + for mut rdef in mem::take(&mut self.external) { rmap.add(&mut rdef, None); } // complete scope pipeline creation *self.factory_ref.borrow_mut() = Some(ScopeFactory { - app_data: self.app_data.take().map(Rc::new), default, services: cfg .into_services() @@ -447,20 +446,39 @@ where Some(self.guards) }; + let scope_data = self.app_data.map(Rc::new); + + // wraps endpoint service (including middleware) call and injects app data for this scope + let endpoint = apply_fn_factory(self.endpoint, move |mut req: ServiceRequest, srv| { + if let Some(ref data) = scope_data { + req.add_data_container(Rc::clone(data)); + } + + let fut = srv.call(req); + + async { Ok(fut.await?.map_into_boxed_body()) } + }); + // register final service config.register_service( ResourceDef::root_prefix(&self.rdef), guards, - self.endpoint, + endpoint, Some(Rc::new(rmap)), ) } } pub struct ScopeFactory { - app_data: Option>, - services: Rc<[(ResourceDef, HttpNewService, RefCell>)]>, - default: Rc, + #[allow(clippy::type_complexity)] + services: Rc< + [( + ResourceDef, + BoxedHttpServiceFactory, + RefCell>, + )], + >, + default: Rc, } impl ServiceFactory for ScopeFactory { @@ -486,8 +504,6 @@ impl ServiceFactory for ScopeFactory { } })); - let app_data = self.app_data.clone(); - Box::pin(async move { let default = default_fut.await?; @@ -503,19 +519,14 @@ impl ServiceFactory for ScopeFactory { }) .finish(); - Ok(ScopeService { - app_data, - router, - default, - }) + Ok(ScopeService { router, default }) }) } } pub struct ScopeService { - app_data: Option>, - router: Router>>, - default: HttpService, + router: Router>>, + default: BoxedHttpService, } impl Service for ScopeService { @@ -526,21 +537,20 @@ impl Service for ScopeService { actix_service::always_ready!(); fn call(&self, mut req: ServiceRequest) -> Self::Future { - let res = self.router.recognize_checked(&mut req, |req, guards| { + let res = self.router.recognize_fn(&mut req, |req, guards| { if let Some(ref guards) = guards { - for f in guards { - if !f.check(req.head()) { + let guard_ctx = req.guard_ctx(); + + for guard in guards { + if !guard.check(&guard_ctx) { return false; } } } + true }); - if let Some(ref app_data) = self.app_data { - req.add_data_container(app_data.clone()); - } - if let Some((srv, _info)) = res { srv.call(req) } else { @@ -576,15 +586,49 @@ impl ServiceFactory for ScopeEndpoint { #[cfg(test)] mod tests { use actix_service::Service; + use actix_utils::future::ok; use bytes::Bytes; - use futures_util::future::ok; - use crate::dev::{Body, ResponseBody}; - use crate::http::{header, HeaderValue, Method, StatusCode}; - use crate::middleware::DefaultHeaders; - use crate::service::ServiceRequest; - use crate::test::{call_service, init_service, read_body, TestRequest}; - use crate::{guard, web, App, HttpRequest, HttpResponse}; + use super::*; + use crate::{ + guard, + http::{ + header::{self, HeaderValue}, + Method, StatusCode, + }, + middleware::{Compat, DefaultHeaders}, + service::{ServiceRequest, ServiceResponse}, + test::{assert_body_eq, call_service, init_service, read_body, TestRequest}, + web, App, HttpMessage, HttpRequest, HttpResponse, + }; + + #[test] + fn can_be_returned_from_fn() { + fn my_scope() -> Scope { + web::scope("/test") + .service(web::resource("").route(web::get().to(|| async { "hello" }))) + } + + fn my_compat_scope() -> Scope< + impl ServiceFactory< + ServiceRequest, + Config = (), + Response = ServiceResponse, + Error = Error, + InitError = (), + >, + > { + web::scope("/test-compat") + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async { Ok(fut.await?.map_into_right_body::<()>()) } + }) + .wrap(Compat::noop()) + .service(web::resource("").route(web::get().to(|| async { "hello" }))) + } + + App::new().service(my_scope()).service(my_compat_scope()); + } #[actix_rt::test] async fn test_scope() { @@ -745,20 +789,13 @@ mod tests { .await; let req = TestRequest::with_uri("/ab-project1/path1").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - - match resp.response().body() { - ResponseBody::Body(Body::Bytes(ref b)) => { - let bytes = b.clone(); - assert_eq!(bytes, Bytes::from_static(b"project: project1")); - } - _ => panic!(), - } + let res = srv.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_body_eq!(res, b"project: project1"); let req = TestRequest::with_uri("/aa-project1/path1").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let res = srv.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[actix_rt::test] @@ -846,16 +883,9 @@ mod tests { .await; let req = TestRequest::with_uri("/app/project_1/path1").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); - - match resp.response().body() { - ResponseBody::Body(Body::Bytes(ref b)) => { - let bytes = b.clone(); - assert_eq!(bytes, Bytes::from_static(b"project: project_1")); - } - _ => panic!(), - } + let res = srv.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::CREATED); + assert_body_eq!(res, b"project: project_1"); } #[actix_rt::test] @@ -874,20 +904,13 @@ mod tests { .await; let req = TestRequest::with_uri("/app/test/1/path1").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); - - match resp.response().body() { - ResponseBody::Body(Body::Bytes(ref b)) => { - let bytes = b.clone(); - assert_eq!(bytes, Bytes::from_static(b"project: test - 1")); - } - _ => panic!(), - } + let res = srv.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::CREATED); + assert_body_eq!(res, b"project: test - 1"); let req = TestRequest::with_uri("/app/test/1/path2").to_request(); - let resp = srv.call(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); + let res = srv.call(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); } #[actix_rt::test] @@ -916,10 +939,7 @@ mod tests { async fn test_default_resource_propagation() { let srv = init_service( App::new() - .service( - web::scope("/app1") - .default_service(web::resource("").to(HttpResponse::BadRequest)), - ) + .service(web::scope("/app1").default_service(web::to(HttpResponse::BadRequest))) .service(web::scope("/app2")) .default_service(|r: ServiceRequest| { ok(r.into_response(HttpResponse::MethodNotAllowed())) @@ -947,7 +967,7 @@ mod tests { web::scope("app") .wrap( DefaultHeaders::new() - .header(header::CONTENT_TYPE, HeaderValue::from_static("0001")), + .add((header::CONTENT_TYPE, HeaderValue::from_static("0001"))), ) .service(web::resource("/test").route(web::get().to(HttpResponse::Ok))), ), @@ -963,6 +983,29 @@ mod tests { ); } + #[actix_rt::test] + async fn test_middleware_body_type() { + // Compile test that Scope accepts any body type; test for `EitherBody` + let srv = init_service( + App::new().service( + web::scope("app") + .wrap_fn(|req, srv| { + let fut = srv.call(req); + async { Ok(fut.await?.map_into_right_body::<()>()) } + }) + .service(web::resource("/test").route(web::get().to(|| async { "hello" }))), + ), + ) + .await; + + // test if `MessageBody::try_into_bytes()` is preserved across scope layer + use actix_http::body::MessageBody as _; + let req = TestRequest::with_uri("/app/test").to_request(); + let resp = call_service(&srv, req).await; + let body = resp.into_body(); + assert_eq!(body.try_into_bytes().unwrap(), b"hello".as_ref()); + } + #[actix_rt::test] async fn test_middleware_fn() { let srv = init_service( @@ -991,6 +1034,43 @@ mod tests { ); } + #[actix_rt::test] + async fn test_middleware_app_data() { + let srv = init_service( + App::new().service( + web::scope("app") + .app_data(1usize) + .wrap_fn(|req, srv| { + assert_eq!(req.app_data::(), Some(&1usize)); + req.extensions_mut().insert(1usize); + srv.call(req) + }) + .route("/test", web::get().to(HttpResponse::Ok)) + .default_service(|req: ServiceRequest| async move { + let (req, _) = req.into_parts(); + + assert_eq!(req.extensions().get::(), Some(&1)); + + Ok(ServiceResponse::new( + req, + HttpResponse::BadRequest().finish(), + )) + }), + ), + ) + .await; + + let req = TestRequest::with_uri("/app/test").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/app/default").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + // allow deprecated {App, Scope}::data + #[allow(deprecated)] #[actix_rt::test] async fn test_override_data() { let srv = init_service(App::new().data(1usize).service( @@ -1009,6 +1089,8 @@ mod tests { assert_eq!(resp.status(), StatusCode::OK); } + // allow deprecated `{App, Scope}::data` + #[allow(deprecated)] #[actix_rt::test] async fn test_override_data_default_service() { let srv = init_service(App::new().data(1usize).service( @@ -1114,4 +1196,70 @@ mod tests { Bytes::from_static(b"http://localhost:8080/a/b/c/12345") ); } + + #[actix_rt::test] + async fn dynamic_scopes() { + let srv = init_service( + App::new().service( + web::scope("/{a}/").service( + web::scope("/{b}/") + .route("", web::get().to(|_: HttpRequest| HttpResponse::Created())) + .route( + "/", + web::get().to(|_: HttpRequest| HttpResponse::Accepted()), + ) + .route("/{c}", web::get().to(|_: HttpRequest| HttpResponse::Ok())), + ), + ), + ) + .await; + + // note the unintuitive behavior with trailing slashes on scopes with dynamic segments + let req = TestRequest::with_uri("/a//b//c").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/a//b/").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::CREATED); + + let req = TestRequest::with_uri("/a//b//").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::ACCEPTED); + + let req = TestRequest::with_uri("/a//b//c/d").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + let srv = init_service( + App::new().service( + web::scope("/{a}").service( + web::scope("/{b}") + .route("", web::get().to(|_: HttpRequest| HttpResponse::Created())) + .route( + "/", + web::get().to(|_: HttpRequest| HttpResponse::Accepted()), + ) + .route("/{c}", web::get().to(|_: HttpRequest| HttpResponse::Ok())), + ), + ), + ) + .await; + + let req = TestRequest::with_uri("/a/b/c").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let req = TestRequest::with_uri("/a/b").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::CREATED); + + let req = TestRequest::with_uri("/a/b/").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::ACCEPTED); + + let req = TestRequest::with_uri("/a/b/c/d").to_request(); + let resp = call_service(&srv, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + } } diff --git a/src/server.rs b/src/server.rs index d69d6570d..ed0c965b3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,25 +6,18 @@ use std::{ sync::{Arc, Mutex}, }; -use actix_http::{ - body::MessageBody, Error, Extensions, HttpService, KeepAlive, Request, Response, -}; +use actix_http::{body::MessageBody, Extensions, HttpService, KeepAlive, Request, Response}; use actix_server::{Server, ServerBuilder}; -use actix_service::{map_config, IntoServiceFactory, Service, ServiceFactory}; - -#[cfg(unix)] -use actix_http::Protocol; -#[cfg(unix)] -use actix_service::pipeline_factory; -#[cfg(unix)] -use futures_util::future::ok; +use actix_service::{ + map_config, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _, +}; #[cfg(feature = "openssl")] -use actix_tls::accept::openssl::{AlpnError, SslAcceptor, SslAcceptorBuilder}; +use actix_tls::accept::openssl::reexports::{AlpnError, SslAcceptor, SslAcceptorBuilder}; #[cfg(feature = "rustls")] -use actix_tls::accept::rustls::ServerConfig as RustlsServerConfig; +use actix_tls::accept::rustls::reexports::ServerConfig as RustlsServerConfig; -use crate::config::AppConfig; +use crate::{config::AppConfig, Error}; struct Socket { scheme: &'static str, @@ -42,7 +35,7 @@ struct Config { /// /// Create new HTTP server with application factory. /// -/// ```rust,no_run +/// ```no_run /// use actix_web::{web, App, HttpResponse, HttpServer}; /// /// #[actix_rt::main] @@ -70,6 +63,7 @@ where backlog: u32, sockets: Vec, builder: ServerBuilder, + #[allow(clippy::type_complexity)] on_connect_fn: Option>, _phantom: PhantomData<(S, B)>, } @@ -78,12 +72,14 @@ impl HttpServer where F: Fn() -> I + Send + Clone + 'static, I: IntoServiceFactory, + S: ServiceFactory + 'static, S::Error: Into + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, >::Future: 'static, S::Service: 'static, + B: MessageBody + 'static, { /// Create new HTTP server with application factory @@ -106,14 +102,14 @@ where /// Sets function that will be called once before each connection is handled. /// It will receive a `&std::any::Any`, which contains underlying connection type and an - /// [Extensions] container so that request-local data can be passed to middleware and handlers. + /// [Extensions] container so that connection data can be accessed in middleware and handlers. /// - /// For example: - /// - `actix_tls::openssl::SslStream` when using openssl. - /// - `actix_tls::rustls::TlsStream` when using rustls. + /// # Connection Types + /// - `actix_tls::accept::openssl::TlsStream` when using openssl. + /// - `actix_tls::accept::rustls::TlsStream` when using rustls. /// - `actix_web::rt::net::TcpStream` when no encryption is used. /// - /// See `on_connect` example for additional details. + /// See the `on_connect` example for additional details. pub fn on_connect(self, f: CB) -> HttpServer where CB: Fn(&dyn Any, &mut Extensions) + Send + Sync + 'static, @@ -160,7 +156,7 @@ where /// /// By default max connections is set to a 25k. pub fn max_connections(mut self, num: usize) -> Self { - self.builder = self.builder.maxconn(num); + self.builder = self.builder.max_concurrent_connections(num); self } @@ -177,6 +173,16 @@ where self } + /// Set max number of threads for each worker's blocking task thread pool. + /// + /// One thread pool is set up **per worker**; not shared across workers. + /// + /// By default set to 512 / workers. + pub fn worker_max_blocking_threads(mut self, num: usize) -> Self { + self.builder = self.builder.worker_max_blocking_threads(num); + self + } + /// Set server keep-alive setting. /// /// By default keep alive is set to a 5 seconds. @@ -224,7 +230,7 @@ where self } - /// Stop actix system. + /// Stop Actix `System` after server shutdown. pub fn system_exit(mut self) -> Self { self.builder = self.builder.system_exit(); self @@ -283,19 +289,24 @@ where let c = cfg.lock().unwrap(); let host = c.host.clone().unwrap_or_else(|| format!("{}", addr)); - let svc = HttpService::build() + let mut svc = HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) + .client_disconnect(c.client_shutdown) .local_addr(addr); - let svc = if let Some(handler) = on_connect_fn.clone() { - svc.on_connect_ext(move |io: &_, ext: _| (handler)(io as &dyn Any, ext)) - } else { - svc + if let Some(handler) = on_connect_fn.clone() { + svc = svc.on_connect_ext(move |io: &_, ext: _| { + (handler)(io as &dyn Any, ext) + }) }; - svc.finish(map_config(factory(), move |_| { - AppConfig::new(false, addr, host.clone()) + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + svc.finish(map_config(fac, move |_| { + AppConfig::new(false, host.clone(), addr) })) .tcp() })?; @@ -339,7 +350,8 @@ where let svc = HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) - .client_disconnect(c.client_shutdown); + .client_disconnect(c.client_shutdown) + .local_addr(addr); let svc = if let Some(handler) = on_connect_fn.clone() { svc.on_connect_ext(move |io: &_, ext: _| { @@ -349,18 +361,23 @@ where svc }; - svc.finish(map_config(factory(), move |_| { - AppConfig::new(true, addr, host.clone()) + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + svc.finish(map_config(fac, move |_| { + AppConfig::new(true, host.clone(), addr) })) .openssl(acceptor.clone()) })?; + Ok(self) } #[cfg(feature = "rustls")] /// Use listener for accepting incoming tls connection requests /// - /// This method sets alpn protocols to "h2" and "http/1.1" + /// This method prepends alpn protocols "h2" and "http/1.1" to configured ones pub fn listen_rustls( self, lst: net::TcpListener, @@ -402,11 +419,16 @@ where svc }; - svc.finish(map_config(factory(), move |_| { - AppConfig::new(true, addr, host.clone()) + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + svc.finish(map_config(fac, move |_| { + AppConfig::new(true, host.clone(), addr) })) .rustls(config.clone()) })?; + Ok(self) } @@ -438,17 +460,15 @@ where } } - if !success { - if let Some(e) = err.take() { - Err(e) - } else { - Err(io::Error::new( - io::ErrorKind::Other, - "Can not bind to address.", - )) - } - } else { + if success { Ok(sockets) + } else if let Some(e) = err.take() { + Err(e) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "Can not bind to address.", + )) } } @@ -473,7 +493,7 @@ where #[cfg(feature = "rustls")] /// Start listening for incoming tls connections. /// - /// This method sets alpn protocols to "h2" and "http/1.1" + /// This method prepends alpn protocols "h2" and "http/1.1" to configured ones pub fn bind_rustls( mut self, addr: A, @@ -489,7 +509,9 @@ where #[cfg(unix)] /// Start listening for unix domain (UDS) connections on existing listener. pub fn listen_uds(mut self, lst: std::os::unix::net::UnixListener) -> io::Result { + use actix_http::Protocol; use actix_rt::net::UnixStream; + use actix_service::{fn_service, ServiceFactoryExt as _}; let cfg = self.config.clone(); let factory = self.factory.clone(); @@ -500,46 +522,54 @@ where addr: socket_addr, }); - let addr = format!("actix-web-service-{:?}", lst.local_addr()?); + let addr = lst.local_addr()?; + let name = format!("actix-web-service-{:?}", addr); let on_connect_fn = self.on_connect_fn.clone(); - self.builder = self.builder.listen_uds(addr, lst, move || { + self.builder = self.builder.listen_uds(name, lst, move || { let c = cfg.lock().unwrap(); let config = AppConfig::new( false, - socket_addr, c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), + socket_addr, ); - pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then({ - let svc = HttpService::build() + fn_service(|io: UnixStream| async { Ok((io, Protocol::Http1, None)) }).and_then({ + let mut svc = HttpService::build() .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout); + .client_timeout(c.client_timeout) + .client_disconnect(c.client_shutdown); - let svc = if let Some(handler) = on_connect_fn.clone() { - svc.on_connect_ext(move |io: &_, ext: _| (&*handler)(io as &dyn Any, ext)) - } else { - svc - }; + if let Some(handler) = on_connect_fn.clone() { + svc = svc + .on_connect_ext(move |io: &_, ext: _| (&*handler)(io as &dyn Any, ext)); + } - svc.finish(map_config(factory(), move |_| config.clone())) + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + svc.finish(map_config(fac, move |_| config.clone())) }) })?; Ok(self) } - #[cfg(unix)] /// Start listening for incoming unix domain connections. + #[cfg(unix)] pub fn bind_uds(mut self, addr: A) -> io::Result where A: AsRef, { + use actix_http::Protocol; use actix_rt::net::UnixStream; + use actix_service::{fn_service, ServiceFactoryExt as _}; let cfg = self.config.clone(); let factory = self.factory.clone(); let socket_addr = net::SocketAddr::new(net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), 8080); + self.sockets.push(Socket { scheme: "http", addr: socket_addr, @@ -552,17 +582,24 @@ where let c = cfg.lock().unwrap(); let config = AppConfig::new( false, - socket_addr, c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), + socket_addr, ); - pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then( + + let fac = factory() + .into_factory() + .map_err(|err| err.into().error_response()); + + fn_service(|io: UnixStream| async { Ok((io, Protocol::Http1, None)) }).and_then( HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) - .finish(map_config(factory(), move |_| config.clone())), + .client_disconnect(c.client_shutdown) + .finish(map_config(fac, move |_| config.clone())), ) }, )?; + Ok(self) } } @@ -587,7 +624,7 @@ where /// This methods panics if no socket address can be bound or an `Actix` system is not yet /// configured. /// - /// ```rust,no_run + /// ```no_run /// use std::io; /// use actix_web::{web, App, HttpResponse, HttpServer}; /// @@ -606,21 +643,18 @@ where fn create_tcp_listener(addr: net::SocketAddr, backlog: u32) -> io::Result { use socket2::{Domain, Protocol, Socket, Type}; - let domain = match addr { - net::SocketAddr::V4(_) => Domain::ipv4(), - net::SocketAddr::V6(_) => Domain::ipv6(), - }; - let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?; + let domain = Domain::for_address(addr); + let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; socket.set_reuse_address(true)?; socket.bind(&addr.into())?; // clamp backlog to max u32 that fits in i32 range let backlog = cmp::min(backlog, i32::MAX as u32) as i32; socket.listen(backlog)?; - Ok(socket.into_tcp_listener()) + Ok(net::TcpListener::from(socket)) } -#[cfg(feature = "openssl")] /// Configure `SslAcceptorBuilder` with custom server flags. +#[cfg(feature = "openssl")] fn openssl_acceptor(mut builder: SslAcceptorBuilder) -> io::Result { builder.set_alpn_select_callback(|_, protocols| { const H2: &[u8] = b"\x02h2"; diff --git a/src/service.rs b/src/service.rs index fcbe61a02..f15cbfc9f 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,21 +1,35 @@ -use std::cell::{Ref, RefMut}; -use std::rc::Rc; -use std::{fmt, net}; - -use actix_http::body::{Body, MessageBody, ResponseBody}; -use actix_http::http::{HeaderMap, Method, StatusCode, Uri, Version}; -use actix_http::{ - Error, Extensions, HttpMessage, Payload, PayloadStream, RequestHead, Response, ResponseHead, +use std::{ + cell::{Ref, RefMut}, + fmt, net, + rc::Rc, }; -use actix_router::{IntoPattern, Path, Resource, ResourceDef, Url}; -use actix_service::{IntoServiceFactory, ServiceFactory}; -use crate::config::{AppConfig, AppService}; -use crate::dev::insert_slash; -use crate::guard::Guard; -use crate::info::ConnectionInfo; -use crate::request::HttpRequest; -use crate::rmap::ResourceMap; +use actix_http::{ + body::{BoxBody, EitherBody, MessageBody}, + header::HeaderMap, + BoxedPayloadStream, Extensions, HttpMessage, Method, Payload, RequestHead, Response, + ResponseHead, StatusCode, Uri, Version, +}; +use actix_router::{IntoPatterns, Path, Patterns, Resource, ResourceDef, Url}; +use actix_service::{ + boxed::{BoxService, BoxServiceFactory}, + IntoServiceFactory, ServiceFactory, +}; +#[cfg(feature = "cookies")] +use cookie::{Cookie, ParseError as CookieParseError}; + +use crate::{ + config::{AppConfig, AppService}, + dev::ensure_leading_slash, + guard::{Guard, GuardContext}, + info::ConnectionInfo, + rmap::ResourceMap, + Error, HttpRequest, HttpResponse, +}; + +pub(crate) type BoxedHttpService = BoxService, Error>; +pub(crate) type BoxedHttpServiceFactory = + BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; pub trait HttpServiceFactory { fn register(self, config: &mut AppService); @@ -55,9 +69,9 @@ where } } -/// An service http request +/// A service level request wrapper. /// -/// ServiceRequest allows mutable access to request's internal structures +/// Allows mutable access to request's internal structures. pub struct ServiceRequest { req: HttpRequest, payload: Payload, @@ -75,6 +89,12 @@ impl ServiceRequest { (self.req, self.payload) } + /// Get mutable access to inner `HttpRequest` and `Payload` + #[inline] + pub fn parts_mut(&mut self) -> (&mut HttpRequest, &mut Payload) { + (&mut self.req, &mut self.payload) + } + /// Construct request from parts. pub fn from_parts(req: HttpRequest, payload: Payload) -> Self { Self { req, payload } @@ -93,20 +113,21 @@ impl ServiceRequest { /// Create service response #[inline] pub fn into_response>>(self, res: R) -> ServiceResponse { - ServiceResponse::new(self.req, res.into()) + let res = HttpResponse::from(res.into()); + ServiceResponse::new(self.req, res) } /// Create service response for error #[inline] - pub fn error_response>(self, err: E) -> ServiceResponse { - let res: Response = err.into().into(); - ServiceResponse::new(self.req, res.into_body()) + pub fn error_response>(self, err: E) -> ServiceResponse { + let res = HttpResponse::from_error(err.into()); + ServiceResponse::new(self.req, res) } /// This method returns reference to the request head #[inline] pub fn head(&self) -> &RequestHead { - &self.req.head() + self.req.head() } /// This method returns reference to the request head @@ -151,16 +172,10 @@ impl ServiceRequest { self.head().uri.path() } - /// The query string in the URL. - /// - /// E.g., id=10 + /// Counterpart to [`HttpRequest::query_string`]. #[inline] pub fn query_string(&self) -> &str { - if let Some(query) = self.uri().query().as_ref() { - query - } else { - "" - } + self.req.query_string() } /// Peer socket address. @@ -179,7 +194,7 @@ impl ServiceRequest { /// Get *ConnectionInfo* for the current request. #[inline] pub fn connection_info(&self) -> Ref<'_, ConnectionInfo> { - ConnectionInfo::get(self.head(), &*self.app_config()) + self.req.connection_info() } /// Get a reference to the Path parameters. @@ -193,26 +208,26 @@ impl ServiceRequest { self.req.match_info() } - /// Counterpart to [`HttpRequest::match_name`](super::HttpRequest::match_name()). + /// Counterpart to [`HttpRequest::match_name`]. #[inline] pub fn match_name(&self) -> Option<&str> { self.req.match_name() } - /// Counterpart to [`HttpRequest::match_pattern`](super::HttpRequest::match_pattern()). + /// Counterpart to [`HttpRequest::match_pattern`]. #[inline] pub fn match_pattern(&self) -> Option { self.req.match_pattern() } - #[inline] /// Get a mutable reference to the Path parameters. + #[inline] pub fn match_info_mut(&mut self) -> &mut Path { self.req.match_info_mut() } - #[inline] /// Get a reference to a `ResourceMap` of current application. + #[inline] pub fn resource_map(&self) -> &ResourceMap { self.req.resource_map() } @@ -223,7 +238,8 @@ impl ServiceRequest { self.req.app_config() } - /// Counterpart to [`HttpRequest::app_data`](super::HttpRequest::app_data()). + /// Counterpart to [`HttpRequest::app_data`]. + #[inline] pub fn app_data(&self) -> Option<&T> { for container in self.req.inner.app_data.iter().rev() { if let Some(data) = container.get::() { @@ -234,7 +250,39 @@ impl ServiceRequest { None } + /// Counterpart to [`HttpRequest::conn_data`]. + #[inline] + pub fn conn_data(&self) -> Option<&T> { + self.req.conn_data() + } + + /// Counterpart to [`HttpRequest::req_data`]. + #[inline] + pub fn req_data(&self) -> Ref<'_, Extensions> { + self.req.req_data() + } + + /// Counterpart to [`HttpRequest::req_data_mut`]. + #[inline] + pub fn req_data_mut(&self) -> RefMut<'_, Extensions> { + self.req.req_data_mut() + } + + #[cfg(feature = "cookies")] + #[inline] + pub fn cookies(&self) -> Result>>, CookieParseError> { + self.req.cookies() + } + + /// Return request cookie. + #[cfg(feature = "cookies")] + #[inline] + pub fn cookie(&self, name: &str) -> Option> { + self.req.cookie(name) + } + /// Set request payload. + #[inline] pub fn set_payload(&mut self, payload: Payload) { self.payload = payload; } @@ -249,16 +297,27 @@ impl ServiceRequest { .app_data .push(extensions); } + + /// Creates a context object for use with a [guard](crate::guard). + /// + /// Useful if you are implementing + #[inline] + pub fn guard_ctx(&self) -> GuardContext<'_> { + GuardContext { req: self } + } } -impl Resource for ServiceRequest { - fn resource_path(&mut self) -> &mut Path { +impl Resource for ServiceRequest { + type Path = Url; + + #[inline] + fn resource_path(&mut self) -> &mut Path { self.match_info_mut() } } impl HttpMessage for ServiceRequest { - type Stream = PayloadStream; + type Stream = BoxedPayloadStream; #[inline] /// Returns Request's headers. @@ -307,36 +366,35 @@ impl fmt::Debug for ServiceRequest { } } -pub struct ServiceResponse { +/// A service level response wrapper. +pub struct ServiceResponse { request: HttpRequest, - response: Response, + response: HttpResponse, +} + +impl ServiceResponse { + /// Create service response from the error + pub fn from_err>(err: E, request: HttpRequest) -> Self { + let response = HttpResponse::from_error(err); + ServiceResponse { request, response } + } } impl ServiceResponse { /// Create service response instance - pub fn new(request: HttpRequest, response: Response) -> Self { + pub fn new(request: HttpRequest, response: HttpResponse) -> Self { ServiceResponse { request, response } } - /// Create service response from the error - pub fn from_err>(err: E, request: HttpRequest) -> Self { - let e: Error = err.into(); - let res: Response = e.into(); - ServiceResponse { - request, - response: res.into_body(), - } - } - /// Create service response for error #[inline] - pub fn error_response>(self, err: E) -> Self { - Self::from_err(err, self.request) + pub fn error_response>(self, err: E) -> ServiceResponse { + ServiceResponse::from_err(err, self.request) } /// Create service response #[inline] - pub fn into_response(self, response: Response) -> ServiceResponse { + pub fn into_response(self, response: HttpResponse) -> ServiceResponse { ServiceResponse::new(self.request, response) } @@ -348,13 +406,13 @@ impl ServiceResponse { /// Get reference to response #[inline] - pub fn response(&self) -> &Response { + pub fn response(&self) -> &HttpResponse { &self.response } /// Get mutable reference to response #[inline] - pub fn response_mut(&mut self) -> &mut Response { + pub fn response_mut(&mut self) -> &mut HttpResponse { &mut self.response } @@ -370,38 +428,29 @@ impl ServiceResponse { self.response.headers() } - #[inline] /// Returns mutable response's headers. + #[inline] pub fn headers_mut(&mut self) -> &mut HeaderMap { self.response.headers_mut() } - /// Execute closure and in case of error convert it to response. - pub fn checked_expr(mut self, f: F) -> Self - where - F: FnOnce(&mut Self) -> Result<(), E>, - E: Into, - { - match f(&mut self) { - Ok(_) => self, - Err(err) => { - let res: Response = err.into().into(); - ServiceResponse::new(self.request, res.into_body()) - } - } + /// Destructures `ServiceResponse` into request and response components. + #[inline] + pub fn into_parts(self) -> (HttpRequest, HttpResponse) { + (self.request, self.response) } /// Extract response body - pub fn take_body(&mut self) -> ResponseBody { - self.response.take_body() + #[inline] + pub fn into_body(self) -> B { + self.response.into_body() } -} -impl ServiceResponse { /// Set a new body + #[inline] pub fn map_body(self, f: F) -> ServiceResponse where - F: FnOnce(&mut ResponseHead, ResponseBody) -> ResponseBody, + F: FnOnce(&mut ResponseHead, B) -> B2, { let response = self.response.map_body(f); @@ -410,15 +459,43 @@ impl ServiceResponse { request: self.request, } } + + #[inline] + pub fn map_into_left_body(self) -> ServiceResponse> { + self.map_body(|_, body| EitherBody::left(body)) + } + + #[inline] + pub fn map_into_right_body(self) -> ServiceResponse> { + self.map_body(|_, body| EitherBody::right(body)) + } + + #[inline] + pub fn map_into_boxed_body(self) -> ServiceResponse + where + B: MessageBody + 'static, + { + self.map_body(|_, body| body.boxed()) + } } -impl From> for Response { - fn from(res: ServiceResponse) -> Response { +impl From> for HttpResponse { + fn from(res: ServiceResponse) -> HttpResponse { res.response } } -impl fmt::Debug for ServiceResponse { +impl From> for Response { + fn from(res: ServiceResponse) -> Response { + res.response.into() + } +} + +impl fmt::Debug for ServiceResponse +where + B: MessageBody, + B::Error: Into, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let res = writeln!( f, @@ -437,14 +514,14 @@ impl fmt::Debug for ServiceResponse { } pub struct WebService { - rdef: Vec, + rdef: Patterns, name: Option, guards: Vec>, } impl WebService { /// Create new `WebService` instance. - pub fn new(path: T) -> Self { + pub fn new(path: T) -> Self { WebService { rdef: path.patterns(), name: None, @@ -454,7 +531,7 @@ impl WebService { /// Set service name. /// - /// Name is used for url generation. + /// Name is used for URL generation. pub fn name(mut self, name: &str) -> Self { self.name = Some(name.to_string()); self @@ -462,7 +539,7 @@ impl WebService { /// Add match guard to a web service. /// - /// ```rust + /// ``` /// use actix_web::{web, guard, dev, App, Error, HttpResponse}; /// /// async fn index(req: dev::ServiceRequest) -> Result { @@ -506,7 +583,7 @@ impl WebService { struct WebServiceImpl { srv: T, - rdef: Vec, + rdef: Patterns, name: Option, guards: Vec>, } @@ -529,13 +606,15 @@ where }; let mut rdef = if config.is_root() || !self.rdef.is_empty() { - ResourceDef::new(insert_slash(self.rdef)) + ResourceDef::new(ensure_leading_slash(self.rdef)) } else { ResourceDef::new(self.rdef) }; + if let Some(ref name) = self.name { - *rdef.name_mut() = name.clone(); + rdef.set_name(name); } + config.register_service(rdef, guards, self.srv, None) } } @@ -547,7 +626,6 @@ where /// The max number of services can be grouped together is 12. /// /// # Examples -/// /// ``` /// use actix_web::{services, web, App}; /// @@ -573,31 +651,28 @@ macro_rules! services { } /// HttpServiceFactory trait impl for tuples -macro_rules! service_tuple ({ $(($n:tt, $T:ident)),+} => { +macro_rules! service_tuple ({ $($T:ident)+ } => { impl<$($T: HttpServiceFactory),+> HttpServiceFactory for ($($T,)+) { + #[allow(non_snake_case)] fn register(self, config: &mut AppService) { - $(self.$n.register(config);)+ + let ($($T,)*) = self; + $($T.register(config);)+ } } }); -#[rustfmt::skip] -mod m { - use super::*; - - service_tuple!((0, A)); - service_tuple!((0, A), (1, B)); - service_tuple!((0, A), (1, B), (2, C)); - service_tuple!((0, A), (1, B), (2, C), (3, D)); - service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E)); - service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F)); - service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G)); - service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H)); - service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I)); - service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I), (9, J)); - service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I), (9, J), (10, K)); - service_tuple!((0, A), (1, B), (2, C), (3, D), (4, E), (5, F), (6, G), (7, H), (8, I), (9, J), (10, K), (11, L)); -} +service_tuple! { A } +service_tuple! { A B } +service_tuple! { A B C } +service_tuple! { A B C D } +service_tuple! { A B C D E } +service_tuple! { A B C D E F } +service_tuple! { A B C D E F G } +service_tuple! { A B C D E F G H } +service_tuple! { A B C D E F G H I } +service_tuple! { A B C D E F G H I J } +service_tuple! { A B C D E F G H I J K } +service_tuple! { A B C D E F G H I J K L } #[cfg(test)] mod tests { @@ -605,7 +680,7 @@ mod tests { use crate::test::{init_service, TestRequest}; use crate::{guard, http, web, App, HttpResponse}; use actix_service::Service; - use futures_util::future::ok; + use actix_utils::future::ok; #[actix_rt::test] async fn test_service() { @@ -632,6 +707,8 @@ mod tests { assert_eq!(resp.status(), http::StatusCode::NOT_FOUND); } + // allow deprecated App::data + #[allow(deprecated)] #[actix_rt::test] async fn test_service_data() { let srv = diff --git a/src/test.rs b/src/test.rs deleted file mode 100644 index 2ec4252b1..000000000 --- a/src/test.rs +++ /dev/null @@ -1,1222 +0,0 @@ -//! Various helpers for Actix applications to use during testing. - -use std::net::SocketAddr; -use std::rc::Rc; -use std::sync::mpsc; -use std::{fmt, net, thread, time}; - -use actix_codec::{AsyncRead, AsyncWrite, Framed}; -#[cfg(feature = "cookies")] -use actix_http::cookie::Cookie; -use actix_http::http::header::{ContentType, IntoHeaderPair}; -use actix_http::http::{Method, StatusCode, Uri, Version}; -use actix_http::test::TestRequest as HttpTestRequest; -use actix_http::{ws, Extensions, HttpService, Request}; -use actix_router::{Path, ResourceDef, Url}; -use actix_rt::{time::sleep, System}; -use actix_service::{map_config, IntoService, IntoServiceFactory, Service, ServiceFactory}; -use awc::error::PayloadError; -use awc::{Client, ClientRequest, ClientResponse, Connector}; -use bytes::{Bytes, BytesMut}; -use futures_core::Stream; -use futures_util::future::ok; -use futures_util::StreamExt; -use serde::de::DeserializeOwned; -use serde::Serialize; -use socket2::{Domain, Protocol, Socket, Type}; - -pub use actix_http::test::TestBuffer; - -use crate::app_service::AppInitServiceState; -use crate::config::AppConfig; -use crate::data::Data; -use crate::dev::{Body, MessageBody, Payload, Server}; -use crate::rmap::ResourceMap; -use crate::service::{ServiceRequest, ServiceResponse}; -use crate::{Error, HttpRequest, HttpResponse}; - -/// Create service that always responds with `HttpResponse::Ok()` -pub fn ok_service( -) -> impl Service, Error = Error> { - default_service(StatusCode::OK) -} - -/// Create service that responds with response with specified status code -pub fn default_service( - status_code: StatusCode, -) -> impl Service, Error = Error> { - (move |req: ServiceRequest| { - ok(req.into_response(HttpResponse::build(status_code).finish())) - }) - .into_service() -} - -/// This method accepts application builder instance, and constructs -/// service. -/// -/// ```rust -/// use actix_service::Service; -/// use actix_web::{test, web, App, HttpResponse, http::StatusCode}; -/// -/// #[actix_rt::test] -/// async fn test_init_service() { -/// let app = test::init_service( -/// App::new() -/// .service(web::resource("/test").to(|| async { HttpResponse::Ok() })) -/// ).await; -/// -/// // Create request object -/// let req = test::TestRequest::with_uri("/test").to_request(); -/// -/// // Execute application -/// let resp = app.call(req).await.unwrap(); -/// assert_eq!(resp.status(), StatusCode::OK); -/// } -/// ``` -pub async fn init_service( - app: R, -) -> impl Service, Error = E> -where - R: IntoServiceFactory, - S: ServiceFactory, Error = E>, - S::InitError: std::fmt::Debug, -{ - try_init_service(app) - .await - .expect("service initilization failed") -} - -/// Fallible version of init_service that allows testing data factory errors. -pub(crate) async fn try_init_service( - app: R, -) -> Result, Error = E>, S::InitError> -where - R: IntoServiceFactory, - S: ServiceFactory, Error = E>, - S::InitError: std::fmt::Debug, -{ - let srv = app.into_factory(); - srv.new_service(AppConfig::default()).await -} - -/// Calls service and waits for response future completion. -/// -/// ```rust -/// use actix_web::{test, web, App, HttpResponse, http::StatusCode}; -/// -/// #[actix_rt::test] -/// async fn test_response() { -/// let app = test::init_service( -/// App::new() -/// .service(web::resource("/test").to(|| async { -/// HttpResponse::Ok() -/// })) -/// ).await; -/// -/// // Create request object -/// let req = test::TestRequest::with_uri("/test").to_request(); -/// -/// // Call application -/// let resp = test::call_service(&app, req).await; -/// assert_eq!(resp.status(), StatusCode::OK); -/// } -/// ``` -pub async fn call_service(app: &S, req: R) -> S::Response -where - S: Service, Error = E>, - E: std::fmt::Debug, -{ - app.call(req).await.unwrap() -} - -/// Helper function that returns a response body of a TestRequest -/// -/// ```rust -/// use actix_web::{test, web, App, HttpResponse, http::header}; -/// use bytes::Bytes; -/// -/// #[actix_rt::test] -/// async fn test_index() { -/// let app = test::init_service( -/// App::new().service( -/// web::resource("/index.html") -/// .route(web::post().to(|| async { -/// HttpResponse::Ok().body("welcome!") -/// }))) -/// ).await; -/// -/// let req = test::TestRequest::post() -/// .uri("/index.html") -/// .header(header::CONTENT_TYPE, "application/json") -/// .to_request(); -/// -/// let result = test::read_response(&app, req).await; -/// assert_eq!(result, Bytes::from_static(b"welcome!")); -/// } -/// ``` -pub async fn read_response(app: &S, req: Request) -> Bytes -where - S: Service, Error = Error>, - B: MessageBody + Unpin, -{ - let mut resp = app - .call(req) - .await - .unwrap_or_else(|_| panic!("read_response failed at application call")); - - let mut body = resp.take_body(); - let mut bytes = BytesMut::new(); - while let Some(item) = body.next().await { - bytes.extend_from_slice(&item.unwrap()); - } - bytes.freeze() -} - -/// Helper function that returns a response body of a ServiceResponse. -/// -/// ```rust -/// use actix_web::{test, web, App, HttpResponse, http::header}; -/// use bytes::Bytes; -/// -/// #[actix_rt::test] -/// async fn test_index() { -/// let app = test::init_service( -/// App::new().service( -/// web::resource("/index.html") -/// .route(web::post().to(|| async { -/// HttpResponse::Ok().body("welcome!") -/// }))) -/// ).await; -/// -/// let req = test::TestRequest::post() -/// .uri("/index.html") -/// .header(header::CONTENT_TYPE, "application/json") -/// .to_request(); -/// -/// let resp = test::call_service(&app, req).await; -/// let result = test::read_body(resp).await; -/// assert_eq!(result, Bytes::from_static(b"welcome!")); -/// } -/// ``` -pub async fn read_body(mut res: ServiceResponse) -> Bytes -where - B: MessageBody + Unpin, -{ - let mut body = res.take_body(); - let mut bytes = BytesMut::new(); - while let Some(item) = body.next().await { - bytes.extend_from_slice(&item.unwrap()); - } - bytes.freeze() -} - -/// Helper function that returns a deserialized response body of a ServiceResponse. -/// -/// ```rust -/// use actix_web::{App, test, web, HttpResponse, http::header}; -/// use serde::{Serialize, Deserialize}; -/// -/// #[derive(Serialize, Deserialize)] -/// pub struct Person { -/// id: String, -/// name: String, -/// } -/// -/// #[actix_rt::test] -/// async fn test_post_person() { -/// let app = test::init_service( -/// App::new().service( -/// web::resource("/people") -/// .route(web::post().to(|person: web::Json| async { -/// HttpResponse::Ok() -/// .json(person)}) -/// )) -/// ).await; -/// -/// let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); -/// -/// let resp = test::TestRequest::post() -/// .uri("/people") -/// .header(header::CONTENT_TYPE, "application/json") -/// .set_payload(payload) -/// .send_request(&mut app) -/// .await; -/// -/// assert!(resp.status().is_success()); -/// -/// let result: Person = test::read_body_json(resp).await; -/// } -/// ``` -pub async fn read_body_json(res: ServiceResponse) -> T -where - B: MessageBody + Unpin, - T: DeserializeOwned, -{ - let body = read_body(res).await; - - serde_json::from_slice(&body) - .unwrap_or_else(|e| panic!("read_response_json failed during deserialization: {}", e)) -} - -pub async fn load_stream(mut stream: S) -> Result -where - S: Stream> + Unpin, -{ - let mut data = BytesMut::new(); - while let Some(item) = stream.next().await { - data.extend_from_slice(&item?); - } - Ok(data.freeze()) -} - -/// Helper function that returns a deserialized response body of a TestRequest -/// -/// ```rust -/// use actix_web::{App, test, web, HttpResponse, http::header}; -/// use serde::{Serialize, Deserialize}; -/// -/// #[derive(Serialize, Deserialize)] -/// pub struct Person { -/// id: String, -/// name: String -/// } -/// -/// #[actix_rt::test] -/// async fn test_add_person() { -/// let app = test::init_service( -/// App::new().service( -/// web::resource("/people") -/// .route(web::post().to(|person: web::Json| async { -/// HttpResponse::Ok() -/// .json(person)}) -/// )) -/// ).await; -/// -/// let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); -/// -/// let req = test::TestRequest::post() -/// .uri("/people") -/// .header(header::CONTENT_TYPE, "application/json") -/// .set_payload(payload) -/// .to_request(); -/// -/// let result: Person = test::read_response_json(&mut app, req).await; -/// } -/// ``` -pub async fn read_response_json(app: &S, req: Request) -> T -where - S: Service, Error = Error>, - B: MessageBody + Unpin, - T: DeserializeOwned, -{ - let body = read_response(app, req).await; - - serde_json::from_slice(&body) - .unwrap_or_else(|_| panic!("read_response_json failed during deserialization")) -} - -/// Test `Request` builder. -/// -/// For unit testing, actix provides a request builder type and a simple handler runner. TestRequest implements a builder-like pattern. -/// You can generate various types of request via TestRequest's methods: -/// * `TestRequest::to_request` creates `actix_http::Request` instance. -/// * `TestRequest::to_srv_request` creates `ServiceRequest` instance, which is used for testing middlewares and chain adapters. -/// * `TestRequest::to_srv_response` creates `ServiceResponse` instance. -/// * `TestRequest::to_http_request` creates `HttpRequest` instance, which is used for testing handlers. -/// -/// ```rust -/// use actix_web::{test, HttpRequest, HttpResponse, HttpMessage}; -/// use actix_web::http::{header, StatusCode}; -/// -/// async fn index(req: HttpRequest) -> HttpResponse { -/// if let Some(hdr) = req.headers().get(header::CONTENT_TYPE) { -/// HttpResponse::Ok().into() -/// } else { -/// HttpResponse::BadRequest().into() -/// } -/// } -/// -/// #[test] -/// fn test_index() { -/// let req = test::TestRequest::default().insert_header("content-type", "text/plain") -/// .to_http_request(); -/// -/// let resp = index(req).await.unwrap(); -/// assert_eq!(resp.status(), StatusCode::OK); -/// -/// let req = test::TestRequest::default().to_http_request(); -/// let resp = index(req).await.unwrap(); -/// assert_eq!(resp.status(), StatusCode::BAD_REQUEST); -/// } -/// ``` -pub struct TestRequest { - req: HttpTestRequest, - rmap: ResourceMap, - config: AppConfig, - path: Path, - peer_addr: Option, - app_data: Extensions, -} - -impl Default for TestRequest { - fn default() -> TestRequest { - TestRequest { - req: HttpTestRequest::default(), - rmap: ResourceMap::new(ResourceDef::new("")), - config: AppConfig::default(), - path: Path::new(Url::new(Uri::default())), - peer_addr: None, - app_data: Extensions::new(), - } - } -} - -#[allow(clippy::wrong_self_convention)] -impl TestRequest { - /// Create TestRequest and set request uri - pub fn with_uri(path: &str) -> TestRequest { - TestRequest::default().uri(path) - } - - /// Create TestRequest and set method to `Method::GET` - pub fn get() -> TestRequest { - TestRequest::default().method(Method::GET) - } - - /// Create TestRequest and set method to `Method::POST` - pub fn post() -> TestRequest { - TestRequest::default().method(Method::POST) - } - - /// Create TestRequest and set method to `Method::PUT` - pub fn put() -> TestRequest { - TestRequest::default().method(Method::PUT) - } - - /// Create TestRequest and set method to `Method::PATCH` - pub fn patch() -> TestRequest { - TestRequest::default().method(Method::PATCH) - } - - /// Create TestRequest and set method to `Method::DELETE` - pub fn delete() -> TestRequest { - TestRequest::default().method(Method::DELETE) - } - - /// Set HTTP version of this request - pub fn version(mut self, ver: Version) -> Self { - self.req.version(ver); - self - } - - /// Set HTTP method of this request - pub fn method(mut self, meth: Method) -> Self { - self.req.method(meth); - self - } - - /// Set HTTP Uri of this request - pub fn uri(mut self, path: &str) -> Self { - self.req.uri(path); - self - } - - /// Insert a header, replacing any that were set with an equivalent field name. - pub fn insert_header(mut self, header: H) -> Self - where - H: IntoHeaderPair, - { - self.req.insert_header(header); - self - } - - /// Append a header, keeping any that were set with an equivalent field name. - pub fn append_header(mut self, header: H) -> Self - where - H: IntoHeaderPair, - { - self.req.append_header(header); - self - } - - /// Set cookie for this request. - #[cfg(feature = "cookies")] - pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { - self.req.cookie(cookie); - self - } - - /// Set request path pattern parameter - pub fn param(mut self, name: &'static str, value: &'static str) -> Self { - self.path.add_static(name, value); - self - } - - /// Set peer addr - pub fn peer_addr(mut self, addr: SocketAddr) -> Self { - self.peer_addr = Some(addr); - self - } - - /// Set request payload - pub fn set_payload>(mut self, data: B) -> Self { - self.req.set_payload(data); - self - } - - /// Serialize `data` to a URL encoded form and set it as the request payload. The `Content-Type` - /// header is set to `application/x-www-form-urlencoded`. - pub fn set_form(mut self, data: &T) -> Self { - let bytes = serde_urlencoded::to_string(data) - .expect("Failed to serialize test data as a urlencoded form"); - self.req.set_payload(bytes); - self.req.insert_header(ContentType::form_url_encoded()); - self - } - - /// Serialize `data` to JSON and set it as the request payload. The `Content-Type` header is - /// set to `application/json`. - pub fn set_json(mut self, data: &T) -> Self { - let bytes = serde_json::to_string(data).expect("Failed to serialize test data to json"); - self.req.set_payload(bytes); - self.req.insert_header(ContentType::json()); - self - } - - /// Set application data. This is equivalent of `App::data()` method - /// for testing purpose. - pub fn data(mut self, data: T) -> Self { - self.app_data.insert(Data::new(data)); - self - } - - /// Set application data. This is equivalent of `App::app_data()` method - /// for testing purpose. - pub fn app_data(mut self, data: T) -> Self { - self.app_data.insert(data); - self - } - - #[cfg(test)] - /// Set request config - pub(crate) fn rmap(mut self, rmap: ResourceMap) -> Self { - self.rmap = rmap; - self - } - - /// Complete request creation and generate `Request` instance - pub fn to_request(mut self) -> Request { - let mut req = self.req.finish(); - req.head_mut().peer_addr = self.peer_addr; - req - } - - /// Complete request creation and generate `ServiceRequest` instance - pub fn to_srv_request(mut self) -> ServiceRequest { - let (mut head, payload) = self.req.finish().into_parts(); - head.peer_addr = self.peer_addr; - self.path.get_mut().update(&head.uri); - - let app_state = AppInitServiceState::new(Rc::new(self.rmap), self.config.clone()); - - ServiceRequest::new( - HttpRequest::new(self.path, head, app_state, Rc::new(self.app_data)), - payload, - ) - } - - /// Complete request creation and generate `ServiceResponse` instance - pub fn to_srv_response(self, res: HttpResponse) -> ServiceResponse { - self.to_srv_request().into_response(res) - } - - /// Complete request creation and generate `HttpRequest` instance - pub fn to_http_request(mut self) -> HttpRequest { - let (mut head, _) = self.req.finish().into_parts(); - head.peer_addr = self.peer_addr; - self.path.get_mut().update(&head.uri); - - let app_state = AppInitServiceState::new(Rc::new(self.rmap), self.config.clone()); - - HttpRequest::new(self.path, head, app_state, Rc::new(self.app_data)) - } - - /// Complete request creation and generate `HttpRequest` and `Payload` instances - pub fn to_http_parts(mut self) -> (HttpRequest, Payload) { - let (mut head, payload) = self.req.finish().into_parts(); - head.peer_addr = self.peer_addr; - self.path.get_mut().update(&head.uri); - - let app_state = AppInitServiceState::new(Rc::new(self.rmap), self.config.clone()); - - let req = HttpRequest::new(self.path, head, app_state, Rc::new(self.app_data)); - - (req, payload) - } - - /// Complete request creation, calls service and waits for response future completion. - pub async fn send_request(self, app: &S) -> S::Response - where - S: Service, Error = E>, - E: std::fmt::Debug, - { - let req = self.to_request(); - call_service(app, req).await - } -} - -/// Start test server with default configuration -/// -/// Test server is very simple server that simplify process of writing -/// integration tests cases for actix web applications. -/// -/// # Examples -/// -/// ```rust -/// use actix_web::{web, test, App, HttpResponse, Error}; -/// -/// async fn my_handler() -> Result { -/// Ok(HttpResponse::Ok().into()) -/// } -/// -/// #[actix_rt::test] -/// async fn test_example() { -/// let srv = test::start( -/// || App::new().service( -/// web::resource("/").to(my_handler)) -/// ); -/// -/// let req = srv.get("/"); -/// let response = req.send().await.unwrap(); -/// assert!(response.status().is_success()); -/// } -/// ``` -pub fn start(factory: F) -> TestServer -where - F: Fn() -> I + Send + Clone + 'static, - I: IntoServiceFactory, - S: ServiceFactory + 'static, - S::Error: Into + 'static, - S::InitError: fmt::Debug, - S::Response: Into> + 'static, - >::Future: 'static, - B: MessageBody + 'static, -{ - start_with(TestServerConfig::default(), factory) -} - -/// Start test server with custom configuration -/// -/// Test server could be configured in different ways, for details check -/// `TestServerConfig` docs. -/// -/// # Examples -/// -/// ```rust -/// use actix_web::{web, test, App, HttpResponse, Error}; -/// -/// async fn my_handler() -> Result { -/// Ok(HttpResponse::Ok().into()) -/// } -/// -/// #[actix_rt::test] -/// async fn test_example() { -/// let srv = test::start_with(test::config().h1(), || -/// App::new().service(web::resource("/").to(my_handler)) -/// ); -/// -/// let req = srv.get("/"); -/// let response = req.send().await.unwrap(); -/// assert!(response.status().is_success()); -/// } -/// ``` -pub fn start_with(cfg: TestServerConfig, factory: F) -> TestServer -where - F: Fn() -> I + Send + Clone + 'static, - I: IntoServiceFactory, - S: ServiceFactory + 'static, - S::Error: Into + 'static, - S::InitError: fmt::Debug, - S::Response: Into> + 'static, - >::Future: 'static, - B: MessageBody + 'static, -{ - let (tx, rx) = mpsc::channel(); - - let ssl = match cfg.stream { - StreamType::Tcp => false, - #[cfg(feature = "openssl")] - StreamType::Openssl(_) => true, - #[cfg(feature = "rustls")] - StreamType::Rustls(_) => true, - }; - - // run server in separate thread - thread::spawn(move || { - let sys = System::new(); - let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); - let local_addr = tcp.local_addr().unwrap(); - let factory = factory.clone(); - let cfg = cfg.clone(); - let ctimeout = cfg.client_timeout; - let builder = Server::build().workers(1).disable_signals(); - - let srv = match cfg.stream { - StreamType::Tcp => match cfg.tp { - HttpVer::Http1 => builder.listen("test", tcp, move || { - let cfg = AppConfig::new(false, local_addr, format!("{}", local_addr)); - HttpService::build() - .client_timeout(ctimeout) - .h1(map_config(factory(), move |_| cfg.clone())) - .tcp() - }), - HttpVer::Http2 => builder.listen("test", tcp, move || { - let cfg = AppConfig::new(false, local_addr, format!("{}", local_addr)); - HttpService::build() - .client_timeout(ctimeout) - .h2(map_config(factory(), move |_| cfg.clone())) - .tcp() - }), - HttpVer::Both => builder.listen("test", tcp, move || { - let cfg = AppConfig::new(false, local_addr, format!("{}", local_addr)); - HttpService::build() - .client_timeout(ctimeout) - .finish(map_config(factory(), move |_| cfg.clone())) - .tcp() - }), - }, - #[cfg(feature = "openssl")] - StreamType::Openssl(acceptor) => match cfg.tp { - HttpVer::Http1 => builder.listen("test", tcp, move || { - let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); - HttpService::build() - .client_timeout(ctimeout) - .h1(map_config(factory(), move |_| cfg.clone())) - .openssl(acceptor.clone()) - }), - HttpVer::Http2 => builder.listen("test", tcp, move || { - let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); - HttpService::build() - .client_timeout(ctimeout) - .h2(map_config(factory(), move |_| cfg.clone())) - .openssl(acceptor.clone()) - }), - HttpVer::Both => builder.listen("test", tcp, move || { - let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); - HttpService::build() - .client_timeout(ctimeout) - .finish(map_config(factory(), move |_| cfg.clone())) - .openssl(acceptor.clone()) - }), - }, - #[cfg(feature = "rustls")] - StreamType::Rustls(config) => match cfg.tp { - HttpVer::Http1 => builder.listen("test", tcp, move || { - let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); - HttpService::build() - .client_timeout(ctimeout) - .h1(map_config(factory(), move |_| cfg.clone())) - .rustls(config.clone()) - }), - HttpVer::Http2 => builder.listen("test", tcp, move || { - let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); - HttpService::build() - .client_timeout(ctimeout) - .h2(map_config(factory(), move |_| cfg.clone())) - .rustls(config.clone()) - }), - HttpVer::Both => builder.listen("test", tcp, move || { - let cfg = AppConfig::new(true, local_addr, format!("{}", local_addr)); - HttpService::build() - .client_timeout(ctimeout) - .finish(map_config(factory(), move |_| cfg.clone())) - .rustls(config.clone()) - }), - }, - } - .unwrap(); - - sys.block_on(async { - let srv = srv.run(); - tx.send((System::current(), srv, local_addr)).unwrap(); - }); - - sys.run() - }); - - let (system, server, addr) = rx.recv().unwrap(); - - let client = { - let connector = { - #[cfg(feature = "openssl")] - { - use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; - - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_verify(SslVerifyMode::NONE); - let _ = builder - .set_alpn_protos(b"\x02h2\x08http/1.1") - .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); - Connector::new() - .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(30000)) - .ssl(builder.build()) - .finish() - } - #[cfg(not(feature = "openssl"))] - { - Connector::new() - .conn_lifetime(time::Duration::from_secs(0)) - .timeout(time::Duration::from_millis(30000)) - .finish() - } - }; - - Client::builder().connector(connector).finish() - }; - - TestServer { - ssl, - addr, - client, - system, - server, - } -} - -#[derive(Clone)] -pub struct TestServerConfig { - tp: HttpVer, - stream: StreamType, - client_timeout: u64, -} - -#[derive(Clone)] -enum HttpVer { - Http1, - Http2, - Both, -} - -#[derive(Clone)] -enum StreamType { - Tcp, - #[cfg(feature = "openssl")] - Openssl(openssl::ssl::SslAcceptor), - #[cfg(feature = "rustls")] - Rustls(rustls::ServerConfig), -} - -impl Default for TestServerConfig { - fn default() -> Self { - TestServerConfig::new() - } -} - -/// Create default test server config -pub fn config() -> TestServerConfig { - TestServerConfig::new() -} - -impl TestServerConfig { - /// Create default server configuration - pub(crate) fn new() -> TestServerConfig { - TestServerConfig { - tp: HttpVer::Both, - stream: StreamType::Tcp, - client_timeout: 5000, - } - } - - /// Start HTTP/1.1 server only - pub fn h1(mut self) -> Self { - self.tp = HttpVer::Http1; - self - } - - /// Start HTTP/2 server only - pub fn h2(mut self) -> Self { - self.tp = HttpVer::Http2; - self - } - - /// Start openssl server - #[cfg(feature = "openssl")] - pub fn openssl(mut self, acceptor: openssl::ssl::SslAcceptor) -> Self { - self.stream = StreamType::Openssl(acceptor); - self - } - - /// Start rustls server - #[cfg(feature = "rustls")] - pub fn rustls(mut self, config: rustls::ServerConfig) -> Self { - self.stream = StreamType::Rustls(config); - self - } - - /// Set server client timeout in milliseconds for first request. - pub fn client_timeout(mut self, val: u64) -> Self { - self.client_timeout = val; - self - } -} - -/// Get first available unused address -pub fn unused_addr() -> net::SocketAddr { - let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); - let socket = Socket::new(Domain::ipv4(), Type::stream(), Some(Protocol::tcp())).unwrap(); - socket.bind(&addr.into()).unwrap(); - socket.set_reuse_address(true).unwrap(); - let tcp = socket.into_tcp_listener(); - tcp.local_addr().unwrap() -} - -/// Test server controller -pub struct TestServer { - addr: net::SocketAddr, - client: awc::Client, - system: actix_rt::System, - ssl: bool, - server: Server, -} - -impl TestServer { - /// Construct test server url - pub fn addr(&self) -> net::SocketAddr { - self.addr - } - - /// Construct test server url - pub fn url(&self, uri: &str) -> String { - let scheme = if self.ssl { "https" } else { "http" }; - - if uri.starts_with('/') { - format!("{}://localhost:{}{}", scheme, self.addr.port(), uri) - } else { - format!("{}://localhost:{}/{}", scheme, self.addr.port(), uri) - } - } - - /// Create `GET` request - pub fn get>(&self, path: S) -> ClientRequest { - self.client.get(self.url(path.as_ref()).as_str()) - } - - /// Create `POST` request - pub fn post>(&self, path: S) -> ClientRequest { - self.client.post(self.url(path.as_ref()).as_str()) - } - - /// Create `HEAD` request - pub fn head>(&self, path: S) -> ClientRequest { - self.client.head(self.url(path.as_ref()).as_str()) - } - - /// Create `PUT` request - pub fn put>(&self, path: S) -> ClientRequest { - self.client.put(self.url(path.as_ref()).as_str()) - } - - /// Create `PATCH` request - pub fn patch>(&self, path: S) -> ClientRequest { - self.client.patch(self.url(path.as_ref()).as_str()) - } - - /// Create `DELETE` request - pub fn delete>(&self, path: S) -> ClientRequest { - self.client.delete(self.url(path.as_ref()).as_str()) - } - - /// Create `OPTIONS` request - pub fn options>(&self, path: S) -> ClientRequest { - self.client.options(self.url(path.as_ref()).as_str()) - } - - /// Connect to test HTTP server - pub fn request>(&self, method: Method, path: S) -> ClientRequest { - self.client.request(method, path.as_ref()) - } - - pub async fn load_body( - &mut self, - mut response: ClientResponse, - ) -> Result - where - S: Stream> + Unpin + 'static, - { - response.body().limit(10_485_760).await - } - - /// Connect to WebSocket server at a given path. - pub async fn ws_at( - &mut self, - path: &str, - ) -> Result, awc::error::WsClientError> { - let url = self.url(path); - let connect = self.client.ws(url).connect(); - connect.await.map(|(_, framed)| framed) - } - - /// Connect to a WebSocket server. - pub async fn ws( - &mut self, - ) -> Result, awc::error::WsClientError> { - self.ws_at("/").await - } - - /// Gracefully stop HTTP server - pub async fn stop(self) { - self.server.stop(true).await; - self.system.stop(); - sleep(time::Duration::from_millis(100)).await; - } -} - -impl Drop for TestServer { - fn drop(&mut self) { - self.system.stop() - } -} - -#[cfg(test)] -mod tests { - use actix_http::HttpMessage; - use serde::{Deserialize, Serialize}; - use std::time::SystemTime; - - use super::*; - use crate::{http::header, web, App, HttpResponse, Responder}; - - #[actix_rt::test] - async fn test_basics() { - let req = TestRequest::default() - .version(Version::HTTP_2) - .insert_header(header::ContentType::json()) - .insert_header(header::Date(SystemTime::now().into())) - .param("test", "123") - .data(10u32) - .app_data(20u64) - .peer_addr("127.0.0.1:8081".parse().unwrap()) - .to_http_request(); - assert!(req.headers().contains_key(header::CONTENT_TYPE)); - assert!(req.headers().contains_key(header::DATE)); - assert_eq!( - req.head().peer_addr, - Some("127.0.0.1:8081".parse().unwrap()) - ); - assert_eq!(&req.match_info()["test"], "123"); - assert_eq!(req.version(), Version::HTTP_2); - let data = req.app_data::>().unwrap(); - assert!(req.app_data::>().is_none()); - assert_eq!(*data.get_ref(), 10); - - assert!(req.app_data::().is_none()); - let data = req.app_data::().unwrap(); - assert_eq!(*data, 20); - } - - #[actix_rt::test] - async fn test_request_methods() { - let app = init_service( - App::new().service( - web::resource("/index.html") - .route(web::put().to(|| HttpResponse::Ok().body("put!"))) - .route(web::patch().to(|| HttpResponse::Ok().body("patch!"))) - .route(web::delete().to(|| HttpResponse::Ok().body("delete!"))), - ), - ) - .await; - - let put_req = TestRequest::put() - .uri("/index.html") - .insert_header((header::CONTENT_TYPE, "application/json")) - .to_request(); - - let result = read_response(&app, put_req).await; - assert_eq!(result, Bytes::from_static(b"put!")); - - let patch_req = TestRequest::patch() - .uri("/index.html") - .insert_header((header::CONTENT_TYPE, "application/json")) - .to_request(); - - let result = read_response(&app, patch_req).await; - assert_eq!(result, Bytes::from_static(b"patch!")); - - let delete_req = TestRequest::delete().uri("/index.html").to_request(); - let result = read_response(&app, delete_req).await; - assert_eq!(result, Bytes::from_static(b"delete!")); - } - - #[actix_rt::test] - async fn test_response() { - let app = init_service( - App::new().service( - web::resource("/index.html") - .route(web::post().to(|| HttpResponse::Ok().body("welcome!"))), - ), - ) - .await; - - let req = TestRequest::post() - .uri("/index.html") - .insert_header((header::CONTENT_TYPE, "application/json")) - .to_request(); - - let result = read_response(&app, req).await; - assert_eq!(result, Bytes::from_static(b"welcome!")); - } - - #[actix_rt::test] - async fn test_send_request() { - let app = init_service( - App::new().service( - web::resource("/index.html") - .route(web::get().to(|| HttpResponse::Ok().body("welcome!"))), - ), - ) - .await; - - let resp = TestRequest::get() - .uri("/index.html") - .send_request(&app) - .await; - - let result = read_body(resp).await; - assert_eq!(result, Bytes::from_static(b"welcome!")); - } - - #[derive(Serialize, Deserialize)] - pub struct Person { - id: String, - name: String, - } - - #[actix_rt::test] - async fn test_response_json() { - let app = init_service(App::new().service(web::resource("/people").route( - web::post().to(|person: web::Json| HttpResponse::Ok().json(person)), - ))) - .await; - - let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); - - let req = TestRequest::post() - .uri("/people") - .insert_header((header::CONTENT_TYPE, "application/json")) - .set_payload(payload) - .to_request(); - - let result: Person = read_response_json(&app, req).await; - assert_eq!(&result.id, "12345"); - } - - #[actix_rt::test] - async fn test_body_json() { - let app = init_service(App::new().service(web::resource("/people").route( - web::post().to(|person: web::Json| HttpResponse::Ok().json(person)), - ))) - .await; - - let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); - - let resp = TestRequest::post() - .uri("/people") - .insert_header((header::CONTENT_TYPE, "application/json")) - .set_payload(payload) - .send_request(&app) - .await; - - let result: Person = read_body_json(resp).await; - assert_eq!(&result.name, "User name"); - } - - #[actix_rt::test] - async fn test_request_response_form() { - let app = init_service(App::new().service(web::resource("/people").route( - web::post().to(|person: web::Form| HttpResponse::Ok().json(person)), - ))) - .await; - - let payload = Person { - id: "12345".to_string(), - name: "User name".to_string(), - }; - - let req = TestRequest::post() - .uri("/people") - .set_form(&payload) - .to_request(); - - assert_eq!(req.content_type(), "application/x-www-form-urlencoded"); - - let result: Person = read_response_json(&app, req).await; - assert_eq!(&result.id, "12345"); - assert_eq!(&result.name, "User name"); - } - - #[actix_rt::test] - async fn test_request_response_json() { - let app = init_service(App::new().service(web::resource("/people").route( - web::post().to(|person: web::Json| HttpResponse::Ok().json(person)), - ))) - .await; - - let payload = Person { - id: "12345".to_string(), - name: "User name".to_string(), - }; - - let req = TestRequest::post() - .uri("/people") - .set_json(&payload) - .to_request(); - - assert_eq!(req.content_type(), "application/json"); - - let result: Person = read_response_json(&app, req).await; - assert_eq!(&result.id, "12345"); - assert_eq!(&result.name, "User name"); - } - - #[actix_rt::test] - async fn test_async_with_block() { - async fn async_with_block() -> Result { - let res = web::block(move || Some(4usize).ok_or("wrong")).await; - - match res { - Ok(value) => Ok(HttpResponse::Ok() - .content_type("text/plain") - .body(format!("Async with block value: {:?}", value))), - Err(_) => panic!("Unexpected"), - } - } - - let app = - init_service(App::new().service(web::resource("/index.html").to(async_with_block))) - .await; - - let req = TestRequest::post().uri("/index.html").to_request(); - let res = app.call(req).await.unwrap(); - assert!(res.status().is_success()); - } - - #[actix_rt::test] - async fn test_server_data() { - async fn handler(data: web::Data) -> impl Responder { - assert_eq!(**data, 10); - HttpResponse::Ok() - } - - let app = init_service( - App::new() - .data(10usize) - .service(web::resource("/index.html").to(handler)), - ) - .await; - - let req = TestRequest::post().uri("/index.html").to_request(); - let res = app.call(req).await.unwrap(); - assert!(res.status().is_success()); - } -} diff --git a/src/test/mod.rs b/src/test/mod.rs new file mode 100644 index 000000000..a29dfc437 --- /dev/null +++ b/src/test/mod.rs @@ -0,0 +1,81 @@ +//! Various helpers for Actix applications to use during testing. +//! +//! # Creating A Test Service +//! - [`init_service`] +//! +//! # Off-The-Shelf Test Services +//! - [`ok_service`] +//! - [`simple_service`] +//! +//! # Calling Test Service +//! - [`TestRequest`] +//! - [`call_service`] +//! - [`call_and_read_body`] +//! - [`call_and_read_body_json`] +//! +//! # Reading Response Payloads +//! - [`read_body`] +//! - [`read_body_json`] + +// TODO: more docs on generally how testing works with these parts + +pub use actix_http::test::TestBuffer; + +mod test_request; +mod test_services; +mod test_utils; + +pub use self::test_request::TestRequest; +#[allow(deprecated)] +pub use self::test_services::{default_service, ok_service, simple_service}; +#[allow(deprecated)] +pub use self::test_utils::{ + call_and_read_body, call_and_read_body_json, call_service, init_service, read_body, + read_body_json, read_response, read_response_json, +}; + +#[cfg(test)] +pub(crate) use self::test_utils::try_init_service; + +/// Reduces boilerplate code when testing expected response payloads. +/// +/// Must be used inside an async test. Works for both `ServiceRequest` and `HttpRequest`. +/// +/// # Examples +/// ``` +/// use actix_web::{http::StatusCode, HttpResponse}; +/// +/// let res = HttpResponse::with_body(StatusCode::OK, "http response"); +/// assert_body_eq!(res, b"http response"); +/// ``` +#[cfg(test)] +macro_rules! assert_body_eq { + ($res:ident, $expected:expr) => { + assert_eq!( + ::actix_http::body::to_bytes($res.into_body()) + .await + .expect("error reading test response body"), + ::bytes::Bytes::from_static($expected), + ) + }; +} + +#[cfg(test)] +pub(crate) use assert_body_eq; + +#[cfg(test)] +mod tests { + use super::*; + use crate::{http::StatusCode, service::ServiceResponse, HttpResponse}; + + #[actix_rt::test] + async fn assert_body_works_for_service_and_regular_response() { + let res = HttpResponse::with_body(StatusCode::OK, "http response"); + assert_body_eq!(res, b"http response"); + + let req = TestRequest::default().to_http_request(); + let res = HttpResponse::with_body(StatusCode::OK, "service response"); + let res = ServiceResponse::new(req, res); + assert_body_eq!(res, b"service response"); + } +} diff --git a/src/test/test_request.rs b/src/test/test_request.rs new file mode 100644 index 000000000..fc42253d7 --- /dev/null +++ b/src/test/test_request.rs @@ -0,0 +1,434 @@ +use std::{borrow::Cow, net::SocketAddr, rc::Rc}; + +use actix_http::{test::TestRequest as HttpTestRequest, Request}; +use serde::Serialize; + +use crate::{ + app_service::AppInitServiceState, + config::AppConfig, + data::Data, + dev::{Extensions, Path, Payload, ResourceDef, Service, Url}, + http::header::ContentType, + http::{header::TryIntoHeaderPair, Method, Uri, Version}, + rmap::ResourceMap, + service::{ServiceRequest, ServiceResponse}, + test, + web::Bytes, + HttpRequest, HttpResponse, +}; + +#[cfg(feature = "cookies")] +use crate::cookie::{Cookie, CookieJar}; + +/// Test `Request` builder. +/// +/// For unit testing, actix provides a request builder type and a simple handler runner. TestRequest implements a builder-like pattern. +/// You can generate various types of request via TestRequest's methods: +/// * `TestRequest::to_request` creates `actix_http::Request` instance. +/// * `TestRequest::to_srv_request` creates `ServiceRequest` instance, which is used for testing middlewares and chain adapters. +/// * `TestRequest::to_srv_response` creates `ServiceResponse` instance. +/// * `TestRequest::to_http_request` creates `HttpRequest` instance, which is used for testing handlers. +/// +/// ``` +/// use actix_web::{test, HttpRequest, HttpResponse, HttpMessage}; +/// use actix_web::http::{header, StatusCode}; +/// +/// async fn index(req: HttpRequest) -> HttpResponse { +/// if let Some(hdr) = req.headers().get(header::CONTENT_TYPE) { +/// HttpResponse::Ok().into() +/// } else { +/// HttpResponse::BadRequest().into() +/// } +/// } +/// +/// #[actix_web::test] +/// async fn test_index() { +/// let req = test::TestRequest::default().insert_header("content-type", "text/plain") +/// .to_http_request(); +/// +/// let resp = index(req).await.unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// +/// let req = test::TestRequest::default().to_http_request(); +/// let resp = index(req).await.unwrap(); +/// assert_eq!(resp.status(), StatusCode::BAD_REQUEST); +/// } +/// ``` +pub struct TestRequest { + req: HttpTestRequest, + rmap: ResourceMap, + config: AppConfig, + path: Path, + peer_addr: Option, + app_data: Extensions, + #[cfg(feature = "cookies")] + cookies: CookieJar, +} + +impl Default for TestRequest { + fn default() -> TestRequest { + TestRequest { + req: HttpTestRequest::default(), + rmap: ResourceMap::new(ResourceDef::new("")), + config: AppConfig::default(), + path: Path::new(Url::new(Uri::default())), + peer_addr: None, + app_data: Extensions::new(), + #[cfg(feature = "cookies")] + cookies: CookieJar::new(), + } + } +} + +#[allow(clippy::wrong_self_convention)] +impl TestRequest { + /// Create TestRequest and set request uri + pub fn with_uri(path: &str) -> TestRequest { + TestRequest::default().uri(path) + } + + /// Create TestRequest and set method to `Method::GET` + pub fn get() -> TestRequest { + TestRequest::default().method(Method::GET) + } + + /// Create TestRequest and set method to `Method::POST` + pub fn post() -> TestRequest { + TestRequest::default().method(Method::POST) + } + + /// Create TestRequest and set method to `Method::PUT` + pub fn put() -> TestRequest { + TestRequest::default().method(Method::PUT) + } + + /// Create TestRequest and set method to `Method::PATCH` + pub fn patch() -> TestRequest { + TestRequest::default().method(Method::PATCH) + } + + /// Create TestRequest and set method to `Method::DELETE` + pub fn delete() -> TestRequest { + TestRequest::default().method(Method::DELETE) + } + + /// Set HTTP version of this request + pub fn version(mut self, ver: Version) -> Self { + self.req.version(ver); + self + } + + /// Set HTTP method of this request + pub fn method(mut self, meth: Method) -> Self { + self.req.method(meth); + self + } + + /// Set HTTP URI of this request + pub fn uri(mut self, path: &str) -> Self { + self.req.uri(path); + self + } + + /// Insert a header, replacing any that were set with an equivalent field name. + pub fn insert_header(mut self, header: impl TryIntoHeaderPair) -> Self { + self.req.insert_header(header); + self + } + + /// Append a header, keeping any that were set with an equivalent field name. + pub fn append_header(mut self, header: impl TryIntoHeaderPair) -> Self { + self.req.append_header(header); + self + } + + /// Set cookie for this request. + #[cfg(feature = "cookies")] + pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { + self.cookies.add(cookie.into_owned()); + self + } + + /// Set request path pattern parameter. + /// + /// # Examples + /// ``` + /// use actix_web::test::TestRequest; + /// + /// let req = TestRequest::default().param("foo", "bar"); + /// let req = TestRequest::default().param("foo".to_owned(), "bar".to_owned()); + /// ``` + pub fn param( + mut self, + name: impl Into>, + value: impl Into>, + ) -> Self { + self.path.add_static(name, value); + self + } + + /// Set peer addr. + pub fn peer_addr(mut self, addr: SocketAddr) -> Self { + self.peer_addr = Some(addr); + self + } + + /// Set request payload. + pub fn set_payload(mut self, data: impl Into) -> Self { + self.req.set_payload(data); + self + } + + /// Serialize `data` to a URL encoded form and set it as the request payload. + /// + /// The `Content-Type` header is set to `application/x-www-form-urlencoded`. + pub fn set_form(mut self, data: impl Serialize) -> Self { + let bytes = serde_urlencoded::to_string(&data) + .expect("Failed to serialize test data as a urlencoded form"); + self.req.set_payload(bytes); + self.req.insert_header(ContentType::form_url_encoded()); + self + } + + /// Serialize `data` to JSON and set it as the request payload. + /// + /// The `Content-Type` header is set to `application/json`. + pub fn set_json(mut self, data: impl Serialize) -> Self { + let bytes = + serde_json::to_string(&data).expect("Failed to serialize test data to json"); + self.req.set_payload(bytes); + self.req.insert_header(ContentType::json()); + self + } + + /// Set application data. This is equivalent of `App::data()` method + /// for testing purpose. + pub fn data(mut self, data: T) -> Self { + self.app_data.insert(Data::new(data)); + self + } + + /// Set application data. This is equivalent of `App::app_data()` method + /// for testing purpose. + pub fn app_data(mut self, data: T) -> Self { + self.app_data.insert(data); + self + } + + #[cfg(test)] + /// Set request config + pub(crate) fn rmap(mut self, rmap: ResourceMap) -> Self { + self.rmap = rmap; + self + } + + fn finish(&mut self) -> Request { + // mut used when cookie feature is enabled + #[allow(unused_mut)] + let mut req = self.req.finish(); + + #[cfg(feature = "cookies")] + { + use actix_http::header::{HeaderValue, COOKIE}; + + let cookie: String = self + .cookies + .delta() + // ensure only name=value is written to cookie header + .map(|c| c.stripped().encoded().to_string()) + .collect::>() + .join("; "); + + if !cookie.is_empty() { + req.headers_mut() + .insert(COOKIE, HeaderValue::from_str(&cookie).unwrap()); + } + } + + req + } + + /// Complete request creation and generate `Request` instance + pub fn to_request(mut self) -> Request { + let mut req = self.finish(); + req.head_mut().peer_addr = self.peer_addr; + req + } + + /// Complete request creation and generate `ServiceRequest` instance + pub fn to_srv_request(mut self) -> ServiceRequest { + let (mut head, payload) = self.finish().into_parts(); + head.peer_addr = self.peer_addr; + self.path.get_mut().update(&head.uri); + + let app_state = AppInitServiceState::new(Rc::new(self.rmap), self.config.clone()); + + ServiceRequest::new( + HttpRequest::new( + self.path, + head, + app_state, + Rc::new(self.app_data), + None, + Default::default(), + ), + payload, + ) + } + + /// Complete request creation and generate `ServiceResponse` instance + pub fn to_srv_response(self, res: HttpResponse) -> ServiceResponse { + self.to_srv_request().into_response(res) + } + + /// Complete request creation and generate `HttpRequest` instance + pub fn to_http_request(mut self) -> HttpRequest { + let (mut head, _) = self.finish().into_parts(); + head.peer_addr = self.peer_addr; + self.path.get_mut().update(&head.uri); + + let app_state = AppInitServiceState::new(Rc::new(self.rmap), self.config.clone()); + + HttpRequest::new( + self.path, + head, + app_state, + Rc::new(self.app_data), + None, + Default::default(), + ) + } + + /// Complete request creation and generate `HttpRequest` and `Payload` instances + pub fn to_http_parts(mut self) -> (HttpRequest, Payload) { + let (mut head, payload) = self.finish().into_parts(); + head.peer_addr = self.peer_addr; + self.path.get_mut().update(&head.uri); + + let app_state = AppInitServiceState::new(Rc::new(self.rmap), self.config.clone()); + + let req = HttpRequest::new( + self.path, + head, + app_state, + Rc::new(self.app_data), + None, + Default::default(), + ); + + (req, payload) + } + + /// Complete request creation, calls service and waits for response future completion. + pub async fn send_request(self, app: &S) -> S::Response + where + S: Service, Error = E>, + E: std::fmt::Debug, + { + let req = self.to_request(); + test::call_service(app, req).await + } + + #[cfg(test)] + pub fn set_server_hostname(&mut self, host: &str) { + self.config.set_host(host) + } +} + +#[cfg(test)] +mod tests { + use std::time::SystemTime; + + use super::*; + use crate::{http::header, test::init_service, web, App, Error, HttpResponse, Responder}; + + #[actix_rt::test] + async fn test_basics() { + let req = TestRequest::default() + .version(Version::HTTP_2) + .insert_header(header::ContentType::json()) + .insert_header(header::Date(SystemTime::now().into())) + .param("test", "123") + .data(10u32) + .app_data(20u64) + .peer_addr("127.0.0.1:8081".parse().unwrap()) + .to_http_request(); + assert!(req.headers().contains_key(header::CONTENT_TYPE)); + assert!(req.headers().contains_key(header::DATE)); + assert_eq!( + req.head().peer_addr, + Some("127.0.0.1:8081".parse().unwrap()) + ); + assert_eq!(&req.match_info()["test"], "123"); + assert_eq!(req.version(), Version::HTTP_2); + let data = req.app_data::>().unwrap(); + assert!(req.app_data::>().is_none()); + assert_eq!(*data.get_ref(), 10); + + assert!(req.app_data::().is_none()); + let data = req.app_data::().unwrap(); + assert_eq!(*data, 20); + } + + #[actix_rt::test] + async fn test_send_request() { + let app = init_service( + App::new().service( + web::resource("/index.html") + .route(web::get().to(|| HttpResponse::Ok().body("welcome!"))), + ), + ) + .await; + + let resp = TestRequest::get() + .uri("/index.html") + .send_request(&app) + .await; + + let result = test::read_body(resp).await; + assert_eq!(result, Bytes::from_static(b"welcome!")); + } + + #[actix_rt::test] + async fn test_async_with_block() { + async fn async_with_block() -> Result { + let res = web::block(move || Some(4usize).ok_or("wrong")).await; + + match res { + Ok(value) => Ok(HttpResponse::Ok() + .content_type("text/plain") + .body(format!("Async with block value: {:?}", value))), + Err(_) => panic!("Unexpected"), + } + } + + let app = + init_service(App::new().service(web::resource("/index.html").to(async_with_block))) + .await; + + let req = TestRequest::post().uri("/index.html").to_request(); + let res = app.call(req).await.unwrap(); + assert!(res.status().is_success()); + } + + // allow deprecated App::data + #[allow(deprecated)] + #[actix_rt::test] + async fn test_server_data() { + async fn handler(data: web::Data) -> impl Responder { + assert_eq!(**data, 10); + HttpResponse::Ok() + } + + let app = init_service( + App::new() + .data(10usize) + .service(web::resource("/index.html").to(handler)), + ) + .await; + + let req = TestRequest::post().uri("/index.html").to_request(); + let res = app.call(req).await.unwrap(); + assert!(res.status().is_success()); + } +} diff --git a/src/test/test_services.rs b/src/test/test_services.rs new file mode 100644 index 000000000..b4810cfd8 --- /dev/null +++ b/src/test/test_services.rs @@ -0,0 +1,31 @@ +use actix_utils::future::ok; + +use crate::{ + body::BoxBody, + dev::{fn_service, Service, ServiceRequest, ServiceResponse}, + http::StatusCode, + Error, HttpResponseBuilder, +}; + +/// Creates service that always responds with `200 OK` and no body. +pub fn ok_service( +) -> impl Service, Error = Error> { + simple_service(StatusCode::OK) +} + +/// Creates service that always responds with given status code and no body. +pub fn simple_service( + status_code: StatusCode, +) -> impl Service, Error = Error> { + fn_service(move |req: ServiceRequest| { + ok(req.into_response(HttpResponseBuilder::new(status_code).finish())) + }) +} + +#[doc(hidden)] +#[deprecated(since = "4.0.0", note = "Renamed to `simple_service`.")] +pub fn default_service( + status_code: StatusCode, +) -> impl Service, Error = Error> { + simple_service(status_code) +} diff --git a/src/test/test_utils.rs b/src/test/test_utils.rs new file mode 100644 index 000000000..02d4c9bf3 --- /dev/null +++ b/src/test/test_utils.rs @@ -0,0 +1,474 @@ +use std::fmt; + +use actix_http::Request; +use actix_service::IntoServiceFactory; +use serde::de::DeserializeOwned; + +use crate::{ + body::{self, MessageBody}, + config::AppConfig, + dev::{Service, ServiceFactory}, + service::ServiceResponse, + web::Bytes, + Error, +}; + +/// Initialize service from application builder instance. +/// +/// # Examples +/// ``` +/// use actix_service::Service; +/// use actix_web::{test, web, App, HttpResponse, http::StatusCode}; +/// +/// #[actix_web::test] +/// async fn test_init_service() { +/// let app = test::init_service( +/// App::new() +/// .service(web::resource("/test").to(|| async { "OK" })) +/// ).await; +/// +/// // Create request object +/// let req = test::TestRequest::with_uri("/test").to_request(); +/// +/// // Execute application +/// let res = app.call(req).await.unwrap(); +/// assert_eq!(res.status(), StatusCode::OK); +/// } +/// ``` +/// +/// # Panics +/// Panics if service initialization returns an error. +pub async fn init_service( + app: R, +) -> impl Service, Error = E> +where + R: IntoServiceFactory, + S: ServiceFactory, Error = E>, + S::InitError: std::fmt::Debug, +{ + try_init_service(app) + .await + .expect("service initialization failed") +} + +/// Fallible version of [`init_service`] that allows testing initialization errors. +pub(crate) async fn try_init_service( + app: R, +) -> Result, Error = E>, S::InitError> +where + R: IntoServiceFactory, + S: ServiceFactory, Error = E>, + S::InitError: std::fmt::Debug, +{ + let srv = app.into_factory(); + srv.new_service(AppConfig::default()).await +} + +/// Calls service and waits for response future completion. +/// +/// # Examples +/// ``` +/// use actix_web::{test, web, App, HttpResponse, http::StatusCode}; +/// +/// #[actix_web::test] +/// async fn test_response() { +/// let app = test::init_service( +/// App::new() +/// .service(web::resource("/test").to(|| async { +/// HttpResponse::Ok() +/// })) +/// ).await; +/// +/// // Create request object +/// let req = test::TestRequest::with_uri("/test").to_request(); +/// +/// // Call application +/// let res = test::call_service(&app, req).await; +/// assert_eq!(res.status(), StatusCode::OK); +/// } +/// ``` +/// +/// # Panics +/// Panics if service call returns error. +pub async fn call_service(app: &S, req: R) -> S::Response +where + S: Service, Error = E>, + E: std::fmt::Debug, +{ + app.call(req) + .await + .expect("test service call returned error") +} + +/// Helper function that returns a response body of a TestRequest +/// +/// # Examples +/// ``` +/// use actix_web::{test, web, App, HttpResponse, http::header}; +/// use bytes::Bytes; +/// +/// #[actix_web::test] +/// async fn test_index() { +/// let app = test::init_service( +/// App::new().service( +/// web::resource("/index.html") +/// .route(web::post().to(|| async { +/// HttpResponse::Ok().body("welcome!") +/// }))) +/// ).await; +/// +/// let req = test::TestRequest::post() +/// .uri("/index.html") +/// .header(header::CONTENT_TYPE, "application/json") +/// .to_request(); +/// +/// let result = test::call_and_read_body(&app, req).await; +/// assert_eq!(result, Bytes::from_static(b"welcome!")); +/// } +/// ``` +/// +/// # Panics +/// Panics if: +/// - service call returns error; +/// - body yields an error while it is being read. +pub async fn call_and_read_body(app: &S, req: Request) -> Bytes +where + S: Service, Error = Error>, + B: MessageBody, + B::Error: fmt::Debug, +{ + let res = call_service(app, req).await; + read_body(res).await +} + +#[doc(hidden)] +#[deprecated(since = "4.0.0", note = "Renamed to `call_and_read_body`.")] +pub async fn read_response(app: &S, req: Request) -> Bytes +where + S: Service, Error = Error>, + B: MessageBody, + B::Error: fmt::Debug, +{ + let res = call_service(app, req).await; + read_body(res).await +} + +/// Helper function that returns a response body of a ServiceResponse. +/// +/// # Examples +/// ``` +/// use actix_web::{test, web, App, HttpResponse, http::header}; +/// use bytes::Bytes; +/// +/// #[actix_web::test] +/// async fn test_index() { +/// let app = test::init_service( +/// App::new().service( +/// web::resource("/index.html") +/// .route(web::post().to(|| async { +/// HttpResponse::Ok().body("welcome!") +/// }))) +/// ).await; +/// +/// let req = test::TestRequest::post() +/// .uri("/index.html") +/// .header(header::CONTENT_TYPE, "application/json") +/// .to_request(); +/// +/// let res = test::call_service(&app, req).await; +/// let result = test::read_body(res).await; +/// assert_eq!(result, Bytes::from_static(b"welcome!")); +/// } +/// ``` +/// +/// # Panics +/// Panics if body yields an error while it is being read. +pub async fn read_body(res: ServiceResponse) -> Bytes +where + B: MessageBody, + B::Error: fmt::Debug, +{ + let body = res.into_body(); + body::to_bytes(body) + .await + .expect("error reading test response body") +} + +/// Helper function that returns a deserialized response body of a ServiceResponse. +/// +/// # Examples +/// ``` +/// use actix_web::{App, test, web, HttpResponse, http::header}; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Serialize, Deserialize)] +/// pub struct Person { +/// id: String, +/// name: String, +/// } +/// +/// #[actix_web::test] +/// async fn test_post_person() { +/// let app = test::init_service( +/// App::new().service( +/// web::resource("/people") +/// .route(web::post().to(|person: web::Json| async { +/// HttpResponse::Ok() +/// .json(person)}) +/// )) +/// ).await; +/// +/// let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); +/// +/// let res = test::TestRequest::post() +/// .uri("/people") +/// .header(header::CONTENT_TYPE, "application/json") +/// .set_payload(payload) +/// .send_request(&mut app) +/// .await; +/// +/// assert!(res.status().is_success()); +/// +/// let result: Person = test::read_body_json(res).await; +/// } +/// ``` +/// +/// # Panics +/// Panics if: +/// - body yields an error while it is being read; +/// - received body is not a valid JSON representation of `T`. +pub async fn read_body_json(res: ServiceResponse) -> T +where + B: MessageBody, + B::Error: fmt::Debug, + T: DeserializeOwned, +{ + let body = read_body(res).await; + + serde_json::from_slice(&body).unwrap_or_else(|err| { + panic!( + "could not deserialize body into a {}\nerr: {}\nbody: {:?}", + std::any::type_name::(), + err, + body, + ) + }) +} + +/// Helper function that returns a deserialized response body of a TestRequest +/// +/// # Examples +/// ``` +/// use actix_web::{App, test, web, HttpResponse, http::header}; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Serialize, Deserialize)] +/// pub struct Person { +/// id: String, +/// name: String +/// } +/// +/// #[actix_web::test] +/// async fn test_add_person() { +/// let app = test::init_service( +/// App::new().service( +/// web::resource("/people") +/// .route(web::post().to(|person: web::Json| async { +/// HttpResponse::Ok() +/// .json(person)}) +/// )) +/// ).await; +/// +/// let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); +/// +/// let req = test::TestRequest::post() +/// .uri("/people") +/// .header(header::CONTENT_TYPE, "application/json") +/// .set_payload(payload) +/// .to_request(); +/// +/// let result: Person = test::call_and_read_body_json(&mut app, req).await; +/// } +/// ``` +/// +/// # Panics +/// Panics if: +/// - service call returns an error body yields an error while it is being read; +/// - body yields an error while it is being read; +/// - received body is not a valid JSON representation of `T`. +pub async fn call_and_read_body_json(app: &S, req: Request) -> T +where + S: Service, Error = Error>, + B: MessageBody, + B::Error: fmt::Debug, + T: DeserializeOwned, +{ + let res = call_service(app, req).await; + read_body_json(res).await +} + +#[doc(hidden)] +#[deprecated(since = "4.0.0", note = "Renamed to `call_and_read_body_json`.")] +pub async fn read_response_json(app: &S, req: Request) -> T +where + S: Service, Error = Error>, + B: MessageBody, + B::Error: fmt::Debug, + T: DeserializeOwned, +{ + call_and_read_body_json(app, req).await +} + +#[cfg(test)] +mod tests { + + use serde::{Deserialize, Serialize}; + + use super::*; + use crate::{http::header, test::TestRequest, web, App, HttpMessage, HttpResponse}; + + #[actix_rt::test] + async fn test_request_methods() { + let app = init_service( + App::new().service( + web::resource("/index.html") + .route(web::put().to(|| HttpResponse::Ok().body("put!"))) + .route(web::patch().to(|| HttpResponse::Ok().body("patch!"))) + .route(web::delete().to(|| HttpResponse::Ok().body("delete!"))), + ), + ) + .await; + + let put_req = TestRequest::put() + .uri("/index.html") + .insert_header((header::CONTENT_TYPE, "application/json")) + .to_request(); + + let result = call_and_read_body(&app, put_req).await; + assert_eq!(result, Bytes::from_static(b"put!")); + + let patch_req = TestRequest::patch() + .uri("/index.html") + .insert_header((header::CONTENT_TYPE, "application/json")) + .to_request(); + + let result = call_and_read_body(&app, patch_req).await; + assert_eq!(result, Bytes::from_static(b"patch!")); + + let delete_req = TestRequest::delete().uri("/index.html").to_request(); + let result = call_and_read_body(&app, delete_req).await; + assert_eq!(result, Bytes::from_static(b"delete!")); + } + + #[derive(Serialize, Deserialize)] + pub struct Person { + id: String, + name: String, + } + + #[actix_rt::test] + async fn test_response_json() { + let app = init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Json| HttpResponse::Ok().json(person)), + ))) + .await; + + let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); + + let req = TestRequest::post() + .uri("/people") + .insert_header((header::CONTENT_TYPE, "application/json")) + .set_payload(payload) + .to_request(); + + let result: Person = call_and_read_body_json(&app, req).await; + assert_eq!(&result.id, "12345"); + } + + #[actix_rt::test] + async fn test_body_json() { + let app = init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Json| HttpResponse::Ok().json(person)), + ))) + .await; + + let payload = r#"{"id":"12345","name":"User name"}"#.as_bytes(); + + let res = TestRequest::post() + .uri("/people") + .insert_header((header::CONTENT_TYPE, "application/json")) + .set_payload(payload) + .send_request(&app) + .await; + + let result: Person = read_body_json(res).await; + assert_eq!(&result.name, "User name"); + } + + #[actix_rt::test] + async fn test_request_response_form() { + let app = init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Form| HttpResponse::Ok().json(person)), + ))) + .await; + + let payload = Person { + id: "12345".to_string(), + name: "User name".to_string(), + }; + + let req = TestRequest::post() + .uri("/people") + .set_form(&payload) + .to_request(); + + assert_eq!(req.content_type(), "application/x-www-form-urlencoded"); + + let result: Person = call_and_read_body_json(&app, req).await; + assert_eq!(&result.id, "12345"); + assert_eq!(&result.name, "User name"); + } + + #[actix_rt::test] + async fn test_response() { + let app = init_service( + App::new().service( + web::resource("/index.html") + .route(web::post().to(|| HttpResponse::Ok().body("welcome!"))), + ), + ) + .await; + + let req = TestRequest::post() + .uri("/index.html") + .insert_header((header::CONTENT_TYPE, "application/json")) + .to_request(); + + let result = call_and_read_body(&app, req).await; + assert_eq!(result, Bytes::from_static(b"welcome!")); + } + + #[actix_rt::test] + async fn test_request_response_json() { + let app = init_service(App::new().service(web::resource("/people").route( + web::post().to(|person: web::Json| HttpResponse::Ok().json(person)), + ))) + .await; + + let payload = Person { + id: "12345".to_string(), + name: "User name".to_string(), + }; + + let req = TestRequest::post() + .uri("/people") + .set_json(&payload) + .to_request(); + + assert_eq!(req.content_type(), "application/json"); + + let result: Person = call_and_read_body_json(&app, req).await; + assert_eq!(&result.id, "12345"); + assert_eq!(&result.name, "User name"); + } +} diff --git a/src/types/either.rs b/src/types/either.rs index bbab48dec..0eafb9e43 100644 --- a/src/types/either.rs +++ b/src/types/either.rs @@ -1,9 +1,18 @@ //! For either helper, see [`Either`]. +use std::{ + future::Future, + mem, + pin::Pin, + task::{Context, Poll}, +}; + use bytes::Bytes; -use futures_util::{future::LocalBoxFuture, FutureExt, TryFutureExt}; +use futures_core::ready; +use pin_project_lite::pin_project; use crate::{ + body::EitherBody, dev, web::{Form, Json}, Error, FromRequest, HttpRequest, HttpResponse, Responder, @@ -11,8 +20,6 @@ use crate::{ /// Combines two extractor or responder types into a single type. /// -/// Can be converted to and from an [`either::Either`]. -/// /// # Extractor /// Provides a mechanism for trying two extractors, a primary and a fallback. Useful for /// "polymorphic payloads" where, for example, a form might be JSON or URL encoded. @@ -93,24 +100,6 @@ impl Either, Form> { } } -impl From> for Either { - fn from(val: either::Either) -> Self { - match val { - either::Either::Left(l) => Either::Left(l), - either::Either::Right(r) => Either::Right(r), - } - } -} - -impl From> for either::Either { - fn from(val: Either) -> Self { - match val { - Either::Left(l) => either::Either::Left(l), - Either::Right(r) => either::Either::Right(r), - } - } -} - #[cfg(test)] impl Either { pub(self) fn unwrap_left(self) -> L { @@ -138,10 +127,12 @@ where L: Responder, R: Responder, { - fn respond_to(self, req: &HttpRequest) -> HttpResponse { + type Body = EitherBody; + + fn respond_to(self, req: &HttpRequest) -> HttpResponse { match self { - Either::Left(a) => a.respond_to(req), - Either::Right(b) => b.respond_to(req), + Either::Left(a) => a.respond_to(req).map_into_left_body(), + Either::Right(b) => b.respond_to(req).map_into_right_body(), } } } @@ -155,7 +146,7 @@ pub enum EitherExtractError { /// Error from payload buffering, such as exceeding payload max size limit. Bytes(Error), - /// Error from primary extractor. + /// Error from primary and fallback extractors. Extract(L, R), } @@ -179,41 +170,113 @@ where R: FromRequest + 'static, { type Error = EitherExtractError; - type Future = LocalBoxFuture<'static, Result>; - type Config = (); + type Future = EitherExtractFut; fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { - let req2 = req.clone(); - - Bytes::from_request(req, payload) - .map_err(EitherExtractError::Bytes) - .and_then(|bytes| bytes_to_l_or_r(req2, bytes)) - .boxed_local() + EitherExtractFut { + req: req.clone(), + state: EitherExtractState::Bytes { + bytes: Bytes::from_request(req, payload), + }, + } } } -async fn bytes_to_l_or_r( - req: HttpRequest, - bytes: Bytes, -) -> Result, EitherExtractError> +pin_project! { + pub struct EitherExtractFut + where + R: FromRequest, + L: FromRequest, + { + req: HttpRequest, + #[pin] + state: EitherExtractState, + } +} + +pin_project! { + #[project = EitherExtractProj] + pub enum EitherExtractState + where + L: FromRequest, + R: FromRequest, + { + Bytes { + #[pin] + bytes: ::Future, + }, + Left { + #[pin] + left: L::Future, + fallback: Bytes, + }, + Right { + #[pin] + right: R::Future, + left_err: Option, + }, + } +} + +impl Future for EitherExtractFut where - L: FromRequest + 'static, - R: FromRequest + 'static, + L: FromRequest, + R: FromRequest, + LF: Future> + 'static, + RF: Future> + 'static, + LE: Into, + RE: Into, { - let fallback = bytes.clone(); - let a_err; + type Output = Result, EitherExtractError>; - let mut pl = payload_from_bytes(bytes); - match L::from_request(&req, &mut pl).await { - Ok(a_data) => return Ok(Either::Left(a_data)), - // store A's error for returning if B also fails - Err(err) => a_err = err, - }; - - let mut pl = payload_from_bytes(fallback); - match R::from_request(&req, &mut pl).await { - Ok(b_data) => return Ok(Either::Right(b_data)), - Err(b_err) => Err(EitherExtractError::Extract(a_err, b_err)), + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + let ready = loop { + let next = match this.state.as_mut().project() { + EitherExtractProj::Bytes { bytes } => { + let res = ready!(bytes.poll(cx)); + match res { + Ok(bytes) => { + let fallback = bytes.clone(); + let left = + L::from_request(this.req, &mut payload_from_bytes(bytes)); + EitherExtractState::Left { left, fallback } + } + Err(err) => break Err(EitherExtractError::Bytes(err)), + } + } + EitherExtractProj::Left { left, fallback } => { + let res = ready!(left.poll(cx)); + match res { + Ok(extracted) => break Ok(Either::Left(extracted)), + Err(left_err) => { + let right = R::from_request( + this.req, + &mut payload_from_bytes(mem::take(fallback)), + ); + EitherExtractState::Right { + left_err: Some(left_err), + right, + } + } + } + } + EitherExtractProj::Right { right, left_err } => { + let res = ready!(right.poll(cx)); + match res { + Ok(data) => break Ok(Either::Right(data)), + Err(err) => { + break Err(EitherExtractError::Extract( + left_err.take().unwrap(), + err, + )); + } + } + } + }; + this.state.set(next); + }; + Poll::Ready(ready) } } diff --git a/src/types/form.rs b/src/types/form.rs index 0b5c3c1b4..9c09c6b73 100644 --- a/src/types/form.rs +++ b/src/types/form.rs @@ -1,6 +1,7 @@ //! For URL encoded form helper documentation, see [`Form`]. use std::{ + borrow::Cow, fmt, future::Future, ops, @@ -12,17 +13,16 @@ use std::{ use actix_http::Payload; use bytes::BytesMut; use encoding_rs::{Encoding, UTF_8}; -use futures_util::{ - future::{FutureExt, LocalBoxFuture}, - StreamExt, -}; +use futures_core::{future::LocalBoxFuture, ready}; +use futures_util::{FutureExt as _, StreamExt as _}; use serde::{de::DeserializeOwned, Serialize}; -#[cfg(feature = "compress")] +#[cfg(feature = "__compress")] use crate::dev::Decompress; use crate::{ - error::UrlencodedError, extract::FromRequest, http::header::CONTENT_LENGTH, web, Error, - HttpMessage, HttpRequest, HttpResponse, Responder, + body::EitherBody, error::UrlencodedError, extract::FromRequest, + http::header::CONTENT_LENGTH, web, Error, HttpMessage, HttpRequest, HttpResponse, + Responder, }; /// URL encoded payload extractor and responder. @@ -31,9 +31,9 @@ use crate::{ /// /// # Extractor /// To extract typed data from a request body, the inner type `T` must implement the -/// [`serde::Deserialize`] trait. +/// [`DeserializeOwned`] trait. /// -/// Use [`FormConfig`] to configure extraction process. +/// Use [`FormConfig`] to configure extraction options. /// /// ``` /// use actix_web::{post, web}; @@ -82,7 +82,11 @@ use crate::{ /// }) /// } /// ``` -#[derive(PartialEq, Eq, PartialOrd, Ord)] +/// +/// # Panics +/// URL encoded forms consist of unordered `key=value` pairs, therefore they cannot be decoded into +/// any type which depends upon data ordering (eg. tuples). Trying to do so will result in a panic. +#[derive(PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct Form(pub T); impl Form { @@ -106,43 +110,66 @@ impl ops::DerefMut for Form { } } +impl Serialize for Form +where + T: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.0.serialize(serializer) + } +} + /// See [here](#extractor) for example of usage as an extractor. impl FromRequest for Form where T: DeserializeOwned + 'static, { - type Config = FormConfig; type Error = Error; - type Future = LocalBoxFuture<'static, Result>; + type Future = FormExtractFut; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - let req2 = req.clone(); - let (limit, err_handler) = req - .app_data::() - .or_else(|| { - req.app_data::>() - .map(|d| d.as_ref()) - }) - .map(|c| (c.limit, c.err_handler.clone())) - .unwrap_or((16384, None)); + let FormConfig { limit, err_handler } = FormConfig::from_req(req).clone(); - UrlEncoded::new(req, payload) - .limit(limit) - .map(move |res| match res { - Err(err) => match err_handler { - Some(err_handler) => Err((err_handler)(err, &req2)), - None => Err(err.into()), - }, - Ok(item) => Ok(Form(item)), - }) - .boxed_local() + FormExtractFut { + fut: UrlEncoded::new(req, payload).limit(limit), + req: req.clone(), + err_handler, + } } } -impl fmt::Debug for Form { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) +type FormErrHandler = Option Error>>; + +pub struct FormExtractFut { + fut: UrlEncoded, + err_handler: FormErrHandler, + req: HttpRequest, +} + +impl Future for FormExtractFut +where + T: DeserializeOwned + 'static, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + let res = ready!(Pin::new(&mut this.fut).poll(cx)); + + let res = match res { + Err(err) => match &this.err_handler { + Some(err_handler) => Err((err_handler)(err, &this.req)), + None => Err(err.into()), + }, + Ok(item) => Ok(Form(item)), + }; + + Poll::Ready(res) } } @@ -154,12 +181,21 @@ impl fmt::Display for Form { /// See [here](#responder) for example of usage as a handler return type. impl Responder for Form { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { + type Body = EitherBody; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { match serde_urlencoded::to_string(&self.0) { - Ok(body) => HttpResponse::Ok() + Ok(body) => match HttpResponse::Ok() .content_type(mime::APPLICATION_WWW_FORM_URLENCODED) - .body(body), - Err(err) => HttpResponse::from_error(err.into()), + .message_body(body) + { + Ok(res) => res.map_into_left_body(), + Err(err) => HttpResponse::from_error(err).map_into_right_body(), + }, + + Err(err) => { + HttpResponse::from_error(UrlencodedError::Serialize(err)).map_into_right_body() + } } } } @@ -189,7 +225,7 @@ impl Responder for Form { #[derive(Clone)] pub struct FormConfig { limit: usize, - err_handler: Option Error>>, + err_handler: FormErrHandler, } impl FormConfig { @@ -207,14 +243,26 @@ impl FormConfig { self.err_handler = Some(Rc::new(f)); self } + + /// Extract payload config from app data. + /// + /// Checks both `T` and `Data`, in that order, and falls back to the default payload config. + fn from_req(req: &HttpRequest) -> &Self { + req.app_data::() + .or_else(|| req.app_data::>().map(|d| d.as_ref())) + .unwrap_or(&DEFAULT_CONFIG) + } } +/// Allow shared refs used as default. +const DEFAULT_CONFIG: FormConfig = FormConfig { + limit: 16_384, // 2^14 bytes (~16kB) + err_handler: None, +}; + impl Default for FormConfig { fn default() -> Self { - FormConfig { - limit: 16_384, // 2^14 bytes (~16kB) - err_handler: None, - } + DEFAULT_CONFIG } } @@ -226,9 +274,9 @@ impl Default for FormConfig { /// - content type is not `application/x-www-form-urlencoded` /// - content length is greater than [limit](UrlEncoded::limit()) pub struct UrlEncoded { - #[cfg(feature = "compress")] + #[cfg(feature = "__compress")] stream: Option>, - #[cfg(not(feature = "compress"))] + #[cfg(not(feature = "__compress"))] stream: Option, limit: usize, @@ -264,10 +312,15 @@ impl UrlEncoded { } }; - #[cfg(feature = "compress")] - let payload = Decompress::from_headers(payload.take(), req.headers()); - #[cfg(not(feature = "compress"))] - let payload = payload.take(); + let payload = { + cfg_if::cfg_if! { + if #[cfg(feature = "__compress")] { + Decompress::from_headers(payload.take(), req.headers()) + } else { + payload.take() + } + } + }; UrlEncoded { encoding, @@ -342,14 +395,14 @@ where } if encoding == UTF_8 { - serde_urlencoded::from_bytes::(&body).map_err(|_| UrlencodedError::Parse) + serde_urlencoded::from_bytes::(&body).map_err(UrlencodedError::Parse) } else { let body = encoding .decode_without_bom_handling_and_without_replacement(&body) - .map(|s| s.into_owned()) - .ok_or(UrlencodedError::Parse)?; + .map(Cow::into_owned) + .ok_or(UrlencodedError::Encoding)?; - serde_urlencoded::from_str::(&body).map_err(|_| UrlencodedError::Parse) + serde_urlencoded::from_str::(&body).map_err(UrlencodedError::Parse) } } .boxed_local(), @@ -365,11 +418,14 @@ mod tests { use serde::{Deserialize, Serialize}; use super::*; - use crate::http::{ - header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE}, - StatusCode, - }; use crate::test::TestRequest; + use crate::{ + http::{ + header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE}, + StatusCode, + }, + test::assert_body_eq, + }; #[derive(Deserialize, Serialize, Debug, PartialEq)] struct Info { @@ -400,12 +456,8 @@ mod tests { UrlencodedError::Overflow { .. } => { matches!(other, UrlencodedError::Overflow { .. }) } - UrlencodedError::UnknownLength => { - matches!(other, UrlencodedError::UnknownLength) - } - UrlencodedError::ContentType => { - matches!(other, UrlencodedError::ContentType) - } + UrlencodedError::UnknownLength => matches!(other, UrlencodedError::UnknownLength), + UrlencodedError::ContentType => matches!(other, UrlencodedError::ContentType), _ => false, } } @@ -481,15 +533,13 @@ mod tests { hello: "world".to_string(), counter: 123, }); - let resp = form.respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); + let res = form.respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(CONTENT_TYPE).unwrap(), + res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/x-www-form-urlencoded") ); - - use crate::responder::tests::BodyTest; - assert_eq!(resp.body().bin_ref(), b"hello=world&counter=123"); + assert_body_eq!(res, b"hello=world&counter=123"); } #[actix_rt::test] diff --git a/src/types/header.rs b/src/types/header.rs new file mode 100644 index 000000000..6ea77faf6 --- /dev/null +++ b/src/types/header.rs @@ -0,0 +1,102 @@ +//! For header extractor helper documentation, see [`Header`](crate::types::Header). + +use std::{fmt, ops}; + +use actix_utils::future::{err, ok, Ready}; + +use crate::{ + dev::Payload, error::ParseError, extract::FromRequest, http::header::Header as ParseHeader, + HttpRequest, +}; + +/// Extract typed headers from the request. +/// +/// To extract a header, the inner type `T` must implement the +/// [`Header`](crate::http::header::Header) trait. +/// +/// # Examples +/// ``` +/// use actix_web::{get, web, http::header}; +/// +/// #[get("/")] +/// async fn index(date: web::Header) -> String { +/// format!("Request was sent at {}", date.to_string()) +/// } +/// ``` +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] +pub struct Header(pub T); + +impl Header { + /// Unwrap into the inner `T` value. + pub fn into_inner(self) -> T { + self.0 + } +} + +impl ops::Deref for Header { + type Target = T; + + fn deref(&self) -> &T { + &self.0 + } +} + +impl ops::DerefMut for Header { + fn deref_mut(&mut self) -> &mut T { + &mut self.0 + } +} + +impl fmt::Display for Header +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl FromRequest for Header +where + T: ParseHeader, +{ + type Error = ParseError; + type Future = Ready>; + + #[inline] + fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { + match ParseHeader::parse(req) { + Ok(header) => ok(Header(header)), + Err(e) => err(e), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::http::{header, Method}; + use crate::test::TestRequest; + + #[actix_rt::test] + async fn test_header_extract() { + let (req, mut pl) = TestRequest::default() + .insert_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)) + .insert_header((header::ALLOW, header::Allow(vec![Method::GET]))) + .to_http_parts(); + + let s = Header::::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(s.into_inner().0, mime::APPLICATION_JSON); + + let s = Header::::from_request(&req, &mut pl) + .await + .unwrap(); + assert_eq!(s.into_inner().0, vec![Method::GET]); + + assert!(Header::::from_request(&req, &mut pl) + .await + .is_err()); + } +} diff --git a/src/types/json.rs b/src/types/json.rs index 28960402a..8fdbfafa4 100644 --- a/src/types/json.rs +++ b/src/types/json.rs @@ -11,14 +11,15 @@ use std::{ }; use bytes::BytesMut; -use futures_util::{ready, stream::Stream}; +use futures_core::{ready, Stream as _}; use serde::{de::DeserializeOwned, Serialize}; use actix_http::Payload; -#[cfg(feature = "compress")] +#[cfg(feature = "__compress")] use crate::dev::Decompress; use crate::{ + body::EitherBody, error::{Error, JsonPayloadError}, extract::FromRequest, http::header::CONTENT_LENGTH, @@ -34,7 +35,7 @@ use crate::{ /// To extract typed data from a request body, the inner type `T` must implement the /// [`serde::Deserialize`] trait. /// -/// Use [`JsonConfig`] to configure extraction process. +/// Use [`JsonConfig`] to configure extraction options. /// /// ``` /// use actix_web::{post, web, App}; @@ -73,6 +74,7 @@ use crate::{ /// }) /// } /// ``` +#[derive(Debug)] pub struct Json(pub T); impl Json { @@ -96,21 +98,18 @@ impl ops::DerefMut for Json { } } -impl fmt::Debug for Json -where - T: fmt::Debug, -{ +impl fmt::Display for Json { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Json: {:?}", self.0) + fmt::Display::fmt(&self.0, f) } } -impl fmt::Display for Json -where - T: fmt::Display, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.0, f) +impl Serialize for Json { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.0.serialize(serializer) } } @@ -118,36 +117,42 @@ where /// /// If serialization failed impl Responder for Json { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { + type Body = EitherBody; + + fn respond_to(self, _: &HttpRequest) -> HttpResponse { match serde_json::to_string(&self.0) { - Ok(body) => HttpResponse::Ok() + Ok(body) => match HttpResponse::Ok() .content_type(mime::APPLICATION_JSON) - .body(body), - Err(err) => HttpResponse::from_error(err.into()), + .message_body(body) + { + Ok(res) => res.map_into_left_body(), + Err(err) => HttpResponse::from_error(err).map_into_right_body(), + }, + + Err(err) => { + HttpResponse::from_error(JsonPayloadError::Serialize(err)).map_into_right_body() + } } } } /// See [here](#extractor) for example of usage as an extractor. -impl FromRequest for Json -where - T: DeserializeOwned + 'static, -{ +impl FromRequest for Json { type Error = Error; type Future = JsonExtractFut; - type Config = JsonConfig; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { let config = JsonConfig::from_req(req); let limit = config.limit; - let ctype = config.content_type.as_deref(); + let ctype_required = config.content_type_required; + let ctype_fn = config.content_type.as_deref(); let err_handler = config.err_handler.clone(); JsonExtractFut { req: Some(req.clone()), - fut: JsonBody::new(req, payload, ctype).limit(limit), + fut: JsonBody::new(req, payload, ctype_fn, ctype_required).limit(limit), err_handler, } } @@ -162,10 +167,7 @@ pub struct JsonExtractFut { err_handler: JsonErrorHandler, } -impl Future for JsonExtractFut -where - T: DeserializeOwned + 'static, -{ +impl Future for JsonExtractFut { type Output = Result, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -221,7 +223,7 @@ where /// .content_type(|mime| mime == mime::TEXT_PLAIN) /// // use custom error handler /// .error_handler(|err, req| { -/// error::InternalError::from_response(err, HttpResponse::Conflict().finish()).into() +/// error::InternalError::from_response(err, HttpResponse::Conflict().into()).into() /// }); /// /// App::new() @@ -233,10 +235,11 @@ pub struct JsonConfig { limit: usize, err_handler: JsonErrorHandler, content_type: Option bool + Send + Sync>>, + content_type_required: bool, } impl JsonConfig { - /// Set maximum accepted payload size. By default this limit is 32kB. + /// Set maximum accepted payload size. By default this limit is 2MB. pub fn limit(mut self, limit: usize) -> Self { self.limit = limit; self @@ -260,6 +263,12 @@ impl JsonConfig { self } + /// Sets whether or not the request must have a `Content-Type` header to be parsed. + pub fn content_type_required(mut self, content_type_required: bool) -> Self { + self.content_type_required = content_type_required; + self + } + /// Extract payload config from app data. Check both `T` and `Data`, in that order, and fall /// back to the default payload config. fn from_req(req: &HttpRequest) -> &Self { @@ -269,11 +278,14 @@ impl JsonConfig { } } +const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb + /// Allow shared refs used as default. const DEFAULT_CONFIG: JsonConfig = JsonConfig { - limit: 32_768, // 2^15 bytes, (~32kB) + limit: DEFAULT_LIMIT, err_handler: None, content_type: None, + content_type_required: true, }; impl Default for JsonConfig { @@ -284,19 +296,22 @@ impl Default for JsonConfig { /// Future that resolves to some `T` when parsed from a JSON payload. /// -/// Form can be deserialized from any type `T` that implements [`serde::Deserialize`]. +/// Can deserialize any type `T` that implements [`Deserialize`][serde::Deserialize]. /// /// Returns error if: -/// - content type is not `application/json` -/// - content length is greater than [limit](JsonBody::limit()) +/// - `Content-Type` is not `application/json` when `ctype_required` (passed to [`new`][Self::new]) +/// is `true`. +/// - `Content-Length` is greater than [limit](JsonBody::limit()). +/// - The payload, when consumed, is not valid JSON. pub enum JsonBody { Error(Option), Body { limit: usize, + /// Length as reported by `Content-Length` header, if present. length: Option, - #[cfg(feature = "compress")] + #[cfg(feature = "__compress")] payload: Decompress, - #[cfg(not(feature = "compress"))] + #[cfg(not(feature = "__compress"))] payload: Payload, buf: BytesMut, _res: PhantomData, @@ -305,27 +320,27 @@ pub enum JsonBody { impl Unpin for JsonBody {} -impl JsonBody -where - T: DeserializeOwned + 'static, -{ +impl JsonBody { /// Create a new future to decode a JSON request payload. #[allow(clippy::borrow_interior_mutable_const)] pub fn new( req: &HttpRequest, payload: &mut Payload, - ctype: Option<&(dyn Fn(mime::Mime) -> bool + Send + Sync)>, + ctype_fn: Option<&(dyn Fn(mime::Mime) -> bool + Send + Sync)>, + ctype_required: bool, ) -> Self { // check content-type - let json = if let Ok(Some(mime)) = req.mime_type() { + let can_parse_json = if let Ok(Some(mime)) = req.mime_type() { mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) - || ctype.map_or(false, |predicate| predicate(mime)) + || ctype_fn.map_or(false, |predicate| predicate(mime)) } else { - false + // if `ctype_required` is false, assume payload is + // json even when content-type header is missing + !ctype_required }; - if !json { + if !can_parse_json { return JsonBody::Error(Some(JsonPayloadError::ContentType)); } @@ -335,17 +350,22 @@ where .and_then(|l| l.to_str().ok()) .and_then(|s| s.parse::().ok()); - // Notice the content_length is not checked against limit of json config here. + // Notice the content-length is not checked against limit of json config here. // As the internal usage always call JsonBody::limit after JsonBody::new. // And limit check to return an error variant of JsonBody happens there. - #[cfg(feature = "compress")] - let payload = Decompress::from_headers(payload.take(), req.headers()); - #[cfg(not(feature = "compress"))] - let payload = payload.take(); + let payload = { + cfg_if::cfg_if! { + if #[cfg(feature = "__compress")] { + Decompress::from_headers(payload.take(), req.headers()) + } else { + payload.take() + } + } + }; JsonBody::Body { - limit: 262_144, + limit: DEFAULT_LIMIT, length, payload, buf: BytesMut::with_capacity(8192), @@ -353,7 +373,7 @@ where } } - /// Set maximum accepted payload size. The default limit is 256kB. + /// Set maximum accepted payload size. The default limit is 2MB. pub fn limit(self, limit: usize) -> Self { match self { JsonBody::Body { @@ -364,7 +384,10 @@ where } => { if let Some(len) = length { if len > limit { - return JsonBody::Error(Some(JsonPayloadError::Overflow)); + return JsonBody::Error(Some(JsonPayloadError::OverflowKnownLength { + length: len, + limit, + })); } } @@ -381,10 +404,7 @@ where } } -impl Future for JsonBody -where - T: DeserializeOwned + 'static, -{ +impl Future for JsonBody { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -401,14 +421,18 @@ where match res { Some(chunk) => { let chunk = chunk?; - if (buf.len() + chunk.len()) > *limit { - return Poll::Ready(Err(JsonPayloadError::Overflow)); + let buf_len = buf.len() + chunk.len(); + if buf_len > *limit { + return Poll::Ready(Err(JsonPayloadError::Overflow { + limit: *limit, + })); } else { buf.extend_from_slice(&chunk); } } None => { - let json = serde_json::from_slice::(&buf)?; + let json = serde_json::from_slice::(buf) + .map_err(JsonPayloadError::Deserialize)?; return Poll::Ready(Ok(json)); } } @@ -425,12 +449,13 @@ mod tests { use super::*; use crate::{ + body, error::InternalError, http::{ header::{self, CONTENT_LENGTH, CONTENT_TYPE}, StatusCode, }, - test::{load_stream, TestRequest}, + test::{assert_body_eq, TestRequest}, }; #[derive(Serialize, Deserialize, PartialEq, Debug)] @@ -440,10 +465,13 @@ mod tests { fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool { match err { - JsonPayloadError::Overflow => matches!(other, JsonPayloadError::Overflow), - JsonPayloadError::ContentType => { - matches!(other, JsonPayloadError::ContentType) + JsonPayloadError::Overflow { .. } => { + matches!(other, JsonPayloadError::Overflow { .. }) } + JsonPayloadError::OverflowKnownLength { .. } => { + matches!(other, JsonPayloadError::OverflowKnownLength { .. }) + } + JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType), _ => false, } } @@ -455,15 +483,13 @@ mod tests { let j = Json(MyObject { name: "test".to_string(), }); - let resp = j.respond_to(&req); - assert_eq!(resp.status(), StatusCode::OK); + let res = j.respond_to(&req); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - resp.headers().get(header::CONTENT_TYPE).unwrap(), + res.headers().get(header::CONTENT_TYPE).unwrap(), header::HeaderValue::from_static("application/json") ); - - use crate::responder::tests::BodyTest; - assert_eq!(resp.body().bin_ref(), b"{\"name\":\"test\"}"); + assert_body_eq!(res, b"{\"name\":\"test\"}"); } #[actix_rt::test] @@ -489,10 +515,10 @@ mod tests { .to_http_parts(); let s = Json::::from_request(&req, &mut pl).await; - let mut resp = HttpResponse::from_error(s.err().unwrap()); + let resp = HttpResponse::from_error(s.unwrap_err()); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let body = load_stream(resp.take_body()).await.unwrap(); + let body = body::to_bytes(resp.into_body()).await.unwrap(); let msg: MyObject = serde_json::from_slice(&body).unwrap(); assert_eq!(msg.name, "invalid request"); } @@ -535,7 +561,7 @@ mod tests { let s = Json::::from_request(&req, &mut pl).await; assert!(format!("{}", s.err().unwrap()) - .contains("Json payload size is bigger than allowed")); + .contains("JSON payload (16 bytes) is larger than allowed (limit: 10 bytes).")); let (req, mut pl) = TestRequest::default() .insert_header(( @@ -560,7 +586,7 @@ mod tests { #[actix_rt::test] async fn test_json_body() { let (req, mut pl) = TestRequest::default().to_http_parts(); - let json = JsonBody::::new(&req, &mut pl, None).await; + let json = JsonBody::::new(&req, &mut pl, None, true).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let (req, mut pl) = TestRequest::default() @@ -569,7 +595,7 @@ mod tests { header::HeaderValue::from_static("application/text"), )) .to_http_parts(); - let json = JsonBody::::new(&req, &mut pl, None).await; + let json = JsonBody::::new(&req, &mut pl, None, true).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let (req, mut pl) = TestRequest::default() @@ -583,10 +609,33 @@ mod tests { )) .to_http_parts(); - let json = JsonBody::::new(&req, &mut pl, None) + let json = JsonBody::::new(&req, &mut pl, None, true) .limit(100) .await; - assert!(json_eq(json.err().unwrap(), JsonPayloadError::Overflow)); + assert!(json_eq( + json.err().unwrap(), + JsonPayloadError::OverflowKnownLength { + length: 10000, + limit: 100 + } + )); + + let (req, mut pl) = TestRequest::default() + .insert_header(( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + )) + .set_payload(Bytes::from_static(&[0u8; 1000])) + .to_http_parts(); + + let json = JsonBody::::new(&req, &mut pl, None, true) + .limit(100) + .await; + + assert!(json_eq( + json.err().unwrap(), + JsonPayloadError::Overflow { limit: 100 } + )); let (req, mut pl) = TestRequest::default() .insert_header(( @@ -600,7 +649,7 @@ mod tests { .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .to_http_parts(); - let json = JsonBody::::new(&req, &mut pl, None).await; + let json = JsonBody::::new(&req, &mut pl, None, true).await; assert_eq!( json.ok().unwrap(), MyObject { @@ -670,6 +719,21 @@ mod tests { assert!(s.is_err()) } + #[actix_rt::test] + async fn test_json_with_no_content_type() { + let (req, mut pl) = TestRequest::default() + .insert_header(( + header::CONTENT_LENGTH, + header::HeaderValue::from_static("16"), + )) + .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .app_data(JsonConfig::default().content_type_required(false)) + .to_http_parts(); + + let s = Json::::from_request(&req, &mut pl).await; + assert!(s.is_ok()) + } + #[actix_rt::test] async fn test_with_config_in_data_wrapper() { let (req, mut pl) = TestRequest::default() @@ -683,6 +747,7 @@ mod tests { assert!(s.is_err()); let err_str = s.err().unwrap().to_string(); - assert!(err_str.contains("Json payload size is bigger than allowed")); + assert!(err_str + .contains("JSON payload (16 bytes) is larger than allowed (limit: 10 bytes).")); } } diff --git a/src/types/mod.rs b/src/types/mod.rs index a062c351e..bab7c3bc0 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,17 +1,18 @@ //! Common extractors and responders. -// TODO: review visibility mod either; -pub(crate) mod form; -pub(crate) mod json; +mod form; +mod header; +mod json; mod path; -pub(crate) mod payload; +mod payload; mod query; -pub(crate) mod readlines; +mod readlines; -pub use self::either::{Either, EitherExtractError}; -pub use self::form::{Form, FormConfig}; -pub use self::json::{Json, JsonConfig}; +pub use self::either::Either; +pub use self::form::{Form, FormConfig, UrlEncoded}; +pub use self::header::Header; +pub use self::json::{Json, JsonBody, JsonConfig}; pub use self::path::{Path, PathConfig}; pub use self::payload::{Payload, PayloadConfig}; pub use self::query::{Query, QueryConfig}; diff --git a/src/types/path.rs b/src/types/path.rs index 4ab124d53..c3efc22c0 100644 --- a/src/types/path.rs +++ b/src/types/path.rs @@ -2,16 +2,20 @@ use std::{fmt, ops, sync::Arc}; -use actix_http::error::{Error, ErrorNotFound}; use actix_router::PathDeserializer; -use futures_util::future::{ready, Ready}; +use actix_utils::future::{ready, Ready}; use serde::de; -use crate::{dev::Payload, error::PathError, FromRequest, HttpRequest}; +use crate::{ + dev::Payload, + error::{Error, ErrorNotFound, PathError}, + web::Data, + FromRequest, HttpRequest, +}; /// Extract typed data from request path segments. /// -/// Use [`PathConfig`] to configure extraction process. +/// Use [`PathConfig`] to configure extraction option. /// /// # Examples /// ``` @@ -20,7 +24,7 @@ use crate::{dev::Payload, error::PathError, FromRequest, HttpRequest}; /// // extract path info from "/{name}/{count}/index.html" into tuple /// // {name} - deserialize a String /// // {count} - deserialize a u32 -/// #[get("/")] +/// #[get("/{name}/{count}/index.html")] /// async fn index(path: web::Path<(String, u32)>) -> String { /// let (name, count) = path.into_inner(); /// format!("Welcome {}! {}", name, count) @@ -40,12 +44,12 @@ use crate::{dev::Payload, error::PathError, FromRequest, HttpRequest}; /// } /// /// // extract `Info` from a path using serde -/// #[get("/")] +/// #[get("/{name}")] /// async fn index(info: web::Path) -> String { /// format!("Welcome {}!", info.name) /// } /// ``` -#[derive(PartialEq, Eq, PartialOrd, Ord)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct Path(T); impl Path { @@ -81,48 +85,42 @@ impl From for Path { } } -impl fmt::Debug for Path { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - impl fmt::Display for Path { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } -/// See [here](#usage) for example of usage as an extractor. +/// See [here](#Examples) for example of usage as an extractor. impl FromRequest for Path where T: de::DeserializeOwned, { type Error = Error; - type Future = Ready>; - type Config = PathConfig; + type Future = Ready>; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let error_handler = req - .app_data::() - .map(|c| c.ehandler.clone()) - .unwrap_or(None); + .app_data::() + .or_else(|| req.app_data::>().map(Data::get_ref)) + .and_then(|c| c.err_handler.clone()); ready( de::Deserialize::deserialize(PathDeserializer::new(req.match_info())) .map(Path) - .map_err(move |e| { + .map_err(move |err| { log::debug!( "Failed during Path extractor deserialization. \ Request path: {:?}", req.path() ); + if let Some(error_handler) = error_handler { - let e = PathError::Deserialize(e); + let e = PathError::Deserialize(err); (error_handler)(e, req) } else { - ErrorNotFound(e) + ErrorNotFound(err) } }), ) @@ -140,6 +138,7 @@ where /// enum Folder { /// #[serde(rename = "inbox")] /// Inbox, +/// /// #[serde(rename = "outbox")] /// Outbox, /// } @@ -149,42 +148,34 @@ where /// format!("Selected folder: {:?}!", folder) /// } /// -/// fn main() { -/// let app = App::new().service( -/// web::resource("/messages/{folder}") -/// .app_data(PathConfig::default().error_handler(|err, req| { -/// error::InternalError::from_response( -/// err, -/// HttpResponse::Conflict().finish(), -/// ) -/// .into() -/// })) -/// .route(web::post().to(index)), -/// ); -/// } +/// let app = App::new().service( +/// web::resource("/messages/{folder}") +/// .app_data(PathConfig::default().error_handler(|err, req| { +/// error::InternalError::from_response( +/// err, +/// HttpResponse::Conflict().into(), +/// ) +/// .into() +/// })) +/// .route(web::post().to(index)), +/// ); /// ``` -#[derive(Clone)] +#[derive(Clone, Default)] pub struct PathConfig { - ehandler: Option Error + Send + Sync>>, + err_handler: Option Error + Send + Sync>>, } impl PathConfig { - /// Set custom error handler + /// Set custom error handler. pub fn error_handler(mut self, f: F) -> Self where F: Fn(PathError, &HttpRequest) -> Error + Send + Sync + 'static, { - self.ehandler = Some(Arc::new(f)); + self.err_handler = Some(Arc::new(f)); self } } -impl Default for PathConfig { - fn default() -> Self { - PathConfig { ehandler: None } - } -} - #[cfg(test)] mod tests { use actix_router::ResourceDef; @@ -213,7 +204,7 @@ mod tests { let resource = ResourceDef::new("/{value}/"); let mut req = TestRequest::with_uri("/32/").to_srv_request(); - resource.match_path(req.match_info_mut()); + resource.capture_match_info(req.match_info_mut()); let (req, mut pl) = req.into_parts(); assert_eq!(*Path::::from_request(&req, &mut pl).await.unwrap(), 32); @@ -225,7 +216,7 @@ mod tests { let resource = ResourceDef::new("/{key}/{value}/"); let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); - resource.match_path(req.match_info_mut()); + resource.capture_match_info(req.match_info_mut()); let (req, mut pl) = req.into_parts(); let (Path(res),) = <(Path<(String, String)>,)>::from_request(&req, &mut pl) @@ -251,7 +242,7 @@ mod tests { let mut req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); let resource = ResourceDef::new("/{key}/{value}/"); - resource.match_path(req.match_info_mut()); + resource.capture_match_info(req.match_info_mut()); let (req, mut pl) = req.into_parts(); let mut s = Path::::from_request(&req, &mut pl).await.unwrap(); @@ -261,7 +252,7 @@ mod tests { assert_eq!(s.value, "user2"); assert_eq!( format!("{}, {:?}", s, s), - "MyStruct(name, user2), MyStruct { key: \"name\", value: \"user2\" }" + "MyStruct(name, user2), Path(MyStruct { key: \"name\", value: \"user2\" })" ); let s = s.into_inner(); assert_eq!(s.value, "user2"); @@ -274,7 +265,7 @@ mod tests { let mut req = TestRequest::with_uri("/name/32/").to_srv_request(); let resource = ResourceDef::new("/{key}/{value}/"); - resource.match_path(req.match_info_mut()); + resource.capture_match_info(req.match_info_mut()); let (req, mut pl) = req.into_parts(); let s = Path::::from_request(&req, &mut pl).await.unwrap(); @@ -294,6 +285,18 @@ mod tests { assert_eq!(res[1], "32".to_owned()); } + #[actix_rt::test] + async fn paths_decoded() { + let resource = ResourceDef::new("/{key}/{value}"); + let mut req = TestRequest::with_uri("/na%2Bme/us%2Fer%251").to_srv_request(); + resource.capture_match_info(req.match_info_mut()); + + let (req, mut pl) = req.into_parts(); + let path_items = Path::::from_request(&req, &mut pl).await.unwrap(); + assert_eq!(path_items.key, "na+me"); + assert_eq!(path_items.value, "us/er%1"); + } + #[actix_rt::test] async fn test_custom_err_handler() { let (req, mut pl) = TestRequest::with_uri("/name/user1/") @@ -306,7 +309,7 @@ mod tests { let s = Path::<(usize,)>::from_request(&req, &mut pl) .await .unwrap_err(); - let res: HttpResponse = s.into(); + let res = HttpResponse::from_error(s); assert_eq!(res.status(), http::StatusCode::CONFLICT); } diff --git a/src/types/payload.rs b/src/types/payload.rs index 781347b84..d2ab29639 100644 --- a/src/types/payload.rs +++ b/src/types/payload.rs @@ -1,23 +1,24 @@ //! Basic binary and string payload extractors. use std::{ + borrow::Cow, future::Future, pin::Pin, str, task::{Context, Poll}, }; -use actix_http::error::{ErrorBadRequest, PayloadError}; +use actix_http::error::PayloadError; +use actix_utils::future::{ready, Either, Ready}; use bytes::{Bytes, BytesMut}; use encoding_rs::{Encoding, UTF_8}; -use futures_core::stream::Stream; -use futures_util::{ - future::{ready, Either, ErrInto, Ready, TryFutureExt as _}, - ready, -}; +use futures_core::{ready, stream::Stream}; use mime::Mime; -use crate::{dev, http::header, web, Error, FromRequest, HttpMessage, HttpRequest}; +use crate::{ + dev, error::ErrorBadRequest, http::header, web, Error, FromRequest, HttpMessage, + HttpRequest, +}; /// Extract a request's raw payload stream. /// @@ -26,7 +27,7 @@ use crate::{dev, http::header, web, Error, FromRequest, HttpMessage, HttpRequest /// # Examples /// ``` /// use std::future::Future; -/// use futures_util::stream::{Stream, StreamExt}; +/// use futures_util::stream::StreamExt as _; /// use actix_web::{post, web}; /// /// // `body: web::Payload` parameter extracts raw payload stream from request @@ -42,11 +43,12 @@ use crate::{dev, http::header, web, Error, FromRequest, HttpMessage, HttpRequest /// Ok(format!("Request Body Bytes:\n{:?}", bytes)) /// } /// ``` -pub struct Payload(pub crate::dev::Payload); +pub struct Payload(dev::Payload); impl Payload { /// Unwrap to inner Payload type. - pub fn into_inner(self) -> crate::dev::Payload { + #[inline] + pub fn into_inner(self) -> dev::Payload { self.0 } } @@ -60,9 +62,8 @@ impl Stream for Payload { } } -/// See [here](#usage) for example of usage as an extractor. +/// See [here](#Examples) for example of usage as an extractor. impl FromRequest for Payload { - type Config = PayloadConfig; type Error = Error; type Future = Ready>; @@ -89,9 +90,8 @@ impl FromRequest for Payload { /// } /// ``` impl FromRequest for Bytes { - type Config = PayloadConfig; type Error = Error; - type Future = Either, Ready>>; + type Future = Either>>; #[inline] fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { @@ -99,12 +99,25 @@ impl FromRequest for Bytes { let cfg = PayloadConfig::from_req(req); if let Err(err) = cfg.check_mimetype(req) { - return Either::Right(ready(Err(err))); + return Either::right(ready(Err(err))); } - let limit = cfg.limit; - let fut = HttpMessageBody::new(req, payload).limit(limit); - Either::Left(fut.err_into()) + Either::left(BytesExtractFut { + body_fut: HttpMessageBody::new(req, payload).limit(cfg.limit), + }) + } +} + +/// Future for `Bytes` extractor. +pub struct BytesExtractFut { + body_fut: HttpMessageBody, +} + +impl<'a> Future for BytesExtractFut { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.body_fut).poll(cx).map_err(Into::into) } } @@ -112,8 +125,7 @@ impl FromRequest for Bytes { /// /// Text extractor automatically decode body according to the request's charset. /// -/// [**PayloadConfig**](PayloadConfig) allows to configure -/// extraction process. +/// Use [`PayloadConfig`] to configure extraction process. /// /// # Examples /// ``` @@ -125,7 +137,6 @@ impl FromRequest for Bytes { /// format!("Body {}!", text) /// } impl FromRequest for String { - type Config = PayloadConfig; type Error = Error; type Future = Either>>; @@ -135,21 +146,22 @@ impl FromRequest for String { // check content-type if let Err(err) = cfg.check_mimetype(req) { - return Either::Right(ready(Err(err))); + return Either::right(ready(Err(err))); } // check charset let encoding = match req.encoding() { Ok(enc) => enc, - Err(err) => return Either::Right(ready(Err(err.into()))), + Err(err) => return Either::right(ready(Err(err.into()))), }; let limit = cfg.limit; let body_fut = HttpMessageBody::new(req, payload).limit(limit); - Either::Left(StringExtractFut { body_fut, encoding }) + Either::left(StringExtractFut { body_fut, encoding }) } } +/// Future for `String` extractor. pub struct StringExtractFut { body_fut: HttpMessageBody, encoding: &'static Encoding, @@ -176,21 +188,22 @@ fn bytes_to_string(body: Bytes, encoding: &'static Encoding) -> Result, - #[cfg(feature = "compress")] + #[cfg(feature = "__compress")] stream: dev::Decompress, - #[cfg(not(feature = "compress"))] + #[cfg(not(feature = "__compress"))] stream: dev::Payload, buf: BytesMut, err: Option, @@ -298,10 +312,15 @@ impl HttpMessageBody { } } - #[cfg(feature = "compress")] - let stream = dev::Decompress::from_headers(payload.take(), req.headers()); - #[cfg(not(feature = "compress"))] - let stream = payload.take(); + let stream = { + cfg_if::cfg_if! { + if #[cfg(feature = "__compress")] { + dev::Decompress::from_headers(payload.take(), req.headers()) + } else { + payload.take() + } + } + }; HttpMessageBody { stream, @@ -379,6 +398,8 @@ mod tests { assert!(cfg.check_mimetype(&req).is_ok()); } + // allow deprecated App::data + #[allow(deprecated)] #[actix_rt::test] async fn test_config_recall_locations() { async fn bytes_handler(_: Bytes) -> impl Responder { diff --git a/src/types/query.rs b/src/types/query.rs index 691a4792b..97d17123d 100644 --- a/src/types/query.rs +++ b/src/types/query.rs @@ -2,15 +2,15 @@ use std::{fmt, ops, sync::Arc}; -use futures_util::future::{err, ok, Ready}; -use serde::de; +use actix_utils::future::{err, ok, Ready}; +use serde::de::DeserializeOwned; use crate::{dev::Payload, error::QueryPayloadError, Error, FromRequest, HttpRequest}; /// Extract typed information from the request's query. /// /// To extract typed data from the URL query string, the inner type `T` must implement the -/// [`serde::Deserialize`] trait. +/// [`DeserializeOwned`] trait. /// /// Use [`QueryConfig`] to configure extraction process. /// @@ -29,7 +29,7 @@ use crate::{dev::Payload, error::QueryPayloadError, Error, FromRequest, HttpRequ /// Code /// } /// -/// #[derive(Deserialize)] +/// #[derive(Debug, Deserialize)] /// pub struct AuthRequest { /// id: u64, /// response_type: ResponseType, @@ -42,17 +42,33 @@ use crate::{dev::Payload, error::QueryPayloadError, Error, FromRequest, HttpRequ /// async fn index(info: web::Query) -> String { /// format!("Authorization request for id={} and type={:?}!", info.id, info.response_type) /// } +/// +/// // To access the entire underlying query struct, use `.into_inner()`. +/// #[get("/debug1")] +/// async fn debug1(info: web::Query) -> String { +/// dbg!("Authorization object = {:?}", info.into_inner()); +/// "OK".to_string() +/// } +/// +/// // Or use destructuring, which is equivalent to `.into_inner()`. +/// #[get("/debug2")] +/// async fn debug2(web::Query(info): web::Query) -> String { +/// dbg!("Authorization object = {:?}", info); +/// "OK".to_string() +/// } /// ``` -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct Query(T); +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct Query(pub T); impl Query { /// Unwrap into inner `T` value. pub fn into_inner(self) -> T { self.0 } +} - /// Deserialize `T` from a URL encoded query parameter string. +impl Query { + /// Deserialize a `T` from the URL encoded query parameter string. /// /// ``` /// # use std::collections::HashMap; @@ -62,10 +78,7 @@ impl Query { /// assert_eq!(numbers.get("two"), Some(&2)); /// assert!(numbers.get("three").is_none()); /// ``` - pub fn from_query(query_str: &str) -> Result - where - T: de::DeserializeOwned, - { + pub fn from_query(query_str: &str) -> Result { serde_urlencoded::from_str::(query_str) .map(Self) .map_err(QueryPayloadError::Deserialize) @@ -86,33 +99,22 @@ impl ops::DerefMut for Query { } } -impl fmt::Debug for Query { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - impl fmt::Display for Query { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } -/// See [here](#usage) for example of usage as an extractor. -impl FromRequest for Query -where - T: de::DeserializeOwned, -{ +/// See [here](#Examples) for example of usage as an extractor. +impl FromRequest for Query { type Error = Error; type Future = Ready>; - type Config = QueryConfig; #[inline] fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let error_handler = req - .app_data::() - .map(|c| c.err_handler.clone()) - .unwrap_or(None); + .app_data::() + .and_then(|c| c.err_handler.clone()); serde_urlencoded::from_str::(req.query_string()) .map(|val| ok(Query(val))) @@ -165,7 +167,7 @@ where /// .app_data(query_cfg) /// .service(index); /// ``` -#[derive(Clone)] +#[derive(Clone, Default)] pub struct QueryConfig { err_handler: Option Error + Send + Sync>>, } @@ -181,22 +183,14 @@ impl QueryConfig { } } -impl Default for QueryConfig { - fn default() -> Self { - QueryConfig { err_handler: None } - } -} - #[cfg(test)] mod tests { - use actix_http::http::StatusCode; + use actix_http::StatusCode; use derive_more::Display; use serde::Deserialize; use super::*; - use crate::error::InternalError; - use crate::test::TestRequest; - use crate::HttpResponse; + use crate::{error::InternalError, test::TestRequest, HttpResponse}; #[derive(Deserialize, Debug, Display)] struct Id { @@ -206,13 +200,16 @@ mod tests { #[actix_rt::test] async fn test_service_request_extract() { let req = TestRequest::with_uri("/name/user1/").to_srv_request(); - assert!(Query::::from_query(&req.query_string()).is_err()); + assert!(Query::::from_query(req.query_string()).is_err()); let req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request(); - let mut s = Query::::from_query(&req.query_string()).unwrap(); + let mut s = Query::::from_query(req.query_string()).unwrap(); assert_eq!(s.id, "test"); - assert_eq!(format!("{}, {:?}", s, s), "test, Id { id: \"test\" }"); + assert_eq!( + format!("{}, {:?}", s, s), + "test, Query(Id { id: \"test\" })" + ); s.id = "test1".to_string(); let s = s.into_inner(); @@ -230,7 +227,10 @@ mod tests { let mut s = Query::::from_request(&req, &mut pl).await.unwrap(); assert_eq!(s.id, "test"); - assert_eq!(format!("{}, {:?}", s, s), "test, Id { id: \"test\" }"); + assert_eq!( + format!("{}, {:?}", s, s), + "test, Query(Id { id: \"test\" })" + ); s.id = "test1".to_string(); let s = s.into_inner(); diff --git a/src/types/readlines.rs b/src/types/readlines.rs index b8bdcc504..6c456e21c 100644 --- a/src/types/readlines.rs +++ b/src/types/readlines.rs @@ -177,7 +177,7 @@ where #[cfg(test)] mod tests { - use futures_util::stream::StreamExt; + use futures_util::stream::StreamExt as _; use super::*; use crate::test::TestRequest; diff --git a/src/web.rs b/src/web.rs index 1cef37109..4858600af 100644 --- a/src/web.rs +++ b/src/web.rs @@ -1,49 +1,39 @@ //! Essentials helper functions and types for application registration. -use actix_http::http::Method; -use actix_router::IntoPattern; use std::future::Future; -pub use actix_http::Response as HttpResponse; +use actix_router::IntoPatterns; pub use bytes::{Buf, BufMut, Bytes, BytesMut}; -use crate::error::BlockingError; -use crate::extract::FromRequest; -use crate::handler::Handler; -use crate::resource::Resource; -use crate::responder::Responder; -use crate::route::Route; -use crate::scope::Scope; -use crate::service::WebService; +use crate::{ + error::BlockingError, http::Method, service::WebService, FromRequest, Handler, Resource, + Responder, Route, Scope, +}; pub use crate::config::ServiceConfig; pub use crate::data::Data; pub use crate::request::HttpRequest; pub use crate::request_data::ReqData; +pub use crate::response::HttpResponse; pub use crate::types::*; -/// Create resource for a specific path. +/// Creates a new resource for a specific path. /// -/// Resources may have variable path segments. For example, a -/// resource with the path `/a/{name}/c` would match all incoming -/// requests with paths such as `/a/b/c`, `/a/1/c`, or `/a/etc/c`. +/// Resources may have dynamic path segments. For example, a resource with the path `/a/{name}/c` +/// would match all incoming requests with paths such as `/a/b/c`, `/a/1/c`, or `/a/etc/c`. /// -/// A variable segment is specified in the form `{identifier}`, -/// where the identifier can be used later in a request handler to -/// access the matched value for that segment. This is done by -/// looking up the identifier in the `Params` object returned by -/// `HttpRequest.match_info()` method. +/// A dynamic segment is specified in the form `{identifier}`, where the identifier can be used +/// later in a request handler to access the matched value for that segment. This is done by looking +/// up the identifier in the `Path` object returned by [`HttpRequest.match_info()`] method. /// /// By default, each segment matches the regular expression `[^{}/]+`. /// /// You can also specify a custom regex in the form `{identifier:regex}`: /// -/// For instance, to route `GET`-requests on any route matching -/// `/users/{userid}/{friend}` and store `userid` and `friend` in -/// the exposed `Params` object: +/// For instance, to route `GET`-requests on any route matching `/users/{userid}/{friend}` and store +/// `userid` and `friend` in the exposed `Path` object: /// -/// ```rust -/// # extern crate actix_web; +/// ``` /// use actix_web::{web, App, HttpResponse}; /// /// let app = App::new().service( @@ -52,16 +42,27 @@ pub use crate::types::*; /// .route(web::head().to(|| HttpResponse::MethodNotAllowed())) /// ); /// ``` -pub fn resource(path: T) -> Resource { +pub fn resource(path: T) -> Resource { Resource::new(path) } -/// Configure scope for common root path. +/// Creates scope for common path prefix. /// -/// Scopes collect multiple paths under a common path prefix. -/// Scope path can contain variable path segments as resources. +/// Scopes collect multiple paths under a common path prefix. The scope's path can contain dynamic +/// path segments. /// -/// ```rust +/// # Avoid Trailing Slashes +/// Avoid using trailing slashes in the scope prefix (e.g., `web::scope("/scope/")`). It will almost +/// certainly not have the expected behavior. See the [documentation on resource definitions][pat] +/// to understand why this is the case and how to correctly construct scope/prefix definitions. +/// +/// # Examples +/// In this example, three routes are set up (and will handle any method): +/// - `/{project_id}/path1` +/// - `/{project_id}/path2` +/// - `/{project_id}/path3` +/// +/// ``` /// use actix_web::{web, App, HttpResponse}; /// /// let app = App::new().service( @@ -72,149 +73,51 @@ pub fn resource(path: T) -> Resource { /// ); /// ``` /// -/// In the above example, three routes get added: -/// * /{project_id}/path1 -/// * /{project_id}/path2 -/// * /{project_id}/path3 -/// +/// [pat]: crate::dev::ResourceDef#prefix-resources pub fn scope(path: &str) -> Scope { Scope::new(path) } -/// Create *route* without configuration. +/// Creates a new un-configured route. pub fn route() -> Route { Route::new() } -/// Create *route* with `GET` method guard. -/// -/// ```rust -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::get().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `GET` route gets added: -/// * /{project_id} -/// -pub fn get() -> Route { - method(Method::GET) +macro_rules! method_route { + ($method_fn:ident, $method_const:ident) => { + #[doc = concat!(" Creates a new route with `", stringify!($method_const), "` method guard.")] + /// + /// # Examples + #[doc = concat!(" In this example, one `", stringify!($method_const), " /{project_id}` route is set up:")] + /// ``` + /// use actix_web::{web, App, HttpResponse}; + /// + /// let app = App::new().service( + /// web::resource("/{project_id}") + #[doc = concat!(" .route(web::", stringify!($method_fn), "().to(|| HttpResponse::Ok()))")] + /// + /// ); + /// ``` + pub fn $method_fn() -> Route { + method(Method::$method_const) + } + }; } -/// Create *route* with `POST` method guard. -/// -/// ```rust -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::post().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `POST` route gets added: -/// * /{project_id} -/// -pub fn post() -> Route { - method(Method::POST) -} +method_route!(get, GET); +method_route!(post, POST); +method_route!(put, PUT); +method_route!(patch, PATCH); +method_route!(delete, DELETE); +method_route!(head, HEAD); +method_route!(trace, TRACE); -/// Create *route* with `PUT` method guard. +/// Creates a new route with specified method guard. /// -/// ```rust -/// use actix_web::{web, App, HttpResponse}; +/// # Examples +/// In this example, one `GET /{project_id}` route is set up: /// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::put().to(|| HttpResponse::Ok())) -/// ); /// ``` -/// -/// In the above example, one `PUT` route gets added: -/// * /{project_id} -/// -pub fn put() -> Route { - method(Method::PUT) -} - -/// Create *route* with `PATCH` method guard. -/// -/// ```rust -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::patch().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `PATCH` route gets added: -/// * /{project_id} -/// -pub fn patch() -> Route { - method(Method::PATCH) -} - -/// Create *route* with `DELETE` method guard. -/// -/// ```rust -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::delete().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `DELETE` route gets added: -/// * /{project_id} -/// -pub fn delete() -> Route { - method(Method::DELETE) -} - -/// Create *route* with `HEAD` method guard. -/// -/// ```rust -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::head().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `HEAD` route gets added: -/// * /{project_id} -/// -pub fn head() -> Route { - method(Method::HEAD) -} - -/// Create *route* with `TRACE` method guard. -/// -/// ```rust -/// use actix_web::{web, App, HttpResponse}; -/// -/// let app = App::new().service( -/// web::resource("/{project_id}") -/// .route(web::trace().to(|| HttpResponse::Ok())) -/// ); -/// ``` -/// -/// In the above example, one `HEAD` route gets added: -/// * /{project_id} -/// -pub fn trace() -> Route { - method(Method::TRACE) -} - -/// Create *route* and add method guard. -/// -/// ```rust /// use actix_web::{web, http, App, HttpResponse}; /// /// let app = App::new().service( @@ -222,17 +125,13 @@ pub fn trace() -> Route { /// .route(web::method(http::Method::GET).to(|| HttpResponse::Ok())) /// ); /// ``` -/// -/// In the above example, one `GET` route gets added: -/// * /{project_id} -/// pub fn method(method: Method) -> Route { Route::new().method(method) } -/// Create a new route and add handler. +/// Creates a new any-method route with handler. /// -/// ```rust +/// ``` /// use actix_web::{web, App, HttpResponse, Responder}; /// /// async fn index() -> impl Responder { @@ -244,19 +143,18 @@ pub fn method(method: Method) -> Route { /// web::to(index)) /// ); /// ``` -pub fn to(handler: F) -> Route +pub fn to(handler: F) -> Route where - F: Handler, - I: FromRequest + 'static, - R: Future + 'static, - R::Output: Responder + 'static, + F: Handler, + Args: FromRequest + 'static, + F::Output: Responder + 'static, { Route::new().to(handler) } -/// Create raw service for a specific path. +/// Creates a raw service for a specific path. /// -/// ```rust +/// ``` /// use actix_web::{dev, web, guard, App, Error, HttpResponse}; /// /// async fn my_service(req: dev::ServiceRequest) -> Result { @@ -269,12 +167,12 @@ where /// .finish(my_service) /// ); /// ``` -pub fn service(path: T) -> WebService { +pub fn service(path: T) -> WebService { WebService::new(path) } -/// Execute blocking function on a thread pool, returns future that resolves -/// to result of the function execution. +/// Executes blocking function on a thread pool, returns future that resolves to result of the +/// function execution. pub fn block(f: F) -> impl Future> where F: FnOnce() -> R + Send + 'static, diff --git a/tests/compression.rs b/tests/compression.rs new file mode 100644 index 000000000..88c462f60 --- /dev/null +++ b/tests/compression.rs @@ -0,0 +1,307 @@ +use actix_http::ContentEncoding; +use actix_web::{ + http::{header, StatusCode}, + middleware::Compress, + web, App, HttpResponse, +}; +use bytes::Bytes; + +mod utils; + +static LOREM: &[u8] = include_bytes!("fixtures/lorem.txt"); +static LOREM_GZIP: &[u8] = include_bytes!("fixtures/lorem.txt.gz"); +static LOREM_BR: &[u8] = include_bytes!("fixtures/lorem.txt.br"); +static LOREM_ZSTD: &[u8] = include_bytes!("fixtures/lorem.txt.zst"); +static LOREM_XZ: &[u8] = include_bytes!("fixtures/lorem.txt.xz"); + +macro_rules! test_server { + () => { + actix_test::start(|| { + App::new() + .wrap(Compress::default()) + .route("/static", web::to(|| HttpResponse::Ok().body(LOREM))) + .route( + "/static-gzip", + web::to(|| { + HttpResponse::Ok() + // signal to compressor that content should not be altered + // signal to client that content is encoded + .insert_header(ContentEncoding::Gzip) + .body(LOREM_GZIP) + }), + ) + .route( + "/static-br", + web::to(|| { + HttpResponse::Ok() + // signal to compressor that content should not be altered + // signal to client that content is encoded + .insert_header(ContentEncoding::Brotli) + .body(LOREM_BR) + }), + ) + .route( + "/static-zstd", + web::to(|| { + HttpResponse::Ok() + // signal to compressor that content should not be altered + // signal to client that content is encoded + .insert_header(ContentEncoding::Zstd) + .body(LOREM_ZSTD) + }), + ) + .route( + "/static-xz", + web::to(|| { + HttpResponse::Ok() + // signal to compressor that content should not be altered + // signal to client that content is encoded as 7zip + .insert_header((header::CONTENT_ENCODING, "xz")) + .body(LOREM_XZ) + }), + ) + .route( + "/echo", + web::to(|body: Bytes| HttpResponse::Ok().body(body)), + ) + }) + }; +} + +#[actix_rt::test] +async fn negotiate_encoding_identity() { + let srv = test_server!(); + + let req = srv + .post("/static") + .insert_header((header::ACCEPT_ENCODING, "identity")) + .send(); + + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get(header::CONTENT_ENCODING), None); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(LOREM)); + + srv.stop().await; +} + +#[actix_rt::test] +async fn negotiate_encoding_gzip() { + let srv = test_server!(); + + let req = srv + .post("/static") + .insert_header((header::ACCEPT_ENCODING, "gzip,br,zstd")) + .send(); + + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "gzip"); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(LOREM)); + + let mut res = srv + .post("/static") + .no_decompress() + .insert_header((header::ACCEPT_ENCODING, "gzip,br,zstd")) + .send() + .await + .unwrap(); + let bytes = res.body().await.unwrap(); + assert_eq!(utils::gzip::decode(bytes), LOREM); + + srv.stop().await; +} + +#[actix_rt::test] +async fn negotiate_encoding_br() { + let srv = test_server!(); + + let req = srv + .post("/static") + .insert_header((header::ACCEPT_ENCODING, "br,zstd,gzip")) + .send(); + + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "br"); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(LOREM)); + + let mut res = srv + .post("/static") + .no_decompress() + .insert_header((header::ACCEPT_ENCODING, "br,zstd,gzip")) + .send() + .await + .unwrap(); + let bytes = res.body().await.unwrap(); + assert_eq!(utils::brotli::decode(bytes), LOREM); + + srv.stop().await; +} + +#[actix_rt::test] +async fn negotiate_encoding_zstd() { + let srv = test_server!(); + + let req = srv + .post("/static") + .insert_header((header::ACCEPT_ENCODING, "zstd,gzip,br")) + .send(); + + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "zstd"); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(LOREM)); + + let mut res = srv + .post("/static") + .no_decompress() + .insert_header((header::ACCEPT_ENCODING, "zstd,gzip,br")) + .send() + .await + .unwrap(); + let bytes = res.body().await.unwrap(); + assert_eq!(utils::zstd::decode(bytes), LOREM); + + srv.stop().await; +} + +#[cfg(all( + feature = "compress-brotli", + feature = "compress-gzip", + feature = "compress-zstd", +))] +#[actix_rt::test] +async fn client_encoding_prefers_brotli() { + let srv = test_server!(); + + let req = srv.post("/static").send(); + + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "br"); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(LOREM)); + + srv.stop().await; +} + +#[actix_rt::test] +async fn gzip_no_decompress() { + let srv = test_server!(); + + let req = srv + .post("/static-gzip") + // don't decompress response body + .no_decompress() + // signal that we want a compressed body + .insert_header((header::ACCEPT_ENCODING, "gzip,br,zstd")) + .send(); + + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "gzip"); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(LOREM_GZIP)); + + srv.stop().await; +} + +#[actix_rt::test] +async fn manual_custom_coding() { + let srv = test_server!(); + + let req = srv + .post("/static-xz") + // don't decompress response body + .no_decompress() + // signal that we want a compressed body + .insert_header((header::ACCEPT_ENCODING, "xz")) + .send(); + + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "xz"); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(LOREM_XZ)); + + srv.stop().await; +} + +#[actix_rt::test] +async fn deny_identity_coding() { + let srv = test_server!(); + + let req = srv + .post("/static") + // signal that we want a compressed body + .insert_header((header::ACCEPT_ENCODING, "br, identity;q=0")) + .send(); + + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "br"); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(LOREM)); + + srv.stop().await; +} + +#[actix_rt::test] +async fn deny_identity_coding_no_decompress() { + let srv = test_server!(); + + let req = srv + .post("/static-br") + // don't decompress response body + .no_decompress() + // signal that we want a compressed body + .insert_header((header::ACCEPT_ENCODING, "br, identity;q=0")) + .send(); + + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "br"); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(LOREM_BR)); + + srv.stop().await; +} + +// TODO: fix test +// currently fails because negotiation doesn't consider unknown encoding types +#[ignore] +#[actix_rt::test] +async fn deny_identity_for_manual_coding() { + let srv = test_server!(); + + let req = srv + .post("/static-xz") + // don't decompress response body + .no_decompress() + // signal that we want a compressed body + .insert_header((header::ACCEPT_ENCODING, "xz, identity;q=0")) + .send(); + + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "xz"); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(LOREM_XZ)); + + srv.stop().await; +} diff --git a/tests/fixtures/lorem.txt b/tests/fixtures/lorem.txt new file mode 100644 index 000000000..a2abb6432 --- /dev/null +++ b/tests/fixtures/lorem.txt @@ -0,0 +1,5 @@ +Lorem ipsum dolor sit amet, consectetur adipiscing elit. Proin interdum tincidunt lacus, sed tempor lorem consectetur et. Pellentesque et egestas sem, at cursus massa. Nunc feugiat elit sit amet ipsum commodo luctus. Proin auctor dignissim pharetra. Integer iaculis quam a tellus auctor, vitae auctor nisl viverra. Nullam consequat maximus porttitor. Pellentesque tortor enim, molestie at varius non, tempor non nibh. Suspendisse tempus erat lorem, vel faucibus magna blandit vel. Sed pellentesque ligula augue, vitae fermentum eros blandit et. Cras dignissim in massa ut varius. Vestibulum commodo nunc sit amet pellentesque dignissim. + +Donec imperdiet blandit lobortis. Suspendisse fringilla nunc quis venenatis tempor. Nunc tempor sed erat sed convallis. Pellentesque aliquet elit lectus, quis vulputate arcu pharetra sed. Etiam laoreet aliquet arcu cursus vehicula. Maecenas odio odio, elementum faucibus sollicitudin vitae, pellentesque ac purus. Donec venenatis faucibus lorem, et finibus lacus tincidunt vitae. Quisque laoreet metus sapien, vitae euismod mauris lobortis malesuada. Integer sit amet elementum turpis. Maecenas ex mauris, dapibus eu placerat vitae, rutrum convallis enim. Nulla vitae orci ultricies, sagittis turpis et, lacinia dui. Praesent egestas urna turpis, sit amet feugiat mauris tristique eu. Quisque id tempor libero. Donec ullamcorper dapibus lorem, vel consequat risus congue a. + +Nullam dignissim ut lectus vitae tempor. Pellentesque ut odio fringilla, volutpat mi et, vulputate tellus. Fusce eget diam non odio tincidunt viverra eu vel augue. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Nullam sed eleifend purus, vitae aliquam orci. Cras fringilla justo eget tempus bibendum. Phasellus venenatis, odio nec pulvinar commodo, quam neque lacinia turpis, ut rutrum tortor massa eu nulla. Vivamus tincidunt ut lectus a gravida. Donec varius mi quis enim interdum ultrices. Sed aliquam consectetur nisi vitae viverra. Praesent nec ligula egestas, porta lectus sed, consectetur augue. diff --git a/tests/fixtures/lorem.txt.br b/tests/fixtures/lorem.txt.br new file mode 100644 index 000000000..e12d7510e Binary files /dev/null and b/tests/fixtures/lorem.txt.br differ diff --git a/tests/fixtures/lorem.txt.gz b/tests/fixtures/lorem.txt.gz new file mode 100644 index 000000000..d3a0664a3 Binary files /dev/null and b/tests/fixtures/lorem.txt.gz differ diff --git a/tests/fixtures/lorem.txt.xz b/tests/fixtures/lorem.txt.xz new file mode 100644 index 000000000..bb45fe753 Binary files /dev/null and b/tests/fixtures/lorem.txt.xz differ diff --git a/tests/fixtures/lorem.txt.zst b/tests/fixtures/lorem.txt.zst new file mode 100644 index 000000000..13bea6d7a Binary files /dev/null and b/tests/fixtures/lorem.txt.zst differ diff --git a/tests/test-macro-import-conflict.rs b/tests/test-macro-import-conflict.rs new file mode 100644 index 000000000..0d23bb41d --- /dev/null +++ b/tests/test-macro-import-conflict.rs @@ -0,0 +1,15 @@ +//! Checks that test macro does not cause problems in the presence of imports named "test" that +//! could be either a module with test items or the "test with runtime" macro itself. +//! +//! Before actix/actix-net#399 was implemented, this macro was running twice. The first run output +//! `#[test]` and it got run again and since it was in scope. +//! +//! Prevented by using the fully-qualified test marker (`#[::core::prelude::v1::test]`). + +use actix_web::test; + +#[actix_web::test] +async fn test_macro_naming_conflict() { + let _req = test::TestRequest::default(); + assert_eq!(async { 1 }.await, 1); +} diff --git a/tests/test_error_propagation.rs b/tests/test_error_propagation.rs new file mode 100644 index 000000000..958276b62 --- /dev/null +++ b/tests/test_error_propagation.rs @@ -0,0 +1,99 @@ +use std::sync::Arc; + +use actix_utils::future::{ok, Ready}; +use actix_web::{ + dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, + get, + test::{call_service, init_service, TestRequest}, + ResponseError, +}; +use futures_core::future::LocalBoxFuture; +use futures_util::lock::Mutex; + +#[derive(Debug, Clone)] +pub struct MyError; + +impl ResponseError for MyError {} + +impl std::fmt::Display for MyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "A custom error") + } +} + +#[get("/test")] +async fn test() -> Result { + Err(MyError.into()) +} + +#[derive(Clone)] +pub struct SpyMiddleware(Arc>>); + +impl Transform for SpyMiddleware +where + S: Service, Error = actix_web::Error>, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse; + type Error = actix_web::Error; + type Transform = Middleware; + type InitError = (); + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(Middleware { + was_error: self.0.clone(), + service, + }) + } +} + +#[doc(hidden)] +pub struct Middleware { + was_error: Arc>>, + service: S, +} + +impl Service for Middleware +where + S: Service, Error = actix_web::Error>, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse; + type Error = actix_web::Error; + type Future = LocalBoxFuture<'static, Result>; + + forward_ready!(service); + + fn call(&self, req: ServiceRequest) -> Self::Future { + let lock = self.was_error.clone(); + let response_future = self.service.call(req); + Box::pin(async move { + let response = response_future.await; + if let Ok(success) = &response { + *lock.lock().await = Some(success.response().error().is_some()); + } + response + }) + } +} + +#[actix_rt::test] +async fn error_cause_should_be_propagated_to_middlewares() { + let lock = Arc::new(Mutex::new(None)); + let spy_middleware = SpyMiddleware(lock.clone()); + + let app = init_service( + actix_web::App::new() + .wrap(spy_middleware.clone()) + .service(test), + ) + .await; + + call_service(&app, TestRequest::with_uri("/test").to_request()).await; + + let was_error_captured = lock.lock().await.unwrap(); + assert!(was_error_captured); +} diff --git a/tests/test_httpserver.rs b/tests/test_httpserver.rs index 3aa1d36b0..464a650a2 100644 --- a/tests/test_httpserver.rs +++ b/tests/test_httpserver.rs @@ -1,78 +1,62 @@ -use std::sync::mpsc; -use std::{thread, time::Duration}; - #[cfg(feature = "openssl")] extern crate tls_openssl as openssl; -#[cfg(feature = "rustls")] -extern crate tls_rustls as rustls; -#[cfg(feature = "openssl")] -use openssl::ssl::SslAcceptorBuilder; - -use actix_web::{test, web, App, HttpResponse, HttpServer}; +#[cfg(any(unix, feature = "openssl"))] +use { + actix_web::{web, App, HttpResponse, HttpServer}, + std::{sync::mpsc, thread, time::Duration}, +}; #[cfg(unix)] #[actix_rt::test] async fn test_start() { - let addr = test::unused_addr(); + let addr = actix_test::unused_addr(); let (tx, rx) = mpsc::channel(); thread::spawn(move || { - let sys = actix_rt::System::new(); + actix_rt::System::new() + .block_on(async { + let srv = HttpServer::new(|| { + App::new().service( + web::resource("/").route(web::to(|| HttpResponse::Ok().body("test"))), + ) + }) + .workers(1) + .backlog(1) + .max_connections(10) + .max_connection_rate(10) + .keep_alive(10) + .client_timeout(5000) + .client_shutdown(0) + .server_hostname("localhost") + .system_exit() + .disable_signals() + .bind(format!("{}", addr)) + .unwrap() + .run(); - sys.block_on(async { - let srv = HttpServer::new(|| { - App::new().service( - web::resource("/").route(web::to(|| HttpResponse::Ok().body("test"))), - ) + tx.send(srv.handle()).unwrap(); + + srv.await }) - .workers(1) - .backlog(1) - .max_connections(10) - .max_connection_rate(10) - .keep_alive(10) - .client_timeout(5000) - .client_shutdown(0) - .server_hostname("localhost") - .system_exit() - .disable_signals() - .bind(format!("{}", addr)) - .unwrap() - .run(); - - let _ = tx.send((srv, actix_rt::System::current())); - }); - - let _ = sys.run(); + .unwrap(); }); - let (srv, sys) = rx.recv().unwrap(); - #[cfg(feature = "client")] - { - use actix_http::client; + let srv = rx.recv().unwrap(); - let client = awc::Client::builder() - .connector( - client::Connector::new() - .timeout(Duration::from_millis(100)) - .finish(), - ) - .finish(); + let client = awc::Client::builder() + .connector(awc::Connector::new().timeout(Duration::from_millis(100))) + .finish(); - let host = format!("http://{}", addr); - let response = client.get(host.clone()).send().await.unwrap(); - assert!(response.status().is_success()); - } + let host = format!("http://{}", addr); + let response = client.get(host.clone()).send().await.unwrap(); + assert!(response.status().is_success()); - // stop - let _ = srv.stop(false); - - thread::sleep(Duration::from_millis(100)); - let _ = sys.stop(); + srv.stop(false).await; } #[cfg(feature = "openssl")] -fn ssl_acceptor() -> std::io::Result { +fn ssl_acceptor() -> openssl::ssl::SslAcceptorBuilder { use openssl::{ pkey::PKey, ssl::{SslAcceptor, SslMethod}, @@ -89,44 +73,45 @@ fn ssl_acceptor() -> std::io::Result { builder.set_certificate(&cert).unwrap(); builder.set_private_key(&key).unwrap(); - Ok(builder) + builder } #[actix_rt::test] #[cfg(feature = "openssl")] async fn test_start_ssl() { use actix_web::HttpRequest; + use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; - let addr = test::unused_addr(); + let addr = actix_test::unused_addr(); let (tx, rx) = mpsc::channel(); thread::spawn(move || { - let sys = actix_rt::System::new(); - let builder = ssl_acceptor().unwrap(); + actix_rt::System::new() + .block_on(async { + let builder = ssl_acceptor(); - let srv = HttpServer::new(|| { - App::new().service(web::resource("/").route(web::to(|req: HttpRequest| { - assert!(req.app_config().secure()); - HttpResponse::Ok().body("test") - }))) - }) - .workers(1) - .shutdown_timeout(1) - .system_exit() - .disable_signals() - .bind_openssl(format!("{}", addr), builder) - .unwrap(); + let srv = HttpServer::new(|| { + App::new().service(web::resource("/").route(web::to(|req: HttpRequest| { + assert!(req.app_config().secure()); + HttpResponse::Ok().body("test") + }))) + }) + .workers(1) + .shutdown_timeout(1) + .system_exit() + .disable_signals() + .bind_openssl(format!("{}", addr), builder) + .unwrap(); - sys.block_on(async { - let srv = srv.run(); - let _ = tx.send((srv, actix_rt::System::current())); - }); + let srv = srv.run(); + tx.send(srv.handle()).unwrap(); - let _ = sys.run(); + srv.await + }) + .unwrap() }); - let (srv, sys) = rx.recv().unwrap(); + let srv = rx.recv().unwrap(); - use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_verify(SslVerifyMode::NONE); let _ = builder @@ -136,9 +121,8 @@ async fn test_start_ssl() { let client = awc::Client::builder() .connector( awc::Connector::new() - .ssl(builder.build()) - .timeout(Duration::from_millis(100)) - .finish(), + .openssl(builder.build()) + .timeout(Duration::from_millis(100)), ) .finish(); @@ -146,9 +130,5 @@ async fn test_start_ssl() { let response = client.get(host.clone()).send().await.unwrap(); assert!(response.status().is_success()); - // stop - let _ = srv.stop(false); - - thread::sleep(Duration::from_millis(100)); - let _ = sys.stop(); + srv.stop(false).await; } diff --git a/tests/test_server.rs b/tests/test_server.rs index 2466730f9..987e51a65 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -10,50 +10,29 @@ use std::{ task::{Context, Poll}, }; -use actix_http::http::header::{ - ContentEncoding, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING, +use actix_web::{ + cookie::{Cookie, CookieBuilder}, + http::{header, StatusCode}, + middleware::{Compress, NormalizePath, TrailingSlash}, + web, App, Error, HttpResponse, }; -use brotli2::write::{BrotliDecoder, BrotliEncoder}; use bytes::Bytes; -use flate2::{ - read::GzDecoder, - write::{GzEncoder, ZlibDecoder, ZlibEncoder}, - Compression, -}; -use futures_util::ready; +use futures_core::ready; +use rand::{distributions::Alphanumeric, Rng as _}; + +#[cfg(feature = "openssl")] use openssl::{ pkey::PKey, ssl::{SslAcceptor, SslMethod}, x509::X509, }; -use rand::{distributions::Alphanumeric, Rng}; -use actix_web::dev::BodyEncoding; -use actix_web::middleware::{Compress, NormalizePath, TrailingSlash}; -use actix_web::{dev, test, web, App, Error, HttpResponse}; +mod utils; -const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World \ - Hello World Hello World Hello World Hello World Hello World"; +const S: &str = "Hello World "; +const STR: &str = const_str::repeat!(S, 100); +#[cfg(feature = "openssl")] fn openssl_config() -> SslAcceptor { let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); let cert_file = cert.serialize_pem().unwrap(); @@ -113,166 +92,92 @@ impl futures_core::stream::Stream for TestBody { #[actix_rt::test] async fn test_body() { - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) }); - let mut response = srv.get("/").send().await.unwrap(); - assert!(response.status().is_success()); + let mut res = srv.get("/").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } -#[actix_rt::test] -async fn test_body_gzip() { - let srv = test::start_with(test::config().h1(), || { - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service(web::resource("/").route(web::to(|| HttpResponse::Ok().body(STR)))) - }); +// enforcing an encoding per-response is removed +// #[actix_rt::test] +// async fn test_body_encoding_override() { +// let srv = actix_test::start_with(actix_test::config().h1(), || { +// App::new() +// .wrap(Compress::default()) +// .service(web::resource("/").route(web::to(|| { +// HttpResponse::Ok() +// .encode_with(ContentEncoding::Deflate) +// .body(STR) +// }))) +// .service(web::resource("/raw").route(web::to(|| { +// let mut res = HttpResponse::with_body(actix_web::http::StatusCode::OK, STR); +// res.encode_with(ContentEncoding::Deflate); +// res.map_into_boxed_body() +// }))) +// }); - let mut response = srv - .get("/") - .no_decompress() - .append_header((ACCEPT_ENCODING, "gzip")) - .send() - .await - .unwrap(); - assert!(response.status().is_success()); +// // Builder +// let mut res = srv +// .get("/") +// .no_decompress() +// .append_header((ACCEPT_ENCODING, "deflate")) +// .send() +// .await +// .unwrap(); +// assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); +// let bytes = res.body().await.unwrap(); +// assert_eq!(utils::deflate::decode(bytes), STR.as_bytes()); - // decode - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); -} +// // Raw Response +// let mut res = srv +// .request(actix_web::http::Method::GET, srv.url("/raw")) +// .no_decompress() +// .append_header((ACCEPT_ENCODING, "deflate")) +// .send() +// .await +// .unwrap(); +// assert_eq!(res.status(), StatusCode::OK); + +// let bytes = res.body().await.unwrap(); +// assert_eq!(utils::deflate::decode(bytes), STR.as_bytes()); + +// srv.stop().await; +// } #[actix_rt::test] -async fn test_body_gzip2() { - let srv = test::start_with(test::config().h1(), || { - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service(web::resource("/").route(web::to(|| { - HttpResponse::Ok().body(STR).into_body::() - }))) - }); - - let mut response = srv - .get("/") - .no_decompress() - .append_header((ACCEPT_ENCODING, "gzip")) - .send() - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - - // decode - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); -} - -#[actix_rt::test] -async fn test_body_encoding_override() { - let srv = test::start_with(test::config().h1(), || { - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service(web::resource("/").route(web::to(|| { - HttpResponse::Ok() - .encoding(ContentEncoding::Deflate) - .body(STR) - }))) - .service(web::resource("/raw").route(web::to(|| { - let body = actix_web::dev::Body::Bytes(STR.into()); - let mut response = - HttpResponse::with_body(actix_web::http::StatusCode::OK, body); - - response.encoding(ContentEncoding::Deflate); - - response - }))) - }); - - // Builder - let mut response = srv - .get("/") - .no_decompress() - .append_header((ACCEPT_ENCODING, "deflate")) - .send() - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - - // decode - let mut e = ZlibDecoder::new(Vec::new()); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); - - // Raw Response - let mut response = srv - .request(actix_web::http::Method::GET, srv.url("/raw")) - .no_decompress() - .append_header((ACCEPT_ENCODING, "deflate")) - .send() - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - - // decode - let mut e = ZlibDecoder::new(Vec::new()); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); -} - -#[actix_rt::test] -async fn test_body_gzip_large() { +async fn body_gzip_large() { let data = STR.repeat(10); let srv_data = data.clone(); - let srv = test::start_with(test::config().h1(), move || { + let srv = actix_test::start_with(actix_test::config().h1(), move || { let data = srv_data.clone(); - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service( - web::resource("/") - .route(web::to(move || HttpResponse::Ok().body(data.clone()))), - ) + + App::new().wrap(Compress::default()).service( + web::resource("/").route(web::to(move || HttpResponse::Ok().body(data.clone()))), + ) }); - let mut response = srv + let mut res = srv .get("/") .no_decompress() - .append_header((ACCEPT_ENCODING, "gzip")) + .append_header((header::ACCEPT_ENCODING, "gzip")) .send() .await .unwrap(); - assert!(response.status().is_success()); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); + assert_eq!(utils::gzip::decode(bytes), data.as_bytes()); - // decode - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from(data)); + srv.stop().await; } #[actix_rt::test] @@ -284,126 +189,107 @@ async fn test_body_gzip_large_random() { .collect::(); let srv_data = data.clone(); - let srv = test::start_with(test::config().h1(), move || { + let srv = actix_test::start_with(actix_test::config().h1(), move || { let data = srv_data.clone(); - App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) - .service( - web::resource("/") - .route(web::to(move || HttpResponse::Ok().body(data.clone()))), - ) + App::new().wrap(Compress::default()).service( + web::resource("/").route(web::to(move || HttpResponse::Ok().body(data.clone()))), + ) }); - let mut response = srv + let mut res = srv .get("/") .no_decompress() - .append_header((ACCEPT_ENCODING, "gzip")) + .append_header((header::ACCEPT_ENCODING, "gzip")) .send() .await .unwrap(); - assert!(response.status().is_success()); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); + assert_eq!(utils::gzip::decode(bytes), data.as_bytes()); - // decode - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(dec.len(), data.len()); - assert_eq!(Bytes::from(dec), Bytes::from(data)); + srv.stop().await; } #[actix_rt::test] async fn test_body_chunked_implicit() { - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new() - .wrap(Compress::new(ContentEncoding::Gzip)) + .wrap(Compress::default()) .service(web::resource("/").route(web::get().to(move || { HttpResponse::Ok() .streaming(TestBody::new(Bytes::from_static(STR.as_ref()), 24)) }))) }); - let mut response = srv + let mut res = srv .get("/") .no_decompress() - .append_header((ACCEPT_ENCODING, "gzip")) + .append_header((header::ACCEPT_ENCODING, "gzip")) .send() .await .unwrap(); - assert!(response.status().is_success()); + assert_eq!(res.status(), StatusCode::OK); assert_eq!( - response.headers().get(TRANSFER_ENCODING).unwrap(), - &b"chunked"[..] + res.headers().get(header::TRANSFER_ENCODING).unwrap(), + "chunked" ); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); + assert_eq!(utils::gzip::decode(bytes), STR.as_bytes()); - // decode - let mut e = GzDecoder::new(&bytes[..]); - let mut dec = Vec::new(); - e.read_to_end(&mut dec).unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + srv.stop().await; } #[actix_rt::test] async fn test_body_br_streaming() { - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new() - .wrap(Compress::new(ContentEncoding::Br)) + .wrap(Compress::default()) .service(web::resource("/").route(web::to(move || { HttpResponse::Ok() .streaming(TestBody::new(Bytes::from_static(STR.as_ref()), 24)) }))) }); - let mut response = srv + let mut res = srv .get("/") - .append_header((ACCEPT_ENCODING, "br")) + .append_header((header::ACCEPT_ENCODING, "br")) .no_decompress() .send() .await .unwrap(); - assert!(response.status().is_success()); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); - println!("TEST: {:?}", bytes.len()); + let bytes = res.body().await.unwrap(); + assert_eq!(utils::brotli::decode(bytes), STR.as_bytes()); - // decode br - let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - println!("T: {:?}", Bytes::copy_from_slice(&dec)); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + srv.stop().await; } #[actix_rt::test] async fn test_head_binary() { - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new().service( web::resource("/").route(web::head().to(move || HttpResponse::Ok().body(STR))), ) }); - let mut response = srv.head("/").send().await.unwrap(); - assert!(response.status().is_success()); + let mut res = srv.head("/").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - { - let len = response.headers().get(CONTENT_LENGTH).unwrap(); - assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - } + let len = res.headers().get(header::CONTENT_LENGTH).unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); assert!(bytes.is_empty()); + + srv.stop().await; } #[actix_rt::test] async fn test_no_chunking() { - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new().service(web::resource("/").route(web::to(move || { HttpResponse::Ok() .no_chunking(STR.len() as u64) @@ -411,144 +297,225 @@ async fn test_no_chunking() { }))) }); - let mut response = srv.get("/").send().await.unwrap(); - assert!(response.status().is_success()); - assert!(!response.headers().contains_key(TRANSFER_ENCODING)); + let mut res = srv.get("/").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert!(!res.headers().contains_key(header::TRANSFER_ENCODING)); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] async fn test_body_deflate() { - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new() - .wrap(Compress::new(ContentEncoding::Deflate)) + .wrap(Compress::default()) .service(web::resource("/").route(web::to(move || HttpResponse::Ok().body(STR)))) }); - // client request - let mut response = srv + let mut res = srv .get("/") - .append_header((ACCEPT_ENCODING, "deflate")) + .append_header((header::ACCEPT_ENCODING, "deflate")) .no_decompress() .send() .await .unwrap(); - assert!(response.status().is_success()); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); + assert_eq!(utils::deflate::decode(bytes), STR.as_bytes()); - let mut e = ZlibDecoder::new(Vec::new()); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + srv.stop().await; } #[actix_rt::test] async fn test_body_brotli() { - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new() - .wrap(Compress::new(ContentEncoding::Br)) + .wrap(Compress::default()) .service(web::resource("/").route(web::to(move || HttpResponse::Ok().body(STR)))) }); - // client request - let mut response = srv + let mut res = srv .get("/") - .append_header((ACCEPT_ENCODING, "br")) + .append_header((header::ACCEPT_ENCODING, "br")) .no_decompress() .send() .await .unwrap(); - assert!(response.status().is_success()); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); + assert_eq!(utils::brotli::decode(bytes), STR.as_bytes()); - // decode brotli - let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + srv.stop().await; +} + +#[actix_rt::test] +async fn test_body_zstd() { + let srv = actix_test::start_with(actix_test::config().h1(), || { + App::new() + .wrap(Compress::default()) + .service(web::resource("/").route(web::to(move || HttpResponse::Ok().body(STR)))) + }); + + let mut res = srv + .get("/") + .append_header((header::ACCEPT_ENCODING, "zstd")) + .no_decompress() + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let bytes = res.body().await.unwrap(); + assert_eq!(utils::zstd::decode(bytes), STR.as_bytes()); + + srv.stop().await; +} + +#[actix_rt::test] +async fn test_body_zstd_streaming() { + let srv = actix_test::start_with(actix_test::config().h1(), || { + App::new() + .wrap(Compress::default()) + .service(web::resource("/").route(web::to(move || { + HttpResponse::Ok() + .streaming(TestBody::new(Bytes::from_static(STR.as_ref()), 24)) + }))) + }); + + let mut res = srv + .get("/") + .append_header((header::ACCEPT_ENCODING, "zstd")) + .no_decompress() + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let bytes = res.body().await.unwrap(); + assert_eq!(utils::zstd::decode(bytes), STR.as_bytes()); + + srv.stop().await; +} + +#[actix_rt::test] +async fn test_zstd_encoding() { + let srv = actix_test::start_with(actix_test::config().h1(), || { + App::new().service( + web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), + ) + }); + + let request = srv + .post("/") + .append_header((header::CONTENT_ENCODING, "zstd")) + .send_body(utils::zstd::encode(STR)); + let mut res = request.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; +} + +#[actix_rt::test] +async fn test_zstd_encoding_large() { + let data = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(320_000) + .map(char::from) + .collect::(); + + let srv = actix_test::start_with(actix_test::config().h1(), || { + App::new().service( + web::resource("/") + .app_data(web::PayloadConfig::new(320_000)) + .route(web::to(move |body: Bytes| { + HttpResponse::Ok().streaming(TestBody::new(body, 10240)) + })), + ) + }); + + let request = srv + .post("/") + .append_header((header::CONTENT_ENCODING, "zstd")) + .send_body(utils::zstd::encode(&data)); + let mut res = request.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let bytes = res.body().limit(320_000).await.unwrap(); + assert_eq!(bytes, data.as_bytes()); + + srv.stop().await; } #[actix_rt::test] async fn test_encoding() { - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new().wrap(Compress::default()).service( web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); - // client request - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(STR.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - let request = srv .post("/") - .insert_header((CONTENT_ENCODING, "gzip")) - .send_body(enc.clone()); - let mut response = request.await.unwrap(); - assert!(response.status().is_success()); + .insert_header((header::CONTENT_ENCODING, "gzip")) + .send_body(utils::gzip::encode(STR)); + let mut res = request.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] async fn test_gzip_encoding() { - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new().service( web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); - // client request - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(STR.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - let request = srv .post("/") - .append_header((CONTENT_ENCODING, "gzip")) - .send_body(enc.clone()); - let mut response = request.await.unwrap(); - assert!(response.status().is_success()); + .append_header((header::CONTENT_ENCODING, "gzip")) + .send_body(utils::gzip::encode(STR)); + let mut res = request.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, STR.as_bytes()); + + srv.stop().await; } #[actix_rt::test] async fn test_gzip_encoding_large() { let data = STR.repeat(10); - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new().service( web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); - // client request - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - let request = srv + let req = srv .post("/") - .append_header((CONTENT_ENCODING, "gzip")) - .send_body(enc.clone()); - let mut response = request.await.unwrap(); - assert!(response.status().is_success()); + .append_header((header::CONTENT_ENCODING, "gzip")) + .send_body(utils::gzip::encode(&data)); + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes, Bytes::from(data)); + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, data); + + srv.stop().await; } #[actix_rt::test] @@ -559,79 +526,66 @@ async fn test_reading_gzip_encoding_large_random() { .map(char::from) .collect::(); - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new().service( web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); - // client request - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - let request = srv .post("/") - .append_header((CONTENT_ENCODING, "gzip")) - .send_body(enc.clone()); - let mut response = request.await.unwrap(); - assert!(response.status().is_success()); + .append_header((header::CONTENT_ENCODING, "gzip")) + .send_body(utils::gzip::encode(&data)); + let mut res = request.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data)); + let bytes = res.body().await.unwrap(); + assert_eq!(bytes, data.as_bytes()); + + srv.stop().await; } #[actix_rt::test] async fn test_reading_deflate_encoding() { - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new().service( web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); - let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); - e.write_all(STR.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request let request = srv .post("/") - .append_header((CONTENT_ENCODING, "deflate")) - .send_body(enc.clone()); - let mut response = request.await.unwrap(); - assert!(response.status().is_success()); + .append_header((header::CONTENT_ENCODING, "deflate")) + .send_body(utils::deflate::encode(STR)); + let mut res = request.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] async fn test_reading_deflate_encoding_large() { let data = STR.repeat(10); - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new().service( web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); - let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request let request = srv .post("/") - .append_header((CONTENT_ENCODING, "deflate")) - .send_body(enc.clone()); - let mut response = request.await.unwrap(); - assert!(response.status().is_success()); + .append_header((header::CONTENT_ENCODING, "deflate")) + .send_body(utils::deflate::encode(&data)); + let mut res = request.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } #[actix_rt::test] @@ -642,53 +596,45 @@ async fn test_reading_deflate_encoding_large_random() { .map(char::from) .collect::(); - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new().service( web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); - let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request let request = srv .post("/") - .append_header((CONTENT_ENCODING, "deflate")) - .send_body(enc.clone()); - let mut response = request.await.unwrap(); - assert!(response.status().is_success()); + .append_header((header::CONTENT_ENCODING, "deflate")) + .send_body(utils::deflate::encode(&data)); + let mut res = request.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); assert_eq!(bytes.len(), data.len()); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } #[actix_rt::test] async fn test_brotli_encoding() { - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new().service( web::resource("/").route(web::to(move |body: Bytes| HttpResponse::Ok().body(body))), ) }); - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(STR.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request let request = srv .post("/") - .append_header((CONTENT_ENCODING, "br")) - .send_body(enc.clone()); - let mut response = request.await.unwrap(); - assert!(response.status().is_success()); + .append_header((header::CONTENT_ENCODING, "br")) + .send_body(utils::brotli::encode(STR)); + let mut res = request.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + srv.stop().await; } #[actix_rt::test] @@ -699,7 +645,7 @@ async fn test_brotli_encoding_large() { .map(char::from) .collect::(); - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new().service( web::resource("/") .app_data(web::PayloadConfig::new(320_000)) @@ -709,137 +655,130 @@ async fn test_brotli_encoding_large() { ) }); - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request let request = srv .post("/") - .append_header((CONTENT_ENCODING, "br")) - .send_body(enc.clone()); - let mut response = request.await.unwrap(); - assert!(response.status().is_success()); + .append_header((header::CONTENT_ENCODING, "br")) + .send_body(utils::brotli::encode(&data)); + let mut res = request.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().limit(320_000).await.unwrap(); + let bytes = res.body().limit(320_000).await.unwrap(); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } #[cfg(feature = "openssl")] #[actix_rt::test] async fn test_brotli_encoding_large_openssl() { + use actix_web::http::header; + let data = STR.repeat(10); - let srv = test::start_with(test::config().openssl(openssl_config()), move || { - App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { - HttpResponse::Ok() - .encoding(actix_web::http::ContentEncoding::Identity) - .body(bytes) - }))) - }); + let srv = + actix_test::start_with(actix_test::config().openssl(openssl_config()), move || { + App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { + // echo decompressed request body back in response + HttpResponse::Ok() + .insert_header(header::ContentEncoding::Identity) + .body(bytes) + }))) + }); - // body - let mut e = BrotliEncoder::new(Vec::new(), 3); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request - let mut response = srv + let mut res = srv .post("/") - .append_header((actix_web::http::header::CONTENT_ENCODING, "br")) - .send_body(enc) + .append_header((header::CONTENT_ENCODING, "br")) + .send_body(utils::brotli::encode(&data)) .await .unwrap(); - assert!(response.status().is_success()); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } -#[cfg(all(feature = "rustls", feature = "openssl"))] +#[cfg(feature = "rustls")] mod plus_rustls { use std::io::BufReader; - use rustls::{ - internal::pemfile::{certs, pkcs8_private_keys}, - NoClientAuth, ServerConfig as RustlsServerConfig, - }; + use rustls::{Certificate, PrivateKey, ServerConfig as RustlsServerConfig}; + use rustls_pemfile::{certs, pkcs8_private_keys}; use super::*; - fn rustls_config() -> RustlsServerConfig { + fn tls_config() -> RustlsServerConfig { let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap(); let cert_file = cert.serialize_pem().unwrap(); let key_file = cert.serialize_private_key_pem(); - let mut config = RustlsServerConfig::new(NoClientAuth::new()); let cert_file = &mut BufReader::new(cert_file.as_bytes()); let key_file = &mut BufReader::new(key_file.as_bytes()); - let cert_chain = certs(cert_file).unwrap(); + let cert_chain = certs(cert_file) + .unwrap() + .into_iter() + .map(Certificate) + .collect(); let mut keys = pkcs8_private_keys(key_file).unwrap(); - config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); - config + RustlsServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert_chain, PrivateKey(keys.remove(0))) + .unwrap() } #[actix_rt::test] async fn test_reading_deflate_encoding_large_random_rustls() { - use rustls::internal::pemfile::{certs, pkcs8_private_keys}; - use rustls::{NoClientAuth, ServerConfig}; - use std::fs::File; - use std::io::BufReader; - let data = rand::thread_rng() .sample_iter(&Alphanumeric) .take(160_000) .map(char::from) .collect::(); - let srv = test::start_with(test::config().rustls(rustls_config()), || { + let srv = actix_test::start_with(actix_test::config().rustls(tls_config()), || { App::new().service(web::resource("/").route(web::to(|bytes: Bytes| { + // echo decompressed request body back in response HttpResponse::Ok() - .encoding(actix_web::http::ContentEncoding::Identity) + .insert_header(header::ContentEncoding::Identity) .body(bytes) }))) }); - // encode data - let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request let req = srv .post("/") - .insert_header((actix_web::http::header::CONTENT_ENCODING, "deflate")) - .send_stream(TestBody::new(Bytes::from(enc), 1024)); + .insert_header((header::CONTENT_ENCODING, "deflate")) + .send_stream(TestBody::new( + Bytes::from(utils::deflate::encode(&data)), + 1024, + )); - let mut response = req.await.unwrap(); - assert!(response.status().is_success()); + let mut res = req.await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); - // read response - let bytes = response.body().await.unwrap(); + let bytes = res.body().await.unwrap(); assert_eq!(bytes.len(), data.len()); assert_eq!(bytes, Bytes::from(data)); + + srv.stop().await; } } #[actix_rt::test] async fn test_server_cookies() { - use actix_web::{http, HttpMessage}; + use actix_web::http; - let srv = test::start(|| { + let srv = actix_test::start(|| { App::new().default_service(web::to(|| { HttpResponse::Ok() .cookie( - http::CookieBuilder::new("first", "first_value") + CookieBuilder::new("first", "first_value") .http_only(true) .finish(), ) - .cookie(http::Cookie::new("second", "first_value")) - .cookie(http::Cookie::new("second", "second_value")) + .cookie(Cookie::new("second", "first_value")) + .cookie(Cookie::new("second", "second_value")) .finish() })) }); @@ -848,10 +787,10 @@ async fn test_server_cookies() { let res = req.send().await.unwrap(); assert!(res.status().is_success()); - let first_cookie = http::CookieBuilder::new("first", "first_value") + let first_cookie = CookieBuilder::new("first", "first_value") .http_only(true) .finish(); - let second_cookie = http::Cookie::new("second", "second_value"); + let second_cookie = Cookie::new("second", "second_value"); let cookies = res.cookies().expect("To have cookies"); assert_eq!(cookies.len(), 2); @@ -877,13 +816,15 @@ async fn test_server_cookies() { assert_eq!(cookies[0], second_cookie); assert_eq!(cookies[1], first_cookie); } + + srv.stop().await; } #[actix_rt::test] async fn test_slow_request() { use std::net; - let srv = test::start_with(test::config().client_timeout(200), || { + let srv = actix_test::start_with(actix_test::config().client_timeout(200), || { App::new().service(web::resource("/").route(web::to(HttpResponse::Ok))) }); @@ -897,16 +838,96 @@ async fn test_slow_request() { let mut data = String::new(); let _ = stream.read_to_string(&mut data); assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); + + srv.stop().await; } #[actix_rt::test] async fn test_normalize() { - let srv = test::start_with(test::config().h1(), || { + let srv = actix_test::start_with(actix_test::config().h1(), || { App::new() .wrap(NormalizePath::new(TrailingSlash::Trim)) - .service(web::resource("/one").route(web::to(|| HttpResponse::Ok().finish()))) + .service(web::resource("/one").route(web::to(HttpResponse::Ok))) }); - let response = srv.get("/one/").send().await.unwrap(); - assert!(response.status().is_success()); + let res = srv.get("/one/").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + srv.stop().await +} + +// allow deprecated App::data +#[allow(deprecated)] +#[actix_rt::test] +async fn test_data_drop() { + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + struct TestData(Arc); + + impl TestData { + fn new(inner: Arc) -> Self { + let _ = inner.fetch_add(1, Ordering::SeqCst); + Self(inner) + } + } + + impl Clone for TestData { + fn clone(&self) -> Self { + let inner = self.0.clone(); + let _ = inner.fetch_add(1, Ordering::SeqCst); + Self(inner) + } + } + + impl Drop for TestData { + fn drop(&mut self) { + self.0.fetch_sub(1, Ordering::SeqCst); + } + } + + let num = Arc::new(AtomicUsize::new(0)); + let data = TestData::new(num.clone()); + assert_eq!(num.load(Ordering::SeqCst), 1); + + let srv = actix_test::start(move || { + let data = data.clone(); + + App::new() + .data(data) + .service(web::resource("/").to(|_data: web::Data| async { "ok" })) + }); + + assert!(srv.get("/").send().await.unwrap().status().is_success()); + srv.stop().await; + + assert_eq!(num.load(Ordering::SeqCst), 0); +} + +#[actix_rt::test] +async fn test_accept_encoding_no_match() { + let srv = actix_test::start_with(actix_test::config().h1(), || { + App::new() + .wrap(Compress::default()) + .service(web::resource("/").route(web::to(move || HttpResponse::Ok().finish()))) + }); + + let mut res = srv + .get("/") + .insert_header((header::ACCEPT_ENCODING, "xz, identity;q=0")) + .no_decompress() + .send() + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE); + assert_eq!(res.headers().get(header::CONTENT_ENCODING), None); + + let bytes = res.body().await.unwrap(); + // body should contain the supported encodings + assert!(!bytes.is_empty()); + + srv.stop().await; } diff --git a/tests/utils.rs b/tests/utils.rs new file mode 100644 index 000000000..9a3743d8b --- /dev/null +++ b/tests/utils.rs @@ -0,0 +1,76 @@ +// compiling some tests will trigger unused function warnings even though other tests use them +#![allow(dead_code)] + +use std::io::{Read as _, Write as _}; + +pub mod gzip { + use super::*; + use flate2::{read::GzDecoder, write::GzEncoder, Compression}; + + pub fn encode(bytes: impl AsRef<[u8]>) -> Vec { + let mut encoder = GzEncoder::new(Vec::new(), Compression::fast()); + encoder.write_all(bytes.as_ref()).unwrap(); + encoder.finish().unwrap() + } + + pub fn decode(bytes: impl AsRef<[u8]>) -> Vec { + let mut decoder = GzDecoder::new(bytes.as_ref()); + let mut buf = Vec::new(); + decoder.read_to_end(&mut buf).unwrap(); + buf + } +} + +pub mod deflate { + use super::*; + use flate2::{read::ZlibDecoder, write::ZlibEncoder, Compression}; + + pub fn encode(bytes: impl AsRef<[u8]>) -> Vec { + let mut encoder = ZlibEncoder::new(Vec::new(), Compression::fast()); + encoder.write_all(bytes.as_ref()).unwrap(); + encoder.finish().unwrap() + } + + pub fn decode(bytes: impl AsRef<[u8]>) -> Vec { + let mut decoder = ZlibDecoder::new(bytes.as_ref()); + let mut buf = Vec::new(); + decoder.read_to_end(&mut buf).unwrap(); + buf + } +} + +pub mod brotli { + use super::*; + use ::brotli2::{read::BrotliDecoder, write::BrotliEncoder}; + + pub fn encode(bytes: impl AsRef<[u8]>) -> Vec { + let mut encoder = BrotliEncoder::new(Vec::new(), 3); + encoder.write_all(bytes.as_ref()).unwrap(); + encoder.finish().unwrap() + } + + pub fn decode(bytes: impl AsRef<[u8]>) -> Vec { + let mut decoder = BrotliDecoder::new(bytes.as_ref()); + let mut buf = Vec::new(); + decoder.read_to_end(&mut buf).unwrap(); + buf + } +} + +pub mod zstd { + use super::*; + use ::zstd::stream::{read::Decoder, write::Encoder}; + + pub fn encode(bytes: impl AsRef<[u8]>) -> Vec { + let mut encoder = Encoder::new(Vec::new(), 3).unwrap(); + encoder.write_all(bytes.as_ref()).unwrap(); + encoder.finish().unwrap() + } + + pub fn decode(bytes: impl AsRef<[u8]>) -> Vec { + let mut decoder = Decoder::new(bytes.as_ref()).unwrap(); + let mut buf = Vec::new(); + decoder.read_to_end(&mut buf).unwrap(); + buf + } +} diff --git a/tests/weird_poll.rs b/tests/weird_poll.rs new file mode 100644 index 000000000..5844ea2c2 --- /dev/null +++ b/tests/weird_poll.rs @@ -0,0 +1,30 @@ +//! Regression test for https://github.com/actix/actix-web/issues/1321 + +// use actix_http::body::{BodyStream, MessageBody}; +// use bytes::Bytes; +// use futures_channel::oneshot; +// use futures_util::{ +// stream::once, +// task::{noop_waker, Context}, +// }; + +// #[test] +// fn weird_poll() { +// let (sender, receiver) = oneshot::channel(); +// let mut body_stream = Ok(BodyStream::new(once(async { +// let x = Box::new(0); +// let y = &x; +// receiver.await.unwrap(); +// let _z = **y; +// Ok::<_, ()>(Bytes::new()) +// }))); + +// let waker = noop_waker(); +// let mut cx = Context::from_waker(&waker); + +// let _ = body_stream.as_mut().unwrap().poll_next(&mut cx); +// sender.send(()).unwrap(); +// let _ = std::mem::replace(&mut body_stream, Err([0; 32])) +// .unwrap() +// .poll_next(&mut cx); +// }