diff --git a/.appveyor.yml b/.appveyor.yml index 7addc8c08..2f0a4a7dd 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,6 +1,6 @@ environment: global: - PROJECT_NAME: actix + PROJECT_NAME: actix-web matrix: # Stable channel - TARGET: i686-pc-windows-msvc @@ -37,4 +37,5 @@ build: false # Equivalent to Travis' `script` phase test_script: + - cargo clean - cargo test --no-default-features --features="flate2-rust" diff --git a/.travis.yml b/.travis.yml index 54a86aa7a..9b1bcff54 100644 --- a/.travis.yml +++ b/.travis.yml @@ -30,14 +30,17 @@ before_script: script: - | - if [[ "$TRAVIS_RUST_VERSION" != "stable" ]]; then + if [[ "$TRAVIS_RUST_VERSION" != "nightly" ]]; then cargo clean - cargo test --features="alpn,tls" -- --nocapture + cargo check --features rust-tls + cargo check --features ssl + cargo check --features tls + cargo test --features="ssl,tls,rust-tls,uds" -- --nocapture fi - | - if [[ "$TRAVIS_RUST_VERSION" == "stable" ]]; then + if [[ "$TRAVIS_RUST_VERSION" == "nightly" ]]; then RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install -f cargo-tarpaulin - cargo tarpaulin --features="alpn,tls" --out Xml --no-count + RUST_BACKTRACE=1 cargo tarpaulin --features="ssl,tls,rust-tls" --out Xml bash <(curl -s https://codecov.io/bash) echo "Uploaded code coverage" fi @@ -45,8 +48,8 @@ script: # Upload docs after_success: - | - if [[ "$TRAVIS_OS_NAME" == "linux" && "$TRAVIS_PULL_REQUEST" = "false" && "$TRAVIS_BRANCH" == "master" && "$TRAVIS_RUST_VERSION" == "beta" ]]; then - cargo doc --features "alpn, tls, session" --no-deps && + if [[ "$TRAVIS_OS_NAME" == "linux" && "$TRAVIS_PULL_REQUEST" = "false" && "$TRAVIS_BRANCH" == "master" && "$TRAVIS_RUST_VERSION" == "stable" ]]; then + cargo doc --features "ssl,tls,rust-tls,session" --no-deps && echo "" > target/doc/index.html && git clone https://github.com/davisp/ghp-import.git && ./ghp-import/ghp_import.py -n -p -f -m "Documentation upload" -r https://"$GH_TOKEN"@github.com/"$TRAVIS_REPO_SLUG.git" target/doc && diff --git a/CHANGES.md b/CHANGES.md index 15786fb69..6092544e9 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,16 +1,260 @@ # Changes -## [0.7.1] - 2018-07-21 +## [0.7.15] - 2018-12-05 + +## Changed + +* `ClientConnector::resolver` now accepts `Into` instead of `Addr`. It enables user to implement own resolver. + +* `QueryConfig` and `PathConfig` are made public. + +* `AsyncResult::async` is changed to `AsyncResult::future` as `async` is reserved keyword in 2018 edition. ### Added - * Add implementation of `FromRequest` for `Option` and `Result` +* By default, `Path` extractor now percent decode all characters. This behaviour can be disabled + with `PathConfig::default().disable_decoding()` + + +## [0.7.14] - 2018-11-14 + +### Added + +* Add method to configure custom error handler to `Query` and `Path` extractors. + +* Add method to configure `SameSite` option in `CookieIdentityPolicy`. + +* By default, `Path` extractor now percent decode all characters. This behaviour can be disabled + with `PathConfig::default().disable_decoding()` + + +### Fixed + +* Fix websockets connection drop if request contains "content-length" header #567 + +* Fix keep-alive timer reset + +* HttpServer now treats streaming bodies the same for HTTP/1.x protocols. #549 + +* Set nodelay for socket #560 + + +## [0.7.13] - 2018-10-14 + +### Fixed + +* Fixed rustls support + +* HttpServer not sending streamed request body on HTTP/2 requests #544 + + +## [0.7.12] - 2018-10-10 + +### Changed + +* Set min version for actix + +* Set min version for actix-net + + +## [0.7.11] - 2018-10-09 + +### Fixed + +* Fixed 204 responses for http/2 + + +## [0.7.10] - 2018-10-09 + +### Fixed + +* Fixed panic during graceful shutdown + + +## [0.7.9] - 2018-10-09 + +### Added + +* Added client shutdown timeout setting + +* Added slow request timeout setting + +* Respond with 408 response on slow request timeout #523 + + +### Fixed + +* HTTP1 decoding errors are reported to the client. #512 + +* Correctly compose multiple allowed origins in CORS. #517 + +* Websocket server finished() isn't called if client disconnects #511 + +* Responses with the following codes: 100, 101, 102, 204 -- are sent without Content-Length header. #521 + +* Correct usage of `no_http2` flag in `bind_*` methods. #519 + + +## [0.7.8] - 2018-09-17 + +### Added + +* Use server `Keep-Alive` setting as slow request timeout #439 + +### Changed + +* Use 5 seconds keep-alive timer by default. + +### Fixed + +* Fixed wrong error message for i16 type #510 + + +## [0.7.7] - 2018-09-11 + +### Fixed + +* Fix linked list of HttpChannels #504 + +* Fix requests to TestServer fail #508 + + +## [0.7.6] - 2018-09-07 + +### Fixed + +* Fix system_exit in HttpServer #501 + +* Fix parsing of route param containin regexes with repetition #500 + +### Changes + +* Unhide `SessionBackend` and `SessionImpl` traits #455 + + +## [0.7.5] - 2018-09-04 + +### Added + +* Added the ability to pass a custom `TlsConnector`. + +* Allow to register handlers on scope level #465 + + +### Fixed + +* Handle socket read disconnect + +* Handling scoped paths without leading slashes #460 + + +### Changed + +* Read client response until eof if connection header set to close #464 + + +## [0.7.4] - 2018-08-23 + +### Added + +* Added `HttpServer::maxconn()` and `HttpServer::maxconnrate()`, + accept backpressure #250 + +* Allow to customize connection handshake process via `HttpServer::listen_with()` + and `HttpServer::bind_with()` methods + +* Support making client connections via `tokio-uds`'s `UnixStream` when "uds" feature is enabled #472 + +### Changed + +* It is allowed to use function with up to 10 parameters for handler with `extractor parameters`. + `Route::with_config()`/`Route::with_async_config()` always passes configuration objects as tuple + even for handler with one parameter. + +* native-tls - 0.2 + +* `Content-Disposition` is re-worked. Its parser is now more robust and handles quoted content better. See #461 + +### Fixed + +* Use zlib instead of raw deflate for decoding and encoding payloads with + `Content-Encoding: deflate`. + +* Fixed headers formating for CORS Middleware Access-Control-Expose-Headers #436 + +* Fix adding multiple response headers #446 + +* Client includes port in HOST header when it is not default(e.g. not 80 and 443). #448 + +* Panic during access without routing being set #452 + +* Fixed http/2 error handling + +### Deprecated + +* `HttpServer::no_http2()` is deprecated, use `OpensslAcceptor::with_flags()` or + `RustlsAcceptor::with_flags()` instead + +* `HttpServer::listen_tls()`, `HttpServer::listen_ssl()`, `HttpServer::listen_rustls()` have been + deprecated in favor of `HttpServer::listen_with()` with specific `acceptor`. + +* `HttpServer::bind_tls()`, `HttpServer::bind_ssl()`, `HttpServer::bind_rustls()` have been + deprecated in favor of `HttpServer::bind_with()` with specific `acceptor`. + + +## [0.7.3] - 2018-08-01 + +### Added + +* Support HTTP/2 with rustls #36 + +* Allow TestServer to open a websocket on any URL (TestServer::ws_at()) #433 + +### Fixed + +* Fixed failure 0.1.2 compatibility + +* Do not override HOST header for client request #428 + +* Gz streaming, use `flate2::write::GzDecoder` #228 + +* HttpRequest::url_for is not working with scopes #429 + +* Fixed headers' formating for CORS Middleware `Access-Control-Expose-Headers` header value to HTTP/1.1 & HTTP/2 spec-compliant format #436 + + +## [0.7.2] - 2018-07-26 + +### Added + +* Add implementation of `FromRequest` for `Option` and `Result` + +* Allow to handle application prefix, i.e. allow to handle `/app` path + for application with `/app` prefix. + Check [`App::prefix()`](https://actix.rs/actix-web/actix_web/struct.App.html#method.prefix) + api doc. + +* Add `CookieSessionBackend::http_only` method to set `HttpOnly` directive of cookies + +### Changed + +* Upgrade to cookie 0.11 + +* Removed the timestamp from the default logger middleware + +### Fixed + +* Missing response header "content-encoding" #421 + +* Fix stream draining for http/2 connections #290 + + +## [0.7.1] - 2018-07-21 ### Fixed * Fixed default_resource 'not yet implemented' panic #410 -* Add `CookieSessionBackend::http_only` method to set `HttpOnly` directive of cookies ## [0.7.0] - 2018-07-21 diff --git a/Cargo.toml b/Cargo.toml index a6b73ee55..7b8dcec35 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-web" -version = "0.7.1" +version = "0.7.15" authors = ["Nikolay Kim "] description = "Actix web is a simple, pragmatic and extremely fast web framework for Rust." readme = "README.md" @@ -17,7 +17,7 @@ exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] build = "build.rs" [package.metadata.docs.rs] -features = ["tls", "alpn", "session", "brotli", "flate2-c"] +features = ["tls", "ssl", "rust-tls", "session", "brotli", "flate2-c"] [badges] travis-ci = { repository = "actix/actix-web", branch = "master" } @@ -29,13 +29,22 @@ name = "actix_web" path = "src/lib.rs" [features] -default = ["session", "brotli", "flate2-c"] +default = ["session", "brotli", "flate2-c", "cell"] # tls -tls = ["native-tls", "tokio-tls"] +tls = ["native-tls", "tokio-tls", "actix-net/tls"] # openssl -alpn = ["openssl", "tokio-openssl"] +ssl = ["openssl", "tokio-openssl", "actix-net/ssl"] + +# deprecated, use "ssl" +alpn = ["openssl", "tokio-openssl", "actix-net/ssl"] + +# rustls +rust-tls = ["rustls", "tokio-rustls", "webpki", "webpki-roots", "actix-net/rust-tls"] + +# unix sockets +uds = ["tokio-uds"] # sessions feature, session require "ring" crate and c compiler session = ["cookie/secure"] @@ -49,21 +58,25 @@ flate2-c = ["flate2/miniz-sys"] # rust backend for flate2 crate flate2-rust = ["flate2/rust_backend"] -[dependencies] -actix = "0.7.0" +cell = ["actix-net/cell"] -base64 = "0.9" +[dependencies] +actix = "0.7.7" +actix-net = "0.2.2" + +askama_escape = "0.1.0" +base64 = "0.10" bitflags = "1.0" +failure = "^0.1.2" h2 = "0.1" -htmlescape = "0.3" -http = "^0.1.5" +http = "^0.1.8" httparse = "1.3" log = "0.4" mime = "0.3" mime_guess = "2.0.0-alpha" num_cpus = "1.0" percent-encoding = "1.0" -rand = "0.5" +rand = "0.6" regex = "1.0" serde = "1.0" serde_json = "1.0" @@ -74,13 +87,12 @@ encoding = "0.2" language-tags = "0.2" lazy_static = "1.0" lazycell = "1.0.0" -parking_lot = "0.6" +parking_lot = "0.7" +serde_urlencoded = "^0.5.3" url = { version="1.7", features=["query_encoding"] } -cookie = { version="0.10", features=["percent-encode"] } +cookie = { version="0.11", features=["percent-encode"] } brotli2 = { version="^0.3.2", optional = true } -flate2 = { version="1.0", optional = true, default-features = false } - -failure = "=0.1.1" +flate2 = { version="^1.0.2", optional = true, default-features = false } # io mio = "^0.6.13" @@ -95,21 +107,27 @@ tokio-io = "0.1" tokio-tcp = "0.1" tokio-timer = "0.2" tokio-reactor = "0.1" +tokio-current-thread = "0.1" # native-tls -native-tls = { version="0.1", optional = true } -tokio-tls = { version="0.1", optional = true } +native-tls = { version="0.2", optional = true } +tokio-tls = { version="0.2", optional = true } # openssl openssl = { version="0.10", optional = true } tokio-openssl = { version="0.2", optional = true } -# forked url_encoded -itoa = "0.4" -dtoa = "0.4" +#rustls +rustls = { version = "0.14", optional = true } +tokio-rustls = { version = "0.8", optional = true } +webpki = { version = "0.18", optional = true } +webpki-roots = { version = "0.15", optional = true } + +# unix sockets +tokio-uds = { version="0.2", optional = true } [dev-dependencies] -env_logger = "0.5" +env_logger = "0.6" serde_derive = "1.0" [build-dependencies] @@ -119,8 +137,3 @@ version_check = "0.1" lto = true opt-level = 3 codegen-units = 1 - -[workspace] -members = [ - "./", -] diff --git a/MIGRATION.md b/MIGRATION.md index 29bf0c348..6b49e3e6a 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -1,3 +1,39 @@ +## 0.7.15 + +* The `' '` character is not percent decoded anymore before matching routes. If you need to use it in + your routes, you should use `%20`. + + instead of + + ```rust + fn main() { + let app = App::new().resource("/my index", |r| { + r.method(http::Method::GET) + .with(index); + }); + } + ``` + + use + + ```rust + fn main() { + let app = App::new().resource("/my%20index", |r| { + r.method(http::Method::GET) + .with(index); + }); + } + ``` + +* If you used `AsyncResult::async` you need to replace it with `AsyncResult::future` + + +## 0.7.4 + +* `Route::with_config()`/`Route::with_async_config()` always passes configuration objects as tuple + even for handler with one parameter. + + ## 0.7 * `HttpRequest` does not implement `Stream` anymore. If you need to read request payload diff --git a/README.md b/README.md index ec8c439ef..db3cc68c5 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,13 @@ Actix web is a simple, pragmatic and extremely fast web framework for Rust. * Client/server [WebSockets](https://actix.rs/docs/websockets/) support * Transparent content compression/decompression (br, gzip, deflate) * Configurable [request routing](https://actix.rs/docs/url-dispatch/) -* Graceful server shutdown * Multipart streams * Static assets * SSL support with OpenSSL or `native-tls` -* Middlewares ([Logger,Session,CORS,CSRF,etc](https://actix.rs/docs/middleware/)) +* Middlewares ([Logger, Session, CORS, CSRF, etc](https://actix.rs/docs/middleware/)) * Includes an asynchronous [HTTP client](https://actix.rs/actix-web/actix_web/client/index.html) * Built on top of [Actix actor framework](https://github.com/actix/actix) +* Experimental [Async/Await](https://github.com/mehcode/actix-web-async-await) support. ## Documentation & community resources @@ -51,7 +51,7 @@ fn main() { * [Protobuf support](https://github.com/actix/examples/tree/master/protobuf/) * [Multipart streams](https://github.com/actix/examples/tree/master/multipart/) * [Simple websocket](https://github.com/actix/examples/tree/master/websocket/) -* [Tera](https://github.com/actix/examples/tree/master/template_tera/) / +* [Tera](https://github.com/actix/examples/tree/master/template_tera/) / [Askama](https://github.com/actix/examples/tree/master/template_askama/) templates * [Diesel integration](https://github.com/actix/examples/tree/master/diesel/) * [r2d2](https://github.com/actix/examples/tree/master/r2d2/) @@ -66,8 +66,6 @@ You may consider checking out * [TechEmpower Framework Benchmark](https://www.techempower.com/benchmarks/#section=data-r16&hw=ph&test=plaintext) -* Some basic benchmarks could be found in this [repository](https://github.com/fafhrd91/benchmarks). - ## License This project is licensed under either of diff --git a/src/application.rs b/src/application.rs index f36adf69e..d8a6cbe7b 100644 --- a/src/application.rs +++ b/src/application.rs @@ -12,6 +12,7 @@ use resource::Resource; use router::{ResourceDef, Router}; use scope::Scope; use server::{HttpHandler, HttpHandlerTask, IntoHttpHandler, Request}; +use with::WithFactory; /// Application pub struct HttpApplication { @@ -134,13 +135,13 @@ where /// instance for each thread, thus application state must be constructed /// multiple times. If you want to share state between different /// threads, a shared object should be used, e.g. `Arc`. Application - /// state does not need to be `Send` and `Sync`. + /// state does not need to be `Send` or `Sync`. pub fn with_state(state: S) -> App { App { parts: Some(ApplicationParts { state, prefix: "".to_owned(), - router: Router::new(), + router: Router::new(ResourceDef::prefix("")), middlewares: Vec::new(), filters: Vec::new(), encoding: ContentEncoding::Auto, @@ -171,7 +172,9 @@ where /// In the following example only requests with an `/app/` path /// prefix get handled. Requests with path `/app/test/` would be /// handled, while requests with the paths `/application` or - /// `/other/...` would return `NOT FOUND`. + /// `/other/...` would return `NOT FOUND`. It is also possible to + /// handle `/app` path, to do this you can register resource for + /// empty string `""` /// /// ```rust /// # extern crate actix_web; @@ -180,6 +183,8 @@ where /// fn main() { /// let app = App::new() /// .prefix("/app") + /// .resource("", |r| r.f(|_| HttpResponse::Ok())) // <- handle `/app` path + /// .resource("/", |r| r.f(|_| HttpResponse::Ok())) // <- handle `/app/` path /// .resource("/test", |r| { /// r.get().f(|_| HttpResponse::Ok()); /// r.head().f(|_| HttpResponse::MethodNotAllowed()); @@ -194,6 +199,7 @@ where if !prefix.starts_with('/') { prefix.insert(0, '/') } + parts.router.set_prefix(&prefix); parts.prefix = prefix; } self @@ -244,7 +250,7 @@ where /// ``` pub fn route(mut self, path: &str, method: Method, f: F) -> App where - F: Fn(T) -> R + 'static, + F: WithFactory, R: Responder + 'static, T: FromRequest + 'static, { @@ -441,11 +447,8 @@ where { let mut path = path.trim().trim_right_matches('/').to_owned(); if !path.is_empty() && !path.starts_with('/') { - path.insert(0, '/') - } - if path.len() > 1 && path.ends_with('/') { - path.pop(); - } + path.insert(0, '/'); + }; self.parts .as_mut() .expect("Use after finish") @@ -770,8 +773,7 @@ mod tests { .route("/test", Method::GET, |_: HttpRequest| HttpResponse::Ok()) .route("/test", Method::POST, |_: HttpRequest| { HttpResponse::Created() - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/test").method(Method::GET).request(); let resp = app.run(req); @@ -822,6 +824,23 @@ mod tests { assert_eq!(resp.as_msg().status(), StatusCode::NOT_FOUND); } + #[test] + fn test_option_responder() { + let app = App::new() + .resource("/none", |r| r.f(|_| -> Option<&'static str> { None })) + .resource("/some", |r| r.f(|_| Some("some"))) + .finish(); + + let req = TestRequest::with_uri("/none").request(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/some").request(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::OK); + assert_eq!(resp.as_msg().body(), &Body::Binary(Binary::Slice(b"some"))); + } + #[test] fn test_filter() { let mut srv = TestServer::with_factory(|| { @@ -840,19 +859,21 @@ mod tests { } #[test] - fn test_option_responder() { - let app = App::new() - .resource("/none", |r| r.f(|_| -> Option<&'static str> { None })) - .resource("/some", |r| r.f(|_| Some("some"))) - .finish(); + fn test_prefix_root() { + let mut srv = TestServer::with_factory(|| { + App::new() + .prefix("/test") + .resource("/", |r| r.f(|_| HttpResponse::Ok())) + .resource("", |r| r.f(|_| HttpResponse::Created())) + }); - let req = TestRequest::with_uri("/none").request(); - let resp = app.run(req); - assert_eq!(resp.as_msg().status(), StatusCode::NOT_FOUND); + let request = srv.get().uri(srv.url("/test/")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert_eq!(response.status(), StatusCode::OK); - let req = TestRequest::with_uri("/some").request(); - let resp = app.run(req); - assert_eq!(resp.as_msg().status(), StatusCode::OK); - assert_eq!(resp.as_msg().body(), &Body::Binary(Binary::Slice(b"some"))); + let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert_eq!(response.status(), StatusCode::CREATED); } + } diff --git a/src/body.rs b/src/body.rs index a93db1e92..5487dbba4 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,5 +1,6 @@ use bytes::{Bytes, BytesMut}; use futures::Stream; +use std::borrow::Cow; use std::sync::Arc; use std::{fmt, mem}; @@ -194,12 +195,30 @@ impl From> for Binary { } } +impl From> for Binary { + fn from(b: Cow<'static, [u8]>) -> Binary { + match b { + Cow::Borrowed(s) => Binary::Slice(s), + Cow::Owned(vec) => Binary::Bytes(Bytes::from(vec)), + } + } +} + impl From for Binary { fn from(s: String) -> Binary { Binary::Bytes(Bytes::from(s)) } } +impl From> for Binary { + fn from(s: Cow<'static, str>) -> Binary { + match s { + Cow::Borrowed(s) => Binary::Slice(s.as_ref()), + Cow::Owned(s) => Binary::Bytes(Bytes::from(s)), + } + } +} + impl<'a> From<&'a String> for Binary { fn from(s: &'a String) -> Binary { Binary::Bytes(Bytes::from(AsRef::<[u8]>::as_ref(&s))) @@ -287,6 +306,16 @@ mod tests { assert_eq!(Binary::from("test").as_ref(), b"test"); } + #[test] + fn test_cow_str() { + let cow: Cow<'static, str> = Cow::Borrowed("test"); + assert_eq!(Binary::from(cow.clone()).len(), 4); + assert_eq!(Binary::from(cow.clone()).as_ref(), b"test"); + let cow: Cow<'static, str> = Cow::Owned("test".to_owned()); + assert_eq!(Binary::from(cow.clone()).len(), 4); + assert_eq!(Binary::from(cow.clone()).as_ref(), b"test"); + } + #[test] fn test_static_bytes() { assert_eq!(Binary::from(b"test".as_ref()).len(), 4); @@ -307,6 +336,16 @@ mod tests { assert_eq!(Binary::from(Bytes::from("test")).as_ref(), b"test"); } + #[test] + fn test_cow_bytes() { + let cow: Cow<'static, [u8]> = Cow::Borrowed(b"test"); + assert_eq!(Binary::from(cow.clone()).len(), 4); + assert_eq!(Binary::from(cow.clone()).as_ref(), b"test"); + let cow: Cow<'static, [u8]> = Cow::Owned(Vec::from("test")); + assert_eq!(Binary::from(cow.clone()).len(), 4); + assert_eq!(Binary::from(cow.clone()).as_ref(), b"test"); + } + #[test] fn test_arc_string() { let b = Arc::new("test".to_owned()); diff --git a/src/client/connector.rs b/src/client/connector.rs index 6d391af87..f5affad37 100644 --- a/src/client/connector.rs +++ b/src/client/connector.rs @@ -5,7 +5,7 @@ use std::{fmt, io, mem, time}; use actix::resolver::{Connect as ResolveConnect, Resolver, ResolverError}; use actix::{ - fut, Actor, ActorFuture, ActorResponse, Addr, AsyncContext, Context, + fut, Actor, ActorFuture, ActorResponse, AsyncContext, Context, ContextFutureSpawner, Handler, Message, Recipient, StreamHandler, Supervised, SystemService, WrapFuture, }; @@ -16,18 +16,40 @@ use http::{Error as HttpError, HttpTryFrom, Uri}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_timer::Delay; -#[cfg(feature = "alpn")] -use openssl::ssl::{Error as OpensslError, SslConnector, SslMethod}; -#[cfg(feature = "alpn")] -use tokio_openssl::SslConnectorExt; +#[cfg(any(feature = "alpn", feature = "ssl"))] +use { + openssl::ssl::{Error as SslError, SslConnector, SslMethod}, + tokio_openssl::SslConnectorExt, +}; -#[cfg(all(feature = "tls", not(feature = "alpn")))] -use native_tls::{Error as TlsError, TlsConnector}; -#[cfg(all(feature = "tls", not(feature = "alpn")))] -use tokio_tls::TlsConnectorExt; +#[cfg(all( + feature = "tls", + not(any(feature = "alpn", feature = "ssl", feature = "rust-tls")) +))] +use { + native_tls::{Error as SslError, TlsConnector as NativeTlsConnector}, + tokio_tls::TlsConnector as SslConnector, +}; + +#[cfg(all( + feature = "rust-tls", + not(any(feature = "alpn", feature = "tls", feature = "ssl")) +))] +use { + rustls::ClientConfig, std::io::Error as SslError, std::sync::Arc, + tokio_rustls::TlsConnector as SslConnector, webpki::DNSNameRef, webpki_roots, +}; + +#[cfg(not(any( + feature = "alpn", + feature = "ssl", + feature = "tls", + feature = "rust-tls" +)))] +type SslConnector = (); use server::IoStream; -use {HAS_OPENSSL, HAS_TLS}; +use {HAS_OPENSSL, HAS_RUSTLS, HAS_TLS}; /// Client connector usage stats #[derive(Default, Message)] @@ -130,14 +152,14 @@ pub enum ClientConnectorError { SslIsNotSupported, /// SSL error - #[cfg(feature = "alpn")] + #[cfg(any( + feature = "tls", + feature = "alpn", + feature = "ssl", + feature = "rust-tls", + ))] #[fail(display = "{}", _0)] - SslError(#[cause] OpensslError), - - /// SSL error - #[cfg(all(feature = "tls", not(feature = "alpn")))] - #[fail(display = "{}", _0)] - SslError(#[cause] TlsError), + SslError(#[cause] SslError), /// Resolver error #[fail(display = "{}", _0)] @@ -189,10 +211,8 @@ impl Paused { /// `ClientConnector` type is responsible for transport layer of a /// client connection. pub struct ClientConnector { - #[cfg(all(feature = "alpn"))] + #[allow(dead_code)] connector: SslConnector, - #[cfg(all(feature = "tls", not(feature = "alpn")))] - connector: TlsConnector, stats: ClientConnectorStats, subscriber: Option>, @@ -200,7 +220,7 @@ pub struct ClientConnector { acq_tx: mpsc::UnboundedSender, acq_rx: Option>, - resolver: Option>, + resolver: Option>, conn_lifetime: Duration, conn_keep_alive: Duration, limit: usize, @@ -219,7 +239,7 @@ impl Actor for ClientConnector { fn started(&mut self, ctx: &mut Self::Context) { if self.resolver.is_none() { - self.resolver = Some(Resolver::from_registry()) + self.resolver = Some(Resolver::from_registry().recipient()) } self.collect_periodic(ctx); ctx.add_stream(self.acq_rx.take().unwrap()); @@ -233,63 +253,47 @@ impl SystemService for ClientConnector {} impl Default for ClientConnector { fn default() -> ClientConnector { - #[cfg(all(feature = "alpn"))] - { - let builder = SslConnector::builder(SslMethod::tls()).unwrap(); - ClientConnector::with_connector(builder.build()) - } - #[cfg(all(feature = "tls", not(feature = "alpn")))] - { - let (tx, rx) = mpsc::unbounded(); - let builder = TlsConnector::builder().unwrap(); - ClientConnector { - stats: ClientConnectorStats::default(), - subscriber: None, - acq_tx: tx, - acq_rx: Some(rx), - resolver: None, - connector: builder.build().unwrap(), - conn_lifetime: Duration::from_secs(75), - conn_keep_alive: Duration::from_secs(15), - limit: 100, - limit_per_host: 0, - acquired: 0, - acquired_per_host: HashMap::new(), - available: HashMap::new(), - to_close: Vec::new(), - waiters: Some(HashMap::new()), - wait_timeout: None, - paused: Paused::No, + let connector = { + #[cfg(all(any(feature = "alpn", feature = "ssl")))] + { + SslConnector::builder(SslMethod::tls()).unwrap().build() } - } - #[cfg(not(any(feature = "alpn", feature = "tls")))] - { - let (tx, rx) = mpsc::unbounded(); - ClientConnector { - stats: ClientConnectorStats::default(), - subscriber: None, - acq_tx: tx, - acq_rx: Some(rx), - resolver: None, - conn_lifetime: Duration::from_secs(75), - conn_keep_alive: Duration::from_secs(15), - limit: 100, - limit_per_host: 0, - acquired: 0, - acquired_per_host: HashMap::new(), - available: HashMap::new(), - to_close: Vec::new(), - waiters: Some(HashMap::new()), - wait_timeout: None, - paused: Paused::No, + #[cfg(all( + feature = "tls", + not(any(feature = "alpn", feature = "ssl", feature = "rust-tls")) + ))] + { + NativeTlsConnector::builder().build().unwrap().into() } - } + + #[cfg(all( + feature = "rust-tls", + not(any(feature = "alpn", feature = "tls", feature = "ssl")) + ))] + { + let mut config = ClientConfig::new(); + config + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + SslConnector::from(Arc::new(config)) + } + + #[cfg_attr(rustfmt, rustfmt_skip)] + #[cfg(not(any( + feature = "alpn", feature = "ssl", feature = "tls", feature = "rust-tls")))] + { + () + } + }; + + #[cfg_attr(feature = "cargo-clippy", allow(let_unit_value))] + ClientConnector::with_connector_impl(connector) } } impl ClientConnector { - #[cfg(feature = "alpn")] + #[cfg(any(feature = "alpn", feature = "ssl"))] /// Create `ClientConnector` actor with custom `SslConnector` instance. /// /// By default `ClientConnector` uses very a simple SSL configuration. @@ -302,7 +306,6 @@ impl ClientConnector { /// # extern crate futures; /// # use futures::{future, Future}; /// # use std::io::Write; - /// # use std::process; /// # use actix_web::actix::Actor; /// extern crate openssl; /// use actix_web::{actix, client::ClientConnector, client::Connect}; @@ -325,10 +328,112 @@ impl ClientConnector { /// # actix::System::current().stop(); /// Ok(()) /// }) - /// ); + /// }); /// } /// ``` pub fn with_connector(connector: SslConnector) -> ClientConnector { + // keep level of indirection for docstrings matching featureflags + Self::with_connector_impl(connector) + } + + #[cfg(all( + feature = "rust-tls", + not(any(feature = "alpn", feature = "ssl", feature = "tls")) + ))] + /// Create `ClientConnector` actor with custom `SslConnector` instance. + /// + /// By default `ClientConnector` uses very a simple SSL configuration. + /// With `with_connector` method it is possible to use a custom + /// `SslConnector` object. + /// + /// ```rust + /// # #![cfg(feature = "rust-tls")] + /// # extern crate actix_web; + /// # extern crate futures; + /// # use futures::{future, Future}; + /// # use std::io::Write; + /// # use actix_web::actix::Actor; + /// extern crate rustls; + /// extern crate webpki_roots; + /// use actix_web::{actix, client::ClientConnector, client::Connect}; + /// + /// use rustls::ClientConfig; + /// use std::sync::Arc; + /// + /// fn main() { + /// actix::run(|| { + /// // Start `ClientConnector` with custom `ClientConfig` + /// let mut config = ClientConfig::new(); + /// config + /// .root_store + /// .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + /// let conn = ClientConnector::with_connector(config).start(); + /// + /// conn.send( + /// Connect::new("https://www.rust-lang.org").unwrap()) // <- connect to host + /// .map_err(|_| ()) + /// .and_then(|res| { + /// if let Ok(mut stream) = res { + /// stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); + /// } + /// # actix::System::current().stop(); + /// Ok(()) + /// }) + /// }); + /// } + /// ``` + pub fn with_connector(connector: ClientConfig) -> ClientConnector { + // keep level of indirection for docstrings matching featureflags + Self::with_connector_impl(SslConnector::from(Arc::new(connector))) + } + + #[cfg(all( + feature = "tls", + not(any(feature = "ssl", feature = "alpn", feature = "rust-tls")) + ))] + /// Create `ClientConnector` actor with custom `SslConnector` instance. + /// + /// By default `ClientConnector` uses very a simple SSL configuration. + /// With `with_connector` method it is possible to use a custom + /// `SslConnector` object. + /// + /// ```rust + /// # #![cfg(feature = "tls")] + /// # extern crate actix_web; + /// # extern crate futures; + /// # use futures::{future, Future}; + /// # use std::io::Write; + /// # use actix_web::actix::Actor; + /// extern crate native_tls; + /// extern crate webpki_roots; + /// use native_tls::TlsConnector; + /// use actix_web::{actix, client::ClientConnector, client::Connect}; + /// + /// fn main() { + /// actix::run(|| { + /// let connector = TlsConnector::new().unwrap(); + /// let conn = ClientConnector::with_connector(connector.into()).start(); + /// + /// conn.send( + /// Connect::new("https://www.rust-lang.org").unwrap()) // <- connect to host + /// .map_err(|_| ()) + /// .and_then(|res| { + /// if let Ok(mut stream) = res { + /// stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); + /// } + /// # actix::System::current().stop(); + /// Ok(()) + /// }) + /// }); + /// } + /// ``` + pub fn with_connector(connector: SslConnector) -> ClientConnector { + // keep level of indirection for docstrings matching featureflags + Self::with_connector_impl(connector) + } + + #[inline] + fn with_connector_impl(connector: SslConnector) -> ClientConnector { let (tx, rx) = mpsc::unbounded(); ClientConnector { @@ -398,8 +503,10 @@ impl ClientConnector { } /// Use custom resolver actor - pub fn resolver(mut self, addr: Addr) -> Self { - self.resolver = Some(addr); + /// + /// By default actix's Resolver is used. + pub fn resolver>>(mut self, addr: A) -> Self { + self.resolver = Some(addr.into()); self } @@ -599,7 +706,7 @@ impl ClientConnector { } Acquire::Available => { // create new connection - self.connect_waiter(key.clone(), waiter, ctx); + self.connect_waiter(&key, waiter, ctx); } } } @@ -608,7 +715,8 @@ impl ClientConnector { self.waiters = Some(act_waiters); } - fn connect_waiter(&mut self, key: Key, waiter: Waiter, ctx: &mut Context) { + fn connect_waiter(&mut self, key: &Key, waiter: Waiter, ctx: &mut Context) { + let key = key.clone(); let conn = AcquiredConn(key.clone(), Some(self.acq_tx.clone())); let key2 = key.clone(); @@ -620,118 +728,164 @@ impl ClientConnector { ).map_err(move |_, act, _| { act.release_key(&key2); () - }) - .and_then(move |res, act, _| { - #[cfg(feature = "alpn")] - match res { - Err(err) => { - let _ = waiter.tx.send(Err(err.into())); - fut::Either::B(fut::err(())) - } - Ok(stream) => { - act.stats.opened += 1; - if conn.0.ssl { - fut::Either::A( - act.connector - .connect_async(&key.host, stream) - .into_actor(act) - .then(move |res, act, _| { - match res { - Err(e) => { - let _ = waiter.tx.send(Err( - ClientConnectorError::SslError(e), - )); - } - Ok(stream) => { - let _ = - waiter.tx.send(Ok(Connection::new( - conn.0.clone(), - Some(conn), - Box::new(stream), - ))); - } + }).and_then(move |res, act, _| { + #[cfg(any(feature = "alpn", feature = "ssl"))] + match res { + Err(err) => { + let _ = waiter.tx.send(Err(err.into())); + fut::Either::B(fut::err(())) + } + Ok(stream) => { + act.stats.opened += 1; + if conn.0.ssl { + fut::Either::A( + act.connector + .connect_async(&key.host, stream) + .into_actor(act) + .then(move |res, _, _| { + match res { + Err(e) => { + let _ = waiter.tx.send(Err( + ClientConnectorError::SslError(e), + )); } - fut::ok(()) - }), - ) - } else { - let _ = waiter.tx.send(Ok(Connection::new( - conn.0.clone(), - Some(conn), - Box::new(stream), - ))); - fut::Either::B(fut::ok(())) - } - } - } - - #[cfg(all(feature = "tls", not(feature = "alpn")))] - match res { - Err(err) => { - let _ = waiter.tx.send(Err(err.into())); - fut::Either::B(fut::err(())) - } - Ok(stream) => { - act.stats.opened += 1; - if conn.0.ssl { - fut::Either::A( - act.connector - .connect_async(&conn.0.host, stream) - .into_actor(act) - .then(move |res, _, _| { - match res { - Err(e) => { - let _ = waiter.tx.send(Err( - ClientConnectorError::SslError(e), - )); - } - Ok(stream) => { - let _ = - waiter.tx.send(Ok(Connection::new( - conn.0.clone(), - Some(conn), - Box::new(stream), - ))); - } + Ok(stream) => { + let _ = waiter.tx.send(Ok(Connection::new( + conn.0.clone(), + Some(conn), + Box::new(stream), + ))); } - fut::ok(()) - }), - ) - } else { - let _ = waiter.tx.send(Ok(Connection::new( - conn.0.clone(), - Some(conn), - Box::new(stream), - ))); - fut::Either::B(fut::ok(())) - } + } + fut::ok(()) + }), + ) + } else { + let _ = waiter.tx.send(Ok(Connection::new( + conn.0.clone(), + Some(conn), + Box::new(stream), + ))); + fut::Either::B(fut::ok(())) } } + } - #[cfg(not(any(feature = "alpn", feature = "tls")))] - match res { - Err(err) => { - let _ = waiter.tx.send(Err(err.into())); - fut::err(()) - } - Ok(stream) => { - act.stats.opened += 1; - if conn.0.ssl { - let _ = waiter - .tx - .send(Err(ClientConnectorError::SslIsNotSupported)); - } else { - let _ = waiter.tx.send(Ok(Connection::new( - conn.0.clone(), - Some(conn), - Box::new(stream), - ))); - }; - fut::ok(()) + #[cfg(all(feature = "tls", not(any(feature = "alpn", feature = "ssl"))))] + match res { + Err(err) => { + let _ = waiter.tx.send(Err(err.into())); + fut::Either::B(fut::err(())) + } + Ok(stream) => { + act.stats.opened += 1; + if conn.0.ssl { + fut::Either::A( + act.connector + .connect(&conn.0.host, stream) + .into_actor(act) + .then(move |res, _, _| { + match res { + Err(e) => { + let _ = waiter.tx.send(Err( + ClientConnectorError::SslError(e), + )); + } + Ok(stream) => { + let _ = waiter.tx.send(Ok(Connection::new( + conn.0.clone(), + Some(conn), + Box::new(stream), + ))); + } + } + fut::ok(()) + }), + ) + } else { + let _ = waiter.tx.send(Ok(Connection::new( + conn.0.clone(), + Some(conn), + Box::new(stream), + ))); + fut::Either::B(fut::ok(())) } } - }) - .spawn(ctx); + } + + #[cfg(all( + feature = "rust-tls", + not(any(feature = "alpn", feature = "ssl", feature = "tls")) + ))] + match res { + Err(err) => { + let _ = waiter.tx.send(Err(err.into())); + fut::Either::B(fut::err(())) + } + Ok(stream) => { + act.stats.opened += 1; + if conn.0.ssl { + let host = DNSNameRef::try_from_ascii_str(&key.host).unwrap(); + fut::Either::A( + act.connector + .connect(host, stream) + .into_actor(act) + .then(move |res, _, _| { + match res { + Err(e) => { + let _ = waiter.tx.send(Err( + ClientConnectorError::SslError(e), + )); + } + Ok(stream) => { + let _ = waiter.tx.send(Ok(Connection::new( + conn.0.clone(), + Some(conn), + Box::new(stream), + ))); + } + } + fut::ok(()) + }), + ) + } else { + let _ = waiter.tx.send(Ok(Connection::new( + conn.0.clone(), + Some(conn), + Box::new(stream), + ))); + fut::Either::B(fut::ok(())) + } + } + } + + #[cfg(not(any( + feature = "alpn", + feature = "ssl", + feature = "tls", + feature = "rust-tls" + )))] + match res { + Err(err) => { + let _ = waiter.tx.send(Err(err.into())); + fut::err(()) + } + Ok(stream) => { + act.stats.opened += 1; + if conn.0.ssl { + let _ = + waiter.tx.send(Err(ClientConnectorError::SslIsNotSupported)); + } else { + let _ = waiter.tx.send(Ok(Connection::new( + conn.0.clone(), + Some(conn), + Box::new(stream), + ))); + }; + fut::ok(()) + } + } + }).spawn(ctx); } } @@ -783,12 +937,12 @@ impl Handler for ClientConnector { }; // check ssl availability - if proto.is_secure() && !HAS_OPENSSL && !HAS_TLS { + if proto.is_secure() && !HAS_OPENSSL && !HAS_TLS && !HAS_RUSTLS { return ActorResponse::reply(Err(ClientConnectorError::SslIsNotSupported)); } let host = uri.host().unwrap().to_owned(); - let port = uri.port().unwrap_or_else(|| proto.port()); + let port = uri.port_part().map(|port| port.as_u16()).unwrap_or_else(|| proto.port()); let key = Key { host, port, @@ -828,7 +982,7 @@ impl Handler for ClientConnector { wait, conn_timeout, }; - self.connect_waiter(key.clone(), waiter, ctx); + self.connect_waiter(&key, waiter, ctx); return ActorResponse::async( rx.map_err(|_| ClientConnectorError::Disconnected) @@ -885,7 +1039,7 @@ impl Handler for ClientConnector { wait, conn_timeout, }; - self.connect_waiter(key.clone(), waiter, ctx); + self.connect_waiter(&key, waiter, ctx); ActorResponse::async( rx.map_err(|_| ClientConnectorError::Disconnected) @@ -1089,6 +1243,10 @@ impl Connection { } /// Create a new connection from an IO Stream + /// + /// The stream can be a `UnixStream` if the Unix-only "uds" feature is enabled. + /// + /// See also `ClientRequestBuilder::with_connection()`. pub fn from_stream(io: T) -> Connection { Connection::new(Key::empty(), None, Box::new(io)) } @@ -1122,6 +1280,11 @@ impl IoStream for Connection { fn set_linger(&mut self, dur: Option) -> io::Result<()> { IoStream::set_linger(&mut *self.stream, dur) } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + IoStream::set_keepalive(&mut *self.stream, dur) + } } impl io::Read for Connection { @@ -1147,3 +1310,31 @@ impl AsyncWrite for Connection { self.stream.shutdown() } } + +#[cfg(feature = "tls")] +use tokio_tls::TlsStream; + +#[cfg(feature = "tls")] +/// This is temp solution untile actix-net migration +impl IoStream for TlsStream { + #[inline] + fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { + let _ = self.get_mut().shutdown(); + Ok(()) + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.get_mut().get_mut().set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_linger(dur) + } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_keepalive(dur) + } +} diff --git a/src/client/parser.rs b/src/client/parser.rs index f5390cc34..92a7abe13 100644 --- a/src/client/parser.rs +++ b/src/client/parser.rs @@ -20,6 +20,7 @@ const MAX_HEADERS: usize = 96; #[derive(Default)] pub struct HttpResponseParser { decoder: Option, + eof: bool, // indicate that we read payload until stream eof } #[derive(Debug, Fail)] @@ -38,43 +39,42 @@ impl HttpResponseParser { where T: IoStream, { - // if buf is empty parse_message will always return NotReady, let's avoid that - if buf.is_empty() { + loop { + // Don't call parser until we have data to parse. + if !buf.is_empty() { + match HttpResponseParser::parse_message(buf) + .map_err(HttpResponseParserError::Error)? + { + Async::Ready((msg, info)) => { + if let Some((decoder, eof)) = info { + self.eof = eof; + self.decoder = Some(decoder); + } else { + self.eof = false; + self.decoder = None; + } + return Ok(Async::Ready(msg)); + } + Async::NotReady => { + if buf.len() >= MAX_BUFFER_SIZE { + return Err(HttpResponseParserError::Error( + ParseError::TooLarge, + )); + } + // Parser needs more data. + } + } + } + // Read some more data into the buffer for the parser. match io.read_available(buf) { - Ok(Async::Ready(true)) => { + Ok(Async::Ready((false, true))) => { return Err(HttpResponseParserError::Disconnect) } - Ok(Async::Ready(false)) => (), + Ok(Async::Ready(_)) => (), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => return Err(HttpResponseParserError::Error(err.into())), } } - - loop { - match HttpResponseParser::parse_message(buf) - .map_err(HttpResponseParserError::Error)? - { - Async::Ready((msg, decoder)) => { - self.decoder = decoder; - return Ok(Async::Ready(msg)); - } - Async::NotReady => { - if buf.capacity() >= MAX_BUFFER_SIZE { - return Err(HttpResponseParserError::Error(ParseError::TooLarge)); - } - match io.read_available(buf) { - Ok(Async::Ready(true)) => { - return Err(HttpResponseParserError::Disconnect) - } - Ok(Async::Ready(false)) => (), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => { - return Err(HttpResponseParserError::Error(err.into())) - } - } - } - } - } } pub fn parse_payload( @@ -87,8 +87,8 @@ impl HttpResponseParser { loop { // read payload let (not_ready, stream_finished) = match io.read_available(buf) { - Ok(Async::Ready(true)) => (false, true), - Ok(Async::Ready(false)) => (false, false), + Ok(Async::Ready((_, true))) => (false, true), + Ok(Async::Ready((_, false))) => (false, false), Ok(Async::NotReady) => (true, false), Err(err) => return Err(err.into()), }; @@ -104,7 +104,12 @@ impl HttpResponseParser { return Ok(Async::NotReady); } if stream_finished { - return Err(PayloadError::Incomplete); + // read untile eof? + if self.eof { + return Ok(Async::Ready(None)); + } else { + return Err(PayloadError::Incomplete); + } } } Err(err) => return Err(err.into()), @@ -117,7 +122,7 @@ impl HttpResponseParser { fn parse_message( buf: &mut BytesMut, - ) -> Poll<(ClientResponse, Option), ParseError> { + ) -> Poll<(ClientResponse, Option<(EncodingDecoder, bool)>), ParseError> { // Unsafe: we read only this data only after httparse parses headers into. // performance bump for pipeline benchmarks. let mut headers: [HeaderIndex; MAX_HEADERS] = unsafe { mem::uninitialized() }; @@ -163,12 +168,12 @@ impl HttpResponseParser { } let decoder = if status == StatusCode::SWITCHING_PROTOCOLS { - Some(EncodingDecoder::eof()) + Some((EncodingDecoder::eof(), true)) } else if let Some(len) = hdrs.get(header::CONTENT_LENGTH) { // Content-Length if let Ok(s) = len.to_str() { if let Ok(len) = s.parse::() { - Some(EncodingDecoder::length(len)) + Some((EncodingDecoder::length(len), false)) } else { debug!("illegal Content-Length: {:?}", len); return Err(ParseError::Header); @@ -179,7 +184,18 @@ impl HttpResponseParser { } } else if chunked(&hdrs)? { // Chunked encoding - Some(EncodingDecoder::chunked()) + Some((EncodingDecoder::chunked(), false)) + } else if let Some(value) = hdrs.get(header::CONNECTION) { + let close = if let Ok(s) = value.to_str() { + s == "close" + } else { + false + }; + if close { + Some((EncodingDecoder::eof(), true)) + } else { + None + } } else { None }; diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs index e5538b060..394b7a6cd 100644 --- a/src/client/pipeline.rs +++ b/src/client/pipeline.rs @@ -216,7 +216,7 @@ impl Future for SendRequest { match pl.parse() { Ok(Async::Ready(mut resp)) => { - if self.req.method() == &Method::HEAD { + if self.req.method() == Method::HEAD { pl.parser.take(); } resp.set_pipeline(pl); diff --git a/src/client/request.rs b/src/client/request.rs index 650f0eeaa..71da8f74d 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -254,16 +254,16 @@ impl ClientRequest { impl fmt::Debug for ClientRequest { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let res = writeln!( + writeln!( f, "\nClientRequest {:?} {}:{}", self.version, self.method, self.uri - ); - let _ = writeln!(f, " headers:"); + )?; + writeln!(f, " headers:")?; for (key, val) in self.headers.iter() { - let _ = writeln!(f, " {:?}: {:?}", key, val); + writeln!(f, " {:?}: {:?}", key, val)?; } - res + Ok(()) } } @@ -291,10 +291,6 @@ impl ClientRequestBuilder { fn _uri(&mut self, url: &str) -> &mut Self { match Uri::try_from(url) { Ok(uri) => { - // set request host header - if let Some(host) = uri.host() { - self.set_header(header::HOST, host); - } if let Some(parts) = parts(&mut self.request, &self.err) { parts.uri = uri; } @@ -316,8 +312,7 @@ impl ClientRequestBuilder { /// Set HTTP method of this request. #[inline] pub fn get_method(&mut self) -> &Method { - let parts = - parts(&mut self.request, &self.err).expect("cannot reuse request builder"); + let parts = self.request.as_ref().expect("cannot reuse request builder"); &parts.method } @@ -630,9 +625,31 @@ impl ClientRequestBuilder { self.set_header_if_none(header::ACCEPT_ENCODING, "gzip, deflate"); } + // set request host header + if let Some(parts) = parts(&mut self.request, &self.err) { + if let Some(host) = parts.uri.host() { + if !parts.headers.contains_key(header::HOST) { + let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); + + let _ = match parts.uri.port_part().map(|port| port.as_u16()) { + None | Some(80) | Some(443) => write!(wrt, "{}", host), + Some(port) => write!(wrt, "{}:{}", host, port), + }; + + match wrt.get_mut().take().freeze().try_into() { + Ok(value) => { + parts.headers.insert(header::HOST, value); + } + Err(e) => self.err = Some(e.into()), + } + } + } + } + + // user agent self.set_header_if_none( header::USER_AGENT, - concat!("Actix-web/", env!("CARGO_PKG_VERSION")), + concat!("actix-web/", env!("CARGO_PKG_VERSION")), ); } @@ -733,16 +750,16 @@ fn parts<'a>( impl fmt::Debug for ClientRequestBuilder { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if let Some(ref parts) = self.request { - let res = writeln!( + writeln!( f, "\nClientRequestBuilder {:?} {}:{}", parts.version, parts.method, parts.uri - ); - let _ = writeln!(f, " headers:"); + )?; + writeln!(f, " headers:")?; for (key, val) in parts.headers.iter() { - let _ = writeln!(f, " {:?}: {:?}", key, val); + writeln!(f, " {:?}: {:?}", key, val)?; } - res + Ok(()) } else { write!(f, "ClientRequestBuilder(Consumed)") } diff --git a/src/client/response.rs b/src/client/response.rs index 0c094a2aa..5f1f42649 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -95,12 +95,12 @@ impl ClientResponse { impl fmt::Debug for ClientResponse { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let res = writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status()); - let _ = writeln!(f, " headers:"); + writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status())?; + writeln!(f, " headers:")?; for (key, val) in self.headers().iter() { - let _ = writeln!(f, " {:?}: {:?}", key, val); + writeln!(f, " {:?}: {:?}", key, val)?; } - res + Ok(()) } } diff --git a/src/client/writer.rs b/src/client/writer.rs index b691407dd..321753bbf 100644 --- a/src/client/writer.rs +++ b/src/client/writer.rs @@ -1,4 +1,7 @@ -#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] +#![cfg_attr( + feature = "cargo-clippy", + allow(redundant_field_names) +)] use std::cell::RefCell; use std::fmt::Write as FmtWrite; @@ -8,7 +11,7 @@ use std::io::{self, Write}; use brotli2::write::BrotliEncoder; use bytes::{BufMut, BytesMut}; #[cfg(feature = "flate2")] -use flate2::write::{DeflateEncoder, GzEncoder}; +use flate2::write::{GzEncoder, ZlibEncoder}; #[cfg(feature = "flate2")] use flate2::Compression; use futures::{Async, Poll}; @@ -232,7 +235,7 @@ fn content_encoder(buf: BytesMut, req: &mut ClientRequest) -> Output { let mut enc = match encoding { #[cfg(feature = "flate2")] ContentEncoding::Deflate => ContentEncoder::Deflate( - DeflateEncoder::new(transfer, Compression::default()), + ZlibEncoder::new(transfer, Compression::default()), ), #[cfg(feature = "flate2")] ContentEncoding::Gzip => ContentEncoder::Gzip(GzEncoder::new( @@ -302,10 +305,9 @@ fn content_encoder(buf: BytesMut, req: &mut ClientRequest) -> Output { req.replace_body(body); let enc = match encoding { #[cfg(feature = "flate2")] - ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new( - transfer, - Compression::default(), - )), + ContentEncoding::Deflate => { + ContentEncoder::Deflate(ZlibEncoder::new(transfer, Compression::default())) + } #[cfg(feature = "flate2")] ContentEncoding::Gzip => { ContentEncoder::Gzip(GzEncoder::new(transfer, Compression::default())) diff --git a/src/de.rs b/src/de.rs index ecb2fa9ae..05f8914f8 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1,7 +1,10 @@ +use std::rc::Rc; + use serde::de::{self, Deserializer, Error as DeError, Visitor}; use httprequest::HttpRequest; use param::ParamsIter; +use uri::RESERVED_QUOTER; macro_rules! unsupported_type { ($trait_fn:ident, $name:expr) => { @@ -13,6 +16,20 @@ macro_rules! unsupported_type { }; } +macro_rules! percent_decode_if_needed { + ($value:expr, $decode:expr) => { + if $decode { + if let Some(ref mut value) = RESERVED_QUOTER.requote($value.as_bytes()) { + Rc::make_mut(value).parse() + } else { + $value.parse() + } + } else { + $value.parse() + } + } +} + macro_rules! parse_single_value { ($trait_fn:ident, $visit_fn:ident, $tp:tt) => { fn $trait_fn(self, visitor: V) -> Result @@ -23,11 +40,11 @@ macro_rules! parse_single_value { format!("wrong number of parameters: {} expected 1", self.req.match_info().len()).as_str())) } else { - let v = self.req.match_info()[0].parse().map_err( - |_| de::value::Error::custom( - format!("can not parse {:?} to a {}", - &self.req.match_info()[0], $tp)))?; - visitor.$visit_fn(v) + let v_parsed = percent_decode_if_needed!(&self.req.match_info()[0], self.decode) + .map_err(|_| de::value::Error::custom( + format!("can not parse {:?} to a {}", &self.req.match_info()[0], $tp) + ))?; + visitor.$visit_fn(v_parsed) } } } @@ -35,11 +52,12 @@ macro_rules! parse_single_value { pub struct PathDeserializer<'de, S: 'de> { req: &'de HttpRequest, + decode: bool, } impl<'de, S: 'de> PathDeserializer<'de, S> { - pub fn new(req: &'de HttpRequest) -> Self { - PathDeserializer { req } + pub fn new(req: &'de HttpRequest, decode: bool) -> Self { + PathDeserializer { req, decode } } } @@ -53,6 +71,7 @@ impl<'de, S: 'de> Deserializer<'de> for PathDeserializer<'de, S> { visitor.visit_map(ParamsDeserializer { params: self.req.match_info().iter(), current: None, + decode: self.decode, }) } @@ -107,6 +126,7 @@ impl<'de, S: 'de> Deserializer<'de> for PathDeserializer<'de, S> { } else { visitor.visit_seq(ParamsSeq { params: self.req.match_info().iter(), + decode: self.decode, }) } } @@ -128,6 +148,7 @@ impl<'de, S: 'de> Deserializer<'de> for PathDeserializer<'de, S> { } else { visitor.visit_seq(ParamsSeq { params: self.req.match_info().iter(), + decode: self.decode, }) } } @@ -141,28 +162,13 @@ impl<'de, S: 'de> Deserializer<'de> for PathDeserializer<'de, S> { Err(de::value::Error::custom("unsupported type: enum")) } - fn deserialize_str(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - if self.req.match_info().len() != 1 { - Err(de::value::Error::custom( - format!( - "wrong number of parameters: {} expected 1", - self.req.match_info().len() - ).as_str(), - )) - } else { - visitor.visit_str(&self.req.match_info()[0]) - } - } - fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_seq(ParamsSeq { params: self.req.match_info().iter(), + decode: self.decode, }) } @@ -175,7 +181,7 @@ impl<'de, S: 'de> Deserializer<'de> for PathDeserializer<'de, S> { parse_single_value!(deserialize_bool, visit_bool, "bool"); parse_single_value!(deserialize_i8, visit_i8, "i8"); parse_single_value!(deserialize_i16, visit_i16, "i16"); - parse_single_value!(deserialize_i32, visit_i32, "i16"); + parse_single_value!(deserialize_i32, visit_i32, "i32"); parse_single_value!(deserialize_i64, visit_i64, "i64"); parse_single_value!(deserialize_u8, visit_u8, "u8"); parse_single_value!(deserialize_u16, visit_u16, "u16"); @@ -184,13 +190,16 @@ impl<'de, S: 'de> Deserializer<'de> for PathDeserializer<'de, S> { parse_single_value!(deserialize_f32, visit_f32, "f32"); parse_single_value!(deserialize_f64, visit_f64, "f64"); parse_single_value!(deserialize_string, visit_string, "String"); + parse_single_value!(deserialize_str, visit_string, "String"); parse_single_value!(deserialize_byte_buf, visit_string, "String"); parse_single_value!(deserialize_char, visit_char, "char"); + } struct ParamsDeserializer<'de> { params: ParamsIter<'de>, current: Option<(&'de str, &'de str)>, + decode: bool, } impl<'de> de::MapAccess<'de> for ParamsDeserializer<'de> { @@ -212,7 +221,7 @@ impl<'de> de::MapAccess<'de> for ParamsDeserializer<'de> { V: de::DeserializeSeed<'de>, { if let Some((_, value)) = self.current.take() { - seed.deserialize(Value { value }) + seed.deserialize(Value { value, decode: self.decode }) } else { Err(de::value::Error::custom("unexpected item")) } @@ -252,16 +261,18 @@ macro_rules! parse_value { fn $trait_fn(self, visitor: V) -> Result where V: Visitor<'de> { - let v = self.value.parse().map_err( - |_| de::value::Error::custom( - format!("can not parse {:?} to a {}", self.value, $tp)))?; - visitor.$visit_fn(v) + let v_parsed = percent_decode_if_needed!(&self.value, self.decode) + .map_err(|_| de::value::Error::custom( + format!("can not parse {:?} to a {}", &self.value, $tp) + ))?; + visitor.$visit_fn(v_parsed) } } } struct Value<'de> { value: &'de str, + decode: bool, } impl<'de> Deserializer<'de> for Value<'de> { @@ -377,6 +388,7 @@ impl<'de> Deserializer<'de> for Value<'de> { struct ParamsSeq<'de> { params: ParamsIter<'de>, + decode: bool, } impl<'de> de::SeqAccess<'de> for ParamsSeq<'de> { @@ -387,7 +399,7 @@ impl<'de> de::SeqAccess<'de> for ParamsSeq<'de> { T: de::DeserializeSeed<'de>, { match self.params.next() { - Some(item) => Ok(Some(seed.deserialize(Value { value: item.1 })?)), + Some(item) => Ok(Some(seed.deserialize(Value { value: item.1, decode: self.decode })?)), None => Ok(None), } } diff --git a/src/error.rs b/src/error.rs index 461b23e20..1766c1523 100644 --- a/src/error.rs +++ b/src/error.rs @@ -52,7 +52,8 @@ pub struct Error { impl Error { /// Deprecated way to reference the underlying response error. #[deprecated( - since = "0.6.0", note = "please use `Error::as_response_error()` instead" + since = "0.6.0", + note = "please use `Error::as_response_error()` instead" )] pub fn cause(&self) -> &ResponseError { self.cause.as_ref() @@ -97,21 +98,9 @@ impl Error { // // So we first downcast into that compat, to then further downcast through // the failure's Error downcasting system into the original failure. - // - // This currently requires a transmute. This could be avoided if failure - // provides a deref: https://github.com/rust-lang-nursery/failure/pull/213 let compat: Option<&failure::Compat> = Fail::downcast_ref(self.cause.as_fail()); - if let Some(compat) = compat { - pub struct CompatWrappedError { - error: failure::Error, - } - let compat: &CompatWrappedError = - unsafe { &*(compat as *const _ as *const CompatWrappedError) }; - compat.error.downcast_ref() - } else { - None - } + compat.and_then(|e| e.get_ref().downcast_ref()) } } @@ -770,6 +759,16 @@ where InternalError::new(err, StatusCode::UNAUTHORIZED).into() } +/// Helper function that creates wrapper of any error and generate +/// *PAYMENT_REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorPaymentRequired(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PAYMENT_REQUIRED).into() +} + /// Helper function that creates wrapper of any error and generate *FORBIDDEN* /// response. #[allow(non_snake_case)] @@ -800,6 +799,26 @@ where InternalError::new(err, StatusCode::METHOD_NOT_ALLOWED).into() } +/// Helper function that creates wrapper of any error and generate *NOT +/// ACCEPTABLE* response. +#[allow(non_snake_case)] +pub fn ErrorNotAcceptable(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NOT_ACCEPTABLE).into() +} + +/// Helper function that creates wrapper of any error and generate *PROXY +/// AUTHENTICATION REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorProxyAuthenticationRequired(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PROXY_AUTHENTICATION_REQUIRED).into() +} + /// Helper function that creates wrapper of any error and generate *REQUEST /// TIMEOUT* response. #[allow(non_snake_case)] @@ -830,6 +849,16 @@ where InternalError::new(err, StatusCode::GONE).into() } +/// Helper function that creates wrapper of any error and generate *LENGTH +/// REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorLengthRequired(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::LENGTH_REQUIRED).into() +} + /// Helper function that creates wrapper of any error and generate /// *PRECONDITION FAILED* response. #[allow(non_snake_case)] @@ -840,6 +869,46 @@ where InternalError::new(err, StatusCode::PRECONDITION_FAILED).into() } +/// Helper function that creates wrapper of any error and generate +/// *PAYLOAD TOO LARGE* response. +#[allow(non_snake_case)] +pub fn ErrorPayloadTooLarge(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PAYLOAD_TOO_LARGE).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *URI TOO LONG* response. +#[allow(non_snake_case)] +pub fn ErrorUriTooLong(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::URI_TOO_LONG).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UNSUPPORTED MEDIA TYPE* response. +#[allow(non_snake_case)] +pub fn ErrorUnsupportedMediaType(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UNSUPPORTED_MEDIA_TYPE).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *RANGE NOT SATISFIABLE* response. +#[allow(non_snake_case)] +pub fn ErrorRangeNotSatisfiable(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::RANGE_NOT_SATISFIABLE).into() +} + /// Helper function that creates wrapper of any error and generate /// *EXPECTATION FAILED* response. #[allow(non_snake_case)] @@ -850,6 +919,106 @@ where InternalError::new(err, StatusCode::EXPECTATION_FAILED).into() } +/// Helper function that creates wrapper of any error and generate +/// *IM A TEAPOT* response. +#[allow(non_snake_case)] +pub fn ErrorImATeapot(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::IM_A_TEAPOT).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *MISDIRECTED REQUEST* response. +#[allow(non_snake_case)] +pub fn ErrorMisdirectedRequest(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::MISDIRECTED_REQUEST).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UNPROCESSABLE ENTITY* response. +#[allow(non_snake_case)] +pub fn ErrorUnprocessableEntity(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UNPROCESSABLE_ENTITY).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *LOCKED* response. +#[allow(non_snake_case)] +pub fn ErrorLocked(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::LOCKED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *FAILED DEPENDENCY* response. +#[allow(non_snake_case)] +pub fn ErrorFailedDependency(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::FAILED_DEPENDENCY).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UPGRADE REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorUpgradeRequired(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UPGRADE_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *PRECONDITION REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorPreconditionRequired(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PRECONDITION_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *TOO MANY REQUESTS* response. +#[allow(non_snake_case)] +pub fn ErrorTooManyRequests(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::TOO_MANY_REQUESTS).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *REQUEST HEADER FIELDS TOO LARGE* response. +#[allow(non_snake_case)] +pub fn ErrorRequestHeaderFieldsTooLarge(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UNAVAILABLE FOR LEGAL REASONS* response. +#[allow(non_snake_case)] +pub fn ErrorUnavailableForLegalReasons(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS).into() +} + /// Helper function that creates wrapper of any error and /// generate *INTERNAL SERVER ERROR* response. #[allow(non_snake_case)] @@ -900,6 +1069,66 @@ where InternalError::new(err, StatusCode::GATEWAY_TIMEOUT).into() } +/// Helper function that creates wrapper of any error and +/// generate *HTTP VERSION NOT SUPPORTED* response. +#[allow(non_snake_case)] +pub fn ErrorHttpVersionNotSupported(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::HTTP_VERSION_NOT_SUPPORTED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *VARIANT ALSO NEGOTIATES* response. +#[allow(non_snake_case)] +pub fn ErrorVariantAlsoNegotiates(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::VARIANT_ALSO_NEGOTIATES).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *INSUFFICIENT STORAGE* response. +#[allow(non_snake_case)] +pub fn ErrorInsufficientStorage(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::INSUFFICIENT_STORAGE).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *LOOP DETECTED* response. +#[allow(non_snake_case)] +pub fn ErrorLoopDetected(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::LOOP_DETECTED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *NOT EXTENDED* response. +#[allow(non_snake_case)] +pub fn ErrorNotExtended(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NOT_EXTENDED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *NETWORK AUTHENTICATION REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorNetworkAuthenticationRequired(err: T) -> Error +where + T: Send + Sync + fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NETWORK_AUTHENTICATION_REQUIRED).into() +} + #[cfg(test)] mod tests { use super::*; @@ -1079,6 +1308,9 @@ mod tests { let r: HttpResponse = ErrorUnauthorized("err").into(); assert_eq!(r.status(), StatusCode::UNAUTHORIZED); + let r: HttpResponse = ErrorPaymentRequired("err").into(); + assert_eq!(r.status(), StatusCode::PAYMENT_REQUIRED); + let r: HttpResponse = ErrorForbidden("err").into(); assert_eq!(r.status(), StatusCode::FORBIDDEN); @@ -1088,6 +1320,12 @@ mod tests { let r: HttpResponse = ErrorMethodNotAllowed("err").into(); assert_eq!(r.status(), StatusCode::METHOD_NOT_ALLOWED); + let r: HttpResponse = ErrorNotAcceptable("err").into(); + assert_eq!(r.status(), StatusCode::NOT_ACCEPTABLE); + + let r: HttpResponse = ErrorProxyAuthenticationRequired("err").into(); + assert_eq!(r.status(), StatusCode::PROXY_AUTHENTICATION_REQUIRED); + let r: HttpResponse = ErrorRequestTimeout("err").into(); assert_eq!(r.status(), StatusCode::REQUEST_TIMEOUT); @@ -1097,12 +1335,57 @@ mod tests { let r: HttpResponse = ErrorGone("err").into(); assert_eq!(r.status(), StatusCode::GONE); + let r: HttpResponse = ErrorLengthRequired("err").into(); + assert_eq!(r.status(), StatusCode::LENGTH_REQUIRED); + let r: HttpResponse = ErrorPreconditionFailed("err").into(); assert_eq!(r.status(), StatusCode::PRECONDITION_FAILED); + let r: HttpResponse = ErrorPayloadTooLarge("err").into(); + assert_eq!(r.status(), StatusCode::PAYLOAD_TOO_LARGE); + + let r: HttpResponse = ErrorUriTooLong("err").into(); + assert_eq!(r.status(), StatusCode::URI_TOO_LONG); + + let r: HttpResponse = ErrorUnsupportedMediaType("err").into(); + assert_eq!(r.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE); + + let r: HttpResponse = ErrorRangeNotSatisfiable("err").into(); + assert_eq!(r.status(), StatusCode::RANGE_NOT_SATISFIABLE); + let r: HttpResponse = ErrorExpectationFailed("err").into(); assert_eq!(r.status(), StatusCode::EXPECTATION_FAILED); + let r: HttpResponse = ErrorImATeapot("err").into(); + assert_eq!(r.status(), StatusCode::IM_A_TEAPOT); + + let r: HttpResponse = ErrorMisdirectedRequest("err").into(); + assert_eq!(r.status(), StatusCode::MISDIRECTED_REQUEST); + + let r: HttpResponse = ErrorUnprocessableEntity("err").into(); + assert_eq!(r.status(), StatusCode::UNPROCESSABLE_ENTITY); + + let r: HttpResponse = ErrorLocked("err").into(); + assert_eq!(r.status(), StatusCode::LOCKED); + + let r: HttpResponse = ErrorFailedDependency("err").into(); + assert_eq!(r.status(), StatusCode::FAILED_DEPENDENCY); + + let r: HttpResponse = ErrorUpgradeRequired("err").into(); + assert_eq!(r.status(), StatusCode::UPGRADE_REQUIRED); + + let r: HttpResponse = ErrorPreconditionRequired("err").into(); + assert_eq!(r.status(), StatusCode::PRECONDITION_REQUIRED); + + let r: HttpResponse = ErrorTooManyRequests("err").into(); + assert_eq!(r.status(), StatusCode::TOO_MANY_REQUESTS); + + let r: HttpResponse = ErrorRequestHeaderFieldsTooLarge("err").into(); + assert_eq!(r.status(), StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE); + + let r: HttpResponse = ErrorUnavailableForLegalReasons("err").into(); + assert_eq!(r.status(), StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS); + let r: HttpResponse = ErrorInternalServerError("err").into(); assert_eq!(r.status(), StatusCode::INTERNAL_SERVER_ERROR); @@ -1117,5 +1400,23 @@ mod tests { let r: HttpResponse = ErrorGatewayTimeout("err").into(); assert_eq!(r.status(), StatusCode::GATEWAY_TIMEOUT); + + let r: HttpResponse = ErrorHttpVersionNotSupported("err").into(); + assert_eq!(r.status(), StatusCode::HTTP_VERSION_NOT_SUPPORTED); + + let r: HttpResponse = ErrorVariantAlsoNegotiates("err").into(); + assert_eq!(r.status(), StatusCode::VARIANT_ALSO_NEGOTIATES); + + let r: HttpResponse = ErrorInsufficientStorage("err").into(); + assert_eq!(r.status(), StatusCode::INSUFFICIENT_STORAGE); + + let r: HttpResponse = ErrorLoopDetected("err").into(); + assert_eq!(r.status(), StatusCode::LOOP_DETECTED); + + let r: HttpResponse = ErrorNotExtended("err").into(); + assert_eq!(r.status(), StatusCode::NOT_EXTENDED); + + let r: HttpResponse = ErrorNetworkAuthenticationRequired("err").into(); + assert_eq!(r.status(), StatusCode::NETWORK_AUTHENTICATION_REQUIRED); } } diff --git a/src/extensions.rs b/src/extensions.rs index da7b5ba24..430b87bda 100644 --- a/src/extensions.rs +++ b/src/extensions.rs @@ -31,6 +31,7 @@ impl Hasher for IdHasher { type AnyMap = HashMap, BuildHasherDefault>; +#[derive(Default)] /// A type map of request extensions. pub struct Extensions { map: AnyMap, @@ -39,7 +40,7 @@ pub struct Extensions { impl Extensions { /// Create an empty `Extensions`. #[inline] - pub(crate) fn new() -> Extensions { + pub fn new() -> Extensions { Extensions { map: HashMap::default(), } diff --git a/src/extractor.rs b/src/extractor.rs index 5b3a69a89..861334f32 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -6,7 +6,7 @@ use std::{fmt, str}; use bytes::Bytes; use encoding::all::UTF_8; use encoding::types::{DecoderTrap, Encoding}; -use futures::{Async, Future, Poll, future}; +use futures::{future, Async, Future, Poll}; use mime::Mime; use serde::de::{self, DeserializeOwned}; use serde_urlencoded; @@ -19,7 +19,8 @@ use httprequest::HttpRequest; use Either; #[derive(PartialEq, Eq, PartialOrd, Ord)] -/// Extract typed information from the request's path. +/// Extract typed information from the request's path. Information from the path is +/// URL decoded. Decoding of special characters can be disabled through `PathConfig`. /// /// ## Example /// @@ -102,22 +103,83 @@ impl Path { } } +impl From for Path { + fn from(inner: T) -> Path { + Path { inner } + } +} + impl FromRequest for Path where T: DeserializeOwned, { - type Config = (); + type Config = PathConfig; type Result = Result; #[inline] - fn from_request(req: &HttpRequest, _: &Self::Config) -> Self::Result { + fn from_request(req: &HttpRequest, cfg: &Self::Config) -> Self::Result { let req = req.clone(); - de::Deserialize::deserialize(PathDeserializer::new(&req)) - .map_err(ErrorNotFound) + let req2 = req.clone(); + let err = Rc::clone(&cfg.ehandler); + de::Deserialize::deserialize(PathDeserializer::new(&req, cfg.decode)) + .map_err(move |e| (*err)(e, &req2)) .map(|inner| Path { inner }) } } +/// Path extractor configuration +/// +/// ```rust +/// # extern crate actix_web; +/// use actix_web::{error, http, App, HttpResponse, Path, Result}; +/// +/// /// deserialize `Info` from request's body, max payload size is 4kb +/// fn index(info: Path<(u32, String)>) -> Result { +/// Ok(format!("Welcome {}!", info.1)) +/// } +/// +/// fn main() { +/// let app = App::new().resource("/index.html/{id}/{name}", |r| { +/// r.method(http::Method::GET).with_config(index, |cfg| { +/// cfg.0.error_handler(|err, req| { +/// // <- create custom error response +/// error::InternalError::from_response(err, HttpResponse::Conflict().finish()).into() +/// }); +/// }) +/// }); +/// } +/// ``` +pub struct PathConfig { + ehandler: Rc) -> Error>, + decode: bool, +} +impl PathConfig { + /// Set custom error handler + pub fn error_handler(&mut self, f: F) -> &mut Self + where + F: Fn(serde_urlencoded::de::Error, &HttpRequest) -> Error + 'static, + { + self.ehandler = Rc::new(f); + self + } + + /// Disable decoding of URL encoded special charaters from the path + pub fn disable_decoding(&mut self) -> &mut Self + { + self.decode = false; + self + } +} + +impl Default for PathConfig { + fn default() -> Self { + PathConfig { + ehandler: Rc::new(|e, _| ErrorNotFound(e)), + decode: true, + } + } +} + impl fmt::Debug for Path { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.inner.fmt(f) @@ -195,17 +257,69 @@ impl FromRequest for Query where T: de::DeserializeOwned, { - type Config = (); + type Config = QueryConfig; type Result = Result; #[inline] - fn from_request(req: &HttpRequest, _: &Self::Config) -> Self::Result { + fn from_request(req: &HttpRequest, cfg: &Self::Config) -> Self::Result { + let req2 = req.clone(); + let err = Rc::clone(&cfg.ehandler); serde_urlencoded::from_str::(req.query_string()) - .map_err(|e| e.into()) + .map_err(move |e| (*err)(e, &req2)) .map(Query) } } +/// Query extractor configuration +/// +/// ```rust +/// # extern crate actix_web; +/// #[macro_use] extern crate serde_derive; +/// use actix_web::{error, http, App, HttpResponse, Query, Result}; +/// +/// #[derive(Deserialize)] +/// struct Info { +/// username: String, +/// } +/// +/// /// deserialize `Info` from request's body, max payload size is 4kb +/// fn index(info: Query) -> Result { +/// Ok(format!("Welcome {}!", info.username)) +/// } +/// +/// fn main() { +/// let app = App::new().resource("/index.html", |r| { +/// r.method(http::Method::GET).with_config(index, |cfg| { +/// cfg.0.error_handler(|err, req| { +/// // <- create custom error response +/// error::InternalError::from_response(err, HttpResponse::Conflict().finish()).into() +/// }); +/// }) +/// }); +/// } +/// ``` +pub struct QueryConfig { + ehandler: Rc) -> Error>, +} +impl QueryConfig { + /// Set custom error handler + pub fn error_handler(&mut self, f: F) -> &mut Self + where + F: Fn(serde_urlencoded::de::Error, &HttpRequest) -> Error + 'static, + { + self.ehandler = Rc::new(f); + self + } +} + +impl Default for QueryConfig { + fn default() -> Self { + QueryConfig { + ehandler: Rc::new(|e, _| e.into()), + } + } +} + impl fmt::Debug for Query { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.0.fmt(f) @@ -327,7 +441,7 @@ impl fmt::Display for Form { /// |r| { /// r.method(http::Method::GET) /// // register form handler and change form extractor configuration -/// .with_config(index, |cfg| {cfg.limit(4096);}) +/// .with_config(index, |cfg| {cfg.0.limit(4096);}) /// }, /// ); /// } @@ -422,7 +536,7 @@ impl FromRequest for Bytes { /// let app = App::new().resource("/index.html", |r| { /// r.method(http::Method::GET) /// .with_config(index, |cfg| { // <- register handler with extractor params -/// cfg.limit(4096); // <- limit size of the payload +/// cfg.0.limit(4096); // <- limit size of the payload /// }) /// }); /// } @@ -505,19 +619,18 @@ impl FromRequest for String { /// }); /// } /// ``` -impl FromRequest for Option where T: FromRequest { +impl FromRequest for Option +where + T: FromRequest, +{ type Config = T::Config; type Result = Box, Error = Error>>; #[inline] fn from_request(req: &HttpRequest, cfg: &Self::Config) -> Self::Result { - Box::new(T::from_request(req, cfg).into().then( |r| { - match r { - Ok(v) => future::ok(Some(v)), - Err(e) => { - future::ok(None) - } - } + Box::new(T::from_request(req, cfg).into().then(|r| match r { + Ok(v) => future::ok(Some(v)), + Err(_) => future::ok(None), })) } } @@ -711,13 +824,16 @@ impl Default for EitherConfig where A: FromRequest, B: FromRequ /// }); /// } /// ``` -impl FromRequest for Result where T: FromRequest{ +impl FromRequest for Result +where + T: FromRequest, +{ type Config = T::Config; type Result = Box, Error = Error>>; #[inline] fn from_request(req: &HttpRequest, cfg: &Self::Config) -> Self::Result { - Box::new(T::from_request(req, cfg).into().then( |r| { future::ok(r) })) + Box::new(T::from_request(req, cfg).into().then(future::ok)) } } @@ -833,6 +949,12 @@ macro_rules! tuple_from_req ({$fut_type:ident, $(($n:tt, $T:ident)),+} => { } }); +impl FromRequest for () { + type Config = (); + type Result = Self; + fn from_request(_req: &HttpRequest, _cfg: &Self::Config) -> Self::Result {} +} + tuple_from_req!(TupleFromRequest1, (0, A)); tuple_from_req!(TupleFromRequest2, (0, A), (1, B)); tuple_from_req!(TupleFromRequest3, (0, A), (1, B), (2, C)); @@ -938,8 +1060,8 @@ mod tests { header::CONTENT_TYPE, "application/x-www-form-urlencoded", ).header(header::CONTENT_LENGTH, "11") - .set_payload(Bytes::from_static(b"hello=world")) - .finish(); + .set_payload(Bytes::from_static(b"hello=world")) + .finish(); let mut cfg = FormConfig::default(); cfg.limit(4096); @@ -961,7 +1083,10 @@ mod tests { let mut cfg = FormConfig::default(); cfg.limit(4096); - match Option::>::from_request(&req, &cfg).poll().unwrap() { + match Option::>::from_request(&req, &cfg) + .poll() + .unwrap() + { Async::Ready(r) => assert_eq!(r, None), _ => unreachable!(), } @@ -970,11 +1095,19 @@ mod tests { header::CONTENT_TYPE, "application/x-www-form-urlencoded", ).header(header::CONTENT_LENGTH, "9") - .set_payload(Bytes::from_static(b"hello=world")) - .finish(); + .set_payload(Bytes::from_static(b"hello=world")) + .finish(); - match Option::>::from_request(&req, &cfg).poll().unwrap() { - Async::Ready(r) => assert_eq!(r, Some(Form(Info { hello: "world".into() }))), + match Option::>::from_request(&req, &cfg) + .poll() + .unwrap() + { + Async::Ready(r) => assert_eq!( + r, + Some(Form(Info { + hello: "world".into() + })) + ), _ => unreachable!(), } @@ -982,10 +1115,13 @@ mod tests { header::CONTENT_TYPE, "application/x-www-form-urlencoded", ).header(header::CONTENT_LENGTH, "9") - .set_payload(Bytes::from_static(b"bye=world")) - .finish(); + .set_payload(Bytes::from_static(b"bye=world")) + .finish(); - match Option::>::from_request(&req, &cfg).poll().unwrap() { + match Option::>::from_request(&req, &cfg) + .poll() + .unwrap() + { Async::Ready(r) => assert_eq!(r, None), _ => unreachable!(), } @@ -1039,11 +1175,19 @@ mod tests { header::CONTENT_TYPE, "application/x-www-form-urlencoded", ).header(header::CONTENT_LENGTH, "11") - .set_payload(Bytes::from_static(b"hello=world")) - .finish(); + .set_payload(Bytes::from_static(b"hello=world")) + .finish(); - match Result::, Error>::from_request(&req, &FormConfig::default()).poll().unwrap() { - Async::Ready(Ok(r)) => assert_eq!(r, Form(Info { hello: "world".into() })), + match Result::, Error>::from_request(&req, &FormConfig::default()) + .poll() + .unwrap() + { + Async::Ready(Ok(r)) => assert_eq!( + r, + Form(Info { + hello: "world".into() + }) + ), _ => unreachable!(), } @@ -1051,17 +1195,18 @@ mod tests { header::CONTENT_TYPE, "application/x-www-form-urlencoded", ).header(header::CONTENT_LENGTH, "9") - .set_payload(Bytes::from_static(b"bye=world")) - .finish(); + .set_payload(Bytes::from_static(b"bye=world")) + .finish(); - match Result::, Error>::from_request(&req, &FormConfig::default()).poll().unwrap() { + match Result::, Error>::from_request(&req, &FormConfig::default()) + .poll() + .unwrap() + { Async::Ready(r) => assert!(r.is_err()), _ => unreachable!(), } } - - #[test] fn test_payload_config() { let req = TestRequest::default().finish(); @@ -1101,33 +1246,33 @@ mod tests { fn test_request_extract() { let req = TestRequest::with_uri("/name/user1/?id=test").finish(); - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); router.register_resource(Resource::new(ResourceDef::new("/{key}/{value}/"))); let info = router.recognize(&req, &(), 0); let req = req.with_route_info(info); - let s = Path::::from_request(&req, &()).unwrap(); + let s = Path::::from_request(&req, &PathConfig::default()).unwrap(); assert_eq!(s.key, "name"); assert_eq!(s.value, "user1"); - let s = Path::<(String, String)>::from_request(&req, &()).unwrap(); + let s = Path::<(String, String)>::from_request(&req, &PathConfig::default()).unwrap(); assert_eq!(s.0, "name"); assert_eq!(s.1, "user1"); - let s = Query::::from_request(&req, &()).unwrap(); + let s = Query::::from_request(&req, &QueryConfig::default()).unwrap(); assert_eq!(s.id, "test"); - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); router.register_resource(Resource::new(ResourceDef::new("/{key}/{value}/"))); let req = TestRequest::with_uri("/name/32/").finish(); let info = router.recognize(&req, &(), 0); let req = req.with_route_info(info); - let s = Path::::from_request(&req, &()).unwrap(); + let s = Path::::from_request(&req, &PathConfig::default()).unwrap(); assert_eq!(s.as_ref().key, "name"); assert_eq!(s.value, 32); - let s = Path::<(String, u8)>::from_request(&req, &()).unwrap(); + let s = Path::<(String, u8)>::from_request(&req, &PathConfig::default()).unwrap(); assert_eq!(s.0, "name"); assert_eq!(s.1, 32); @@ -1138,18 +1283,80 @@ mod tests { #[test] fn test_extract_path_single() { - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); router.register_resource(Resource::new(ResourceDef::new("/{value}/"))); let req = TestRequest::with_uri("/32/").finish(); let info = router.recognize(&req, &(), 0); let req = req.with_route_info(info); - assert_eq!(*Path::::from_request(&req, &()).unwrap(), 32); + assert_eq!(*Path::::from_request(&req, &&PathConfig::default()).unwrap(), 32); + } + + #[test] + fn test_extract_path_decode() { + let mut router = Router::<()>::default(); + router.register_resource(Resource::new(ResourceDef::new("/{value}/"))); + + macro_rules! test_single_value { + ($value:expr, $expected:expr) => { + { + let req = TestRequest::with_uri($value).finish(); + let info = router.recognize(&req, &(), 0); + let req = req.with_route_info(info); + assert_eq!(*Path::::from_request(&req, &PathConfig::default()).unwrap(), $expected); + } + } + } + + test_single_value!("/%25/", "%"); + test_single_value!("/%40%C2%A3%24%25%5E%26%2B%3D/", "@£$%^&+="); + test_single_value!("/%2B/", "+"); + test_single_value!("/%252B/", "%2B"); + test_single_value!("/%2F/", "/"); + test_single_value!("/%252F/", "%2F"); + test_single_value!("/http%3A%2F%2Flocalhost%3A80%2Ffoo/", "http://localhost:80/foo"); + test_single_value!("/%2Fvar%2Flog%2Fsyslog/", "/var/log/syslog"); + test_single_value!( + "/http%3A%2F%2Flocalhost%3A80%2Ffile%2F%252Fvar%252Flog%252Fsyslog/", + "http://localhost:80/file/%2Fvar%2Flog%2Fsyslog" + ); + + let req = TestRequest::with_uri("/%25/7/?id=test").finish(); + + let mut router = Router::<()>::default(); + router.register_resource(Resource::new(ResourceDef::new("/{key}/{value}/"))); + let info = router.recognize(&req, &(), 0); + let req = req.with_route_info(info); + + let s = Path::::from_request(&req, &PathConfig::default()).unwrap(); + assert_eq!(s.key, "%"); + assert_eq!(s.value, 7); + + let s = Path::<(String, String)>::from_request(&req, &PathConfig::default()).unwrap(); + assert_eq!(s.0, "%"); + assert_eq!(s.1, "7"); + } + + #[test] + fn test_extract_path_no_decode() { + let mut router = Router::<()>::default(); + router.register_resource(Resource::new(ResourceDef::new("/{value}/"))); + + let req = TestRequest::with_uri("/%25/").finish(); + let info = router.recognize(&req, &(), 0); + let req = req.with_route_info(info); + assert_eq!( + *Path::::from_request( + &req, + &&PathConfig::default().disable_decoding() + ).unwrap(), + "%25" + ); } #[test] fn test_tuple_extract() { - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); router.register_resource(Resource::new(ResourceDef::new("/{key}/{value}/"))); let req = TestRequest::with_uri("/name/user1/?id=test").finish(); @@ -1173,5 +1380,7 @@ mod tests { assert_eq!((res.0).1, "user1"); assert_eq!((res.1).0, "name"); assert_eq!((res.1).1, "user1"); + + let () = <()>::extract(&req); } } diff --git a/src/fs.rs b/src/fs.rs index f23ba12cd..aec058aaf 100644 --- a/src/fs.rs +++ b/src/fs.rs @@ -11,10 +11,10 @@ use std::{cmp, io}; #[cfg(unix)] use std::os::unix::fs::MetadataExt; +use askama_escape::{escape as escape_html_entity}; use bytes::Bytes; use futures::{Async, Future, Poll, Stream}; use futures_cpupool::{CpuFuture, CpuPool}; -use htmlescape::encode_minimal as escape_html_entity; use mime; use mime_guess::{get_mime_type, guess_mime_type}; use percent_encoding::{utf8_percent_encode, DEFAULT_ENCODE_SET}; @@ -164,11 +164,7 @@ impl NamedFile { let disposition_type = C::content_disposition_map(ct.type_()); let cd = ContentDisposition { disposition: disposition_type, - parameters: vec![DispositionParam::Filename( - header::Charset::Ext("UTF-8".to_owned()), - None, - filename.as_bytes().to_vec(), - )], + parameters: vec![DispositionParam::Filename(filename.into_owned())], }; (ct, cd) }; @@ -373,11 +369,7 @@ impl Responder for NamedFile { .body("This resource only supports GET and HEAD.")); } - let etag = if C::is_use_etag() { - self.etag() - } else { - None - }; + let etag = if C::is_use_etag() { self.etag() } else { None }; let last_modified = if C::is_use_last_modifier() { self.last_modified() } else { @@ -480,6 +472,7 @@ impl Responder for NamedFile { } } +#[doc(hidden)] /// A helper created from a `std::fs::File` which reads the file /// chunk-by-chunk on a `CpuPool`. pub struct ChunkedReadFile { @@ -522,7 +515,8 @@ impl Stream for ChunkedReadFile { max_bytes = cmp::min(size.saturating_sub(counter), 65_536) as usize; let mut buf = Vec::with_capacity(max_bytes); file.seek(io::SeekFrom::Start(offset))?; - let nbytes = file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; + let nbytes = + file.by_ref().take(max_bytes as u64).read_to_end(&mut buf)?; if nbytes == 0 { return Err(io::ErrorKind::UnexpectedEof.into()); } @@ -568,8 +562,23 @@ impl Directory { } } +// show file url as relative to static path +macro_rules! encode_file_url { + ($path:ident) => { + utf8_percent_encode(&$path.to_string_lossy(), DEFAULT_ENCODE_SET) + }; +} + +// " -- " & -- & ' -- ' < -- < > -- > / -- / +macro_rules! encode_file_name { + ($entry:ident) => { + escape_html_entity(&$entry.file_name().to_string_lossy()) + }; +} + fn directory_listing( - dir: &Directory, req: &HttpRequest, + dir: &Directory, + req: &HttpRequest, ) -> Result { let index_of = format!("Index of {}", req.path()); let mut body = String::new(); @@ -582,11 +591,6 @@ fn directory_listing( Ok(p) => base.join(p), Err(_) => continue, }; - // show file url as relative to static path - let file_url = utf8_percent_encode(&p.to_string_lossy(), DEFAULT_ENCODE_SET) - .to_string(); - // " -- " & -- & ' -- ' < -- < > -- > - let file_name = escape_html_entity(&entry.file_name().to_string_lossy()); // if file is a directory, add '/' to the end of the name if let Ok(metadata) = entry.metadata() { @@ -594,13 +598,15 @@ fn directory_listing( let _ = write!( body, "
  • {}/
  • ", - file_url, file_name + encode_file_url!(p), + encode_file_name!(entry), ); } else { let _ = write!( body, "
  • {}
  • ", - file_url, file_name + encode_file_url!(p), + encode_file_name!(entry), ); } } else { @@ -663,7 +669,8 @@ impl StaticFiles { /// Create new `StaticFiles` instance for specified base directory and /// `CpuPool`. pub fn with_pool>( - dir: T, pool: CpuPool, + dir: T, + pool: CpuPool, ) -> Result, Error> { Self::with_config_pool(dir, pool, DefaultConfig) } @@ -674,7 +681,8 @@ impl StaticFiles { /// /// Identical with `new` but allows to specify configiration to use. pub fn with_config>( - dir: T, config: C, + dir: T, + config: C, ) -> Result, Error> { // use default CpuPool let pool = { DEFAULT_CPUPOOL.lock().clone() }; @@ -685,7 +693,9 @@ impl StaticFiles { /// Create new `StaticFiles` instance for specified base directory with config and /// `CpuPool`. pub fn with_config_pool>( - dir: T, pool: CpuPool, _: C, + dir: T, + pool: CpuPool, + _: C, ) -> Result, Error> { let dir = dir.into().canonicalize()?; @@ -743,7 +753,8 @@ impl StaticFiles { } fn try_handle( - &self, req: &HttpRequest, + &self, + req: &HttpRequest, ) -> Result, Error> { let tail: String = req.match_info().query("tail")?; let relpath = PathBuf::from_param(tail.trim_left_matches('/'))?; @@ -873,8 +884,7 @@ impl HttpRange { length: length as u64, })) } - }) - .collect::>()?; + }).collect::>()?; let ranges: Vec = all_ranges.into_iter().filter_map(|x| x).collect(); @@ -990,11 +1000,7 @@ mod tests { use header::{ContentDisposition, DispositionParam, DispositionType}; let cd = ContentDisposition { disposition: DispositionType::Attachment, - parameters: vec![DispositionParam::Filename( - header::Charset::Ext("UTF-8".to_owned()), - None, - "test.png".as_bytes().to_vec(), - )], + parameters: vec![DispositionParam::Filename(String::from("test.png"))], }; let mut file = NamedFile::open("tests/test.png") .unwrap() diff --git a/src/handler.rs b/src/handler.rs index 1735ffde6..c68808181 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -250,7 +250,7 @@ pub(crate) enum AsyncResultItem { impl AsyncResult { /// Create async response #[inline] - pub fn async(fut: Box>) -> AsyncResult { + pub fn future(fut: Box>) -> AsyncResult { AsyncResult(Some(AsyncResultItem::Future(fut))) } @@ -353,13 +353,17 @@ impl> From> for AsyncResult { } } -impl> From>, E>> - for AsyncResult +impl From>, E>> for AsyncResult +where + T: 'static, + E: Into + 'static, { #[inline] - fn from(res: Result>, E>) -> Self { + fn from(res: Result>, E>) -> Self { match res { - Ok(fut) => AsyncResult(Some(AsyncResultItem::Future(fut))), + Ok(fut) => AsyncResult(Some(AsyncResultItem::Future(Box::new( + fut.map_err(|e| e.into()), + )))), Err(err) => AsyncResult(Some(AsyncResultItem::Err(err.into()))), } } @@ -397,7 +401,7 @@ where }, Err(e) => err(e), }); - Ok(AsyncResult::async(Box::new(fut))) + Ok(AsyncResult::future(Box::new(fut))) } } @@ -498,7 +502,7 @@ where Err(e) => Either::A(err(e)), } }); - AsyncResult::async(Box::new(fut)) + AsyncResult::future(Box::new(fut)) } } @@ -526,8 +530,7 @@ where /// } /// /// /// extract path info using serde -/// fn index(data: (State, Path)) -> String { -/// let (state, path) = data; +/// fn index(state: State, path: Path) -> String { /// format!("{} {}!", state.msg, path.username) /// } /// diff --git a/src/header/common/content_disposition.rs b/src/header/common/content_disposition.rs index ff04ef565..5e8cbd67a 100644 --- a/src/header/common/content_disposition.rs +++ b/src/header/common/content_disposition.rs @@ -2,17 +2,35 @@ // // "The Content-Disposition Header Field" https://www.ietf.org/rfc/rfc2183.txt // "The Content-Disposition Header Field in the Hypertext Transfer Protocol (HTTP)" https://www.ietf.org/rfc/rfc6266.txt -// "Returning Values from Forms: multipart/form-data" https://www.ietf.org/rfc/rfc2388.txt +// "Returning Values from Forms: multipart/form-data" https://www.ietf.org/rfc/rfc7578.txt // Browser conformance tests at: http://greenbytes.de/tech/tc2231/ // IANA assignment: http://www.iana.org/assignments/cont-disp/cont-disp.xhtml -use language_tags::LanguageTag; use header; +use header::ExtendedValue; use header::{Header, IntoHeaderValue, Writer}; -use header::shared::Charset; +use regex::Regex; use std::fmt::{self, Write}; +/// Split at the index of the first `needle` if it exists or at the end. +fn split_once(haystack: &str, needle: char) -> (&str, &str) { + haystack.find(needle).map_or_else( + || (haystack, ""), + |sc| { + let (first, last) = haystack.split_at(sc); + (first, last.split_at(1).1) + }, + ) +} + +/// Split at the index of the first `needle` if it exists or at the end, trim the right of the +/// first part and the left of the last part. +fn split_once_and_trim(haystack: &str, needle: char) -> (&str, &str) { + let (first, last) = split_once(haystack, needle); + (first.trim_right(), last.trim_left()) +} + /// The implied disposition of the content of the HTTP body. #[derive(Clone, Debug, PartialEq)] pub enum DispositionType { @@ -21,27 +39,164 @@ pub enum DispositionType { /// Attachment implies that the recipient should prompt the user to save the response locally, /// rather than process it normally (as per its media type). Attachment, - /// Extension type. Should be handled by recipients the same way as Attachment - Ext(String) + /// Used in *multipart/form-data* as defined in + /// [RFC7578](https://tools.ietf.org/html/rfc7578) to carry the field name and the file name. + FormData, + /// Extension type. Should be handled by recipients the same way as Attachment + Ext(String), } -/// A parameter to the disposition type. +impl<'a> From<&'a str> for DispositionType { + fn from(origin: &'a str) -> DispositionType { + if origin.eq_ignore_ascii_case("inline") { + DispositionType::Inline + } else if origin.eq_ignore_ascii_case("attachment") { + DispositionType::Attachment + } else if origin.eq_ignore_ascii_case("form-data") { + DispositionType::FormData + } else { + DispositionType::Ext(origin.to_owned()) + } + } +} + +/// Parameter in [`ContentDisposition`]. +/// +/// # Examples +/// ``` +/// use actix_web::http::header::DispositionParam; +/// +/// let param = DispositionParam::Filename(String::from("sample.txt")); +/// assert!(param.is_filename()); +/// assert_eq!(param.as_filename().unwrap(), "sample.txt"); +/// ``` #[derive(Clone, Debug, PartialEq)] pub enum DispositionParam { - /// A Filename consisting of a Charset, an optional LanguageTag, and finally a sequence of - /// bytes representing the filename - Filename(Charset, Option, Vec), - /// Extension type consisting of token and value. Recipients should ignore unrecognized - /// parameters. - Ext(String, String) + /// For [`DispositionType::FormData`] (i.e. *multipart/form-data*), the name of an field from + /// the form. + Name(String), + /// A plain file name. + Filename(String), + /// An extended file name. It must not exist for `ContentType::Formdata` according to + /// [RFC7578 Section 4.2](https://tools.ietf.org/html/rfc7578#section-4.2). + FilenameExt(ExtendedValue), + /// An unrecognized regular parameter as defined in + /// [RFC5987](https://tools.ietf.org/html/rfc5987) as *reg-parameter*, in + /// [RFC6266](https://tools.ietf.org/html/rfc6266) as *token "=" value*. Recipients should + /// ignore unrecognizable parameters. + Unknown(String, String), + /// An unrecognized extended paramater as defined in + /// [RFC5987](https://tools.ietf.org/html/rfc5987) as *ext-parameter*, in + /// [RFC6266](https://tools.ietf.org/html/rfc6266) as *ext-token "=" ext-value*. The single + /// trailling asterisk is not included. Recipients should ignore unrecognizable parameters. + UnknownExt(String, ExtendedValue), } -/// A `Content-Disposition` header, (re)defined in [RFC6266](https://tools.ietf.org/html/rfc6266). +impl DispositionParam { + /// Returns `true` if the paramater is [`Name`](DispositionParam::Name). + #[inline] + pub fn is_name(&self) -> bool { + self.as_name().is_some() + } + + /// Returns `true` if the paramater is [`Filename`](DispositionParam::Filename). + #[inline] + pub fn is_filename(&self) -> bool { + self.as_filename().is_some() + } + + /// Returns `true` if the paramater is [`FilenameExt`](DispositionParam::FilenameExt). + #[inline] + pub fn is_filename_ext(&self) -> bool { + self.as_filename_ext().is_some() + } + + /// Returns `true` if the paramater is [`Unknown`](DispositionParam::Unknown) and the `name` + #[inline] + /// matches. + pub fn is_unknown>(&self, name: T) -> bool { + self.as_unknown(name).is_some() + } + + /// Returns `true` if the paramater is [`UnknownExt`](DispositionParam::UnknownExt) and the + /// `name` matches. + #[inline] + pub fn is_unknown_ext>(&self, name: T) -> bool { + self.as_unknown_ext(name).is_some() + } + + /// Returns the name if applicable. + #[inline] + pub fn as_name(&self) -> Option<&str> { + match self { + DispositionParam::Name(ref name) => Some(name.as_str()), + _ => None, + } + } + + /// Returns the filename if applicable. + #[inline] + pub fn as_filename(&self) -> Option<&str> { + match self { + DispositionParam::Filename(ref filename) => Some(filename.as_str()), + _ => None, + } + } + + /// Returns the filename* if applicable. + #[inline] + pub fn as_filename_ext(&self) -> Option<&ExtendedValue> { + match self { + DispositionParam::FilenameExt(ref value) => Some(value), + _ => None, + } + } + + /// Returns the value of the unrecognized regular parameter if it is + /// [`Unknown`](DispositionParam::Unknown) and the `name` matches. + #[inline] + pub fn as_unknown>(&self, name: T) -> Option<&str> { + match self { + DispositionParam::Unknown(ref ext_name, ref value) + if ext_name.eq_ignore_ascii_case(name.as_ref()) => + { + Some(value.as_str()) + } + _ => None, + } + } + + /// Returns the value of the unrecognized extended parameter if it is + /// [`Unknown`](DispositionParam::Unknown) and the `name` matches. + #[inline] + pub fn as_unknown_ext>(&self, name: T) -> Option<&ExtendedValue> { + match self { + DispositionParam::UnknownExt(ref ext_name, ref value) + if ext_name.eq_ignore_ascii_case(name.as_ref()) => + { + Some(value) + } + _ => None, + } + } +} + +/// A *Content-Disposition* header. It is compatible to be used either as +/// [a response header for the main body](https://mdn.io/Content-Disposition#As_a_response_header_for_the_main_body) +/// as (re)defined in [RFC6266](https://tools.ietf.org/html/rfc6266), or as +/// [a header for a multipart body](https://mdn.io/Content-Disposition#As_a_header_for_a_multipart_body) +/// as (re)defined in [RFC7587](https://tools.ietf.org/html/rfc7578). /// -/// The Content-Disposition response header field is used to convey -/// additional information about how to process the response payload, and -/// also can be used to attach additional metadata, such as the filename -/// to use when saving the response payload locally. +/// In a regular HTTP response, the *Content-Disposition* response header is a header indicating if +/// the content is expected to be displayed *inline* in the browser, that is, as a Web page or as +/// part of a Web page, or as an attachment, that is downloaded and saved locally, and also can be +/// used to attach additional metadata, such as the filename to use when saving the response payload +/// locally. +/// +/// In a *multipart/form-data* body, the HTTP *Content-Disposition* general header is a header that +/// can be used on the subpart of a multipart body to give information about the field it applies to. +/// The subpart is delimited by the boundary defined in the *Content-Type* header. Used on the body +/// itself, *Content-Disposition* has no effect. /// /// # ABNF @@ -65,88 +220,211 @@ pub enum DispositionParam { /// ext-token = /// ``` /// +/// **Note**: filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within +/// *multipart/form-data*. +/// /// # Example /// /// ``` -/// use actix_web::http::header::{ContentDisposition, DispositionType, DispositionParam, Charset}; +/// use actix_web::http::header::{ +/// Charset, ContentDisposition, DispositionParam, DispositionType, +/// ExtendedValue, +/// }; /// /// let cd1 = ContentDisposition { /// disposition: DispositionType::Attachment, -/// parameters: vec![DispositionParam::Filename( -/// Charset::Iso_8859_1, // The character set for the bytes of the filename -/// None, // The optional language tag (see `language-tag` crate) -/// b"\xa9 Copyright 1989.txt".to_vec() // the actual bytes of the filename -/// )] +/// parameters: vec![DispositionParam::FilenameExt(ExtendedValue { +/// charset: Charset::Iso_8859_1, // The character set for the bytes of the filename +/// language_tag: None, // The optional language tag (see `language-tag` crate) +/// value: b"\xa9 Copyright 1989.txt".to_vec(), // the actual bytes of the filename +/// })], /// }; +/// assert!(cd1.is_attachment()); +/// assert!(cd1.get_filename_ext().is_some()); /// /// let cd2 = ContentDisposition { -/// disposition: DispositionType::Inline, -/// parameters: vec![DispositionParam::Filename( -/// Charset::Ext("UTF-8".to_owned()), -/// None, -/// "\u{2764}".as_bytes().to_vec() -/// )] +/// disposition: DispositionType::FormData, +/// parameters: vec![ +/// DispositionParam::Name(String::from("file")), +/// DispositionParam::Filename(String::from("bill.odt")), +/// ], /// }; +/// assert_eq!(cd2.get_name(), Some("file")); // field name +/// assert_eq!(cd2.get_filename(), Some("bill.odt")); /// ``` +/// +/// # WARN +/// If "filename" parameter is supplied, do not use the file name blindly, check and possibly +/// change to match local file system conventions if applicable, and do not use directory path +/// information that may be present. See [RFC2183](https://tools.ietf.org/html/rfc2183#section-2.3) +/// . #[derive(Clone, Debug, PartialEq)] pub struct ContentDisposition { - /// The disposition + /// The disposition type pub disposition: DispositionType, /// Disposition parameters pub parameters: Vec, } + impl ContentDisposition { - /// Parse a raw Content-Disposition header value + /// Parse a raw Content-Disposition header value. pub fn from_raw(hv: &header::HeaderValue) -> Result { - header::from_one_raw_str(Some(hv)).and_then(|s: String| { - let mut sections = s.split(';'); - let disposition = match sections.next() { - Some(s) => s.trim(), - None => return Err(::error::ParseError::Header), - }; + // `header::from_one_raw_str` invokes `hv.to_str` which assumes `hv` contains only visible + // ASCII characters. So `hv.as_bytes` is necessary here. + let hv = String::from_utf8(hv.as_bytes().to_vec()) + .map_err(|_| ::error::ParseError::Header)?; + let (disp_type, mut left) = split_once_and_trim(hv.as_str().trim(), ';'); + if disp_type.is_empty() { + return Err(::error::ParseError::Header); + } + let mut cd = ContentDisposition { + disposition: disp_type.into(), + parameters: Vec::new(), + }; - let mut cd = ContentDisposition { - disposition: if disposition.eq_ignore_ascii_case("inline") { - DispositionType::Inline - } else if disposition.eq_ignore_ascii_case("attachment") { - DispositionType::Attachment - } else { - DispositionType::Ext(disposition.to_owned()) - }, - parameters: Vec::new(), - }; - - for section in sections { - let mut parts = section.splitn(2, '='); - - let key = if let Some(key) = parts.next() { - key.trim() - } else { - return Err(::error::ParseError::Header); - }; - - let val = if let Some(val) = parts.next() { - val.trim() - } else { - return Err(::error::ParseError::Header); - }; - - cd.parameters.push( - if key.eq_ignore_ascii_case("filename") { - DispositionParam::Filename( - Charset::Ext("UTF-8".to_owned()), None, - val.trim_matches('"').as_bytes().to_owned()) - } else if key.eq_ignore_ascii_case("filename*") { - let extended_value = try!(header::parse_extended_value(val)); - DispositionParam::Filename(extended_value.charset, extended_value.language_tag, extended_value.value) - } else { - DispositionParam::Ext(key.to_owned(), val.trim_matches('"').to_owned()) - } - ); + while !left.is_empty() { + let (param_name, new_left) = split_once_and_trim(left, '='); + if param_name.is_empty() || param_name == "*" || new_left.is_empty() { + return Err(::error::ParseError::Header); } + left = new_left; + if param_name.ends_with('*') { + // extended parameters + let param_name = ¶m_name[..param_name.len() - 1]; // trim asterisk + let (ext_value, new_left) = split_once_and_trim(left, ';'); + left = new_left; + let ext_value = header::parse_extended_value(ext_value)?; - Ok(cd) - }) + let param = if param_name.eq_ignore_ascii_case("filename") { + DispositionParam::FilenameExt(ext_value) + } else { + DispositionParam::UnknownExt(param_name.to_owned(), ext_value) + }; + cd.parameters.push(param); + } else { + // regular parameters + let value = if left.starts_with('\"') { + // quoted-string: defined in RFC6266 -> RFC2616 Section 3.6 + let mut escaping = false; + let mut quoted_string = vec![]; + let mut end = None; + // search for closing quote + for (i, &c) in left.as_bytes().iter().skip(1).enumerate() { + if escaping { + escaping = false; + quoted_string.push(c); + } else if c == 0x5c { + // backslash + escaping = true; + } else if c == 0x22 { + // double quote + end = Some(i + 1); // cuz skipped 1 for the leading quote + break; + } else { + quoted_string.push(c); + } + } + left = &left[end.ok_or(::error::ParseError::Header)? + 1..]; + left = split_once(left, ';').1.trim_left(); + // In fact, it should not be Err if the above code is correct. + String::from_utf8(quoted_string).map_err(|_| ::error::ParseError::Header)? + } else { + // token: won't contains semicolon according to RFC 2616 Section 2.2 + let (token, new_left) = split_once_and_trim(left, ';'); + left = new_left; + token.to_owned() + }; + if value.is_empty() { + return Err(::error::ParseError::Header); + } + + let param = if param_name.eq_ignore_ascii_case("name") { + DispositionParam::Name(value) + } else if param_name.eq_ignore_ascii_case("filename") { + DispositionParam::Filename(value) + } else { + DispositionParam::Unknown(param_name.to_owned(), value) + }; + cd.parameters.push(param); + } + } + + Ok(cd) + } + + /// Returns `true` if it is [`Inline`](DispositionType::Inline). + pub fn is_inline(&self) -> bool { + match self.disposition { + DispositionType::Inline => true, + _ => false, + } + } + + /// Returns `true` if it is [`Attachment`](DispositionType::Attachment). + pub fn is_attachment(&self) -> bool { + match self.disposition { + DispositionType::Attachment => true, + _ => false, + } + } + + /// Returns `true` if it is [`FormData`](DispositionType::FormData). + pub fn is_form_data(&self) -> bool { + match self.disposition { + DispositionType::FormData => true, + _ => false, + } + } + + /// Returns `true` if it is [`Ext`](DispositionType::Ext) and the `disp_type` matches. + pub fn is_ext>(&self, disp_type: T) -> bool { + match self.disposition { + DispositionType::Ext(ref t) + if t.eq_ignore_ascii_case(disp_type.as_ref()) => + { + true + } + _ => false, + } + } + + /// Return the value of *name* if exists. + pub fn get_name(&self) -> Option<&str> { + self.parameters.iter().filter_map(|p| p.as_name()).nth(0) + } + + /// Return the value of *filename* if exists. + pub fn get_filename(&self) -> Option<&str> { + self.parameters + .iter() + .filter_map(|p| p.as_filename()) + .nth(0) + } + + /// Return the value of *filename\** if exists. + pub fn get_filename_ext(&self) -> Option<&ExtendedValue> { + self.parameters + .iter() + .filter_map(|p| p.as_filename_ext()) + .nth(0) + } + + /// Return the value of the parameter which the `name` matches. + pub fn get_unknown>(&self, name: T) -> Option<&str> { + let name = name.as_ref(); + self.parameters + .iter() + .filter_map(|p| p.as_unknown(name)) + .nth(0) + } + + /// Return the value of the extended parameter which the `name` matches. + pub fn get_unknown_ext>(&self, name: T) -> Option<&ExtendedValue> { + let name = name.as_ref(); + self.parameters + .iter() + .filter_map(|p| p.as_unknown_ext(name)) + .nth(0) } } @@ -174,67 +452,76 @@ impl Header for ContentDisposition { } } -impl fmt::Display for ContentDisposition { +impl fmt::Display for DispositionType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.disposition { - DispositionType::Inline => try!(write!(f, "inline")), - DispositionType::Attachment => try!(write!(f, "attachment")), - DispositionType::Ext(ref s) => try!(write!(f, "{}", s)), + match self { + DispositionType::Inline => write!(f, "inline"), + DispositionType::Attachment => write!(f, "attachment"), + DispositionType::FormData => write!(f, "form-data"), + DispositionType::Ext(ref s) => write!(f, "{}", s), } - for param in &self.parameters { - match *param { - DispositionParam::Filename(ref charset, ref opt_lang, ref bytes) => { - let mut use_simple_format: bool = false; - if opt_lang.is_none() { - if let Charset::Ext(ref ext) = *charset { - if ext.eq_ignore_ascii_case("utf-8") { - use_simple_format = true; - } - } - } - if use_simple_format { - use std::str; - try!(write!(f, "; filename=\"{}\"", - match str::from_utf8(bytes) { - Ok(s) => s, - Err(_) => return Err(fmt::Error), - })); - } else { - try!(write!(f, "; filename*={}'", charset)); - if let Some(ref lang) = *opt_lang { - try!(write!(f, "{}", lang)); - }; - try!(write!(f, "'")); - try!(header::http_percent_encode(f, bytes)) - } - }, - DispositionParam::Ext(ref k, ref v) => try!(write!(f, "; {}=\"{}\"", k, v)), + } +} + +impl fmt::Display for DispositionParam { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // All ASCII control charaters (0-30, 127) excepting horizontal tab, double quote, and + // backslash should be escaped in quoted-string (i.e. "foobar"). + // Ref: RFC6266 S4.1 -> RFC2616 S2.2; RFC 7578 S4.2 -> RFC2183 S2 -> ... . + lazy_static! { + static ref RE: Regex = Regex::new("[\x01-\x08\x10\x1F\x7F\"\\\\]").unwrap(); + } + match self { + DispositionParam::Name(ref value) => write!(f, "name={}", value), + DispositionParam::Filename(ref value) => { + write!(f, "filename=\"{}\"", RE.replace_all(value, "\\$0").as_ref()) + } + DispositionParam::Unknown(ref name, ref value) => write!( + f, + "{}=\"{}\"", + name, + &RE.replace_all(value, "\\$0").as_ref() + ), + DispositionParam::FilenameExt(ref ext_value) => { + write!(f, "filename*={}", ext_value) + } + DispositionParam::UnknownExt(ref name, ref ext_value) => { + write!(f, "{}*={}", name, ext_value) } } - Ok(()) + } +} + +impl fmt::Display for ContentDisposition { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.disposition)?; + self.parameters + .iter() + .map(|param| write!(f, "; {}", param)) + .collect() } } #[cfg(test)] mod tests { - use super::{ContentDisposition,DispositionType,DispositionParam}; - use header::HeaderValue; + use super::{ContentDisposition, DispositionParam, DispositionType}; use header::shared::Charset; + use header::{ExtendedValue, HeaderValue}; #[test] - fn test_from_raw() { + fn test_from_raw_basic() { assert!(ContentDisposition::from_raw(&HeaderValue::from_static("")).is_err()); - let a = HeaderValue::from_static("form-data; dummy=3; name=upload; filename=\"sample.png\""); + let a = HeaderValue::from_static( + "form-data; dummy=3; name=upload; filename=\"sample.png\"", + ); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let b = ContentDisposition { - disposition: DispositionType::Ext("form-data".to_owned()), + disposition: DispositionType::FormData, parameters: vec![ - DispositionParam::Ext("dummy".to_owned(), "3".to_owned()), - DispositionParam::Ext("name".to_owned(), "upload".to_owned()), - DispositionParam::Filename( - Charset::Ext("UTF-8".to_owned()), - None, - "sample.png".bytes().collect()) ] + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + ], }; assert_eq!(a, b); @@ -242,44 +529,386 @@ mod tests { let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let b = ContentDisposition { disposition: DispositionType::Attachment, - parameters: vec![ - DispositionParam::Filename( - Charset::Ext("UTF-8".to_owned()), - None, - "image.jpg".bytes().collect()) ] + parameters: vec![DispositionParam::Filename("image.jpg".to_owned())], }; assert_eq!(a, b); - let a = HeaderValue::from_static("attachment; filename*=UTF-8''%c2%a3%20and%20%e2%82%ac%20rates"); + let a = HeaderValue::from_static("inline; filename=image.jpg"); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![DispositionParam::Filename("image.jpg".to_owned())], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static( + "attachment; creation-date=\"Wed, 12 Feb 1997 16:29:51 -0500\"", + ); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); let b = ContentDisposition { disposition: DispositionType::Attachment, - parameters: vec![ - DispositionParam::Filename( - Charset::Ext("UTF-8".to_owned()), - None, - vec![0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, - 0xe2, 0x82, 0xac, 0x20, b'r', b'a', b't', b'e', b's']) ] + parameters: vec![DispositionParam::Unknown( + String::from("creation-date"), + "Wed, 12 Feb 1997 16:29:51 -0500".to_owned(), + )], }; assert_eq!(a, b); } #[test] - fn test_display() { - let as_string = "attachment; filename*=UTF-8'en'%C2%A3%20and%20%E2%82%AC%20rates"; + fn test_from_raw_extended() { + let a = HeaderValue::from_static( + "attachment; filename*=UTF-8''%c2%a3%20and%20%e2%82%ac%20rates", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Ext(String::from("UTF-8")), + language_tag: None, + value: vec![ + 0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, 0xe2, 0x82, 0xac, 0x20, + b'r', b'a', b't', b'e', b's', + ], + })], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static( + "attachment; filename*=UTF-8''%c2%a3%20and%20%e2%82%ac%20rates", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Ext(String::from("UTF-8")), + language_tag: None, + value: vec![ + 0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, 0xe2, 0x82, 0xac, 0x20, + b'r', b'a', b't', b'e', b's', + ], + })], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_extra_whitespace() { + let a = HeaderValue::from_static( + "form-data ; du-mmy= 3 ; name =upload ; filename = \"sample.png\" ; ", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("du-mmy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_unordered() { + let a = HeaderValue::from_static( + "form-data; dummy=3; filename=\"sample.png\" ; name=upload;", + // Actually, a trailling semolocon is not compliant. But it is fine to accept. + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + DispositionParam::Name("upload".to_owned()), + ], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_str( + "attachment; filename*=iso-8859-1''foo-%E4.html; filename=\"foo-ä.html\"", + ).unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![ + DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Iso_8859_1, + language_tag: None, + value: b"foo-\xe4.html".to_vec(), + }), + DispositionParam::Filename("foo-ä.html".to_owned()), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_only_disp() { + let a = ContentDisposition::from_raw(&HeaderValue::from_static("attachment")) + .unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![], + }; + assert_eq!(a, b); + + let a = + ContentDisposition::from_raw(&HeaderValue::from_static("inline ;")).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![], + }; + assert_eq!(a, b); + + let a = ContentDisposition::from_raw(&HeaderValue::from_static( + "unknown-disp-param", + )).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Ext(String::from("unknown-disp-param")), + parameters: vec![], + }; + assert_eq!(a, b); + } + + #[test] + fn from_raw_with_mixed_case() { + let a = HeaderValue::from_str( + "InLInE; fIlenAME*=iso-8859-1''foo-%E4.html; filEName=\"foo-ä.html\"", + ).unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![ + DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Iso_8859_1, + language_tag: None, + value: b"foo-\xe4.html".to_vec(), + }), + DispositionParam::Filename("foo-ä.html".to_owned()), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn from_raw_with_unicode() { + /* RFC7578 Section 4.2: + Some commonly deployed systems use multipart/form-data with file names directly encoded + including octets outside the US-ASCII range. The encoding used for the file names is + typically UTF-8, although HTML forms will use the charset associated with the form. + + Mainstream browsers like Firefox (gecko) and Chrome use UTF-8 directly as above. + (And now, only UTF-8 is handled by this implementation.) + */ + let a = + HeaderValue::from_str("form-data; name=upload; filename=\"文件.webp\"") + .unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name(String::from("upload")), + DispositionParam::Filename(String::from("文件.webp")), + ], + }; + assert_eq!(a, b); + + let a = + HeaderValue::from_str("form-data; name=upload; filename=\"余固知謇謇之為患兮,忍而不能舍也.pptx\"").unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name(String::from("upload")), + DispositionParam::Filename(String::from( + "余固知謇謇之為患兮,忍而不能舍也.pptx", + )), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_escape() { + let a = HeaderValue::from_static( + "form-data; dummy=3; name=upload; filename=\"s\\amp\\\"le.png\"", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename( + ['s', 'a', 'm', 'p', '\"', 'l', 'e', '.', 'p', 'n', 'g'] + .iter() + .collect(), + ), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_semicolon() { + let a = + HeaderValue::from_static("form-data; filename=\"A semicolon here;.pdf\""); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![DispositionParam::Filename(String::from( + "A semicolon here;.pdf", + ))], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_uncessary_percent_decode() { + let a = HeaderValue::from_static( + "form-data; name=photo; filename=\"%74%65%73%74%2e%70%6e%67\"", // Should not be decoded! + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name("photo".to_owned()), + DispositionParam::Filename(String::from("%74%65%73%74%2e%70%6e%67")), + ], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static( + "form-data; name=photo; filename=\"%74%65%73%74.png\"", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name("photo".to_owned()), + DispositionParam::Filename(String::from("%74%65%73%74.png")), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_param_value_missing() { + let a = HeaderValue::from_static("form-data; name=upload ; filename="); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("attachment; dummy=; filename=invoice.pdf"); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; filename= "); + assert!(ContentDisposition::from_raw(&a).is_err()); + } + + #[test] + fn test_from_raw_param_name_missing() { + let a = HeaderValue::from_static("inline; =\"test.txt\""); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; =diary.odt"); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; ="); + assert!(ContentDisposition::from_raw(&a).is_err()); + } + + #[test] + fn test_display_extended() { + let as_string = + "attachment; filename*=UTF-8'en'%C2%A3%20and%20%E2%82%AC%20rates"; let a = HeaderValue::from_static(as_string); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); - let display_rendered = format!("{}",a); + let display_rendered = format!("{}", a); assert_eq!(as_string, display_rendered); - let a = HeaderValue::from_static("attachment; filename*=UTF-8''black%20and%20white.csv"); - let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); - let display_rendered = format!("{}",a); - assert_eq!("attachment; filename=\"black and white.csv\"".to_owned(), display_rendered); - let a = HeaderValue::from_static("attachment; filename=colourful.csv"); let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); - let display_rendered = format!("{}",a); - assert_eq!("attachment; filename=\"colourful.csv\"".to_owned(), display_rendered); + let display_rendered = format!("{}", a); + assert_eq!( + "attachment; filename=\"colourful.csv\"".to_owned(), + display_rendered + ); + } + + #[test] + fn test_display_quote() { + let as_string = "form-data; name=upload; filename=\"Quote\\\"here.png\""; + as_string + .find(['\\', '\"'].iter().collect::().as_str()) + .unwrap(); // ensure `\"` is there + let a = HeaderValue::from_static(as_string); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!(as_string, display_rendered); + } + + #[test] + fn test_display_space_tab() { + let as_string = "form-data; name=upload; filename=\"Space here.png\""; + let a = HeaderValue::from_static(as_string); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!(as_string, display_rendered); + + let a: ContentDisposition = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![DispositionParam::Filename(String::from("Tab\there.png"))], + }; + let display_rendered = format!("{}", a); + assert_eq!("inline; filename=\"Tab\x09here.png\"", display_rendered); + } + + #[test] + fn test_display_control_characters() { + /* let a = "attachment; filename=\"carriage\rreturn.png\""; + let a = HeaderValue::from_static(a); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!( + "attachment; filename=\"carriage\\\rreturn.png\"", + display_rendered + );*/ + // No way to create a HeaderValue containing a carriage return. + + let a: ContentDisposition = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![DispositionParam::Filename(String::from("bell\x07.png"))], + }; + let display_rendered = format!("{}", a); + assert_eq!("inline; filename=\"bell\\\x07.png\"", display_rendered); + } + + #[test] + fn test_param_methods() { + let param = DispositionParam::Filename(String::from("sample.txt")); + assert!(param.is_filename()); + assert_eq!(param.as_filename().unwrap(), "sample.txt"); + + let param = DispositionParam::Unknown(String::from("foo"), String::from("bar")); + assert!(param.is_unknown("foo")); + assert_eq!(param.as_unknown("fOo"), Some("bar")); + } + + #[test] + fn test_disposition_methods() { + let cd = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + ], + }; + assert_eq!(cd.get_name(), Some("upload")); + assert_eq!(cd.get_unknown("dummy"), Some("3")); + assert_eq!(cd.get_filename(), Some("sample.png")); + assert_eq!(cd.get_unknown_ext("dummy"), None); + assert_eq!(cd.get_unknown("duMMy"), Some("3")); } } diff --git a/src/header/mod.rs b/src/header/mod.rs index 291bc6eac..74e4b03e5 100644 --- a/src/header/mod.rs +++ b/src/header/mod.rs @@ -223,8 +223,7 @@ pub fn from_comma_delimited( .filter_map(|x| match x.trim() { "" => None, y => Some(y), - }) - .filter_map(|x| x.trim().parse().ok()), + }).filter_map(|x| x.trim().parse().ok()), ) } Ok(result) @@ -263,8 +262,10 @@ where // From hyper v0.11.27 src/header/parsing.rs -/// An extended header parameter value (i.e., tagged with a character set and optionally, -/// a language), as defined in [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). +/// The value part of an extended parameter consisting of three parts: +/// the REQUIRED character set name (`charset`), the OPTIONAL language information (`language_tag`), +/// and a character sequence representing the actual value (`value`), separated by single quote +/// characters. It is defined in [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). #[derive(Clone, Debug, PartialEq)] pub struct ExtendedValue { /// The character set that is used to encode the `value` to a string. diff --git a/src/helpers.rs b/src/helpers.rs index 400b12253..e82d61616 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -279,8 +279,7 @@ mod tests { true, StatusCode::MOVED_PERMANENTLY, )) - }) - .finish(); + }).finish(); // trailing slashes let params = vec![ diff --git a/src/httpmessage.rs b/src/httpmessage.rs index 5db2f075b..60f77b07e 100644 --- a/src/httpmessage.rs +++ b/src/httpmessage.rs @@ -479,8 +479,7 @@ where body.extend_from_slice(&chunk); Ok(body) } - }) - .map(|body| body.freeze()), + }).map(|body| body.freeze()), )); self.poll() } @@ -588,8 +587,7 @@ where body.extend_from_slice(&chunk); Ok(body) } - }) - .and_then(move |body| { + }).and_then(move |body| { if (encoding as *const Encoding) == UTF_8 { serde_urlencoded::from_bytes::(&body) .map_err(|_| UrlencodedError::Parse) @@ -694,8 +692,7 @@ mod tests { .header( header::TRANSFER_ENCODING, Bytes::from_static(b"some va\xadscc\xacas0xsdasdlue"), - ) - .finish(); + ).finish(); assert!(req.chunked().is_err()); } @@ -734,7 +731,7 @@ mod tests { header::CONTENT_TYPE, "application/x-www-form-urlencoded", ).header(header::CONTENT_LENGTH, "xxxx") - .finish(); + .finish(); assert_eq!( req.urlencoded::().poll().err().unwrap(), UrlencodedError::UnknownLength @@ -744,7 +741,7 @@ mod tests { header::CONTENT_TYPE, "application/x-www-form-urlencoded", ).header(header::CONTENT_LENGTH, "1000000") - .finish(); + .finish(); assert_eq!( req.urlencoded::().poll().err().unwrap(), UrlencodedError::Overflow @@ -765,8 +762,8 @@ mod tests { header::CONTENT_TYPE, "application/x-www-form-urlencoded", ).header(header::CONTENT_LENGTH, "11") - .set_payload(Bytes::from_static(b"hello=world")) - .finish(); + .set_payload(Bytes::from_static(b"hello=world")) + .finish(); let result = req.urlencoded::().poll().ok().unwrap(); assert_eq!( @@ -780,8 +777,8 @@ mod tests { header::CONTENT_TYPE, "application/x-www-form-urlencoded; charset=utf-8", ).header(header::CONTENT_LENGTH, "11") - .set_payload(Bytes::from_static(b"hello=world")) - .finish(); + .set_payload(Bytes::from_static(b"hello=world")) + .finish(); let result = req.urlencoded().poll().ok().unwrap(); assert_eq!( @@ -830,8 +827,7 @@ mod tests { b"Lorem Ipsum is simply dummy text of the printing and typesetting\n\ industry. Lorem Ipsum has been the industry's standard dummy\n\ Contrary to popular belief, Lorem Ipsum is not simply random text.", - )) - .finish(); + )).finish(); let mut r = Readlines::new(&req); match r.poll().ok().unwrap() { Async::Ready(Some(s)) => assert_eq!( diff --git a/src/httprequest.rs b/src/httprequest.rs index 83017dfa0..0e4f74e5e 100644 --- a/src/httprequest.rs +++ b/src/httprequest.rs @@ -81,6 +81,15 @@ impl HttpRequest { } } + /// Construct new http request with empty state. + pub fn drop_state(&self) -> HttpRequest { + HttpRequest { + state: Rc::new(()), + req: self.req.as_ref().map(|r| r.clone()), + resource: self.resource.clone(), + } + } + #[inline] /// Construct new http request with new RouteInfo. pub(crate) fn with_route_info(&self, mut resource: ResourceInfo) -> HttpRequest { @@ -207,7 +216,7 @@ impl HttpRequest { self.url_for(name, &NO_PARAMS) } - /// This method returns reference to current `RouteInfo` object. + /// This method returns reference to current `ResourceInfo` object. #[inline] pub fn resource(&self) -> &ResourceInfo { &self.resource @@ -255,7 +264,8 @@ impl HttpRequest { if self.extensions().get::().is_none() { let mut cookies = Vec::new(); for hdr in self.request().inner.headers.get_all(header::COOKIE) { - let s = str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; + let s = + str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; for cookie_str in s.split(';').map(|s| s.trim()) { if !cookie_str.is_empty() { cookies.push(Cookie::parse_encoded(cookie_str)?.into_owned()); @@ -344,24 +354,24 @@ impl FromRequest for HttpRequest { impl fmt::Debug for HttpRequest { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let res = writeln!( + writeln!( f, "\nHttpRequest {:?} {}:{}", self.version(), self.method(), self.path() - ); + )?; if !self.query_string().is_empty() { - let _ = writeln!(f, " query: ?{:?}", self.query_string()); + writeln!(f, " query: ?{:?}", self.query_string())?; } if !self.match_info().is_empty() { - let _ = writeln!(f, " params: {:?}", self.match_info()); + writeln!(f, " params: {:?}", self.match_info())?; } - let _ = writeln!(f, " headers:"); + writeln!(f, " headers:")?; for (key, val) in self.headers().iter() { - let _ = writeln!(f, " {:?}: {:?}", key, val); + writeln!(f, " {:?}: {:?}", key, val)?; } - res + Ok(()) } } @@ -420,7 +430,7 @@ mod tests { #[test] fn test_request_match_info() { - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); router.register_resource(Resource::new(ResourceDef::new("/{key}/"))); let req = TestRequest::with_uri("/value/?id=test").finish(); @@ -430,7 +440,7 @@ mod tests { #[test] fn test_url_for() { - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); let mut resource = Resource::new(ResourceDef::new("/user/{name}.{ext}")); resource.name("index"); router.register_resource(resource); @@ -464,7 +474,8 @@ mod tests { fn test_url_for_with_prefix() { let mut resource = Resource::new(ResourceDef::new("/user/{name}.html")); resource.name("index"); - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); + router.set_prefix("/prefix"); router.register_resource(resource); let mut info = router.default_route_info(); @@ -490,7 +501,8 @@ mod tests { fn test_url_for_static() { let mut resource = Resource::new(ResourceDef::new("/index.html")); resource.name("index"); - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); + router.set_prefix("/prefix"); router.register_resource(resource); let mut info = router.default_route_info(); @@ -513,7 +525,7 @@ mod tests { #[test] fn test_url_for_external() { - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); router.register_external( "youtube", ResourceDef::external("https://youtube.com/watch/{video_id}"), diff --git a/src/httpresponse.rs b/src/httpresponse.rs index 2673da2a3..52dd8046b 100644 --- a/src/httpresponse.rs +++ b/src/httpresponse.rs @@ -142,8 +142,7 @@ impl HttpResponse { HeaderValue::from_str(&cookie.to_string()) .map(|c| { h.append(header::SET_COOKIE, c); - }) - .map_err(|e| e.into()) + }).map_err(|e| e.into()) } /// Remove all cookies with the given name from this response. Returns @@ -273,7 +272,7 @@ impl HttpResponse { self.get_mut().response_size = size; } - /// Set write buffer capacity + /// Get write buffer capacity pub fn write_buffer_capacity(&self) -> usize { self.get_ref().write_capacity } @@ -650,7 +649,14 @@ impl HttpResponseBuilder { /// /// `HttpResponseBuilder` can not be used after this call. pub fn json(&mut self, value: T) -> HttpResponse { - match serde_json::to_string(&value) { + self.json2(&value) + } + + /// Set a json body and generate `HttpResponse` + /// + /// `HttpResponseBuilder` can not be used after this call. + pub fn json2(&mut self, value: &T) -> HttpResponse { + match serde_json::to_string(value) { Ok(body) => { let contains = if let Some(parts) = parts(&mut self.response, &self.err) { @@ -1072,8 +1078,7 @@ mod tests { .http_only(true) .max_age(Duration::days(1)) .finish(), - ) - .del_cookie(&cookies[0]) + ).del_cookie(&cookies[0]) .finish(); let mut val: Vec<_> = resp @@ -1186,6 +1191,30 @@ mod tests { ); } + #[test] + fn test_json2() { + let resp = HttpResponse::build(StatusCode::OK).json2(&vec!["v1", "v2", "v3"]); + let ct = resp.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("application/json")); + assert_eq!( + *resp.body(), + Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]")) + ); + } + + #[test] + fn test_json2_ct() { + let resp = HttpResponse::build(StatusCode::OK) + .header(CONTENT_TYPE, "text/json") + .json2(&vec!["v1", "v2", "v3"]); + let ct = resp.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("text/json")); + assert_eq!( + *resp.body(), + Body::from(Bytes::from_static(b"[\"v1\",\"v2\",\"v3\"]")) + ); + } + impl Body { pub(crate) fn bin_ref(&self) -> &Binary { match *self { diff --git a/src/info.rs b/src/info.rs index b15ba9886..43c22123e 100644 --- a/src/info.rs +++ b/src/info.rs @@ -16,7 +16,10 @@ pub struct ConnectionInfo { impl ConnectionInfo { /// Create *ConnectionInfo* instance for a request. - #[cfg_attr(feature = "cargo-clippy", allow(cyclomatic_complexity))] + #[cfg_attr( + feature = "cargo-clippy", + allow(cyclomatic_complexity) + )] pub fn update(&mut self, req: &Request) { let mut host = None; let mut scheme = None; @@ -174,8 +177,7 @@ mod tests { .header( header::FORWARDED, "for=192.0.2.60; proto=https; by=203.0.113.43; host=rust-lang.org", - ) - .request(); + ).request(); let mut info = ConnectionInfo::default(); info.update(&req); diff --git a/src/json.rs b/src/json.rs index c76aeaa7d..178143f11 100644 --- a/src/json.rs +++ b/src/json.rs @@ -172,7 +172,7 @@ where /// let app = App::new().resource("/index.html", |r| { /// r.method(http::Method::POST) /// .with_config(index, |cfg| { -/// cfg.limit(4096) // <- change json extractor configuration +/// cfg.0.limit(4096) // <- change json extractor configuration /// .error_handler(|err, req| { // <- create custom error response /// error::InternalError::from_response( /// err, HttpResponse::Conflict().finish()).into() @@ -327,8 +327,7 @@ impl Future for JsonBod body.extend_from_slice(&chunk); Ok(body) } - }) - .and_then(|body| Ok(serde_json::from_slice::(&body)?)); + }).and_then(|body| Ok(serde_json::from_slice::(&body)?)); self.fut = Some(Box::new(fut)); self.poll() } @@ -388,8 +387,7 @@ mod tests { .header( header::CONTENT_TYPE, header::HeaderValue::from_static("application/text"), - ) - .finish(); + ).finish(); let mut json = req.json::(); assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType); @@ -397,12 +395,10 @@ mod tests { .header( header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"), - ) - .header( + ).header( header::CONTENT_LENGTH, header::HeaderValue::from_static("10000"), - ) - .finish(); + ).finish(); let mut json = req.json::().limit(100); assert_eq!(json.poll().err().unwrap(), JsonPayloadError::Overflow); @@ -410,12 +406,10 @@ mod tests { .header( header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"), - ) - .header( + ).header( header::CONTENT_LENGTH, header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + ).set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .finish(); let mut json = req.json::(); @@ -442,9 +436,8 @@ mod tests { ).header( header::CONTENT_LENGTH, header::HeaderValue::from_static("16"), - ) - .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) - .finish(); + ).set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) + .finish(); assert!(handler.handle(&req).as_err().is_none()) } } diff --git a/src/lib.rs b/src/lib.rs index 0ab4a1bef..f8326886f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,25 +64,24 @@ //! ## Package feature //! //! * `tls` - enables ssl support via `native-tls` crate -//! * `alpn` - enables ssl support via `openssl` crate, require for `http/2` -//! support +//! * `ssl` - enables ssl support via `openssl` crate, supports `http/2` +//! * `rust-tls` - enables ssl support via `rustls` crate, supports `http/2` +//! * `uds` - enables support for making client requests via Unix Domain Sockets. +//! Unix only. Not necessary for *serving* requests. //! * `session` - enables session support, includes `ring` crate as //! dependency //! * `brotli` - enables `brotli` compression support, requires `c` //! compiler -//! * `flate-c` - enables `gzip`, `deflate` compression support, requires +//! * `flate2-c` - enables `gzip`, `deflate` compression support, requires //! `c` compiler -//! * `flate-rust` - experimental rust based implementation for +//! * `flate2-rust` - experimental rust based implementation for //! `gzip`, `deflate` compression. //! #![cfg_attr(actix_nightly, feature( specialization, // for impl ErrorResponse for std::error::Error extern_prelude, + tool_lints, ))] -#![cfg_attr( - feature = "cargo-clippy", - allow(decimal_literal_representation, suspicious_arithmetic_impl) -)] #![warn(missing_docs)] #[macro_use] @@ -101,9 +100,9 @@ extern crate failure; extern crate lazy_static; #[macro_use] extern crate futures; +extern crate askama_escape; extern crate cookie; extern crate futures_cpupool; -extern crate htmlescape; extern crate http as modhttp; extern crate httparse; extern crate language_tags; @@ -116,10 +115,13 @@ extern crate parking_lot; extern crate rand; extern crate slab; extern crate tokio; +extern crate tokio_current_thread; extern crate tokio_io; extern crate tokio_reactor; extern crate tokio_tcp; extern crate tokio_timer; +#[cfg(all(unix, feature = "uds"))] +extern crate tokio_uds; extern crate url; #[macro_use] extern crate serde; @@ -130,10 +132,13 @@ extern crate encoding; extern crate flate2; extern crate h2 as http2; extern crate num_cpus; +extern crate serde_urlencoded; #[macro_use] extern crate percent_encoding; extern crate serde_json; extern crate smallvec; + +extern crate actix_net; #[macro_use] extern crate actix as actix_inner; @@ -151,6 +156,15 @@ extern crate openssl; #[cfg(feature = "openssl")] extern crate tokio_openssl; +#[cfg(feature = "rust-tls")] +extern crate rustls; +#[cfg(feature = "rust-tls")] +extern crate tokio_rustls; +#[cfg(feature = "rust-tls")] +extern crate webpki; +#[cfg(feature = "rust-tls")] +extern crate webpki_roots; + mod application; mod body; mod context; @@ -173,7 +187,6 @@ mod resource; mod route; mod router; mod scope; -mod serde_urlencoded; mod uri; mod with; @@ -224,6 +237,11 @@ pub(crate) const HAS_TLS: bool = true; #[cfg(not(feature = "tls"))] pub(crate) const HAS_TLS: bool = false; +#[cfg(feature = "rust-tls")] +pub(crate) const HAS_RUSTLS: bool = true; +#[cfg(not(feature = "rust-tls"))] +pub(crate) const HAS_RUSTLS: bool = false; + pub mod dev { //! The `actix-web` prelude for library developers //! @@ -237,14 +255,15 @@ pub mod dev { pub use body::BodyStream; pub use context::Drain; - pub use extractor::{FormConfig, PayloadConfig}; + pub use extractor::{FormConfig, PayloadConfig, QueryConfig, PathConfig}; pub use handler::{AsyncResult, Handler}; - pub use httpmessage::{MessageBody, UrlEncoded}; + pub use httpmessage::{MessageBody, Readlines, UrlEncoded}; pub use httpresponse::HttpResponseBuilder; pub use info::ConnectionInfo; pub use json::{JsonBody, JsonConfig}; pub use param::{FromParam, Params}; pub use payload::{Payload, PayloadBuffer}; + pub use pipeline::Pipeline; pub use resource::Resource; pub use route::Route; pub use router::{ResourceDef, ResourceInfo, ResourceType, Router}; @@ -266,7 +285,9 @@ pub mod http { /// Various http headers pub mod header { pub use header::*; - pub use header::{ContentDisposition, DispositionType, DispositionParam, Charset, LanguageTag}; + pub use header::{ + Charset, ContentDisposition, DispositionParam, DispositionType, LanguageTag, + }; } pub use header::ContentEncoding; pub use httpresponse::ConnectionType; diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index 052e4da23..953f2911c 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -387,12 +387,10 @@ impl Middleware for Cors { header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str(), ); - }) - .if_some(headers, |headers, resp| { + }).if_some(headers, |headers, resp| { let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); - }) - .if_true(self.inner.origins.is_all(), |resp| { + }).if_true(self.inner.origins.is_all(), |resp| { if self.inner.send_wildcard { resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"); } else { @@ -402,17 +400,14 @@ impl Middleware for Cors { origin.clone(), ); } - }) - .if_true(self.inner.origins.is_some(), |resp| { + }).if_true(self.inner.origins.is_some(), |resp| { resp.header( header::ACCESS_CONTROL_ALLOW_ORIGIN, self.inner.origins_str.as_ref().unwrap().clone(), ); - }) - .if_true(self.inner.supports_credentials, |resp| { + }).if_true(self.inner.supports_credentials, |resp| { resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); - }) - .header( + }).header( header::ACCESS_CONTROL_ALLOW_METHODS, &self .inner @@ -420,8 +415,7 @@ impl Middleware for Cors { .iter() .fold(String::new(), |s, v| s + "," + v.as_str()) .as_str()[1..], - ) - .finish(), + ).finish(), )) } else { // Only check requests with a origin header. @@ -832,15 +826,15 @@ impl CorsBuilder { if let AllOrSome::Some(ref origins) = cors.origins { let s = origins .iter() - .fold(String::new(), |s, v| s + &v.to_string()); - cors.origins_str = Some(HeaderValue::try_from(s.as_str()).unwrap()); + .fold(String::new(), |s, v| format!("{}, {}", s, v)); + cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap()); } if !self.expose_hdrs.is_empty() { cors.expose_hdrs = Some( self.expose_hdrs .iter() - .fold(String::new(), |s, v| s + v.as_str())[1..] + .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..] .to_owned(), ); } @@ -978,8 +972,7 @@ mod tests { .header( header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT", - ) - .method(Method::OPTIONS) + ).method(Method::OPTIONS) .finish(); let resp = cors.start(&req).unwrap().response(); @@ -1073,12 +1066,14 @@ mod tests { #[test] fn test_response() { + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; let cors = Cors::build() .send_wildcard() .disable_preflight() .max_age(3600) .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) .allowed_header(header::CONTENT_TYPE) .finish(); @@ -1100,6 +1095,22 @@ mod tests { resp.headers().get(header::VARY).unwrap().as_bytes() ); + { + let headers = resp + .headers() + .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) + .unwrap() + .to_str() + .unwrap() + .split(',') + .map(|s| s.trim()) + .collect::>(); + + for h in exposed_headers { + assert!(headers.contains(&h.as_str())); + } + } + let resp: HttpResponse = HttpResponse::Ok().header(header::VARY, "Accept").finish(); let resp = cors.response(&req, resp).unwrap().response(); @@ -1111,16 +1122,29 @@ mod tests { let cors = Cors::build() .disable_vary_header() .allowed_origin("https://www.example.com") + .allowed_origin("https://www.google.com") .finish(); let resp: HttpResponse = HttpResponse::Ok().into(); let resp = cors.response(&req, resp).unwrap().response(); - assert_eq!( - &b"https://www.example.com"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); + + let origins_str = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .to_str() + .unwrap(); + + if origins_str.starts_with("https://www.example.com") { + assert_eq!( + "https://www.example.com, https://www.google.com", + origins_str + ); + } else { + assert_eq!( + "https://www.google.com, https://www.example.com", + origins_str + ); + } } #[test] diff --git a/src/middleware/csrf.rs b/src/middleware/csrf.rs index cda1d324c..cacfc8d53 100644 --- a/src/middleware/csrf.rs +++ b/src/middleware/csrf.rs @@ -76,7 +76,7 @@ impl ResponseError for CsrfError { } fn uri_origin(uri: &Uri) -> Option { - match (uri.scheme_part(), uri.host(), uri.port()) { + match (uri.scheme_part(), uri.host(), uri.port_part().map(|port| port.as_u16())) { (Some(scheme), Some(host), Some(port)) => { Some(format!("{}://{}:{}", scheme, host, port)) } @@ -93,8 +93,7 @@ fn origin(headers: &HeaderMap) -> Option, CsrfError>> { .to_str() .map_err(|_| CsrfError::BadOrigin) .map(|o| o.into()) - }) - .or_else(|| { + }).or_else(|| { headers.get(header::REFERER).map(|referer| { Uri::try_from(Bytes::from(referer.as_bytes())) .ok() @@ -251,7 +250,7 @@ mod tests { "Referer", "https://www.example.com/some/path?query=param", ).method(Method::POST) - .finish(); + .finish(); assert!(csrf.start(&req).is_ok()); } diff --git a/src/middleware/errhandlers.rs b/src/middleware/errhandlers.rs index 83c66aae1..c7d19d334 100644 --- a/src/middleware/errhandlers.rs +++ b/src/middleware/errhandlers.rs @@ -131,7 +131,7 @@ mod tests { ErrorHandlers::new() .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500), ).middleware(MiddlewareOne) - .handler(|_| HttpResponse::Ok()) + .handler(|_| HttpResponse::Ok()) }); let request = srv.get().finish().unwrap(); diff --git a/src/middleware/identity.rs b/src/middleware/identity.rs index d890bebef..a664ba1f0 100644 --- a/src/middleware/identity.rs +++ b/src/middleware/identity.rs @@ -48,7 +48,7 @@ //! ``` use std::rc::Rc; -use cookie::{Cookie, CookieJar, Key}; +use cookie::{Cookie, CookieJar, Key, SameSite}; use futures::future::{err as FutErr, ok as FutOk, FutureResult}; use futures::Future; use time::Duration; @@ -237,6 +237,7 @@ struct CookieIdentityInner { domain: Option, secure: bool, max_age: Option, + same_site: Option, } impl CookieIdentityInner { @@ -248,6 +249,7 @@ impl CookieIdentityInner { domain: None, secure: true, max_age: None, + same_site: None, } } @@ -268,6 +270,10 @@ impl CookieIdentityInner { cookie.set_max_age(max_age); } + if let Some(same_site) = self.same_site { + cookie.set_same_site(same_site); + } + let mut jar = CookieJar::new(); if some { jar.private(&self.key).add(cookie); @@ -370,6 +376,12 @@ impl CookieIdentityPolicy { Rc::get_mut(&mut self.0).unwrap().max_age = Some(value); self } + + /// Sets the `same_site` field in the session cookie being built. + pub fn same_site(mut self, same_site: SameSite) -> Self { + Rc::get_mut(&mut self.0).unwrap().same_site = Some(same_site); + self + } } impl IdentityPolicy for CookieIdentityPolicy { diff --git a/src/middleware/logger.rs b/src/middleware/logger.rs index 103cbf373..b7bb1bb80 100644 --- a/src/middleware/logger.rs +++ b/src/middleware/logger.rs @@ -25,7 +25,7 @@ use middleware::{Finished, Middleware, Started}; /// default format: /// /// ```ignore -/// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T +/// %a "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T /// ``` /// ```rust /// # extern crate actix_web; @@ -94,7 +94,7 @@ impl Default for Logger { /// Create `Logger` middleware with format: /// /// ```ignore - /// %a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T + /// %a "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T /// ``` fn default() -> Logger { Logger { @@ -143,7 +143,7 @@ struct Format(Vec); impl Default for Format { /// Return the default formatting style for the `Logger`: fn default() -> Format { - Format::new(r#"%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T"#) + Format::new(r#"%a "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T"#) } } diff --git a/src/middleware/session.rs b/src/middleware/session.rs index cc7aab6b4..e8b0e5558 100644 --- a/src/middleware/session.rs +++ b/src/middleware/session.rs @@ -270,14 +270,17 @@ impl> Middleware for SessionStorage { } /// A simple key-value storage interface that is internally used by `Session`. -#[doc(hidden)] pub trait SessionImpl: 'static { + /// Get session value by key fn get(&self, key: &str) -> Option<&str>; + /// Set session value fn set(&mut self, key: &str, value: String); + /// Remove specific key from session fn remove(&mut self, key: &str); + /// Remove all values from session fn clear(&mut self); /// Write session to storage backend. @@ -285,9 +288,10 @@ pub trait SessionImpl: 'static { } /// Session's storage backend trait definition. -#[doc(hidden)] pub trait SessionBackend: Sized + 'static { + /// Session item type Session: SessionImpl; + /// Future that reads session type ReadFuture: Future; /// Parse the session from request and load data from a storage backend. @@ -579,8 +583,7 @@ mod tests { App::new() .middleware(SessionStorage::new( CookieSessionBackend::signed(&[0; 32]).secure(false), - )) - .resource("/", |r| { + )).resource("/", |r| { r.f(|req| { let _ = req.session().set("counter", 100); "test" @@ -599,8 +602,7 @@ mod tests { App::new() .middleware(SessionStorage::new( CookieSessionBackend::signed(&[0; 32]).secure(false), - )) - .resource("/", |r| { + )).resource("/", |r| { r.with(|ses: Session| { let _ = ses.set("counter", 100); "test" diff --git a/src/multipart.rs b/src/multipart.rs index d4b6059f2..862f60ecb 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -441,13 +441,13 @@ where impl fmt::Debug for Field { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let res = writeln!(f, "\nMultipartField: {}", self.ct); - let _ = writeln!(f, " boundary: {}", self.inner.borrow().boundary); - let _ = writeln!(f, " headers:"); + writeln!(f, "\nMultipartField: {}", self.ct)?; + writeln!(f, " boundary: {}", self.inner.borrow().boundary)?; + writeln!(f, " headers:")?; for (key, val) in self.headers.iter() { - let _ = writeln!(f, " {:?}: {:?}", key, val); + writeln!(f, " {:?}: {:?}", key, val)?; } - res + Ok(()) } } @@ -756,13 +756,10 @@ mod tests { { use http::header::{DispositionParam, DispositionType}; let cd = field.content_disposition().unwrap(); - assert_eq!( - cd.disposition, - DispositionType::Ext("form-data".into()) - ); + assert_eq!(cd.disposition, DispositionType::FormData); assert_eq!( cd.parameters[0], - DispositionParam::Ext("name".into(), "file".into()) + DispositionParam::Name("file".into()) ); } assert_eq!(field.content_type().type_(), mime::TEXT); @@ -813,7 +810,6 @@ mod tests { let res: Result<(), ()> = Ok(()); result(res) - })) - .unwrap(); + })).unwrap(); } } diff --git a/src/param.rs b/src/param.rs index 2704b60d0..a3f602599 100644 --- a/src/param.rs +++ b/src/param.rs @@ -8,7 +8,7 @@ use http::StatusCode; use smallvec::SmallVec; use error::{InternalError, ResponseError, UriSegmentError}; -use uri::Url; +use uri::{Url, RESERVED_QUOTER}; /// A trait to abstract the idea of creating a new instance of a type from a /// path parameter. @@ -103,6 +103,17 @@ impl Params { } } + /// Get URL-decoded matched parameter by name without type conversion + pub fn get_decoded(&self, key: &str) -> Option { + self.get(key).map(|value| { + if let Some(ref mut value) = RESERVED_QUOTER.requote(value.as_bytes()) { + Rc::make_mut(value).to_string() + } else { + value.to_string() + } + }) + } + /// Get unprocessed part of path pub fn unprocessed(&self) -> &str { &self.url.path()[(self.tail as usize)..] @@ -236,7 +247,6 @@ macro_rules! FROM_STR { ($type:ty) => { impl FromParam for $type { type Err = InternalError<<$type as FromStr>::Err>; - fn from_param(val: &str) -> Result { <$type as FromStr>::from_str(val) .map_err(|e| InternalError::new(e, StatusCode::BAD_REQUEST)) @@ -301,4 +311,24 @@ mod tests { Ok(PathBuf::from_iter(vec!["seg2"])) ); } + + #[test] + fn test_get_param_by_name() { + let mut params = Params::new(); + params.add_static("item1", "path"); + params.add_static("item2", "http%3A%2F%2Flocalhost%3A80%2Ffoo"); + + assert_eq!(params.get("item0"), None); + assert_eq!(params.get_decoded("item0"), None); + assert_eq!(params.get("item1"), Some("path")); + assert_eq!(params.get_decoded("item1"), Some("path".to_string())); + assert_eq!( + params.get("item2"), + Some("http%3A%2F%2Flocalhost%3A80%2Ffoo") + ); + assert_eq!( + params.get_decoded("item2"), + Some("http://localhost:80/foo".to_string()) + ); + } } diff --git a/src/payload.rs b/src/payload.rs index b20bec652..2131e3c3c 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -1,6 +1,8 @@ //! Payload stream use bytes::{Bytes, BytesMut}; -use futures::task::{current as current_task, Task}; +#[cfg(not(test))] +use futures::task::current as current_task; +use futures::task::Task; use futures::{Async, Poll, Stream}; use std::cell::RefCell; use std::cmp; @@ -513,8 +515,7 @@ where .fold(BytesMut::new(), |mut b, c| { b.extend_from_slice(c); b - }) - .freeze() + }).freeze() } } @@ -553,8 +554,7 @@ mod tests { let res: Result<(), ()> = Ok(()); result(res) - })) - .unwrap(); + })).unwrap(); } #[test] @@ -578,8 +578,7 @@ mod tests { let res: Result<(), ()> = Ok(()); result(res) - })) - .unwrap(); + })).unwrap(); } #[test] @@ -596,8 +595,7 @@ mod tests { payload.readany().err().unwrap(); let res: Result<(), ()> = Ok(()); result(res) - })) - .unwrap(); + })).unwrap(); } #[test] @@ -625,8 +623,7 @@ mod tests { let res: Result<(), ()> = Ok(()); result(res) - })) - .unwrap(); + })).unwrap(); } #[test] @@ -659,8 +656,7 @@ mod tests { let res: Result<(), ()> = Ok(()); result(res) - })) - .unwrap(); + })).unwrap(); } #[test] @@ -693,8 +689,7 @@ mod tests { let res: Result<(), ()> = Ok(()); result(res) - })) - .unwrap(); + })).unwrap(); } #[test] @@ -715,7 +710,6 @@ mod tests { let res: Result<(), ()> = Ok(()); result(res) - })) - .unwrap(); + })).unwrap(); } } diff --git a/src/pipeline.rs b/src/pipeline.rs index dbe9e58ad..a938f2eb2 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -42,13 +42,6 @@ enum PipelineState { } impl> PipelineState { - fn is_response(&self) -> bool { - match *self { - PipelineState::Response(_) => true, - _ => false, - } - } - fn poll( &mut self, info: &mut PipelineInfo, mws: &[Box>], ) -> Option> { @@ -58,9 +51,8 @@ impl> PipelineState { PipelineState::RunMiddlewares(ref mut state) => state.poll(info, mws), PipelineState::Finishing(ref mut state) => state.poll(info, mws), PipelineState::Completed(ref mut state) => state.poll(info), - PipelineState::Response(_) | PipelineState::None | PipelineState::Error => { - None - } + PipelineState::Response(ref mut state) => state.poll(info, mws), + PipelineState::None | PipelineState::Error => None, } } } @@ -89,7 +81,7 @@ impl PipelineInfo { } impl> Pipeline { - pub fn new( + pub(crate) fn new( req: HttpRequest, mws: Rc>>>, handler: Rc, ) -> Pipeline { let mut info = PipelineInfo { @@ -130,22 +122,20 @@ impl> HttpHandlerTask for Pipeline { let mut state = mem::replace(&mut self.1, PipelineState::None); loop { - if state.is_response() { - if let PipelineState::Response(st) = state { - match st.poll_io(io, &mut self.0, &self.2) { - Ok(state) => { - self.1 = state; - if let Some(error) = self.0.error.take() { - return Err(error); - } else { - return Ok(Async::Ready(self.is_done())); - } - } - Err(state) => { - self.1 = state; - return Ok(Async::NotReady); + if let PipelineState::Response(st) = state { + match st.poll_io(io, &mut self.0, &self.2) { + Ok(state) => { + self.1 = state; + if let Some(error) = self.0.error.take() { + return Err(error); + } else { + return Ok(Async::Ready(self.is_done())); } } + Err(state) => { + self.1 = state; + return Ok(Async::NotReady); + } } } match state { @@ -401,7 +391,7 @@ impl RunMiddlewares { } struct ProcessResponse { - resp: HttpResponse, + resp: Option, iostate: IOState, running: RunningState, drain: Option>, @@ -409,7 +399,7 @@ struct ProcessResponse { _h: PhantomData, } -#[derive(PartialEq)] +#[derive(PartialEq, Debug)] enum RunningState { Running, Paused, @@ -442,7 +432,7 @@ impl ProcessResponse { #[inline] fn init(resp: HttpResponse) -> PipelineState { PipelineState::Response(ProcessResponse { - resp, + resp: Some(resp), iostate: IOState::Response, running: RunningState::Running, drain: None, @@ -451,6 +441,79 @@ impl ProcessResponse { }) } + fn poll( + &mut self, info: &mut PipelineInfo, mws: &[Box>], + ) -> Option> { + // connection is dead at this point + match mem::replace(&mut self.iostate, IOState::Done) { + IOState::Response => Some(FinishingMiddlewares::init( + info, + mws, + self.resp.take().unwrap(), + )), + IOState::Payload(_) => Some(FinishingMiddlewares::init( + info, + mws, + self.resp.take().unwrap(), + )), + IOState::Actor(mut ctx) => { + if info.disconnected.take().is_some() { + ctx.disconnected(); + } + loop { + match ctx.poll() { + Ok(Async::Ready(Some(vec))) => { + if vec.is_empty() { + continue; + } + for frame in vec { + match frame { + Frame::Chunk(None) => { + info.context = Some(ctx); + return Some(FinishingMiddlewares::init( + info, + mws, + self.resp.take().unwrap(), + )); + } + Frame::Chunk(Some(_)) => (), + Frame::Drain(fut) => { + let _ = fut.send(()); + } + } + } + } + Ok(Async::Ready(None)) => { + return Some(FinishingMiddlewares::init( + info, + mws, + self.resp.take().unwrap(), + )) + } + Ok(Async::NotReady) => { + self.iostate = IOState::Actor(ctx); + return None; + } + Err(err) => { + info.context = Some(ctx); + info.error = Some(err); + return Some(FinishingMiddlewares::init( + info, + mws, + self.resp.take().unwrap(), + )); + } + } + } + } + IOState::Done => Some(FinishingMiddlewares::init( + info, + mws, + self.resp.take().unwrap(), + )), + } + } + fn poll_io( mut self, io: &mut Writer, info: &mut PipelineInfo, mws: &[Box>], @@ -461,29 +524,39 @@ impl ProcessResponse { 'inner: loop { let result = match mem::replace(&mut self.iostate, IOState::Done) { IOState::Response => { - let encoding = - self.resp.content_encoding().unwrap_or(info.encoding); + let encoding = self + .resp + .as_ref() + .unwrap() + .content_encoding() + .unwrap_or(info.encoding); - let result = - match io.start(&info.req, &mut self.resp, encoding) { - Ok(res) => res, - Err(err) => { - info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init( - info, mws, self.resp, - )); - } - }; + let result = match io.start( + &info.req, + self.resp.as_mut().unwrap(), + encoding, + ) { + Ok(res) => res, + Err(err) => { + info.error = Some(err.into()); + return Ok(FinishingMiddlewares::init( + info, + mws, + self.resp.take().unwrap(), + )); + } + }; - if let Some(err) = self.resp.error() { - if self.resp.status().is_server_error() { + if let Some(err) = self.resp.as_ref().unwrap().error() { + if self.resp.as_ref().unwrap().status().is_server_error() + { error!( - "Error occured during request handling, status: {} {}", - self.resp.status(), err + "Error occurred during request handling, status: {} {}", + self.resp.as_ref().unwrap().status(), err ); } else { warn!( - "Error occured during request handling: {}", + "Error occurred during request handling: {}", err ); } @@ -493,7 +566,7 @@ impl ProcessResponse { } // always poll stream or actor for the first time - match self.resp.replace_body(Body::Empty) { + match self.resp.as_mut().unwrap().replace_body(Body::Empty) { Body::Streaming(stream) => { self.iostate = IOState::Payload(stream); continue 'inner; @@ -512,7 +585,9 @@ impl ProcessResponse { if let Err(err) = io.write_eof() { info.error = Some(err.into()); return Ok(FinishingMiddlewares::init( - info, mws, self.resp, + info, + mws, + self.resp.take().unwrap(), )); } break; @@ -523,7 +598,9 @@ impl ProcessResponse { Err(err) => { info.error = Some(err.into()); return Ok(FinishingMiddlewares::init( - info, mws, self.resp, + info, + mws, + self.resp.take().unwrap(), )); } Ok(result) => result, @@ -536,7 +613,9 @@ impl ProcessResponse { Err(err) => { info.error = Some(err); return Ok(FinishingMiddlewares::init( - info, mws, self.resp, + info, + mws, + self.resp.take().unwrap(), )); } }, @@ -559,26 +638,30 @@ impl ProcessResponse { info.error = Some(err.into()); return Ok( FinishingMiddlewares::init( - info, mws, self.resp, + info, + mws, + self.resp.take().unwrap(), ), ); } break 'inner; } - Frame::Chunk(Some(chunk)) => { - match io.write(&chunk) { - Err(err) => { - info.context = Some(ctx); - info.error = Some(err.into()); - return Ok( - FinishingMiddlewares::init( - info, mws, self.resp, - ), - ); - } - Ok(result) => res = Some(result), + Frame::Chunk(Some(chunk)) => match io + .write(&chunk) + { + Err(err) => { + info.context = Some(ctx); + info.error = Some(err.into()); + return Ok( + FinishingMiddlewares::init( + info, + mws, + self.resp.take().unwrap(), + ), + ); } - } + Ok(result) => res = Some(result), + }, Frame::Drain(fut) => self.drain = Some(fut), } } @@ -598,7 +681,9 @@ impl ProcessResponse { info.context = Some(ctx); info.error = Some(err); return Ok(FinishingMiddlewares::init( - info, mws, self.resp, + info, + mws, + self.resp.take().unwrap(), )); } } @@ -638,7 +723,11 @@ impl ProcessResponse { info.context = Some(ctx); } info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init(info, mws, self.resp)); + return Ok(FinishingMiddlewares::init( + info, + mws, + self.resp.take().unwrap(), + )); } } } @@ -652,11 +741,19 @@ impl ProcessResponse { Ok(_) => (), Err(err) => { info.error = Some(err.into()); - return Ok(FinishingMiddlewares::init(info, mws, self.resp)); + return Ok(FinishingMiddlewares::init( + info, + mws, + self.resp.take().unwrap(), + )); } } - self.resp.set_response_size(io.written()); - Ok(FinishingMiddlewares::init(info, mws, self.resp)) + self.resp.as_mut().unwrap().set_response_size(io.written()); + Ok(FinishingMiddlewares::init( + info, + mws, + self.resp.take().unwrap(), + )) } _ => Err(PipelineState::Response(self)), } diff --git a/src/pred.rs b/src/pred.rs index 22f12ac2a..99d6e608b 100644 --- a/src/pred.rs +++ b/src/pred.rs @@ -264,8 +264,7 @@ mod tests { .header( header::HOST, header::HeaderValue::from_static("www.rust-lang.org"), - ) - .finish(); + ).finish(); let pred = Host("www.rust-lang.org"); assert!(pred.check(&req, req.state())); diff --git a/src/resource.rs b/src/resource.rs index 1bf8d88fa..d884dd447 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -13,6 +13,7 @@ use middleware::Middleware; use pred; use route::Route; use router::ResourceDef; +use with::WithFactory; #[derive(Copy, Clone)] pub(crate) struct RouteId(usize); @@ -217,7 +218,7 @@ impl Resource { /// ``` pub fn with(&mut self, handler: F) where - F: Fn(T) -> R + 'static, + F: WithFactory, R: Responder + 'static, T: FromRequest + 'static, { diff --git a/src/route.rs b/src/route.rs index d383d90be..884a367ed 100644 --- a/src/route.rs +++ b/src/route.rs @@ -16,7 +16,7 @@ use middleware::{ Started as MiddlewareStarted, }; use pred::Predicate; -use with::{With, WithAsync}; +use with::{WithAsyncFactory, WithFactory}; /// Resource route definition /// @@ -57,7 +57,7 @@ impl Route { pub(crate) fn compose( &self, req: HttpRequest, mws: Rc>>>, ) -> AsyncResult { - AsyncResult::async(Box::new(Compose::new(req, mws, self.handler.clone()))) + AsyncResult::future(Box::new(Compose::new(req, mws, self.handler.clone()))) } /// Add match predicate to route. @@ -134,8 +134,7 @@ impl Route { /// } /// ``` /// - /// It is possible to use tuples for specifing multiple extractors for one - /// handler function. + /// It is possible to use multiple extractors for one handler function. /// /// ```rust /// # extern crate bytes; @@ -152,9 +151,9 @@ impl Route { /// /// /// extract path info using serde /// fn index( - /// info: (Path, Query>, Json), + /// path: Path, query: Query>, body: Json, /// ) -> Result { - /// Ok(format!("Welcome {}!", info.0.username)) + /// Ok(format!("Welcome {}!", path.username)) /// } /// /// fn main() { @@ -166,15 +165,15 @@ impl Route { /// ``` pub fn with(&mut self, handler: F) where - F: Fn(T) -> R + 'static, + F: WithFactory + 'static, R: Responder + 'static, T: FromRequest + 'static, { - self.h(With::new(handler, ::default())); + self.h(handler.create()); } /// Set handler function. Same as `.with()` but it allows to configure - /// extractor. + /// extractor. Configuration closure accepts config objects as tuple. /// /// ```rust /// # extern crate bytes; @@ -192,21 +191,21 @@ impl Route { /// let app = App::new().resource("/index.html", |r| { /// r.method(http::Method::GET) /// .with_config(index, |cfg| { // <- register handler - /// cfg.limit(4096); // <- limit size of the payload + /// cfg.0.limit(4096); // <- limit size of the payload /// }) /// }); /// } /// ``` pub fn with_config(&mut self, handler: F, cfg_f: C) where - F: Fn(T) -> R + 'static, + F: WithFactory, R: Responder + 'static, T: FromRequest + 'static, C: FnOnce(&mut T::Config), { let mut cfg = ::default(); cfg_f(&mut cfg); - self.h(With::new(handler, cfg)); + self.h(handler.create_with_config(cfg)); } /// Set async handler function, use request extractor for parameters. @@ -240,17 +239,18 @@ impl Route { /// ``` pub fn with_async(&mut self, handler: F) where - F: Fn(T) -> R + 'static, + F: WithAsyncFactory, R: Future + 'static, I: Responder + 'static, E: Into + 'static, T: FromRequest + 'static, { - self.h(WithAsync::new(handler, ::default())); + self.h(handler.create()); } /// Set async handler function, use request extractor for parameters. - /// This method allows to configure extractor. + /// This method allows to configure extractor. Configuration closure + /// accepts config objects as tuple. /// /// ```rust /// # extern crate bytes; @@ -275,14 +275,14 @@ impl Route { /// "/{username}/index.html", // <- define path parameters /// |r| r.method(http::Method::GET) /// .with_async_config(index, |cfg| { - /// cfg.limit(4096); + /// cfg.0.limit(4096); /// }), /// ); // <- use `with` extractor /// } /// ``` pub fn with_async_config(&mut self, handler: F, cfg: C) where - F: Fn(T) -> R + 'static, + F: WithAsyncFactory, R: Future + 'static, I: Responder + 'static, E: Into + 'static, @@ -291,7 +291,7 @@ impl Route { { let mut extractor_cfg = ::default(); cfg(&mut extractor_cfg); - self.h(WithAsync::new(handler, extractor_cfg)); + self.h(handler.create_with_config(extractor_cfg)); } } diff --git a/src/router.rs b/src/router.rs index e79dc93da..aa15e46d2 100644 --- a/src/router.rs +++ b/src/router.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::cmp::min; use std::collections::HashMap; use std::hash::{Hash, Hasher}; @@ -16,6 +17,7 @@ use pred::Predicate; use resource::{DefaultResource, Resource}; use scope::Scope; use server::Request; +use with::WithFactory; #[derive(Debug, Copy, Clone, PartialEq)] pub(crate) enum ResourceId { @@ -111,9 +113,14 @@ impl ResourceInfo { U: IntoIterator, I: AsRef, { - if let Some(pattern) = self.rmap.named.get(name) { - let path = - pattern.resource_path(elements, &req.path()[..(self.prefix as usize)])?; + let mut path = String::new(); + let mut elements = elements.into_iter(); + + if self + .rmap + .patterns_for(name, &mut path, &mut elements)? + .is_some() + { if path.starts_with('/') { let conn = req.connection_info(); Ok(Url::parse(&format!( @@ -160,12 +167,15 @@ impl ResourceInfo { } pub(crate) struct ResourceMap { + root: ResourceDef, + parent: RefCell>>, named: HashMap, patterns: Vec<(ResourceDef, Option>)>, + nested: Vec>, } impl ResourceMap { - pub fn has_resource(&self, path: &str) -> bool { + fn has_resource(&self, path: &str) -> bool { let path = if path.is_empty() { "/" } else { path }; for (pattern, rmap) in &self.patterns { @@ -179,20 +189,91 @@ impl ResourceMap { } false } + + fn patterns_for( + &self, name: &str, path: &mut String, elements: &mut U, + ) -> Result, UrlGenerationError> + where + U: Iterator, + I: AsRef, + { + if self.pattern_for(name, path, elements)?.is_some() { + Ok(Some(())) + } else { + self.parent_pattern_for(name, path, elements) + } + } + + fn pattern_for( + &self, name: &str, path: &mut String, elements: &mut U, + ) -> Result, UrlGenerationError> + where + U: Iterator, + I: AsRef, + { + if let Some(pattern) = self.named.get(name) { + self.fill_root(path, elements)?; + pattern.resource_path(path, elements)?; + Ok(Some(())) + } else { + for rmap in &self.nested { + if rmap.pattern_for(name, path, elements)?.is_some() { + return Ok(Some(())); + } + } + Ok(None) + } + } + + fn fill_root( + &self, path: &mut String, elements: &mut U, + ) -> Result<(), UrlGenerationError> + where + U: Iterator, + I: AsRef, + { + if let Some(ref parent) = *self.parent.borrow() { + parent.fill_root(path, elements)?; + } + self.root.resource_path(path, elements) + } + + fn parent_pattern_for( + &self, name: &str, path: &mut String, elements: &mut U, + ) -> Result, UrlGenerationError> + where + U: Iterator, + I: AsRef, + { + if let Some(ref parent) = *self.parent.borrow() { + if let Some(pattern) = parent.named.get(name) { + self.fill_root(path, elements)?; + pattern.resource_path(path, elements)?; + Ok(Some(())) + } else { + parent.parent_pattern_for(name, path, elements) + } + } else { + Ok(None) + } + } } impl Default for Router { fn default() -> Self { - Router::new() + Router::new(ResourceDef::new("")) } } impl Router { - pub(crate) fn new() -> Self { + pub(crate) fn new(root: ResourceDef) -> Self { Router { rmap: Rc::new(ResourceMap { + root, + parent: RefCell::new(None), named: HashMap::new(), patterns: Vec::new(), + nested: Vec::new(), }), resources: Vec::new(), patterns: Vec::new(), @@ -210,19 +291,6 @@ impl Router { } } - #[cfg(test)] - pub(crate) fn route_info(&self, req: &Request, prefix: u16) -> ResourceInfo { - let mut params = Params::with_url(req.url()); - params.set_tail(prefix); - - ResourceInfo { - params, - prefix: 0, - rmap: self.rmap.clone(), - resource: ResourceId::Default, - } - } - #[cfg(test)] pub(crate) fn default_route_info(&self) -> ResourceInfo { ResourceInfo { @@ -233,6 +301,10 @@ impl Router { } } + pub(crate) fn set_prefix(&mut self, path: &str) { + Rc::get_mut(&mut self.rmap).unwrap().root = ResourceDef::new(path); + } + pub(crate) fn register_resource(&mut self, resource: Resource) { { let rmap = Rc::get_mut(&mut self.rmap).unwrap(); @@ -258,6 +330,11 @@ impl Router { .unwrap() .patterns .push((scope.rdef().clone(), Some(scope.router().rmap.clone()))); + Rc::get_mut(&mut self.rmap) + .unwrap() + .nested + .push(scope.router().rmap.clone()); + let filters = scope.take_filters(); self.patterns .push(ResourcePattern::Scope(scope.rdef().clone(), filters)); @@ -286,22 +363,25 @@ impl Router { } pub(crate) fn finish(&mut self) { - if let Some(ref default) = self.default { - for resource in &mut self.resources { - match resource { - ResourceItem::Resource(_) => (), - ResourceItem::Scope(scope) => { - if !scope.has_default_resource() { + for resource in &mut self.resources { + match resource { + ResourceItem::Resource(_) => (), + ResourceItem::Scope(scope) => { + if !scope.has_default_resource() { + if let Some(ref default) = self.default { scope.default_resource(default.clone()); } - scope.finish() } - ResourceItem::Handler(hnd) => { - if !hnd.has_default_resource() { + *scope.router().rmap.parent.borrow_mut() = Some(self.rmap.clone()); + scope.finish(); + } + ResourceItem::Handler(hnd) => { + if !hnd.has_default_resource() { + if let Some(ref default) = self.default { hnd.default_resource(default.clone()); } - hnd.finish() } + hnd.finish() } } } @@ -319,7 +399,7 @@ impl Router { pub(crate) fn register_route(&mut self, path: &str, method: Method, f: F) where - F: Fn(T) -> R + 'static, + F: WithFactory, R: Responder + 'static, T: FromRequest + 'static, { @@ -459,35 +539,38 @@ pub struct ResourceDef { } impl ResourceDef { - /// Parse path pattern and create new `Resource` instance. + /// Parse path pattern and create new `ResourceDef` instance. /// /// Panics if path pattern is wrong. pub fn new(path: &str) -> Self { - ResourceDef::with_prefix(path, "/", false) + ResourceDef::with_prefix(path, false, !path.is_empty()) } - /// Parse path pattern and create new `Resource` instance. + /// Parse path pattern and create new `ResourceDef` instance. /// /// Use `prefix` type instead of `static`. /// /// Panics if path regex pattern is wrong. pub fn prefix(path: &str) -> Self { - ResourceDef::with_prefix(path, "/", true) + ResourceDef::with_prefix(path, true, !path.is_empty()) } - /// Construct external resource + /// Construct external resource def /// /// Panics if path pattern is wrong. pub fn external(path: &str) -> Self { - let mut resource = ResourceDef::with_prefix(path, "/", false); + let mut resource = ResourceDef::with_prefix(path, false, false); resource.rtp = ResourceType::External; resource } - /// Parse path pattern and create new `Resource` instance with custom prefix - pub fn with_prefix(path: &str, prefix: &str, for_prefix: bool) -> Self { - let (pattern, elements, is_dynamic, len) = - ResourceDef::parse(path, prefix, for_prefix); + /// Parse path pattern and create new `ResourceDef` instance with custom prefix + pub fn with_prefix(path: &str, for_prefix: bool, slash: bool) -> Self { + let mut path = path.to_owned(); + if slash && !path.starts_with('/') { + path.insert(0, '/'); + } + let (pattern, elements, is_dynamic, len) = ResourceDef::parse(&path, for_prefix); let tp = if is_dynamic { let re = match Regex::new(&pattern) { @@ -705,23 +788,21 @@ impl ResourceDef { /// Build resource path. pub fn resource_path( - &self, elements: U, prefix: &str, - ) -> Result + &self, path: &mut String, elements: &mut U, + ) -> Result<(), UrlGenerationError> where - U: IntoIterator, + U: Iterator, I: AsRef, { - let mut path = match self.tp { - PatternType::Prefix(ref p) => p.to_owned(), - PatternType::Static(ref p) => p.to_owned(), + match self.tp { + PatternType::Prefix(ref p) => path.push_str(p), + PatternType::Static(ref p) => path.push_str(p), PatternType::Dynamic(..) => { - let mut path = String::new(); - let mut iter = elements.into_iter(); for el in &self.elements { match *el { PatternElement::Str(ref s) => path.push_str(s), PatternElement::Var(_) => { - if let Some(val) = iter.next() { + if let Some(val) = elements.next() { path.push_str(val.as_ref()) } else { return Err(UrlGenerationError::NotEnoughElements); @@ -729,99 +810,75 @@ impl ResourceDef { } } } - path } }; + Ok(()) + } - if self.rtp != ResourceType::External { - if prefix.ends_with('/') { - if path.starts_with('/') { - path.insert_str(0, &prefix[..prefix.len() - 1]); - } else { - path.insert_str(0, prefix); + fn parse_param(pattern: &str) -> (PatternElement, String, &str) { + const DEFAULT_PATTERN: &str = "[^/]+"; + let mut params_nesting = 0usize; + let close_idx = pattern + .find(|c| match c { + '{' => { + params_nesting += 1; + false } - } else { - if !path.starts_with('/') { - path.insert(0, '/'); + '}' => { + params_nesting -= 1; + params_nesting == 0 } - path.insert_str(0, prefix); + _ => false, + }).expect("malformed param"); + let (mut param, rem) = pattern.split_at(close_idx + 1); + param = ¶m[1..param.len() - 1]; // Remove outer brackets + let (name, pattern) = match param.find(':') { + Some(idx) => { + let (name, pattern) = param.split_at(idx); + (name, &pattern[1..]) } - } - Ok(path) + None => (param, DEFAULT_PATTERN), + }; + ( + PatternElement::Var(name.to_string()), + format!(r"(?P<{}>{})", &name, &pattern), + rem, + ) } fn parse( - pattern: &str, prefix: &str, for_prefix: bool, + mut pattern: &str, for_prefix: bool, ) -> (String, Vec, bool, usize) { - const DEFAULT_PATTERN: &str = "[^/]+"; - - let mut re1 = String::from("^") + prefix; - let mut re2 = String::from(prefix); - let mut el = String::new(); - let mut in_param = false; - let mut in_param_pattern = false; - let mut param_name = String::new(); - let mut param_pattern = String::from(DEFAULT_PATTERN); - let mut is_dynamic = false; - let mut elems = Vec::new(); - let mut len = 0; - - for (index, ch) in pattern.chars().enumerate() { - // All routes must have a leading slash so its optional to have one - if index == 0 && ch == '/' { - continue; - } - - if in_param { - // In parameter segment: `{....}` - if ch == '}' { - elems.push(PatternElement::Var(param_name.clone())); - re1.push_str(&format!(r"(?P<{}>{})", ¶m_name, ¶m_pattern)); - - param_name.clear(); - param_pattern = String::from(DEFAULT_PATTERN); - - len = 0; - in_param_pattern = false; - in_param = false; - } else if ch == ':' { - // The parameter name has been determined; custom pattern land - in_param_pattern = true; - param_pattern.clear(); - } else if in_param_pattern { - // Ignore leading whitespace for pattern - if !(ch == ' ' && param_pattern.is_empty()) { - param_pattern.push(ch); - } - } else { - param_name.push(ch); - } - } else if ch == '{' { - in_param = true; - is_dynamic = true; - elems.push(PatternElement::Str(el.clone())); - el.clear(); - } else { - re1.push_str(escape(&ch.to_string()).as_str()); - re2.push(ch); - el.push(ch); - len += 1; - } - } - - if !el.is_empty() { - elems.push(PatternElement::Str(el.clone())); - } - - let re = if is_dynamic { - if !for_prefix { - re1.push('$'); - } - re1 - } else { - re2 + if pattern.find('{').is_none() { + return ( + String::from(pattern), + vec![PatternElement::Str(String::from(pattern))], + false, + pattern.chars().count(), + ); }; - (re, elems, is_dynamic, len) + + let mut elems = Vec::new(); + let mut re = String::from("^"); + + while let Some(idx) = pattern.find('{') { + let (prefix, rem) = pattern.split_at(idx); + elems.push(PatternElement::Str(String::from(prefix))); + re.push_str(&escape(prefix)); + let (param_pattern, re_part, rem) = Self::parse_param(rem); + elems.push(param_pattern); + re.push_str(&re_part); + pattern = rem; + } + + elems.push(PatternElement::Str(String::from(pattern))); + re.push_str(&escape(pattern)); + + if !for_prefix { + re.push_str("$"); + } + + (re, elems, true, pattern.chars().count()) } } @@ -846,7 +903,7 @@ mod tests { #[test] fn test_recognizer10() { - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); router.register_resource(Resource::new(ResourceDef::new("/name"))); router.register_resource(Resource::new(ResourceDef::new("/name/{val}"))); router.register_resource(Resource::new(ResourceDef::new( @@ -858,7 +915,7 @@ mod tests { ))); router.register_resource(Resource::new(ResourceDef::new("/v/{tail:.*}"))); router.register_resource(Resource::new(ResourceDef::new("/test2/{test}.html"))); - router.register_resource(Resource::new(ResourceDef::new("{test}/index.html"))); + router.register_resource(Resource::new(ResourceDef::new("/{test}/index.html"))); let req = TestRequest::with_uri("/name").finish(); let info = router.recognize(&req, &(), 0); @@ -909,7 +966,7 @@ mod tests { #[test] fn test_recognizer_2() { - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); router.register_resource(Resource::new(ResourceDef::new("/index.json"))); router.register_resource(Resource::new(ResourceDef::new("/{source}.json"))); @@ -924,7 +981,7 @@ mod tests { #[test] fn test_recognizer_with_prefix() { - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); router.register_resource(Resource::new(ResourceDef::new("/name"))); router.register_resource(Resource::new(ResourceDef::new("/name/{val}"))); @@ -943,7 +1000,7 @@ mod tests { assert_eq!(&info.match_info()["val"], "value"); // same patterns - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); router.register_resource(Resource::new(ResourceDef::new("/name"))); router.register_resource(Resource::new(ResourceDef::new("/name/{val}"))); @@ -1012,6 +1069,16 @@ mod tests { let info = re.match_with_params(&req, 0).unwrap(); assert_eq!(info.get("version").unwrap(), "151"); assert_eq!(info.get("id").unwrap(), "adahg32"); + + let re = ResourceDef::new("/{id:[[:digit:]]{6}}"); + assert!(re.is_match("/012345")); + assert!(!re.is_match("/012")); + assert!(!re.is_match("/01234567")); + assert!(!re.is_match("/XXXXXX")); + + let req = TestRequest::with_uri("/012345").finish(); + let info = re.match_with_params(&req, 0).unwrap(); + assert_eq!(info.get("id").unwrap(), "012345"); } #[test] @@ -1049,7 +1116,7 @@ mod tests { #[test] fn test_request_resource() { - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); let mut resource = Resource::new(ResourceDef::new("/index.json")); resource.name("r1"); router.register_resource(resource); @@ -1071,7 +1138,7 @@ mod tests { #[test] fn test_has_resource() { - let mut router = Router::<()>::new(); + let mut router = Router::<()>::default(); let scope = Scope::new("/test").resource("/name", |_| "done"); router.register_scope(scope); @@ -1088,4 +1155,93 @@ mod tests { let info = router.default_route_info(); assert!(info.has_resource("/test2/test10/name")); } + + #[test] + fn test_url_for() { + let mut router = Router::<()>::new(ResourceDef::prefix("")); + + let mut resource = Resource::new(ResourceDef::new("/tttt")); + resource.name("r0"); + router.register_resource(resource); + + let scope = Scope::new("/test").resource("/name", |r| { + r.name("r1"); + }); + router.register_scope(scope); + + let scope = Scope::new("/test2") + .nested("/test10", |s| s.resource("/name", |r| r.name("r2"))); + router.register_scope(scope); + router.finish(); + + let req = TestRequest::with_uri("/test").request(); + { + let info = router.default_route_info(); + + let res = info + .url_for(&req, "r0", Vec::<&'static str>::new()) + .unwrap(); + assert_eq!(res.as_str(), "http://localhost:8080/tttt"); + + let res = info + .url_for(&req, "r1", Vec::<&'static str>::new()) + .unwrap(); + assert_eq!(res.as_str(), "http://localhost:8080/test/name"); + + let res = info + .url_for(&req, "r2", Vec::<&'static str>::new()) + .unwrap(); + assert_eq!(res.as_str(), "http://localhost:8080/test2/test10/name"); + } + + let req = TestRequest::with_uri("/test/name").request(); + let info = router.recognize(&req, &(), 0); + assert_eq!(info.resource, ResourceId::Normal(1)); + + let res = info + .url_for(&req, "r0", Vec::<&'static str>::new()) + .unwrap(); + assert_eq!(res.as_str(), "http://localhost:8080/tttt"); + + let res = info + .url_for(&req, "r1", Vec::<&'static str>::new()) + .unwrap(); + assert_eq!(res.as_str(), "http://localhost:8080/test/name"); + + let res = info + .url_for(&req, "r2", Vec::<&'static str>::new()) + .unwrap(); + assert_eq!(res.as_str(), "http://localhost:8080/test2/test10/name"); + } + + #[test] + fn test_url_for_dynamic() { + let mut router = Router::<()>::new(ResourceDef::prefix("")); + + let mut resource = Resource::new(ResourceDef::new("/{name}/test/index.{ext}")); + resource.name("r0"); + router.register_resource(resource); + + let scope = Scope::new("/{name1}").nested("/{name2}", |s| { + s.resource("/{name3}/test/index.{ext}", |r| r.name("r2")) + }); + router.register_scope(scope); + router.finish(); + + let req = TestRequest::with_uri("/test").request(); + { + let info = router.default_route_info(); + + let res = info.url_for(&req, "r0", vec!["sec1", "html"]).unwrap(); + assert_eq!(res.as_str(), "http://localhost:8080/sec1/test/index.html"); + + let res = info + .url_for(&req, "r2", vec!["sec1", "sec2", "sec3", "html"]) + .unwrap(); + assert_eq!( + res.as_str(), + "http://localhost:8080/sec1/sec2/sec3/test/index.html" + ); + } + } } diff --git a/src/scope.rs b/src/scope.rs index 43d078529..fb9e7514a 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -5,7 +5,10 @@ use std::rc::Rc; use futures::{Async, Future, Poll}; use error::Error; -use handler::{AsyncResult, AsyncResultItem, FromRequest, Responder, RouteHandler}; +use handler::{ + AsyncResult, AsyncResultItem, FromRequest, Handler, Responder, RouteHandler, + WrapHandler, +}; use http::Method; use httprequest::HttpRequest; use httpresponse::HttpResponse; @@ -17,6 +20,7 @@ use pred::Predicate; use resource::{DefaultResource, Resource}; use router::{ResourceDef, Router}; use server::Request; +use with::WithFactory; /// Resources scope /// @@ -55,14 +59,17 @@ pub struct Scope { middlewares: Rc>>>, } -#[cfg_attr(feature = "cargo-clippy", allow(new_without_default_derive))] +#[cfg_attr( + feature = "cargo-clippy", + allow(new_without_default_derive) +)] impl Scope { /// Create a new scope - // TODO: Why is this not exactly the default impl? pub fn new(path: &str) -> Scope { + let rdef = ResourceDef::prefix(path); Scope { - rdef: ResourceDef::prefix(path), - router: Rc::new(Router::new()), + rdef: rdef.clone(), + router: Rc::new(Router::new(rdef)), filters: Vec::new(), middlewares: Rc::new(Vec::new()), } @@ -132,10 +139,11 @@ impl Scope { where F: FnOnce(Scope) -> Scope, { + let rdef = ResourceDef::prefix(path); let scope = Scope { - rdef: ResourceDef::prefix(path), + rdef: rdef.clone(), filters: Vec::new(), - router: Rc::new(Router::new()), + router: Rc::new(Router::new(rdef)), middlewares: Rc::new(Vec::new()), }; let mut scope = f(scope); @@ -178,10 +186,11 @@ impl Scope { where F: FnOnce(Scope) -> Scope, { + let rdef = ResourceDef::prefix(&insert_slash(path)); let scope = Scope { - rdef: ResourceDef::prefix(&path), + rdef: rdef.clone(), filters: Vec::new(), - router: Rc::new(Router::new()), + router: Rc::new(Router::new(rdef)), middlewares: Rc::new(Vec::new()), }; Rc::get_mut(&mut self.router) @@ -220,13 +229,15 @@ impl Scope { /// ``` pub fn route(mut self, path: &str, method: Method, f: F) -> Scope where - F: Fn(T) -> R + 'static, + F: WithFactory, R: Responder + 'static, T: FromRequest + 'static, { - Rc::get_mut(&mut self.router) - .unwrap() - .register_route(path, method, f); + Rc::get_mut(&mut self.router).unwrap().register_route( + &insert_slash(path), + method, + f, + ); self } @@ -258,12 +269,7 @@ impl Scope { F: FnOnce(&mut Resource) -> R + 'static, { // add resource - let pattern = ResourceDef::with_prefix( - path, - if path.is_empty() { "" } else { "/" }, - false, - ); - let mut resource = Resource::new(pattern); + let mut resource = Resource::new(ResourceDef::new(&insert_slash(path))); f(&mut resource); Rc::get_mut(&mut self.router) @@ -288,6 +294,35 @@ impl Scope { self } + /// Configure handler for specific path prefix. + /// + /// A path prefix consists of valid path segments, i.e for the + /// prefix `/app` any request with the paths `/app`, `/app/` or + /// `/app/test` would match, but the path `/application` would + /// not. + /// + /// ```rust + /// # extern crate actix_web; + /// use actix_web::{http, App, HttpRequest, HttpResponse}; + /// + /// fn main() { + /// let app = App::new().scope("/scope-prefix", |scope| { + /// scope.handler("/app", |req: &HttpRequest| match *req.method() { + /// http::Method::GET => HttpResponse::Ok(), + /// http::Method::POST => HttpResponse::MethodNotAllowed(), + /// _ => HttpResponse::NotFound(), + /// }) + /// }); + /// } + /// ``` + pub fn handler>(mut self, path: &str, handler: H) -> Scope { + let path = insert_slash(path.trim().trim_right_matches('/')); + Rc::get_mut(&mut self.router) + .expect("Multiple copies of scope router") + .register_handler(&path, Box::new(WrapHandler::new(handler)), None); + self + } + /// Register a scope middleware /// /// This is similar to `App's` middlewares, but @@ -303,6 +338,14 @@ impl Scope { } } +fn insert_slash(path: &str) -> String { + let mut path = path.to_owned(); + if !path.is_empty() && !path.starts_with('/') { + path.insert(0, '/'); + }; + path +} + impl RouteHandler for Scope { fn handle(&self, req: &HttpRequest) -> AsyncResult { let tail = req.match_info().tail as usize; @@ -313,7 +356,7 @@ impl RouteHandler for Scope { if self.middlewares.is_empty() { self.router.handle(&req2) } else { - AsyncResult::async(Box::new(Compose::new( + AsyncResult::future(Box::new(Compose::new( req2, Rc::clone(&self.router), Rc::clone(&self.middlewares), @@ -717,8 +760,7 @@ mod tests { let app = App::new() .scope("/app", |scope| { scope.resource("/path1", |r| r.f(|_| HttpResponse::Ok())) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/path1").request(); let resp = app.run(req); @@ -732,8 +774,7 @@ mod tests { scope .resource("", |r| r.f(|_| HttpResponse::Ok())) .resource("/", |r| r.f(|_| HttpResponse::Created())) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app").request(); let resp = app.run(req); @@ -749,8 +790,7 @@ mod tests { let app = App::new() .scope("/app/", |scope| { scope.resource("", |r| r.f(|_| HttpResponse::Ok())) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app").request(); let resp = app.run(req); @@ -766,8 +806,7 @@ mod tests { let app = App::new() .scope("/app/", |scope| { scope.resource("/", |r| r.f(|_| HttpResponse::Ok())) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app").request(); let resp = app.run(req); @@ -785,12 +824,38 @@ mod tests { scope .route("/path1", Method::GET, |_: HttpRequest<_>| { HttpResponse::Ok() - }) - .route("/path1", Method::DELETE, |_: HttpRequest<_>| { + }).route("/path1", Method::DELETE, |_: HttpRequest<_>| { HttpResponse::Ok() }) - }) - .finish(); + }).finish(); + + let req = TestRequest::with_uri("/app/path1").request(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::OK); + + let req = TestRequest::with_uri("/app/path1") + .method(Method::DELETE) + .request(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::OK); + + let req = TestRequest::with_uri("/app/path1") + .method(Method::POST) + .request(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::NOT_FOUND); + } + + #[test] + fn test_scope_route_without_leading_slash() { + let app = App::new() + .scope("app", |scope| { + scope + .route("path1", Method::GET, |_: HttpRequest<_>| HttpResponse::Ok()) + .route("path1", Method::DELETE, |_: HttpRequest<_>| { + HttpResponse::Ok() + }) + }).finish(); let req = TestRequest::with_uri("/app/path1").request(); let resp = app.run(req); @@ -816,8 +881,7 @@ mod tests { scope .filter(pred::Get()) .resource("/path1", |r| r.f(|_| HttpResponse::Ok())) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/path1") .method(Method::POST) @@ -842,8 +906,7 @@ mod tests { .body(format!("project: {}", &r.match_info()["project"])) }) }) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/ab-project1/path1").request(); let resp = app.run(req); @@ -871,8 +934,7 @@ mod tests { scope.with_state("/t1", State, |scope| { scope.resource("/path1", |r| r.f(|_| HttpResponse::Created())) }) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/t1/path1").request(); let resp = app.run(req); @@ -890,8 +952,7 @@ mod tests { .resource("", |r| r.f(|_| HttpResponse::Ok())) .resource("/", |r| r.f(|_| HttpResponse::Created())) }) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/t1").request(); let resp = app.run(req); @@ -911,8 +972,7 @@ mod tests { scope.with_state("/t1/", State, |scope| { scope.resource("", |r| r.f(|_| HttpResponse::Ok())) }) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/t1").request(); let resp = app.run(req); @@ -932,8 +992,7 @@ mod tests { scope.with_state("/t1/", State, |scope| { scope.resource("/", |r| r.f(|_| HttpResponse::Ok())) }) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/t1").request(); let resp = app.run(req); @@ -955,8 +1014,7 @@ mod tests { .filter(pred::Get()) .resource("/path1", |r| r.f(|_| HttpResponse::Ok())) }) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/t1/path1") .method(Method::POST) @@ -978,8 +1036,21 @@ mod tests { scope.nested("/t1", |scope| { scope.resource("/path1", |r| r.f(|_| HttpResponse::Created())) }) - }) - .finish(); + }).finish(); + + let req = TestRequest::with_uri("/app/t1/path1").request(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::CREATED); + } + + #[test] + fn test_nested_scope_no_slash() { + let app = App::new() + .scope("/app", |scope| { + scope.nested("t1", |scope| { + scope.resource("/path1", |r| r.f(|_| HttpResponse::Created())) + }) + }).finish(); let req = TestRequest::with_uri("/app/t1/path1").request(); let resp = app.run(req); @@ -995,8 +1066,7 @@ mod tests { .resource("", |r| r.f(|_| HttpResponse::Ok())) .resource("/", |r| r.f(|_| HttpResponse::Created())) }) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/t1").request(); let resp = app.run(req); @@ -1016,8 +1086,7 @@ mod tests { .filter(pred::Get()) .resource("/path1", |r| r.f(|_| HttpResponse::Ok())) }) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/t1/path1") .method(Method::POST) @@ -1046,8 +1115,7 @@ mod tests { }) }) }) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/project_1/path1").request(); let resp = app.run(req); @@ -1079,8 +1147,7 @@ mod tests { }) }) }) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/test/1/path1").request(); let resp = app.run(req); @@ -1106,8 +1173,7 @@ mod tests { scope .resource("/path1", |r| r.f(|_| HttpResponse::Ok())) .default_resource(|r| r.f(|_| HttpResponse::BadRequest())) - }) - .finish(); + }).finish(); let req = TestRequest::with_uri("/app/path2").request(); let resp = app.run(req); @@ -1123,8 +1189,7 @@ mod tests { let app = App::new() .scope("/app1", |scope| { scope.default_resource(|r| r.f(|_| HttpResponse::BadRequest())) - }) - .scope("/app2", |scope| scope) + }).scope("/app2", |scope| scope) .default_resource(|r| r.f(|_| HttpResponse::MethodNotAllowed())) .finish(); @@ -1140,4 +1205,32 @@ mod tests { let resp = app.run(req); assert_eq!(resp.as_msg().status(), StatusCode::METHOD_NOT_ALLOWED); } + + #[test] + fn test_handler() { + let app = App::new() + .scope("/scope", |scope| { + scope.handler("/test", |_: &_| HttpResponse::Ok()) + }).finish(); + + let req = TestRequest::with_uri("/scope/test").request(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::OK); + + let req = TestRequest::with_uri("/scope/test/").request(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::OK); + + let req = TestRequest::with_uri("/scope/test/app").request(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::OK); + + let req = TestRequest::with_uri("/scope/testapp").request(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::NOT_FOUND); + + let req = TestRequest::with_uri("/scope/blah").request(); + let resp = app.run(req); + assert_eq!(resp.as_msg().status(), StatusCode::NOT_FOUND); + } } diff --git a/src/serde_urlencoded/de.rs b/src/serde_urlencoded/de.rs deleted file mode 100644 index ae14afbf5..000000000 --- a/src/serde_urlencoded/de.rs +++ /dev/null @@ -1,305 +0,0 @@ -//! Deserialization support for the `application/x-www-form-urlencoded` format. - -use serde::de::Error as de_Error; -use serde::de::{ - self, DeserializeSeed, EnumAccess, IntoDeserializer, VariantAccess, Visitor, -}; - -use serde::de::value::MapDeserializer; -use std::borrow::Cow; -use std::io::Read; -use url::form_urlencoded::parse; -use url::form_urlencoded::Parse as UrlEncodedParse; - -#[doc(inline)] -pub use serde::de::value::Error; - -/// Deserializes a `application/x-wwww-url-encoded` value from a `&[u8]`. -/// -/// ```ignore -/// let meal = vec![ -/// ("bread".to_owned(), "baguette".to_owned()), -/// ("cheese".to_owned(), "comté".to_owned()), -/// ("meat".to_owned(), "ham".to_owned()), -/// ("fat".to_owned(), "butter".to_owned()), -/// ]; -/// -/// assert_eq!( -/// serde_urlencoded::from_bytes::>( -/// b"bread=baguette&cheese=comt%C3%A9&meat=ham&fat=butter"), -/// Ok(meal)); -/// ``` -pub fn from_bytes<'de, T>(input: &'de [u8]) -> Result -where - T: de::Deserialize<'de>, -{ - T::deserialize(Deserializer::new(parse(input))) -} - -/// Deserializes a `application/x-wwww-url-encoded` value from a `&str`. -/// -/// ```ignore -/// let meal = vec![ -/// ("bread".to_owned(), "baguette".to_owned()), -/// ("cheese".to_owned(), "comté".to_owned()), -/// ("meat".to_owned(), "ham".to_owned()), -/// ("fat".to_owned(), "butter".to_owned()), -/// ]; -/// -/// assert_eq!( -/// serde_urlencoded::from_str::>( -/// "bread=baguette&cheese=comt%C3%A9&meat=ham&fat=butter"), -/// Ok(meal)); -/// ``` -pub fn from_str<'de, T>(input: &'de str) -> Result -where - T: de::Deserialize<'de>, -{ - from_bytes(input.as_bytes()) -} - -#[allow(dead_code)] -/// Convenience function that reads all bytes from `reader` and deserializes -/// them with `from_bytes`. -pub fn from_reader(mut reader: R) -> Result -where - T: de::DeserializeOwned, - R: Read, -{ - let mut buf = vec![]; - reader - .read_to_end(&mut buf) - .map_err(|e| de::Error::custom(format_args!("could not read input: {}", e)))?; - from_bytes(&buf) -} - -/// A deserializer for the `application/x-www-form-urlencoded` format. -/// -/// * Supported top-level outputs are structs, maps and sequences of pairs, -/// with or without a given length. -/// -/// * Main `deserialize` methods defers to `deserialize_map`. -/// -/// * Everything else but `deserialize_seq` and `deserialize_seq_fixed_size` -/// defers to `deserialize`. -pub struct Deserializer<'de> { - inner: MapDeserializer<'de, PartIterator<'de>, Error>, -} - -impl<'de> Deserializer<'de> { - /// Returns a new `Deserializer`. - pub fn new(parser: UrlEncodedParse<'de>) -> Self { - Deserializer { - inner: MapDeserializer::new(PartIterator(parser)), - } - } -} - -impl<'de> de::Deserializer<'de> for Deserializer<'de> { - type Error = Error; - - fn deserialize_any(self, visitor: V) -> Result - where - V: de::Visitor<'de>, - { - self.deserialize_map(visitor) - } - - fn deserialize_map(self, visitor: V) -> Result - where - V: de::Visitor<'de>, - { - visitor.visit_map(self.inner) - } - - fn deserialize_seq(self, visitor: V) -> Result - where - V: de::Visitor<'de>, - { - visitor.visit_seq(self.inner) - } - - fn deserialize_unit(self, visitor: V) -> Result - where - V: de::Visitor<'de>, - { - self.inner.end()?; - visitor.visit_unit() - } - - forward_to_deserialize_any! { - bool - u8 - u16 - u32 - u64 - i8 - i16 - i32 - i64 - f32 - f64 - char - str - string - option - bytes - byte_buf - unit_struct - newtype_struct - tuple_struct - struct - identifier - tuple - enum - ignored_any - } -} - -struct PartIterator<'de>(UrlEncodedParse<'de>); - -impl<'de> Iterator for PartIterator<'de> { - type Item = (Part<'de>, Part<'de>); - - fn next(&mut self) -> Option { - self.0.next().map(|(k, v)| (Part(k), Part(v))) - } -} - -struct Part<'de>(Cow<'de, str>); - -impl<'de> IntoDeserializer<'de> for Part<'de> { - type Deserializer = Self; - - fn into_deserializer(self) -> Self::Deserializer { - self - } -} - -macro_rules! forward_parsed_value { - ($($ty:ident => $method:ident,)*) => { - $( - fn $method(self, visitor: V) -> Result - where V: de::Visitor<'de> - { - match self.0.parse::<$ty>() { - Ok(val) => val.into_deserializer().$method(visitor), - Err(e) => Err(de::Error::custom(e)) - } - } - )* - } -} - -impl<'de> de::Deserializer<'de> for Part<'de> { - type Error = Error; - - fn deserialize_any(self, visitor: V) -> Result - where - V: de::Visitor<'de>, - { - self.0.into_deserializer().deserialize_any(visitor) - } - - fn deserialize_option(self, visitor: V) -> Result - where - V: de::Visitor<'de>, - { - visitor.visit_some(self) - } - - fn deserialize_enum( - self, _name: &'static str, _variants: &'static [&'static str], visitor: V, - ) -> Result - where - V: de::Visitor<'de>, - { - visitor.visit_enum(ValueEnumAccess { value: self.0 }) - } - - forward_to_deserialize_any! { - char - str - string - unit - bytes - byte_buf - unit_struct - newtype_struct - tuple_struct - struct - identifier - tuple - ignored_any - seq - map - } - - forward_parsed_value! { - bool => deserialize_bool, - u8 => deserialize_u8, - u16 => deserialize_u16, - u32 => deserialize_u32, - u64 => deserialize_u64, - i8 => deserialize_i8, - i16 => deserialize_i16, - i32 => deserialize_i32, - i64 => deserialize_i64, - f32 => deserialize_f32, - f64 => deserialize_f64, - } -} - -/// Provides access to a keyword which can be deserialized into an enum variant. The enum variant -/// must be a unit variant, otherwise deserialization will fail. -struct ValueEnumAccess<'de> { - value: Cow<'de, str>, -} - -impl<'de> EnumAccess<'de> for ValueEnumAccess<'de> { - type Error = Error; - type Variant = UnitOnlyVariantAccess; - - fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> - where - V: DeserializeSeed<'de>, - { - let variant = seed.deserialize(self.value.into_deserializer())?; - Ok((variant, UnitOnlyVariantAccess)) - } -} - -/// A visitor for deserializing the contents of the enum variant. As we only support -/// `unit_variant`, all other variant types will return an error. -struct UnitOnlyVariantAccess; - -impl<'de> VariantAccess<'de> for UnitOnlyVariantAccess { - type Error = Error; - - fn unit_variant(self) -> Result<(), Self::Error> { - Ok(()) - } - - fn newtype_variant_seed(self, _seed: T) -> Result - where - T: DeserializeSeed<'de>, - { - Err(Error::custom("expected unit variant")) - } - - fn tuple_variant(self, _len: usize, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::custom("expected unit variant")) - } - - fn struct_variant( - self, _fields: &'static [&'static str], _visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - Err(Error::custom("expected unit variant")) - } -} diff --git a/src/serde_urlencoded/mod.rs b/src/serde_urlencoded/mod.rs deleted file mode 100644 index 7e2cf33ae..000000000 --- a/src/serde_urlencoded/mod.rs +++ /dev/null @@ -1,121 +0,0 @@ -//! `x-www-form-urlencoded` meets Serde - -extern crate dtoa; -extern crate itoa; - -pub mod de; -pub mod ser; - -#[doc(inline)] -pub use self::de::{from_bytes, from_reader, from_str, Deserializer}; -#[doc(inline)] -pub use self::ser::{to_string, Serializer}; - -#[cfg(test)] -mod tests { - #[test] - fn deserialize_bytes() { - let result = vec![("first".to_owned(), 23), ("last".to_owned(), 42)]; - - assert_eq!(super::from_bytes(b"first=23&last=42"), Ok(result)); - } - - #[test] - fn deserialize_str() { - let result = vec![("first".to_owned(), 23), ("last".to_owned(), 42)]; - - assert_eq!(super::from_str("first=23&last=42"), Ok(result)); - } - - #[test] - fn deserialize_reader() { - let result = vec![("first".to_owned(), 23), ("last".to_owned(), 42)]; - - assert_eq!(super::from_reader(b"first=23&last=42" as &[_]), Ok(result)); - } - - #[test] - fn deserialize_option() { - let result = vec![ - ("first".to_owned(), Some(23)), - ("last".to_owned(), Some(42)), - ]; - assert_eq!(super::from_str("first=23&last=42"), Ok(result)); - } - - #[test] - fn deserialize_unit() { - assert_eq!(super::from_str(""), Ok(())); - assert_eq!(super::from_str("&"), Ok(())); - assert_eq!(super::from_str("&&"), Ok(())); - assert!(super::from_str::<()>("first=23").is_err()); - } - - #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] - enum X { - A, - B, - C, - } - - #[test] - fn deserialize_unit_enum() { - let result = vec![ - ("one".to_owned(), X::A), - ("two".to_owned(), X::B), - ("three".to_owned(), X::C), - ]; - - assert_eq!(super::from_str("one=A&two=B&three=C"), Ok(result)); - } - - #[test] - fn serialize_option_map_int() { - let params = &[("first", Some(23)), ("middle", None), ("last", Some(42))]; - - assert_eq!(super::to_string(params), Ok("first=23&last=42".to_owned())); - } - - #[test] - fn serialize_option_map_string() { - let params = &[ - ("first", Some("hello")), - ("middle", None), - ("last", Some("world")), - ]; - - assert_eq!( - super::to_string(params), - Ok("first=hello&last=world".to_owned()) - ); - } - - #[test] - fn serialize_option_map_bool() { - let params = &[("one", Some(true)), ("two", Some(false))]; - - assert_eq!( - super::to_string(params), - Ok("one=true&two=false".to_owned()) - ); - } - - #[test] - fn serialize_map_bool() { - let params = &[("one", true), ("two", false)]; - - assert_eq!( - super::to_string(params), - Ok("one=true&two=false".to_owned()) - ); - } - - #[test] - fn serialize_unit_enum() { - let params = &[("one", X::A), ("two", X::B), ("three", X::C)]; - assert_eq!( - super::to_string(params), - Ok("one=A&two=B&three=C".to_owned()) - ); - } -} diff --git a/src/serde_urlencoded/ser/key.rs b/src/serde_urlencoded/ser/key.rs deleted file mode 100644 index 48497a558..000000000 --- a/src/serde_urlencoded/ser/key.rs +++ /dev/null @@ -1,74 +0,0 @@ -use super::super::ser::part::Sink; -use super::super::ser::Error; -use serde::Serialize; -use std::borrow::Cow; -use std::ops::Deref; - -pub enum Key<'key> { - Static(&'static str), - Dynamic(Cow<'key, str>), -} - -impl<'key> Deref for Key<'key> { - type Target = str; - - fn deref(&self) -> &str { - match *self { - Key::Static(key) => key, - Key::Dynamic(ref key) => key, - } - } -} - -impl<'key> From> for Cow<'static, str> { - fn from(key: Key<'key>) -> Self { - match key { - Key::Static(key) => key.into(), - Key::Dynamic(key) => key.into_owned().into(), - } - } -} - -pub struct KeySink { - end: End, -} - -impl KeySink -where - End: for<'key> FnOnce(Key<'key>) -> Result, -{ - pub fn new(end: End) -> Self { - KeySink { end } - } -} - -impl Sink for KeySink -where - End: for<'key> FnOnce(Key<'key>) -> Result, -{ - type Ok = Ok; - - fn serialize_static_str(self, value: &'static str) -> Result { - (self.end)(Key::Static(value)) - } - - fn serialize_str(self, value: &str) -> Result { - (self.end)(Key::Dynamic(value.into())) - } - - fn serialize_string(self, value: String) -> Result { - (self.end)(Key::Dynamic(value.into())) - } - - fn serialize_none(self) -> Result { - Err(self.unsupported()) - } - - fn serialize_some(self, _value: &T) -> Result { - Err(self.unsupported()) - } - - fn unsupported(self) -> Error { - Error::Custom("unsupported key".into()) - } -} diff --git a/src/serde_urlencoded/ser/mod.rs b/src/serde_urlencoded/ser/mod.rs deleted file mode 100644 index b4022d563..000000000 --- a/src/serde_urlencoded/ser/mod.rs +++ /dev/null @@ -1,490 +0,0 @@ -//! Serialization support for the `application/x-www-form-urlencoded` format. - -mod key; -mod pair; -mod part; -mod value; - -use serde::ser; -use std::borrow::Cow; -use std::error; -use std::fmt; -use std::str; -use url::form_urlencoded::Serializer as UrlEncodedSerializer; -use url::form_urlencoded::Target as UrlEncodedTarget; - -/// Serializes a value into a `application/x-wwww-url-encoded` `String` buffer. -/// -/// ```ignore -/// let meal = &[ -/// ("bread", "baguette"), -/// ("cheese", "comté"), -/// ("meat", "ham"), -/// ("fat", "butter"), -/// ]; -/// -/// assert_eq!( -/// serde_urlencoded::to_string(meal), -/// Ok("bread=baguette&cheese=comt%C3%A9&meat=ham&fat=butter".to_owned())); -/// ``` -pub fn to_string(input: T) -> Result { - let mut urlencoder = UrlEncodedSerializer::new("".to_owned()); - input.serialize(Serializer::new(&mut urlencoder))?; - Ok(urlencoder.finish()) -} - -/// A serializer for the `application/x-www-form-urlencoded` format. -/// -/// * Supported top-level inputs are structs, maps and sequences of pairs, -/// with or without a given length. -/// -/// * Supported keys and values are integers, bytes (if convertible to strings), -/// unit structs and unit variants. -/// -/// * Newtype structs defer to their inner values. -pub struct Serializer<'output, Target: 'output + UrlEncodedTarget> { - urlencoder: &'output mut UrlEncodedSerializer, -} - -impl<'output, Target: 'output + UrlEncodedTarget> Serializer<'output, Target> { - /// Returns a new `Serializer`. - pub fn new(urlencoder: &'output mut UrlEncodedSerializer) -> Self { - Serializer { urlencoder } - } -} - -/// Errors returned during serializing to `application/x-www-form-urlencoded`. -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum Error { - Custom(Cow<'static, str>), - Utf8(str::Utf8Error), -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Error::Custom(ref msg) => msg.fmt(f), - Error::Utf8(ref err) => write!(f, "invalid UTF-8: {}", err), - } - } -} - -impl error::Error for Error { - fn description(&self) -> &str { - match *self { - Error::Custom(ref msg) => msg, - Error::Utf8(ref err) => error::Error::description(err), - } - } - - /// The lower-level cause of this error, in the case of a `Utf8` error. - fn cause(&self) -> Option<&error::Error> { - match *self { - Error::Custom(_) => None, - Error::Utf8(ref err) => Some(err), - } - } -} - -impl ser::Error for Error { - fn custom(msg: T) -> Self { - Error::Custom(format!("{}", msg).into()) - } -} - -/// Sequence serializer. -pub struct SeqSerializer<'output, Target: 'output + UrlEncodedTarget> { - urlencoder: &'output mut UrlEncodedSerializer, -} - -/// Tuple serializer. -/// -/// Mostly used for arrays. -pub struct TupleSerializer<'output, Target: 'output + UrlEncodedTarget> { - urlencoder: &'output mut UrlEncodedSerializer, -} - -/// Tuple struct serializer. -/// -/// Never instantiated, tuple structs are not supported. -pub struct TupleStructSerializer<'output, T: 'output + UrlEncodedTarget> { - inner: ser::Impossible<&'output mut UrlEncodedSerializer, Error>, -} - -/// Tuple variant serializer. -/// -/// Never instantiated, tuple variants are not supported. -pub struct TupleVariantSerializer<'output, T: 'output + UrlEncodedTarget> { - inner: ser::Impossible<&'output mut UrlEncodedSerializer, Error>, -} - -/// Map serializer. -pub struct MapSerializer<'output, Target: 'output + UrlEncodedTarget> { - urlencoder: &'output mut UrlEncodedSerializer, - key: Option>, -} - -/// Struct serializer. -pub struct StructSerializer<'output, Target: 'output + UrlEncodedTarget> { - urlencoder: &'output mut UrlEncodedSerializer, -} - -/// Struct variant serializer. -/// -/// Never instantiated, struct variants are not supported. -pub struct StructVariantSerializer<'output, T: 'output + UrlEncodedTarget> { - inner: ser::Impossible<&'output mut UrlEncodedSerializer, Error>, -} - -impl<'output, Target> ser::Serializer for Serializer<'output, Target> -where - Target: 'output + UrlEncodedTarget, -{ - type Ok = &'output mut UrlEncodedSerializer; - type Error = Error; - type SerializeSeq = SeqSerializer<'output, Target>; - type SerializeTuple = TupleSerializer<'output, Target>; - type SerializeTupleStruct = TupleStructSerializer<'output, Target>; - type SerializeTupleVariant = TupleVariantSerializer<'output, Target>; - type SerializeMap = MapSerializer<'output, Target>; - type SerializeStruct = StructSerializer<'output, Target>; - type SerializeStructVariant = StructVariantSerializer<'output, Target>; - - /// Returns an error. - fn serialize_bool(self, _v: bool) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_i8(self, _v: i8) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_i16(self, _v: i16) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_i32(self, _v: i32) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_i64(self, _v: i64) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_u8(self, _v: u8) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_u16(self, _v: u16) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_u32(self, _v: u32) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_u64(self, _v: u64) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_f32(self, _v: f32) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_f64(self, _v: f64) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_char(self, _v: char) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_str(self, _value: &str) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_bytes(self, _value: &[u8]) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_unit(self) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_unit_struct(self, _name: &'static str) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_unit_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - ) -> Result { - Err(Error::top_level()) - } - - /// Serializes the inner value, ignoring the newtype name. - fn serialize_newtype_struct( - self, _name: &'static str, value: &T, - ) -> Result { - value.serialize(self) - } - - /// Returns an error. - fn serialize_newtype_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - _value: &T, - ) -> Result { - Err(Error::top_level()) - } - - /// Returns `Ok`. - fn serialize_none(self) -> Result { - Ok(self.urlencoder) - } - - /// Serializes the given value. - fn serialize_some( - self, value: &T, - ) -> Result { - value.serialize(self) - } - - /// Serialize a sequence, given length (if any) is ignored. - fn serialize_seq(self, _len: Option) -> Result { - Ok(SeqSerializer { - urlencoder: self.urlencoder, - }) - } - - /// Returns an error. - fn serialize_tuple(self, _len: usize) -> Result { - Ok(TupleSerializer { - urlencoder: self.urlencoder, - }) - } - - /// Returns an error. - fn serialize_tuple_struct( - self, _name: &'static str, _len: usize, - ) -> Result { - Err(Error::top_level()) - } - - /// Returns an error. - fn serialize_tuple_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - _len: usize, - ) -> Result { - Err(Error::top_level()) - } - - /// Serializes a map, given length is ignored. - fn serialize_map(self, _len: Option) -> Result { - Ok(MapSerializer { - urlencoder: self.urlencoder, - key: None, - }) - } - - /// Serializes a struct, given length is ignored. - fn serialize_struct( - self, _name: &'static str, _len: usize, - ) -> Result { - Ok(StructSerializer { - urlencoder: self.urlencoder, - }) - } - - /// Returns an error. - fn serialize_struct_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - _len: usize, - ) -> Result { - Err(Error::top_level()) - } -} - -impl<'output, Target> ser::SerializeSeq for SeqSerializer<'output, Target> -where - Target: 'output + UrlEncodedTarget, -{ - type Ok = &'output mut UrlEncodedSerializer; - type Error = Error; - - fn serialize_element( - &mut self, value: &T, - ) -> Result<(), Error> { - value.serialize(pair::PairSerializer::new(self.urlencoder)) - } - - fn end(self) -> Result { - Ok(self.urlencoder) - } -} - -impl<'output, Target> ser::SerializeTuple for TupleSerializer<'output, Target> -where - Target: 'output + UrlEncodedTarget, -{ - type Ok = &'output mut UrlEncodedSerializer; - type Error = Error; - - fn serialize_element( - &mut self, value: &T, - ) -> Result<(), Error> { - value.serialize(pair::PairSerializer::new(self.urlencoder)) - } - - fn end(self) -> Result { - Ok(self.urlencoder) - } -} - -impl<'output, Target> ser::SerializeTupleStruct - for TupleStructSerializer<'output, Target> -where - Target: 'output + UrlEncodedTarget, -{ - type Ok = &'output mut UrlEncodedSerializer; - type Error = Error; - - fn serialize_field( - &mut self, value: &T, - ) -> Result<(), Error> { - self.inner.serialize_field(value) - } - - fn end(self) -> Result { - self.inner.end() - } -} - -impl<'output, Target> ser::SerializeTupleVariant - for TupleVariantSerializer<'output, Target> -where - Target: 'output + UrlEncodedTarget, -{ - type Ok = &'output mut UrlEncodedSerializer; - type Error = Error; - - fn serialize_field( - &mut self, value: &T, - ) -> Result<(), Error> { - self.inner.serialize_field(value) - } - - fn end(self) -> Result { - self.inner.end() - } -} - -impl<'output, Target> ser::SerializeMap for MapSerializer<'output, Target> -where - Target: 'output + UrlEncodedTarget, -{ - type Ok = &'output mut UrlEncodedSerializer; - type Error = Error; - - fn serialize_entry( - &mut self, key: &K, value: &V, - ) -> Result<(), Error> { - let key_sink = key::KeySink::new(|key| { - let value_sink = value::ValueSink::new(self.urlencoder, &key); - value.serialize(part::PartSerializer::new(value_sink))?; - self.key = None; - Ok(()) - }); - let entry_serializer = part::PartSerializer::new(key_sink); - key.serialize(entry_serializer) - } - - fn serialize_key( - &mut self, key: &T, - ) -> Result<(), Error> { - let key_sink = key::KeySink::new(|key| Ok(key.into())); - let key_serializer = part::PartSerializer::new(key_sink); - self.key = Some(key.serialize(key_serializer)?); - Ok(()) - } - - fn serialize_value( - &mut self, value: &T, - ) -> Result<(), Error> { - { - let key = self.key.as_ref().ok_or_else(Error::no_key)?; - let value_sink = value::ValueSink::new(self.urlencoder, &key); - value.serialize(part::PartSerializer::new(value_sink))?; - } - self.key = None; - Ok(()) - } - - fn end(self) -> Result { - Ok(self.urlencoder) - } -} - -impl<'output, Target> ser::SerializeStruct for StructSerializer<'output, Target> -where - Target: 'output + UrlEncodedTarget, -{ - type Ok = &'output mut UrlEncodedSerializer; - type Error = Error; - - fn serialize_field( - &mut self, key: &'static str, value: &T, - ) -> Result<(), Error> { - let value_sink = value::ValueSink::new(self.urlencoder, key); - value.serialize(part::PartSerializer::new(value_sink)) - } - - fn end(self) -> Result { - Ok(self.urlencoder) - } -} - -impl<'output, Target> ser::SerializeStructVariant - for StructVariantSerializer<'output, Target> -where - Target: 'output + UrlEncodedTarget, -{ - type Ok = &'output mut UrlEncodedSerializer; - type Error = Error; - - fn serialize_field( - &mut self, key: &'static str, value: &T, - ) -> Result<(), Error> { - self.inner.serialize_field(key, value) - } - - fn end(self) -> Result { - self.inner.end() - } -} - -impl Error { - fn top_level() -> Self { - let msg = "top-level serializer supports only maps and structs"; - Error::Custom(msg.into()) - } - - fn no_key() -> Self { - let msg = "tried to serialize a value before serializing key"; - Error::Custom(msg.into()) - } -} diff --git a/src/serde_urlencoded/ser/pair.rs b/src/serde_urlencoded/ser/pair.rs deleted file mode 100644 index 68db144f9..000000000 --- a/src/serde_urlencoded/ser/pair.rs +++ /dev/null @@ -1,239 +0,0 @@ -use super::super::ser::key::KeySink; -use super::super::ser::part::PartSerializer; -use super::super::ser::value::ValueSink; -use super::super::ser::Error; -use serde::ser; -use std::borrow::Cow; -use std::mem; -use url::form_urlencoded::Serializer as UrlEncodedSerializer; -use url::form_urlencoded::Target as UrlEncodedTarget; - -pub struct PairSerializer<'target, Target: 'target + UrlEncodedTarget> { - urlencoder: &'target mut UrlEncodedSerializer, - state: PairState, -} - -impl<'target, Target> PairSerializer<'target, Target> -where - Target: 'target + UrlEncodedTarget, -{ - pub fn new(urlencoder: &'target mut UrlEncodedSerializer) -> Self { - PairSerializer { - urlencoder, - state: PairState::WaitingForKey, - } - } -} - -impl<'target, Target> ser::Serializer for PairSerializer<'target, Target> -where - Target: 'target + UrlEncodedTarget, -{ - type Ok = (); - type Error = Error; - type SerializeSeq = ser::Impossible<(), Error>; - type SerializeTuple = Self; - type SerializeTupleStruct = ser::Impossible<(), Error>; - type SerializeTupleVariant = ser::Impossible<(), Error>; - type SerializeMap = ser::Impossible<(), Error>; - type SerializeStruct = ser::Impossible<(), Error>; - type SerializeStructVariant = ser::Impossible<(), Error>; - - fn serialize_bool(self, _v: bool) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_i8(self, _v: i8) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_i16(self, _v: i16) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_i32(self, _v: i32) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_i64(self, _v: i64) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_u8(self, _v: u8) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_u16(self, _v: u16) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_u32(self, _v: u32) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_u64(self, _v: u64) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_f32(self, _v: f32) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_f64(self, _v: f64) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_char(self, _v: char) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_str(self, _value: &str) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_bytes(self, _value: &[u8]) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_unit(self) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_unit_struct(self, _name: &'static str) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_unit_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - ) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_newtype_struct( - self, _name: &'static str, value: &T, - ) -> Result<(), Error> { - value.serialize(self) - } - - fn serialize_newtype_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - _value: &T, - ) -> Result<(), Error> { - Err(Error::unsupported_pair()) - } - - fn serialize_none(self) -> Result<(), Error> { - Ok(()) - } - - fn serialize_some(self, value: &T) -> Result<(), Error> { - value.serialize(self) - } - - fn serialize_seq(self, _len: Option) -> Result { - Err(Error::unsupported_pair()) - } - - fn serialize_tuple(self, len: usize) -> Result { - if len == 2 { - Ok(self) - } else { - Err(Error::unsupported_pair()) - } - } - - fn serialize_tuple_struct( - self, _name: &'static str, _len: usize, - ) -> Result { - Err(Error::unsupported_pair()) - } - - fn serialize_tuple_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - _len: usize, - ) -> Result { - Err(Error::unsupported_pair()) - } - - fn serialize_map(self, _len: Option) -> Result { - Err(Error::unsupported_pair()) - } - - fn serialize_struct( - self, _name: &'static str, _len: usize, - ) -> Result { - Err(Error::unsupported_pair()) - } - - fn serialize_struct_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - _len: usize, - ) -> Result { - Err(Error::unsupported_pair()) - } -} - -impl<'target, Target> ser::SerializeTuple for PairSerializer<'target, Target> -where - Target: 'target + UrlEncodedTarget, -{ - type Ok = (); - type Error = Error; - - fn serialize_element( - &mut self, value: &T, - ) -> Result<(), Error> { - match mem::replace(&mut self.state, PairState::Done) { - PairState::WaitingForKey => { - let key_sink = KeySink::new(|key| Ok(key.into())); - let key_serializer = PartSerializer::new(key_sink); - self.state = PairState::WaitingForValue { - key: value.serialize(key_serializer)?, - }; - Ok(()) - } - PairState::WaitingForValue { key } => { - let result = { - let value_sink = ValueSink::new(self.urlencoder, &key); - let value_serializer = PartSerializer::new(value_sink); - value.serialize(value_serializer) - }; - if result.is_ok() { - self.state = PairState::Done; - } else { - self.state = PairState::WaitingForValue { key }; - } - result - } - PairState::Done => Err(Error::done()), - } - } - - fn end(self) -> Result<(), Error> { - if let PairState::Done = self.state { - Ok(()) - } else { - Err(Error::not_done()) - } - } -} - -enum PairState { - WaitingForKey, - WaitingForValue { key: Cow<'static, str> }, - Done, -} - -impl Error { - fn done() -> Self { - Error::Custom("this pair has already been serialized".into()) - } - - fn not_done() -> Self { - Error::Custom("this pair has not yet been serialized".into()) - } - - fn unsupported_pair() -> Self { - Error::Custom("unsupported pair".into()) - } -} diff --git a/src/serde_urlencoded/ser/part.rs b/src/serde_urlencoded/ser/part.rs deleted file mode 100644 index 4874dd34b..000000000 --- a/src/serde_urlencoded/ser/part.rs +++ /dev/null @@ -1,201 +0,0 @@ -use serde; - -use super::super::dtoa; -use super::super::itoa; -use super::super::ser::Error; -use std::str; - -pub struct PartSerializer { - sink: S, -} - -impl PartSerializer { - pub fn new(sink: S) -> Self { - PartSerializer { sink } - } -} - -pub trait Sink: Sized { - type Ok; - - fn serialize_static_str(self, value: &'static str) -> Result; - - fn serialize_str(self, value: &str) -> Result; - fn serialize_string(self, value: String) -> Result; - fn serialize_none(self) -> Result; - - fn serialize_some( - self, value: &T, - ) -> Result; - - fn unsupported(self) -> Error; -} - -impl serde::ser::Serializer for PartSerializer { - type Ok = S::Ok; - type Error = Error; - type SerializeSeq = serde::ser::Impossible; - type SerializeTuple = serde::ser::Impossible; - type SerializeTupleStruct = serde::ser::Impossible; - type SerializeTupleVariant = serde::ser::Impossible; - type SerializeMap = serde::ser::Impossible; - type SerializeStruct = serde::ser::Impossible; - type SerializeStructVariant = serde::ser::Impossible; - - fn serialize_bool(self, v: bool) -> Result { - self.sink - .serialize_static_str(if v { "true" } else { "false" }) - } - - fn serialize_i8(self, v: i8) -> Result { - self.serialize_integer(v) - } - - fn serialize_i16(self, v: i16) -> Result { - self.serialize_integer(v) - } - - fn serialize_i32(self, v: i32) -> Result { - self.serialize_integer(v) - } - - fn serialize_i64(self, v: i64) -> Result { - self.serialize_integer(v) - } - - fn serialize_u8(self, v: u8) -> Result { - self.serialize_integer(v) - } - - fn serialize_u16(self, v: u16) -> Result { - self.serialize_integer(v) - } - - fn serialize_u32(self, v: u32) -> Result { - self.serialize_integer(v) - } - - fn serialize_u64(self, v: u64) -> Result { - self.serialize_integer(v) - } - - fn serialize_f32(self, v: f32) -> Result { - self.serialize_floating(v) - } - - fn serialize_f64(self, v: f64) -> Result { - self.serialize_floating(v) - } - - fn serialize_char(self, v: char) -> Result { - self.sink.serialize_string(v.to_string()) - } - - fn serialize_str(self, value: &str) -> Result { - self.sink.serialize_str(value) - } - - fn serialize_bytes(self, value: &[u8]) -> Result { - match str::from_utf8(value) { - Ok(value) => self.sink.serialize_str(value), - Err(err) => Err(Error::Utf8(err)), - } - } - - fn serialize_unit(self) -> Result { - Err(self.sink.unsupported()) - } - - fn serialize_unit_struct(self, name: &'static str) -> Result { - self.sink.serialize_static_str(name) - } - - fn serialize_unit_variant( - self, _name: &'static str, _variant_index: u32, variant: &'static str, - ) -> Result { - self.sink.serialize_static_str(variant) - } - - fn serialize_newtype_struct( - self, _name: &'static str, value: &T, - ) -> Result { - value.serialize(self) - } - - fn serialize_newtype_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - _value: &T, - ) -> Result { - Err(self.sink.unsupported()) - } - - fn serialize_none(self) -> Result { - self.sink.serialize_none() - } - - fn serialize_some( - self, value: &T, - ) -> Result { - self.sink.serialize_some(value) - } - - fn serialize_seq(self, _len: Option) -> Result { - Err(self.sink.unsupported()) - } - - fn serialize_tuple(self, _len: usize) -> Result { - Err(self.sink.unsupported()) - } - - fn serialize_tuple_struct( - self, _name: &'static str, _len: usize, - ) -> Result { - Err(self.sink.unsupported()) - } - - fn serialize_tuple_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - _len: usize, - ) -> Result { - Err(self.sink.unsupported()) - } - - fn serialize_map(self, _len: Option) -> Result { - Err(self.sink.unsupported()) - } - - fn serialize_struct( - self, _name: &'static str, _len: usize, - ) -> Result { - Err(self.sink.unsupported()) - } - - fn serialize_struct_variant( - self, _name: &'static str, _variant_index: u32, _variant: &'static str, - _len: usize, - ) -> Result { - Err(self.sink.unsupported()) - } -} - -impl PartSerializer { - fn serialize_integer(self, value: I) -> Result - where - I: itoa::Integer, - { - let mut buf = [b'\0'; 20]; - let len = itoa::write(&mut buf[..], value).unwrap(); - let part = unsafe { str::from_utf8_unchecked(&buf[0..len]) }; - serde::ser::Serializer::serialize_str(self, part) - } - - fn serialize_floating(self, value: F) -> Result - where - F: dtoa::Floating, - { - let mut buf = [b'\0'; 24]; - let len = dtoa::write(&mut buf[..], value).unwrap(); - let part = unsafe { str::from_utf8_unchecked(&buf[0..len]) }; - serde::ser::Serializer::serialize_str(self, part) - } -} diff --git a/src/serde_urlencoded/ser/value.rs b/src/serde_urlencoded/ser/value.rs deleted file mode 100644 index 3c47739f3..000000000 --- a/src/serde_urlencoded/ser/value.rs +++ /dev/null @@ -1,59 +0,0 @@ -use super::super::ser::part::{PartSerializer, Sink}; -use super::super::ser::Error; -use serde::ser::Serialize; -use std::str; -use url::form_urlencoded::Serializer as UrlEncodedSerializer; -use url::form_urlencoded::Target as UrlEncodedTarget; - -pub struct ValueSink<'key, 'target, Target> -where - Target: 'target + UrlEncodedTarget, -{ - urlencoder: &'target mut UrlEncodedSerializer, - key: &'key str, -} - -impl<'key, 'target, Target> ValueSink<'key, 'target, Target> -where - Target: 'target + UrlEncodedTarget, -{ - pub fn new( - urlencoder: &'target mut UrlEncodedSerializer, key: &'key str, - ) -> Self { - ValueSink { urlencoder, key } - } -} - -impl<'key, 'target, Target> Sink for ValueSink<'key, 'target, Target> -where - Target: 'target + UrlEncodedTarget, -{ - type Ok = (); - - fn serialize_str(self, value: &str) -> Result<(), Error> { - self.urlencoder.append_pair(self.key, value); - Ok(()) - } - - fn serialize_static_str(self, value: &'static str) -> Result<(), Error> { - self.serialize_str(value) - } - - fn serialize_string(self, value: String) -> Result<(), Error> { - self.serialize_str(&value) - } - - fn serialize_none(self) -> Result { - Ok(()) - } - - fn serialize_some( - self, value: &T, - ) -> Result { - value.serialize(PartSerializer::new(self)) - } - - fn unsupported(self) -> Error { - Error::Custom("unsupported value".into()) - } -} diff --git a/src/server/acceptor.rs b/src/server/acceptor.rs new file mode 100644 index 000000000..994b4b7bd --- /dev/null +++ b/src/server/acceptor.rs @@ -0,0 +1,383 @@ +use std::time::Duration; +use std::{fmt, net}; + +use actix_net::server::ServerMessage; +use actix_net::service::{NewService, Service}; +use futures::future::{err, ok, Either, FutureResult}; +use futures::{Async, Future, Poll}; +use tokio_reactor::Handle; +use tokio_tcp::TcpStream; +use tokio_timer::{sleep, Delay}; + +use super::error::AcceptorError; +use super::IoStream; + +/// This trait indicates types that can create acceptor service for http server. +pub trait AcceptorServiceFactory: Send + Clone + 'static { + type Io: IoStream + Send; + type NewService: NewService; + + fn create(&self) -> Self::NewService; +} + +impl AcceptorServiceFactory for F +where + F: Fn() -> T + Send + Clone + 'static, + T::Response: IoStream + Send, + T: NewService, + T::InitError: fmt::Debug, +{ + type Io = T::Response; + type NewService = T; + + fn create(&self) -> T { + (self)() + } +} + +#[derive(Clone)] +/// Default acceptor service convert `TcpStream` to a `tokio_tcp::TcpStream` +pub(crate) struct DefaultAcceptor; + +impl AcceptorServiceFactory for DefaultAcceptor { + type Io = TcpStream; + type NewService = DefaultAcceptor; + + fn create(&self) -> Self::NewService { + DefaultAcceptor + } +} + +impl NewService for DefaultAcceptor { + type Request = TcpStream; + type Response = TcpStream; + type Error = (); + type InitError = (); + type Service = DefaultAcceptor; + type Future = FutureResult; + + fn new_service(&self) -> Self::Future { + ok(DefaultAcceptor) + } +} + +impl Service for DefaultAcceptor { + type Request = TcpStream; + type Response = TcpStream; + type Error = (); + type Future = FutureResult; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + ok(req) + } +} + +pub(crate) struct TcpAcceptor { + inner: T, +} + +impl TcpAcceptor +where + T: NewService>, + T::InitError: fmt::Debug, +{ + pub(crate) fn new(inner: T) -> Self { + TcpAcceptor { inner } + } +} + +impl NewService for TcpAcceptor +where + T: NewService>, + T::InitError: fmt::Debug, +{ + type Request = net::TcpStream; + type Response = T::Response; + type Error = AcceptorError; + type InitError = T::InitError; + type Service = TcpAcceptorService; + type Future = TcpAcceptorResponse; + + fn new_service(&self) -> Self::Future { + TcpAcceptorResponse { + fut: self.inner.new_service(), + } + } +} + +pub(crate) struct TcpAcceptorResponse +where + T: NewService, + T::InitError: fmt::Debug, +{ + fut: T::Future, +} + +impl Future for TcpAcceptorResponse +where + T: NewService, + T::InitError: fmt::Debug, +{ + type Item = TcpAcceptorService; + type Error = T::InitError; + + fn poll(&mut self) -> Poll { + match self.fut.poll() { + Ok(Async::NotReady) => Ok(Async::NotReady), + Ok(Async::Ready(service)) => { + Ok(Async::Ready(TcpAcceptorService { inner: service })) + } + Err(e) => { + error!("Can not create accetor service: {:?}", e); + Err(e) + } + } + } +} + +pub(crate) struct TcpAcceptorService { + inner: T, +} + +impl Service for TcpAcceptorService +where + T: Service>, +{ + type Request = net::TcpStream; + type Response = T::Response; + type Error = AcceptorError; + type Future = Either>; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.inner.poll_ready() + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + let stream = TcpStream::from_std(req, &Handle::default()).map_err(|e| { + error!("Can not convert to an async tcp stream: {}", e); + AcceptorError::Io(e) + }); + + match stream { + Ok(stream) => Either::A(self.inner.call(stream)), + Err(e) => Either::B(err(e)), + } + } +} + +#[doc(hidden)] +/// Acceptor timeout middleware +/// +/// Applies timeout to request prcoessing. +pub struct AcceptorTimeout { + inner: T, + timeout: Duration, +} + +impl AcceptorTimeout { + /// Create new `AcceptorTimeout` instance. timeout is in milliseconds. + pub fn new(timeout: u64, inner: T) -> Self { + Self { + inner, + timeout: Duration::from_millis(timeout), + } + } +} + +impl NewService for AcceptorTimeout { + type Request = T::Request; + type Response = T::Response; + type Error = AcceptorError; + type InitError = T::InitError; + type Service = AcceptorTimeoutService; + type Future = AcceptorTimeoutFut; + + fn new_service(&self) -> Self::Future { + AcceptorTimeoutFut { + fut: self.inner.new_service(), + timeout: self.timeout, + } + } +} + +#[doc(hidden)] +pub struct AcceptorTimeoutFut { + fut: T::Future, + timeout: Duration, +} + +impl Future for AcceptorTimeoutFut { + type Item = AcceptorTimeoutService; + type Error = T::InitError; + + fn poll(&mut self) -> Poll { + let inner = try_ready!(self.fut.poll()); + Ok(Async::Ready(AcceptorTimeoutService { + inner, + timeout: self.timeout, + })) + } +} + +#[doc(hidden)] +/// Acceptor timeout service +/// +/// Applies timeout to request prcoessing. +pub struct AcceptorTimeoutService { + inner: T, + timeout: Duration, +} + +impl Service for AcceptorTimeoutService { + type Request = T::Request; + type Response = T::Response; + type Error = AcceptorError; + type Future = AcceptorTimeoutResponse; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.inner.poll_ready().map_err(AcceptorError::Service) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + AcceptorTimeoutResponse { + fut: self.inner.call(req), + sleep: sleep(self.timeout), + } + } +} + +#[doc(hidden)] +pub struct AcceptorTimeoutResponse { + fut: T::Future, + sleep: Delay, +} + +impl Future for AcceptorTimeoutResponse { + type Item = T::Response; + type Error = AcceptorError; + + fn poll(&mut self) -> Poll { + match self.fut.poll().map_err(AcceptorError::Service)? { + Async::NotReady => match self.sleep.poll() { + Err(_) => Err(AcceptorError::Timeout), + Ok(Async::Ready(_)) => Err(AcceptorError::Timeout), + Ok(Async::NotReady) => Ok(Async::NotReady), + }, + Async::Ready(resp) => Ok(Async::Ready(resp)), + } + } +} + +pub(crate) struct ServerMessageAcceptor { + inner: T, +} + +impl ServerMessageAcceptor +where + T: NewService, +{ + pub(crate) fn new(inner: T) -> Self { + ServerMessageAcceptor { inner } + } +} + +impl NewService for ServerMessageAcceptor +where + T: NewService, +{ + type Request = ServerMessage; + type Response = (); + type Error = T::Error; + type InitError = T::InitError; + type Service = ServerMessageAcceptorService; + type Future = ServerMessageAcceptorResponse; + + fn new_service(&self) -> Self::Future { + ServerMessageAcceptorResponse { + fut: self.inner.new_service(), + } + } +} + +pub(crate) struct ServerMessageAcceptorResponse +where + T: NewService, +{ + fut: T::Future, +} + +impl Future for ServerMessageAcceptorResponse +where + T: NewService, +{ + type Item = ServerMessageAcceptorService; + type Error = T::InitError; + + fn poll(&mut self) -> Poll { + match self.fut.poll()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(service) => Ok(Async::Ready(ServerMessageAcceptorService { + inner: service, + })), + } + } +} + +pub(crate) struct ServerMessageAcceptorService { + inner: T, +} + +impl Service for ServerMessageAcceptorService +where + T: Service, +{ + type Request = ServerMessage; + type Response = (); + type Error = T::Error; + type Future = + Either, FutureResult<(), Self::Error>>; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.inner.poll_ready() + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + match req { + ServerMessage::Connect(stream) => { + Either::A(ServerMessageAcceptorServiceFut { + fut: self.inner.call(stream), + }) + } + ServerMessage::Shutdown(_) => Either::B(ok(())), + ServerMessage::ForceShutdown => { + // self.settings + // .head() + // .traverse(|proto: &mut HttpProtocol| proto.shutdown()); + Either::B(ok(())) + } + } + } +} + +pub(crate) struct ServerMessageAcceptorServiceFut { + fut: T::Future, +} + +impl Future for ServerMessageAcceptorServiceFut +where + T: Service, +{ + type Item = (); + type Error = T::Error; + + fn poll(&mut self) -> Poll { + match self.fut.poll()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(_) => Ok(Async::Ready(())), + } + } +} diff --git a/src/server/builder.rs b/src/server/builder.rs new file mode 100644 index 000000000..ea3638f10 --- /dev/null +++ b/src/server/builder.rs @@ -0,0 +1,134 @@ +use std::{fmt, net}; + +use actix_net::either::Either; +use actix_net::server::{Server, ServiceFactory}; +use actix_net::service::{NewService, NewServiceExt}; + +use super::acceptor::{ + AcceptorServiceFactory, AcceptorTimeout, ServerMessageAcceptor, TcpAcceptor, +}; +use super::error::AcceptorError; +use super::handler::IntoHttpHandler; +use super::service::{HttpService, StreamConfiguration}; +use super::settings::{ServerSettings, ServiceConfig}; +use super::KeepAlive; + +pub(crate) trait ServiceProvider { + fn register( + &self, + server: Server, + lst: net::TcpListener, + host: String, + addr: net::SocketAddr, + keep_alive: KeepAlive, + secure: bool, + client_timeout: u64, + client_shutdown: u64, + ) -> Server; +} + +/// Utility type that builds complete http pipeline +pub(crate) struct HttpServiceBuilder +where + F: Fn() -> H + Send + Clone, +{ + factory: F, + acceptor: A, +} + +impl HttpServiceBuilder +where + F: Fn() -> H + Send + Clone + 'static, + H: IntoHttpHandler, + A: AcceptorServiceFactory, + ::InitError: fmt::Debug, +{ + /// Create http service builder + pub fn new(factory: F, acceptor: A) -> Self { + Self { factory, acceptor } + } + + fn finish( + &self, + host: String, + addr: net::SocketAddr, + keep_alive: KeepAlive, + secure: bool, + client_timeout: u64, + client_shutdown: u64, + ) -> impl ServiceFactory { + let factory = self.factory.clone(); + let acceptor = self.acceptor.clone(); + move || { + let app = (factory)().into_handler(); + let settings = ServiceConfig::new( + app, + keep_alive, + client_timeout, + client_shutdown, + ServerSettings::new(addr, &host, false), + ); + + if secure { + Either::B(ServerMessageAcceptor::new( + TcpAcceptor::new(AcceptorTimeout::new( + client_timeout, + acceptor.create(), + )).map_err(|_| ()) + .map_init_err(|_| ()) + .and_then(StreamConfiguration::new().nodelay(true)) + .and_then( + HttpService::new(settings) + .map_init_err(|_| ()) + .map_err(|_| ()), + ), + )) + } else { + Either::A(ServerMessageAcceptor::new( + TcpAcceptor::new(acceptor.create().map_err(AcceptorError::Service)) + .map_err(|_| ()) + .map_init_err(|_| ()) + .and_then(StreamConfiguration::new().nodelay(true)) + .and_then( + HttpService::new(settings) + .map_init_err(|_| ()) + .map_err(|_| ()), + ), + )) + } + } + } +} + +impl ServiceProvider for HttpServiceBuilder +where + F: Fn() -> H + Send + Clone + 'static, + A: AcceptorServiceFactory, + ::InitError: fmt::Debug, + H: IntoHttpHandler, +{ + fn register( + &self, + server: Server, + lst: net::TcpListener, + host: String, + addr: net::SocketAddr, + keep_alive: KeepAlive, + secure: bool, + client_timeout: u64, + client_shutdown: u64, + ) -> Server { + server.listen2( + "actix-web", + lst, + self.finish( + host, + addr, + keep_alive, + secure, + client_timeout, + client_shutdown, + ), + ) + } +} diff --git a/src/server/channel.rs b/src/server/channel.rs index b817b4160..d65b05e85 100644 --- a/src/server/channel.rs +++ b/src/server/channel.rs @@ -1,22 +1,43 @@ -use std::net::{Shutdown, SocketAddr}; -use std::rc::Rc; -use std::{io, ptr, time}; +use std::net::Shutdown; +use std::{io, mem, time}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use futures::{Async, Future, Poll}; use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_timer::Delay; -use super::settings::WorkerSettings; +use super::error::HttpDispatchError; +use super::settings::ServiceConfig; use super::{h1, h2, HttpHandler, IoStream}; +use http::StatusCode; const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; -enum HttpProtocol { - H1(h1::Http1), +pub(crate) enum HttpProtocol { + H1(h1::Http1Dispatcher), H2(h2::Http2), - Unknown(Rc>, Option, T, BytesMut), + Unknown(ServiceConfig, T, BytesMut), + None, } +// impl HttpProtocol { +// fn shutdown_(&mut self) { +// match self { +// HttpProtocol::H1(ref mut h1) => { +// let io = h1.io(); +// let _ = IoStream::set_linger(io, Some(time::Duration::new(0, 0))); +// let _ = IoStream::shutdown(io, Shutdown::Both); +// } +// HttpProtocol::H2(ref mut h2) => h2.shutdown(), +// HttpProtocol::Unknown(_, io, _) => { +// let _ = IoStream::set_linger(io, Some(time::Duration::new(0, 0))); +// let _ = IoStream::shutdown(io, Shutdown::Both); +// } +// HttpProtocol::None => (), +// } +// } +// } + enum ProtocolKind { Http1, Http2, @@ -28,8 +49,8 @@ where T: IoStream, H: HttpHandler + 'static, { - proto: Option>, - node: Option>>, + proto: HttpProtocol, + ka_timeout: Option, } impl HttpChannel @@ -37,45 +58,12 @@ where T: IoStream, H: HttpHandler + 'static, { - pub(crate) fn new( - settings: Rc>, mut io: T, peer: Option, - http2: bool, - ) -> HttpChannel { - settings.add_channel(); - let _ = io.set_nodelay(true); + pub(crate) fn new(settings: ServiceConfig, io: T) -> HttpChannel { + let ka_timeout = settings.client_timer(); - if http2 { - HttpChannel { - node: None, - proto: Some(HttpProtocol::H2(h2::Http2::new( - settings, - io, - peer, - Bytes::new(), - ))), - } - } else { - HttpChannel { - node: None, - proto: Some(HttpProtocol::Unknown( - settings, - peer, - io, - BytesMut::with_capacity(8192), - )), - } - } - } - - fn shutdown(&mut self) { - match self.proto { - Some(HttpProtocol::H1(ref mut h1)) => { - let io = h1.io(); - let _ = IoStream::set_linger(io, Some(time::Duration::new(0, 0))); - let _ = IoStream::shutdown(io, Shutdown::Both); - } - Some(HttpProtocol::H2(ref mut h2)) => h2.shutdown(), - _ => (), + HttpChannel { + ka_timeout, + proto: HttpProtocol::Unknown(settings, io, BytesMut::with_capacity(8192)), } } } @@ -86,70 +74,58 @@ where H: HttpHandler + 'static, { type Item = (); - type Error = (); + type Error = HttpDispatchError; fn poll(&mut self) -> Poll { - if self.node.is_some() { - let el = self as *mut _; - self.node = Some(Node::new(el)); - let _ = match self.proto { - Some(HttpProtocol::H1(ref mut h1)) => { - self.node.as_mut().map(|n| h1.settings().head().insert(n)) + // keep-alive timer + if self.ka_timeout.is_some() { + match self.ka_timeout.as_mut().unwrap().poll() { + Ok(Async::Ready(_)) => { + trace!("Slow request timed out, close connection"); + let proto = mem::replace(&mut self.proto, HttpProtocol::None); + if let HttpProtocol::Unknown(settings, io, buf) = proto { + self.proto = HttpProtocol::H1(h1::Http1Dispatcher::for_error( + settings, + io, + StatusCode::REQUEST_TIMEOUT, + self.ka_timeout.take(), + buf, + )); + return self.poll(); + } + return Ok(Async::Ready(())); } - Some(HttpProtocol::H2(ref mut h2)) => { - self.node.as_mut().map(|n| h2.settings().head().insert(n)) - } - Some(HttpProtocol::Unknown(ref mut settings, _, _, _)) => { - self.node.as_mut().map(|n| settings.head().insert(n)) - } - None => unreachable!(), - }; + Ok(Async::NotReady) => (), + Err(_) => panic!("Something is really wrong"), + } } + let mut is_eof = false; let kind = match self.proto { - Some(HttpProtocol::H1(ref mut h1)) => { - let result = h1.poll(); - match result { - Ok(Async::Ready(())) | Err(_) => { - h1.settings().remove_channel(); - if let Some(n) = self.node.as_mut() { - n.remove() - }; - } - _ => (), - } - return result; - } - Some(HttpProtocol::H2(ref mut h2)) => { - let result = h2.poll(); - match result { - Ok(Async::Ready(())) | Err(_) => { - h2.settings().remove_channel(); - if let Some(n) = self.node.as_mut() { - n.remove() - }; - } - _ => (), - } - return result; - } - Some(HttpProtocol::Unknown( - ref mut settings, - _, - ref mut io, - ref mut buf, - )) => { + HttpProtocol::H1(ref mut h1) => return h1.poll(), + HttpProtocol::H2(ref mut h2) => return h2.poll(), + HttpProtocol::Unknown(_, ref mut io, ref mut buf) => { + let mut err = None; + let mut disconnect = false; match io.read_available(buf) { - Ok(Async::Ready(true)) | Err(_) => { - debug!("Ignored premature client disconnection"); - settings.remove_channel(); - if let Some(n) = self.node.as_mut() { - n.remove() - }; - return Err(()); + Ok(Async::Ready((read_some, stream_closed))) => { + is_eof = stream_closed; + // Only disconnect if no data was read. + if is_eof && !read_some { + disconnect = true; + } + } + Err(e) => { + err = Some(e.into()); } _ => (), } + if disconnect { + debug!("Ignored premature client disconnection"); + return Ok(Async::Ready(())); + } else if let Some(e) = err { + return Err(e); + } if buf.len() >= 14 { if buf[..14] == HTTP2_PREFACE[..] { @@ -161,24 +137,30 @@ where return Ok(Async::NotReady); } } - None => unreachable!(), + HttpProtocol::None => unreachable!(), }; // upgrade to specific http protocol - if let Some(HttpProtocol::Unknown(settings, addr, io, buf)) = self.proto.take() { + let proto = mem::replace(&mut self.proto, HttpProtocol::None); + if let HttpProtocol::Unknown(settings, io, buf) = proto { match kind { ProtocolKind::Http1 => { - self.proto = - Some(HttpProtocol::H1(h1::Http1::new(settings, io, addr, buf))); + self.proto = HttpProtocol::H1(h1::Http1Dispatcher::new( + settings, + io, + buf, + is_eof, + self.ka_timeout.take(), + )); return self.poll(); } ProtocolKind::Http2 => { - self.proto = Some(HttpProtocol::H2(h2::Http2::new( + self.proto = HttpProtocol::H2(h2::Http2::new( settings, io, - addr, buf.freeze(), - ))); + self.ka_timeout.take(), + )); return self.poll(); } } @@ -187,79 +169,45 @@ where } } -pub(crate) struct Node { - next: Option<*mut Node>, - prev: Option<*mut Node>, - element: *mut T, +#[doc(hidden)] +pub struct H1Channel +where + T: IoStream, + H: HttpHandler + 'static, +{ + proto: HttpProtocol, } -impl Node { - fn new(el: *mut T) -> Self { - Node { - next: None, - prev: None, - element: el, - } - } - - fn insert(&mut self, next: &mut Node) { - unsafe { - let next: *mut Node = next as *const _ as *mut _; - - if let Some(ref mut next2) = self.next { - let n = next2.as_mut().unwrap(); - n.prev = Some(next); - } - self.next = Some(next); - - let next: &mut Node = &mut *next; - next.prev = Some(self as *mut _); - } - } - - fn remove(&mut self) { - unsafe { - self.element = ptr::null_mut(); - let next = self.next.take(); - let mut prev = self.prev.take(); - - if let Some(ref mut prev) = prev { - prev.as_mut().unwrap().next = next; - } +impl H1Channel +where + T: IoStream, + H: HttpHandler + 'static, +{ + pub(crate) fn new(settings: ServiceConfig, io: T) -> H1Channel { + H1Channel { + proto: HttpProtocol::H1(h1::Http1Dispatcher::new( + settings, + io, + BytesMut::with_capacity(8192), + false, + None, + )), } } } -impl Node<()> { - pub(crate) fn head() -> Self { - Node { - next: None, - prev: None, - element: ptr::null_mut(), - } - } +impl Future for H1Channel +where + T: IoStream, + H: HttpHandler + 'static, +{ + type Item = (); + type Error = HttpDispatchError; - pub(crate) fn traverse(&self) - where - T: IoStream, - H: HttpHandler + 'static, - { - let mut next = self.next.as_ref(); - loop { - if let Some(n) = next { - unsafe { - let n: &Node<()> = &*(n.as_ref().unwrap() as *const _); - next = n.next.as_ref(); - - if !n.element.is_null() { - let ch: &mut HttpChannel = - &mut *(&mut *(n.element as *mut _) as *mut () as *mut _); - ch.shutdown(); - } - } - } else { - return; - } + fn poll(&mut self) -> Poll { + match self.proto { + HttpProtocol::H1(ref mut h1) => h1.poll(), + _ => unreachable!(), } } } @@ -297,6 +245,10 @@ where fn set_linger(&mut self, _: Option) -> io::Result<()> { Ok(()) } + #[inline] + fn set_keepalive(&mut self, _: Option) -> io::Result<()> { + Ok(()) + } } impl io::Read for WrapperStream diff --git a/src/server/error.rs b/src/server/error.rs index b3c79a066..70f100998 100644 --- a/src/server/error.rs +++ b/src/server/error.rs @@ -1,9 +1,84 @@ +use std::io; + use futures::{Async, Poll}; +use http2; use super::{helpers, HttpHandlerTask, Writer}; use http::{StatusCode, Version}; use Error; +/// Errors produced by `AcceptorError` service. +#[derive(Debug)] +pub enum AcceptorError { + /// The inner service error + Service(T), + + /// Io specific error + Io(io::Error), + + /// The request did not complete within the specified timeout. + Timeout, +} + +#[derive(Fail, Debug)] +/// A set of errors that can occur during dispatching http requests +pub enum HttpDispatchError { + /// Application error + #[fail(display = "Application specific error: {}", _0)] + App(Error), + + /// An `io::Error` that occurred while trying to read or write to a network + /// stream. + #[fail(display = "IO error: {}", _0)] + Io(io::Error), + + /// The first request did not complete within the specified timeout. + #[fail(display = "The first request did not complete within the specified timeout")] + SlowRequestTimeout, + + /// Shutdown timeout + #[fail(display = "Connection shutdown timeout")] + ShutdownTimeout, + + /// HTTP2 error + #[fail(display = "HTTP2 error: {}", _0)] + Http2(http2::Error), + + /// Payload is not consumed + #[fail(display = "Task is completed but request's payload is not consumed")] + PayloadIsNotConsumed, + + /// Malformed request + #[fail(display = "Malformed request")] + MalformedRequest, + + /// Internal error + #[fail(display = "Internal error")] + InternalError, + + /// Unknown error + #[fail(display = "Unknown error")] + Unknown, +} + +impl From for HttpDispatchError { + fn from(err: Error) -> Self { + HttpDispatchError::App(err) + } +} + +impl From for HttpDispatchError { + fn from(err: io::Error) -> Self { + HttpDispatchError::Io(err) + } +} + +impl From for HttpDispatchError { + fn from(err: http2::Error) -> Self { + HttpDispatchError::Http2(err) + } +} + pub(crate) struct ServerError(Version, StatusCode); impl ServerError { @@ -16,8 +91,17 @@ impl HttpHandlerTask for ServerError { fn poll_io(&mut self, io: &mut Writer) -> Poll { { let bytes = io.buffer(); + // Buffer should have sufficient capacity for status line + // and extra space + bytes.reserve(helpers::STATUS_LINE_BUF_SIZE + 1); helpers::write_status_line(self.0, self.1.as_u16(), bytes); } + // Convert Status Code to Reason. + let reason = self.1.canonical_reason().unwrap_or(""); + io.buffer().extend_from_slice(reason.as_bytes()); + // No response body. + io.buffer().extend_from_slice(b"\r\ncontent-length: 0\r\n"); + // date header io.set_date(); Ok(Async::Ready(true)) } diff --git a/src/server/h1.rs b/src/server/h1.rs index 511b32bce..f491ba597 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -1,122 +1,164 @@ use std::collections::VecDeque; -use std::net::SocketAddr; -use std::rc::Rc; +use std::net::{Shutdown, SocketAddr}; use std::time::{Duration, Instant}; use bytes::BytesMut; use futures::{Async, Future, Poll}; +use tokio_current_thread::spawn; use tokio_timer::Delay; use error::{Error, PayloadError}; use http::{StatusCode, Version}; use payload::{Payload, PayloadStatus, PayloadWriter}; -use super::error::ServerError; +use super::error::{HttpDispatchError, ServerError}; use super::h1decoder::{DecoderError, H1Decoder, Message}; use super::h1writer::H1Writer; +use super::handler::{HttpHandler, HttpHandlerTask, HttpHandlerTaskFut}; use super::input::PayloadType; -use super::settings::WorkerSettings; -use super::Writer; -use super::{HttpHandler, HttpHandlerTask, IoStream}; +use super::settings::ServiceConfig; +use super::{IoStream, Writer}; const MAX_PIPELINED_MESSAGES: usize = 16; bitflags! { - struct Flags: u8 { - const STARTED = 0b0000_0001; - const ERROR = 0b0000_0010; - const KEEPALIVE = 0b0000_0100; - const SHUTDOWN = 0b0000_1000; - const DISCONNECTED = 0b0001_0000; - const POLLED = 0b0010_0000; + pub struct Flags: u8 { + const STARTED = 0b0000_0001; + const KEEPALIVE_ENABLED = 0b0000_0010; + const KEEPALIVE = 0b0000_0100; + const SHUTDOWN = 0b0000_1000; + const READ_DISCONNECTED = 0b0001_0000; + const WRITE_DISCONNECTED = 0b0010_0000; + const POLLED = 0b0100_0000; + const FLUSHED = 0b1000_0000; } } -bitflags! { - struct EntryFlags: u8 { - const EOF = 0b0000_0001; - const ERROR = 0b0000_0010; - const FINISHED = 0b0000_0100; - } -} - -pub(crate) struct Http1 { +/// Dispatcher for HTTP/1.1 protocol +pub struct Http1Dispatcher { flags: Flags, - settings: Rc>, + settings: ServiceConfig, addr: Option, stream: H1Writer, decoder: H1Decoder, payload: Option, buf: BytesMut, tasks: VecDeque>, - keepalive_timer: Option, + error: Option, + ka_expire: Instant, + ka_timer: Option, } -enum EntryPipe { +enum Entry { Task(H::Task), Error(Box), } -impl EntryPipe { +impl Entry { + fn into_task(self) -> H::Task { + match self { + Entry::Task(task) => task, + Entry::Error(_) => panic!(), + } + } fn disconnected(&mut self) { match *self { - EntryPipe::Task(ref mut task) => task.disconnected(), - EntryPipe::Error(ref mut task) => task.disconnected(), + Entry::Task(ref mut task) => task.disconnected(), + Entry::Error(ref mut task) => task.disconnected(), } } fn poll_io(&mut self, io: &mut Writer) -> Poll { match *self { - EntryPipe::Task(ref mut task) => task.poll_io(io), - EntryPipe::Error(ref mut task) => task.poll_io(io), + Entry::Task(ref mut task) => task.poll_io(io), + Entry::Error(ref mut task) => task.poll_io(io), } } fn poll_completed(&mut self) -> Poll<(), Error> { match *self { - EntryPipe::Task(ref mut task) => task.poll_completed(), - EntryPipe::Error(ref mut task) => task.poll_completed(), + Entry::Task(ref mut task) => task.poll_completed(), + Entry::Error(ref mut task) => task.poll_completed(), } } } -struct Entry { - pipe: EntryPipe, - flags: EntryFlags, -} - -impl Http1 +impl Http1Dispatcher where T: IoStream, H: HttpHandler + 'static, { pub fn new( - settings: Rc>, stream: T, addr: Option, + settings: ServiceConfig, + stream: T, buf: BytesMut, + is_eof: bool, + keepalive_timer: Option, ) -> Self { - Http1 { - flags: Flags::KEEPALIVE, - stream: H1Writer::new(stream, Rc::clone(&settings)), + let addr = stream.peer_addr(); + let (ka_expire, ka_timer) = if let Some(delay) = keepalive_timer { + (delay.deadline(), Some(delay)) + } else if let Some(delay) = settings.keep_alive_timer() { + (delay.deadline(), Some(delay)) + } else { + (settings.now(), None) + }; + + let flags = if is_eof { + Flags::READ_DISCONNECTED | Flags::FLUSHED + } else if settings.keep_alive_enabled() { + Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED | Flags::FLUSHED + } else { + Flags::empty() + }; + + Http1Dispatcher { + stream: H1Writer::new(stream, settings.clone()), decoder: H1Decoder::new(), payload: None, tasks: VecDeque::new(), - keepalive_timer: None, + error: None, + flags, addr, buf, settings, + ka_timer, + ka_expire, } } - #[inline] - pub fn settings(&self) -> &WorkerSettings { - self.settings.as_ref() - } + pub(crate) fn for_error( + settings: ServiceConfig, + stream: T, + status: StatusCode, + mut keepalive_timer: Option, + buf: BytesMut, + ) -> Self { + if let Some(deadline) = settings.client_timer_expire() { + let _ = keepalive_timer.as_mut().map(|delay| delay.reset(deadline)); + } - #[inline] - pub(crate) fn io(&mut self) -> &mut T { - self.stream.get_mut() + let mut disp = Http1Dispatcher { + flags: Flags::STARTED | Flags::READ_DISCONNECTED | Flags::FLUSHED, + stream: H1Writer::new(stream, settings.clone()), + decoder: H1Decoder::new(), + payload: None, + tasks: VecDeque::new(), + error: None, + addr: None, + ka_timer: keepalive_timer, + ka_expire: settings.now(), + buf, + settings, + }; + disp.push_response_entry(status); + disp } #[inline] fn can_read(&self) -> bool { + if self.flags.contains(Flags::READ_DISCONNECTED) { + return false; + } + if let Some(ref info) = self.payload { info.need_read() == PayloadStatus::Read } else { @@ -124,242 +166,305 @@ where } } - fn notify_disconnect(&mut self) { - // notify all tasks - self.stream.disconnected(); - for task in &mut self.tasks { - task.pipe.disconnected(); + // if checked is set to true, delay disconnect until all tasks have finished. + fn client_disconnected(&mut self, checked: bool) { + self.flags.insert(Flags::READ_DISCONNECTED); + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete); + } + + if !checked || self.tasks.is_empty() { + self.flags + .insert(Flags::WRITE_DISCONNECTED | Flags::FLUSHED); + self.stream.disconnected(); + + // notify all tasks + for mut task in self.tasks.drain(..) { + task.disconnected(); + match task.poll_completed() { + Ok(Async::NotReady) => { + // spawn not completed task, it does not require access to io + // at this point + spawn(HttpHandlerTaskFut::new(task.into_task())); + } + Ok(Async::Ready(_)) => (), + Err(err) => { + error!("Unhandled application error: {}", err); + } + } + } } } #[inline] - pub fn poll(&mut self) -> Poll<(), ()> { - // keep-alive timer - if let Some(ref mut timer) = self.keepalive_timer { - match timer.poll() { - Ok(Async::Ready(_)) => { - trace!("Keep-alive timeout, close connection"); - self.flags.insert(Flags::SHUTDOWN); - } - Ok(Async::NotReady) => (), - Err(_) => unreachable!(), - } - } + pub fn poll(&mut self) -> Poll<(), HttpDispatchError> { + // check connection keep-alive + self.poll_keepalive()?; // shutdown if self.flags.contains(Flags::SHUTDOWN) { - match self.stream.poll_completed(true) { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(_)) => return Ok(Async::Ready(())), - Err(err) => { - debug!("Error sending data: {}", err); - return Err(()); - } + if self.flags.contains(Flags::WRITE_DISCONNECTED) { + return Ok(Async::Ready(())); } + return self.poll_flush(true); } - self.poll_io(); + // process incoming requests + if !self.flags.contains(Flags::WRITE_DISCONNECTED) { + self.poll_handler()?; - loop { - match self.poll_handler()? { - Async::Ready(true) => { - self.poll_io(); + // flush stream + self.poll_flush(false)?; + + // deal with keep-alive and stream eof (client-side write shutdown) + if self.tasks.is_empty() && self.flags.contains(Flags::FLUSHED) { + // handle stream eof + if self + .flags + .intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED) + { + return Ok(Async::Ready(())); } - Async::Ready(false) => { + // no keep-alive + if self.flags.contains(Flags::STARTED) + && (!self.flags.contains(Flags::KEEPALIVE_ENABLED) + || !self.flags.contains(Flags::KEEPALIVE)) + { self.flags.insert(Flags::SHUTDOWN); return self.poll(); } - Async::NotReady => return Ok(Async::NotReady), } + Ok(Async::NotReady) + } else if let Some(err) = self.error.take() { + Err(err) + } else { + Ok(Async::Ready(())) } } - #[inline] - /// read data from stream - pub fn poll_io(&mut self) { - if !self.flags.contains(Flags::POLLED) { - self.parse(); - self.flags.insert(Flags::POLLED); - return; + /// Flush stream + fn poll_flush(&mut self, shutdown: bool) -> Poll<(), HttpDispatchError> { + if shutdown || self.flags.contains(Flags::STARTED) { + match self.stream.poll_completed(shutdown) { + Ok(Async::NotReady) => { + // mark stream + if !self.stream.flushed() { + self.flags.remove(Flags::FLUSHED); + } + Ok(Async::NotReady) + } + Err(err) => { + debug!("Error sending data: {}", err); + self.client_disconnected(false); + Err(err.into()) + } + Ok(Async::Ready(_)) => { + // if payload is not consumed we can not use connection + if self.payload.is_some() && self.tasks.is_empty() { + return Err(HttpDispatchError::PayloadIsNotConsumed); + } + self.flags.insert(Flags::FLUSHED); + Ok(Async::Ready(())) + } + } + } else { + Ok(Async::Ready(())) } - // read io from socket - if !self.flags.intersects(Flags::ERROR) - && self.tasks.len() < MAX_PIPELINED_MESSAGES - && self.can_read() - { - match self.stream.get_mut().read_available(&mut self.buf) { - Ok(Async::Ready(disconnected)) => { - if disconnected { - // notify all tasks - self.notify_disconnect(); - // kill keepalive - self.keepalive_timer.take(); + } - // on parse error, stop reading stream but tasks need to be - // completed - self.flags.insert(Flags::ERROR); + /// keep-alive timer. returns `true` is keep-alive, otherwise drop + fn poll_keepalive(&mut self) -> Result<(), HttpDispatchError> { + if let Some(ref mut timer) = self.ka_timer { + match timer.poll() { + Ok(Async::Ready(_)) => { + // if we get timer during shutdown, just drop connection + if self.flags.contains(Flags::SHUTDOWN) { + let io = self.stream.get_mut(); + let _ = IoStream::set_linger(io, Some(Duration::from_secs(0))); + let _ = IoStream::shutdown(io, Shutdown::Both); + return Err(HttpDispatchError::ShutdownTimeout); + } + if timer.deadline() >= self.ka_expire { + // check for any outstanding request handling + if self.tasks.is_empty() && self.flags.contains(Flags::FLUSHED) { + if !self.flags.contains(Flags::STARTED) { + // timeout on first request (slow request) return 408 + trace!("Slow request timeout"); + self.flags + .insert(Flags::STARTED | Flags::READ_DISCONNECTED); + self.tasks.push_back(Entry::Error(ServerError::err( + Version::HTTP_11, + StatusCode::REQUEST_TIMEOUT, + ))); + } else { + trace!("Keep-alive timeout, close connection"); + self.flags.insert(Flags::SHUTDOWN); - if let Some(mut payload) = self.payload.take() { - payload.set_error(PayloadError::Incomplete); + // start shutdown timer + if let Some(deadline) = + self.settings.client_shutdown_timer() + { + timer.reset(deadline); + let _ = timer.poll(); + } else { + return Ok(()); + } + } + } else if let Some(dl) = self.settings.keep_alive_expire() { + timer.reset(dl); + let _ = timer.poll(); } } else { - self.parse(); + timer.reset(self.ka_expire); + let _ = timer.poll(); } } Ok(Async::NotReady) => (), - Err(_) => { - // notify all tasks - self.notify_disconnect(); - // kill keepalive - self.keepalive_timer.take(); - - // on parse error, stop reading stream but tasks need to be - // completed - self.flags.insert(Flags::ERROR); - - if let Some(mut payload) = self.payload.take() { - payload.set_error(PayloadError::Incomplete); - } + Err(e) => { + error!("Timer error {:?}", e); + return Err(HttpDispatchError::Unknown); } } } + + Ok(()) } - pub fn poll_handler(&mut self) -> Poll { - let retry = self.can_read(); - - // check in-flight messages - let mut io = false; - let mut idx = 0; - while idx < self.tasks.len() { - // only one task can do io operation in http/1 - if !io && !self.tasks[idx].flags.contains(EntryFlags::EOF) { - // io is corrupted, send buffer - if self.tasks[idx].flags.contains(EntryFlags::ERROR) { - if let Ok(Async::NotReady) = self.stream.poll_completed(true) { - return Ok(Async::NotReady); - } - self.flags.insert(Flags::ERROR); - return Err(()); - } - - match self.tasks[idx].pipe.poll_io(&mut self.stream) { - Ok(Async::Ready(ready)) => { - // override keep-alive state - if self.stream.keepalive() { - self.flags.insert(Flags::KEEPALIVE); - } else { - self.flags.remove(Flags::KEEPALIVE); - } - // prepare stream for next response - self.stream.reset(); - - if ready { - self.tasks[idx] - .flags - .insert(EntryFlags::EOF | EntryFlags::FINISHED); - } else { - self.tasks[idx].flags.insert(EntryFlags::EOF); - } - } - // no more IO for this iteration - Ok(Async::NotReady) => { - // check if previously read backpressure was enabled - if self.can_read() && !retry { - return Ok(Async::Ready(true)); - } - io = true; - } - Err(err) => { - // it is not possible to recover from error - // during pipe handling, so just drop connection - self.notify_disconnect(); - self.tasks[idx].flags.insert(EntryFlags::ERROR); - error!("Unhandled error1: {}", err); - continue; - } - } - } else if !self.tasks[idx].flags.contains(EntryFlags::FINISHED) { - match self.tasks[idx].pipe.poll_completed() { - Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => { - self.tasks[idx].flags.insert(EntryFlags::FINISHED) - } - Err(err) => { - self.notify_disconnect(); - self.tasks[idx].flags.insert(EntryFlags::ERROR); - error!("Unhandled error: {}", err); - continue; - } - } - } - idx += 1; - } - - // cleanup finished tasks - let max = self.tasks.len() >= MAX_PIPELINED_MESSAGES; - while !self.tasks.is_empty() { - if self.tasks[0] - .flags - .contains(EntryFlags::EOF | EntryFlags::FINISHED) - { - self.tasks.pop_front(); - } else { - break; + #[inline] + /// read data from the stream + pub(self) fn poll_io(&mut self) -> Result { + if !self.flags.contains(Flags::POLLED) { + self.flags.insert(Flags::POLLED); + if !self.buf.is_empty() { + let updated = self.parse()?; + return Ok(updated); } } - // read more message - if max && self.tasks.len() >= MAX_PIPELINED_MESSAGES { - return Ok(Async::Ready(true)); - } - // check stream state - if self.flags.contains(Flags::STARTED) { - match self.stream.poll_completed(false) { - Ok(Async::NotReady) => return Ok(Async::NotReady), + // read io from socket + let mut updated = false; + if self.can_read() && self.tasks.len() < MAX_PIPELINED_MESSAGES { + match self.stream.get_mut().read_available(&mut self.buf) { + Ok(Async::Ready((read_some, disconnected))) => { + if read_some && self.parse()? { + updated = true; + } + if disconnected { + self.client_disconnected(true); + } + } + Ok(Async::NotReady) => (), Err(err) => { - debug!("Error sending data: {}", err); - self.notify_disconnect(); - return Err(()); - } - Ok(Async::Ready(_)) => { - // non consumed payload in that case close connection - if self.payload.is_some() && self.tasks.is_empty() { - return Ok(Async::Ready(false)); - } + self.client_disconnected(false); + return Err(err.into()); } } } - - // deal with keep-alive - if self.tasks.is_empty() { - // no keep-alive - if self.flags.contains(Flags::ERROR) - || (!self.flags.contains(Flags::KEEPALIVE) - || !self.settings.keep_alive_enabled()) - && self.flags.contains(Flags::STARTED) - { - return Ok(Async::Ready(false)); - } - - // start keep-alive timer - let keep_alive = self.settings.keep_alive(); - if self.keepalive_timer.is_none() && keep_alive > 0 { - trace!("Start keep-alive timer"); - let mut timer = - Delay::new(Instant::now() + Duration::new(keep_alive, 0)); - // register timer - let _ = timer.poll(); - self.keepalive_timer = Some(timer); - } - } - Ok(Async::NotReady) + Ok(updated) } - pub fn parse(&mut self) { + pub(self) fn poll_handler(&mut self) -> Result<(), HttpDispatchError> { + self.poll_io()?; + let mut retry = self.can_read(); + + // process first pipelined response, only first task can do io operation in http/1 + while !self.tasks.is_empty() { + match self.tasks[0].poll_io(&mut self.stream) { + Ok(Async::Ready(ready)) => { + // override keep-alive state + if self.stream.keepalive() { + self.flags.insert(Flags::KEEPALIVE); + } else { + self.flags.remove(Flags::KEEPALIVE); + } + // prepare stream for next response + self.stream.reset(); + + let task = self.tasks.pop_front().unwrap(); + if !ready { + // task is done with io operations but still needs to do more work + spawn(HttpHandlerTaskFut::new(task.into_task())); + } + } + Ok(Async::NotReady) => { + // check if we need timer + if self.ka_timer.is_some() && self.stream.upgrade() { + self.ka_timer.take(); + } + + // if read-backpressure is enabled and we consumed some data. + // we may read more dataand retry + if !retry && self.can_read() && self.poll_io()? { + retry = self.can_read(); + continue; + } + break; + } + Err(err) => { + error!("Unhandled error1: {}", err); + // it is not possible to recover from error + // during pipe handling, so just drop connection + self.client_disconnected(false); + return Err(err.into()); + } + } + } + + // check in-flight messages. all tasks must be alive, + // they need to produce response. if app returned error + // and we can not continue processing incoming requests. + let mut idx = 1; + while idx < self.tasks.len() { + let stop = match self.tasks[idx].poll_completed() { + Ok(Async::NotReady) => false, + Ok(Async::Ready(_)) => true, + Err(err) => { + self.error = Some(err.into()); + true + } + }; + if stop { + // error in task handling or task is completed, + // so no response for this task which means we can not read more requests + // because pipeline sequence is broken. + // but we can safely complete existing tasks + self.flags.insert(Flags::READ_DISCONNECTED); + + for mut task in self.tasks.drain(idx..) { + task.disconnected(); + match task.poll_completed() { + Ok(Async::NotReady) => { + // spawn not completed task, it does not require access to io + // at this point + spawn(HttpHandlerTaskFut::new(task.into_task())); + } + Ok(Async::Ready(_)) => (), + Err(err) => { + error!("Unhandled application error: {}", err); + } + } + } + break; + } else { + idx += 1; + } + } + + Ok(()) + } + + fn push_response_entry(&mut self, status: StatusCode) { + self.tasks + .push_back(Entry::Error(ServerError::err(Version::HTTP_11, status))); + } + + pub(self) fn parse(&mut self) -> Result { + let mut updated = false; + 'outer: loop { match self.decoder.decode(&mut self.buf, &self.settings) { Ok(Some(Message::Message { mut msg, payload })) => { + updated = true; self.flags.insert(Flags::STARTED); if payload { @@ -368,85 +473,83 @@ where self.payload = Some(PayloadType::new(&msg.inner.headers, ps)); } + // stream extensions + msg.inner_mut().stream_extensions = + self.stream.get_mut().extensions(); + // set remote addr msg.inner_mut().addr = self.addr; - // stop keepalive timer - self.keepalive_timer.take(); - // search handler for request - for h in self.settings.handlers().iter_mut() { - msg = match h.handle(msg) { - Ok(mut pipe) => { - if self.tasks.is_empty() { - match pipe.poll_io(&mut self.stream) { - Ok(Async::Ready(ready)) => { - // override keep-alive state - if self.stream.keepalive() { - self.flags.insert(Flags::KEEPALIVE); - } else { - self.flags.remove(Flags::KEEPALIVE); - } - // prepare stream for next response - self.stream.reset(); + match self.settings.handler().handle(msg) { + Ok(mut task) => { + if self.tasks.is_empty() { + match task.poll_io(&mut self.stream) { + Ok(Async::Ready(ready)) => { + // override keep-alive state + if self.stream.keepalive() { + self.flags.insert(Flags::KEEPALIVE); + } else { + self.flags.remove(Flags::KEEPALIVE); + } + // prepare stream for next response + self.stream.reset(); - if !ready { - let item = Entry { - pipe: EntryPipe::Task(pipe), - flags: EntryFlags::EOF, - }; - self.tasks.push_back(item); - } - continue 'outer; - } - Ok(Async::NotReady) => {} - Err(err) => { - error!("Unhandled error: {}", err); - self.flags.insert(Flags::ERROR); - return; + if !ready { + // task is done with io operations + // but still needs to do more work + spawn(HttpHandlerTaskFut::new(task)); } + continue 'outer; + } + Ok(Async::NotReady) => (), + Err(err) => { + error!("Unhandled error: {}", err); + self.client_disconnected(false); + return Err(err.into()); } } - self.tasks.push_back(Entry { - pipe: EntryPipe::Task(pipe), - flags: EntryFlags::empty(), - }); - continue 'outer; } - Err(msg) => msg, + self.tasks.push_back(Entry::Task(task)); + continue 'outer; + } + Err(_) => { + // handler is not found + self.push_response_entry(StatusCode::NOT_FOUND); } } - - // handler is not found - self.tasks.push_back(Entry { - pipe: EntryPipe::Error(ServerError::err( - Version::HTTP_11, - StatusCode::NOT_FOUND, - )), - flags: EntryFlags::empty(), - }); } Ok(Some(Message::Chunk(chunk))) => { + updated = true; if let Some(ref mut payload) = self.payload { payload.feed_data(chunk); } else { error!("Internal server error: unexpected payload chunk"); - self.flags.insert(Flags::ERROR); + self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); + self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR); + self.error = Some(HttpDispatchError::InternalError); break; } } Ok(Some(Message::Eof)) => { + updated = true; if let Some(mut payload) = self.payload.take() { payload.feed_eof(); } else { error!("Internal server error: unexpected eof"); - self.flags.insert(Flags::ERROR); + self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); + self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR); + self.error = Some(HttpDispatchError::InternalError); break; } } - Ok(None) => break, + Ok(None) => { + if self.flags.contains(Flags::READ_DISCONNECTED) { + self.client_disconnected(true); + } + break; + } Err(e) => { - self.flags.insert(Flags::ERROR); if let Some(mut payload) = self.payload.take() { let e = match e { DecoderError::Io(e) => PayloadError::Io(e), @@ -454,10 +557,22 @@ where }; payload.set_error(e); } + + // Malformed requests should be responded with 400 + self.push_response_entry(StatusCode::BAD_REQUEST); + self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); + self.error = Some(HttpDispatchError::MalformedRequest); break; } } } + + if self.ka_timer.is_some() && updated { + if let Some(expire) = self.settings.keep_alive_expire() { + self.ka_expire = expire; + } + } + Ok(updated) } } @@ -466,17 +581,30 @@ mod tests { use std::net::Shutdown; use std::{cmp, io, time}; + use actix::System; use bytes::{Buf, Bytes, BytesMut}; + use futures::future; use http::{Method, Version}; use tokio_io::{AsyncRead, AsyncWrite}; use super::*; - use application::HttpApplication; + use application::{App, HttpApplication}; use httpmessage::HttpMessage; use server::h1decoder::Message; - use server::settings::{ServerSettings, WorkerSettings}; + use server::handler::IntoHttpHandler; + use server::settings::{ServerSettings, ServiceConfig}; use server::{KeepAlive, Request}; + fn wrk_settings() -> ServiceConfig { + ServiceConfig::::new( + App::new().into_handler(), + KeepAlive::Os, + 5000, + 2000, + ServerSettings::default(), + ) + } + impl Message { fn message(self) -> Request { match self { @@ -506,8 +634,7 @@ mod tests { macro_rules! parse_ready { ($e:expr) => {{ - let settings: WorkerSettings = - WorkerSettings::new(Vec::new(), KeepAlive::Os, ServerSettings::default()); + let settings = wrk_settings(); match H1Decoder::new().decode($e, &settings) { Ok(Some(msg)) => msg.message(), Ok(_) => unreachable!("Eof during parsing http request"), @@ -518,8 +645,7 @@ mod tests { macro_rules! expect_parse_err { ($e:expr) => {{ - let settings: WorkerSettings = - WorkerSettings::new(Vec::new(), KeepAlive::Os, ServerSettings::default()); + let settings = wrk_settings(); match H1Decoder::new().decode($e, &settings) { Err(err) => match err { @@ -573,6 +699,9 @@ mod tests { fn set_linger(&mut self, _: Option) -> io::Result<()> { Ok(()) } + fn set_keepalive(&mut self, _: Option) -> io::Result<()> { + Ok(()) + } } impl io::Write for Buffer { fn write(&mut self, buf: &[u8]) -> io::Result { @@ -591,46 +720,28 @@ mod tests { } } - #[test] - fn test_req_parse() { - let buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n"); - let readbuf = BytesMut::new(); - let settings = Rc::new(WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - )); - - let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf); - h1.poll_io(); - h1.poll_io(); - assert_eq!(h1.tasks.len(), 1); - } - #[test] fn test_req_parse_err() { - let buf = Buffer::new("GET /test HTTP/1\r\n\r\n"); - let readbuf = BytesMut::new(); - let settings = Rc::new(WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - )); + let mut sys = System::new("test"); + let _ = sys.block_on(future::lazy(|| { + let buf = Buffer::new("GET /test HTTP/1\r\n\r\n"); + let readbuf = BytesMut::new(); + let settings = wrk_settings(); - let mut h1 = Http1::new(Rc::clone(&settings), buf, None, readbuf); - h1.poll_io(); - h1.poll_io(); - assert!(h1.flags.contains(Flags::ERROR)); + let mut h1 = + Http1Dispatcher::new(settings.clone(), buf, readbuf, false, None); + assert!(h1.poll_io().is_ok()); + assert!(h1.poll_io().is_ok()); + assert!(h1.flags.contains(Flags::READ_DISCONNECTED)); + assert_eq!(h1.tasks.len(), 1); + future::ok::<_, ()>(()) + })); } #[test] fn test_parse() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n"); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut reader = H1Decoder::new(); match reader.decode(&mut buf, &settings) { @@ -647,11 +758,7 @@ mod tests { #[test] fn test_parse_partial() { let mut buf = BytesMut::from("PUT /test HTTP/1"); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut reader = H1Decoder::new(); match reader.decode(&mut buf, &settings) { @@ -674,11 +781,7 @@ mod tests { #[test] fn test_parse_post() { let mut buf = BytesMut::from("POST /test2 HTTP/1.0\r\n\r\n"); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut reader = H1Decoder::new(); match reader.decode(&mut buf, &settings) { @@ -696,11 +799,7 @@ mod tests { fn test_parse_body() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut reader = H1Decoder::new(); match reader.decode(&mut buf, &settings) { @@ -727,11 +826,7 @@ mod tests { fn test_parse_body_crlf() { let mut buf = BytesMut::from("\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut reader = H1Decoder::new(); match reader.decode(&mut buf, &settings) { @@ -757,11 +852,7 @@ mod tests { #[test] fn test_parse_partial_eof() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut reader = H1Decoder::new(); assert!(reader.decode(&mut buf, &settings).unwrap().is_none()); @@ -780,11 +871,7 @@ mod tests { #[test] fn test_headers_split_field() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut reader = H1Decoder::new(); assert!{ reader.decode(&mut buf, &settings).unwrap().is_none() } @@ -815,11 +902,7 @@ mod tests { Set-Cookie: c1=cookie1\r\n\ Set-Cookie: c2=cookie2\r\n\r\n", ); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut reader = H1Decoder::new(); let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); let req = msg.message(); @@ -1015,11 +1098,7 @@ mod tests { #[test] fn test_http_request_upgrade() { - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut buf = BytesMut::from( "GET /test HTTP/1.1\r\n\ connection: upgrade\r\n\ @@ -1085,12 +1164,7 @@ mod tests { "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n", ); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); - + let settings = wrk_settings(); let mut reader = H1Decoder::new(); let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); assert!(msg.is_payload()); @@ -1125,11 +1199,7 @@ mod tests { "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n", ); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut reader = H1Decoder::new(); let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); assert!(msg.is_payload()); @@ -1163,11 +1233,7 @@ mod tests { "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n", ); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut reader = H1Decoder::new(); let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); @@ -1214,11 +1280,7 @@ mod tests { &"GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n"[..], ); - let settings = WorkerSettings::::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); + let settings = wrk_settings(); let mut reader = H1Decoder::new(); let msg = reader.decode(&mut buf, &settings).unwrap().unwrap(); diff --git a/src/server/h1decoder.rs b/src/server/h1decoder.rs index d1948a0d1..10f7e68a0 100644 --- a/src/server/h1decoder.rs +++ b/src/server/h1decoder.rs @@ -5,7 +5,7 @@ use futures::{Async, Poll}; use httparse; use super::message::{MessageFlags, Request}; -use super::settings::WorkerSettings; +use super::settings::ServiceConfig; use error::ParseError; use http::header::{HeaderName, HeaderValue}; use http::{header, HttpTryFrom, Method, Uri, Version}; @@ -18,6 +18,7 @@ pub(crate) struct H1Decoder { decoder: Option, } +#[derive(Debug)] pub(crate) enum Message { Message { msg: Request, payload: bool }, Chunk(Bytes), @@ -42,7 +43,9 @@ impl H1Decoder { } pub fn decode( - &mut self, src: &mut BytesMut, settings: &WorkerSettings, + &mut self, + src: &mut BytesMut, + settings: &ServiceConfig, ) -> Result, DecoderError> { // read payload if self.decoder.is_some() { @@ -79,7 +82,9 @@ impl H1Decoder { } fn parse_message( - &self, buf: &mut BytesMut, settings: &WorkerSettings, + &self, + buf: &mut BytesMut, + settings: &ServiceConfig, ) -> Poll<(Request, Option), ParseError> { // Parse http message let mut has_upgrade = false; @@ -166,9 +171,9 @@ impl H1Decoder { { true } else { - version == Version::HTTP_11 - && !(conn.contains("close") - || conn.contains("upgrade")) + version == Version::HTTP_11 && !(conn + .contains("close") + || conn.contains("upgrade")) } } else { false @@ -177,6 +182,13 @@ impl H1Decoder { } header::UPGRADE => { has_upgrade = true; + // check content-length, some clients (dart) + // sends "content-length: 0" with websocket upgrade + if let Ok(val) = value.to_str() { + if val == "websocket" { + content_length = None; + } + } } _ => (), } @@ -220,7 +232,9 @@ pub(crate) struct HeaderIndex { impl HeaderIndex { pub(crate) fn record( - bytes: &[u8], headers: &[httparse::Header], indices: &mut [HeaderIndex], + bytes: &[u8], + headers: &[httparse::Header], + indices: &mut [HeaderIndex], ) { let bytes_ptr = bytes.as_ptr() as usize; for (header, indices) in headers.iter().zip(indices.iter_mut()) { @@ -368,7 +382,10 @@ macro_rules! byte ( impl ChunkedState { fn step( - &self, body: &mut BytesMut, size: &mut u64, buf: &mut Option, + &self, + body: &mut BytesMut, + size: &mut u64, + buf: &mut Option, ) -> Poll { use self::ChunkedState::*; match *self { @@ -431,7 +448,8 @@ impl ChunkedState { } } fn read_size_lf( - rdr: &mut BytesMut, size: &mut u64, + rdr: &mut BytesMut, + size: &mut u64, ) -> Poll { match byte!(rdr) { b'\n' if *size > 0 => Ok(Async::Ready(ChunkedState::Body)), @@ -444,7 +462,9 @@ impl ChunkedState { } fn read_body( - rdr: &mut BytesMut, rem: &mut u64, buf: &mut Option, + rdr: &mut BytesMut, + rem: &mut u64, + buf: &mut Option, ) -> Poll { trace!("Chunked read, remaining={:?}", rem); diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs index e8f172f40..97ce6dff9 100644 --- a/src/server/h1writer.rs +++ b/src/server/h1writer.rs @@ -1,7 +1,6 @@ // #![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] use std::io::{self, Write}; -use std::rc::Rc; use bytes::{BufMut, BytesMut}; use futures::{Async, Poll}; @@ -9,7 +8,7 @@ use tokio_io::AsyncWrite; use super::helpers; use super::output::{Output, ResponseInfo, ResponseLength}; -use super::settings::WorkerSettings; +use super::settings::ServiceConfig; use super::Request; use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE}; use body::{Binary, Body}; @@ -38,11 +37,11 @@ pub(crate) struct H1Writer { headers_size: u32, buffer: Output, buffer_capacity: usize, - settings: Rc>, + settings: ServiceConfig, } impl H1Writer { - pub fn new(stream: T, settings: Rc>) -> H1Writer { + pub fn new(stream: T, settings: ServiceConfig) -> H1Writer { H1Writer { flags: Flags::KEEPALIVE, written: 0, @@ -63,7 +62,17 @@ impl H1Writer { self.flags = Flags::KEEPALIVE; } - pub fn disconnected(&mut self) {} + pub fn flushed(&mut self) -> bool { + self.buffer.is_empty() + } + + pub fn disconnected(&mut self) { + self.flags.insert(Flags::DISCONNECTED); + } + + pub fn upgrade(&self) -> bool { + self.flags.contains(Flags::UPGRADE) + } pub fn keepalive(&self) -> bool { self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) @@ -152,8 +161,7 @@ impl Writer for H1Writer { let reason = msg.reason().as_bytes(); if let Body::Binary(ref bytes) = body { buffer.reserve( - 256 - + msg.headers().len() * AVERAGE_HEADER_SIZE + 256 + msg.headers().len() * AVERAGE_HEADER_SIZE + bytes.len() + reason.len(), ); @@ -168,13 +176,11 @@ impl Writer for H1Writer { buffer.extend_from_slice(reason); // content length + let mut len_is_set = true; match info.length { ResponseLength::Chunked => { buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") } - ResponseLength::Zero => { - buffer.extend_from_slice(b"\r\ncontent-length: 0\r\n") - } ResponseLength::Length(len) => { helpers::write_content_length(len, &mut buffer) } @@ -183,6 +189,10 @@ impl Writer for H1Writer { write!(buffer.writer(), "{}", len)?; buffer.extend_from_slice(b"\r\n"); } + ResponseLength::Zero => { + len_is_set = false; + buffer.extend_from_slice(b"\r\n"); + } ResponseLength::None => buffer.extend_from_slice(b"\r\n"), } if let Some(ce) = info.content_encoding { @@ -195,47 +205,57 @@ impl Writer for H1Writer { let mut pos = 0; let mut has_date = false; let mut remaining = buffer.remaining_mut(); - unsafe { - let mut buf = &mut *(buffer.bytes_mut() as *mut [u8]); - for (key, value) in msg.headers() { - match *key { - TRANSFER_ENCODING => continue, - CONTENT_ENCODING => if encoding != ContentEncoding::Identity { - continue; - }, - CONTENT_LENGTH => match info.length { - ResponseLength::None => (), - _ => continue, - }, - DATE => { - has_date = true; + let mut buf = unsafe { &mut *(buffer.bytes_mut() as *mut [u8]) }; + for (key, value) in msg.headers() { + match *key { + TRANSFER_ENCODING => continue, + CONTENT_ENCODING => if encoding != ContentEncoding::Identity { + continue; + }, + CONTENT_LENGTH => match info.length { + ResponseLength::None => (), + ResponseLength::Zero => { + len_is_set = true; } - _ => (), + _ => continue, + }, + DATE => { + has_date = true; } + _ => (), + } - let v = value.as_ref(); - let k = key.as_str().as_bytes(); - let len = k.len() + v.len() + 4; - if len > remaining { + let v = value.as_ref(); + let k = key.as_str().as_bytes(); + let len = k.len() + v.len() + 4; + if len > remaining { + unsafe { buffer.advance_mut(pos); - pos = 0; - buffer.reserve(len); - remaining = buffer.remaining_mut(); + } + pos = 0; + buffer.reserve(len); + remaining = buffer.remaining_mut(); + unsafe { buf = &mut *(buffer.bytes_mut() as *mut _); } - - buf[pos..pos + k.len()].copy_from_slice(k); - pos += k.len(); - buf[pos..pos + 2].copy_from_slice(b": "); - pos += 2; - buf[pos..pos + v.len()].copy_from_slice(v); - pos += v.len(); - buf[pos..pos + 2].copy_from_slice(b"\r\n"); - pos += 2; - remaining -= len; } + + buf[pos..pos + k.len()].copy_from_slice(k); + pos += k.len(); + buf[pos..pos + 2].copy_from_slice(b": "); + pos += 2; + buf[pos..pos + v.len()].copy_from_slice(v); + pos += v.len(); + buf[pos..pos + 2].copy_from_slice(b"\r\n"); + pos += 2; + remaining -= len; + } + unsafe { buffer.advance_mut(pos); } + if !len_is_set { + buffer.extend_from_slice(b"content-length: 0\r\n") + } // optimized date header, set_date writes \r\n if !has_date { @@ -269,10 +289,7 @@ impl Writer for H1Writer { let pl: &[u8] = payload.as_ref(); let n = match Self::write_data(&mut self.stream, pl) { Err(err) => { - if err.kind() == io::ErrorKind::WriteZero { - self.disconnected(); - } - + self.disconnected(); return Err(err); } Ok(val) => val, @@ -316,14 +333,15 @@ impl Writer for H1Writer { #[inline] fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> { + if self.flags.contains(Flags::DISCONNECTED) { + return Err(io::Error::new(io::ErrorKind::Other, "disconnected")); + } + if !self.buffer.is_empty() { let written = { match Self::write_data(&mut self.stream, self.buffer.as_ref().as_ref()) { Err(err) => { - if err.kind() == io::ErrorKind::WriteZero { - self.disconnected(); - } - + self.disconnected(); return Err(err); } Ok(val) => val, @@ -337,9 +355,10 @@ impl Writer for H1Writer { } } if shutdown { + self.stream.poll_flush()?; self.stream.shutdown() } else { - Ok(Async::Ready(())) + Ok(self.stream.poll_flush()?) } } } diff --git a/src/server/h2.rs b/src/server/h2.rs index 2322f755a..c9e968a39 100644 --- a/src/server/h2.rs +++ b/src/server/h2.rs @@ -2,7 +2,7 @@ use std::collections::VecDeque; use std::io::{Read, Write}; use std::net::SocketAddr; use std::rc::Rc; -use std::time::{Duration, Instant}; +use std::time::Instant; use std::{cmp, io, mem}; use bytes::{Buf, Bytes}; @@ -14,19 +14,21 @@ use tokio_io::{AsyncRead, AsyncWrite}; use tokio_timer::Delay; use error::{Error, PayloadError}; +use extensions::Extensions; use http::{StatusCode, Version}; use payload::{Payload, PayloadStatus, PayloadWriter}; use uri::Url; -use super::error::ServerError; +use super::error::{HttpDispatchError, ServerError}; use super::h2writer::H2Writer; use super::input::PayloadType; -use super::settings::WorkerSettings; -use super::{HttpHandler, HttpHandlerTask, Writer}; +use super::settings::ServiceConfig; +use super::{HttpHandler, HttpHandlerTask, IoStream, Writer}; bitflags! { struct Flags: u8 { - const DISCONNECTED = 0b0000_0010; + const DISCONNECTED = 0b0000_0001; + const SHUTDOWN = 0b0000_0010; } } @@ -37,11 +39,13 @@ where H: HttpHandler + 'static, { flags: Flags, - settings: Rc>, + settings: ServiceConfig, addr: Option, state: State>, tasks: VecDeque>, - keepalive_timer: Option, + extensions: Option>, + ka_expire: Instant, + ka_timer: Option, } enum State { @@ -52,12 +56,27 @@ enum State { impl Http2 where - T: AsyncRead + AsyncWrite + 'static, + T: IoStream + 'static, H: HttpHandler + 'static, { pub fn new( - settings: Rc>, io: T, addr: Option, buf: Bytes, + settings: ServiceConfig, + io: T, + buf: Bytes, + keepalive_timer: Option, ) -> Self { + let addr = io.peer_addr(); + let extensions = io.extensions(); + + // keep-alive timeout + let (ka_expire, ka_timer) = if let Some(delay) = keepalive_timer { + (delay.deadline(), Some(delay)) + } else if let Some(delay) = settings.keep_alive_timer() { + (delay.deadline(), Some(delay)) + } else { + (settings.now(), None) + }; + Http2 { flags: Flags::empty(), tasks: VecDeque::new(), @@ -65,84 +84,84 @@ where unread: if buf.is_empty() { None } else { Some(buf) }, inner: io, })), - keepalive_timer: None, addr, settings, + extensions, + ka_expire, + ka_timer, } } - pub(crate) fn shutdown(&mut self) { - self.state = State::Empty; - self.tasks.clear(); - self.keepalive_timer.take(); - } + pub fn poll(&mut self) -> Poll<(), HttpDispatchError> { + self.poll_keepalive()?; - pub fn settings(&self) -> &WorkerSettings { - self.settings.as_ref() - } - - pub fn poll(&mut self) -> Poll<(), ()> { // server if let State::Connection(ref mut conn) = self.state { - // keep-alive timer - if let Some(ref mut timeout) = self.keepalive_timer { - match timeout.poll() { - Ok(Async::Ready(_)) => { - trace!("Keep-alive timeout, close connection"); - return Ok(Async::Ready(())); - } - Ok(Async::NotReady) => (), - Err(_) => unreachable!(), - } - } - loop { + // shutdown connection + if self.flags.contains(Flags::SHUTDOWN) { + return conn.poll_close().map_err(|e| e.into()); + } + let mut not_ready = true; + let disconnected = self.flags.contains(Flags::DISCONNECTED); // check in-flight connections for item in &mut self.tasks { // read payload - item.poll_payload(); + if !disconnected { + item.poll_payload(); + } if !item.flags.contains(EntryFlags::EOF) { - let retry = item.payload.need_read() == PayloadStatus::Read; - loop { - match item.task.poll_io(&mut item.stream) { - Ok(Async::Ready(ready)) => { - if ready { + if disconnected { + item.flags.insert(EntryFlags::EOF); + } else { + let retry = item.payload.need_read() == PayloadStatus::Read; + loop { + match item.task.poll_io(&mut item.stream) { + Ok(Async::Ready(ready)) => { + if ready { + item.flags.insert( + EntryFlags::EOF | EntryFlags::FINISHED, + ); + } else { + item.flags.insert(EntryFlags::EOF); + } + not_ready = false; + } + Ok(Async::NotReady) => { + if item.payload.need_read() + == PayloadStatus::Read + && !retry + { + continue; + } + } + Err(err) => { + error!("Unhandled error: {}", err); item.flags.insert( - EntryFlags::EOF | EntryFlags::FINISHED, + EntryFlags::EOF + | EntryFlags::ERROR + | EntryFlags::WRITE_DONE, ); - } else { - item.flags.insert(EntryFlags::EOF); - } - not_ready = false; - } - Ok(Async::NotReady) => { - if item.payload.need_read() == PayloadStatus::Read - && !retry - { - continue; + item.stream.reset(Reason::INTERNAL_ERROR); } } - Err(err) => { - error!("Unhandled error: {}", err); - item.flags.insert( - EntryFlags::EOF - | EntryFlags::ERROR - | EntryFlags::WRITE_DONE, - ); - item.stream.reset(Reason::INTERNAL_ERROR); - } + break; } - break; } - } else if !item.flags.contains(EntryFlags::FINISHED) { + } + + if item.flags.contains(EntryFlags::EOF) + && !item.flags.contains(EntryFlags::FINISHED) + { match item.task.poll_completed() { Ok(Async::NotReady) => (), Ok(Async::Ready(_)) => { - not_ready = false; - item.flags.insert(EntryFlags::FINISHED); + item.flags.insert( + EntryFlags::FINISHED | EntryFlags::WRITE_DONE, + ); } Err(err) => { item.flags.insert( @@ -155,14 +174,17 @@ where } } - if !item.flags.contains(EntryFlags::WRITE_DONE) { + if item.flags.contains(EntryFlags::FINISHED) + && !item.flags.contains(EntryFlags::WRITE_DONE) + && !disconnected + { match item.stream.poll_completed(false) { Ok(Async::NotReady) => (), Ok(Async::Ready(_)) => { not_ready = false; item.flags.insert(EntryFlags::WRITE_DONE); } - Err(_err) => { + Err(_) => { item.flags.insert(EntryFlags::ERROR); } } @@ -171,7 +193,7 @@ where // cleanup finished tasks while !self.tasks.is_empty() { - if self.tasks[0].flags.contains(EntryFlags::EOF) + if self.tasks[0].flags.contains(EntryFlags::FINISHED) && self.tasks[0].flags.contains(EntryFlags::WRITE_DONE) || self.tasks[0].flags.contains(EntryFlags::ERROR) { @@ -195,50 +217,30 @@ where not_ready = false; let (parts, body) = req.into_parts(); - // stop keepalive timer - self.keepalive_timer.take(); + // update keep-alive expire + if self.ka_timer.is_some() { + if let Some(expire) = self.settings.keep_alive_expire() { + self.ka_expire = expire; + } + } self.tasks.push_back(Entry::new( parts, body, resp, self.addr, - &self.settings, + self.settings.clone(), + self.extensions.clone(), )); } - Ok(Async::NotReady) => { - // start keep-alive timer - if self.tasks.is_empty() { - if self.settings.keep_alive_enabled() { - let keep_alive = self.settings.keep_alive(); - if keep_alive > 0 && self.keepalive_timer.is_none() { - trace!("Start keep-alive timer"); - let mut timeout = Delay::new( - Instant::now() - + Duration::new(keep_alive, 0), - ); - // register timeout - let _ = timeout.poll(); - self.keepalive_timer = Some(timeout); - } - } else { - // keep-alive disable, drop connection - return conn.poll_close().map_err(|e| { - error!("Error during connection close: {}", e) - }); - } - } else { - // keep-alive unset, rely on operating system - return Ok(Async::NotReady); - } - } + Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => { trace!("Connection error: {}", err); - self.flags.insert(Flags::DISCONNECTED); + self.flags.insert(Flags::SHUTDOWN); for entry in &mut self.tasks { entry.task.disconnected() } - self.keepalive_timer.take(); + continue; } } } @@ -246,9 +248,7 @@ where if not_ready { if self.tasks.is_empty() && self.flags.contains(Flags::DISCONNECTED) { - return conn - .poll_close() - .map_err(|e| error!("Error during connection close: {}", e)); + return conn.poll_close().map_err(|e| e.into()); } else { return Ok(Async::NotReady); } @@ -263,7 +263,7 @@ where Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => { trace!("Error handling connection: {}", err); - return Err(()); + return Err(err.into()); } } } else { @@ -272,6 +272,39 @@ where self.poll() } + + /// keep-alive timer. returns `true` is keep-alive, otherwise drop + fn poll_keepalive(&mut self) -> Result<(), HttpDispatchError> { + if let Some(ref mut timer) = self.ka_timer { + match timer.poll() { + Ok(Async::Ready(_)) => { + // if we get timer during shutdown, just drop connection + if self.flags.contains(Flags::SHUTDOWN) { + return Err(HttpDispatchError::ShutdownTimeout); + } + if timer.deadline() >= self.ka_expire { + // check for any outstanding request handling + if self.tasks.is_empty() { + return Err(HttpDispatchError::ShutdownTimeout); + } else if let Some(dl) = self.settings.keep_alive_expire() { + timer.reset(dl); + let _ = timer.poll(); + } + } else { + timer.reset(self.ka_expire); + let _ = timer.poll(); + } + } + Ok(Async::NotReady) => (), + Err(e) => { + error!("Timer error {:?}", e); + return Err(HttpDispatchError::Unknown); + } + } + } + + Ok(()) + } } bitflags! { @@ -320,8 +353,12 @@ struct Entry { impl Entry { fn new( - parts: Parts, recv: RecvStream, resp: SendResponse, - addr: Option, settings: &Rc>, + parts: Parts, + recv: RecvStream, + resp: SendResponse, + addr: Option, + settings: ServiceConfig, + extensions: Option>, ) -> Entry where H: HttpHandler + 'static, @@ -336,6 +373,7 @@ impl Entry { inner.method = parts.method; inner.version = parts.version; inner.headers = parts.headers; + inner.stream_extensions = extensions; *inner.payload.borrow_mut() = Some(payload); inner.addr = addr; } @@ -344,28 +382,20 @@ impl Entry { let psender = PayloadType::new(msg.headers(), psender); // start request processing - let mut task = None; - for h in settings.handlers().iter_mut() { - msg = match h.handle(msg) { - Ok(t) => { - task = Some(t); - break; - } - Err(msg) => msg, - } - } + let task = match settings.handler().handle(msg) { + Ok(task) => EntryPipe::Task(task), + Err(_) => EntryPipe::Error(ServerError::err( + Version::HTTP_2, + StatusCode::NOT_FOUND, + )), + }; Entry { - task: task.map(EntryPipe::Task).unwrap_or_else(|| { - EntryPipe::Error(ServerError::err( - Version::HTTP_2, - StatusCode::NOT_FOUND, - )) - }), - payload: psender, - stream: H2Writer::new(resp, Rc::clone(settings)), - flags: EntryFlags::empty(), + task, recv, + payload: psender, + stream: H2Writer::new(resp, settings), + flags: EntryFlags::empty(), } } diff --git a/src/server/h2writer.rs b/src/server/h2writer.rs index c4fc59972..fef6f889a 100644 --- a/src/server/h2writer.rs +++ b/src/server/h2writer.rs @@ -1,23 +1,27 @@ -#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] +#![cfg_attr( + feature = "cargo-clippy", + allow(redundant_field_names) +)] + +use std::{cmp, io}; use bytes::{Bytes, BytesMut}; use futures::{Async, Poll}; use http2::server::SendResponse; use http2::{Reason, SendStream}; use modhttp::Response; -use std::rc::Rc; -use std::{cmp, io}; - -use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; -use http::{HttpTryFrom, Method, Version}; use super::helpers; use super::message::Request; -use super::output::{Output, ResponseInfo}; -use super::settings::WorkerSettings; +use super::output::{Output, ResponseInfo, ResponseLength}; +use super::settings::ServiceConfig; use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE}; use body::{Binary, Body}; use header::ContentEncoding; +use http::header::{ + HeaderValue, CONNECTION, CONTENT_ENCODING, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, +}; +use http::{HttpTryFrom, Method, Version}; use httpresponse::HttpResponse; const CHUNK_SIZE: usize = 16_384; @@ -38,13 +42,11 @@ pub(crate) struct H2Writer { written: u64, buffer: Output, buffer_capacity: usize, - settings: Rc>, + settings: ServiceConfig, } impl H2Writer { - pub fn new( - respond: SendResponse, settings: Rc>, - ) -> H2Writer { + pub fn new(respond: SendResponse, settings: ServiceConfig) -> H2Writer { H2Writer { stream: None, flags: Flags::empty(), @@ -92,50 +94,74 @@ impl Writer for H2Writer { let mut info = ResponseInfo::new(req.inner.method == Method::HEAD); self.buffer.for_server(&mut info, &req.inner, msg, encoding); - // http2 specific - msg.headers_mut().remove(CONNECTION); - msg.headers_mut().remove(TRANSFER_ENCODING); - - // using helpers::date is quite a lot faster - if !msg.headers().contains_key(DATE) { - let mut bytes = BytesMut::with_capacity(29); - self.settings.set_date(&mut bytes, false); - msg.headers_mut() - .insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap()); - } - - let body = msg.replace_body(Body::Empty); - match body { - Body::Binary(ref bytes) => { - if bytes.is_empty() { - msg.headers_mut() - .insert(CONTENT_LENGTH, HeaderValue::from_static("0")); - self.flags.insert(Flags::EOF); - } else { - let mut val = BytesMut::new(); - helpers::convert_usize(bytes.len(), &mut val); - let l = val.len(); - msg.headers_mut().insert( - CONTENT_LENGTH, - HeaderValue::try_from(val.split_to(l - 2).freeze()).unwrap(), - ); - } - } - Body::Empty => { - self.flags.insert(Flags::EOF); - msg.headers_mut() - .insert(CONTENT_LENGTH, HeaderValue::from_static("0")); - } - _ => (), - } - + let mut has_date = false; let mut resp = Response::new(()); + let mut len_is_set = false; *resp.status_mut() = msg.status(); *resp.version_mut() = Version::HTTP_2; for (key, value) in msg.headers().iter() { - resp.headers_mut().insert(key, value.clone()); + match *key { + // http2 specific + CONNECTION | TRANSFER_ENCODING => continue, + CONTENT_ENCODING => if encoding != ContentEncoding::Identity { + continue; + }, + CONTENT_LENGTH => match info.length { + ResponseLength::None => (), + ResponseLength::Zero => { + len_is_set = true; + } + _ => continue, + }, + DATE => has_date = true, + _ => (), + } + resp.headers_mut().append(key, value.clone()); } + // set date header + if !has_date { + let mut bytes = BytesMut::with_capacity(29); + self.settings.set_date(&mut bytes, false); + resp.headers_mut() + .insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap()); + } + + // content length + match info.length { + ResponseLength::Zero => { + if !len_is_set { + resp.headers_mut() + .insert(CONTENT_LENGTH, HeaderValue::from_static("0")); + } + self.flags.insert(Flags::EOF); + } + ResponseLength::Length(len) => { + let mut val = BytesMut::new(); + helpers::convert_usize(len, &mut val); + let l = val.len(); + resp.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(val.split_to(l - 2).freeze()).unwrap(), + ); + } + ResponseLength::Length64(len) => { + let l = format!("{}", len); + resp.headers_mut() + .insert(CONTENT_LENGTH, HeaderValue::try_from(l.as_str()).unwrap()); + } + ResponseLength::None => { + self.flags.insert(Flags::EOF); + } + _ => (), + } + if let Some(ce) = info.content_encoding { + resp.headers_mut() + .insert(CONTENT_ENCODING, HeaderValue::try_from(ce).unwrap()); + } + + trace!("Response: {:?}", resp); + match self .respond .send_response(resp, self.flags.contains(Flags::EOF)) @@ -144,14 +170,12 @@ impl Writer for H2Writer { Err(_) => return Err(io::Error::new(io::ErrorKind::Other, "err")), } - trace!("Response: {:?}", msg); - + let body = msg.replace_body(Body::Empty); if let Body::Binary(bytes) = body { if bytes.is_empty() { Ok(WriterState::Done) } else { self.flags.insert(Flags::EOF); - self.written = bytes.len() as u64; self.buffer.write(bytes.as_ref())?; if let Some(ref mut stream) = self.stream { self.flags.insert(Flags::RESERVED); @@ -167,8 +191,6 @@ impl Writer for H2Writer { } fn write(&mut self, payload: &Binary) -> io::Result { - self.written = payload.len() as u64; - if !self.flags.contains(Flags::DISCONNECTED) { if self.flags.contains(Flags::STARTED) { // TODO: add warning, write after EOF @@ -229,14 +251,18 @@ impl Writer for H2Writer { let cap = cmp::min(self.buffer.len(), CHUNK_SIZE); stream.reserve_capacity(cap); } else { + if eof { + stream.reserve_capacity(0); + continue; + } self.flags.remove(Flags::RESERVED); - return Ok(Async::NotReady); + return Ok(Async::Ready(())); } } Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), } } } - Ok(Async::NotReady) + Ok(Async::Ready(())) } } diff --git a/src/server/handler.rs b/src/server/handler.rs new file mode 100644 index 000000000..33e50ac34 --- /dev/null +++ b/src/server/handler.rs @@ -0,0 +1,208 @@ +use futures::{Async, Future, Poll}; + +use super::message::Request; +use super::Writer; +use error::Error; + +/// Low level http request handler +#[allow(unused_variables)] +pub trait HttpHandler: 'static { + /// Request handling task + type Task: HttpHandlerTask; + + /// Handle request + fn handle(&self, req: Request) -> Result; +} + +impl HttpHandler for Box>> { + type Task = Box; + + fn handle(&self, req: Request) -> Result, Request> { + self.as_ref().handle(req) + } +} + +/// Low level http request handler +pub trait HttpHandlerTask { + /// Poll task, this method is used before or after *io* object is available + fn poll_completed(&mut self) -> Poll<(), Error> { + Ok(Async::Ready(())) + } + + /// Poll task when *io* object is available + fn poll_io(&mut self, io: &mut Writer) -> Poll; + + /// Connection is disconnected + fn disconnected(&mut self) {} +} + +impl HttpHandlerTask for Box { + fn poll_io(&mut self, io: &mut Writer) -> Poll { + self.as_mut().poll_io(io) + } +} + +pub(super) struct HttpHandlerTaskFut { + task: T, +} + +impl HttpHandlerTaskFut { + pub(crate) fn new(task: T) -> Self { + Self { task } + } +} + +impl Future for HttpHandlerTaskFut { + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + self.task.poll_completed().map_err(|_| ()) + } +} + +/// Conversion helper trait +pub trait IntoHttpHandler { + /// The associated type which is result of conversion. + type Handler: HttpHandler; + + /// Convert into `HttpHandler` object. + fn into_handler(self) -> Self::Handler; +} + +impl IntoHttpHandler for T { + type Handler = T; + + fn into_handler(self) -> Self::Handler { + self + } +} + +impl IntoHttpHandler for Vec { + type Handler = VecHttpHandler; + + fn into_handler(self) -> Self::Handler { + VecHttpHandler(self.into_iter().map(|item| item.into_handler()).collect()) + } +} + +#[doc(hidden)] +pub struct VecHttpHandler(Vec); + +impl HttpHandler for VecHttpHandler { + type Task = H::Task; + + fn handle(&self, mut req: Request) -> Result { + for h in &self.0 { + req = match h.handle(req) { + Ok(task) => return Ok(task), + Err(e) => e, + }; + } + Err(req) + } +} + +macro_rules! http_handler ({$EN:ident, $(($n:tt, $T:ident)),+} => { + impl<$($T: HttpHandler,)+> HttpHandler for ($($T,)+) { + type Task = $EN<$($T,)+>; + + fn handle(&self, mut req: Request) -> Result { + $( + req = match self.$n.handle(req) { + Ok(task) => return Ok($EN::$T(task)), + Err(e) => e, + }; + )+ + Err(req) + } + } + + #[doc(hidden)] + pub enum $EN<$($T: HttpHandler,)+> { + $($T ($T::Task),)+ + } + + impl<$($T: HttpHandler,)+> HttpHandlerTask for $EN<$($T,)+> + { + fn poll_completed(&mut self) -> Poll<(), Error> { + match self { + $($EN :: $T(ref mut task) => task.poll_completed(),)+ + } + } + + fn poll_io(&mut self, io: &mut Writer) -> Poll { + match self { + $($EN::$T(ref mut task) => task.poll_io(io),)+ + } + } + + /// Connection is disconnected + fn disconnected(&mut self) { + match self { + $($EN::$T(ref mut task) => task.disconnected(),)+ + } + } + } +}); + +http_handler!(HttpHandlerTask1, (0, A)); +http_handler!(HttpHandlerTask2, (0, A), (1, B)); +http_handler!(HttpHandlerTask3, (0, A), (1, B), (2, C)); +http_handler!(HttpHandlerTask4, (0, A), (1, B), (2, C), (3, D)); +http_handler!(HttpHandlerTask5, (0, A), (1, B), (2, C), (3, D), (4, E)); +http_handler!( + HttpHandlerTask6, + (0, A), + (1, B), + (2, C), + (3, D), + (4, E), + (5, F) +); +http_handler!( + HttpHandlerTask7, + (0, A), + (1, B), + (2, C), + (3, D), + (4, E), + (5, F), + (6, G) +); +http_handler!( + HttpHandlerTask8, + (0, A), + (1, B), + (2, C), + (3, D), + (4, E), + (5, F), + (6, G), + (7, H) +); +http_handler!( + HttpHandlerTask9, + (0, A), + (1, B), + (2, C), + (3, D), + (4, E), + (5, F), + (6, G), + (7, H), + (8, I) +); +http_handler!( + HttpHandlerTask10, + (0, A), + (1, B), + (2, C), + (3, D), + (4, E), + (5, F), + (6, G), + (7, H), + (8, I), + (9, J) +); diff --git a/src/server/helpers.rs b/src/server/helpers.rs index 03bbc8310..e4ccd8aef 100644 --- a/src/server/helpers.rs +++ b/src/server/helpers.rs @@ -8,8 +8,10 @@ const DEC_DIGITS_LUT: &[u8] = b"0001020304050607080910111213141516171819\ 6061626364656667686970717273747576777879\ 8081828384858687888990919293949596979899"; +pub(crate) const STATUS_LINE_BUF_SIZE: usize = 13; + pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesMut) { - let mut buf: [u8; 13] = [ + let mut buf: [u8; STATUS_LINE_BUF_SIZE] = [ b'H', b'T', b'T', b'P', b'/', b'1', b'.', b'1', b' ', b' ', b' ', b' ', b' ', ]; match version { @@ -27,20 +29,24 @@ pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesM let lut_ptr = DEC_DIGITS_LUT.as_ptr(); let four = n > 999; + // decode 2 more chars, if > 2 chars + let d1 = (n % 100) << 1; + n /= 100; + curr -= 2; unsafe { - // decode 2 more chars, if > 2 chars - let d1 = (n % 100) << 1; - n /= 100; - curr -= 2; ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), 2); + } - // decode last 1 or 2 chars - if n < 10 { - curr -= 1; + // decode last 1 or 2 chars + if n < 10 { + curr -= 1; + unsafe { *buf_ptr.offset(curr) = (n as u8) + b'0'; - } else { - let d1 = n << 1; - curr -= 2; + } + } else { + let d1 = n << 1; + curr -= 2; + unsafe { ptr::copy_nonoverlapping( lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), @@ -72,7 +78,7 @@ pub fn write_content_length(mut n: usize, bytes: &mut BytesMut) { let d1 = n << 1; unsafe { ptr::copy_nonoverlapping( - DEC_DIGITS_LUT.as_ptr().offset(d1 as isize), + DEC_DIGITS_LUT.as_ptr().add(d1), buf.as_mut_ptr().offset(18), 2, ); @@ -88,7 +94,7 @@ pub fn write_content_length(mut n: usize, bytes: &mut BytesMut) { n /= 100; unsafe { ptr::copy_nonoverlapping( - DEC_DIGITS_LUT.as_ptr().offset(d1 as isize), + DEC_DIGITS_LUT.as_ptr().add(d1), buf.as_mut_ptr().offset(19), 2, ) @@ -105,47 +111,55 @@ pub fn write_content_length(mut n: usize, bytes: &mut BytesMut) { } pub(crate) fn convert_usize(mut n: usize, bytes: &mut BytesMut) { - unsafe { - let mut curr: isize = 39; - let mut buf: [u8; 41] = mem::uninitialized(); - buf[39] = b'\r'; - buf[40] = b'\n'; - let buf_ptr = buf.as_mut_ptr(); - let lut_ptr = DEC_DIGITS_LUT.as_ptr(); + let mut curr: isize = 39; + let mut buf: [u8; 41] = unsafe { mem::uninitialized() }; + buf[39] = b'\r'; + buf[40] = b'\n'; + let buf_ptr = buf.as_mut_ptr(); + let lut_ptr = DEC_DIGITS_LUT.as_ptr(); - // eagerly decode 4 characters at a time - while n >= 10_000 { - let rem = (n % 10_000) as isize; - n /= 10_000; + // eagerly decode 4 characters at a time + while n >= 10_000 { + let rem = (n % 10_000) as isize; + n /= 10_000; - let d1 = (rem / 100) << 1; - let d2 = (rem % 100) << 1; - curr -= 4; + let d1 = (rem / 100) << 1; + let d2 = (rem % 100) << 1; + curr -= 4; + unsafe { ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); ptr::copy_nonoverlapping(lut_ptr.offset(d2), buf_ptr.offset(curr + 2), 2); } + } - // if we reach here numbers are <= 9999, so at most 4 chars long - let mut n = n as isize; // possibly reduce 64bit math + // if we reach here numbers are <= 9999, so at most 4 chars long + let mut n = n as isize; // possibly reduce 64bit math - // decode 2 more chars, if > 2 chars - if n >= 100 { - let d1 = (n % 100) << 1; - n /= 100; - curr -= 2; + // decode 2 more chars, if > 2 chars + if n >= 100 { + let d1 = (n % 100) << 1; + n /= 100; + curr -= 2; + unsafe { ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); } + } - // decode last 1 or 2 chars - if n < 10 { - curr -= 1; + // decode last 1 or 2 chars + if n < 10 { + curr -= 1; + unsafe { *buf_ptr.offset(curr) = (n as u8) + b'0'; - } else { - let d1 = n << 1; - curr -= 2; + } + } else { + let d1 = n << 1; + curr -= 2; + unsafe { ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); } + } + unsafe { bytes.extend_from_slice(slice::from_raw_parts( buf_ptr.offset(curr), 41 - curr as usize, diff --git a/src/server/http.rs b/src/server/http.rs new file mode 100644 index 000000000..0bec8be3f --- /dev/null +++ b/src/server/http.rs @@ -0,0 +1,579 @@ +use std::{fmt, io, mem, net}; + +use actix::{Addr, System}; +use actix_net::server::Server; +use actix_net::service::NewService; +use actix_net::ssl; + +use net2::TcpBuilder; +use num_cpus; + +#[cfg(feature = "tls")] +use native_tls::TlsAcceptor; + +#[cfg(any(feature = "alpn", feature = "ssl"))] +use openssl::ssl::SslAcceptorBuilder; + +#[cfg(feature = "rust-tls")] +use rustls::ServerConfig; + +use super::acceptor::{AcceptorServiceFactory, DefaultAcceptor}; +use super::builder::{HttpServiceBuilder, ServiceProvider}; +use super::{IntoHttpHandler, KeepAlive}; + +struct Socket { + scheme: &'static str, + lst: net::TcpListener, + addr: net::SocketAddr, + handler: Box, +} + +/// An HTTP Server +/// +/// By default it serves HTTP2 when HTTPs is enabled, +/// in order to change it, use `ServerFlags` that can be provided +/// to acceptor service. +pub struct HttpServer +where + H: IntoHttpHandler + 'static, + F: Fn() -> H + Send + Clone, +{ + pub(super) factory: F, + pub(super) host: Option, + pub(super) keep_alive: KeepAlive, + pub(super) client_timeout: u64, + pub(super) client_shutdown: u64, + backlog: i32, + threads: usize, + exit: bool, + shutdown_timeout: u16, + no_http2: bool, + no_signals: bool, + maxconn: usize, + maxconnrate: usize, + sockets: Vec, +} + +impl HttpServer +where + H: IntoHttpHandler + 'static, + F: Fn() -> H + Send + Clone + 'static, +{ + /// Create new http server with application factory + pub fn new(factory: F) -> HttpServer { + HttpServer { + factory, + threads: num_cpus::get(), + host: None, + backlog: 2048, + keep_alive: KeepAlive::Timeout(5), + shutdown_timeout: 30, + exit: false, + no_http2: false, + no_signals: false, + maxconn: 25_600, + maxconnrate: 256, + client_timeout: 5000, + client_shutdown: 5000, + sockets: Vec::new(), + } + } + + /// Set number of workers to start. + /// + /// By default http server uses number of available logical cpu as threads + /// count. + pub fn workers(mut self, num: usize) -> Self { + self.threads = num; + self + } + + /// Set the maximum number of pending connections. + /// + /// This refers to the number of clients that can be waiting to be served. + /// Exceeding this number results in the client getting an error when + /// attempting to connect. It should only affect servers under significant + /// load. + /// + /// Generally set in the 64-2048 range. Default value is 2048. + /// + /// This method should be called before `bind()` method call. + pub fn backlog(mut self, num: i32) -> Self { + self.backlog = num; + self + } + + /// Sets the maximum per-worker number of concurrent connections. + /// + /// All socket listeners will stop accepting connections when this limit is reached + /// for each worker. + /// + /// By default max connections is set to a 25k. + pub fn maxconn(mut self, num: usize) -> Self { + self.maxconn = num; + self + } + + /// Sets the maximum per-worker concurrent connection establish process. + /// + /// All listeners will stop accepting connections when this limit is reached. It + /// can be used to limit the global SSL CPU usage. + /// + /// By default max connections is set to a 256. + pub fn maxconnrate(mut self, num: usize) -> Self { + self.maxconnrate = num; + self + } + + /// Set server keep-alive setting. + /// + /// By default keep alive is set to a 5 seconds. + pub fn keep_alive>(mut self, val: T) -> Self { + self.keep_alive = val.into(); + self + } + + /// Set server client timeout in milliseconds for first request. + /// + /// Defines a timeout for reading client request header. If a client does not transmit + /// the entire set headers within this time, the request is terminated with + /// the 408 (Request Time-out) error. + /// + /// To disable timeout set value to 0. + /// + /// By default client timeout is set to 5000 milliseconds. + pub fn client_timeout(mut self, val: u64) -> Self { + self.client_timeout = val; + self + } + + /// Set server connection shutdown timeout in milliseconds. + /// + /// Defines a timeout for shutdown connection. If a shutdown procedure does not complete + /// within this time, the request is dropped. + /// + /// To disable timeout set value to 0. + /// + /// By default client timeout is set to 5000 milliseconds. + pub fn client_shutdown(mut self, val: u64) -> Self { + self.client_shutdown = val; + self + } + + /// Set server host name. + /// + /// Host name is used by application router aa a hostname for url + /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. + /// html#method.host) documentation for more information. + pub fn server_hostname(mut self, val: String) -> Self { + self.host = Some(val); + self + } + + /// Stop actix system. + /// + /// `SystemExit` message stops currently running system. + pub fn system_exit(mut self) -> Self { + self.exit = true; + self + } + + /// Disable signal handling + pub fn disable_signals(mut self) -> Self { + self.no_signals = true; + self + } + + /// Timeout for graceful workers shutdown. + /// + /// After receiving a stop signal, workers have this much time to finish + /// serving requests. Workers still alive after the timeout are force + /// dropped. + /// + /// By default shutdown timeout sets to 30 seconds. + pub fn shutdown_timeout(mut self, sec: u16) -> Self { + self.shutdown_timeout = sec; + self + } + + /// Disable `HTTP/2` support + pub fn no_http2(mut self) -> Self { + self.no_http2 = true; + self + } + + /// Get addresses of bound sockets. + pub fn addrs(&self) -> Vec { + self.sockets.iter().map(|s| s.addr).collect() + } + + /// Get addresses of bound sockets and the scheme for it. + /// + /// This is useful when the server is bound from different sources + /// with some sockets listening on http and some listening on https + /// and the user should be presented with an enumeration of which + /// socket requires which protocol. + pub fn addrs_with_scheme(&self) -> Vec<(net::SocketAddr, &str)> { + self.sockets.iter().map(|s| (s.addr, s.scheme)).collect() + } + + /// Use listener for accepting incoming connection requests + /// + /// HttpServer does not change any configuration for TcpListener, + /// it needs to be configured before passing it to listen() method. + pub fn listen(mut self, lst: net::TcpListener) -> Self { + let addr = lst.local_addr().unwrap(); + self.sockets.push(Socket { + lst, + addr, + scheme: "http", + handler: Box::new(HttpServiceBuilder::new( + self.factory.clone(), + DefaultAcceptor, + )), + }); + + self + } + + #[doc(hidden)] + /// Use listener for accepting incoming connection requests + pub fn listen_with(mut self, lst: net::TcpListener, acceptor: A) -> Self + where + A: AcceptorServiceFactory, + ::InitError: fmt::Debug, + { + let addr = lst.local_addr().unwrap(); + self.sockets.push(Socket { + lst, + addr, + scheme: "https", + handler: Box::new(HttpServiceBuilder::new(self.factory.clone(), acceptor)), + }); + + self + } + + #[cfg(feature = "tls")] + /// Use listener for accepting incoming tls connection requests + /// + /// HttpServer does not change any configuration for TcpListener, + /// it needs to be configured before passing it to listen() method. + pub fn listen_tls(self, lst: net::TcpListener, acceptor: TlsAcceptor) -> Self { + use actix_net::service::NewServiceExt; + + self.listen_with(lst, move || { + ssl::NativeTlsAcceptor::new(acceptor.clone()).map_err(|_| ()) + }) + } + + #[cfg(any(feature = "alpn", feature = "ssl"))] + /// Use listener for accepting incoming tls connection requests + /// + /// This method sets alpn protocols to "h2" and "http/1.1" + pub fn listen_ssl( + self, lst: net::TcpListener, builder: SslAcceptorBuilder, + ) -> io::Result { + use super::{openssl_acceptor_with_flags, ServerFlags}; + use actix_net::service::NewServiceExt; + + let flags = if self.no_http2 { + ServerFlags::HTTP1 + } else { + ServerFlags::HTTP1 | ServerFlags::HTTP2 + }; + + let acceptor = openssl_acceptor_with_flags(builder, flags)?; + Ok(self.listen_with(lst, move || { + ssl::OpensslAcceptor::new(acceptor.clone()).map_err(|_| ()) + })) + } + + #[cfg(feature = "rust-tls")] + /// Use listener for accepting incoming tls connection requests + /// + /// This method sets alpn protocols to "h2" and "http/1.1" + pub fn listen_rustls(self, lst: net::TcpListener, config: ServerConfig) -> Self { + use super::{RustlsAcceptor, ServerFlags}; + use actix_net::service::NewServiceExt; + + // alpn support + let flags = if self.no_http2 { + ServerFlags::HTTP1 + } else { + ServerFlags::HTTP1 | ServerFlags::HTTP2 + }; + + self.listen_with(lst, move || { + RustlsAcceptor::with_flags(config.clone(), flags).map_err(|_| ()) + }) + } + + /// The socket address to bind + /// + /// To bind multiple addresses this method can be called multiple times. + pub fn bind(mut self, addr: S) -> io::Result { + let sockets = self.bind2(addr)?; + + for lst in sockets { + self = self.listen(lst); + } + + Ok(self) + } + + /// Start listening for incoming connections with supplied acceptor. + #[doc(hidden)] + #[cfg_attr( + feature = "cargo-clippy", + allow(needless_pass_by_value) + )] + pub fn bind_with(mut self, addr: S, acceptor: A) -> io::Result + where + S: net::ToSocketAddrs, + A: AcceptorServiceFactory, + ::InitError: fmt::Debug, + { + let sockets = self.bind2(addr)?; + + for lst in sockets { + let addr = lst.local_addr().unwrap(); + self.sockets.push(Socket { + lst, + addr, + scheme: "https", + handler: Box::new(HttpServiceBuilder::new( + self.factory.clone(), + acceptor.clone(), + )), + }); + } + + Ok(self) + } + + fn bind2( + &self, addr: S, + ) -> io::Result> { + let mut err = None; + let mut succ = false; + let mut sockets = Vec::new(); + for addr in addr.to_socket_addrs()? { + match create_tcp_listener(addr, self.backlog) { + Ok(lst) => { + succ = true; + sockets.push(lst); + } + Err(e) => err = Some(e), + } + } + + if !succ { + if let Some(e) = err.take() { + Err(e) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "Can not bind to address.", + )) + } + } else { + Ok(sockets) + } + } + + #[cfg(feature = "tls")] + /// The ssl socket address to bind + /// + /// To bind multiple addresses this method can be called multiple times. + pub fn bind_tls( + self, addr: S, acceptor: TlsAcceptor, + ) -> io::Result { + use actix_net::service::NewServiceExt; + use actix_net::ssl::NativeTlsAcceptor; + + self.bind_with(addr, move || { + NativeTlsAcceptor::new(acceptor.clone()).map_err(|_| ()) + }) + } + + #[cfg(any(feature = "alpn", feature = "ssl"))] + /// Start listening for incoming tls connections. + /// + /// This method sets alpn protocols to "h2" and "http/1.1" + pub fn bind_ssl(self, addr: S, builder: SslAcceptorBuilder) -> io::Result + where + S: net::ToSocketAddrs, + { + use super::{openssl_acceptor_with_flags, ServerFlags}; + use actix_net::service::NewServiceExt; + + // alpn support + let flags = if self.no_http2 { + ServerFlags::HTTP1 + } else { + ServerFlags::HTTP1 | ServerFlags::HTTP2 + }; + + let acceptor = openssl_acceptor_with_flags(builder, flags)?; + self.bind_with(addr, move || { + ssl::OpensslAcceptor::new(acceptor.clone()).map_err(|_| ()) + }) + } + + #[cfg(feature = "rust-tls")] + /// Start listening for incoming tls connections. + /// + /// This method sets alpn protocols to "h2" and "http/1.1" + pub fn bind_rustls( + self, addr: S, builder: ServerConfig, + ) -> io::Result { + use super::{RustlsAcceptor, ServerFlags}; + use actix_net::service::NewServiceExt; + + // alpn support + let flags = if self.no_http2 { + ServerFlags::HTTP1 + } else { + ServerFlags::HTTP1 | ServerFlags::HTTP2 + }; + + self.bind_with(addr, move || { + RustlsAcceptor::with_flags(builder.clone(), flags).map_err(|_| ()) + }) + } +} + +impl H + Send + Clone> HttpServer { + /// Start listening for incoming connections. + /// + /// This method starts number of http workers in separate threads. + /// For each address this method starts separate thread which does + /// `accept()` in a loop. + /// + /// This methods panics if no socket addresses get bound. + /// + /// This method requires to run within properly configured `Actix` system. + /// + /// ```rust + /// extern crate actix_web; + /// use actix_web::{actix, server, App, HttpResponse}; + /// + /// fn main() { + /// let sys = actix::System::new("example"); // <- create Actix system + /// + /// server::new(|| App::new().resource("/", |r| r.h(|_: &_| HttpResponse::Ok()))) + /// .bind("127.0.0.1:0") + /// .expect("Can not bind to 127.0.0.1:0") + /// .start(); + /// # actix::System::current().stop(); + /// sys.run(); // <- Run actix system, this method starts all async processes + /// } + /// ``` + pub fn start(mut self) -> Addr { + ssl::max_concurrent_ssl_connect(self.maxconnrate); + + let mut srv = Server::new() + .workers(self.threads) + .maxconn(self.maxconn) + .shutdown_timeout(self.shutdown_timeout); + + srv = if self.exit { srv.system_exit() } else { srv }; + srv = if self.no_signals { + srv.disable_signals() + } else { + srv + }; + + let sockets = mem::replace(&mut self.sockets, Vec::new()); + + for socket in sockets { + let host = self + .host + .as_ref() + .map(|h| h.to_owned()) + .unwrap_or_else(|| format!("{}", socket.addr)); + let (secure, client_shutdown) = if socket.scheme == "https" { + (true, self.client_shutdown) + } else { + (false, 0) + }; + srv = socket.handler.register( + srv, + socket.lst, + host, + socket.addr, + self.keep_alive, + secure, + self.client_timeout, + client_shutdown, + ); + } + srv.start() + } + + /// Spawn new thread and start listening for incoming connections. + /// + /// This method spawns new thread and starts new actix system. Other than + /// that it is similar to `start()` method. This method blocks. + /// + /// This methods panics if no socket addresses get bound. + /// + /// ```rust,ignore + /// # extern crate futures; + /// # extern crate actix_web; + /// # use futures::Future; + /// use actix_web::*; + /// + /// fn main() { + /// HttpServer::new(|| App::new().resource("/", |r| r.h(|_| HttpResponse::Ok()))) + /// .bind("127.0.0.1:0") + /// .expect("Can not bind to 127.0.0.1:0") + /// .run(); + /// } + /// ``` + pub fn run(self) { + let sys = System::new("http-server"); + self.start(); + sys.run(); + } + + /// Register current http server as actix-net's server service + pub fn register(self, mut srv: Server) -> Server { + for socket in self.sockets { + let host = self + .host + .as_ref() + .map(|h| h.to_owned()) + .unwrap_or_else(|| format!("{}", socket.addr)); + let (secure, client_shutdown) = if socket.scheme == "https" { + (true, self.client_shutdown) + } else { + (false, 0) + }; + srv = socket.handler.register( + srv, + socket.lst, + host, + socket.addr, + self.keep_alive, + secure, + self.client_timeout, + client_shutdown, + ); + } + srv + } +} + +fn create_tcp_listener( + addr: net::SocketAddr, backlog: i32, +) -> io::Result { + let builder = match addr { + net::SocketAddr::V4(_) => TcpBuilder::new_v4()?, + net::SocketAddr::V6(_) => TcpBuilder::new_v6()?, + }; + builder.reuse_address(true)?; + builder.bind(addr)?; + Ok(builder.listen(backlog)?) +} diff --git a/src/server/incoming.rs b/src/server/incoming.rs new file mode 100644 index 000000000..b13bba2a7 --- /dev/null +++ b/src/server/incoming.rs @@ -0,0 +1,69 @@ +//! Support for `Stream`, deprecated! +use std::{io, net}; + +use actix::{Actor, Arbiter, AsyncContext, Context, Handler, Message}; +use futures::{Future, Stream}; +use tokio_io::{AsyncRead, AsyncWrite}; + +use super::channel::{HttpChannel, WrapperStream}; +use super::handler::{HttpHandler, IntoHttpHandler}; +use super::http::HttpServer; +use super::settings::{ServerSettings, ServiceConfig}; + +impl Message for WrapperStream { + type Result = (); +} + +impl HttpServer +where + H: IntoHttpHandler, + F: Fn() -> H + Send + Clone, +{ + #[doc(hidden)] + #[deprecated(since = "0.7.8")] + /// Start listening for incoming connections from a stream. + /// + /// This method uses only one thread for handling incoming connections. + pub fn start_incoming(self, stream: S, secure: bool) + where + S: Stream + 'static, + T: AsyncRead + AsyncWrite + 'static, + { + // set server settings + let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap(); + let apps = (self.factory)().into_handler(); + let settings = ServiceConfig::new( + apps, + self.keep_alive, + self.client_timeout, + self.client_shutdown, + ServerSettings::new(addr, "127.0.0.1:8080", secure), + ); + + // start server + HttpIncoming::create(move |ctx| { + ctx.add_message_stream(stream.map_err(|_| ()).map(WrapperStream::new)); + HttpIncoming { settings } + }); + } +} + +struct HttpIncoming { + settings: ServiceConfig, +} + +impl Actor for HttpIncoming { + type Context = Context; +} + +impl Handler> for HttpIncoming +where + T: AsyncRead + AsyncWrite, + H: HttpHandler, +{ + type Result = (); + + fn handle(&mut self, msg: WrapperStream, _: &mut Context) -> Self::Result { + Arbiter::spawn(HttpChannel::new(self.settings.clone(), msg).map_err(|_| ())); + } +} diff --git a/src/server/input.rs b/src/server/input.rs index 8c11c2463..d23d1e991 100644 --- a/src/server/input.rs +++ b/src/server/input.rs @@ -1,14 +1,11 @@ -use std::io::{Read, Write}; -use std::{cmp, io}; +use std::io::{self, Write}; #[cfg(feature = "brotli")] use brotli2::write::BrotliDecoder; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; use error::PayloadError; #[cfg(feature = "flate2")] -use flate2::read::GzDecoder; -#[cfg(feature = "flate2")] -use flate2::write::DeflateDecoder; +use flate2::write::{GzDecoder, ZlibDecoder}; use header::ContentEncoding; use http::header::{HeaderMap, CONTENT_ENCODING}; use payload::{PayloadSender, PayloadStatus, PayloadWriter}; @@ -142,48 +139,14 @@ impl PayloadWriter for EncodedPayload { pub(crate) enum Decoder { #[cfg(feature = "flate2")] - Deflate(Box>), + Deflate(Box>), #[cfg(feature = "flate2")] - Gzip(Option>>), + Gzip(Box>), #[cfg(feature = "brotli")] Br(Box>), Identity, } -// should go after write::GzDecoder get implemented -#[derive(Debug)] -pub(crate) struct Wrapper { - pub buf: BytesMut, - pub eof: bool, -} - -impl io::Read for Wrapper { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let len = cmp::min(buf.len(), self.buf.len()); - buf[..len].copy_from_slice(&self.buf[..len]); - self.buf.split_to(len); - if len == 0 { - if self.eof { - Ok(0) - } else { - Err(io::Error::new(io::ErrorKind::WouldBlock, "")) - } - } else { - Ok(len) - } - } -} - -impl io::Write for Wrapper { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.buf.extend_from_slice(buf); - Ok(buf.len()) - } - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - pub(crate) struct Writer { buf: BytesMut, } @@ -212,28 +175,26 @@ impl io::Write for Writer { /// Payload stream with decompression support pub(crate) struct PayloadStream { decoder: Decoder, - dst: BytesMut, } impl PayloadStream { pub fn new(enc: ContentEncoding) -> PayloadStream { - let dec = match enc { + let decoder = match enc { #[cfg(feature = "brotli")] ContentEncoding::Br => { Decoder::Br(Box::new(BrotliDecoder::new(Writer::new()))) } #[cfg(feature = "flate2")] ContentEncoding::Deflate => { - Decoder::Deflate(Box::new(DeflateDecoder::new(Writer::new()))) + Decoder::Deflate(Box::new(ZlibDecoder::new(Writer::new()))) } #[cfg(feature = "flate2")] - ContentEncoding::Gzip => Decoder::Gzip(None), + ContentEncoding::Gzip => { + Decoder::Gzip(Box::new(GzDecoder::new(Writer::new()))) + } _ => Decoder::Identity, }; - PayloadStream { - decoder: dec, - dst: BytesMut::new(), - } + PayloadStream { decoder } } } @@ -253,22 +214,17 @@ impl PayloadStream { Err(e) => Err(e), }, #[cfg(feature = "flate2")] - Decoder::Gzip(ref mut decoder) => { - if let Some(ref mut decoder) = *decoder { - decoder.as_mut().get_mut().eof = true; - - self.dst.reserve(8192); - match decoder.read(unsafe { self.dst.bytes_mut() }) { - Ok(n) => { - unsafe { self.dst.advance_mut(n) }; - return Ok(Some(self.dst.take().freeze())); - } - Err(e) => return Err(e), + Decoder::Gzip(ref mut decoder) => match decoder.try_finish() { + Ok(_) => { + let b = decoder.get_mut().take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) } - } else { - Ok(None) } - } + Err(e) => Err(e), + }, #[cfg(feature = "flate2")] Decoder::Deflate(ref mut decoder) => match decoder.try_finish() { Ok(_) => { @@ -301,43 +257,18 @@ impl PayloadStream { Err(e) => Err(e), }, #[cfg(feature = "flate2")] - Decoder::Gzip(ref mut decoder) => { - if decoder.is_none() { - *decoder = Some(Box::new(GzDecoder::new(Wrapper { - buf: BytesMut::from(data), - eof: false, - }))); - } else { - let _ = decoder.as_mut().unwrap().write(&data); - } - - loop { - self.dst.reserve(8192); - match decoder - .as_mut() - .as_mut() - .unwrap() - .read(unsafe { self.dst.bytes_mut() }) - { - Ok(n) => { - if n != 0 { - unsafe { self.dst.advance_mut(n) }; - } - if n == 0 { - return Ok(Some(self.dst.take().freeze())); - } - } - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock - && !self.dst.is_empty() - { - return Ok(Some(self.dst.take().freeze())); - } - return Err(e); - } + Decoder::Gzip(ref mut decoder) => match decoder.write_all(&data) { + Ok(_) => { + decoder.flush()?; + let b = decoder.get_mut().take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) } } - } + Err(e) => Err(e), + }, #[cfg(feature = "flate2")] Decoder::Deflate(ref mut decoder) => match decoder.write_all(&data) { Ok(_) => { diff --git a/src/server/message.rs b/src/server/message.rs index 395d7b7c3..9c4bc1ec4 100644 --- a/src/server/message.rs +++ b/src/server/message.rs @@ -1,5 +1,6 @@ use std::cell::{Cell, Ref, RefCell, RefMut}; use std::collections::VecDeque; +use std::fmt; use std::net::SocketAddr; use std::rc::Rc; @@ -35,6 +36,7 @@ pub(crate) struct InnerRequest { pub(crate) info: RefCell, pub(crate) payload: RefCell>, pub(crate) settings: ServerSettings, + pub(crate) stream_extensions: Option>, pool: &'static RequestPool, } @@ -82,6 +84,7 @@ impl Request { info: RefCell::new(ConnectionInfo::default()), payload: RefCell::new(None), extensions: RefCell::new(Extensions::new()), + stream_extensions: None, }), } } @@ -189,6 +192,12 @@ impl Request { } } + /// Io stream extensions + #[inline] + pub fn stream_extensions(&self) -> Option<&Extensions> { + self.inner().stream_extensions.as_ref().map(|e| e.as_ref()) + } + /// Server settings #[inline] pub fn server_settings(&self) -> &ServerSettings { @@ -212,6 +221,26 @@ impl Request { } } +impl fmt::Debug for Request { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + "\nRequest {:?} {}:{}", + self.version(), + self.method(), + self.path() + )?; + if let Some(q) = self.uri().query().as_ref() { + writeln!(f, " query: ?{:?}", q)?; + } + writeln!(f, " headers:")?; + for (key, val) in self.headers().iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + pub(crate) struct RequestPool( RefCell>>, RefCell, diff --git a/src/server/mod.rs b/src/server/mod.rs index a302f5e73..0a16f26b9 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,5 +1,105 @@ -//! Http server -use std::net::Shutdown; +//! Http server module +//! +//! The module contains everything necessary to setup +//! HTTP server. +//! +//! In order to start HTTP server, first you need to create and configure it +//! using factory that can be supplied to [new](fn.new.html). +//! +//! ## Factory +//! +//! Factory is a function that returns Application, describing how +//! to serve incoming HTTP requests. +//! +//! As the server uses worker pool, the factory function is restricted to trait bounds +//! `Send + Clone + 'static` so that each worker would be able to accept Application +//! without a need for synchronization. +//! +//! If you wish to share part of state among all workers you should +//! wrap it in `Arc` and potentially synchronization primitive like +//! [RwLock](https://doc.rust-lang.org/std/sync/struct.RwLock.html) +//! If the wrapped type is not thread safe. +//! +//! Note though that locking is not advisable for asynchronous programming +//! and you should minimize all locks in your request handlers +//! +//! ## HTTPS Support +//! +//! Actix-web provides support for major crates that provides TLS. +//! Each TLS implementation is provided with [AcceptorService](trait.AcceptorService.html) +//! that describes how HTTP Server accepts connections. +//! +//! For `bind` and `listen` there are corresponding `bind_ssl|tls|rustls` and `listen_ssl|tls|rustls` that accepts +//! these services. +//! +//! **NOTE:** `native-tls` doesn't support `HTTP2` yet +//! +//! ## Signal handling and shutdown +//! +//! By default HTTP Server listens for system signals +//! and, gracefully shuts down at most after 30 seconds. +//! +//! Both signal handling and shutdown timeout can be controlled +//! using corresponding methods. +//! +//! If worker, for some reason, unable to shut down within timeout +//! it is forcibly dropped. +//! +//! ## Example +//! +//! ```rust,ignore +//!extern crate actix; +//!extern crate actix_web; +//!extern crate rustls; +//! +//!use actix_web::{http, middleware, server, App, Error, HttpRequest, HttpResponse, Responder}; +//!use std::io::BufReader; +//!use rustls::internal::pemfile::{certs, rsa_private_keys}; +//!use rustls::{NoClientAuth, ServerConfig}; +//! +//!fn index(req: &HttpRequest) -> Result { +//! Ok(HttpResponse::Ok().content_type("text/plain").body("Welcome!")) +//!} +//! +//!fn load_ssl() -> ServerConfig { +//! use std::io::BufReader; +//! +//! const CERT: &'static [u8] = include_bytes!("../cert.pem"); +//! const KEY: &'static [u8] = include_bytes!("../key.pem"); +//! +//! let mut cert = BufReader::new(CERT); +//! let mut key = BufReader::new(KEY); +//! +//! let mut config = ServerConfig::new(NoClientAuth::new()); +//! let cert_chain = certs(&mut cert).unwrap(); +//! let mut keys = rsa_private_keys(&mut key).unwrap(); +//! config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); +//! +//! config +//!} +//! +//!fn main() { +//! let sys = actix::System::new("http-server"); +//! // load ssl keys +//! let config = load_ssl(); +//! +//! // create and start server at once +//! server::new(|| { +//! App::new() +//! // register simple handler, handle all methods +//! .resource("/index.html", |r| r.f(index)) +//! })) +//! }).bind_rustls("127.0.0.1:8443", config) +//! .unwrap() +//! .start(); +//! +//! println!("Started http server: 127.0.0.1:8080"); +//! //Run system so that server would start accepting connections +//! let _ = sys.run(); +//!} +//! ``` +use std::net::{Shutdown, SocketAddr}; +use std::rc::Rc; use std::{io, time}; use bytes::{BufMut, BytesMut}; @@ -7,6 +107,10 @@ use futures::{Async, Poll}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_tcp::TcpStream; +pub use actix_net::server::{PauseServer, ResumeServer, StopServer}; + +pub(crate) mod acceptor; +pub(crate) mod builder; mod channel; mod error; pub(crate) mod h1; @@ -14,24 +118,39 @@ pub(crate) mod h1decoder; mod h1writer; mod h2; mod h2writer; +mod handler; pub(crate) mod helpers; +mod http; +pub(crate) mod incoming; pub(crate) mod input; pub(crate) mod message; pub(crate) mod output; +pub(crate) mod service; pub(crate) mod settings; -mod srv; -mod worker; +mod ssl; +pub use self::handler::*; +pub use self::http::HttpServer; pub use self::message::Request; +pub use self::ssl::*; + +pub use self::error::{AcceptorError, HttpDispatchError}; pub use self::settings::ServerSettings; -pub use self::srv::HttpServer; + +#[doc(hidden)] +pub use self::acceptor::AcceptorTimeout; + +#[doc(hidden)] +pub use self::settings::{ServiceConfig, ServiceConfigBuilder}; + +#[doc(hidden)] +pub use self::service::{H1Service, HttpService, StreamConfiguration}; #[doc(hidden)] pub use self::helpers::write_content_length; -use actix::Message; use body::Binary; -use error::Error; +use extensions::Extensions; use header::ContentEncoding; use httpresponse::HttpResponse; @@ -62,15 +181,25 @@ const HW_BUFFER_SIZE: usize = 32_768; /// sys.run(); /// } /// ``` -pub fn new(factory: F) -> HttpServer +pub fn new(factory: F) -> HttpServer where - F: Fn() -> U + Sync + Send + 'static, - U: IntoIterator + 'static, + F: Fn() -> H + Send + Clone + 'static, H: IntoHttpHandler + 'static, { HttpServer::new(factory) } +#[doc(hidden)] +bitflags! { + ///Flags that can be used to configure HTTP Server. + pub struct ServerFlags: u8 { + ///Use HTTP1 protocol + const HTTP1 = 0b0000_0001; + ///Use HTTP2 protocol + const HTTP2 = 0b0000_0010; + } +} + #[derive(Debug, PartialEq, Clone, Copy)] /// Server keep-alive setting pub enum KeepAlive { @@ -100,84 +229,6 @@ impl From> for KeepAlive { } } -/// Pause accepting incoming connections -/// -/// If socket contains some pending connection, they might be dropped. -/// All opened connection remains active. -#[derive(Message)] -pub struct PauseServer; - -/// Resume accepting incoming connections -#[derive(Message)] -pub struct ResumeServer; - -/// Stop incoming connection processing, stop all workers and exit. -/// -/// If server starts with `spawn()` method, then spawned thread get terminated. -pub struct StopServer { - /// Whether to try and shut down gracefully - pub graceful: bool, -} - -impl Message for StopServer { - type Result = Result<(), ()>; -} - -/// Low level http request handler -#[allow(unused_variables)] -pub trait HttpHandler: 'static { - /// Request handling task - type Task: HttpHandlerTask; - - /// Handle request - fn handle(&self, req: Request) -> Result; -} - -impl HttpHandler for Box>> { - type Task = Box; - - fn handle(&self, req: Request) -> Result, Request> { - self.as_ref().handle(req) - } -} - -/// Low level http request handler -pub trait HttpHandlerTask { - /// Poll task, this method is used before or after *io* object is available - fn poll_completed(&mut self) -> Poll<(), Error> { - Ok(Async::Ready(())) - } - - /// Poll task when *io* object is available - fn poll_io(&mut self, io: &mut Writer) -> Poll; - - /// Connection is disconnected - fn disconnected(&mut self) {} -} - -impl HttpHandlerTask for Box { - fn poll_io(&mut self, io: &mut Writer) -> Poll { - self.as_mut().poll_io(io) - } -} - -/// Conversion helper trait -pub trait IntoHttpHandler { - /// The associated type which is result of conversion. - type Handler: HttpHandler; - - /// Convert into `HttpHandler` object. - fn into_handler(self) -> Self::Handler; -} - -impl IntoHttpHandler for T { - type Handler = T; - - fn into_handler(self) -> Self::Handler { - self - } -} - #[doc(hidden)] #[derive(Debug)] pub enum WriterState { @@ -213,41 +264,79 @@ pub trait Writer { pub trait IoStream: AsyncRead + AsyncWrite + 'static { fn shutdown(&mut self, how: Shutdown) -> io::Result<()>; + /// Returns the socket address of the remote peer of this TCP connection. + fn peer_addr(&self) -> Option { + None + } + + /// Sets the value of the TCP_NODELAY option on this socket. fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()>; fn set_linger(&mut self, dur: Option) -> io::Result<()>; - fn read_available(&mut self, buf: &mut BytesMut) -> Poll { + fn set_keepalive(&mut self, dur: Option) -> io::Result<()>; + + fn read_available(&mut self, buf: &mut BytesMut) -> Poll<(bool, bool), io::Error> { let mut read_some = false; loop { if buf.remaining_mut() < LW_BUFFER_SIZE { buf.reserve(HW_BUFFER_SIZE); } - unsafe { - match self.read(buf.bytes_mut()) { - Ok(n) => { - if n == 0 { - return Ok(Async::Ready(!read_some)); - } else { - read_some = true; + + let read = unsafe { self.read(buf.bytes_mut()) }; + match read { + Ok(n) => { + if n == 0 { + return Ok(Async::Ready((read_some, true))); + } else { + read_some = true; + unsafe { buf.advance_mut(n); } } - Err(e) => { - return if e.kind() == io::ErrorKind::WouldBlock { - if read_some { - Ok(Async::Ready(false)) - } else { - Ok(Async::NotReady) - } + } + Err(e) => { + return if e.kind() == io::ErrorKind::WouldBlock { + if read_some { + Ok(Async::Ready((read_some, false))) } else { - Err(e) - }; - } + Ok(Async::NotReady) + } + } else { + Err(e) + }; } } } } + + /// Extra io stream extensions + fn extensions(&self) -> Option> { + None + } +} + +#[cfg(all(unix, feature = "uds"))] +impl IoStream for ::tokio_uds::UnixStream { + #[inline] + fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { + ::tokio_uds::UnixStream::shutdown(self, how) + } + + #[inline] + fn set_nodelay(&mut self, _nodelay: bool) -> io::Result<()> { + Ok(()) + } + + #[inline] + fn set_linger(&mut self, _dur: Option) -> io::Result<()> { + Ok(()) + } + + #[inline] + fn set_keepalive(&mut self, _dur: Option) -> io::Result<()> { + Ok(()) + } } impl IoStream for TcpStream { @@ -256,6 +345,11 @@ impl IoStream for TcpStream { TcpStream::shutdown(self, how) } + #[inline] + fn peer_addr(&self) -> Option { + TcpStream::peer_addr(self).ok() + } + #[inline] fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { TcpStream::set_nodelay(self, nodelay) @@ -265,48 +359,9 @@ impl IoStream for TcpStream { fn set_linger(&mut self, dur: Option) -> io::Result<()> { TcpStream::set_linger(self, dur) } -} - -#[cfg(feature = "alpn")] -use tokio_openssl::SslStream; - -#[cfg(feature = "alpn")] -impl IoStream for SslStream { - #[inline] - fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { - let _ = self.get_mut().shutdown(); - Ok(()) - } #[inline] - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - self.get_mut().get_mut().set_nodelay(nodelay) - } - - #[inline] - fn set_linger(&mut self, dur: Option) -> io::Result<()> { - self.get_mut().get_mut().set_linger(dur) - } -} - -#[cfg(feature = "tls")] -use tokio_tls::TlsStream; - -#[cfg(feature = "tls")] -impl IoStream for TlsStream { - #[inline] - fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { - let _ = self.get_mut().shutdown(); - Ok(()) - } - - #[inline] - fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - self.get_mut().get_mut().set_nodelay(nodelay) - } - - #[inline] - fn set_linger(&mut self, dur: Option) -> io::Result<()> { - self.get_mut().get_mut().set_linger(dur) + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + TcpStream::set_keepalive(self, dur) } } diff --git a/src/server/output.rs b/src/server/output.rs index 597faf342..4a86ffbb7 100644 --- a/src/server/output.rs +++ b/src/server/output.rs @@ -7,11 +7,11 @@ use std::{cmp, fmt, io, mem}; use brotli2::write::BrotliEncoder; use bytes::BytesMut; #[cfg(feature = "flate2")] -use flate2::write::{DeflateEncoder, GzEncoder}; +use flate2::write::{GzEncoder, ZlibEncoder}; #[cfg(feature = "flate2")] use flate2::Compression; use http::header::{ACCEPT_ENCODING, CONTENT_LENGTH}; -use http::Version; +use http::{StatusCode, Version}; use super::message::InnerRequest; use body::{Binary, Body}; @@ -151,10 +151,9 @@ impl Output { let version = resp.version().unwrap_or_else(|| req.version); let mut len = 0; - #[cfg_attr(feature = "cargo-clippy", allow(match_ref_pats))] let has_body = match resp.body() { - &Body::Empty => false, - &Body::Binary(ref bin) => { + Body::Empty => false, + Body::Binary(ref bin) => { len = bin.len(); !(response_encoding == ContentEncoding::Auto && len < 96) } @@ -190,16 +189,19 @@ impl Output { #[cfg(not(any(feature = "brotli", feature = "flate2")))] let mut encoding = ContentEncoding::Identity; - #[cfg_attr(feature = "cargo-clippy", allow(match_ref_pats))] let transfer = match resp.body() { - &Body::Empty => { - if !info.head { - info.length = ResponseLength::Zero; - } + Body::Empty => { + info.length = match resp.status() { + StatusCode::NO_CONTENT + | StatusCode::CONTINUE + | StatusCode::SWITCHING_PROTOCOLS + | StatusCode::PROCESSING => ResponseLength::None, + _ => ResponseLength::Zero, + }; *self = Output::Empty(buf); return; } - &Body::Binary(_) => { + Body::Binary(_) => { #[cfg(any(feature = "brotli", feature = "flate2"))] { if !(encoding == ContentEncoding::Identity @@ -210,7 +212,7 @@ impl Output { let mut enc = match encoding { #[cfg(feature = "flate2")] ContentEncoding::Deflate => ContentEncoder::Deflate( - DeflateEncoder::new(transfer, Compression::fast()), + ZlibEncoder::new(transfer, Compression::fast()), ), #[cfg(feature = "flate2")] ContentEncoding::Gzip => ContentEncoder::Gzip( @@ -244,7 +246,7 @@ impl Output { } return; } - &Body::Streaming(_) | &Body::Actor(_) => { + Body::Streaming(_) | Body::Actor(_) => { if resp.upgrade() { if version == Version::HTTP_2 { error!("Connection upgrade is forbidden for HTTP/2"); @@ -273,10 +275,9 @@ impl Output { let enc = match encoding { #[cfg(feature = "flate2")] - ContentEncoding::Deflate => ContentEncoder::Deflate(DeflateEncoder::new( - transfer, - Compression::fast(), - )), + ContentEncoding::Deflate => { + ContentEncoder::Deflate(ZlibEncoder::new(transfer, Compression::fast())) + } #[cfg(feature = "flate2")] ContentEncoding::Gzip => { ContentEncoder::Gzip(GzEncoder::new(transfer, Compression::fast())) @@ -298,11 +299,10 @@ impl Output { match resp.chunked() { Some(true) => { // Enable transfer encoding + info.length = ResponseLength::Chunked; if version == Version::HTTP_2 { - info.length = ResponseLength::None; TransferEncoding::eof(buf) } else { - info.length = ResponseLength::Chunked; TransferEncoding::chunked(buf) } } @@ -336,15 +336,11 @@ impl Output { } } else { // Enable transfer encoding - match version { - Version::HTTP_11 => { - info.length = ResponseLength::Chunked; - TransferEncoding::chunked(buf) - } - _ => { - info.length = ResponseLength::None; - TransferEncoding::eof(buf) - } + info.length = ResponseLength::Chunked; + if version == Version::HTTP_2 { + TransferEncoding::eof(buf) + } else { + TransferEncoding::chunked(buf) } } } @@ -354,7 +350,7 @@ impl Output { pub(crate) enum ContentEncoder { #[cfg(feature = "flate2")] - Deflate(DeflateEncoder), + Deflate(ZlibEncoder), #[cfg(feature = "flate2")] Gzip(GzEncoder), #[cfg(feature = "brotli")] diff --git a/src/server/service.rs b/src/server/service.rs new file mode 100644 index 000000000..e3402e305 --- /dev/null +++ b/src/server/service.rs @@ -0,0 +1,272 @@ +use std::marker::PhantomData; +use std::time::Duration; + +use actix_net::service::{NewService, Service}; +use futures::future::{ok, FutureResult}; +use futures::{Async, Poll}; + +use super::channel::{H1Channel, HttpChannel}; +use super::error::HttpDispatchError; +use super::handler::HttpHandler; +use super::settings::ServiceConfig; +use super::IoStream; + +/// `NewService` implementation for HTTP1/HTTP2 transports +pub struct HttpService +where + H: HttpHandler, + Io: IoStream, +{ + settings: ServiceConfig, + _t: PhantomData, +} + +impl HttpService +where + H: HttpHandler, + Io: IoStream, +{ + /// Create new `HttpService` instance. + pub fn new(settings: ServiceConfig) -> Self { + HttpService { + settings, + _t: PhantomData, + } + } +} + +impl NewService for HttpService +where + H: HttpHandler, + Io: IoStream, +{ + type Request = Io; + type Response = (); + type Error = HttpDispatchError; + type InitError = (); + type Service = HttpServiceHandler; + type Future = FutureResult; + + fn new_service(&self) -> Self::Future { + ok(HttpServiceHandler::new(self.settings.clone())) + } +} + +pub struct HttpServiceHandler +where + H: HttpHandler, + Io: IoStream, +{ + settings: ServiceConfig, + _t: PhantomData, +} + +impl HttpServiceHandler +where + H: HttpHandler, + Io: IoStream, +{ + fn new(settings: ServiceConfig) -> HttpServiceHandler { + HttpServiceHandler { + settings, + _t: PhantomData, + } + } +} + +impl Service for HttpServiceHandler +where + H: HttpHandler, + Io: IoStream, +{ + type Request = Io; + type Response = (); + type Error = HttpDispatchError; + type Future = HttpChannel; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + HttpChannel::new(self.settings.clone(), req) + } +} + +/// `NewService` implementation for HTTP1 transport +pub struct H1Service +where + H: HttpHandler, + Io: IoStream, +{ + settings: ServiceConfig, + _t: PhantomData, +} + +impl H1Service +where + H: HttpHandler, + Io: IoStream, +{ + /// Create new `HttpService` instance. + pub fn new(settings: ServiceConfig) -> Self { + H1Service { + settings, + _t: PhantomData, + } + } +} + +impl NewService for H1Service +where + H: HttpHandler, + Io: IoStream, +{ + type Request = Io; + type Response = (); + type Error = HttpDispatchError; + type InitError = (); + type Service = H1ServiceHandler; + type Future = FutureResult; + + fn new_service(&self) -> Self::Future { + ok(H1ServiceHandler::new(self.settings.clone())) + } +} + +/// `Service` implementation for HTTP1 transport +pub struct H1ServiceHandler +where + H: HttpHandler, + Io: IoStream, +{ + settings: ServiceConfig, + _t: PhantomData, +} + +impl H1ServiceHandler +where + H: HttpHandler, + Io: IoStream, +{ + fn new(settings: ServiceConfig) -> H1ServiceHandler { + H1ServiceHandler { + settings, + _t: PhantomData, + } + } +} + +impl Service for H1ServiceHandler +where + H: HttpHandler, + Io: IoStream, +{ + type Request = Io; + type Response = (); + type Error = HttpDispatchError; + type Future = H1Channel; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + H1Channel::new(self.settings.clone(), req) + } +} + +/// `NewService` implementation for stream configuration service +/// +/// Stream configuration service allows to change some socket level +/// parameters. for example `tcp nodelay` or `tcp keep-alive`. +pub struct StreamConfiguration { + no_delay: Option, + tcp_ka: Option>, + _t: PhantomData<(T, E)>, +} + +impl Default for StreamConfiguration { + fn default() -> Self { + Self::new() + } +} + +impl StreamConfiguration { + /// Create new `StreamConfigurationService` instance. + pub fn new() -> Self { + Self { + no_delay: None, + tcp_ka: None, + _t: PhantomData, + } + } + + /// Sets the value of the `TCP_NODELAY` option on this socket. + pub fn nodelay(mut self, nodelay: bool) -> Self { + self.no_delay = Some(nodelay); + self + } + + /// Sets whether keepalive messages are enabled to be sent on this socket. + pub fn tcp_keepalive(mut self, keepalive: Option) -> Self { + self.tcp_ka = Some(keepalive); + self + } +} + +impl NewService for StreamConfiguration { + type Request = T; + type Response = T; + type Error = E; + type InitError = (); + type Service = StreamConfigurationService; + type Future = FutureResult; + + fn new_service(&self) -> Self::Future { + ok(StreamConfigurationService { + no_delay: self.no_delay, + tcp_ka: self.tcp_ka, + _t: PhantomData, + }) + } +} + +/// Stream configuration service +/// +/// Stream configuration service allows to change some socket level +/// parameters. for example `tcp nodelay` or `tcp keep-alive`. +pub struct StreamConfigurationService { + no_delay: Option, + tcp_ka: Option>, + _t: PhantomData<(T, E)>, +} + +impl Service for StreamConfigurationService +where + T: IoStream, +{ + type Request = T; + type Response = T; + type Error = E; + type Future = FutureResult; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, mut req: Self::Request) -> Self::Future { + if let Some(no_delay) = self.no_delay { + if req.set_nodelay(no_delay).is_err() { + error!("Can not set socket no-delay option"); + } + } + if let Some(keepalive) = self.tcp_ka { + if req.set_keepalive(keepalive).is_err() { + error!("Can not set socket keep-alive option"); + } + } + + ok(req) + } +} diff --git a/src/server/settings.rs b/src/server/settings.rs index cc2e1c06e..66a4eed88 100644 --- a/src/server/settings.rs +++ b/src/server/settings.rs @@ -1,17 +1,20 @@ -use std::cell::{Cell, RefCell, RefMut, UnsafeCell}; +use std::cell::{Cell, RefCell}; use std::collections::VecDeque; use std::fmt::Write; use std::rc::Rc; +use std::time::{Duration, Instant}; use std::{env, fmt, net}; use bytes::BytesMut; +use futures::{future, Future}; use futures_cpupool::CpuPool; use http::StatusCode; use lazycell::LazyCell; use parking_lot::Mutex; use time; +use tokio_current_thread::spawn; +use tokio_timer::{sleep, Delay}; -use super::channel::Node; use super::message::{Request, RequestPool}; use super::KeepAlive; use body::Body; @@ -39,7 +42,7 @@ lazy_static! { /// Various server settings pub struct ServerSettings { - addr: Option, + addr: net::SocketAddr, secure: bool, host: String, cpu_pool: LazyCell, @@ -61,7 +64,7 @@ impl Clone for ServerSettings { impl Default for ServerSettings { fn default() -> Self { ServerSettings { - addr: None, + addr: "127.0.0.1:8080".parse().unwrap(), secure: false, host: "localhost:8080".to_owned(), responses: HttpResponsePool::get_pool(), @@ -73,15 +76,9 @@ impl Default for ServerSettings { impl ServerSettings { /// Crate server settings instance pub(crate) fn new( - addr: Option, host: &Option, secure: bool, + addr: net::SocketAddr, host: &str, secure: bool, ) -> ServerSettings { - let host = if let Some(ref host) = *host { - host.clone() - } else if let Some(ref addr) = addr { - format!("{}", addr) - } else { - "localhost".to_owned() - }; + let host = host.to_owned(); let cpu_pool = LazyCell::new(); let responses = HttpResponsePool::get_pool(); ServerSettings { @@ -93,23 +90,8 @@ impl ServerSettings { } } - pub(crate) fn parts(&self) -> (Option, String, bool) { - (self.addr, self.host.clone(), self.secure) - } - - pub(crate) fn from_parts(parts: (Option, String, bool)) -> Self { - let (addr, host, secure) = parts; - ServerSettings { - addr, - host, - secure, - cpu_pool: LazyCell::new(), - responses: HttpResponsePool::get_pool(), - } - } - /// Returns the socket address of the local half of this TCP connection - pub fn local_addr(&self) -> Option { + pub fn local_addr(&self) -> net::SocketAddr { self.addr } @@ -144,105 +126,294 @@ impl ServerSettings { // "Sun, 06 Nov 1994 08:49:37 GMT".len() const DATE_VALUE_LENGTH: usize = 29; -pub(crate) struct WorkerSettings { - h: RefCell>, - keep_alive: u64, +/// Http service configuration +pub struct ServiceConfig(Rc>); + +struct Inner { + handler: H, + keep_alive: Option, + client_timeout: u64, + client_shutdown: u64, ka_enabled: bool, bytes: Rc, messages: &'static RequestPool, - channels: Cell, - node: RefCell>, - date: UnsafeCell, + date: Cell>, } -impl WorkerSettings { +impl Clone for ServiceConfig { + fn clone(&self) -> Self { + ServiceConfig(self.0.clone()) + } +} + +impl ServiceConfig { + /// Create instance of `ServiceConfig` pub(crate) fn new( - h: Vec, keep_alive: KeepAlive, settings: ServerSettings, - ) -> WorkerSettings { + handler: H, keep_alive: KeepAlive, client_timeout: u64, client_shutdown: u64, + settings: ServerSettings, + ) -> ServiceConfig { let (keep_alive, ka_enabled) = match keep_alive { KeepAlive::Timeout(val) => (val as u64, true), KeepAlive::Os | KeepAlive::Tcp(_) => (0, true), KeepAlive::Disabled => (0, false), }; + let keep_alive = if ka_enabled && keep_alive > 0 { + Some(Duration::from_secs(keep_alive)) + } else { + None + }; - WorkerSettings { - h: RefCell::new(h), - bytes: Rc::new(SharedBytesPool::new()), - messages: RequestPool::pool(settings), - channels: Cell::new(0), - node: RefCell::new(Node::head()), - date: UnsafeCell::new(Date::new()), + ServiceConfig(Rc::new(Inner { + handler, keep_alive, ka_enabled, - } + client_timeout, + client_shutdown, + bytes: Rc::new(SharedBytesPool::new()), + messages: RequestPool::pool(settings), + date: Cell::new(None), + })) } - pub fn num_channels(&self) -> usize { - self.channels.get() + /// Create worker settings builder. + pub fn build(handler: H) -> ServiceConfigBuilder { + ServiceConfigBuilder::new(handler) } - pub fn head(&self) -> RefMut> { - self.node.borrow_mut() + pub(crate) fn handler(&self) -> &H { + &self.0.handler } - pub fn handlers(&self) -> RefMut> { - self.h.borrow_mut() - } - - pub fn keep_alive(&self) -> u64 { - self.keep_alive + #[inline] + /// Keep alive duration if configured. + pub fn keep_alive(&self) -> Option { + self.0.keep_alive } + #[inline] + /// Return state of connection keep-alive funcitonality pub fn keep_alive_enabled(&self) -> bool { - self.ka_enabled + self.0.ka_enabled } - pub fn get_bytes(&self) -> BytesMut { - self.bytes.get_bytes() + pub(crate) fn get_bytes(&self) -> BytesMut { + self.0.bytes.get_bytes() } - pub fn release_bytes(&self, bytes: BytesMut) { - self.bytes.release_bytes(bytes) + pub(crate) fn release_bytes(&self, bytes: BytesMut) { + self.0.bytes.release_bytes(bytes) } - pub fn get_request(&self) -> Request { - RequestPool::get(self.messages) - } - - pub fn add_channel(&self) { - self.channels.set(self.channels.get() + 1); - } - - pub fn remove_channel(&self) { - let num = self.channels.get(); - if num > 0 { - self.channels.set(num - 1); - } else { - error!("Number of removed channels is bigger than added channel. Bug in actix-web"); - } - } - - pub fn update_date(&self) { - // Unsafe: WorkerSetting is !Sync and !Send - unsafe { &mut *self.date.get() }.update(); - } - - pub fn set_date(&self, dst: &mut BytesMut, full: bool) { - // Unsafe: WorkerSetting is !Sync and !Send - let date_bytes = unsafe { &(*self.date.get()).bytes }; - if full { - let mut buf: [u8; 39] = [0; 39]; - buf[..6].copy_from_slice(b"date: "); - buf[6..35].copy_from_slice(date_bytes); - buf[35..].copy_from_slice(b"\r\n\r\n"); - dst.extend_from_slice(&buf); - } else { - dst.extend_from_slice(date_bytes); - } + pub(crate) fn get_request(&self) -> Request { + RequestPool::get(self.0.messages) } } +impl ServiceConfig { + #[inline] + /// Client timeout for first request. + pub fn client_timer(&self) -> Option { + let delay = self.0.client_timeout; + if delay != 0 { + Some(Delay::new(self.now() + Duration::from_millis(delay))) + } else { + None + } + } + + /// Client timeout for first request. + pub fn client_timer_expire(&self) -> Option { + let delay = self.0.client_timeout; + if delay != 0 { + Some(self.now() + Duration::from_millis(delay)) + } else { + None + } + } + + /// Client shutdown timer + pub fn client_shutdown_timer(&self) -> Option { + let delay = self.0.client_shutdown; + if delay != 0 { + Some(self.now() + Duration::from_millis(delay)) + } else { + None + } + } + + #[inline] + /// Return keep-alive timer delay is configured. + pub fn keep_alive_timer(&self) -> Option { + if let Some(ka) = self.0.keep_alive { + Some(Delay::new(self.now() + ka)) + } else { + None + } + } + + /// Keep-alive expire time + pub fn keep_alive_expire(&self) -> Option { + if let Some(ka) = self.0.keep_alive { + Some(self.now() + ka) + } else { + None + } + } + + fn check_date(&self) { + if unsafe { &*self.0.date.as_ptr() }.is_none() { + self.0.date.set(Some(Date::new())); + + // periodic date update + let s = self.clone(); + spawn(sleep(Duration::from_millis(500)).then(move |_| { + s.0.date.set(None); + future::ok(()) + })); + } + } + + pub(crate) fn set_date(&self, dst: &mut BytesMut, full: bool) { + self.check_date(); + + let date = &unsafe { &*self.0.date.as_ptr() }.as_ref().unwrap().bytes; + + if full { + let mut buf: [u8; 39] = [0; 39]; + buf[..6].copy_from_slice(b"date: "); + buf[6..35].copy_from_slice(date); + buf[35..].copy_from_slice(b"\r\n\r\n"); + dst.extend_from_slice(&buf); + } else { + dst.extend_from_slice(date); + } + } + + #[inline] + pub(crate) fn now(&self) -> Instant { + self.check_date(); + unsafe { &*self.0.date.as_ptr() }.as_ref().unwrap().current + } +} + +/// A service config builder +/// +/// This type can be used to construct an instance of `ServiceConfig` through a +/// builder-like pattern. +pub struct ServiceConfigBuilder { + handler: H, + keep_alive: KeepAlive, + client_timeout: u64, + client_shutdown: u64, + host: String, + addr: net::SocketAddr, + secure: bool, +} + +impl ServiceConfigBuilder { + /// Create instance of `ServiceConfigBuilder` + pub fn new(handler: H) -> ServiceConfigBuilder { + ServiceConfigBuilder { + handler, + keep_alive: KeepAlive::Timeout(5), + client_timeout: 5000, + client_shutdown: 5000, + secure: false, + host: "localhost".to_owned(), + addr: "127.0.0.1:8080".parse().unwrap(), + } + } + + /// Enable secure flag for current server. + /// + /// By default this flag is set to false. + pub fn secure(mut self) -> Self { + self.secure = true; + self + } + + /// Set server keep-alive setting. + /// + /// By default keep alive is set to a 5 seconds. + pub fn keep_alive>(mut self, val: T) -> Self { + self.keep_alive = val.into(); + self + } + + /// Set server client timeout in milliseconds for first request. + /// + /// Defines a timeout for reading client request header. If a client does not transmit + /// the entire set headers within this time, the request is terminated with + /// the 408 (Request Time-out) error. + /// + /// To disable timeout set value to 0. + /// + /// By default client timeout is set to 5000 milliseconds. + pub fn client_timeout(mut self, val: u64) -> Self { + self.client_timeout = val; + self + } + + /// Set server connection shutdown timeout in milliseconds. + /// + /// Defines a timeout for shutdown connection. If a shutdown procedure does not complete + /// within this time, the request is dropped. This timeout affects only secure connections. + /// + /// To disable timeout set value to 0. + /// + /// By default client timeout is set to 5000 milliseconds. + pub fn client_shutdown(mut self, val: u64) -> Self { + self.client_shutdown = val; + self + } + + /// Set server host name. + /// + /// Host name is used by application router aa a hostname for url + /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. + /// html#method.host) documentation for more information. + /// + /// By default host name is set to a "localhost" value. + pub fn server_hostname(mut self, val: &str) -> Self { + self.host = val.to_owned(); + self + } + + /// Set server ip address. + /// + /// Host name is used by application router aa a hostname for url + /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. + /// html#method.host) documentation for more information. + /// + /// By default server address is set to a "127.0.0.1:8080" + pub fn server_address(mut self, addr: S) -> Self { + match addr.to_socket_addrs() { + Err(err) => error!("Can not convert to SocketAddr: {}", err), + Ok(mut addrs) => if let Some(addr) = addrs.next() { + self.addr = addr; + }, + } + self + } + + /// Finish service configuration and create `ServiceConfig` object. + pub fn finish(self) -> ServiceConfig { + let settings = ServerSettings::new(self.addr, &self.host, self.secure); + let client_shutdown = if self.secure { self.client_shutdown } else { 0 }; + + ServiceConfig::new( + self.handler, + self.keep_alive, + self.client_timeout, + client_shutdown, + settings, + ) + } +} + +#[derive(Copy, Clone)] struct Date { + current: Instant, bytes: [u8; DATE_VALUE_LENGTH], pos: usize, } @@ -250,6 +421,7 @@ struct Date { impl Date { fn new() -> Date { let mut date = Date { + current: Instant::now(), bytes: [0; DATE_VALUE_LENGTH], pos: 0, }; @@ -258,6 +430,7 @@ impl Date { } fn update(&mut self) { self.pos = 0; + self.current = Instant::now(); write!(self, "{}", time::at_utc(time::get_time()).rfc822()).unwrap(); } } @@ -299,6 +472,8 @@ impl SharedBytesPool { #[cfg(test)] mod tests { use super::*; + use futures::future; + use tokio::runtime::current_thread; #[test] fn test_date_len() { @@ -307,15 +482,22 @@ mod tests { #[test] fn test_date() { - let settings = WorkerSettings::<()>::new( - Vec::new(), - KeepAlive::Os, - ServerSettings::default(), - ); - let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); - settings.set_date(&mut buf1, true); - let mut buf2 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); - settings.set_date(&mut buf2, true); - assert_eq!(buf1, buf2); + let mut rt = current_thread::Runtime::new().unwrap(); + + let _ = rt.block_on(future::lazy(|| { + let settings = ServiceConfig::<()>::new( + (), + KeepAlive::Os, + 0, + 0, + ServerSettings::default(), + ); + let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); + settings.set_date(&mut buf1, true); + let mut buf2 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); + settings.set_date(&mut buf2, true); + assert_eq!(buf1, buf2); + future::ok::<_, ()>(()) + })); } } diff --git a/src/server/srv.rs b/src/server/srv.rs deleted file mode 100644 index 02580d015..000000000 --- a/src/server/srv.rs +++ /dev/null @@ -1,960 +0,0 @@ -use std::rc::Rc; -use std::sync::{mpsc as sync_mpsc, Arc}; -use std::time::Duration; -use std::{io, net, thread}; - -use actix::{ - fut, signal, Actor, ActorFuture, Addr, Arbiter, AsyncContext, Context, Handler, - Response, StreamHandler, System, WrapFuture, -}; - -use futures::sync::mpsc; -use futures::{Future, Sink, Stream}; -use mio; -use net2::TcpBuilder; -use num_cpus; -use slab::Slab; -use tokio_io::{AsyncRead, AsyncWrite}; - -#[cfg(feature = "tls")] -use native_tls::TlsAcceptor; - -#[cfg(feature = "alpn")] -use openssl::ssl::{AlpnError, SslAcceptorBuilder}; - -use super::channel::{HttpChannel, WrapperStream}; -use super::settings::{ServerSettings, WorkerSettings}; -use super::worker::{Conn, SocketInfo, StopWorker, StreamHandlerType, Worker}; -use super::{IntoHttpHandler, IoStream, KeepAlive}; -use super::{PauseServer, ResumeServer, StopServer}; - -#[cfg(feature = "alpn")] -fn configure_alpn(builder: &mut SslAcceptorBuilder) -> io::Result<()> { - builder.set_alpn_protos(b"\x02h2\x08http/1.1")?; - builder.set_alpn_select_callback(|_, protos| { - const H2: &[u8] = b"\x02h2"; - if protos.windows(3).any(|window| window == H2) { - Ok(b"h2") - } else { - Err(AlpnError::NOACK) - } - }); - Ok(()) -} - -/// An HTTP Server -pub struct HttpServer -where - H: IntoHttpHandler + 'static, -{ - h: Option>>, - threads: usize, - backlog: i32, - host: Option, - keep_alive: KeepAlive, - factory: Arc Vec + Send + Sync>, - #[cfg_attr(feature = "cargo-clippy", allow(type_complexity))] - workers: Vec<(usize, Addr>)>, - sockets: Vec, - accept: Vec<(mio::SetReadiness, sync_mpsc::Sender)>, - exit: bool, - shutdown_timeout: u16, - signals: Option>, - no_http2: bool, - no_signals: bool, -} - -enum ServerCommand { - WorkerDied(usize, Slab), -} - -impl Actor for HttpServer -where - H: IntoHttpHandler, -{ - type Context = Context; -} - -struct Socket { - lst: net::TcpListener, - addr: net::SocketAddr, - tp: StreamHandlerType, -} - -impl HttpServer -where - H: IntoHttpHandler + 'static, -{ - /// Create new http server with application factory - pub fn new(factory: F) -> Self - where - F: Fn() -> U + Sync + Send + 'static, - U: IntoIterator + 'static, - { - let f = move || (factory)().into_iter().collect(); - - HttpServer { - h: None, - threads: num_cpus::get(), - backlog: 2048, - host: None, - keep_alive: KeepAlive::Os, - factory: Arc::new(f), - workers: Vec::new(), - sockets: Vec::new(), - accept: Vec::new(), - exit: false, - shutdown_timeout: 30, - signals: None, - no_http2: false, - no_signals: false, - } - } - - /// Set number of workers to start. - /// - /// By default http server uses number of available logical cpu as threads - /// count. - pub fn workers(mut self, num: usize) -> Self { - self.threads = num; - self - } - - #[doc(hidden)] - #[deprecated(since = "0.6.0", note = "please use `HttpServer::workers()` instead")] - pub fn threads(self, num: usize) -> Self { - self.workers(num) - } - - /// Set the maximum number of pending connections. - /// - /// This refers to the number of clients that can be waiting to be served. - /// Exceeding this number results in the client getting an error when - /// attempting to connect. It should only affect servers under significant - /// load. - /// - /// Generally set in the 64-2048 range. Default value is 2048. - /// - /// This method should be called before `bind()` method call. - pub fn backlog(mut self, num: i32) -> Self { - self.backlog = num; - self - } - - /// Set server keep-alive setting. - /// - /// By default keep alive is set to a `Os`. - pub fn keep_alive>(mut self, val: T) -> Self { - self.keep_alive = val.into(); - self - } - - /// Set server host name. - /// - /// Host name is used by application router aa a hostname for url - /// generation. Check [ConnectionInfo](./dev/struct.ConnectionInfo. - /// html#method.host) documentation for more information. - pub fn server_hostname(mut self, val: String) -> Self { - self.host = Some(val); - self - } - - /// Stop actix system. - /// - /// `SystemExit` message stops currently running system. - pub fn system_exit(mut self) -> Self { - self.exit = true; - self - } - - /// Set alternative address for `ProcessSignals` actor. - pub fn signals(mut self, addr: Addr) -> Self { - self.signals = Some(addr); - self - } - - /// Disable signal handling - pub fn disable_signals(mut self) -> Self { - self.no_signals = true; - self - } - - /// Timeout for graceful workers shutdown. - /// - /// After receiving a stop signal, workers have this much time to finish - /// serving requests. Workers still alive after the timeout are force - /// dropped. - /// - /// By default shutdown timeout sets to 30 seconds. - pub fn shutdown_timeout(mut self, sec: u16) -> Self { - self.shutdown_timeout = sec; - self - } - - /// Disable `HTTP/2` support - pub fn no_http2(mut self) -> Self { - self.no_http2 = true; - self - } - - /// Get addresses of bound sockets. - pub fn addrs(&self) -> Vec { - self.sockets.iter().map(|s| s.addr).collect() - } - - /// Get addresses of bound sockets and the scheme for it. - /// - /// This is useful when the server is bound from different sources - /// with some sockets listening on http and some listening on https - /// and the user should be presented with an enumeration of which - /// socket requires which protocol. - pub fn addrs_with_scheme(&self) -> Vec<(net::SocketAddr, &str)> { - self.sockets - .iter() - .map(|s| (s.addr, s.tp.scheme())) - .collect() - } - - /// Use listener for accepting incoming connection requests - /// - /// HttpServer does not change any configuration for TcpListener, - /// it needs to be configured before passing it to listen() method. - pub fn listen(mut self, lst: net::TcpListener) -> Self { - let addr = lst.local_addr().unwrap(); - self.sockets.push(Socket { - addr, - lst, - tp: StreamHandlerType::Normal, - }); - self - } - - #[cfg(feature = "tls")] - /// Use listener for accepting incoming tls connection requests - /// - /// HttpServer does not change any configuration for TcpListener, - /// it needs to be configured before passing it to listen() method. - pub fn listen_tls(mut self, lst: net::TcpListener, acceptor: TlsAcceptor) -> Self { - let addr = lst.local_addr().unwrap(); - self.sockets.push(Socket { - addr, - lst, - tp: StreamHandlerType::Tls(acceptor.clone()), - }); - self - } - - #[cfg(feature = "alpn")] - /// Use listener for accepting incoming tls connection requests - /// - /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn listen_ssl( - mut self, lst: net::TcpListener, mut builder: SslAcceptorBuilder, - ) -> io::Result { - // alpn support - if !self.no_http2 { - configure_alpn(&mut builder)?; - } - let acceptor = builder.build(); - let addr = lst.local_addr().unwrap(); - self.sockets.push(Socket { - addr, - lst, - tp: StreamHandlerType::Alpn(acceptor.clone()), - }); - Ok(self) - } - - fn bind2(&mut self, addr: S) -> io::Result> { - let mut err = None; - let mut succ = false; - let mut sockets = Vec::new(); - for addr in addr.to_socket_addrs()? { - match create_tcp_listener(addr, self.backlog) { - Ok(lst) => { - succ = true; - let addr = lst.local_addr().unwrap(); - sockets.push(Socket { - lst, - addr, - tp: StreamHandlerType::Normal, - }); - } - Err(e) => err = Some(e), - } - } - - if !succ { - if let Some(e) = err.take() { - Err(e) - } else { - Err(io::Error::new( - io::ErrorKind::Other, - "Can not bind to address.", - )) - } - } else { - Ok(sockets) - } - } - - /// The socket address to bind - /// - /// To bind multiple addresses this method can be called multiple times. - pub fn bind(mut self, addr: S) -> io::Result { - let sockets = self.bind2(addr)?; - self.sockets.extend(sockets); - Ok(self) - } - - #[cfg(feature = "tls")] - /// The ssl socket address to bind - /// - /// To bind multiple addresses this method can be called multiple times. - pub fn bind_tls( - mut self, addr: S, acceptor: TlsAcceptor, - ) -> io::Result { - let sockets = self.bind2(addr)?; - self.sockets.extend(sockets.into_iter().map(|mut s| { - s.tp = StreamHandlerType::Tls(acceptor.clone()); - s - })); - Ok(self) - } - - #[cfg(feature = "alpn")] - /// Start listening for incoming tls connections. - /// - /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn bind_ssl( - mut self, addr: S, mut builder: SslAcceptorBuilder, - ) -> io::Result { - // alpn support - if !self.no_http2 { - configure_alpn(&mut builder)?; - } - - let acceptor = builder.build(); - let sockets = self.bind2(addr)?; - self.sockets.extend(sockets.into_iter().map(|mut s| { - s.tp = StreamHandlerType::Alpn(acceptor.clone()); - s - })); - Ok(self) - } - - fn start_workers( - &mut self, settings: &ServerSettings, sockets: &Slab, - ) -> Vec<(usize, mpsc::UnboundedSender>)> { - // start workers - let mut workers = Vec::new(); - for idx in 0..self.threads { - let (tx, rx) = mpsc::unbounded::>(); - - let ka = self.keep_alive; - let socks = sockets.clone(); - let factory = Arc::clone(&self.factory); - let parts = settings.parts(); - - let addr = Arbiter::start(move |ctx: &mut Context<_>| { - let s = ServerSettings::from_parts(parts); - let apps: Vec<_> = - (*factory)().into_iter().map(|h| h.into_handler()).collect(); - ctx.add_message_stream(rx); - Worker::new(apps, socks, ka, s) - }); - workers.push((idx, tx)); - self.workers.push((idx, addr)); - } - info!("Starting {} http workers", self.threads); - workers - } - - // subscribe to os signals - fn subscribe_to_signals(&self) -> Option> { - if !self.no_signals { - if let Some(ref signals) = self.signals { - Some(signals.clone()) - } else { - Some(System::current().registry().get::()) - } - } else { - None - } - } -} - -impl HttpServer { - /// Start listening for incoming connections. - /// - /// This method starts number of http handler workers in separate threads. - /// For each address this method starts separate thread which does - /// `accept()` in a loop. - /// - /// This methods panics if no socket addresses get bound. - /// - /// This method requires to run within properly configured `Actix` system. - /// - /// ```rust - /// extern crate actix_web; - /// use actix_web::{actix, server, App, HttpResponse}; - /// - /// fn main() { - /// let sys = actix::System::new("example"); // <- create Actix system - /// - /// server::new(|| App::new().resource("/", |r| r.h(|_: &_| HttpResponse::Ok()))) - /// .bind("127.0.0.1:0") - /// .expect("Can not bind to 127.0.0.1:0") - /// .start(); - /// # actix::System::current().stop(); - /// sys.run(); // <- Run actix system, this method starts all async processes - /// } - /// ``` - pub fn start(mut self) -> Addr { - if self.sockets.is_empty() { - panic!("HttpServer::bind() has to be called before start()"); - } else { - let (tx, rx) = mpsc::unbounded(); - - let mut socks = Slab::new(); - let mut addrs: Vec<(usize, Socket)> = Vec::new(); - - for socket in self.sockets.drain(..) { - let entry = socks.vacant_entry(); - let token = entry.key(); - entry.insert(SocketInfo { - addr: socket.addr, - htype: socket.tp.clone(), - }); - addrs.push((token, socket)); - } - - let settings = ServerSettings::new(Some(addrs[0].1.addr), &self.host, false); - let workers = self.start_workers(&settings, &socks); - - // start acceptors threads - for (token, sock) in addrs { - info!("Starting server on http://{}", sock.addr); - self.accept.push(start_accept_thread( - token, - sock, - tx.clone(), - socks.clone(), - workers.clone(), - )); - } - - // start http server actor - let signals = self.subscribe_to_signals(); - let addr = Actor::create(move |ctx| { - ctx.add_stream(rx); - self - }); - if let Some(signals) = signals { - signals.do_send(signal::Subscribe(addr.clone().recipient())) - } - addr - } - } - - /// Spawn new thread and start listening for incoming connections. - /// - /// This method spawns new thread and starts new actix system. Other than - /// that it is similar to `start()` method. This method blocks. - /// - /// This methods panics if no socket addresses get bound. - /// - /// ```rust,ignore - /// # extern crate futures; - /// # extern crate actix_web; - /// # use futures::Future; - /// use actix_web::*; - /// - /// fn main() { - /// HttpServer::new(|| App::new().resource("/", |r| r.h(|_| HttpResponse::Ok()))) - /// .bind("127.0.0.1:0") - /// .expect("Can not bind to 127.0.0.1:0") - /// .run(); - /// } - /// ``` - pub fn run(self) { - let sys = System::new("http-server"); - self.start(); - sys.run(); - } -} - -#[doc(hidden)] -#[cfg(feature = "tls")] -#[deprecated( - since = "0.6.0", note = "please use `actix_web::HttpServer::bind_tls` instead" -)] -impl HttpServer { - /// Start listening for incoming tls connections. - pub fn start_tls(mut self, acceptor: TlsAcceptor) -> io::Result> { - for sock in &mut self.sockets { - match sock.tp { - StreamHandlerType::Normal => (), - _ => continue, - } - sock.tp = StreamHandlerType::Tls(acceptor.clone()); - } - Ok(self.start()) - } -} - -#[doc(hidden)] -#[cfg(feature = "alpn")] -#[deprecated( - since = "0.6.0", note = "please use `actix_web::HttpServer::bind_ssl` instead" -)] -impl HttpServer { - /// Start listening for incoming tls connections. - /// - /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn start_ssl( - mut self, mut builder: SslAcceptorBuilder, - ) -> io::Result> { - // alpn support - if !self.no_http2 { - builder.set_alpn_protos(b"\x02h2\x08http/1.1")?; - builder.set_alpn_select_callback(|_, protos| { - const H2: &[u8] = b"\x02h2"; - if protos.windows(3).any(|window| window == H2) { - Ok(b"h2") - } else { - Err(AlpnError::NOACK) - } - }); - } - - let acceptor = builder.build(); - for sock in &mut self.sockets { - match sock.tp { - StreamHandlerType::Normal => (), - _ => continue, - } - sock.tp = StreamHandlerType::Alpn(acceptor.clone()); - } - Ok(self.start()) - } -} - -impl HttpServer { - /// Start listening for incoming connections from a stream. - /// - /// This method uses only one thread for handling incoming connections. - pub fn start_incoming(mut self, stream: S, secure: bool) -> Addr - where - S: Stream + Send + 'static, - T: AsyncRead + AsyncWrite + Send + 'static, - { - // set server settings - let addr: net::SocketAddr = "127.0.0.1:8080".parse().unwrap(); - let settings = ServerSettings::new(Some(addr), &self.host, secure); - let apps: Vec<_> = (*self.factory)() - .into_iter() - .map(|h| h.into_handler()) - .collect(); - self.h = Some(Rc::new(WorkerSettings::new( - apps, - self.keep_alive, - settings, - ))); - - // start server - let signals = self.subscribe_to_signals(); - let addr = HttpServer::create(move |ctx| { - ctx.add_message_stream(stream.map_err(|_| ()).map(move |t| Conn { - io: WrapperStream::new(t), - token: 0, - peer: None, - http2: false, - })); - self - }); - - if let Some(signals) = signals { - signals.do_send(signal::Subscribe(addr.clone().recipient())) - } - addr - } -} - -/// Signals support -/// Handle `SIGINT`, `SIGTERM`, `SIGQUIT` signals and stop actix system -/// message to `System` actor. -impl Handler for HttpServer { - type Result = (); - - fn handle(&mut self, msg: signal::Signal, ctx: &mut Context) { - match msg.0 { - signal::SignalType::Int => { - info!("SIGINT received, exiting"); - self.exit = true; - Handler::::handle(self, StopServer { graceful: false }, ctx); - } - signal::SignalType::Term => { - info!("SIGTERM received, stopping"); - self.exit = true; - Handler::::handle(self, StopServer { graceful: true }, ctx); - } - signal::SignalType::Quit => { - info!("SIGQUIT received, exiting"); - self.exit = true; - Handler::::handle(self, StopServer { graceful: false }, ctx); - } - _ => (), - } - } -} - -/// Commands from accept threads -impl StreamHandler for HttpServer { - fn finished(&mut self, _: &mut Context) {} - - fn handle(&mut self, msg: ServerCommand, _: &mut Context) { - match msg { - ServerCommand::WorkerDied(idx, socks) => { - let mut found = false; - for i in 0..self.workers.len() { - if self.workers[i].0 == idx { - self.workers.swap_remove(i); - found = true; - break; - } - } - - if found { - error!("Worker has died {:?}, restarting", idx); - let (tx, rx) = mpsc::unbounded::>(); - - let mut new_idx = self.workers.len(); - 'found: loop { - for i in 0..self.workers.len() { - if self.workers[i].0 == new_idx { - new_idx += 1; - continue 'found; - } - } - break; - } - - let ka = self.keep_alive; - let factory = Arc::clone(&self.factory); - let host = self.host.clone(); - let addr = socks[0].addr; - - let addr = Arbiter::start(move |ctx: &mut Context<_>| { - let settings = ServerSettings::new(Some(addr), &host, false); - let apps: Vec<_> = - (*factory)().into_iter().map(|h| h.into_handler()).collect(); - ctx.add_message_stream(rx); - Worker::new(apps, socks, ka, settings) - }); - for item in &self.accept { - let _ = item.1.send(Command::Worker(new_idx, tx.clone())); - let _ = item.0.set_readiness(mio::Ready::readable()); - } - - self.workers.push((new_idx, addr)); - } - } - } - } -} - -impl Handler> for HttpServer -where - T: IoStream, - H: IntoHttpHandler, -{ - type Result = (); - - fn handle(&mut self, msg: Conn, _: &mut Context) -> Self::Result { - Arbiter::spawn(HttpChannel::new( - Rc::clone(self.h.as_ref().unwrap()), - msg.io, - msg.peer, - msg.http2, - )); - } -} - -impl Handler for HttpServer { - type Result = (); - - fn handle(&mut self, _: PauseServer, _: &mut Context) { - for item in &self.accept { - let _ = item.1.send(Command::Pause); - let _ = item.0.set_readiness(mio::Ready::readable()); - } - } -} - -impl Handler for HttpServer { - type Result = (); - - fn handle(&mut self, _: ResumeServer, _: &mut Context) { - for item in &self.accept { - let _ = item.1.send(Command::Resume); - let _ = item.0.set_readiness(mio::Ready::readable()); - } - } -} - -impl Handler for HttpServer { - type Result = Response<(), ()>; - - fn handle(&mut self, msg: StopServer, ctx: &mut Context) -> Self::Result { - // stop accept threads - for item in &self.accept { - let _ = item.1.send(Command::Stop); - let _ = item.0.set_readiness(mio::Ready::readable()); - } - - // stop workers - let (tx, rx) = mpsc::channel(1); - - let dur = if msg.graceful { - Some(Duration::new(u64::from(self.shutdown_timeout), 0)) - } else { - None - }; - for worker in &self.workers { - let tx2 = tx.clone(); - ctx.spawn( - worker - .1 - .send(StopWorker { graceful: dur }) - .into_actor(self) - .then(move |_, slf, ctx| { - slf.workers.pop(); - if slf.workers.is_empty() { - let _ = tx2.send(()); - - // we need to stop system if server was spawned - if slf.exit { - ctx.run_later(Duration::from_millis(300), |_, _| { - System::current().stop(); - }); - } - } - fut::ok(()) - }), - ); - } - - if !self.workers.is_empty() { - Response::async(rx.into_future().map(|_| ()).map_err(|_| ())) - } else { - // we need to stop system if server was spawned - if self.exit { - ctx.run_later(Duration::from_millis(300), |_, _| { - System::current().stop(); - }); - } - Response::reply(Ok(())) - } - } -} - -enum Command { - Pause, - Resume, - Stop, - Worker(usize, mpsc::UnboundedSender>), -} - -fn start_accept_thread( - token: usize, sock: Socket, srv: mpsc::UnboundedSender, - socks: Slab, - mut workers: Vec<(usize, mpsc::UnboundedSender>)>, -) -> (mio::SetReadiness, sync_mpsc::Sender) { - let (tx, rx) = sync_mpsc::channel(); - let (reg, readiness) = mio::Registration::new2(); - - // start accept thread - #[cfg_attr(feature = "cargo-clippy", allow(cyclomatic_complexity))] - let _ = thread::Builder::new() - .name(format!("Accept on {}", sock.addr)) - .spawn(move || { - const SRV: mio::Token = mio::Token(0); - const CMD: mio::Token = mio::Token(1); - - let addr = sock.addr; - let mut server = Some( - mio::net::TcpListener::from_std(sock.lst) - .expect("Can not create mio::net::TcpListener"), - ); - - // Create a poll instance - let poll = match mio::Poll::new() { - Ok(poll) => poll, - Err(err) => panic!("Can not create mio::Poll: {}", err), - }; - - // Start listening for incoming connections - if let Some(ref srv) = server { - if let Err(err) = - poll.register(srv, SRV, mio::Ready::readable(), mio::PollOpt::edge()) - { - panic!("Can not register io: {}", err); - } - } - - // Start listening for incoming commands - if let Err(err) = - poll.register(®, CMD, mio::Ready::readable(), mio::PollOpt::edge()) - { - panic!("Can not register Registration: {}", err); - } - - // Create storage for events - let mut events = mio::Events::with_capacity(128); - - // Sleep on error - let sleep = Duration::from_millis(100); - - let mut next = 0; - loop { - if let Err(err) = poll.poll(&mut events, None) { - panic!("Poll error: {}", err); - } - - for event in events.iter() { - match event.token() { - SRV => if let Some(ref server) = server { - loop { - match server.accept_std() { - Ok((io, addr)) => { - let mut msg = Conn { - io, - token, - peer: Some(addr), - http2: false, - }; - while !workers.is_empty() { - match workers[next].1.unbounded_send(msg) { - Ok(_) => (), - Err(err) => { - let _ = srv.unbounded_send( - ServerCommand::WorkerDied( - workers[next].0, - socks.clone(), - ), - ); - msg = err.into_inner(); - workers.swap_remove(next); - if workers.is_empty() { - error!("No workers"); - thread::sleep(sleep); - break; - } else if workers.len() <= next { - next = 0; - } - continue; - } - } - next = (next + 1) % workers.len(); - break; - } - } - Err(ref e) - if e.kind() == io::ErrorKind::WouldBlock => - { - break - } - Err(ref e) if connection_error(e) => continue, - Err(e) => { - error!("Error accepting connection: {}", e); - // sleep after error - thread::sleep(sleep); - break; - } - } - } - }, - CMD => match rx.try_recv() { - Ok(cmd) => match cmd { - Command::Pause => if let Some(ref server) = server { - if let Err(err) = poll.deregister(server) { - error!( - "Can not deregister server socket {}", - err - ); - } else { - info!( - "Paused accepting connections on {}", - addr - ); - } - }, - Command::Resume => { - if let Some(ref server) = server { - if let Err(err) = poll.register( - server, - SRV, - mio::Ready::readable(), - mio::PollOpt::edge(), - ) { - error!("Can not resume socket accept process: {}", err); - } else { - info!("Accepting connections on {} has been resumed", - addr); - } - } - } - Command::Stop => { - if let Some(server) = server.take() { - let _ = poll.deregister(&server); - } - return; - } - Command::Worker(idx, addr) => { - workers.push((idx, addr)); - } - }, - Err(err) => match err { - sync_mpsc::TryRecvError::Empty => (), - sync_mpsc::TryRecvError::Disconnected => { - if let Some(server) = server.take() { - let _ = poll.deregister(&server); - } - return; - } - }, - }, - _ => unreachable!(), - } - } - } - }); - - (readiness, tx) -} - -fn create_tcp_listener( - addr: net::SocketAddr, backlog: i32, -) -> io::Result { - let builder = match addr { - net::SocketAddr::V4(_) => TcpBuilder::new_v4()?, - net::SocketAddr::V6(_) => TcpBuilder::new_v6()?, - }; - builder.reuse_address(true)?; - builder.bind(addr)?; - Ok(builder.listen(backlog)?) -} - -/// This function defines errors that are per-connection. Which basically -/// means that if we get this error from `accept()` system call it means -/// next connection might be ready to be accepted. -/// -/// All other errors will incur a timeout before next `accept()` is performed. -/// The timeout is useful to handle resource exhaustion errors like ENFILE -/// and EMFILE. Otherwise, could enter into tight loop. -fn connection_error(e: &io::Error) -> bool { - e.kind() == io::ErrorKind::ConnectionRefused - || e.kind() == io::ErrorKind::ConnectionAborted - || e.kind() == io::ErrorKind::ConnectionReset -} diff --git a/src/server/ssl/mod.rs b/src/server/ssl/mod.rs new file mode 100644 index 000000000..c09573fe3 --- /dev/null +++ b/src/server/ssl/mod.rs @@ -0,0 +1,12 @@ +#[cfg(any(feature = "alpn", feature = "ssl"))] +mod openssl; +#[cfg(any(feature = "alpn", feature = "ssl"))] +pub use self::openssl::{openssl_acceptor_with_flags, OpensslAcceptor}; + +#[cfg(feature = "tls")] +mod nativetls; + +#[cfg(feature = "rust-tls")] +mod rustls; +#[cfg(feature = "rust-tls")] +pub use self::rustls::RustlsAcceptor; diff --git a/src/server/ssl/nativetls.rs b/src/server/ssl/nativetls.rs new file mode 100644 index 000000000..a9797ffb3 --- /dev/null +++ b/src/server/ssl/nativetls.rs @@ -0,0 +1,34 @@ +use std::net::{Shutdown, SocketAddr}; +use std::{io, time}; + +use actix_net::ssl::TlsStream; + +use server::IoStream; + +impl IoStream for TlsStream { + #[inline] + fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { + let _ = self.get_mut().shutdown(); + Ok(()) + } + + #[inline] + fn peer_addr(&self) -> Option { + self.get_ref().get_ref().peer_addr() + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.get_mut().get_mut().set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_linger(dur) + } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_keepalive(dur) + } +} diff --git a/src/server/ssl/openssl.rs b/src/server/ssl/openssl.rs new file mode 100644 index 000000000..9d370f8be --- /dev/null +++ b/src/server/ssl/openssl.rs @@ -0,0 +1,87 @@ +use std::net::{Shutdown, SocketAddr}; +use std::{io, time}; + +use actix_net::ssl; +use openssl::ssl::{AlpnError, SslAcceptor, SslAcceptorBuilder}; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_openssl::SslStream; + +use server::{IoStream, ServerFlags}; + +/// Support `SSL` connections via openssl package +/// +/// `ssl` feature enables `OpensslAcceptor` type +pub struct OpensslAcceptor { + _t: ssl::OpensslAcceptor, +} + +impl OpensslAcceptor { + /// Create `OpensslAcceptor` with enabled `HTTP/2` and `HTTP1.1` support. + pub fn new(builder: SslAcceptorBuilder) -> io::Result> { + OpensslAcceptor::with_flags(builder, ServerFlags::HTTP1 | ServerFlags::HTTP2) + } + + /// Create `OpensslAcceptor` with custom server flags. + pub fn with_flags( + builder: SslAcceptorBuilder, flags: ServerFlags, + ) -> io::Result> { + let acceptor = openssl_acceptor_with_flags(builder, flags)?; + + Ok(ssl::OpensslAcceptor::new(acceptor)) + } +} + +/// Configure `SslAcceptorBuilder` with custom server flags. +pub fn openssl_acceptor_with_flags( + mut builder: SslAcceptorBuilder, flags: ServerFlags, +) -> io::Result { + let mut protos = Vec::new(); + if flags.contains(ServerFlags::HTTP1) { + protos.extend(b"\x08http/1.1"); + } + if flags.contains(ServerFlags::HTTP2) { + protos.extend(b"\x02h2"); + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else { + Err(AlpnError::NOACK) + } + }); + } + + if !protos.is_empty() { + builder.set_alpn_protos(&protos)?; + } + + Ok(builder.build()) +} + +impl IoStream for SslStream { + #[inline] + fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { + let _ = self.get_mut().shutdown(); + Ok(()) + } + + #[inline] + fn peer_addr(&self) -> Option { + self.get_ref().get_ref().peer_addr() + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.get_mut().get_mut().set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_linger(dur) + } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().get_mut().set_keepalive(dur) + } +} diff --git a/src/server/ssl/rustls.rs b/src/server/ssl/rustls.rs new file mode 100644 index 000000000..a53a53a98 --- /dev/null +++ b/src/server/ssl/rustls.rs @@ -0,0 +1,87 @@ +use std::net::{Shutdown, SocketAddr}; +use std::{io, time}; + +use actix_net::ssl; //::RustlsAcceptor; +use rustls::{ClientSession, ServerConfig, ServerSession}; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_rustls::TlsStream; + +use server::{IoStream, ServerFlags}; + +/// Support `SSL` connections via rustls package +/// +/// `rust-tls` feature enables `RustlsAcceptor` type +pub struct RustlsAcceptor { + _t: ssl::RustlsAcceptor, +} + +impl RustlsAcceptor { + /// Create `RustlsAcceptor` with custom server flags. + pub fn with_flags( + mut config: ServerConfig, flags: ServerFlags, + ) -> ssl::RustlsAcceptor { + let mut protos = Vec::new(); + if flags.contains(ServerFlags::HTTP2) { + protos.push("h2".to_string()); + } + if flags.contains(ServerFlags::HTTP1) { + protos.push("http/1.1".to_string()); + } + if !protos.is_empty() { + config.set_protocols(&protos); + } + + ssl::RustlsAcceptor::new(config) + } +} + +impl IoStream for TlsStream { + #[inline] + fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { + let _ = ::shutdown(self); + Ok(()) + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.get_mut().0.set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().0.set_linger(dur) + } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().0.set_keepalive(dur) + } +} + +impl IoStream for TlsStream { + #[inline] + fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> { + let _ = ::shutdown(self); + Ok(()) + } + + #[inline] + fn peer_addr(&self) -> Option { + self.get_ref().0.peer_addr() + } + + #[inline] + fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { + self.get_mut().0.set_nodelay(nodelay) + } + + #[inline] + fn set_linger(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().0.set_linger(dur) + } + + #[inline] + fn set_keepalive(&mut self, dur: Option) -> io::Result<()> { + self.get_mut().0.set_keepalive(dur) + } +} diff --git a/src/server/worker.rs b/src/server/worker.rs deleted file mode 100644 index 8fd3fe601..000000000 --- a/src/server/worker.rs +++ /dev/null @@ -1,252 +0,0 @@ -use futures::sync::oneshot; -use futures::Future; -use net2::TcpStreamExt; -use slab::Slab; -use std::rc::Rc; -use std::{net, time}; -use tokio::executor::current_thread; -use tokio_reactor::Handle; -use tokio_tcp::TcpStream; - -#[cfg(any(feature = "tls", feature = "alpn"))] -use futures::future; - -#[cfg(feature = "tls")] -use native_tls::TlsAcceptor; -#[cfg(feature = "tls")] -use tokio_tls::TlsAcceptorExt; - -#[cfg(feature = "alpn")] -use openssl::ssl::SslAcceptor; -#[cfg(feature = "alpn")] -use tokio_openssl::SslAcceptorExt; - -use actix::msgs::StopArbiter; -use actix::{Actor, Arbiter, AsyncContext, Context, Handler, Message, Response}; - -use server::channel::HttpChannel; -use server::settings::{ServerSettings, WorkerSettings}; -use server::{HttpHandler, KeepAlive}; - -#[derive(Message)] -pub(crate) struct Conn { - pub io: T, - pub token: usize, - pub peer: Option, - pub http2: bool, -} - -#[derive(Clone)] -pub(crate) struct SocketInfo { - pub addr: net::SocketAddr, - pub htype: StreamHandlerType, -} - -/// Stop worker message. Returns `true` on successful shutdown -/// and `false` if some connections still alive. -pub(crate) struct StopWorker { - pub graceful: Option, -} - -impl Message for StopWorker { - type Result = Result; -} - -/// Http worker -/// -/// Worker accepts Socket objects via unbounded channel and start requests -/// processing. -pub(crate) struct Worker -where - H: HttpHandler + 'static, -{ - settings: Rc>, - socks: Slab, - tcp_ka: Option, -} - -impl Worker { - pub(crate) fn new( - h: Vec, socks: Slab, keep_alive: KeepAlive, - settings: ServerSettings, - ) -> Worker { - let tcp_ka = if let KeepAlive::Tcp(val) = keep_alive { - Some(time::Duration::new(val as u64, 0)) - } else { - None - }; - - Worker { - settings: Rc::new(WorkerSettings::new(h, keep_alive, settings)), - socks, - tcp_ka, - } - } - - fn update_time(&self, ctx: &mut Context) { - self.settings.update_date(); - ctx.run_later(time::Duration::new(1, 0), |slf, ctx| slf.update_time(ctx)); - } - - fn shutdown_timeout( - &self, ctx: &mut Context, tx: oneshot::Sender, dur: time::Duration, - ) { - // sleep for 1 second and then check again - ctx.run_later(time::Duration::new(1, 0), move |slf, ctx| { - let num = slf.settings.num_channels(); - if num == 0 { - let _ = tx.send(true); - Arbiter::current().do_send(StopArbiter(0)); - } else if let Some(d) = dur.checked_sub(time::Duration::new(1, 0)) { - slf.shutdown_timeout(ctx, tx, d); - } else { - info!("Force shutdown http worker, {} connections", num); - slf.settings.head().traverse::(); - let _ = tx.send(false); - Arbiter::current().do_send(StopArbiter(0)); - } - }); - } -} - -impl Actor for Worker -where - H: HttpHandler + 'static, -{ - type Context = Context; - - fn started(&mut self, ctx: &mut Self::Context) { - self.update_time(ctx); - } -} - -impl Handler> for Worker -where - H: HttpHandler + 'static, -{ - type Result = (); - - fn handle(&mut self, msg: Conn, _: &mut Context) { - if self.tcp_ka.is_some() && msg.io.set_keepalive(self.tcp_ka).is_err() { - error!("Can not set socket keep-alive option"); - } - self.socks - .get_mut(msg.token) - .unwrap() - .htype - .handle(Rc::clone(&self.settings), msg); - } -} - -/// `StopWorker` message handler -impl Handler for Worker -where - H: HttpHandler + 'static, -{ - type Result = Response; - - fn handle(&mut self, msg: StopWorker, ctx: &mut Context) -> Self::Result { - let num = self.settings.num_channels(); - if num == 0 { - info!("Shutting down http worker, 0 connections"); - Response::reply(Ok(true)) - } else if let Some(dur) = msg.graceful { - info!("Graceful http worker shutdown, {} connections", num); - let (tx, rx) = oneshot::channel(); - self.shutdown_timeout(ctx, tx, dur); - Response::async(rx.map_err(|_| ())) - } else { - info!("Force shutdown http worker, {} connections", num); - self.settings.head().traverse::(); - Response::reply(Ok(false)) - } - } -} - -#[derive(Clone)] -pub(crate) enum StreamHandlerType { - Normal, - #[cfg(feature = "tls")] - Tls(TlsAcceptor), - #[cfg(feature = "alpn")] - Alpn(SslAcceptor), -} - -impl StreamHandlerType { - fn handle( - &mut self, h: Rc>, msg: Conn, - ) { - match *self { - StreamHandlerType::Normal => { - let _ = msg.io.set_nodelay(true); - let io = TcpStream::from_std(msg.io, &Handle::default()) - .expect("failed to associate TCP stream"); - - current_thread::spawn(HttpChannel::new(h, io, msg.peer, msg.http2)); - } - #[cfg(feature = "tls")] - StreamHandlerType::Tls(ref acceptor) => { - let Conn { - io, peer, http2, .. - } = msg; - let _ = io.set_nodelay(true); - let io = TcpStream::from_std(io, &Handle::default()) - .expect("failed to associate TCP stream"); - - current_thread::spawn(TlsAcceptorExt::accept_async(acceptor, io).then( - move |res| { - match res { - Ok(io) => current_thread::spawn(HttpChannel::new( - h, io, peer, http2, - )), - Err(err) => { - trace!("Error during handling tls connection: {}", err) - } - }; - future::result(Ok(())) - }, - )); - } - #[cfg(feature = "alpn")] - StreamHandlerType::Alpn(ref acceptor) => { - let Conn { io, peer, .. } = msg; - let _ = io.set_nodelay(true); - let io = TcpStream::from_std(io, &Handle::default()) - .expect("failed to associate TCP stream"); - - current_thread::spawn(SslAcceptorExt::accept_async(acceptor, io).then( - move |res| { - match res { - Ok(io) => { - let http2 = if let Some(p) = - io.get_ref().ssl().selected_alpn_protocol() - { - p.len() == 2 && &p == b"h2" - } else { - false - }; - current_thread::spawn(HttpChannel::new( - h, io, peer, http2, - )); - } - Err(err) => { - trace!("Error during handling tls connection: {}", err) - } - }; - future::result(Ok(())) - }, - )); - } - } - } - - pub(crate) fn scheme(&self) -> &'static str { - match *self { - StreamHandlerType::Normal => "http", - #[cfg(feature = "tls")] - StreamHandlerType::Tls(_) => "https", - #[cfg(feature = "alpn")] - StreamHandlerType::Alpn(_) => "https", - } - } -} diff --git a/src/test.rs b/src/test.rs index c2e5c7569..d0cfb255a 100644 --- a/src/test.rs +++ b/src/test.rs @@ -13,14 +13,16 @@ use http::{HeaderMap, HttpTryFrom, Method, Uri, Version}; use net2::TcpBuilder; use tokio::runtime::current_thread::Runtime; -#[cfg(feature = "alpn")] +#[cfg(any(feature = "alpn", feature = "ssl"))] use openssl::ssl::SslAcceptorBuilder; +#[cfg(feature = "rust-tls")] +use rustls::ServerConfig; use application::{App, HttpApplication}; use body::Binary; use client::{ClientConnector, ClientRequest, ClientRequestBuilder}; use error::Error; -use handler::{AsyncResultItem, Handler, Responder}; +use handler::{AsyncResult, AsyncResultItem, Handler, Responder}; use header::{Header, IntoHeaderValue}; use httprequest::HttpRequest; use httpresponse::HttpResponse; @@ -73,13 +75,13 @@ impl TestServer { /// middlewares or set handlers for test application. pub fn new(config: F) -> Self where - F: Sync + Send + 'static + Fn(&mut TestApp<()>), + F: Clone + Send + 'static + Fn(&mut TestApp<()>), { TestServerBuilder::new(|| ()).start(config) } /// Create test server builder - pub fn build() -> TestServerBuilder<()> { + pub fn build() -> TestServerBuilder<(), impl Fn() -> () + Clone + Send + 'static> { TestServerBuilder::new(|| ()) } @@ -88,19 +90,18 @@ impl TestServer { /// This method can be used for constructing application state. /// Also it can be used for external dependency initialization, /// like creating sync actors for diesel integration. - pub fn build_with_state(state: F) -> TestServerBuilder + pub fn build_with_state(state: F) -> TestServerBuilder where - F: Fn() -> S + Sync + Send + 'static, + F: Fn() -> S + Clone + Send + 'static, S: 'static, { TestServerBuilder::new(state) } /// Start new test server with application factory - pub fn with_factory(factory: F) -> Self + pub fn with_factory(factory: F) -> Self where - F: Fn() -> U + Sync + Send + 'static, - U: IntoIterator + 'static, + F: Fn() -> H + Send + Clone + 'static, H: IntoHttpHandler + 'static, { let (tx, rx) = mpsc::channel(); @@ -111,9 +112,10 @@ impl TestServer { let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); let local_addr = tcp.local_addr().unwrap(); - HttpServer::new(factory) + let _ = HttpServer::new(factory) .disable_signals() .listen(tcp) + .keep_alive(5) .start(); tx.send((System::current(), local_addr, TestServer::get_conn())) @@ -132,7 +134,7 @@ impl TestServer { } fn get_conn() -> Addr { - #[cfg(feature = "alpn")] + #[cfg(any(feature = "alpn", feature = "ssl"))] { use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; @@ -140,7 +142,20 @@ impl TestServer { builder.set_verify(SslVerifyMode::NONE); ClientConnector::with_connector(builder.build()).start() } - #[cfg(not(feature = "alpn"))] + #[cfg(all( + feature = "rust-tls", + not(any(feature = "alpn", feature = "ssl")) + ))] + { + use rustls::ClientConfig; + use std::fs::File; + use std::io::BufReader; + let mut config = ClientConfig::new(); + let pem_file = &mut BufReader::new(File::open("tests/cert.pem").unwrap()); + config.root_store.add_pem_file(pem_file).unwrap(); + ClientConnector::with_connector(config).start() + } + #[cfg(not(any(feature = "alpn", feature = "ssl", feature = "rust-tls")))] { ClientConnector::default().start() } @@ -165,16 +180,16 @@ impl TestServer { pub fn url(&self, uri: &str) -> String { if uri.starts_with('/') { format!( - "{}://{}{}", + "{}://localhost:{}{}", if self.ssl { "https" } else { "http" }, - self.addr, + self.addr.port(), uri ) } else { format!( - "{}://{}/{}", + "{}://localhost:{}/{}", if self.ssl { "https" } else { "http" }, - self.addr, + self.addr.port(), uri ) } @@ -193,13 +208,20 @@ impl TestServer { self.rt.block_on(fut) } - /// Connect to websocket server + /// Connect to websocket server at a given path + pub fn ws_at( + &mut self, path: &str, + ) -> Result<(ws::ClientReader, ws::ClientWriter), ws::ClientError> { + let url = self.url(path); + self.rt + .block_on(ws::Client::with_connector(url, self.conn.clone()).connect()) + } + + /// Connect to a websocket server pub fn ws( &mut self, ) -> Result<(ws::ClientReader, ws::ClientWriter), ws::ClientError> { - let url = self.url("/"); - self.rt - .block_on(ws::Client::with_connector(url, self.conn.clone()).connect()) + self.ws_at("/") } /// Create `GET` request @@ -237,75 +259,105 @@ impl Drop for TestServer { /// /// This type can be used to construct an instance of `TestServer` through a /// builder-like pattern. -pub struct TestServerBuilder { - state: Box S + Sync + Send + 'static>, - #[cfg(feature = "alpn")] +pub struct TestServerBuilder +where + F: Fn() -> S + Send + Clone + 'static, +{ + state: F, + #[cfg(any(feature = "alpn", feature = "ssl"))] ssl: Option, + #[cfg(feature = "rust-tls")] + rust_ssl: Option, } -impl TestServerBuilder { +impl TestServerBuilder +where + F: Fn() -> S + Send + Clone + 'static, +{ /// Create a new test server - pub fn new(state: F) -> TestServerBuilder - where - F: Fn() -> S + Sync + Send + 'static, - { + pub fn new(state: F) -> TestServerBuilder { TestServerBuilder { - state: Box::new(state), - #[cfg(feature = "alpn")] + state, + #[cfg(any(feature = "alpn", feature = "ssl"))] ssl: None, + #[cfg(feature = "rust-tls")] + rust_ssl: None, } } - #[cfg(feature = "alpn")] + #[cfg(any(feature = "alpn", feature = "ssl"))] /// Create ssl server pub fn ssl(mut self, ssl: SslAcceptorBuilder) -> Self { self.ssl = Some(ssl); self } + #[cfg(feature = "rust-tls")] + /// Create rust tls server + pub fn rustls(mut self, ssl: ServerConfig) -> Self { + self.rust_ssl = Some(ssl); + self + } + #[allow(unused_mut)] /// Configure test application and run test server - pub fn start(mut self, config: F) -> TestServer + pub fn start(mut self, config: C) -> TestServer where - F: Sync + Send + 'static + Fn(&mut TestApp), + C: Fn(&mut TestApp) + Clone + Send + 'static, { let (tx, rx) = mpsc::channel(); - #[cfg(feature = "alpn")] - let ssl = self.ssl.is_some(); - #[cfg(not(feature = "alpn"))] - let ssl = false; + let mut has_ssl = false; + + #[cfg(any(feature = "alpn", feature = "ssl"))] + { + has_ssl = has_ssl || self.ssl.is_some(); + } + + #[cfg(feature = "rust-tls")] + { + has_ssl = has_ssl || self.rust_ssl.is_some(); + } // run server in separate thread thread::spawn(move || { - let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); - let local_addr = tcp.local_addr().unwrap(); + let addr = TestServer::unused_addr(); let sys = System::new("actix-test-server"); let state = self.state; - let srv = HttpServer::new(move || { + let mut srv = HttpServer::new(move || { let mut app = TestApp::new(state()); config(&mut app); - vec![app] + app }).workers(1) - .disable_signals(); + .keep_alive(5) + .disable_signals(); - tx.send((System::current(), local_addr, TestServer::get_conn())) + tx.send((System::current(), addr, TestServer::get_conn())) .unwrap(); - #[cfg(feature = "alpn")] + #[cfg(any(feature = "alpn", feature = "ssl"))] { let ssl = self.ssl.take(); if let Some(ssl) = ssl { - srv.listen_ssl(tcp, ssl).unwrap().start(); - } else { - srv.listen(tcp).start(); + let tcp = net::TcpListener::bind(addr).unwrap(); + srv = srv.listen_ssl(tcp, ssl).unwrap(); } } - #[cfg(not(feature = "alpn"))] + #[cfg(feature = "rust-tls")] { - srv.listen(tcp).start(); + let ssl = self.rust_ssl.take(); + if let Some(ssl) = ssl { + let tcp = net::TcpListener::bind(addr).unwrap(); + srv = srv.listen_rustls(tcp, ssl); + } } + if !has_ssl { + let tcp = net::TcpListener::bind(addr).unwrap(); + srv = srv.listen(tcp); + } + srv.start(); + sys.run(); }); @@ -313,8 +365,8 @@ impl TestServerBuilder { System::set_current(system); TestServer { addr, - ssl, conn, + ssl: has_ssl, rt: Runtime::new().unwrap(), } } @@ -549,7 +601,7 @@ impl TestRequest { payload, prefix, } = self; - let router = Router::<()>::new(); + let router = Router::<()>::default(); let pool = RequestPool::pool(ServerSettings::default()); let mut req = RequestPool::get(pool); @@ -629,8 +681,6 @@ impl TestRequest { /// This method generates `HttpRequest` instance and runs handler /// with generated request. - /// - /// This method panics is handler returns actor or async result. pub fn run>(self, h: &H) -> Result { let req = self.finish(); let resp = h.handle(&req); @@ -639,7 +689,10 @@ impl TestRequest { Ok(resp) => match resp.into().into() { AsyncResultItem::Ok(resp) => Ok(resp), AsyncResultItem::Err(err) => Err(err), - AsyncResultItem::Future(_) => panic!("Async handler is not supported."), + AsyncResultItem::Future(fut) => { + let mut sys = System::new("test"); + sys.block_on(fut) + } }, Err(err) => Err(err.into()), } @@ -659,8 +712,8 @@ impl TestRequest { let req = self.finish(); let fut = h(req.clone()); - let mut core = Runtime::new().unwrap(); - match core.block_on(fut) { + let mut sys = System::new("test"); + match sys.block_on(fut) { Ok(r) => match r.respond_to(&req) { Ok(reply) => match reply.into().into() { AsyncResultItem::Ok(resp) => Ok(resp), @@ -671,4 +724,45 @@ impl TestRequest { Err(err) => Err(err), } } + + /// This method generates `HttpRequest` instance and executes handler + pub fn run_async_result(self, f: F) -> Result + where + F: FnOnce(&HttpRequest) -> R, + R: Into>, + { + let req = self.finish(); + let res = f(&req); + + match res.into().into() { + AsyncResultItem::Ok(resp) => Ok(resp), + AsyncResultItem::Err(err) => Err(err), + AsyncResultItem::Future(fut) => { + let mut sys = System::new("test"); + sys.block_on(fut) + } + } + } + + /// This method generates `HttpRequest` instance and executes handler + pub fn execute(self, f: F) -> Result + where + F: FnOnce(&HttpRequest) -> R, + R: Responder + 'static, + { + let req = self.finish(); + let resp = f(&req); + + match resp.respond_to(&req) { + Ok(resp) => match resp.into().into() { + AsyncResultItem::Ok(resp) => Ok(resp), + AsyncResultItem::Err(err) => Err(err), + AsyncResultItem::Future(fut) => { + let mut sys = System::new("test"); + sys.block_on(fut) + } + }, + Err(err) => Err(err.into()), + } + } } diff --git a/src/uri.rs b/src/uri.rs index 752ddad86..c87cb3d5b 100644 --- a/src/uri.rs +++ b/src/uri.rs @@ -1,25 +1,12 @@ use http::Uri; use std::rc::Rc; -#[allow(dead_code)] -const GEN_DELIMS: &[u8] = b":/?#[]@"; -#[allow(dead_code)] -const SUB_DELIMS_WITHOUT_QS: &[u8] = b"!$'()*,"; -#[allow(dead_code)] -const SUB_DELIMS: &[u8] = b"!$'()*,+?=;"; -#[allow(dead_code)] -const RESERVED: &[u8] = b":/?#[]@!$'()*,+?=;"; -#[allow(dead_code)] -const UNRESERVED: &[u8] = b"abcdefghijklmnopqrstuvwxyz - ABCDEFGHIJKLMNOPQRSTUVWXYZ - 1234567890 - -._~"; -const ALLOWED: &[u8] = b"abcdefghijklmnopqrstuvwxyz - ABCDEFGHIJKLMNOPQRSTUVWXYZ - 1234567890 - -._~ - !$'()*,"; -const QS: &[u8] = b"+&=;b"; +// https://tools.ietf.org/html/rfc3986#section-2.2 +const RESERVED_PLUS_EXTRA: &[u8] = b":/?#[]@!$&'()*,+?;=%^ <>\"\\`{}|"; + +// https://tools.ietf.org/html/rfc3986#section-2.3 +const UNRESERVED: &[u8] = + b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-._~"; #[inline] fn bit_at(array: &[u8], ch: u8) -> bool { @@ -32,7 +19,8 @@ fn set_bit(array: &mut [u8], ch: u8) { } lazy_static! { - static ref DEFAULT_QUOTER: Quoter = { Quoter::new(b"@:", b"/+") }; + static ref UNRESERVED_QUOTER: Quoter = { Quoter::new(UNRESERVED) }; + pub(crate) static ref RESERVED_QUOTER: Quoter = { Quoter::new(RESERVED_PLUS_EXTRA) }; } #[derive(Default, Clone, Debug)] @@ -43,7 +31,7 @@ pub(crate) struct Url { impl Url { pub fn new(uri: Uri) -> Url { - let path = DEFAULT_QUOTER.requote(uri.path().as_bytes()); + let path = UNRESERVED_QUOTER.requote(uri.path().as_bytes()); Url { uri, path } } @@ -63,36 +51,19 @@ impl Url { pub(crate) struct Quoter { safe_table: [u8; 16], - protected_table: [u8; 16], } impl Quoter { - pub fn new(safe: &[u8], protected: &[u8]) -> Quoter { + pub fn new(safe: &[u8]) -> Quoter { let mut q = Quoter { safe_table: [0; 16], - protected_table: [0; 16], }; // prepare safe table - for i in 0..128 { - if ALLOWED.contains(&i) { - set_bit(&mut q.safe_table, i); - } - if QS.contains(&i) { - set_bit(&mut q.safe_table, i); - } - } - for ch in safe { set_bit(&mut q.safe_table, *ch) } - // prepare protected table - for ch in protected { - set_bit(&mut q.safe_table, *ch); - set_bit(&mut q.protected_table, *ch); - } - q } @@ -115,19 +86,17 @@ impl Quoter { if let Some(ch) = restore_ch(pct[1], pct[2]) { if ch < 128 { - if bit_at(&self.protected_table, ch) { - buf.extend_from_slice(&pct); - idx += 1; - continue; - } - if bit_at(&self.safe_table, ch) { buf.push(ch); idx += 1; continue; } + + buf.extend_from_slice(&pct); + } else { + // Not ASCII, decode it + buf.push(ch); } - buf.push(ch); } else { buf.extend_from_slice(&pct[..]); } @@ -148,7 +117,7 @@ impl Quoter { if let Some(data) = cloned { // Unsafe: we get data from http::Uri, which does utf-8 checks already // this code only decodes valid pct encoded values - Some(unsafe { Rc::new(String::from_utf8_unchecked(data)) }) + Some(Rc::new(unsafe { String::from_utf8_unchecked(data) })) } else { None } @@ -172,3 +141,37 @@ fn from_hex(v: u8) -> Option { fn restore_ch(d1: u8, d2: u8) -> Option { from_hex(d1).and_then(|d1| from_hex(d2).and_then(move |d2| Some(d1 << 4 | d2))) } + + +#[cfg(test)] +mod tests { + use std::rc::Rc; + + use super::*; + + #[test] + fn decode_path() { + assert_eq!(UNRESERVED_QUOTER.requote(b"https://localhost:80/foo"), None); + + assert_eq!( + Rc::try_unwrap(UNRESERVED_QUOTER.requote( + b"https://localhost:80/foo%25" + ).unwrap()).unwrap(), + "https://localhost:80/foo%25".to_string() + ); + + assert_eq!( + Rc::try_unwrap(UNRESERVED_QUOTER.requote( + b"http://cache-service/http%3A%2F%2Flocalhost%3A80%2Ffoo" + ).unwrap()).unwrap(), + "http://cache-service/http%3A%2F%2Flocalhost%3A80%2Ffoo".to_string() + ); + + assert_eq!( + Rc::try_unwrap(UNRESERVED_QUOTER.requote( + b"http://cache/http%3A%2F%2Flocal%3A80%2Ffile%2F%252Fvar%252Flog%0A" + ).unwrap()).unwrap(), + "http://cache/http%3A%2F%2Flocal%3A80%2Ffile%2F%252Fvar%252Flog%0A".to_string() + ); + } +} \ No newline at end of file diff --git a/src/with.rs b/src/with.rs index 0af626c8b..140e086e1 100644 --- a/src/with.rs +++ b/src/with.rs @@ -7,24 +7,57 @@ use handler::{AsyncResult, AsyncResultItem, FromRequest, Handler, Responder}; use httprequest::HttpRequest; use httpresponse::HttpResponse; -pub(crate) struct With +trait FnWith: 'static { + fn call_with(self: &Self, T) -> R; +} + +impl R + 'static> FnWith for F { + fn call_with(self: &Self, arg: T) -> R { + (*self)(arg) + } +} + +#[doc(hidden)] +pub trait WithFactory: 'static +where + T: FromRequest, + R: Responder, +{ + fn create(self) -> With; + + fn create_with_config(self, T::Config) -> With; +} + +#[doc(hidden)] +pub trait WithAsyncFactory: 'static +where + T: FromRequest, + R: Future, + I: Responder, + E: Into, +{ + fn create(self) -> WithAsync; + + fn create_with_config(self, T::Config) -> WithAsync; +} + +#[doc(hidden)] +pub struct With where - F: Fn(T) -> R, T: FromRequest, S: 'static, { - hnd: Rc, + hnd: Rc>, cfg: Rc, _s: PhantomData, } -impl With +impl With where - F: Fn(T) -> R, T: FromRequest, S: 'static, { - pub fn new(f: F, cfg: T::Config) -> Self { + pub fn new R + 'static>(f: F, cfg: T::Config) -> Self { With { cfg: Rc::new(cfg), hnd: Rc::new(f), @@ -33,9 +66,8 @@ where } } -impl Handler for With +impl Handler for With where - F: Fn(T) -> R + 'static, R: Responder + 'static, T: FromRequest + 'static, S: 'static, @@ -54,30 +86,28 @@ where match fut.poll() { Ok(Async::Ready(resp)) => AsyncResult::ok(resp), - Ok(Async::NotReady) => AsyncResult::async(Box::new(fut)), + Ok(Async::NotReady) => AsyncResult::future(Box::new(fut)), Err(e) => AsyncResult::err(e), } } } -struct WithHandlerFut +struct WithHandlerFut where - F: Fn(T) -> R, R: Responder, T: FromRequest + 'static, S: 'static, { started: bool, - hnd: Rc, + hnd: Rc>, cfg: Rc, req: HttpRequest, fut1: Option>>, fut2: Option>>, } -impl Future for WithHandlerFut +impl Future for WithHandlerFut where - F: Fn(T) -> R, R: Responder + 'static, T: FromRequest + 'static, S: 'static, @@ -108,7 +138,7 @@ where } }; - let item = match (*self.hnd)(item).respond_to(&self.req) { + let item = match self.hnd.as_ref().call_with(item).respond_to(&self.req) { Ok(item) => item.into(), Err(e) => return Err(e.into()), }; @@ -124,30 +154,29 @@ where } } -pub(crate) struct WithAsync +#[doc(hidden)] +pub struct WithAsync where - F: Fn(T) -> R, R: Future, I: Responder, E: Into, T: FromRequest, S: 'static, { - hnd: Rc, + hnd: Rc>, cfg: Rc, _s: PhantomData, } -impl WithAsync +impl WithAsync where - F: Fn(T) -> R, R: Future, I: Responder, E: Into, T: FromRequest, S: 'static, { - pub fn new(f: F, cfg: T::Config) -> Self { + pub fn new R + 'static>(f: F, cfg: T::Config) -> Self { WithAsync { cfg: Rc::new(cfg), hnd: Rc::new(f), @@ -156,9 +185,8 @@ where } } -impl Handler for WithAsync +impl Handler for WithAsync where - F: Fn(T) -> R + 'static, R: Future + 'static, I: Responder + 'static, E: Into + 'static, @@ -180,15 +208,14 @@ where match fut.poll() { Ok(Async::Ready(resp)) => AsyncResult::ok(resp), - Ok(Async::NotReady) => AsyncResult::async(Box::new(fut)), + Ok(Async::NotReady) => AsyncResult::future(Box::new(fut)), Err(e) => AsyncResult::err(e), } } } -struct WithAsyncHandlerFut +struct WithAsyncHandlerFut where - F: Fn(T) -> R, R: Future + 'static, I: Responder + 'static, E: Into + 'static, @@ -196,7 +223,7 @@ where S: 'static, { started: bool, - hnd: Rc, + hnd: Rc>, cfg: Rc, req: HttpRequest, fut1: Option>>, @@ -204,9 +231,8 @@ where fut3: Option>>, } -impl Future for WithAsyncHandlerFut +impl Future for WithAsyncHandlerFut where - F: Fn(T) -> R, R: Future + 'static, I: Responder + 'static, E: Into + 'static, @@ -257,7 +283,101 @@ where } }; - self.fut2 = Some((*self.hnd)(item)); + self.fut2 = Some(self.hnd.as_ref().call_with(item)); self.poll() } } + +macro_rules! with_factory_tuple ({$(($n:tt, $T:ident)),+} => { + impl<$($T,)+ State, Func, Res> WithFactory<($($T,)+), State, Res> for Func + where Func: Fn($($T,)+) -> Res + 'static, + $($T: FromRequest + 'static,)+ + Res: Responder + 'static, + State: 'static, + { + fn create(self) -> With<($($T,)+), State, Res> { + With::new(move |($($n,)+)| (self)($($n,)+), ($($T::Config::default(),)+)) + } + + fn create_with_config(self, cfg: ($($T::Config,)+)) -> With<($($T,)+), State, Res> { + With::new(move |($($n,)+)| (self)($($n,)+), cfg) + } + } +}); + +macro_rules! with_async_factory_tuple ({$(($n:tt, $T:ident)),+} => { + impl<$($T,)+ State, Func, Res, Item, Err> WithAsyncFactory<($($T,)+), State, Res, Item, Err> for Func + where Func: Fn($($T,)+) -> Res + 'static, + $($T: FromRequest + 'static,)+ + Res: Future, + Item: Responder + 'static, + Err: Into, + State: 'static, + { + fn create(self) -> WithAsync<($($T,)+), State, Res, Item, Err> { + WithAsync::new(move |($($n,)+)| (self)($($n,)+), ($($T::Config::default(),)+)) + } + + fn create_with_config(self, cfg: ($($T::Config,)+)) -> WithAsync<($($T,)+), State, Res, Item, Err> { + WithAsync::new(move |($($n,)+)| (self)($($n,)+), cfg) + } + } +}); + +with_factory_tuple!((a, A)); +with_factory_tuple!((a, A), (b, B)); +with_factory_tuple!((a, A), (b, B), (c, C)); +with_factory_tuple!((a, A), (b, B), (c, C), (d, D)); +with_factory_tuple!((a, A), (b, B), (c, C), (d, D), (e, E)); +with_factory_tuple!((a, A), (b, B), (c, C), (d, D), (e, E), (f, F)); +with_factory_tuple!((a, A), (b, B), (c, C), (d, D), (e, E), (f, F), (g, G)); +with_factory_tuple!( + (a, A), + (b, B), + (c, C), + (d, D), + (e, E), + (f, F), + (g, G), + (h, H) +); +with_factory_tuple!( + (a, A), + (b, B), + (c, C), + (d, D), + (e, E), + (f, F), + (g, G), + (h, H), + (i, I) +); + +with_async_factory_tuple!((a, A)); +with_async_factory_tuple!((a, A), (b, B)); +with_async_factory_tuple!((a, A), (b, B), (c, C)); +with_async_factory_tuple!((a, A), (b, B), (c, C), (d, D)); +with_async_factory_tuple!((a, A), (b, B), (c, C), (d, D), (e, E)); +with_async_factory_tuple!((a, A), (b, B), (c, C), (d, D), (e, E), (f, F)); +with_async_factory_tuple!((a, A), (b, B), (c, C), (d, D), (e, E), (f, F), (g, G)); +with_async_factory_tuple!( + (a, A), + (b, B), + (c, C), + (d, D), + (e, E), + (f, F), + (g, G), + (h, H) +); +with_async_factory_tuple!( + (a, A), + (b, B), + (c, C), + (d, D), + (e, E), + (f, F), + (g, G), + (h, H), + (i, I) +); diff --git a/src/ws/context.rs b/src/ws/context.rs index 4db83df5c..5e207d43e 100644 --- a/src/ws/context.rs +++ b/src/ws/context.rs @@ -231,6 +231,13 @@ where pub fn handle(&self) -> SpawnHandle { self.inner.curr_handle() } + + /// Set mailbox capacity + /// + /// By default mailbox capacity is 16 messages. + pub fn set_mailbox_capacity(&mut self, cap: usize) { + self.inner.set_mailbox_capacity(cap) + } } impl WsWriter for WebsocketContext diff --git a/src/ws/mask.rs b/src/ws/mask.rs index e9bfb3d56..18ce57bb7 100644 --- a/src/ws/mask.rs +++ b/src/ws/mask.rs @@ -50,7 +50,10 @@ pub(crate) fn apply_mask(buf: &mut [u8], mask_u32: u32) { // TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so // inefficient, it could be done better. The compiler does not understand that // a `ShortSlice` must be smaller than a u64. -#[cfg_attr(feature = "cargo-clippy", allow(needless_pass_by_value))] +#[cfg_attr( + feature = "cargo-clippy", + allow(needless_pass_by_value) +)] fn xor_short(buf: ShortSlice, mask: u64) { // Unsafe: we know that a `ShortSlice` fits in a u64 unsafe { diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 6b37bc7e0..c16f8d6d2 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -387,8 +387,7 @@ mod tests { .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .finish(); + ).finish(); assert_eq!( HandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap() @@ -398,12 +397,10 @@ mod tests { .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + ).header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .finish(); + ).finish(); assert_eq!( HandshakeError::NoVersionHeader, handshake(&req).err().unwrap() @@ -413,16 +410,13 @@ mod tests { .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + ).header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + ).header( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("5"), - ) - .finish(); + ).finish(); assert_eq!( HandshakeError::UnsupportedVersion, handshake(&req).err().unwrap() @@ -432,16 +426,13 @@ mod tests { .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + ).header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + ).header( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), - ) - .finish(); + ).finish(); assert_eq!( HandshakeError::BadWebsocketKey, handshake(&req).err().unwrap() @@ -451,20 +442,16 @@ mod tests { .header( header::UPGRADE, header::HeaderValue::from_static("websocket"), - ) - .header( + ).header( header::CONNECTION, header::HeaderValue::from_static("upgrade"), - ) - .header( + ).header( header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13"), - ) - .header( + ).header( header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("13"), - ) - .finish(); + ).finish(); assert_eq!( StatusCode::SWITCHING_PROTOCOLS, handshake(&req).unwrap().finish().status() diff --git a/tests/cert.pem b/tests/cert.pem index 159aacea2..db04fbfae 100644 --- a/tests/cert.pem +++ b/tests/cert.pem @@ -1,31 +1,31 @@ -----BEGIN CERTIFICATE----- -MIIFPjCCAyYCCQDvLYiYD+jqeTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJV -UzELMAkGA1UECAwCQ0ExCzAJBgNVBAcMAlNGMRAwDgYDVQQKDAdDb21wYW55MQww -CgYDVQQLDANPcmcxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xODAxMjUx -NzQ2MDFaFw0xOTAxMjUxNzQ2MDFaMGExCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJD -QTELMAkGA1UEBwwCU0YxEDAOBgNVBAoMB0NvbXBhbnkxDDAKBgNVBAsMA09yZzEY -MBYGA1UEAwwPd3d3LmV4YW1wbGUuY29tMIICIjANBgkqhkiG9w0BAQEFAAOCAg8A -MIICCgKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEPn8k1 -sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+MIK5U -NLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM54jXy -voLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZWLWr -odGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAkoqND -xdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNliJDmA -CRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6/stI -yFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuDYX2U -UuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nPwPTO -vRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA69un -CEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEAATAN -BgkqhkiG9w0BAQsFAAOCAgEApavsgsn7SpPHfhDSN5iZs1ILZQRewJg0Bty0xPfk -3tynSW6bNH3nSaKbpsdmxxomthNSQgD2heOq1By9YzeOoNR+7Pk3s4FkASnf3ToI -JNTUasBFFfaCG96s4Yvs8KiWS/k84yaWuU8c3Wb1jXs5Rv1qE1Uvuwat1DSGXSoD -JNluuIkCsC4kWkyq5pWCGQrabWPRTWsHwC3PTcwSRBaFgYLJaR72SloHB1ot02zL -d2age9dmFRFLLCBzP+D7RojBvL37qS/HR+rQ4SoQwiVc/JzaeqSe7ZbvEH9sZYEu -ALowJzgbwro7oZflwTWunSeSGDSltkqKjvWvZI61pwfHKDahUTmZ5h2y67FuGEaC -CIOUI8dSVSPKITxaq3JL4ze2e9/0Lt7hj19YK2uUmtMAW5Tirz4Yx5lyGH9U8Wur -y/X8VPxTc4A9TMlJgkyz0hqvhbPOT/zSWB10zXh0glKAsSBryAOEDxV1UygmSir7 -YV8Qaq+oyKUTMc1MFq5vZ07M51EPaietn85t8V2Y+k/8XYltRp32NxsypxAJuyxh -g/ko6RVTrWa1sMvz/F9LFqAdKiK5eM96lh9IU4xiLg4ob8aS/GRAA8oIFkZFhLrt -tOwjIUPmEPyHWFi8dLpNuQKYalLYhuwZftG/9xV+wqhKGZO9iPrpHSYBRTap8w2y -1QU= +MIIFXTCCA0WgAwIBAgIJAJ3tqfd0MLLNMA0GCSqGSIb3DQEBCwUAMGExCzAJBgNV +BAYTAlVTMQswCQYDVQQIDAJDRjELMAkGA1UEBwwCU0YxEDAOBgNVBAoMB0NvbXBh +bnkxDDAKBgNVBAsMA09yZzEYMBYGA1UEAwwPd3d3LmV4YW1wbGUuY29tMB4XDTE4 +MDcyOTE4MDgzNFoXDTE5MDcyOTE4MDgzNFowYTELMAkGA1UEBhMCVVMxCzAJBgNV +BAgMAkNGMQswCQYDVQQHDAJTRjEQMA4GA1UECgwHQ29tcGFueTEMMAoGA1UECwwD +T3JnMRgwFgYDVQQDDA93d3cuZXhhbXBsZS5jb20wggIiMA0GCSqGSIb3DQEBAQUA +A4ICDwAwggIKAoICAQDZbMgDYilVH1Nv0QWEhOXG6ETmtjZrdLqrNg3NBWBIWCDF +cQ+fyTWxARx6vkF8A/3zpJyTcfQW8HgG38jw/A61QKaHBxzwq0HlNwY9Hh+Neeuk +L4wgrlQ0uTC7IEMrOJjNN0GPyRQVfVbGa8QcSCpOg85l8GCxLvVwkBH/M5atoMtJ +EzniNfK+gtk3hOL2tBqBCu9NDjhXPnJwNDLtTG1tQaHUJW/r281Wvv9I46H83DkU +05lYtauh0bKh5znCH2KpFmBGqJNRzou3tXZFZzZfaCPBJPZR8j5TjoinehpDtkPh +4CSio0PF2eIFkDKRUbdz/327HgEARJMXx+w1yHpS2JwHFgy5O76i68/Smx8j3DDA +2WIkOYAJFRMH0CBHKdsvUDOGpCgN+xv3whl+N806nCfC4vCkwA+FuB3ko11logng +dvr+y0jIUSU4THF3dMDEXYayF3+WrUlw0cBnUNJdXky85ZP81aBfBsjNSBDx4iL4 +e4NhfZRS5oHpHy1t3nYfuttS/oet+Ke5KUpaqNJguSIoeTBSmgzDzL1TJxFLOzUT +2c/A9M69FdvSY0JB4EJX0W9K01Vd0JRNPwsY+/zvFIPama3suKOUTqYcsbwxx9xa +TMDr26cIQcgUAUOKZO43sQGWNzXX3FYVNwczKhkB8UX6hOrBJsEYiau4LGdokQID +AQABoxgwFjAUBgNVHREEDTALgglsb2NhbGhvc3QwDQYJKoZIhvcNAQELBQADggIB +AIX+Qb4QRBxHl5X2UjRyLfWVkimtGlwI8P+eJZL3DrHBH/TpqAaCvTf0EbRC32nm +ASDMwIghaMvyrW40QN6V/CWRRi25cXUfsIZr1iHAHK0eZJV8SWooYtt4iNrcUs3g +4OTvDxhNmDyNwV9AXhJsBKf80dCW6/84jItqVAj20/OO4Rkd2tEeI8NomiYBc6a1 +hgwvv02myYF5hG/xZ9YSqeroBCZHwGYoJJnSpMPqJsxbCVnx2/U9FzGwcRmNHFCe +0g7EJZd3//8Plza6nkTBjJ/V7JnLqMU+ltx4mAgZO8rfzIr84qZdt0YN33VJQhYq +seuMySxrsuaAoxAmm8IoK9cW4IPzx1JveBQiroNlq5YJGf2UW7BTc3gz6c2tINZi +7ailBVdhlMnDXAf3/9xiiVlRAHOxgZh/7sRrKU7kDEHM4fGoc0YyZBTQKndPYMwO +3Bd82rlQ4sd46XYutTrB+mBYClVrJs+OzbNedTsR61DVNKKsRG4mNPyKSAIgOfM5 +XmSvCMPN5JK9U0DsNIV2/SnVsmcklQczT35FLTxl9ntx8ys7ZYK+SppD7XuLfWMq +GT9YMWhlpw0aRDg/aayeeOcnsNBhzAFMcOpQj1t6Fgv4+zbS9BM2bT0hbX86xjkr +E6wWgkuCslMgQlEJ+TM5RhYrI5/rVZQhvmgcob/9gPZv -----END CERTIFICATE----- diff --git a/tests/identity.pfx b/tests/identity.pfx new file mode 100644 index 000000000..946e3b8b8 Binary files /dev/null and b/tests/identity.pfx differ diff --git a/tests/test_client.rs b/tests/test_client.rs index cf20fb8b8..9808f3e6f 100644 --- a/tests/test_client.rs +++ b/tests/test_client.rs @@ -5,8 +5,11 @@ extern crate bytes; extern crate flate2; extern crate futures; extern crate rand; +#[cfg(all(unix, feature = "uds"))] +extern crate tokio_uds; -use std::io::Read; +use std::io::{Read, Write}; +use std::{net, thread}; use bytes::Bytes; use flate2::read::GzDecoder; @@ -64,6 +67,16 @@ fn test_simple() { assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } +#[test] +fn test_connection_close() { + let mut srv = + test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR))); + + let request = srv.get().header("Connection", "close").finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); +} + #[test] fn test_with_query_parameter() { let mut srv = test::TestServer::new(|app| { @@ -116,8 +129,7 @@ fn test_client_gzip_encoding() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Deflate) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -146,8 +158,7 @@ fn test_client_gzip_encoding_large() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Deflate) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -168,7 +179,7 @@ fn test_client_gzip_encoding_large() { #[test] fn test_client_gzip_encoding_large_random() { let data = rand::thread_rng() - .gen_ascii_chars() + .sample_iter(&rand::distributions::Alphanumeric) .take(100_000) .collect::(); @@ -179,8 +190,7 @@ fn test_client_gzip_encoding_large_random() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Deflate) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -198,6 +208,13 @@ fn test_client_gzip_encoding_large_random() { assert_eq!(bytes, Bytes::from(data)); } +#[cfg(all(unix, feature = "uds"))] +#[test] +fn test_compatible_with_unix_socket_stream() { + let (stream, _) = tokio_uds::UnixStream::pair().unwrap(); + let _ = client::Connection::from_stream(stream); +} + #[cfg(feature = "brotli")] #[test] fn test_client_brotli_encoding() { @@ -208,8 +225,7 @@ fn test_client_brotli_encoding() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Gzip) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -231,7 +247,7 @@ fn test_client_brotli_encoding() { #[test] fn test_client_brotli_encoding_large_random() { let data = rand::thread_rng() - .gen_ascii_chars() + .sample_iter(&rand::distributions::Alphanumeric) .take(70_000) .collect::(); @@ -242,8 +258,7 @@ fn test_client_brotli_encoding_large_random() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Gzip) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -272,8 +287,7 @@ fn test_client_deflate_encoding() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Br) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -295,7 +309,7 @@ fn test_client_deflate_encoding() { #[test] fn test_client_deflate_encoding_large_random() { let data = rand::thread_rng() - .gen_ascii_chars() + .sample_iter(&rand::distributions::Alphanumeric) .take(70_000) .collect::(); @@ -306,8 +320,7 @@ fn test_client_deflate_encoding_large_random() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Br) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -336,8 +349,7 @@ fn test_client_streaming_explicit() { .chunked() .content_encoding(http::ContentEncoding::Identity) .body(body)) - }) - .responder() + }).responder() }) }); @@ -395,24 +407,29 @@ fn test_client_cookie_handling() { let cookie2 = cookie2b.clone(); app.handler(move |req: &HttpRequest| { // Check cookies were sent correctly - req.cookie("cookie1").ok_or_else(err) - .and_then(|c1| if c1.value() == "value1" { + req.cookie("cookie1") + .ok_or_else(err) + .and_then(|c1| { + if c1.value() == "value1" { Ok(()) } else { Err(err()) - }) - .and_then(|()| req.cookie("cookie2").ok_or_else(err)) - .and_then(|c2| if c2.value() == "value2" { + } + }).and_then(|()| req.cookie("cookie2").ok_or_else(err)) + .and_then(|c2| { + if c2.value() == "value2" { Ok(()) } else { Err(err()) - }) - // Send some cookies back - .map(|_| HttpResponse::Ok() - .cookie(cookie1.clone()) - .cookie(cookie2.clone()) - .finish() - ) + } + }) + // Send some cookies back + .map(|_| { + HttpResponse::Ok() + .cookie(cookie1.clone()) + .cookie(cookie2.clone()) + .finish() + }) }) }); @@ -438,7 +455,7 @@ fn test_default_headers() { let repr = format!("{:?}", request); assert!(repr.contains("\"accept-encoding\": \"gzip, deflate\"")); assert!(repr.contains(concat!( - "\"user-agent\": \"Actix-web/", + "\"user-agent\": \"actix-web/", env!("CARGO_PKG_VERSION"), "\"" ))); @@ -459,3 +476,33 @@ fn test_default_headers() { "\"" ))); } + +#[test] +fn client_read_until_eof() { + let addr = test::TestServer::unused_addr(); + + thread::spawn(move || { + let lst = net::TcpListener::bind(addr).unwrap(); + + for stream in lst.incoming() { + let mut stream = stream.unwrap(); + let mut b = [0; 1000]; + let _ = stream.read(&mut b).unwrap(); + let _ = stream + .write_all(b"HTTP/1.1 200 OK\r\nconnection: close\r\n\r\nwelcome!"); + } + }); + + let mut sys = actix::System::new("test"); + + // client request + let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) + .finish() + .unwrap(); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = sys.block_on(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(b"welcome!")); +} diff --git a/tests/test_custom_pipeline.rs b/tests/test_custom_pipeline.rs new file mode 100644 index 000000000..6b5df00e3 --- /dev/null +++ b/tests/test_custom_pipeline.rs @@ -0,0 +1,81 @@ +extern crate actix; +extern crate actix_net; +extern crate actix_web; + +use std::{thread, time}; + +use actix::System; +use actix_net::server::Server; +use actix_net::service::NewServiceExt; +use actix_web::server::{HttpService, KeepAlive, ServiceConfig, StreamConfiguration}; +use actix_web::{client, http, test, App, HttpRequest}; + +#[test] +fn test_custom_pipeline() { + let addr = test::TestServer::unused_addr(); + + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + let app = App::new() + .route("/", http::Method::GET, |_: HttpRequest| "OK") + .finish(); + let settings = ServiceConfig::build(app) + .keep_alive(KeepAlive::Disabled) + .client_timeout(1000) + .client_shutdown(1000) + .server_hostname("localhost") + .server_address(addr) + .finish(); + + StreamConfiguration::new() + .nodelay(true) + .tcp_keepalive(Some(time::Duration::from_secs(10))) + .and_then(HttpService::new(settings)) + }).unwrap() + .run(); + }); + + let mut sys = System::new("test"); + { + let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) + .finish() + .unwrap(); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + } +} + +#[test] +fn test_h1() { + use actix_web::server::H1Service; + + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + Server::new() + .bind("test", addr, move || { + let app = App::new() + .route("/", http::Method::GET, |_: HttpRequest| "OK") + .finish(); + let settings = ServiceConfig::build(app) + .keep_alive(KeepAlive::Disabled) + .client_timeout(1000) + .client_shutdown(1000) + .server_hostname("localhost") + .server_address(addr) + .finish(); + + H1Service::new(settings) + }).unwrap() + .run(); + }); + + let mut sys = System::new("test"); + { + let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) + .finish() + .unwrap(); + let response = sys.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + } +} diff --git a/tests/test_handlers.rs b/tests/test_handlers.rs index c86a3e9c0..debc1626a 100644 --- a/tests/test_handlers.rs +++ b/tests/test_handlers.rs @@ -191,8 +191,7 @@ fn test_form_extractor() { .uri(srv.url("/test1/index.html")) .form(FormData { username: "test".to_string(), - }) - .unwrap(); + }).unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -208,7 +207,7 @@ fn test_form_extractor2() { r.route().with_config( |form: Form| format!("{}", form.username), |cfg| { - cfg.error_handler(|err, _| { + cfg.0.error_handler(|err, _| { error::InternalError::from_response( err, HttpResponse::Conflict().finish(), @@ -306,8 +305,7 @@ fn test_path_and_query_extractor2_async() { Delay::new(Instant::now() + Duration::from_millis(10)) .and_then(move |_| { Ok(format!("Welcome {} - {}!", p.username, data.0)) - }) - .responder() + }).responder() }, ) }); @@ -336,8 +334,7 @@ fn test_path_and_query_extractor3_async() { Delay::new(Instant::now() + Duration::from_millis(10)) .and_then(move |_| { Ok(format!("Welcome {} - {}!", p.username, data.0)) - }) - .responder() + }).responder() }) }); }); @@ -361,8 +358,7 @@ fn test_path_and_query_extractor4_async() { Delay::new(Instant::now() + Duration::from_millis(10)) .and_then(move |_| { Ok(format!("Welcome {} - {}!", p.username, data.0)) - }) - .responder() + }).responder() }) }); }); @@ -387,8 +383,7 @@ fn test_path_and_query_extractor2_async2() { Delay::new(Instant::now() + Duration::from_millis(10)) .and_then(move |_| { Ok(format!("Welcome {} - {}!", p.username, data.0)) - }) - .responder() + }).responder() }, ) }); @@ -422,15 +417,13 @@ fn test_path_and_query_extractor2_async2() { fn test_path_and_query_extractor2_async3() { let mut srv = test::TestServer::new(|app| { app.resource("/{username}/index.html", |r| { - r.route().with( - |(data, p, _q): (Json, Path, Query)| { + r.route() + .with(|data: Json, p: Path, _: Query| { Delay::new(Instant::now() + Duration::from_millis(10)) .and_then(move |_| { Ok(format!("Welcome {} - {}!", p.username, data.0)) - }) - .responder() - }, - ) + }).responder() + }) }); }); @@ -467,8 +460,7 @@ fn test_path_and_query_extractor2_async4() { Delay::new(Instant::now() + Duration::from_millis(10)) .and_then(move |_| { Ok(format!("Welcome {} - {}!", data.1.username, (data.0).0)) - }) - .responder() + }).responder() }) }); }); @@ -680,6 +672,6 @@ fn test_unsafe_path_route() { let bytes = srv.execute(response.body()).unwrap(); assert_eq!( bytes, - Bytes::from_static(b"success: http:%2F%2Fexample.com") + Bytes::from_static(b"success: http%3A%2F%2Fexample.com") ); } diff --git a/tests/test_middleware.rs b/tests/test_middleware.rs index 4fa1c81da..6cb6ee363 100644 --- a/tests/test_middleware.rs +++ b/tests/test_middleware.rs @@ -84,11 +84,10 @@ fn test_middleware_multiple() { response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), }).middleware(MiddlewareTest { - start: Arc::clone(&act_num1), - response: Arc::clone(&act_num2), - finish: Arc::clone(&act_num3), - }) - .handler(|_| HttpResponse::Ok()) + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }).handler(|_| HttpResponse::Ok()) }); let request = srv.get().finish().unwrap(); @@ -143,11 +142,10 @@ fn test_resource_middleware_multiple() { response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), }).middleware(MiddlewareTest { - start: Arc::clone(&act_num1), - response: Arc::clone(&act_num2), - finish: Arc::clone(&act_num3), - }) - .handler(|_| HttpResponse::Ok()) + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }).handler(|_| HttpResponse::Ok()) }); let request = srv.get().finish().unwrap(); @@ -176,8 +174,7 @@ fn test_scope_middleware() { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .resource("/test", |r| r.f(|_| HttpResponse::Ok())) + }).resource("/test", |r| r.f(|_| HttpResponse::Ok())) }) }); @@ -207,13 +204,11 @@ fn test_scope_middleware_multiple() { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .middleware(MiddlewareTest { + }).middleware(MiddlewareTest { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .resource("/test", |r| r.f(|_| HttpResponse::Ok())) + }).resource("/test", |r| r.f(|_| HttpResponse::Ok())) }) }); @@ -242,8 +237,7 @@ fn test_middleware_async_handler() { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .resource("/", |r| { + }).resource("/", |r| { r.route().a(|_| { Delay::new(Instant::now() + Duration::from_millis(10)) .and_then(|_| Ok(HttpResponse::Ok())) @@ -312,8 +306,7 @@ fn test_scope_middleware_async_handler() { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .resource("/test", |r| { + }).resource("/test", |r| { r.route().a(|_| { Delay::new(Instant::now() + Duration::from_millis(10)) .and_then(|_| Ok(HttpResponse::Ok())) @@ -379,8 +372,7 @@ fn test_scope_middleware_async_error() { start: Arc::clone(&act_req), response: Arc::clone(&act_resp), finish: Arc::clone(&act_fin), - }) - .resource("/test", |r| r.f(index_test_middleware_async_error)) + }).resource("/test", |r| r.f(index_test_middleware_async_error)) }) }); @@ -514,13 +506,11 @@ fn test_async_middleware_multiple() { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .middleware(MiddlewareAsyncTest { + }).middleware(MiddlewareAsyncTest { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .resource("/test", |r| r.f(|_| HttpResponse::Ok())) + }).resource("/test", |r| r.f(|_| HttpResponse::Ok())) }); let request = srv.get().uri(srv.url("/test")).finish().unwrap(); @@ -550,13 +540,11 @@ fn test_async_sync_middleware_multiple() { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .middleware(MiddlewareTest { + }).middleware(MiddlewareTest { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .resource("/test", |r| r.f(|_| HttpResponse::Ok())) + }).resource("/test", |r| r.f(|_| HttpResponse::Ok())) }); let request = srv.get().uri(srv.url("/test")).finish().unwrap(); @@ -587,8 +575,7 @@ fn test_async_scope_middleware() { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .resource("/test", |r| r.f(|_| HttpResponse::Ok())) + }).resource("/test", |r| r.f(|_| HttpResponse::Ok())) }) }); @@ -620,13 +607,11 @@ fn test_async_scope_middleware_multiple() { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .middleware(MiddlewareAsyncTest { + }).middleware(MiddlewareAsyncTest { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .resource("/test", |r| r.f(|_| HttpResponse::Ok())) + }).resource("/test", |r| r.f(|_| HttpResponse::Ok())) }) }); @@ -658,13 +643,11 @@ fn test_async_async_scope_middleware_multiple() { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .middleware(MiddlewareTest { + }).middleware(MiddlewareTest { start: Arc::clone(&act_num1), response: Arc::clone(&act_num2), finish: Arc::clone(&act_num3), - }) - .resource("/test", |r| r.f(|_| HttpResponse::Ok())) + }).resource("/test", |r| r.f(|_| HttpResponse::Ok())) }) }); @@ -1012,8 +995,7 @@ fn test_session_storage_middleware() { App::new() .middleware(SessionStorage::new( CookieSessionBackend::signed(&[0; 32]).secure(false), - )) - .resource("/index", move |r| { + )).resource("/index", move |r| { r.f(|req| { let res = req.session().set(COMPLEX_NAME, COMPLEX_PAYLOAD); assert!(res.is_ok()); @@ -1033,8 +1015,7 @@ fn test_session_storage_middleware() { HttpResponse::Ok() }) - }) - .resource("/expect_cookie", move |r| { + }).resource("/expect_cookie", move |r| { r.f(|req| { let _cookies = req.cookies().expect("To get cookies"); diff --git a/tests/test_server.rs b/tests/test_server.rs index 82a318e59..f3c9bf9dd 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -1,4 +1,5 @@ extern crate actix; +extern crate actix_net; extern crate actix_web; #[cfg(feature = "brotli")] extern crate brotli2; @@ -9,18 +10,27 @@ extern crate h2; extern crate http as modhttp; extern crate rand; extern crate tokio; +extern crate tokio_current_thread; +extern crate tokio_current_thread as current_thread; extern crate tokio_reactor; extern crate tokio_tcp; +#[cfg(feature = "tls")] +extern crate native_tls; +#[cfg(feature = "ssl")] +extern crate openssl; +#[cfg(feature = "rust-tls")] +extern crate rustls; + use std::io::{Read, Write}; -use std::sync::{mpsc, Arc}; -use std::{net, thread, time}; +use std::sync::Arc; +use std::{thread, time}; #[cfg(feature = "brotli")] use brotli2::write::{BrotliDecoder, BrotliEncoder}; use bytes::{Bytes, BytesMut}; use flate2::read::GzDecoder; -use flate2::write::{DeflateDecoder, DeflateEncoder, GzEncoder}; +use flate2::write::{GzEncoder, ZlibDecoder, ZlibEncoder}; use flate2::Compression; use futures::stream::once; use futures::{Future, Stream}; @@ -28,11 +38,10 @@ use h2::client as h2client; use modhttp::Request; use rand::distributions::Alphanumeric; use rand::Rng; -use tokio::executor::current_thread; use tokio::runtime::current_thread::Runtime; +use tokio_current_thread::spawn; use tokio_tcp::TcpStream; -use actix::System; use actix_web::*; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ @@ -60,6 +69,9 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ #[test] #[cfg(unix)] fn test_start() { + use actix::System; + use std::sync::mpsc; + let _ = test::TestServer::unused_addr(); let (tx, rx) = mpsc::channel(); @@ -117,6 +129,10 @@ fn test_start() { #[test] #[cfg(unix)] fn test_shutdown() { + use actix::System; + use std::net; + use std::sync::mpsc; + let _ = test::TestServer::unused_addr(); let (tx, rx) = mpsc::channel(); @@ -153,6 +169,64 @@ fn test_shutdown() { let _ = sys.stop(); } +#[test] +#[cfg(unix)] +fn test_panic() { + use actix::System; + use std::sync::mpsc; + + let _ = test::TestServer::unused_addr(); + let (tx, rx) = mpsc::channel(); + + thread::spawn(|| { + System::run(move || { + let srv = server::new(|| { + App::new() + .resource("/panic", |r| { + r.method(http::Method::GET).f(|_| -> &'static str { + panic!("error"); + }); + }).resource("/", |r| { + r.method(http::Method::GET).f(|_| HttpResponse::Ok()) + }) + }).workers(1); + + let srv = srv.bind("127.0.0.1:0").unwrap(); + let addr = srv.addrs()[0]; + srv.start(); + let _ = tx.send((addr, System::current())); + }); + }); + let (addr, sys) = rx.recv().unwrap(); + System::set_current(sys.clone()); + + let mut rt = Runtime::new().unwrap(); + { + let req = client::ClientRequest::get(format!("http://{}/panic", addr).as_str()) + .finish() + .unwrap(); + let response = rt.block_on(req.send()); + assert!(response.is_err()); + } + + { + let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) + .finish() + .unwrap(); + let response = rt.block_on(req.send()); + assert!(response.is_err()); + } + { + let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()) + .finish() + .unwrap(); + let response = rt.block_on(req.send()).unwrap(); + assert!(response.status().is_success()); + } + + let _ = sys.stop(); +} + #[test] fn test_simple() { let mut srv = test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok())); @@ -472,7 +546,7 @@ fn test_body_chunked_explicit() { #[test] fn test_body_identity() { - let mut e = DeflateEncoder::new(Vec::new(), Compression::default()); + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); e.write_all(STR.as_ref()).unwrap(); let enc = e.finish().unwrap(); let enc2 = enc.clone(); @@ -522,7 +596,7 @@ fn test_body_deflate() { let bytes = srv.execute(response.body()).unwrap(); // decode deflate - let mut e = DeflateDecoder::new(Vec::new()); + let mut e = ZlibDecoder::new(Vec::new()); e.write_all(bytes.as_ref()).unwrap(); let dec = e.finish().unwrap(); assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); @@ -563,8 +637,7 @@ fn test_gzip_encoding() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Identity) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -596,8 +669,7 @@ fn test_gzip_encoding_large() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Identity) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -633,8 +705,7 @@ fn test_reading_gzip_encoding_large_random() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Identity) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -666,12 +737,11 @@ fn test_reading_deflate_encoding() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Identity) .body(bytes)) - }) - .responder() + }).responder() }) }); - let mut e = DeflateEncoder::new(Vec::new(), Compression::default()); + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); e.write_all(STR.as_ref()).unwrap(); let enc = e.finish().unwrap(); @@ -699,12 +769,11 @@ fn test_reading_deflate_encoding_large() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Identity) .body(bytes)) - }) - .responder() + }).responder() }) }); - let mut e = DeflateEncoder::new(Vec::new(), Compression::default()); + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); e.write_all(data.as_ref()).unwrap(); let enc = e.finish().unwrap(); @@ -736,12 +805,11 @@ fn test_reading_deflate_encoding_large_random() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Identity) .body(bytes)) - }) - .responder() + }).responder() }) }); - let mut e = DeflateEncoder::new(Vec::new(), Compression::default()); + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); e.write_all(data.as_ref()).unwrap(); let enc = e.finish().unwrap(); @@ -770,8 +838,7 @@ fn test_brotli_encoding() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Identity) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -804,8 +871,7 @@ fn test_brotli_encoding_large() { Ok(HttpResponse::Ok() .content_encoding(http::ContentEncoding::Identity) .body(bytes)) - }) - .responder() + }).responder() }) }); @@ -827,10 +893,214 @@ fn test_brotli_encoding_large() { assert_eq!(bytes, Bytes::from(data)); } +#[cfg(all(feature = "brotli", feature = "ssl"))] +#[test] +fn test_brotli_encoding_large_ssl() { + use actix::{Actor, System}; + use openssl::ssl::{ + SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode, + }; + // load ssl keys + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder + .set_private_key_file("tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("tests/cert.pem") + .unwrap(); + + let data = STR.repeat(10); + let srv = test::TestServer::build().ssl(builder).start(|app| { + app.handler(|req: &HttpRequest| { + req.body() + .and_then(|bytes: Bytes| { + Ok(HttpResponse::Ok() + .content_encoding(http::ContentEncoding::Identity) + .body(bytes)) + }).responder() + }) + }); + let mut rt = System::new("test"); + + // client connector + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let conn = client::ClientConnector::with_connector(builder.build()).start(); + + // body + let mut e = BrotliEncoder::new(Vec::new(), 5); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); + + // client request + let request = client::ClientRequest::build() + .uri(srv.url("/")) + .method(http::Method::POST) + .header(http::header::CONTENT_ENCODING, "br") + .with_connector(conn) + .body(enc) + .unwrap(); + let response = rt.block_on(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = rt.block_on(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from(data)); +} + +#[cfg(all(feature = "rust-tls", feature = "ssl"))] +#[test] +fn test_reading_deflate_encoding_large_random_ssl() { + use actix::{Actor, System}; + use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + use rustls::internal::pemfile::{certs, rsa_private_keys}; + use rustls::{NoClientAuth, ServerConfig}; + use std::fs::File; + use std::io::BufReader; + + // load ssl keys + let mut config = ServerConfig::new(NoClientAuth::new()); + let cert_file = &mut BufReader::new(File::open("tests/cert.pem").unwrap()); + let key_file = &mut BufReader::new(File::open("tests/key.pem").unwrap()); + let cert_chain = certs(cert_file).unwrap(); + let mut keys = rsa_private_keys(key_file).unwrap(); + config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + + let data = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(160_000) + .collect::(); + + let srv = test::TestServer::build().rustls(config).start(|app| { + app.handler(|req: &HttpRequest| { + req.body() + .and_then(|bytes: Bytes| { + Ok(HttpResponse::Ok() + .content_encoding(http::ContentEncoding::Identity) + .body(bytes)) + }).responder() + }) + }); + + let mut rt = System::new("test"); + + // client connector + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let conn = client::ClientConnector::with_connector(builder.build()).start(); + + // encode data + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); + + // client request + let request = client::ClientRequest::build() + .uri(srv.url("/")) + .method(http::Method::POST) + .header(http::header::CONTENT_ENCODING, "deflate") + .with_connector(conn) + .body(enc) + .unwrap(); + let response = rt.block_on(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = rt.block_on(response.body()).unwrap(); + assert_eq!(bytes.len(), data.len()); + assert_eq!(bytes, Bytes::from(data)); +} + +#[cfg(all(feature = "tls", feature = "ssl"))] +#[test] +fn test_reading_deflate_encoding_large_random_tls() { + use native_tls::{Identity, TlsAcceptor}; + use openssl::ssl::{ + SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode, + }; + use std::fs::File; + use std::sync::mpsc; + + use actix::{Actor, System}; + let (tx, rx) = mpsc::channel(); + + // load ssl keys + let mut file = File::open("tests/identity.pfx").unwrap(); + let mut identity = vec![]; + file.read_to_end(&mut identity).unwrap(); + let identity = Identity::from_pkcs12(&identity, "1").unwrap(); + let acceptor = TlsAcceptor::new(identity).unwrap(); + + // load ssl keys + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder + .set_private_key_file("tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("tests/cert.pem") + .unwrap(); + + let data = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(160_000) + .collect::(); + + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + System::run(move || { + server::new(|| { + App::new().handler("/", |req: &HttpRequest| { + req.body() + .and_then(|bytes: Bytes| { + Ok(HttpResponse::Ok() + .content_encoding(http::ContentEncoding::Identity) + .body(bytes)) + }).responder() + }) + }).bind_tls(addr, acceptor) + .unwrap() + .start(); + let _ = tx.send(System::current()); + }); + }); + let sys = rx.recv().unwrap(); + + let mut rt = System::new("test"); + + // client connector + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let conn = client::ClientConnector::with_connector(builder.build()).start(); + + // encode data + let mut e = ZlibEncoder::new(Vec::new(), Compression::default()); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); + + // client request + let request = client::ClientRequest::build() + .uri(format!("https://{}/", addr)) + .method(http::Method::POST) + .header(http::header::CONTENT_ENCODING, "deflate") + .with_connector(conn) + .body(enc) + .unwrap(); + let response = rt.block_on(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = rt.block_on(response.body()).unwrap(); + assert_eq!(bytes.len(), data.len()); + assert_eq!(bytes, Bytes::from(data)); + + let _ = sys.stop(); +} + #[test] fn test_h2() { let srv = test::TestServer::new(|app| app.handler(|_| HttpResponse::Ok().body(STR))); let addr = srv.addr(); + thread::sleep(time::Duration::from_millis(500)); let mut core = Runtime::new().unwrap(); let tcp = TcpStream::connect(&addr); @@ -847,7 +1117,7 @@ fn test_h2() { let (response, _) = client.send_request(request, false).unwrap(); // Spawn a task to run the conn... - current_thread::spawn(h2.map_err(|e| println!("GOT ERR={:?}", e))); + spawn(h2.map_err(|e| println!("GOT ERR={:?}", e))); response.and_then(|response| { assert_eq!(response.status(), http::StatusCode::OK); @@ -874,3 +1144,257 @@ fn test_application() { let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); } + +#[test] +fn test_default_404_handler_response() { + let mut srv = test::TestServer::with_factory(|| { + App::new() + .prefix("/app") + .resource("", |r| r.f(|_| HttpResponse::Ok())) + .resource("/", |r| r.f(|_| HttpResponse::Ok())) + }); + let addr = srv.addr(); + + let mut buf = [0; 24]; + let request = TcpStream::connect(&addr) + .and_then(|sock| { + tokio::io::write_all(sock, "HEAD / HTTP/1.1\r\nHost: localhost\r\n\r\n") + .and_then(|(sock, _)| tokio::io::read_exact(sock, &mut buf)) + .and_then(|(_, buf)| Ok(buf)) + }).map_err(|e| panic!("{:?}", e)); + let response = srv.execute(request).unwrap(); + let rep = String::from_utf8_lossy(&response[..]); + assert!(rep.contains("HTTP/1.1 404 Not Found")); +} + +#[test] +fn test_server_cookies() { + use actix_web::http; + + let mut srv = test::TestServer::with_factory(|| { + App::new().resource("/", |r| { + r.f(|_| { + HttpResponse::Ok() + .cookie( + http::CookieBuilder::new("first", "first_value") + .http_only(true) + .finish(), + ).cookie(http::Cookie::new("second", "first_value")) + .cookie(http::Cookie::new("second", "second_value")) + .finish() + }) + }) + }); + + let first_cookie = http::CookieBuilder::new("first", "first_value") + .http_only(true) + .finish(); + let second_cookie = http::Cookie::new("second", "second_value"); + + let request = srv.get().finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + let cookies = response.cookies().expect("To have cookies"); + assert_eq!(cookies.len(), 2); + if cookies[0] == first_cookie { + assert_eq!(cookies[1], second_cookie); + } else { + assert_eq!(cookies[0], second_cookie); + assert_eq!(cookies[1], first_cookie); + } + + let first_cookie = first_cookie.to_string(); + let second_cookie = second_cookie.to_string(); + //Check that we have exactly two instances of raw cookie headers + let cookies = response + .headers() + .get_all(http::header::SET_COOKIE) + .iter() + .map(|header| header.to_str().expect("To str").to_string()) + .collect::>(); + assert_eq!(cookies.len(), 2); + if cookies[0] == first_cookie { + assert_eq!(cookies[1], second_cookie); + } else { + assert_eq!(cookies[0], second_cookie); + assert_eq!(cookies[1], first_cookie); + } +} + +#[test] +fn test_slow_request() { + use actix::System; + use std::net; + use std::sync::mpsc; + let (tx, rx) = mpsc::channel(); + + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + System::run(move || { + let srv = server::new(|| { + vec![App::new().resource("/", |r| { + r.method(http::Method::GET).f(|_| HttpResponse::Ok()) + })] + }); + + let srv = srv.bind(addr).unwrap(); + srv.client_timeout(200).start(); + let _ = tx.send(System::current()); + }); + }); + let sys = rx.recv().unwrap(); + + thread::sleep(time::Duration::from_millis(200)); + + let mut stream = net::TcpStream::connect(addr).unwrap(); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); + + let mut stream = net::TcpStream::connect(addr).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); + + sys.stop(); +} + +#[test] +fn test_malformed_request() { + use actix::System; + use std::net; + use std::sync::mpsc; + let (tx, rx) = mpsc::channel(); + + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + System::run(move || { + let srv = server::new(|| { + App::new().resource("/", |r| { + r.method(http::Method::GET).f(|_| HttpResponse::Ok()) + }) + }); + + let _ = srv.bind(addr).unwrap().start(); + let _ = tx.send(System::current()); + }); + }); + let sys = rx.recv().unwrap(); + thread::sleep(time::Duration::from_millis(200)); + + let mut stream = net::TcpStream::connect(addr).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP1.1\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 400 Bad Request")); + + sys.stop(); +} + +#[test] +fn test_app_404() { + let mut srv = test::TestServer::with_factory(|| { + App::new().prefix("/prefix").resource("/", |r| { + r.method(http::Method::GET).f(|_| HttpResponse::Ok()) + }) + }); + + let request = srv.client(http::Method::GET, "/prefix/").finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + let request = srv.client(http::Method::GET, "/").finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert_eq!(response.status(), http::StatusCode::NOT_FOUND); +} + +#[test] +#[cfg(feature = "ssl")] +fn test_ssl_handshake_timeout() { + use actix::System; + use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; + use std::net; + use std::sync::mpsc; + + let (tx, rx) = mpsc::channel(); + let addr = test::TestServer::unused_addr(); + + // load ssl keys + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder + .set_private_key_file("tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("tests/cert.pem") + .unwrap(); + + thread::spawn(move || { + System::run(move || { + let srv = server::new(|| { + App::new().resource("/", |r| { + r.method(http::Method::GET).f(|_| HttpResponse::Ok()) + }) + }); + + srv.bind_ssl(addr, builder) + .unwrap() + .workers(1) + .client_timeout(200) + .start(); + let _ = tx.send(System::current()); + }); + }); + let sys = rx.recv().unwrap(); + + let mut stream = net::TcpStream::connect(addr).unwrap(); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.is_empty()); + + let _ = sys.stop(); +} + +#[test] +fn test_content_length() { + use actix_web::http::header::{HeaderName, HeaderValue}; + use http::StatusCode; + + let mut srv = test::TestServer::new(move |app| { + app.resource("/{status}", |r| { + r.f(|req: &HttpRequest| { + let indx: usize = + req.match_info().get("status").unwrap().parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + HttpResponse::new(statuses[indx]) + }) + }); + }); + + let addr = srv.addr(); + let mut get_resp = |i| { + let url = format!("http://{}/{}", addr, i); + let req = srv.get().uri(url).finish().unwrap(); + srv.execute(req.send()).unwrap() + }; + + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); + + for i in 0..4 { + let response = get_resp(i); + assert_eq!(response.headers().get(&header), None); + } + for i in 4..6 { + let response = get_resp(i); + assert_eq!(response.headers().get(&header), Some(&value)); + } +} diff --git a/tests/test_ws.rs b/tests/test_ws.rs index 66a9153dc..cb46bc7e1 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -5,13 +5,19 @@ extern crate futures; extern crate http; extern crate rand; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::{thread, time}; + use bytes::Bytes; use futures::Stream; use rand::distributions::Alphanumeric; use rand::Rng; -#[cfg(feature = "alpn")] +#[cfg(feature = "ssl")] extern crate openssl; +#[cfg(feature = "rust-tls")] +extern crate rustls; use actix::prelude::*; use actix_web::*; @@ -62,6 +68,45 @@ fn test_simple() { ); } +// websocket resource helper function +fn start_ws_resource(req: &HttpRequest) -> Result { + ws::start(req, Ws) +} + +#[test] +fn test_simple_path() { + const PATH: &str = "/v1/ws/"; + + // Create a websocket at a specific path. + let mut srv = test::TestServer::new(|app| { + app.resource(PATH, |r| r.route().f(start_ws_resource)); + }); + // fetch the sockets for the resource at a given path. + let (reader, mut writer) = srv.ws_at(PATH).unwrap(); + + writer.text("text"); + let (item, reader) = srv.execute(reader.into_future()).unwrap(); + assert_eq!(item, Some(ws::Message::Text("text".to_owned()))); + + writer.binary(b"text".as_ref()); + let (item, reader) = srv.execute(reader.into_future()).unwrap(); + assert_eq!( + item, + Some(ws::Message::Binary(Bytes::from_static(b"text").into())) + ); + + writer.ping("ping"); + let (item, reader) = srv.execute(reader.into_future()).unwrap(); + assert_eq!(item, Some(ws::Message::Pong("ping".to_owned()))); + + writer.close(Some(ws::CloseCode::Normal.into())); + let (item, _) = srv.execute(reader.into_future()).unwrap(); + assert_eq!( + item, + Some(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) + ); +} + #[test] fn test_empty_close_code() { let mut srv = test::TestServer::new(|app| app.handler(|req| ws::start(req, Ws))); @@ -172,8 +217,7 @@ impl Ws2 { act.send(ctx); } actix::fut::ok(()) - }) - .wait(ctx); + }).wait(ctx); } } @@ -238,9 +282,8 @@ fn test_server_send_bin() { } #[test] -#[cfg(feature = "alpn")] +#[cfg(feature = "ssl")] fn test_ws_server_ssl() { - extern crate openssl; use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; // load ssl keys @@ -272,3 +315,81 @@ fn test_ws_server_ssl() { assert_eq!(item, data); } } + +#[test] +#[cfg(feature = "rust-tls")] +fn test_ws_server_rust_tls() { + use rustls::internal::pemfile::{certs, rsa_private_keys}; + use rustls::{NoClientAuth, ServerConfig}; + use std::fs::File; + use std::io::BufReader; + + // load ssl keys + let mut config = ServerConfig::new(NoClientAuth::new()); + let cert_file = &mut BufReader::new(File::open("tests/cert.pem").unwrap()); + let key_file = &mut BufReader::new(File::open("tests/key.pem").unwrap()); + let cert_chain = certs(cert_file).unwrap(); + let mut keys = rsa_private_keys(key_file).unwrap(); + config.set_single_cert(cert_chain, keys.remove(0)).unwrap(); + + let mut srv = test::TestServer::build().rustls(config).start(|app| { + app.handler(|req| { + ws::start( + req, + Ws2 { + count: 0, + bin: false, + }, + ) + }) + }); + + let (mut reader, _writer) = srv.ws().unwrap(); + + let data = Some(ws::Message::Text("0".repeat(65_536))); + for _ in 0..10_000 { + let (item, r) = srv.execute(reader.into_future()).unwrap(); + reader = r; + assert_eq!(item, data); + } +} + +struct WsStopped(Arc); + +impl Actor for WsStopped { + type Context = ws::WebsocketContext; + + fn stopped(&mut self, _: &mut Self::Context) { + self.0.fetch_add(1, Ordering::Relaxed); + } +} + +impl StreamHandler for WsStopped { + fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { + match msg { + ws::Message::Text(text) => ctx.text(text), + _ => (), + } + } +} + +#[test] +fn test_ws_stopped() { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); + + let mut srv = test::TestServer::new(move |app| { + let num3 = num2.clone(); + app.handler(move |req| ws::start(req, WsStopped(num3.clone()))) + }); + { + let (reader, mut writer) = srv.ws().unwrap(); + writer.text("text"); + writer.close(None); + let (item, _) = srv.execute(reader.into_future()).unwrap(); + assert_eq!(item, Some(ws::Message::Text("text".to_owned()))); + } + thread::sleep(time::Duration::from_millis(100)); + + assert_eq!(num.load(Ordering::Relaxed), 1); +}