diff --git a/README.md b/README.md index 508736279..c85c0652f 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ You may consider checking out ## Benchmarks One of the fastest web frameworks available according to the -[TechEmpower Framework Benchmark](https://www.techempower.com/benchmarks/#section=data-r19). +[TechEmpower Framework Benchmark](https://www.techempower.com/benchmarks/#section=data-r20&test=composite). ## License diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index f0f0a0255..84d6617f7 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -3,12 +3,14 @@ ## Unreleased - 2021-xx-xx ### Added * `impl MessageBody for Pin>`. [#2152] +* `Response::{ok, bad_request, not_found, internal_server_error}`. [#2159] * Helper `body::to_bytes` for async collecting message body into Bytes. [#2158] ### Changes * The type parameter of `Response` no longer has a default. [#2152] * The `Message` variant of `body::Body` is now `Pin>`. [#2152] * `BodyStream` and `SizedStream` are no longer restricted to Unpin types. [#2152] +* Error enum types are marked `#[non_exhaustive]`. [#2161] ### Removed * `cookies` feature flag. [#2065] @@ -19,11 +21,15 @@ * `ResponseBuilder::json`. [#2148] * `ResponseBuilder::{set_header, header}`. [#2148] * `impl From for Body`. [#2148] +* `Response::build_from`. [#2159] +* Most of the status code builders on `Response`. [#2159] [#2065]: https://github.com/actix/actix-web/pull/2065 [#2148]: https://github.com/actix/actix-web/pull/2148 [#2152]: https://github.com/actix/actix-web/pull/2152 +[#2159]: https://github.com/actix/actix-web/pull/2159 [#2158]: https://github.com/actix/actix-web/pull/2158 +[#2161]: https://github.com/actix/actix-web/pull/2161 ## 3.0.0-beta.5 - 2021-04-02 diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 4bef9e37c..361cae62f 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -62,6 +62,7 @@ local-channel = "0.1" once_cell = "1.5" log = "0.4" mime = "0.3" +paste = "1" percent-encoding = "2.1" pin-project = "1.0.0" pin-project-lite = "0.2" diff --git a/actix-http/examples/echo.rs b/actix-http/examples/echo.rs index 176ac5c2b..b2cdb0be1 100644 --- a/actix-http/examples/echo.rs +++ b/actix-http/examples/echo.rs @@ -1,6 +1,6 @@ use std::{env, io}; -use actix_http::{Error, HttpService, Request, Response}; +use actix_http::{http::StatusCode, Error, HttpService, Request, Response}; use actix_server::Server; use bytes::BytesMut; use futures_util::StreamExt as _; @@ -25,7 +25,7 @@ async fn main() -> io::Result<()> { info!("request body: {:?}", body); Ok::<_, Error>( - Response::Ok() + Response::build(StatusCode::OK) .insert_header(( "x-head", HeaderValue::from_static("dummy value!"), diff --git a/actix-http/examples/echo2.rs b/actix-http/examples/echo2.rs index 483a79aac..9acf4bbae 100644 --- a/actix-http/examples/echo2.rs +++ b/actix-http/examples/echo2.rs @@ -1,6 +1,6 @@ use std::{env, io}; -use actix_http::{body::Body, http::HeaderValue}; +use actix_http::{body::Body, http::HeaderValue, http::StatusCode}; use actix_http::{Error, HttpService, Request, Response}; use actix_server::Server; use bytes::BytesMut; @@ -14,7 +14,7 @@ async fn handle_request(mut req: Request) -> Result, Error> { } info!("request body: {:?}", body); - Ok(Response::Ok() + Ok(Response::build(StatusCode::OK) .insert_header(("x-head", HeaderValue::from_static("dummy value!"))) .body(body)) } diff --git a/actix-http/examples/hello-world.rs b/actix-http/examples/hello-world.rs index a99ddae46..85994556d 100644 --- a/actix-http/examples/hello-world.rs +++ b/actix-http/examples/hello-world.rs @@ -1,6 +1,6 @@ use std::{env, io}; -use actix_http::{HttpService, Response}; +use actix_http::{http::StatusCode, HttpService, Response}; use actix_server::Server; use actix_utils::future; use http::header::HeaderValue; @@ -18,7 +18,7 @@ async fn main() -> io::Result<()> { .client_disconnect(1000) .finish(|_req| { info!("{:?}", _req); - let mut res = Response::Ok(); + let mut res = Response::build(StatusCode::OK); res.insert_header(( "x-head", HeaderValue::from_static("dummy value!"), diff --git a/actix-http/src/body/body.rs b/actix-http/src/body/body.rs index 5fc461d41..4fe18338a 100644 --- a/actix-http/src/body/body.rs +++ b/actix-http/src/body/body.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, fmt, mem, pin::Pin, task::{Context, Poll}, @@ -118,12 +119,23 @@ impl From for Body { } } -impl<'a> From<&'a String> for Body { - fn from(s: &'a String) -> Body { +impl From<&'_ String> for Body { + fn from(s: &String) -> Body { Body::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(&s))) } } +impl From> for Body { + fn from(s: Cow<'_, str>) -> Body { + match s { + Cow::Owned(s) => Body::from(s), + Cow::Borrowed(s) => { + Body::Bytes(Bytes::copy_from_slice(AsRef::<[u8]>::as_ref(s))) + } + } + } +} + impl From for Body { fn from(s: Bytes) -> Body { Body::Bytes(s) diff --git a/actix-http/src/body/message_body.rs b/actix-http/src/body/message_body.rs index ea2cfd22d..894a5fa98 100644 --- a/actix-http/src/body/message_body.rs +++ b/actix-http/src/body/message_body.rs @@ -12,10 +12,12 @@ use crate::error::Error; use super::BodySize; -/// Type that implement this trait can be streamed to a peer. +/// An interface for response bodies. pub trait MessageBody { + /// Body size hint. fn size(&self) -> BodySize; + /// Attempt to pull out the next chunk of body bytes. fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/actix-http/src/body/mod.rs b/actix-http/src/body/mod.rs index c298dda11..f26d6a8cf 100644 --- a/actix-http/src/body/mod.rs +++ b/actix-http/src/body/mod.rs @@ -56,7 +56,7 @@ pub async fn to_bytes(body: impl MessageBody) -> Result { let body = body.as_mut(); match ready!(body.poll_next(cx)) { - Some(Ok(bytes)) => buf.extend(bytes), + Some(Ok(bytes)) => buf.extend_from_slice(&*bytes), None => return Poll::Ready(Ok(())), Some(Err(err)) => return Poll::Ready(Err(err)), } diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs index 01c4beeba..68ad709a1 100644 --- a/actix-http/src/error.rs +++ b/actix-http/src/error.rs @@ -140,8 +140,8 @@ impl From for Error { } } -#[derive(Debug, Display)] -#[display(fmt = "UnknownError")] +#[derive(Debug, Display, Error)] +#[display(fmt = "Unknown Error")] struct UnitError; /// Returns [`StatusCode::INTERNAL_SERVER_ERROR`] for [`UnitError`]. @@ -190,38 +190,47 @@ impl ResponseError for header::InvalidHeaderValue { } } -/// A set of errors that can occur during parsing HTTP streams -#[derive(Debug, Display)] +/// A set of errors that can occur during parsing HTTP streams. +#[derive(Debug, Display, Error)] +#[non_exhaustive] pub enum ParseError { /// An invalid `Method`, such as `GE.T`. #[display(fmt = "Invalid Method specified")] Method, + /// An invalid `Uri`, such as `exam ple.domain`. #[display(fmt = "Uri error: {}", _0)] Uri(InvalidUri), + /// An invalid `HttpVersion`, such as `HTP/1.1` #[display(fmt = "Invalid HTTP version specified")] Version, + /// An invalid `Header`. #[display(fmt = "Invalid Header provided")] Header, + /// A message head is too large to be reasonable. #[display(fmt = "Message head is too large")] TooLarge, + /// A message reached EOF, but is not complete. #[display(fmt = "Message is incomplete")] Incomplete, + /// An invalid `Status`, such as `1337 ELITE`. #[display(fmt = "Invalid Status provided")] Status, + /// A timeout occurred waiting for an IO event. #[allow(dead_code)] #[display(fmt = "Timeout")] Timeout, - /// An `io::Error` that occurred while trying to read or write to a network - /// stream. + + /// An `io::Error` that occurred while trying to read or write to a network stream. #[display(fmt = "IO error: {}", _0)] Io(io::Error), + /// Parsing a field as string failed #[display(fmt = "UTF8 error: {}", _0)] Utf8(Utf8Error), @@ -273,17 +282,16 @@ impl From for ParseError { } /// A set of errors that can occur running blocking tasks in thread pool. -#[derive(Debug, Display)] +#[derive(Debug, Display, Error)] #[display(fmt = "Blocking thread pool is gone")] pub struct BlockingError; -impl std::error::Error for BlockingError {} - /// `InternalServerError` for `BlockingError` impl ResponseError for BlockingError {} -#[derive(Display, Debug)] -/// A set of errors that can occur during payload parsing +/// A set of errors that can occur during payload parsing. +#[derive(Debug, Display)] +#[non_exhaustive] pub enum PayloadError { /// A payload reached EOF, but is not complete. #[display( @@ -367,8 +375,9 @@ impl ResponseError for PayloadError { } } -#[derive(Debug, Display, From)] -/// A set of errors that can occur during dispatching HTTP requests +/// A set of errors that can occur during dispatching HTTP requests. +#[derive(Debug, Display, Error, From)] +#[non_exhaustive] pub enum DispatchError { /// Service error Service(Error), @@ -414,8 +423,9 @@ pub enum DispatchError { Unknown, } -/// A set of error that can occur during parsing content type -#[derive(Debug, PartialEq, Display, Error)] +/// A set of error that can occur during parsing content type. +#[derive(Debug, Display, Error)] +#[non_exhaustive] pub enum ContentTypeError { /// Can not parse content type #[display(fmt = "Can not parse content type")] @@ -426,6 +436,22 @@ pub enum ContentTypeError { UnknownEncoding, } +#[cfg(test)] +mod content_type_test_impls { + use super::*; + + impl std::cmp::PartialEq for ContentTypeError { + fn eq(&self, other: &Self) -> bool { + match self { + Self::ParseError => matches!(other, ContentTypeError::ParseError), + Self::UnknownEncoding => { + matches!(other, ContentTypeError::UnknownEncoding) + } + } + } + } +} + /// Return `BadRequest` for `ContentTypeError` impl ResponseError for ContentTypeError { fn status_code(&self) -> StatusCode { @@ -533,395 +559,72 @@ where } } -/// Helper function that creates wrapper of any error and generate *BAD -/// REQUEST* response. -#[allow(non_snake_case)] -pub fn ErrorBadRequest(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::BAD_REQUEST).into() +macro_rules! error_helper { + ($name:ident, $status:ident) => { + paste::paste! { + #[doc = "Helper function that wraps any error and generates a `" $status "` response."] + #[allow(non_snake_case)] + pub fn $name(err: T) -> Error + where + T: fmt::Debug + fmt::Display + 'static, + { + InternalError::new(err, StatusCode::$status).into() + } + } + } } -/// Helper function that creates wrapper of any error and generate -/// *UNAUTHORIZED* response. -#[allow(non_snake_case)] -pub fn ErrorUnauthorized(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::UNAUTHORIZED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *PAYMENT_REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorPaymentRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::PAYMENT_REQUIRED).into() -} - -/// Helper function that creates wrapper of any error and generate *FORBIDDEN* -/// response. -#[allow(non_snake_case)] -pub fn ErrorForbidden(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::FORBIDDEN).into() -} - -/// Helper function that creates wrapper of any error and generate *NOT FOUND* -/// response. -#[allow(non_snake_case)] -pub fn ErrorNotFound(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::NOT_FOUND).into() -} - -/// Helper function that creates wrapper of any error and generate *METHOD NOT -/// ALLOWED* response. -#[allow(non_snake_case)] -pub fn ErrorMethodNotAllowed(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::METHOD_NOT_ALLOWED).into() -} - -/// Helper function that creates wrapper of any error and generate *NOT -/// ACCEPTABLE* response. -#[allow(non_snake_case)] -pub fn ErrorNotAcceptable(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::NOT_ACCEPTABLE).into() -} - -/// Helper function that creates wrapper of any error and generate *PROXY -/// AUTHENTICATION REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorProxyAuthenticationRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::PROXY_AUTHENTICATION_REQUIRED).into() -} - -/// Helper function that creates wrapper of any error and generate *REQUEST -/// TIMEOUT* response. -#[allow(non_snake_case)] -pub fn ErrorRequestTimeout(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::REQUEST_TIMEOUT).into() -} - -/// Helper function that creates wrapper of any error and generate *CONFLICT* -/// response. -#[allow(non_snake_case)] -pub fn ErrorConflict(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::CONFLICT).into() -} - -/// Helper function that creates wrapper of any error and generate *GONE* -/// response. -#[allow(non_snake_case)] -pub fn ErrorGone(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::GONE).into() -} - -/// Helper function that creates wrapper of any error and generate *LENGTH -/// REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorLengthRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::LENGTH_REQUIRED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *PAYLOAD TOO LARGE* response. -#[allow(non_snake_case)] -pub fn ErrorPayloadTooLarge(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::PAYLOAD_TOO_LARGE).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *URI TOO LONG* response. -#[allow(non_snake_case)] -pub fn ErrorUriTooLong(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::URI_TOO_LONG).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *UNSUPPORTED MEDIA TYPE* response. -#[allow(non_snake_case)] -pub fn ErrorUnsupportedMediaType(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::UNSUPPORTED_MEDIA_TYPE).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *RANGE NOT SATISFIABLE* response. -#[allow(non_snake_case)] -pub fn ErrorRangeNotSatisfiable(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::RANGE_NOT_SATISFIABLE).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *IM A TEAPOT* response. -#[allow(non_snake_case)] -pub fn ErrorImATeapot(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::IM_A_TEAPOT).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *MISDIRECTED REQUEST* response. -#[allow(non_snake_case)] -pub fn ErrorMisdirectedRequest(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::MISDIRECTED_REQUEST).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *UNPROCESSABLE ENTITY* response. -#[allow(non_snake_case)] -pub fn ErrorUnprocessableEntity(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::UNPROCESSABLE_ENTITY).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *LOCKED* response. -#[allow(non_snake_case)] -pub fn ErrorLocked(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::LOCKED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *FAILED DEPENDENCY* response. -#[allow(non_snake_case)] -pub fn ErrorFailedDependency(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::FAILED_DEPENDENCY).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *UPGRADE REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorUpgradeRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::UPGRADE_REQUIRED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *PRECONDITION FAILED* response. -#[allow(non_snake_case)] -pub fn ErrorPreconditionFailed(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::PRECONDITION_FAILED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *PRECONDITION REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorPreconditionRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::PRECONDITION_REQUIRED).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *TOO MANY REQUESTS* response. -#[allow(non_snake_case)] -pub fn ErrorTooManyRequests(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::TOO_MANY_REQUESTS).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *REQUEST HEADER FIELDS TOO LARGE* response. -#[allow(non_snake_case)] -pub fn ErrorRequestHeaderFieldsTooLarge(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *UNAVAILABLE FOR LEGAL REASONS* response. -#[allow(non_snake_case)] -pub fn ErrorUnavailableForLegalReasons(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS).into() -} - -/// Helper function that creates wrapper of any error and generate -/// *EXPECTATION FAILED* response. -#[allow(non_snake_case)] -pub fn ErrorExpectationFailed(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::EXPECTATION_FAILED).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *INTERNAL SERVER ERROR* response. -#[allow(non_snake_case)] -pub fn ErrorInternalServerError(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::INTERNAL_SERVER_ERROR).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *NOT IMPLEMENTED* response. -#[allow(non_snake_case)] -pub fn ErrorNotImplemented(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::NOT_IMPLEMENTED).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *BAD GATEWAY* response. -#[allow(non_snake_case)] -pub fn ErrorBadGateway(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::BAD_GATEWAY).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *SERVICE UNAVAILABLE* response. -#[allow(non_snake_case)] -pub fn ErrorServiceUnavailable(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::SERVICE_UNAVAILABLE).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *GATEWAY TIMEOUT* response. -#[allow(non_snake_case)] -pub fn ErrorGatewayTimeout(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::GATEWAY_TIMEOUT).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *HTTP VERSION NOT SUPPORTED* response. -#[allow(non_snake_case)] -pub fn ErrorHttpVersionNotSupported(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::HTTP_VERSION_NOT_SUPPORTED).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *VARIANT ALSO NEGOTIATES* response. -#[allow(non_snake_case)] -pub fn ErrorVariantAlsoNegotiates(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::VARIANT_ALSO_NEGOTIATES).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *INSUFFICIENT STORAGE* response. -#[allow(non_snake_case)] -pub fn ErrorInsufficientStorage(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::INSUFFICIENT_STORAGE).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *LOOP DETECTED* response. -#[allow(non_snake_case)] -pub fn ErrorLoopDetected(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::LOOP_DETECTED).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *NOT EXTENDED* response. -#[allow(non_snake_case)] -pub fn ErrorNotExtended(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::NOT_EXTENDED).into() -} - -/// Helper function that creates wrapper of any error and -/// generate *NETWORK AUTHENTICATION REQUIRED* response. -#[allow(non_snake_case)] -pub fn ErrorNetworkAuthenticationRequired(err: T) -> Error -where - T: fmt::Debug + fmt::Display + 'static, -{ - InternalError::new(err, StatusCode::NETWORK_AUTHENTICATION_REQUIRED).into() -} +error_helper!(ErrorBadRequest, BAD_REQUEST); +error_helper!(ErrorUnauthorized, UNAUTHORIZED); +error_helper!(ErrorPaymentRequired, PAYMENT_REQUIRED); +error_helper!(ErrorForbidden, FORBIDDEN); +error_helper!(ErrorNotFound, NOT_FOUND); +error_helper!(ErrorMethodNotAllowed, METHOD_NOT_ALLOWED); +error_helper!(ErrorNotAcceptable, NOT_ACCEPTABLE); +error_helper!( + ErrorProxyAuthenticationRequired, + PROXY_AUTHENTICATION_REQUIRED +); +error_helper!(ErrorRequestTimeout, REQUEST_TIMEOUT); +error_helper!(ErrorConflict, CONFLICT); +error_helper!(ErrorGone, GONE); +error_helper!(ErrorLengthRequired, LENGTH_REQUIRED); +error_helper!(ErrorPayloadTooLarge, PAYLOAD_TOO_LARGE); +error_helper!(ErrorUriTooLong, URI_TOO_LONG); +error_helper!(ErrorUnsupportedMediaType, UNSUPPORTED_MEDIA_TYPE); +error_helper!(ErrorRangeNotSatisfiable, RANGE_NOT_SATISFIABLE); +error_helper!(ErrorImATeapot, IM_A_TEAPOT); +error_helper!(ErrorMisdirectedRequest, MISDIRECTED_REQUEST); +error_helper!(ErrorUnprocessableEntity, UNPROCESSABLE_ENTITY); +error_helper!(ErrorLocked, LOCKED); +error_helper!(ErrorFailedDependency, FAILED_DEPENDENCY); +error_helper!(ErrorUpgradeRequired, UPGRADE_REQUIRED); +error_helper!(ErrorPreconditionFailed, PRECONDITION_FAILED); +error_helper!(ErrorPreconditionRequired, PRECONDITION_REQUIRED); +error_helper!(ErrorTooManyRequests, TOO_MANY_REQUESTS); +error_helper!( + ErrorRequestHeaderFieldsTooLarge, + REQUEST_HEADER_FIELDS_TOO_LARGE +); +error_helper!( + ErrorUnavailableForLegalReasons, + UNAVAILABLE_FOR_LEGAL_REASONS +); +error_helper!(ErrorExpectationFailed, EXPECTATION_FAILED); +error_helper!(ErrorInternalServerError, INTERNAL_SERVER_ERROR); +error_helper!(ErrorNotImplemented, NOT_IMPLEMENTED); +error_helper!(ErrorBadGateway, BAD_GATEWAY); +error_helper!(ErrorServiceUnavailable, SERVICE_UNAVAILABLE); +error_helper!(ErrorGatewayTimeout, GATEWAY_TIMEOUT); +error_helper!(ErrorHttpVersionNotSupported, HTTP_VERSION_NOT_SUPPORTED); +error_helper!(ErrorVariantAlsoNegotiates, VARIANT_ALSO_NEGOTIATES); +error_helper!(ErrorInsufficientStorage, INSUFFICIENT_STORAGE); +error_helper!(ErrorLoopDetected, LOOP_DETECTED); +error_helper!(ErrorNotExtended, NOT_EXTENDED); +error_helper!( + ErrorNetworkAuthenticationRequired, + NETWORK_AUTHENTICATION_REQUIRED +); #[cfg(test)] mod tests { @@ -1021,8 +724,7 @@ mod tests { #[test] fn test_internal_error() { - let err = - InternalError::from_response(ParseError::Method, Response::Ok().into()); + let err = InternalError::from_response(ParseError::Method, Response::ok()); let resp: Response = err.error_response(); assert_eq!(resp.status(), StatusCode::OK); } diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 7ce033fca..23c2ae8d5 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -25,6 +25,7 @@ use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; use crate::error::{ParseError, PayloadError}; +use crate::http::StatusCode; use crate::request::Request; use crate::response::Response; use crate::service::HttpFlow; @@ -596,7 +597,7 @@ where ); this.flags.insert(Flags::READ_DISCONNECT); this.messages.push_back(DispatcherMessage::Error( - Response::InternalServerError().finish().drop_body(), + Response::internal_server_error().drop_body(), )); *this.error = Some(DispatchError::InternalError); break; @@ -609,7 +610,7 @@ where error!("Internal server error: unexpected eof"); this.flags.insert(Flags::READ_DISCONNECT); this.messages.push_back(DispatcherMessage::Error( - Response::InternalServerError().finish().drop_body(), + Response::internal_server_error().drop_body(), )); *this.error = Some(DispatchError::InternalError); break; @@ -632,7 +633,8 @@ where } // Requests overflow buffer size should be responded with 431 this.messages.push_back(DispatcherMessage::Error( - Response::RequestHeaderFieldsTooLarge().finish().drop_body(), + Response::new(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE) + .drop_body(), )); this.flags.insert(Flags::READ_DISCONNECT); *this.error = Some(ParseError::TooLarge.into()); @@ -645,7 +647,7 @@ where // Malformed requests should be responded with 400 this.messages.push_back(DispatcherMessage::Error( - Response::BadRequest().finish().drop_body(), + Response::bad_request().drop_body(), )); this.flags.insert(Flags::READ_DISCONNECT); *this.error = Some(err.into()); @@ -681,11 +683,6 @@ where // go into Some> branch this.ka_timer.set(Some(sleep_until(deadline))); return self.poll_keepalive(cx); - } else { - this.flags.insert(Flags::READ_DISCONNECT); - if let Some(mut payload) = this.payload.take() { - payload.set_error(PayloadError::Incomplete(None)); - } } } } @@ -732,18 +729,14 @@ where } } else { // timeout on first request (slow request) return 408 - if !this.flags.contains(Flags::STARTED) { - trace!("Slow request timeout"); - let _ = self.as_mut().send_response( - Response::RequestTimeout().finish().drop_body(), - ResponseBody::Other(Body::Empty), - ); - this = self.project(); - } else { - trace!("Keep-alive connection timeout"); - } + trace!("Slow request timeout"); + let _ = self.as_mut().send_response( + Response::new(StatusCode::REQUEST_TIMEOUT) + .drop_body(), + ResponseBody::Other(Body::Empty), + ); + this = self.project(); this.flags.insert(Flags::STARTED | Flags::SHUTDOWN); - this.state.set(State::None); } // still have unfinished task. try to reset and register keep-alive. } else if let Some(deadline) = @@ -1031,7 +1024,7 @@ mod tests { } fn ok_service() -> impl Service, Error = Error> { - fn_service(|_req: Request| ready(Ok::<_, Error>(Response::Ok().finish()))) + fn_service(|_req: Request| ready(Ok::<_, Error>(Response::ok()))) } fn pending_service() -> impl Service, Error = Error> { @@ -1056,12 +1049,14 @@ mod tests { ) -> impl Service, Error = Error> { fn_service(|req: Request| { let path = req.path().as_bytes(); - ready(Ok::<_, Error>(Response::Ok().body(Body::from_slice(path)))) + ready(Ok::<_, Error>( + Response::ok().set_body(Body::from_slice(path)), + )) }) } fn echo_payload_service( - ) -> impl Service, Error = Error> { + ) -> impl Service, Error = Error> { fn_service(|mut req: Request| { Box::pin(async move { use futures_util::stream::StreamExt as _; @@ -1072,7 +1067,7 @@ mod tests { body.extend_from_slice(chunk.unwrap().chunk()) } - Ok::<_, Error>(Response::Ok().body(body)) + Ok::<_, Error>(Response::ok().set_body(body.freeze())) }) }) } diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index bba7af4c6..4547f3ef2 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -37,12 +37,12 @@ pub mod encoding; mod extensions; mod header; mod helpers; -mod http_codes; mod http_message; mod message; mod payload; mod request; mod response; +mod response_builder; mod service; mod time_parser; @@ -60,7 +60,8 @@ pub use self::http_message::HttpMessage; pub use self::message::{Message, RequestHead, RequestHeadType, ResponseHead}; pub use self::payload::{Payload, PayloadStream}; pub use self::request::Request; -pub use self::response::{Response, ResponseBuilder}; +pub use self::response::Response; +pub use self::response_builder::ResponseBuilder; pub use self::service::HttpService; pub mod http { diff --git a/actix-http/src/response.rs b/actix-http/src/response.rs index 0c6272485..a3ab1175c 100644 --- a/actix-http/src/response.rs +++ b/actix-http/src/response.rs @@ -1,4 +1,4 @@ -//! HTTP responses. +//! HTTP response. use std::{ cell::{Ref, RefMut}, @@ -10,37 +10,24 @@ use std::{ }; use bytes::{Bytes, BytesMut}; -use futures_core::Stream; use crate::{ - body::{Body, BodyStream, MessageBody, ResponseBody}, + body::{Body, MessageBody, ResponseBody}, error::Error, extensions::Extensions, - header::{IntoHeaderPair, IntoHeaderValue}, - http::{header, Error as HttpError, HeaderMap, StatusCode}, - message::{BoxedResponseHead, ConnectionType, ResponseHead}, + http::{HeaderMap, StatusCode}, + message::{BoxedResponseHead, ResponseHead}, + ResponseBuilder, }; -/// An HTTP Response +/// An HTTP response. pub struct Response { - head: BoxedResponseHead, - body: ResponseBody, - error: Option, + pub(crate) head: BoxedResponseHead, + pub(crate) body: ResponseBody, + pub(crate) error: Option, } impl Response { - /// Create HTTP response builder with specific status. - #[inline] - pub fn build(status: StatusCode) -> ResponseBuilder { - ResponseBuilder::new(status) - } - - /// Create HTTP response builder - #[inline] - pub fn build_from>(source: T) -> ResponseBuilder { - source.into() - } - /// Constructs a response #[inline] pub fn new(status: StatusCode) -> Response { @@ -51,6 +38,41 @@ impl Response { } } + /// Create HTTP response builder with specific status. + #[inline] + pub fn build(status: StatusCode) -> ResponseBuilder { + ResponseBuilder::new(status) + } + + // just a couple frequently used shortcuts + // this list should not grow larger than a few + + /// Creates a new response with status 200 OK. + #[inline] + pub fn ok() -> Response { + Response::new(StatusCode::OK) + } + + /// Creates a new response with status 400 Bad Request. + #[inline] + pub fn bad_request() -> Response { + Response::new(StatusCode::BAD_REQUEST) + } + + /// Creates a new response with status 404 Not Found. + #[inline] + pub fn not_found() -> Response { + Response::new(StatusCode::NOT_FOUND) + } + + /// Creates a new response with status 500 Internal Server Error. + #[inline] + pub fn internal_server_error() -> Response { + Response::new(StatusCode::INTERNAL_SERVER_ERROR) + } + + // end shortcuts + /// Constructs an error response #[inline] pub fn from_error(error: Error) -> Response { @@ -250,295 +272,6 @@ impl Future for Response { } } -/// An HTTP response builder. -/// -/// This type can be used to construct an instance of `Response` through a builder-like pattern. -pub struct ResponseBuilder { - head: Option, - err: Option, -} - -impl ResponseBuilder { - #[inline] - /// Create response builder - pub fn new(status: StatusCode) -> Self { - ResponseBuilder { - head: Some(BoxedResponseHead::new(status)), - err: None, - } - } - - /// Set HTTP status code of this response. - #[inline] - pub fn status(&mut self, status: StatusCode) -> &mut Self { - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.status = status; - } - self - } - - /// Insert a header, replacing any that were set with an equivalent field name. - /// - /// ``` - /// # use actix_http::Response; - /// use actix_http::http::header; - /// - /// Response::Ok() - /// .insert_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)) - /// .insert_header(("X-TEST", "value")) - /// .finish(); - /// ``` - pub fn insert_header(&mut self, header: H) -> &mut Self - where - H: IntoHeaderPair, - { - if let Some(parts) = parts(&mut self.head, &self.err) { - match header.try_into_header_pair() { - Ok((key, value)) => { - parts.headers.insert(key, value); - } - Err(e) => self.err = Some(e.into()), - }; - } - - self - } - - /// Append a header, keeping any that were set with an equivalent field name. - /// - /// ``` - /// # use actix_http::Response; - /// use actix_http::http::header; - /// - /// Response::Ok() - /// .append_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)) - /// .append_header(("X-TEST", "value1")) - /// .append_header(("X-TEST", "value2")) - /// .finish(); - /// ``` - pub fn append_header(&mut self, header: H) -> &mut Self - where - H: IntoHeaderPair, - { - if let Some(parts) = parts(&mut self.head, &self.err) { - match header.try_into_header_pair() { - Ok((key, value)) => parts.headers.append(key, value), - Err(e) => self.err = Some(e.into()), - }; - } - - self - } - - /// Set the custom reason for the response. - #[inline] - pub fn reason(&mut self, reason: &'static str) -> &mut Self { - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.reason = Some(reason); - } - self - } - - /// Set connection type to KeepAlive - #[inline] - pub fn keep_alive(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.set_connection_type(ConnectionType::KeepAlive); - } - self - } - - /// Set connection type to Upgrade - #[inline] - pub fn upgrade(&mut self, value: V) -> &mut Self - where - V: IntoHeaderValue, - { - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.set_connection_type(ConnectionType::Upgrade); - } - - if let Ok(value) = value.try_into_value() { - self.insert_header((header::UPGRADE, value)); - } - - self - } - - /// Force close connection, even if it is marked as keep-alive - #[inline] - pub fn force_close(&mut self) -> &mut Self { - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.set_connection_type(ConnectionType::Close); - } - self - } - - /// Disable chunked transfer encoding for HTTP/1.1 streaming responses. - #[inline] - pub fn no_chunking(&mut self, len: u64) -> &mut Self { - let mut buf = itoa::Buffer::new(); - self.insert_header((header::CONTENT_LENGTH, buf.format(len))); - - if let Some(parts) = parts(&mut self.head, &self.err) { - parts.no_chunking(true); - } - self - } - - /// Set response content type. - #[inline] - pub fn content_type(&mut self, value: V) -> &mut Self - where - V: IntoHeaderValue, - { - if let Some(parts) = parts(&mut self.head, &self.err) { - match value.try_into_value() { - Ok(value) => { - parts.headers.insert(header::CONTENT_TYPE, value); - } - Err(e) => self.err = Some(e.into()), - }; - } - self - } - - /// Responses extensions - #[inline] - pub fn extensions(&self) -> Ref<'_, Extensions> { - let head = self.head.as_ref().expect("cannot reuse response builder"); - head.extensions.borrow() - } - - /// Mutable reference to a the response's extensions - #[inline] - pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { - let head = self.head.as_ref().expect("cannot reuse response builder"); - head.extensions.borrow_mut() - } - - /// Set a body and generate `Response`. - /// - /// `ResponseBuilder` can not be used after this call. - #[inline] - pub fn body>(&mut self, body: B) -> Response { - self.message_body(body.into()) - } - - /// Set a body and generate `Response`. - /// - /// `ResponseBuilder` can not be used after this call. - pub fn message_body(&mut self, body: B) -> Response { - if let Some(e) = self.err.take() { - return Response::from(Error::from(e)).into_body(); - } - - let response = self.head.take().expect("cannot reuse response builder"); - - Response { - head: response, - body: ResponseBody::Body(body), - error: None, - } - } - - /// Set a streaming body and generate `Response`. - /// - /// `ResponseBuilder` can not be used after this call. - #[inline] - pub fn streaming(&mut self, stream: S) -> Response - where - S: Stream> + Unpin + 'static, - E: Into + 'static, - { - self.body(Body::from_message(BodyStream::new(stream))) - } - - /// Set an empty body and generate `Response` - /// - /// `ResponseBuilder` can not be used after this call. - #[inline] - pub fn finish(&mut self) -> Response { - self.body(Body::Empty) - } - - /// This method construct new `ResponseBuilder` - pub fn take(&mut self) -> ResponseBuilder { - ResponseBuilder { - head: self.head.take(), - err: self.err.take(), - } - } -} - -#[inline] -fn parts<'a>( - parts: &'a mut Option, - err: &Option, -) -> Option<&'a mut ResponseHead> { - if err.is_some() { - return None; - } - parts.as_mut().map(|r| &mut **r) -} - -/// Convert `Response` to a `ResponseBuilder`. Body get dropped. -impl From> for ResponseBuilder { - fn from(res: Response) -> ResponseBuilder { - ResponseBuilder { - head: Some(res.head), - err: None, - } - } -} - -/// Convert `ResponseHead` to a `ResponseBuilder` -impl<'a> From<&'a ResponseHead> for ResponseBuilder { - fn from(head: &'a ResponseHead) -> ResponseBuilder { - let mut msg = BoxedResponseHead::new(head.status); - msg.version = head.version; - msg.reason = head.reason; - - for (k, v) in head.headers.iter() { - msg.headers.append(k.clone(), v.clone()); - } - - msg.no_chunking(!head.chunked()); - - ResponseBuilder { - head: Some(msg), - err: None, - } - } -} - -impl Future for ResponseBuilder { - type Output = Result, Error>; - - fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { - Poll::Ready(Ok(self.finish())) - } -} - -impl fmt::Debug for ResponseBuilder { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let head = self.head.as_ref().unwrap(); - - let res = writeln!( - f, - "\nResponseBuilder {:?} {}{}", - head.version, - head.status, - head.reason.unwrap_or(""), - ); - let _ = writeln!(f, " headers:"); - for (key, val) in head.headers.iter() { - let _ = writeln!(f, " {:?}: {:?}", key, val); - } - res - } -} - /// Helper converters impl>, E: Into> From> for Response { fn from(res: Result) -> Self { @@ -557,7 +290,7 @@ impl From for Response { impl From<&'static str> for Response { fn from(val: &'static str) -> Self { - Response::Ok() + Response::build(StatusCode::OK) .content_type(mime::TEXT_PLAIN_UTF_8) .body(val) } @@ -565,7 +298,7 @@ impl From<&'static str> for Response { impl From<&'static [u8]> for Response { fn from(val: &'static [u8]) -> Self { - Response::Ok() + Response::build(StatusCode::OK) .content_type(mime::APPLICATION_OCTET_STREAM) .body(val) } @@ -573,7 +306,7 @@ impl From<&'static [u8]> for Response { impl From for Response { fn from(val: String) -> Self { - Response::Ok() + Response::build(StatusCode::OK) .content_type(mime::TEXT_PLAIN_UTF_8) .body(val) } @@ -581,7 +314,7 @@ impl From for Response { impl<'a> From<&'a String> for Response { fn from(val: &'a String) -> Self { - Response::Ok() + Response::build(StatusCode::OK) .content_type(mime::TEXT_PLAIN_UTF_8) .body(val) } @@ -589,7 +322,7 @@ impl<'a> From<&'a String> for Response { impl From for Response { fn from(val: Bytes) -> Self { - Response::Ok() + Response::build(StatusCode::OK) .content_type(mime::APPLICATION_OCTET_STREAM) .body(val) } @@ -597,7 +330,7 @@ impl From for Response { impl From for Response { fn from(val: BytesMut) -> Self { - Response::Ok() + Response::build(StatusCode::OK) .content_type(mime::APPLICATION_OCTET_STREAM) .body(val) } @@ -607,11 +340,11 @@ impl From for Response { mod tests { use super::*; use crate::body::Body; - use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE, COOKIE}; + use crate::http::header::{HeaderValue, CONTENT_TYPE, COOKIE}; #[test] fn test_debug() { - let resp = Response::Ok() + let resp = Response::build(StatusCode::OK) .append_header((COOKIE, HeaderValue::from_static("cookie1=value1; "))) .append_header((COOKIE, HeaderValue::from_static("cookie2=value2; "))) .finish(); @@ -619,38 +352,6 @@ mod tests { assert!(dbg.contains("Response")); } - #[test] - fn test_basic_builder() { - let resp = Response::Ok().insert_header(("X-TEST", "value")).finish(); - assert_eq!(resp.status(), StatusCode::OK); - } - - #[test] - fn test_upgrade() { - let resp = Response::build(StatusCode::OK) - .upgrade("websocket") - .finish(); - assert!(resp.upgrade()); - assert_eq!( - resp.headers().get(header::UPGRADE).unwrap(), - HeaderValue::from_static("websocket") - ); - } - - #[test] - fn test_force_close() { - let resp = Response::build(StatusCode::OK).force_close().finish(); - assert!(!resp.keep_alive()) - } - - #[test] - fn test_content_type() { - let resp = Response::build(StatusCode::OK) - .content_type("text/plain") - .body(Body::Empty); - assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain") - } - #[test] fn test_into_response() { let resp: Response = "test".into(); @@ -720,72 +421,4 @@ mod tests { assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().get_ref(), b"test"); } - - #[test] - fn test_into_builder() { - let mut resp: Response = "test".into(); - assert_eq!(resp.status(), StatusCode::OK); - - resp.headers_mut().insert( - HeaderName::from_static("cookie"), - HeaderValue::from_static("cookie1=val100"), - ); - - let mut builder: ResponseBuilder = resp.into(); - let resp = builder.status(StatusCode::BAD_REQUEST).finish(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - - let cookie = resp.headers().get_all("Cookie").next().unwrap(); - assert_eq!(cookie.to_str().unwrap(), "cookie1=val100"); - } - - #[test] - fn response_builder_header_insert_kv() { - let mut res = Response::Ok(); - res.insert_header(("Content-Type", "application/octet-stream")); - let res = res.finish(); - - assert_eq!( - res.headers().get("Content-Type"), - Some(&HeaderValue::from_static("application/octet-stream")) - ); - } - - #[test] - fn response_builder_header_insert_typed() { - let mut res = Response::Ok(); - res.insert_header((header::CONTENT_TYPE, mime::APPLICATION_OCTET_STREAM)); - let res = res.finish(); - - assert_eq!( - res.headers().get("Content-Type"), - Some(&HeaderValue::from_static("application/octet-stream")) - ); - } - - #[test] - fn response_builder_header_append_kv() { - let mut res = Response::Ok(); - res.append_header(("Content-Type", "application/octet-stream")); - res.append_header(("Content-Type", "application/json")); - let res = res.finish(); - - let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); - assert_eq!(headers.len(), 2); - assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); - assert!(headers.contains(&HeaderValue::from_static("application/json"))); - } - - #[test] - fn response_builder_header_append_typed() { - let mut res = Response::Ok(); - res.append_header((header::CONTENT_TYPE, mime::APPLICATION_OCTET_STREAM)); - res.append_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)); - let res = res.finish(); - - let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); - assert_eq!(headers.len(), 2); - assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); - assert!(headers.contains(&HeaderValue::from_static("application/json"))); - } } diff --git a/actix-http/src/response_builder.rs b/actix-http/src/response_builder.rs new file mode 100644 index 000000000..4d8cb4429 --- /dev/null +++ b/actix-http/src/response_builder.rs @@ -0,0 +1,468 @@ +//! HTTP response builder. + +use std::{ + cell::{Ref, RefMut}, + fmt, + future::Future, + pin::Pin, + str, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures_core::Stream; + +use crate::{ + body::{Body, BodyStream, ResponseBody}, + error::Error, + extensions::Extensions, + header::{IntoHeaderPair, IntoHeaderValue}, + http::{header, Error as HttpError, StatusCode}, + message::{BoxedResponseHead, ConnectionType, ResponseHead}, + Response, +}; + +/// An HTTP response builder. +/// +/// Used to construct an instance of `Response` using a builder pattern. Response builders are often +/// created using [`Response::build`]. +/// +/// # Examples +/// ``` +/// use actix_http::{Response, ResponseBuilder, body, http::StatusCode, http::header}; +/// +/// # actix_rt::System::new().block_on(async { +/// let mut res: Response<_> = Response::build(StatusCode::OK) +/// .content_type(mime::APPLICATION_JSON) +/// .insert_header((header::SERVER, "my-app/1.0")) +/// .append_header((header::SET_COOKIE, "a=1")) +/// .append_header((header::SET_COOKIE, "b=2")) +/// .body("1234"); +/// +/// assert_eq!(res.status(), StatusCode::OK); +/// assert_eq!(body::to_bytes(res.take_body()).await.unwrap(), &b"1234"[..]); +/// +/// assert!(res.headers().contains_key("server")); +/// assert_eq!(res.headers().get_all("set-cookie").count(), 2); +/// # }) +/// ``` +pub struct ResponseBuilder { + head: Option, + err: Option, +} + +impl ResponseBuilder { + /// Create response builder + /// + /// # Examples + /// ``` + /// use actix_http::{Response, ResponseBuilder, http::StatusCode}; + /// + /// let res: Response<_> = ResponseBuilder::default().finish(); + /// assert_eq!(res.status(), StatusCode::OK); + /// ``` + #[inline] + pub fn new(status: StatusCode) -> Self { + ResponseBuilder { + head: Some(BoxedResponseHead::new(status)), + err: None, + } + } + + /// Set HTTP status code of this response. + /// + /// # Examples + /// ``` + /// use actix_http::{ResponseBuilder, http::StatusCode}; + /// + /// let res = ResponseBuilder::default().status(StatusCode::NOT_FOUND).finish(); + /// assert_eq!(res.status(), StatusCode::NOT_FOUND); + /// ``` + #[inline] + pub fn status(&mut self, status: StatusCode) -> &mut Self { + if let Some(parts) = self.inner() { + parts.status = status; + } + self + } + + /// Insert a header, replacing any that were set with an equivalent field name. + /// + /// # Examples + /// ``` + /// use actix_http::{ResponseBuilder, http::header}; + /// + /// let res = ResponseBuilder::default() + /// .insert_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)) + /// .insert_header(("X-TEST", "value")) + /// .finish(); + /// + /// assert!(res.headers().contains_key("content-type")); + /// assert!(res.headers().contains_key("x-test")); + /// ``` + pub fn insert_header(&mut self, header: H) -> &mut Self + where + H: IntoHeaderPair, + { + if let Some(parts) = self.inner() { + match header.try_into_header_pair() { + Ok((key, value)) => { + parts.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + }; + } + + self + } + + /// Append a header, keeping any that were set with an equivalent field name. + /// + /// # Examples + /// ``` + /// use actix_http::{ResponseBuilder, http::header}; + /// + /// let res = ResponseBuilder::default() + /// .append_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)) + /// .append_header(("X-TEST", "value1")) + /// .append_header(("X-TEST", "value2")) + /// .finish(); + /// + /// assert_eq!(res.headers().get_all("content-type").count(), 1); + /// assert_eq!(res.headers().get_all("x-test").count(), 2); + /// ``` + pub fn append_header(&mut self, header: H) -> &mut Self + where + H: IntoHeaderPair, + { + if let Some(parts) = self.inner() { + match header.try_into_header_pair() { + Ok((key, value)) => parts.headers.append(key, value), + Err(e) => self.err = Some(e.into()), + }; + } + + self + } + + /// Set the custom reason for the response. + #[inline] + pub fn reason(&mut self, reason: &'static str) -> &mut Self { + if let Some(parts) = self.inner() { + parts.reason = Some(reason); + } + self + } + + /// Set connection type to KeepAlive + #[inline] + pub fn keep_alive(&mut self) -> &mut Self { + if let Some(parts) = self.inner() { + parts.set_connection_type(ConnectionType::KeepAlive); + } + self + } + + /// Set connection type to Upgrade + #[inline] + pub fn upgrade(&mut self, value: V) -> &mut Self + where + V: IntoHeaderValue, + { + if let Some(parts) = self.inner() { + parts.set_connection_type(ConnectionType::Upgrade); + } + + if let Ok(value) = value.try_into_value() { + self.insert_header((header::UPGRADE, value)); + } + + self + } + + /// Force close connection, even if it is marked as keep-alive + #[inline] + pub fn force_close(&mut self) -> &mut Self { + if let Some(parts) = self.inner() { + parts.set_connection_type(ConnectionType::Close); + } + self + } + + /// Disable chunked transfer encoding for HTTP/1.1 streaming responses. + #[inline] + pub fn no_chunking(&mut self, len: u64) -> &mut Self { + let mut buf = itoa::Buffer::new(); + self.insert_header((header::CONTENT_LENGTH, buf.format(len))); + + if let Some(parts) = self.inner() { + parts.no_chunking(true); + } + self + } + + /// Set response content type. + #[inline] + pub fn content_type(&mut self, value: V) -> &mut Self + where + V: IntoHeaderValue, + { + if let Some(parts) = self.inner() { + match value.try_into_value() { + Ok(value) => { + parts.headers.insert(header::CONTENT_TYPE, value); + } + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Responses extensions + #[inline] + pub fn extensions(&self) -> Ref<'_, Extensions> { + let head = self.head.as_ref().expect("cannot reuse response builder"); + head.extensions.borrow() + } + + /// Mutable reference to a the response's extensions + #[inline] + pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { + let head = self.head.as_ref().expect("cannot reuse response builder"); + head.extensions.borrow_mut() + } + + /// Generate response with a wrapped body. + /// + /// This `ResponseBuilder` will be left in a useless state. + #[inline] + pub fn body>(&mut self, body: B) -> Response { + self.message_body(body.into()) + } + + /// Generate response with a body. + /// + /// This `ResponseBuilder` will be left in a useless state. + pub fn message_body(&mut self, body: B) -> Response { + if let Some(e) = self.err.take() { + return Response::from(Error::from(e)).into_body(); + } + + let response = self.head.take().expect("cannot reuse response builder"); + + Response { + head: response, + body: ResponseBody::Body(body), + error: None, + } + } + + /// Generate response with a streaming body. + /// + /// This `ResponseBuilder` will be left in a useless state. + #[inline] + pub fn streaming(&mut self, stream: S) -> Response + where + S: Stream> + Unpin + 'static, + E: Into + 'static, + { + self.body(Body::from_message(BodyStream::new(stream))) + } + + /// Generate response with an empty body. + /// + /// This `ResponseBuilder` will be left in a useless state. + #[inline] + pub fn finish(&mut self) -> Response { + self.body(Body::Empty) + } + + /// Create an owned `ResponseBuilder`, leaving the original in a useless state. + pub fn take(&mut self) -> ResponseBuilder { + ResponseBuilder { + head: self.head.take(), + err: self.err.take(), + } + } + + /// Get access to the inner response head if there has been no error. + fn inner(&mut self) -> Option<&mut ResponseHead> { + if self.err.is_some() { + return None; + } + + self.head.as_mut().map(|r| &mut **r) + } +} + +impl Default for ResponseBuilder { + fn default() -> Self { + Self::new(StatusCode::OK) + } +} + +/// Convert `Response` to a `ResponseBuilder`. Body get dropped. +impl From> for ResponseBuilder { + fn from(res: Response) -> ResponseBuilder { + ResponseBuilder { + head: Some(res.head), + err: None, + } + } +} + +/// Convert `ResponseHead` to a `ResponseBuilder` +impl<'a> From<&'a ResponseHead> for ResponseBuilder { + fn from(head: &'a ResponseHead) -> ResponseBuilder { + let mut msg = BoxedResponseHead::new(head.status); + msg.version = head.version; + msg.reason = head.reason; + + for (k, v) in head.headers.iter() { + msg.headers.append(k.clone(), v.clone()); + } + + msg.no_chunking(!head.chunked()); + + ResponseBuilder { + head: Some(msg), + err: None, + } + } +} + +impl Future for ResponseBuilder { + type Output = Result, Error>; + + fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + Poll::Ready(Ok(self.finish())) + } +} + +impl fmt::Debug for ResponseBuilder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let head = self.head.as_ref().unwrap(); + + let res = writeln!( + f, + "\nResponseBuilder {:?} {}{}", + head.version, + head.status, + head.reason.unwrap_or(""), + ); + let _ = writeln!(f, " headers:"); + for (key, val) in head.headers.iter() { + let _ = writeln!(f, " {:?}: {:?}", key, val); + } + res + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::body::Body; + use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; + + #[test] + fn test_basic_builder() { + let resp = Response::build(StatusCode::OK) + .insert_header(("X-TEST", "value")) + .finish(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[test] + fn test_upgrade() { + let resp = Response::build(StatusCode::OK) + .upgrade("websocket") + .finish(); + assert!(resp.upgrade()); + assert_eq!( + resp.headers().get(header::UPGRADE).unwrap(), + HeaderValue::from_static("websocket") + ); + } + + #[test] + fn test_force_close() { + let resp = Response::build(StatusCode::OK).force_close().finish(); + assert!(!resp.keep_alive()) + } + + #[test] + fn test_content_type() { + let resp = Response::build(StatusCode::OK) + .content_type("text/plain") + .body(Body::Empty); + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain") + } + + #[test] + fn test_into_builder() { + let mut resp: Response = "test".into(); + assert_eq!(resp.status(), StatusCode::OK); + + resp.headers_mut().insert( + HeaderName::from_static("cookie"), + HeaderValue::from_static("cookie1=val100"), + ); + + let mut builder: ResponseBuilder = resp.into(); + let resp = builder.status(StatusCode::BAD_REQUEST).finish(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let cookie = resp.headers().get_all("Cookie").next().unwrap(); + assert_eq!(cookie.to_str().unwrap(), "cookie1=val100"); + } + + #[test] + fn response_builder_header_insert_kv() { + let mut res = Response::build(StatusCode::OK); + res.insert_header(("Content-Type", "application/octet-stream")); + let res = res.finish(); + + assert_eq!( + res.headers().get("Content-Type"), + Some(&HeaderValue::from_static("application/octet-stream")) + ); + } + + #[test] + fn response_builder_header_insert_typed() { + let mut res = Response::build(StatusCode::OK); + res.insert_header((header::CONTENT_TYPE, mime::APPLICATION_OCTET_STREAM)); + let res = res.finish(); + + assert_eq!( + res.headers().get("Content-Type"), + Some(&HeaderValue::from_static("application/octet-stream")) + ); + } + + #[test] + fn response_builder_header_append_kv() { + let mut res = Response::build(StatusCode::OK); + res.append_header(("Content-Type", "application/octet-stream")); + res.append_header(("Content-Type", "application/json")); + let res = res.finish(); + + let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); + assert_eq!(headers.len(), 2); + assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); + assert!(headers.contains(&HeaderValue::from_static("application/json"))); + } + + #[test] + fn response_builder_header_append_typed() { + let mut res = Response::build(StatusCode::OK); + res.append_header((header::CONTENT_TYPE, mime::APPLICATION_OCTET_STREAM)); + res.append_header((header::CONTENT_TYPE, mime::APPLICATION_JSON)); + let res = res.finish(); + + let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect(); + assert_eq!(headers.len(), 2); + assert!(headers.contains(&HeaderValue::from_static("application/octet-stream"))); + assert!(headers.contains(&HeaderValue::from_static("application/json"))); + } +} diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 5b18044b2..22df2b4ff 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -101,29 +101,37 @@ pub enum HandshakeError { impl ResponseError for HandshakeError { fn error_response(&self) -> Response { match self { - HandshakeError::GetMethodRequired => Response::MethodNotAllowed() - .insert_header((header::ALLOW, "GET")) - .finish(), + HandshakeError::GetMethodRequired => { + Response::build(StatusCode::METHOD_NOT_ALLOWED) + .insert_header((header::ALLOW, "GET")) + .finish() + } - HandshakeError::NoWebsocketUpgrade => Response::BadRequest() - .reason("No WebSocket Upgrade header found") - .finish(), + HandshakeError::NoWebsocketUpgrade => { + Response::build(StatusCode::BAD_REQUEST) + .reason("No WebSocket Upgrade header found") + .finish() + } - HandshakeError::NoConnectionUpgrade => Response::BadRequest() - .reason("No Connection upgrade") - .finish(), + HandshakeError::NoConnectionUpgrade => { + Response::build(StatusCode::BAD_REQUEST) + .reason("No Connection upgrade") + .finish() + } - HandshakeError::NoVersionHeader => Response::BadRequest() + HandshakeError::NoVersionHeader => Response::build(StatusCode::BAD_REQUEST) .reason("WebSocket version header is required") .finish(), - HandshakeError::UnsupportedVersion => Response::BadRequest() - .reason("Unsupported WebSocket version") - .finish(), - - HandshakeError::BadWebsocketKey => { - Response::BadRequest().reason("Handshake error").finish() + HandshakeError::UnsupportedVersion => { + Response::build(StatusCode::BAD_REQUEST) + .reason("Unsupported WebSocket version") + .finish() } + + HandshakeError::BadWebsocketKey => Response::build(StatusCode::BAD_REQUEST) + .reason("Handshake error") + .finish(), } } } diff --git a/actix-http/tests/test_client.rs b/actix-http/tests/test_client.rs index b5f8d54b9..0a06d90e5 100644 --- a/actix-http/tests/test_client.rs +++ b/actix-http/tests/test_client.rs @@ -33,7 +33,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ async fn test_h1_v2() { let srv = test_server(move || { HttpService::build() - .finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .finish(|_| future::ok::<_, ()>(Response::ok().set_body(STR))) .tcp() }) .await; @@ -61,7 +61,7 @@ async fn test_h1_v2() { async fn test_connection_close() { let srv = test_server(move || { HttpService::build() - .finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .finish(|_| future::ok::<_, ()>(Response::ok().set_body(STR))) .tcp() .map(|_| ()) }) @@ -77,9 +77,9 @@ async fn test_with_query_parameter() { HttpService::build() .finish(|req: Request| { if req.uri().query().unwrap().contains("qp=") { - future::ok::<_, ()>(Response::Ok().finish()) + future::ok::<_, ()>(Response::ok()) } else { - future::ok::<_, ()>(Response::BadRequest().finish()) + future::ok::<_, ()>(Response::bad_request()) } }) .tcp() @@ -112,7 +112,7 @@ async fn test_h1_expect() { let str = std::str::from_utf8(&buf).unwrap(); assert_eq!(str, "expect body"); - Ok::<_, ()>(Response::Ok().finish()) + Ok::<_, ()>(Response::ok()) }) .tcp() }) diff --git a/actix-http/tests/test_openssl.rs b/actix-http/tests/test_openssl.rs index dcf05e8d8..7cbd58518 100644 --- a/actix-http/tests/test_openssl.rs +++ b/actix-http/tests/test_openssl.rs @@ -71,7 +71,7 @@ fn tls_config() -> SslAcceptor { async fn test_h2() -> io::Result<()> { let srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, Error>(Response::Ok().finish())) + .h2(|_| ok::<_, Error>(Response::ok())) .openssl(tls_config()) .map_err(|_| ()) }) @@ -89,7 +89,7 @@ async fn test_h2_1() -> io::Result<()> { .finish(|req: Request| { assert!(req.peer_addr().is_some()); assert_eq!(req.version(), Version::HTTP_2); - ok::<_, Error>(Response::Ok().finish()) + ok::<_, Error>(Response::ok()) }) .openssl(tls_config()) .map_err(|_| ()) @@ -108,7 +108,7 @@ async fn test_h2_body() -> io::Result<()> { HttpService::build() .h2(|mut req: Request<_>| async move { let body = load_body(req.take_payload()).await?; - Ok::<_, Error>(Response::Ok().body(body)) + Ok::<_, Error>(Response::ok().set_body(body)) }) .openssl(tls_config()) .map_err(|_| ()) @@ -186,7 +186,7 @@ async fn test_h2_headers() { let mut srv = test_server(move || { let data = data.clone(); HttpService::build().h2(move |_| { - let mut builder = Response::Ok(); + let mut builder = Response::build(StatusCode::OK); for idx in 0..90 { builder.insert_header( (format!("X-TEST-{}", idx).as_str(), @@ -245,7 +245,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ async fn test_h2_body2() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, ()>(Response::ok().set_body(STR))) .openssl(tls_config()) .map_err(|_| ()) }) @@ -263,7 +263,7 @@ async fn test_h2_body2() { async fn test_h2_head_empty() { let mut srv = test_server(move || { HttpService::build() - .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .finish(|_| ok::<_, ()>(Response::ok().set_body(STR))) .openssl(tls_config()) .map_err(|_| ()) }) @@ -287,7 +287,7 @@ async fn test_h2_head_empty() { async fn test_h2_head_binary() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, ()>(Response::ok().set_body(STR))) .openssl(tls_config()) .map_err(|_| ()) }) @@ -310,7 +310,7 @@ async fn test_h2_head_binary() { async fn test_h2_head_binary2() { let srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, ()>(Response::ok().set_body(STR))) .openssl(tls_config()) .map_err(|_| ()) }) @@ -332,7 +332,7 @@ async fn test_h2_body_length() { .h2(|_| { let body = once(ok(Bytes::from_static(STR.as_ref()))); ok::<_, ()>( - Response::Ok().body(SizedStream::new(STR.len() as u64, body)), + Response::ok().set_body(SizedStream::new(STR.len() as u64, body)), ) }) .openssl(tls_config()) @@ -355,7 +355,7 @@ async fn test_h2_body_chunked_explicit() { .h2(|_| { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); ok::<_, ()>( - Response::Ok() + Response::build(StatusCode::OK) .insert_header((header::TRANSFER_ENCODING, "chunked")) .streaming(body), ) @@ -383,7 +383,7 @@ async fn test_h2_response_http_error_handling() { .h2(fn_service(|_| { let broken_header = Bytes::from_static(b"\0\0\0"); ok::<_, ()>( - Response::Ok() + Response::build(StatusCode::OK) .insert_header((header::CONTENT_TYPE, broken_header)) .body(STR), ) @@ -428,7 +428,7 @@ async fn test_h2_on_connect() { }) .h2(|req: Request| { assert!(req.extensions().contains::()); - ok::<_, ()>(Response::Ok().finish()) + ok::<_, ()>(Response::ok()) }) .openssl(tls_config()) .map_err(|_| ()) diff --git a/actix-http/tests/test_rustls.rs b/actix-http/tests/test_rustls.rs index 538a2b005..a122ab847 100644 --- a/actix-http/tests/test_rustls.rs +++ b/actix-http/tests/test_rustls.rs @@ -56,7 +56,7 @@ fn tls_config() -> RustlsServerConfig { async fn test_h1() -> io::Result<()> { let srv = test_server(move || { HttpService::build() - .h1(|_| ok::<_, Error>(Response::Ok().finish())) + .h1(|_| ok::<_, Error>(Response::ok())) .rustls(tls_config()) }) .await; @@ -70,7 +70,7 @@ async fn test_h1() -> io::Result<()> { async fn test_h2() -> io::Result<()> { let srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, Error>(Response::Ok().finish())) + .h2(|_| ok::<_, Error>(Response::ok())) .rustls(tls_config()) }) .await; @@ -87,7 +87,7 @@ async fn test_h1_1() -> io::Result<()> { .h1(|req: Request| { assert!(req.peer_addr().is_some()); assert_eq!(req.version(), Version::HTTP_11); - ok::<_, Error>(Response::Ok().finish()) + ok::<_, Error>(Response::ok()) }) .rustls(tls_config()) }) @@ -105,7 +105,7 @@ async fn test_h2_1() -> io::Result<()> { .finish(|req: Request| { assert!(req.peer_addr().is_some()); assert_eq!(req.version(), Version::HTTP_2); - ok::<_, Error>(Response::Ok().finish()) + ok::<_, Error>(Response::ok()) }) .rustls(tls_config()) }) @@ -123,7 +123,7 @@ async fn test_h2_body1() -> io::Result<()> { HttpService::build() .h2(|mut req: Request<_>| async move { let body = load_body(req.take_payload()).await?; - Ok::<_, Error>(Response::Ok().body(body)) + Ok::<_, Error>(Response::ok().set_body(body)) }) .rustls(tls_config()) }) @@ -199,7 +199,7 @@ async fn test_h2_headers() { let mut srv = test_server(move || { let data = data.clone(); HttpService::build().h2(move |_| { - let mut config = Response::Ok(); + let mut config = Response::build(StatusCode::OK); for idx in 0..90 { config.insert_header(( format!("X-TEST-{}", idx).as_str(), @@ -257,7 +257,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ async fn test_h2_body2() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, ()>(Response::ok().set_body(STR))) .rustls(tls_config()) }) .await; @@ -274,7 +274,7 @@ async fn test_h2_body2() { async fn test_h2_head_empty() { let mut srv = test_server(move || { HttpService::build() - .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .finish(|_| ok::<_, ()>(Response::ok().set_body(STR))) .rustls(tls_config()) }) .await; @@ -300,7 +300,7 @@ async fn test_h2_head_empty() { async fn test_h2_head_binary() { let mut srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, ()>(Response::ok().set_body(STR))) .rustls(tls_config()) }) .await; @@ -325,7 +325,7 @@ async fn test_h2_head_binary() { async fn test_h2_head_binary2() { let srv = test_server(move || { HttpService::build() - .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h2(|_| ok::<_, ()>(Response::ok().set_body(STR))) .rustls(tls_config()) }) .await; @@ -349,7 +349,7 @@ async fn test_h2_body_length() { .h2(|_| { let body = once(ok(Bytes::from_static(STR.as_ref()))); ok::<_, ()>( - Response::Ok().body(SizedStream::new(STR.len() as u64, body)), + Response::ok().set_body(SizedStream::new(STR.len() as u64, body)), ) }) .rustls(tls_config()) @@ -371,7 +371,7 @@ async fn test_h2_body_chunked_explicit() { .h2(|_| { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); ok::<_, ()>( - Response::Ok() + Response::build(StatusCode::OK) .insert_header((header::TRANSFER_ENCODING, "chunked")) .streaming(body), ) @@ -399,7 +399,7 @@ async fn test_h2_response_http_error_handling() { ok::<_, ()>(fn_service(|_| { let broken_header = Bytes::from_static(b"\0\0\0"); ok::<_, ()>( - Response::Ok() + Response::build(StatusCode::OK) .insert_header((http::header::CONTENT_TYPE, broken_header)) .body(STR), ) diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs index 80ec0335b..9b8b039c3 100644 --- a/actix-http/tests/test_server.rs +++ b/actix-http/tests/test_server.rs @@ -14,8 +14,8 @@ use regex::Regex; use actix_http::HttpMessage; use actix_http::{ body::{Body, SizedStream}, - error, http, - http::header, + error, + http::{self, header, StatusCode}, Error, HttpService, KeepAlive, Request, Response, }; @@ -28,7 +28,7 @@ async fn test_h1() { .client_disconnect(1000) .h1(|req: Request| { assert!(req.peer_addr().is_some()); - ok::<_, ()>(Response::Ok().finish()) + ok::<_, ()>(Response::ok()) }) .tcp() }) @@ -48,7 +48,7 @@ async fn test_h1_2() { .finish(|req: Request| { assert!(req.peer_addr().is_some()); assert_eq!(req.version(), http::Version::HTTP_11); - ok::<_, ()>(Response::Ok().finish()) + ok::<_, ()>(Response::ok()) }) .tcp() }) @@ -69,7 +69,7 @@ async fn test_expect_continue() { err(error::ErrorPreconditionFailed("error")) } })) - .finish(|_| ok::<_, ()>(Response::Ok().finish())) + .finish(|_| ok::<_, ()>(Response::ok())) .tcp() }) .await; @@ -100,7 +100,7 @@ async fn test_expect_continue_h1() { } }) })) - .h1(fn_service(|_| ok::<_, ()>(Response::Ok().finish()))) + .h1(fn_service(|_| ok::<_, ()>(Response::ok()))) .tcp() }) .await; @@ -134,7 +134,9 @@ async fn test_chunked_payload() { }) .fold(0usize, |acc, chunk| ready(acc + chunk.len())) .map(|req_size| { - Ok::<_, Error>(Response::Ok().body(format!("size={}", req_size))) + Ok::<_, Error>( + Response::ok().set_body(format!("size={}", req_size)), + ) }) })) .tcp() @@ -179,7 +181,7 @@ async fn test_slow_request() { let srv = test_server(|| { HttpService::build() .client_timeout(100) - .finish(|_| ok::<_, ()>(Response::Ok().finish())) + .finish(|_| ok::<_, ()>(Response::ok())) .tcp() }) .await; @@ -195,7 +197,7 @@ async fn test_slow_request() { async fn test_http1_malformed_request() { let srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, ()>(Response::ok())) .tcp() }) .await; @@ -211,7 +213,7 @@ async fn test_http1_malformed_request() { async fn test_http1_keepalive() { let srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, ()>(Response::ok())) .tcp() }) .await; @@ -233,7 +235,7 @@ async fn test_http1_keepalive_timeout() { let srv = test_server(|| { HttpService::build() .keep_alive(1) - .h1(|_| ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, ()>(Response::ok())) .tcp() }) .await; @@ -254,7 +256,7 @@ async fn test_http1_keepalive_timeout() { async fn test_http1_keepalive_close() { let srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, ()>(Response::ok())) .tcp() }) .await; @@ -275,7 +277,7 @@ async fn test_http1_keepalive_close() { async fn test_http10_keepalive_default_close() { let srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, ()>(Response::ok())) .tcp() }) .await; @@ -295,7 +297,7 @@ async fn test_http10_keepalive_default_close() { async fn test_http10_keepalive() { let srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, ()>(Response::ok())) .tcp() }) .await; @@ -323,7 +325,7 @@ async fn test_http1_keepalive_disabled() { let srv = test_server(|| { HttpService::build() .keep_alive(KeepAlive::Disabled) - .h1(|_| ok::<_, ()>(Response::Ok().finish())) + .h1(|_| ok::<_, ()>(Response::ok())) .tcp() }) .await; @@ -394,7 +396,7 @@ async fn test_h1_headers() { let mut srv = test_server(move || { let data = data.clone(); HttpService::build().h1(move |_| { - let mut builder = Response::Ok(); + let mut builder = Response::build(StatusCode::OK); for idx in 0..90 { builder.insert_header(( format!("X-TEST-{}", idx).as_str(), @@ -451,7 +453,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ async fn test_h1_body() { let mut srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h1(|_| ok::<_, ()>(Response::ok().set_body(STR))) .tcp() }) .await; @@ -468,7 +470,7 @@ async fn test_h1_body() { async fn test_h1_head_empty() { let mut srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h1(|_| ok::<_, ()>(Response::ok().set_body(STR))) .tcp() }) .await; @@ -493,7 +495,7 @@ async fn test_h1_head_empty() { async fn test_h1_head_binary() { let mut srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h1(|_| ok::<_, ()>(Response::ok().set_body(STR))) .tcp() }) .await; @@ -518,7 +520,7 @@ async fn test_h1_head_binary() { async fn test_h1_head_binary2() { let srv = test_server(|| { HttpService::build() - .h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + .h1(|_| ok::<_, ()>(Response::ok().set_body(STR))) .tcp() }) .await; @@ -542,7 +544,7 @@ async fn test_h1_body_length() { .h1(|_| { let body = once(ok(Bytes::from_static(STR.as_ref()))); ok::<_, ()>( - Response::Ok().body(SizedStream::new(STR.len() as u64, body)), + Response::ok().set_body(SizedStream::new(STR.len() as u64, body)), ) }) .tcp() @@ -564,7 +566,7 @@ async fn test_h1_body_chunked_explicit() { .h1(|_| { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); ok::<_, ()>( - Response::Ok() + Response::build(StatusCode::OK) .insert_header((header::TRANSFER_ENCODING, "chunked")) .streaming(body), ) @@ -598,7 +600,7 @@ async fn test_h1_body_chunked_implicit() { HttpService::build() .h1(|_| { let body = once(ok::<_, Error>(Bytes::from_static(STR.as_ref()))); - ok::<_, ()>(Response::Ok().streaming(body)) + ok::<_, ()>(Response::build(StatusCode::OK).streaming(body)) }) .tcp() }) @@ -628,7 +630,7 @@ async fn test_h1_response_http_error_handling() { .h1(fn_service(|_| { let broken_header = Bytes::from_static(b"\0\0\0"); ok::<_, ()>( - Response::Ok() + Response::build(StatusCode::OK) .insert_header((http::header::CONTENT_TYPE, broken_header)) .body(STR), ) @@ -671,7 +673,7 @@ async fn test_h1_on_connect() { }) .h1(|req: Request| { assert!(req.extensions().contains::()); - ok::<_, ()>(Response::Ok().finish()) + ok::<_, ()>(Response::ok()) }) .tcp() }) diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs index 9a2e57711..72870bab5 100644 --- a/actix-http/tests/test_ws.rs +++ b/actix-http/tests/test_ws.rs @@ -91,7 +91,7 @@ async fn test_simple() { let ws_service = ws_service.clone(); HttpService::build() .upgrade(fn_factory(move || future::ok::<_, ()>(ws_service.clone()))) - .finish(|_| future::ok::<_, ()>(Response::NotFound())) + .finish(|_| future::ok::<_, ()>(Response::not_found())) .tcp() } }) diff --git a/awc/tests/test_ws.rs b/awc/tests/test_ws.rs index 3f19ac4e8..bfc81afbc 100644 --- a/awc/tests/test_ws.rs +++ b/awc/tests/test_ws.rs @@ -36,7 +36,7 @@ async fn test_simple() { ws::Dispatcher::with(framed, ws_service).await } }) - .finish(|_| ok::<_, Error>(Response::NotFound())) + .finish(|_| ok::<_, Error>(Response::not_found())) .tcp() }) .await; diff --git a/src/app_service.rs b/src/app_service.rs index be4ccf22f..32c779a32 100644 --- a/src/app_service.rs +++ b/src/app_service.rs @@ -1,20 +1,23 @@ use std::cell::RefCell; use std::rc::Rc; -use actix_http::{Extensions, Request, Response}; +use actix_http::{Extensions, Request}; use actix_router::{Path, ResourceDef, Router, Url}; use actix_service::boxed::{self, BoxService, BoxServiceFactory}; use actix_service::{fn_service, Service, ServiceFactory}; use futures_core::future::LocalBoxFuture; use futures_util::future::join_all; -use crate::config::{AppConfig, AppService}; use crate::data::FnDataFactory; use crate::error::Error; use crate::guard::Guard; use crate::request::{HttpRequest, HttpRequestPool}; use crate::rmap::ResourceMap; use crate::service::{AppServiceFactory, ServiceRequest, ServiceResponse}; +use crate::{ + config::{AppConfig, AppService}, + HttpResponse, +}; type Guards = Vec>; type HttpService = BoxService; @@ -64,7 +67,7 @@ where // if no user defined default service exists. let default = self.default.clone().unwrap_or_else(|| { Rc::new(boxed::factory(fn_service(|req: ServiceRequest| async { - Ok(req.into_response(Response::NotFound().finish())) + Ok(req.into_response(HttpResponse::NotFound())) }))) }); diff --git a/src/resource.rs b/src/resource.rs index e868bb547..049e56291 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -3,7 +3,7 @@ use std::fmt; use std::future::Future; use std::rc::Rc; -use actix_http::{Error, Extensions, Response}; +use actix_http::{Error, Extensions}; use actix_router::IntoPattern; use actix_service::boxed::{self, BoxService, BoxServiceFactory}; use actix_service::{ @@ -13,7 +13,6 @@ use actix_service::{ use futures_core::future::LocalBoxFuture; use futures_util::future::join_all; -use crate::data::Data; use crate::dev::{insert_slash, AppService, HttpServiceFactory, ResourceDef}; use crate::extract::FromRequest; use crate::guard::Guard; @@ -21,6 +20,7 @@ use crate::handler::Handler; use crate::responder::Responder; use crate::route::{Route, RouteService}; use crate::service::{ServiceRequest, ServiceResponse}; +use crate::{data::Data, HttpResponse}; type HttpService = BoxService; type HttpNewService = BoxServiceFactory<(), ServiceRequest, ServiceResponse, Error, ()>; @@ -71,7 +71,7 @@ impl Resource { guards: Vec::new(), app_data: None, default: boxed::factory(fn_service(|req: ServiceRequest| async { - Ok(req.into_response(Response::MethodNotAllowed().finish())) + Ok(req.into_response(HttpResponse::MethodNotAllowed())) })), } } diff --git a/src/responder.rs b/src/responder.rs index 2348e9276..7b8288ed8 100644 --- a/src/responder.rs +++ b/src/responder.rs @@ -1,4 +1,4 @@ -use std::fmt; +use std::{borrow::Cow, fmt}; use actix_http::{ body::Body, @@ -117,53 +117,29 @@ impl Responder for (T, StatusCode) { } } -impl Responder for &'static str { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(self) - } +macro_rules! impl_responder { + ($res: ty, $ct: path) => { + impl Responder for $res { + fn respond_to(self, _: &HttpRequest) -> HttpResponse { + HttpResponse::Ok().content_type($ct).body(self) + } + } + }; } -impl Responder for &'static [u8] { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(self) - } -} +impl_responder!(&'static str, mime::TEXT_PLAIN_UTF_8); -impl Responder for String { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(self) - } -} +impl_responder!(String, mime::TEXT_PLAIN_UTF_8); -impl<'a> Responder for &'a String { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::TEXT_PLAIN_UTF_8) - .body(self) - } -} +impl_responder!(&'_ String, mime::TEXT_PLAIN_UTF_8); -impl Responder for Bytes { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(self) - } -} +impl_responder!(Cow<'_, str>, mime::TEXT_PLAIN_UTF_8); -impl Responder for BytesMut { - fn respond_to(self, _: &HttpRequest) -> HttpResponse { - HttpResponse::Ok() - .content_type(mime::APPLICATION_OCTET_STREAM) - .body(self) - } -} +impl_responder!(&'static [u8], mime::APPLICATION_OCTET_STREAM); + +impl_responder!(Bytes, mime::APPLICATION_OCTET_STREAM); + +impl_responder!(BytesMut, mime::APPLICATION_OCTET_STREAM); /// Allows overriding status code and headers for a responder. pub struct CustomResponder { @@ -358,6 +334,31 @@ pub(crate) mod tests { HeaderValue::from_static("text/plain; charset=utf-8") ); + let s = String::from("test"); + let resp = Cow::Borrowed(s.as_str()).respond_to(&req); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + + let resp = Cow::<'_, str>::Owned(s).respond_to(&req); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + + let resp = Cow::Borrowed("test").respond_to(&req); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().bin_ref(), b"test"); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + let resp = Bytes::from_static(b"test").respond_to(&req); assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.body().bin_ref(), b"test"); diff --git a/src/response.rs b/src/response/builder.rs similarity index 55% rename from src/response.rs rename to src/response/builder.rs index 23244e6a5..2c04f3f64 100644 --- a/src/response.rs +++ b/src/response/builder.rs @@ -1,17 +1,15 @@ use std::{ cell::{Ref, RefMut}, convert::TryInto, - fmt, future::Future, - mem, pin::Pin, task::{Context, Poll}, }; use actix_http::{ - body::{Body, BodyStream, MessageBody, ResponseBody}, + body::{Body, BodyStream}, http::{ - header::{self, HeaderMap, HeaderName, IntoHeaderPair, IntoHeaderValue}, + header::{self, HeaderName, IntoHeaderPair, IntoHeaderValue}, ConnectionType, Error as HttpError, StatusCode, }, Extensions, Response, ResponseHead, @@ -25,282 +23,10 @@ use actix_http::http::header::HeaderValue; #[cfg(feature = "cookies")] use cookie::{Cookie, CookieJar}; -use crate::error::{Error, JsonPayloadError}; - -/// An HTTP Response -pub struct HttpResponse { - res: Response, - error: Option, -} - -impl HttpResponse { - /// Create HTTP response builder with specific status. - #[inline] - pub fn build(status: StatusCode) -> HttpResponseBuilder { - HttpResponseBuilder::new(status) - } - - /// Create HTTP response builder - #[inline] - pub fn build_from>(source: T) -> HttpResponseBuilder { - source.into() - } - - /// Create a response. - #[inline] - pub fn new(status: StatusCode) -> Self { - Self { - res: Response::new(status), - error: None, - } - } - - /// Create an error response. - #[inline] - pub fn from_error(error: Error) -> Self { - let res = error.as_response_error().error_response(); - - Self { - res, - error: Some(error), - } - } - - /// Convert response to response with body - pub fn into_body(self) -> HttpResponse { - HttpResponse { - res: self.res.into_body(), - error: self.error, - } - } -} - -impl HttpResponse { - /// Constructs a response with body - #[inline] - pub fn with_body(status: StatusCode, body: B) -> Self { - Self { - res: Response::with_body(status, body), - error: None, - } - } - - /// Returns a reference to response head. - #[inline] - pub fn head(&self) -> &ResponseHead { - self.res.head() - } - - /// Returns a mutable reference to response head. - #[inline] - pub fn head_mut(&mut self) -> &mut ResponseHead { - self.res.head_mut() - } - - /// The source `error` for this response - #[inline] - pub fn error(&self) -> Option<&Error> { - self.error.as_ref() - } - - /// Get the response status code - #[inline] - pub fn status(&self) -> StatusCode { - self.res.status() - } - - /// Set the `StatusCode` for this response - #[inline] - pub fn status_mut(&mut self) -> &mut StatusCode { - self.res.status_mut() - } - - /// Get the headers from the response - #[inline] - pub fn headers(&self) -> &HeaderMap { - self.res.headers() - } - - /// Get a mutable reference to the headers - #[inline] - pub fn headers_mut(&mut self) -> &mut HeaderMap { - self.res.headers_mut() - } - - /// Get an iterator for the cookies set by this response. - #[cfg(feature = "cookies")] - pub fn cookies(&self) -> CookieIter<'_> { - CookieIter { - iter: self.headers().get_all(header::SET_COOKIE), - } - } - - /// Add a cookie to this response - #[cfg(feature = "cookies")] - pub fn add_cookie(&mut self, cookie: &Cookie<'_>) -> Result<(), HttpError> { - HeaderValue::from_str(&cookie.to_string()) - .map(|c| { - self.headers_mut().append(header::SET_COOKIE, c); - }) - .map_err(|e| e.into()) - } - - /// Remove all cookies with the given name from this response. Returns - /// the number of cookies removed. - #[cfg(feature = "cookies")] - pub fn del_cookie(&mut self, name: &str) -> usize { - let headers = self.headers_mut(); - - let vals: Vec = headers - .get_all(header::SET_COOKIE) - .map(|v| v.to_owned()) - .collect(); - - headers.remove(header::SET_COOKIE); - - let mut count: usize = 0; - for v in vals { - if let Ok(s) = v.to_str() { - if let Ok(c) = Cookie::parse_encoded(s) { - if c.name() == name { - count += 1; - continue; - } - } - } - - // put set-cookie header head back if it does not validate - headers.append(header::SET_COOKIE, v); - } - - count - } - - /// Connection upgrade status - #[inline] - pub fn upgrade(&self) -> bool { - self.res.upgrade() - } - - /// Keep-alive status for this connection - pub fn keep_alive(&self) -> bool { - self.res.keep_alive() - } - - /// Responses extensions - #[inline] - pub fn extensions(&self) -> Ref<'_, Extensions> { - self.res.extensions() - } - - /// Mutable reference to a the response's extensions - #[inline] - pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { - self.res.extensions_mut() - } - - /// Get body of this response - #[inline] - pub fn body(&self) -> &ResponseBody { - self.res.body() - } - - /// Set a body - pub fn set_body(self, body: B2) -> HttpResponse { - HttpResponse { - res: self.res.set_body(body), - error: None, - // error: self.error, ?? - } - } - - /// Split response and body - pub fn into_parts(self) -> (HttpResponse<()>, ResponseBody) { - let (head, body) = self.res.into_parts(); - - ( - HttpResponse { - res: head, - error: None, - }, - body, - ) - } - - /// Drop request's body - pub fn drop_body(self) -> HttpResponse<()> { - HttpResponse { - res: self.res.drop_body(), - error: None, - } - } - - /// Set a body and return previous body value - pub fn map_body(self, f: F) -> HttpResponse - where - F: FnOnce(&mut ResponseHead, ResponseBody) -> ResponseBody, - { - HttpResponse { - res: self.res.map_body(f), - error: self.error, - } - } - - /// Extract response body - pub fn take_body(&mut self) -> ResponseBody { - self.res.take_body() - } -} - -impl fmt::Debug for HttpResponse { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("HttpResponse") - .field("error", &self.error) - .field("res", &self.res) - .finish() - } -} - -impl From> for HttpResponse { - fn from(res: Response) -> Self { - HttpResponse { res, error: None } - } -} - -impl From for HttpResponse { - fn from(err: Error) -> Self { - HttpResponse::from_error(err) - } -} - -impl From> for Response { - fn from(res: HttpResponse) -> Self { - // this impl will always be called as part of dispatcher - - // TODO: expose cause somewhere? - // if let Some(err) = res.error { - // eprintln!("impl From> for Response let Some(err)"); - // return Response::from_error(err).into_body(); - // } - - res.res - } -} - -impl Future for HttpResponse { - type Output = Result, Error>; - - fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { - if let Some(err) = self.error.take() { - return Poll::Ready(Ok(Response::from_error(err).into_body())); - } - - Poll::Ready(Ok(mem::replace( - &mut self.res, - Response::new(StatusCode::default()), - ))) - } -} +use crate::{ + error::{Error, JsonPayloadError}, + HttpResponse, +}; /// An HTTP response builder. /// @@ -695,146 +421,18 @@ impl Future for HttpResponseBuilder { } } -#[cfg(feature = "cookies")] -pub struct CookieIter<'a> { - iter: header::GetAll<'a>, -} - -#[cfg(feature = "cookies")] -impl<'a> Iterator for CookieIter<'a> { - type Item = Cookie<'a>; - - #[inline] - fn next(&mut self) -> Option> { - for v in self.iter.by_ref() { - if let Ok(c) = Cookie::parse_encoded(v.to_str().ok()?) { - return Some(c); - } - } - None - } -} - -mod http_codes { - //! Status code based HTTP response builders. - - use actix_http::http::StatusCode; - - use super::{HttpResponse, HttpResponseBuilder}; - - macro_rules! static_resp { - ($name:ident, $status:expr) => { - #[allow(non_snake_case, missing_docs)] - pub fn $name() -> HttpResponseBuilder { - HttpResponseBuilder::new($status) - } - }; - } - - impl HttpResponse { - static_resp!(Continue, StatusCode::CONTINUE); - static_resp!(SwitchingProtocols, StatusCode::SWITCHING_PROTOCOLS); - static_resp!(Processing, StatusCode::PROCESSING); - - static_resp!(Ok, StatusCode::OK); - static_resp!(Created, StatusCode::CREATED); - static_resp!(Accepted, StatusCode::ACCEPTED); - static_resp!( - NonAuthoritativeInformation, - StatusCode::NON_AUTHORITATIVE_INFORMATION - ); - - static_resp!(NoContent, StatusCode::NO_CONTENT); - static_resp!(ResetContent, StatusCode::RESET_CONTENT); - static_resp!(PartialContent, StatusCode::PARTIAL_CONTENT); - static_resp!(MultiStatus, StatusCode::MULTI_STATUS); - static_resp!(AlreadyReported, StatusCode::ALREADY_REPORTED); - - static_resp!(MultipleChoices, StatusCode::MULTIPLE_CHOICES); - static_resp!(MovedPermanently, StatusCode::MOVED_PERMANENTLY); - static_resp!(Found, StatusCode::FOUND); - static_resp!(SeeOther, StatusCode::SEE_OTHER); - static_resp!(NotModified, StatusCode::NOT_MODIFIED); - static_resp!(UseProxy, StatusCode::USE_PROXY); - static_resp!(TemporaryRedirect, StatusCode::TEMPORARY_REDIRECT); - static_resp!(PermanentRedirect, StatusCode::PERMANENT_REDIRECT); - - static_resp!(BadRequest, StatusCode::BAD_REQUEST); - static_resp!(NotFound, StatusCode::NOT_FOUND); - static_resp!(Unauthorized, StatusCode::UNAUTHORIZED); - static_resp!(PaymentRequired, StatusCode::PAYMENT_REQUIRED); - static_resp!(Forbidden, StatusCode::FORBIDDEN); - static_resp!(MethodNotAllowed, StatusCode::METHOD_NOT_ALLOWED); - static_resp!(NotAcceptable, StatusCode::NOT_ACCEPTABLE); - static_resp!( - ProxyAuthenticationRequired, - StatusCode::PROXY_AUTHENTICATION_REQUIRED - ); - static_resp!(RequestTimeout, StatusCode::REQUEST_TIMEOUT); - static_resp!(Conflict, StatusCode::CONFLICT); - static_resp!(Gone, StatusCode::GONE); - static_resp!(LengthRequired, StatusCode::LENGTH_REQUIRED); - static_resp!(PreconditionFailed, StatusCode::PRECONDITION_FAILED); - static_resp!(PreconditionRequired, StatusCode::PRECONDITION_REQUIRED); - static_resp!(PayloadTooLarge, StatusCode::PAYLOAD_TOO_LARGE); - static_resp!(UriTooLong, StatusCode::URI_TOO_LONG); - static_resp!(UnsupportedMediaType, StatusCode::UNSUPPORTED_MEDIA_TYPE); - static_resp!(RangeNotSatisfiable, StatusCode::RANGE_NOT_SATISFIABLE); - static_resp!(ExpectationFailed, StatusCode::EXPECTATION_FAILED); - static_resp!(UnprocessableEntity, StatusCode::UNPROCESSABLE_ENTITY); - static_resp!(TooManyRequests, StatusCode::TOO_MANY_REQUESTS); - static_resp!( - RequestHeaderFieldsTooLarge, - StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE - ); - static_resp!( - UnavailableForLegalReasons, - StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS - ); - - static_resp!(InternalServerError, StatusCode::INTERNAL_SERVER_ERROR); - static_resp!(NotImplemented, StatusCode::NOT_IMPLEMENTED); - static_resp!(BadGateway, StatusCode::BAD_GATEWAY); - static_resp!(ServiceUnavailable, StatusCode::SERVICE_UNAVAILABLE); - static_resp!(GatewayTimeout, StatusCode::GATEWAY_TIMEOUT); - static_resp!(VersionNotSupported, StatusCode::HTTP_VERSION_NOT_SUPPORTED); - static_resp!(VariantAlsoNegotiates, StatusCode::VARIANT_ALSO_NEGOTIATES); - static_resp!(InsufficientStorage, StatusCode::INSUFFICIENT_STORAGE); - static_resp!(LoopDetected, StatusCode::LOOP_DETECTED); - } - - #[cfg(test)] - mod tests { - use crate::dev::Body; - use crate::http::StatusCode; - use crate::HttpResponse; - - #[test] - fn test_build() { - let resp = HttpResponse::Ok().body(Body::Empty); - assert_eq!(resp.status(), StatusCode::OK); - } - } -} - #[cfg(test)] mod tests { - use bytes::{Bytes, BytesMut}; + use actix_http::body; - use super::{HttpResponse, HttpResponseBuilder}; - use crate::dev::{Body, MessageBody, ResponseBody}; - use crate::http::header::{self, HeaderValue, CONTENT_TYPE, COOKIE}; - use crate::http::StatusCode; - - #[test] - fn test_debug() { - let resp = HttpResponse::Ok() - .append_header((COOKIE, HeaderValue::from_static("cookie1=value1; "))) - .append_header((COOKIE, HeaderValue::from_static("cookie2=value2; "))) - .finish(); - let dbg = format!("{:?}", resp); - assert!(dbg.contains("HttpResponse")); - } + use super::*; + use crate::{ + dev::Body, + http::{ + header::{self, HeaderValue, CONTENT_TYPE}, + StatusCode, + }, + }; #[test] fn test_basic_builder() { @@ -872,26 +470,13 @@ mod tests { assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain") } - pub async fn read_body(mut body: ResponseBody) -> Bytes - where - B: MessageBody + Unpin, - { - use futures_util::StreamExt as _; - - let mut bytes = BytesMut::new(); - while let Some(item) = body.next().await { - bytes.extend_from_slice(&item.unwrap()); - } - bytes.freeze() - } - #[actix_rt::test] async fn test_json() { let mut resp = HttpResponse::Ok().json(vec!["v1", "v2", "v3"]); let ct = resp.headers().get(CONTENT_TYPE).unwrap(); assert_eq!(ct, HeaderValue::from_static("application/json")); assert_eq!( - read_body(resp.take_body()).await.as_ref(), + body::to_bytes(resp.take_body()).await.unwrap().as_ref(), br#"["v1","v2","v3"]"# ); @@ -899,7 +484,7 @@ mod tests { let ct = resp.headers().get(CONTENT_TYPE).unwrap(); assert_eq!(ct, HeaderValue::from_static("application/json")); assert_eq!( - read_body(resp.take_body()).await.as_ref(), + body::to_bytes(resp.take_body()).await.unwrap().as_ref(), br#"["v1","v2","v3"]"# ); @@ -910,7 +495,7 @@ mod tests { let ct = resp.headers().get(CONTENT_TYPE).unwrap(); assert_eq!(ct, HeaderValue::from_static("text/json")); assert_eq!( - read_body(resp.take_body()).await.as_ref(), + body::to_bytes(resp.take_body()).await.unwrap().as_ref(), br#"["v1","v2","v3"]"# ); } @@ -922,7 +507,7 @@ mod tests { ); assert_eq!( - read_body(resp.take_body()).await.as_ref(), + body::to_bytes(resp.take_body()).await.unwrap().as_ref(), br#"{"test-key":"test-value"}"# ); } diff --git a/actix-http/src/http_codes.rs b/src/response/http_codes.rs similarity index 90% rename from actix-http/src/http_codes.rs rename to src/response/http_codes.rs index dc4f964de..d67ef3f92 100644 --- a/actix-http/src/http_codes.rs +++ b/src/response/http_codes.rs @@ -1,24 +1,19 @@ //! Status code based HTTP response builders. -#![allow(non_upper_case_globals)] +use actix_http::http::StatusCode; -use http::StatusCode; - -use crate::{ - body::Body, - response::{Response, ResponseBuilder}, -}; +use crate::{HttpResponse, HttpResponseBuilder}; macro_rules! static_resp { ($name:ident, $status:expr) => { #[allow(non_snake_case, missing_docs)] - pub fn $name() -> ResponseBuilder { - ResponseBuilder::new($status) + pub fn $name() -> HttpResponseBuilder { + HttpResponseBuilder::new($status) } }; } -impl Response { +impl HttpResponse { static_resp!(Continue, StatusCode::CONTINUE); static_resp!(SwitchingProtocols, StatusCode::SWITCHING_PROTOCOLS); static_resp!(Processing, StatusCode::PROCESSING); @@ -92,13 +87,13 @@ impl Response { #[cfg(test)] mod tests { - use crate::body::Body; - use crate::response::Response; - use http::StatusCode; + use crate::dev::Body; + use crate::http::StatusCode; + use crate::HttpResponse; #[test] fn test_build() { - let resp = Response::Ok().body(Body::Empty); + let resp = HttpResponse::Ok().body(Body::Empty); assert_eq!(resp.status(), StatusCode::OK); } } diff --git a/src/response/mod.rs b/src/response/mod.rs new file mode 100644 index 000000000..8401db9d2 --- /dev/null +++ b/src/response/mod.rs @@ -0,0 +1,10 @@ +mod builder; +mod http_codes; +#[allow(clippy::module_inception)] +mod response; + +pub use self::builder::HttpResponseBuilder; +pub use self::response::HttpResponse; + +#[cfg(feature = "cookies")] +pub use self::response::CookieIter; diff --git a/src/response/response.rs b/src/response/response.rs new file mode 100644 index 000000000..31868fe0b --- /dev/null +++ b/src/response/response.rs @@ -0,0 +1,330 @@ +use std::{ + cell::{Ref, RefMut}, + fmt, + future::Future, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use actix_http::{ + body::{Body, MessageBody, ResponseBody}, + http::{header::HeaderMap, StatusCode}, + Extensions, Response, ResponseHead, +}; + +#[cfg(feature = "cookies")] +use { + actix_http::http::{ + header::{self, HeaderValue}, + Error as HttpError, + }, + cookie::Cookie, +}; + +use crate::{error::Error, HttpResponseBuilder}; + +/// An HTTP Response +pub struct HttpResponse { + res: Response, + error: Option, +} + +impl HttpResponse { + /// Create HTTP response builder with specific status. + #[inline] + pub fn build(status: StatusCode) -> HttpResponseBuilder { + HttpResponseBuilder::new(status) + } + + /// Create a response. + #[inline] + pub fn new(status: StatusCode) -> Self { + Self { + res: Response::new(status), + error: None, + } + } + + /// Create an error response. + #[inline] + pub fn from_error(error: Error) -> Self { + let res = error.as_response_error().error_response(); + + Self { + res, + error: Some(error), + } + } + + /// Convert response to response with body + pub fn into_body(self) -> HttpResponse { + HttpResponse { + res: self.res.into_body(), + error: self.error, + } + } +} + +impl HttpResponse { + /// Constructs a response with body + #[inline] + pub fn with_body(status: StatusCode, body: B) -> Self { + Self { + res: Response::with_body(status, body), + error: None, + } + } + + /// Returns a reference to response head. + #[inline] + pub fn head(&self) -> &ResponseHead { + self.res.head() + } + + /// Returns a mutable reference to response head. + #[inline] + pub fn head_mut(&mut self) -> &mut ResponseHead { + self.res.head_mut() + } + + /// The source `error` for this response + #[inline] + pub fn error(&self) -> Option<&Error> { + self.error.as_ref() + } + + /// Get the response status code + #[inline] + pub fn status(&self) -> StatusCode { + self.res.status() + } + + /// Set the `StatusCode` for this response + #[inline] + pub fn status_mut(&mut self) -> &mut StatusCode { + self.res.status_mut() + } + + /// Get the headers from the response + #[inline] + pub fn headers(&self) -> &HeaderMap { + self.res.headers() + } + + /// Get a mutable reference to the headers + #[inline] + pub fn headers_mut(&mut self) -> &mut HeaderMap { + self.res.headers_mut() + } + + /// Get an iterator for the cookies set by this response. + #[cfg(feature = "cookies")] + pub fn cookies(&self) -> CookieIter<'_> { + CookieIter { + iter: self.headers().get_all(header::SET_COOKIE), + } + } + + /// Add a cookie to this response + #[cfg(feature = "cookies")] + pub fn add_cookie(&mut self, cookie: &Cookie<'_>) -> Result<(), HttpError> { + HeaderValue::from_str(&cookie.to_string()) + .map(|c| { + self.headers_mut().append(header::SET_COOKIE, c); + }) + .map_err(|e| e.into()) + } + + /// Remove all cookies with the given name from this response. Returns + /// the number of cookies removed. + #[cfg(feature = "cookies")] + pub fn del_cookie(&mut self, name: &str) -> usize { + let headers = self.headers_mut(); + + let vals: Vec = headers + .get_all(header::SET_COOKIE) + .map(|v| v.to_owned()) + .collect(); + + headers.remove(header::SET_COOKIE); + + let mut count: usize = 0; + for v in vals { + if let Ok(s) = v.to_str() { + if let Ok(c) = Cookie::parse_encoded(s) { + if c.name() == name { + count += 1; + continue; + } + } + } + + // put set-cookie header head back if it does not validate + headers.append(header::SET_COOKIE, v); + } + + count + } + + /// Connection upgrade status + #[inline] + pub fn upgrade(&self) -> bool { + self.res.upgrade() + } + + /// Keep-alive status for this connection + pub fn keep_alive(&self) -> bool { + self.res.keep_alive() + } + + /// Responses extensions + #[inline] + pub fn extensions(&self) -> Ref<'_, Extensions> { + self.res.extensions() + } + + /// Mutable reference to a the response's extensions + #[inline] + pub fn extensions_mut(&mut self) -> RefMut<'_, Extensions> { + self.res.extensions_mut() + } + + /// Get body of this response + #[inline] + pub fn body(&self) -> &ResponseBody { + self.res.body() + } + + /// Set a body + pub fn set_body(self, body: B2) -> HttpResponse { + HttpResponse { + res: self.res.set_body(body), + error: None, + // error: self.error, ?? + } + } + + /// Split response and body + pub fn into_parts(self) -> (HttpResponse<()>, ResponseBody) { + let (head, body) = self.res.into_parts(); + + ( + HttpResponse { + res: head, + error: None, + }, + body, + ) + } + + /// Drop request's body + pub fn drop_body(self) -> HttpResponse<()> { + HttpResponse { + res: self.res.drop_body(), + error: None, + } + } + + /// Set a body and return previous body value + pub fn map_body(self, f: F) -> HttpResponse + where + F: FnOnce(&mut ResponseHead, ResponseBody) -> ResponseBody, + { + HttpResponse { + res: self.res.map_body(f), + error: self.error, + } + } + + /// Extract response body + pub fn take_body(&mut self) -> ResponseBody { + self.res.take_body() + } +} + +impl fmt::Debug for HttpResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HttpResponse") + .field("error", &self.error) + .field("res", &self.res) + .finish() + } +} + +impl From> for HttpResponse { + fn from(res: Response) -> Self { + HttpResponse { res, error: None } + } +} + +impl From for HttpResponse { + fn from(err: Error) -> Self { + HttpResponse::from_error(err) + } +} + +impl From> for Response { + fn from(res: HttpResponse) -> Self { + // this impl will always be called as part of dispatcher + + // TODO: expose cause somewhere? + // if let Some(err) = res.error { + // eprintln!("impl From> for Response let Some(err)"); + // return Response::from_error(err).into_body(); + // } + + res.res + } +} + +impl Future for HttpResponse { + type Output = Result, Error>; + + fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + if let Some(err) = self.error.take() { + return Poll::Ready(Ok(Response::from_error(err).into_body())); + } + + Poll::Ready(Ok(mem::replace( + &mut self.res, + Response::new(StatusCode::default()), + ))) + } +} + +#[cfg(feature = "cookies")] +pub struct CookieIter<'a> { + iter: header::GetAll<'a>, +} + +#[cfg(feature = "cookies")] +impl<'a> Iterator for CookieIter<'a> { + type Item = Cookie<'a>; + + #[inline] + fn next(&mut self) -> Option> { + for v in self.iter.by_ref() { + if let Ok(c) = Cookie::parse_encoded(v.to_str().ok()?) { + return Some(c); + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::http::header::{HeaderValue, COOKIE}; + + #[test] + fn test_debug() { + let resp = HttpResponse::Ok() + .append_header((COOKIE, HeaderValue::from_static("cookie1=value1; "))) + .append_header((COOKIE, HeaderValue::from_static("cookie2=value2; "))) + .finish(); + let dbg = format!("{:?}", resp); + assert!(dbg.contains("HttpResponse")); + } +}