diff --git a/CHANGES.md b/CHANGES.md index 87c021b1e..ee9b9308d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,9 @@ # Changes ## Unreleased - 2020-xx-xx +### Changed +* Bumped `rand` to `0.8` + ### Fixed * added the actual parsing error to `test::read_body_json` [#1812] diff --git a/Cargo.toml b/Cargo.toml index 31c4cca7e..6ed327f56 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -112,7 +112,7 @@ tinyvec = { version = "1", features = ["alloc"] } [dev-dependencies] actix = "0.10.0" actix-http = { version = "2.1.0", features = ["actors"] } -rand = "0.7" +rand = "0.8" env_logger = "0.8" serde_derive = "1.0" brotli2 = "0.3.2" diff --git a/actix-http-test/src/lib.rs b/actix-http-test/src/lib.rs index f881dfb4c..3ab3f8a0d 100644 --- a/actix-http-test/src/lib.rs +++ b/actix-http-test/src/lib.rs @@ -53,7 +53,7 @@ pub async fn test_server>(factory: F) -> TestServer test_server_with_addr(tcp, factory).await } -/// Start [`test server`](./fn.test_server.html) on a concrete Address +/// Start [`test server`](test_server()) on a concrete Address pub async fn test_server_with_addr>( tcp: net::TcpListener, factory: F, diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index c602ab2e1..81577688d 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,7 +1,8 @@ # Changes ## Unreleased - 2020-xx-xx - +### Changed +* Bumped `rand` to `0.8` ## 2.2.0 - 2020-11-25 ### Added diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 7375c6eb3..7cf344487 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -72,7 +72,7 @@ log = "0.4" mime = "0.3" percent-encoding = "2.1" pin-project = "1.0.0" -rand = "0.7" +rand = "0.8" regex = "1.3" serde = "1.0" serde_json = "1.0" diff --git a/actix-http/src/client/pool.rs b/actix-http/src/client/pool.rs index 08abc6277..a8687dbeb 100644 --- a/actix-http/src/client/pool.rs +++ b/actix-http/src/client/pool.rs @@ -9,8 +9,9 @@ use std::time::{Duration, Instant}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_rt::time::{delay_for, Delay}; use actix_service::Service; -use actix_utils::{oneshot, task::LocalWaker}; +use actix_utils::task::LocalWaker; use bytes::Bytes; +use futures_channel::oneshot; use futures_util::future::{poll_fn, FutureExt, LocalBoxFuture}; use fxhash::FxHashMap; use h2::client::{Connection, SendRequest}; diff --git a/actix-http/src/cloneable.rs b/actix-http/src/cloneable.rs index b64c299fc..0e77c455c 100644 --- a/actix-http/src/cloneable.rs +++ b/actix-http/src/cloneable.rs @@ -4,12 +4,12 @@ use std::task::{Context, Poll}; use actix_service::Service; -#[doc(hidden)] /// Service that allows to turn non-clone service to a service with `Clone` impl /// /// # Panics /// CloneableService might panic with some creative use of thread local storage. /// See https://github.com/actix/actix-web/issues/1295 for example +#[doc(hidden)] pub(crate) struct CloneableService(Rc>); impl CloneableService { diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs index e93c077af..0ebd4c05c 100644 --- a/actix-http/src/error.rs +++ b/actix-http/src/error.rs @@ -25,7 +25,7 @@ pub use crate::cookie::ParseError as CookieParseError; use crate::helpers::Writer; use crate::response::{Response, ResponseBuilder}; -/// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html) +/// A specialized [`std::result::Result`] /// for actix web operations /// /// This typedef is generally used to avoid writing out diff --git a/actix-http/src/extensions.rs b/actix-http/src/extensions.rs index 7dda74731..b20dfe11d 100644 --- a/actix-http/src/extensions.rs +++ b/actix-http/src/extensions.rs @@ -3,8 +3,8 @@ use std::{fmt, mem}; use fxhash::FxHashMap; -#[derive(Default)] /// A type map of request extensions. +#[derive(Default)] pub struct Extensions { /// Use FxHasher with a std HashMap with for faster /// lookups on the small `TypeId` (u64 equivalent) keys. diff --git a/actix-http/src/h1/codec.rs b/actix-http/src/h1/codec.rs index 036f16670..c9a62dc30 100644 --- a/actix-http/src/h1/codec.rs +++ b/actix-http/src/h1/codec.rs @@ -58,6 +58,7 @@ impl Codec { } else { Flags::empty() }; + Codec { config, flags, @@ -69,26 +70,26 @@ impl Codec { } } + /// Check if request is upgrade. #[inline] - /// Check if request is upgrade pub fn upgrade(&self) -> bool { self.ctype == ConnectionType::Upgrade } + /// Check if last response is keep-alive. #[inline] - /// Check if last response is keep-alive pub fn keepalive(&self) -> bool { self.ctype == ConnectionType::KeepAlive } + /// Check if keep-alive enabled on server level. #[inline] - /// Check if keep-alive enabled on server level pub fn keepalive_enabled(&self) -> bool { self.flags.contains(Flags::KEEPALIVE_ENABLED) } + /// Check last request's message type. #[inline] - /// Check last request's message type pub fn message_type(&self) -> MessageType { if self.flags.contains(Flags::STREAM) { MessageType::Stream diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index ace4144e3..ea8f91e0d 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -1,8 +1,11 @@ -use std::collections::VecDeque; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{fmt, io, net}; +use std::{ + collections::VecDeque, + fmt, + future::Future, + io, mem, net, + pin::Pin, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; use actix_rt::time::{delay_until, Delay, Instant}; @@ -59,6 +62,9 @@ where { #[pin] inner: DispatcherState, + + #[cfg(test)] + poll_count: u64, } #[pin_project(project = DispatcherStateProj)] @@ -124,8 +130,8 @@ where B: MessageBody, { None, - ExpectCall(Pin>), - ServiceCall(Pin>), + ExpectCall(#[pin] X::Future), + ServiceCall(#[pin] S::Future), SendPayload(#[pin] ResponseBody), } @@ -247,6 +253,9 @@ where ka_expire, ka_timer, }), + + #[cfg(test)] + poll_count: 0, } } } @@ -338,7 +347,7 @@ where self: Pin<&mut Self>, message: Response<()>, body: ResponseBody, - ) -> Result, DispatchError> { + ) -> Result<(), DispatchError> { let mut this = self.project(); this.codec .encode(Message::Item((message, body.size())), &mut this.write_buf) @@ -351,9 +360,10 @@ where this.flags.set(Flags::KEEPALIVE, this.codec.keepalive()); match body.size() { - BodySize::None | BodySize::Empty => Ok(State::None), - _ => Ok(State::SendPayload(body)), - } + BodySize::None | BodySize::Empty => this.state.set(State::None), + _ => this.state.set(State::SendPayload(body)), + }; + Ok(()) } fn send_continue(self: Pin<&mut Self>) { @@ -368,49 +378,52 @@ where ) -> Result { loop { let mut this = self.as_mut().project(); - let state = match this.state.project() { + // state is not changed on Poll::Pending. + // other variant and conditions always trigger a state change(or an error). + let state_change = match this.state.project() { StateProj::None => match this.messages.pop_front() { Some(DispatcherMessage::Item(req)) => { - Some(self.as_mut().handle_request(req, cx)?) + self.as_mut().handle_request(req, cx)?; + true } - Some(DispatcherMessage::Error(res)) => Some( + Some(DispatcherMessage::Error(res)) => { self.as_mut() - .send_response(res, ResponseBody::Other(Body::Empty))?, - ), + .send_response(res, ResponseBody::Other(Body::Empty))?; + true + } Some(DispatcherMessage::Upgrade(req)) => { return Ok(PollResponse::Upgrade(req)); } - None => None, + None => false, }, - StateProj::ExpectCall(fut) => match fut.as_mut().poll(cx) { + StateProj::ExpectCall(fut) => match fut.poll(cx) { Poll::Ready(Ok(req)) => { self.as_mut().send_continue(); this = self.as_mut().project(); - this.state - .set(State::ServiceCall(Box::pin(this.service.call(req)))); + this.state.set(State::ServiceCall(this.service.call(req))); continue; } Poll::Ready(Err(e)) => { let res: Response = e.into().into(); let (res, body) = res.replace_body(()); - Some(self.as_mut().send_response(res, body.into_body())?) + self.as_mut().send_response(res, body.into_body())?; + true } - Poll::Pending => None, + Poll::Pending => false, }, - StateProj::ServiceCall(fut) => match fut.as_mut().poll(cx) { + StateProj::ServiceCall(fut) => match fut.poll(cx) { Poll::Ready(Ok(res)) => { let (res, body) = res.into().replace_body(()); - let state = self.as_mut().send_response(res, body)?; - this = self.as_mut().project(); - this.state.set(state); + self.as_mut().send_response(res, body)?; continue; } Poll::Ready(Err(e)) => { let res: Response = e.into().into(); let (res, body) = res.replace_body(()); - Some(self.as_mut().send_response(res, body.into_body())?) + self.as_mut().send_response(res, body.into_body())?; + true } - Poll::Pending => None, + Poll::Pending => false, }, StateProj::SendPayload(mut stream) => { loop { @@ -445,11 +458,8 @@ where } }; - this = self.as_mut().project(); - - // set new state - if let Some(state) = state { - this.state.set(state); + // state is changed and continue when the state is not Empty + if state_change { if !self.state.is_empty() { continue; } @@ -474,49 +484,77 @@ where mut self: Pin<&mut Self>, req: Request, cx: &mut Context<'_>, - ) -> Result, DispatchError> { + ) -> Result<(), DispatchError> { // Handle `EXPECT: 100-Continue` header - let req = if req.head().expect() { - let mut task = Box::pin(self.as_mut().project().expect.call(req)); - match task.as_mut().poll(cx) { - Poll::Ready(Ok(req)) => { - self.as_mut().send_continue(); - req - } - Poll::Pending => return Ok(State::ExpectCall(task)), - Poll::Ready(Err(e)) => { - let e = e.into(); - let res: Response = e.into(); - let (res, body) = res.replace_body(()); - return self.send_response(res, body.into_body()); - } - } + if req.head().expect() { + // set dispatcher state so the future is pinned. + let task = self.as_mut().project().expect.call(req); + self.as_mut().project().state.set(State::ExpectCall(task)); } else { - req + // the same as above. + let task = self.as_mut().project().service.call(req); + self.as_mut().project().state.set(State::ServiceCall(task)); }; - // Call service - let mut task = Box::pin(self.as_mut().project().service.call(req)); - match task.as_mut().poll(cx) { - Poll::Ready(Ok(res)) => { - let (res, body) = res.into().replace_body(()); - self.send_response(res, body) - } - Poll::Pending => Ok(State::ServiceCall(task)), - Poll::Ready(Err(e)) => { - let res: Response = e.into().into(); - let (res, body) = res.replace_body(()); - self.send_response(res, body.into_body()) + // eagerly poll the future for once(or twice if expect is resolved immediately). + loop { + match self.as_mut().project().state.project() { + StateProj::ExpectCall(fut) => { + match fut.poll(cx) { + // expect is resolved. continue loop and poll the service call branch. + Poll::Ready(Ok(req)) => { + self.as_mut().send_continue(); + let task = self.as_mut().project().service.call(req); + self.as_mut().project().state.set(State::ServiceCall(task)); + continue; + } + // future is pending. return Ok(()) to notify that a new state is + // set and the outer loop should be continue. + Poll::Pending => return Ok(()), + // future is error. send response and return a result. On success + // to notify the dispatcher a new state is set and the outer loop + // should be continue. + Poll::Ready(Err(e)) => { + let e = e.into(); + let res: Response = e.into(); + let (res, body) = res.replace_body(()); + return self.send_response(res, body.into_body()); + } + } + } + StateProj::ServiceCall(fut) => { + // return no matter the service call future's result. + return match fut.poll(cx) { + // future is resolved. send response and return a result. On success + // to notify the dispatcher a new state is set and the outer loop + // should be continue. + Poll::Ready(Ok(res)) => { + let (res, body) = res.into().replace_body(()); + self.send_response(res, body) + } + // see the comment on ExpectCall state branch's Pending. + Poll::Pending => Ok(()), + // see the comment on ExpectCall state branch's Ready(Err(e)). + Poll::Ready(Err(e)) => { + let res: Response = e.into().into(); + let (res, body) = res.replace_body(()); + self.send_response(res, body.into_body()) + } + }; + } + _ => unreachable!( + "State must be set to ServiceCall or ExceptCall in handle_request" + ), } } } - /// Process one incoming requests + /// Process one incoming request. pub(self) fn poll_request( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Result { - // limit a mount of non processed requests + // limit amount of non-processed requests if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read(cx) { return Ok(false); } @@ -557,9 +595,8 @@ where // handle request early if this.state.is_empty() { - let state = self.as_mut().handle_request(req, cx)?; + self.as_mut().handle_request(req, cx)?; this = self.as_mut().project(); - this.state.set(state); } else { this.messages.push_back(DispatcherMessage::Item(req)); } @@ -725,6 +762,12 @@ where #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_mut().project(); + + #[cfg(test)] + { + *this.poll_count += 1; + } + match this.inner.project() { DispatcherStateProj::Normal(mut inner) => { inner.as_mut().poll_keepalive(cx)?; @@ -788,10 +831,10 @@ where let inner_p = inner.as_mut().project(); let mut parts = FramedParts::with_read_buf( inner_p.io.take().unwrap(), - std::mem::take(inner_p.codec), - std::mem::take(inner_p.read_buf), + mem::take(inner_p.codec), + mem::take(inner_p.read_buf), ); - parts.write_buf = std::mem::take(inner_p.write_buf); + parts.write_buf = mem::take(inner_p.write_buf); let framed = Framed::from_parts(parts); let upgrade = inner_p.upgrade.take().unwrap().call((req, framed)); @@ -803,8 +846,11 @@ where } // we didn't get WouldBlock from write operation, - // so data get written to kernel completely (OSX) + // so data get written to kernel completely (macOS) // and we have to write again otherwise response can get stuck + // + // TODO: what? is WouldBlock good or bad? + // want to find a reference for this macOS behavior if inner.as_mut().poll_flush(cx)? || !drain { break; } @@ -854,6 +900,11 @@ where } } +/// Returns either: +/// - `Ok(Some(true))` - data was read and done reading all data. +/// - `Ok(Some(false))` - data was read but there should be more to read. +/// - `Ok(None)` - no data was read but there should be more to read later. +/// - Unhandled Errors fn read_available( cx: &mut Context<'_>, io: &mut T, @@ -887,17 +938,17 @@ where read_some = true; } } - Poll::Ready(Err(e)) => { - return if e.kind() == io::ErrorKind::WouldBlock { + Poll::Ready(Err(err)) => { + return if err.kind() == io::ErrorKind::WouldBlock { if read_some { Ok(Some(false)) } else { Ok(None) } - } else if e.kind() == io::ErrorKind::ConnectionReset && read_some { + } else if err.kind() == io::ErrorKind::ConnectionReset && read_some { Ok(Some(true)) } else { - Err(e) + Err(err) } } } @@ -917,25 +968,74 @@ where #[cfg(test)] mod tests { - use actix_service::IntoService; - use futures_util::future::{lazy, ok}; + use std::{marker::PhantomData, str}; + + use actix_service::fn_service; + use futures_util::future::{lazy, ready}; use super::*; - use crate::error::Error; - use crate::h1::{ExpectHandler, UpgradeHandler}; use crate::test::TestBuffer; + use crate::{error::Error, KeepAlive}; + use crate::{ + h1::{ExpectHandler, UpgradeHandler}, + test::TestSeqBuffer, + }; + + fn find_slice(haystack: &[u8], needle: &[u8], from: usize) -> Option { + haystack[from..] + .windows(needle.len()) + .position(|window| window == needle) + } + + fn stabilize_date_header(payload: &mut [u8]) { + let mut from = 0; + + while let Some(pos) = find_slice(&payload, b"date", from) { + payload[(from + pos)..(from + pos + 35)] + .copy_from_slice(b"date: Thu, 01 Jan 1970 12:34:56 UTC"); + from += 35; + } + } + + fn ok_service() -> impl Service + { + fn_service(|_req: Request| ready(Ok::<_, Error>(Response::Ok().finish()))) + } + + fn echo_path_service( + ) -> impl Service { + fn_service(|req: Request| { + let path = req.path().as_bytes(); + ready(Ok::<_, Error>(Response::Ok().body(Body::from_slice(path)))) + }) + } + + fn echo_payload_service( + ) -> impl Service { + fn_service(|mut req: Request| { + Box::pin(async move { + use futures_util::stream::StreamExt as _; + + let mut pl = req.take_payload(); + let mut body = BytesMut::new(); + while let Some(chunk) = pl.next().await { + body.extend_from_slice(chunk.unwrap().bytes()) + } + + Ok::<_, Error>(Response::Ok().body(body)) + }) + }) + } #[actix_rt::test] async fn test_req_parse_err() { lazy(|cx| { let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n"); - let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, ServiceConfig::default(), - CloneableService::new( - (|_| ok::<_, Error>(Response::Ok().finish())).into_service(), - ), + CloneableService::new(ok_service()), CloneableService::new(ExpectHandler), None, None, @@ -943,19 +1043,301 @@ mod tests { None, ); - match Pin::new(&mut h1).poll(cx) { + futures_util::pin_mut!(h1); + + match h1.as_mut().poll(cx) { Poll::Pending => panic!(), Poll::Ready(res) => assert!(res.is_err()), } - if let DispatcherState::Normal(ref mut inner) = h1.inner { + if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { assert!(inner.flags.contains(Flags::READ_DISCONNECT)); assert_eq!( - &inner.io.take().unwrap().write_buf[..26], + &inner.project().io.take().unwrap().write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n" ); } }) .await; } + + #[actix_rt::test] + async fn test_pipelining() { + lazy(|cx| { + let buf = TestBuffer::new( + "\ + GET /abcd HTTP/1.1\r\n\r\n\ + GET /def HTTP/1.1\r\n\r\n\ + ", + ); + + let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None); + + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + buf, + cfg, + CloneableService::new(echo_path_service()), + CloneableService::new(ExpectHandler), + None, + None, + Extensions::new(), + None, + ); + + futures_util::pin_mut!(h1); + + assert!(matches!(&h1.inner, DispatcherState::Normal(_))); + + match h1.as_mut().poll(cx) { + Poll::Pending => panic!("first poll should not be pending"), + Poll::Ready(res) => assert!(res.is_ok()), + } + + // polls: initial => shutdown + assert_eq!(h1.poll_count, 2); + + if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { + let res = &mut inner.project().io.take().unwrap().write_buf[..]; + stabilize_date_header(res); + + let exp = b"\ + HTTP/1.1 200 OK\r\n\ + content-length: 5\r\n\ + connection: close\r\n\ + date: Thu, 01 Jan 1970 12:34:56 UTC\r\n\r\n\ + /abcd\ + HTTP/1.1 200 OK\r\n\ + content-length: 4\r\n\ + connection: close\r\n\ + date: Thu, 01 Jan 1970 12:34:56 UTC\r\n\r\n\ + /def\ + "; + + assert_eq!(res.to_vec(), exp.to_vec()); + } + }) + .await; + + lazy(|cx| { + let buf = TestBuffer::new( + "\ + GET /abcd HTTP/1.1\r\n\r\n\ + GET /def HTTP/1\r\n\r\n\ + ", + ); + + let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None); + + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + buf, + cfg, + CloneableService::new(echo_path_service()), + CloneableService::new(ExpectHandler), + None, + None, + Extensions::new(), + None, + ); + + futures_util::pin_mut!(h1); + + assert!(matches!(&h1.inner, DispatcherState::Normal(_))); + + match h1.as_mut().poll(cx) { + Poll::Pending => panic!("first poll should not be pending"), + Poll::Ready(res) => assert!(res.is_err()), + } + + // polls: initial => shutdown + assert_eq!(h1.poll_count, 1); + + if let DispatcherStateProj::Normal(inner) = h1.project().inner.project() { + let res = &mut inner.project().io.take().unwrap().write_buf[..]; + stabilize_date_header(res); + + let exp = b"\ + HTTP/1.1 200 OK\r\n\ + content-length: 5\r\n\ + connection: close\r\n\ + date: Thu, 01 Jan 1970 12:34:56 UTC\r\n\r\n\ + /abcd\ + HTTP/1.1 400 Bad Request\r\n\ + content-length: 0\r\n\ + connection: close\r\n\ + date: Thu, 01 Jan 1970 12:34:56 UTC\r\n\r\n\ + "; + + assert_eq!(res.to_vec(), exp.to_vec()); + } + }) + .await; + } + + #[actix_rt::test] + async fn test_expect() { + lazy(|cx| { + let mut buf = TestSeqBuffer::empty(); + let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( + buf.clone(), + cfg, + CloneableService::new(echo_payload_service()), + CloneableService::new(ExpectHandler), + None, + None, + Extensions::new(), + None, + ); + + buf.extend_read_buf( + "\ + POST /upload HTTP/1.1\r\n\ + Content-Length: 5\r\n\ + Expect: 100-continue\r\n\ + \r\n\ + ", + ); + + futures_util::pin_mut!(h1); + + assert!(h1.as_mut().poll(cx).is_pending()); + assert!(matches!(&h1.inner, DispatcherState::Normal(_))); + + // polls: manual + assert_eq!(h1.poll_count, 1); + eprintln!("poll count: {}", h1.poll_count); + + if let DispatcherState::Normal(ref inner) = h1.inner { + let io = inner.io.as_ref().unwrap(); + let res = &io.write_buf()[..]; + assert_eq!( + str::from_utf8(res).unwrap(), + "HTTP/1.1 100 Continue\r\n\r\n" + ); + } + + buf.extend_read_buf("12345"); + assert!(h1.as_mut().poll(cx).is_ready()); + + // polls: manual manual shutdown + assert_eq!(h1.poll_count, 3); + + if let DispatcherState::Normal(ref inner) = h1.inner { + let io = inner.io.as_ref().unwrap(); + let mut res = (&io.write_buf()[..]).to_owned(); + stabilize_date_header(&mut res); + + assert_eq!( + str::from_utf8(&res).unwrap(), + "\ + HTTP/1.1 100 Continue\r\n\ + \r\n\ + HTTP/1.1 200 OK\r\n\ + content-length: 5\r\n\ + connection: close\r\n\ + date: Thu, 01 Jan 1970 12:34:56 UTC\r\n\ + \r\n\ + 12345\ + " + ); + } + }) + .await; + } + + #[actix_rt::test] + async fn test_eager_expect() { + lazy(|cx| { + let mut buf = TestSeqBuffer::empty(); + let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( + buf.clone(), + cfg, + CloneableService::new(echo_path_service()), + CloneableService::new(ExpectHandler), + None, + None, + Extensions::new(), + None, + ); + + buf.extend_read_buf( + "\ + POST /upload HTTP/1.1\r\n\ + Content-Length: 5\r\n\ + Expect: 100-continue\r\n\ + \r\n\ + ", + ); + + futures_util::pin_mut!(h1); + + assert!(h1.as_mut().poll(cx).is_ready()); + assert!(matches!(&h1.inner, DispatcherState::Normal(_))); + + // polls: manual shutdown + assert_eq!(h1.poll_count, 2); + + if let DispatcherState::Normal(ref inner) = h1.inner { + let io = inner.io.as_ref().unwrap(); + let mut res = (&io.write_buf()[..]).to_owned(); + stabilize_date_header(&mut res); + + // Despite the content-length header and even though the request payload has not + // been sent, this test expects a complete service response since the payload + // is not used at all. The service passed to dispatcher is path echo and doesn't + // consume payload bytes. + assert_eq!( + str::from_utf8(&res).unwrap(), + "\ + HTTP/1.1 100 Continue\r\n\ + \r\n\ + HTTP/1.1 200 OK\r\n\ + content-length: 7\r\n\ + connection: close\r\n\ + date: Thu, 01 Jan 1970 12:34:56 UTC\r\n\ + \r\n\ + /upload\ + " + ); + } + }) + .await; + } + + #[actix_rt::test] + async fn test_upgrade() { + lazy(|cx| { + let mut buf = TestSeqBuffer::empty(); + let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler<_>>::new( + buf.clone(), + cfg, + CloneableService::new(ok_service()), + CloneableService::new(ExpectHandler), + Some(CloneableService::new(UpgradeHandler(PhantomData))), + None, + Extensions::new(), + None, + ); + + buf.extend_read_buf( + "\ + GET /ws HTTP/1.1\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + \r\n\ + ", + ); + + futures_util::pin_mut!(h1); + + assert!(h1.as_mut().poll(cx).is_ready()); + assert!(matches!(&h1.inner, DispatcherState::Upgrade(_))); + + // polls: manual shutdown + assert_eq!(h1.poll_count, 2); + }) + .await; + } } diff --git a/actix-http/src/h1/expect.rs b/actix-http/src/h1/expect.rs index 6c08df08e..b89c7ff74 100644 --- a/actix-http/src/h1/expect.rs +++ b/actix-http/src/h1/expect.rs @@ -1,7 +1,7 @@ use std::task::{Context, Poll}; use actix_service::{Service, ServiceFactory}; -use futures_util::future::{ok, Ready}; +use futures_util::future::{ready, Ready}; use crate::error::Error; use crate::request::Request; @@ -17,8 +17,8 @@ impl ServiceFactory for ExpectHandler { type InitError = Error; type Future = Ready>; - fn new_service(&self, _: ()) -> Self::Future { - ok(ExpectHandler) + fn new_service(&self, _: Self::Config) -> Self::Future { + ready(Ok(ExpectHandler)) } } @@ -33,6 +33,8 @@ impl Service for ExpectHandler { } fn call(&mut self, req: Request) -> Self::Future { - ok(req) + ready(Ok(req)) + // TODO: add some way to trigger error + // Err(error::ErrorExpectationFailed("test")) } } diff --git a/actix-http/src/h1/payload.rs b/actix-http/src/h1/payload.rs index 6a348810c..d4cfee146 100644 --- a/actix-http/src/h1/payload.rs +++ b/actix-http/src/h1/payload.rs @@ -182,9 +182,7 @@ impl Inner { self.len += data.len(); self.items.push_back(data); self.need_read = self.len < MAX_BUFFER_SIZE; - if let Some(task) = self.task.take() { - task.wake() - } + self.task.wake(); } #[cfg(test)] diff --git a/actix-http/src/h1/upgrade.rs b/actix-http/src/h1/upgrade.rs index 22ba99e26..8615f27a8 100644 --- a/actix-http/src/h1/upgrade.rs +++ b/actix-http/src/h1/upgrade.rs @@ -3,13 +3,13 @@ use std::task::{Context, Poll}; use actix_codec::Framed; use actix_service::{Service, ServiceFactory}; -use futures_util::future::Ready; +use futures_util::future::{ready, Ready}; use crate::error::Error; use crate::h1::Codec; use crate::request::Request; -pub struct UpgradeHandler(PhantomData); +pub struct UpgradeHandler(pub(crate) PhantomData); impl ServiceFactory for UpgradeHandler { type Config = (); @@ -36,6 +36,6 @@ impl Service for UpgradeHandler { } fn call(&mut self, _: Self::Request) -> Self::Future { - unimplemented!() + ready(Ok(())) } } diff --git a/actix-http/src/header/common/content_disposition.rs b/actix-http/src/header/common/content_disposition.rs index 37da830ca..826cfef63 100644 --- a/actix-http/src/header/common/content_disposition.rs +++ b/actix-http/src/header/common/content_disposition.rs @@ -550,8 +550,7 @@ impl fmt::Display for ContentDisposition { write!(f, "{}", self.disposition)?; self.parameters .iter() - .map(|param| write!(f, "; {}", param)) - .collect() + .try_for_each(|param| write!(f, "; {}", param)) } } diff --git a/actix-http/src/header/common/mod.rs b/actix-http/src/header/common/mod.rs index 83489b864..c3d18613c 100644 --- a/actix-http/src/header/common/mod.rs +++ b/actix-http/src/header/common/mod.rs @@ -3,7 +3,7 @@ //! ## Mime //! //! Several header fields use MIME values for their contents. Keeping with the -//! strongly-typed theme, the [mime](https://docs.rs/mime) crate +//! strongly-typed theme, the [mime] crate //! is used, such as `ContentType(pub Mime)`. #![cfg_attr(rustfmt, rustfmt_skip)] diff --git a/actix-http/src/header/map.rs b/actix-http/src/header/map.rs index 36c050b8f..6ab3509f7 100644 --- a/actix-http/src/header/map.rs +++ b/actix-http/src/header/map.rs @@ -8,8 +8,6 @@ use http::header::{HeaderName, HeaderValue}; /// A set of HTTP headers /// /// `HeaderMap` is an multi-map of [`HeaderName`] to values. -/// -/// [`HeaderName`]: struct.HeaderName.html #[derive(Debug, Clone)] pub struct HeaderMap { pub(crate) inner: FxHashMap, @@ -141,8 +139,6 @@ impl HeaderMap { /// The returned view does not incur any allocations and allows iterating /// the values associated with the key. See [`GetAll`] for more details. /// Returns `None` if there are no values associated with the key. - /// - /// [`GetAll`]: struct.GetAll.html pub fn get_all(&self, name: N) -> GetAll<'_> { GetAll { idx: 0, diff --git a/actix-http/src/header/shared/entity.rs b/actix-http/src/header/shared/entity.rs index 3525a19c6..344cfb864 100644 --- a/actix-http/src/header/shared/entity.rs +++ b/actix-http/src/header/shared/entity.rs @@ -7,10 +7,12 @@ use crate::header::{HeaderValue, IntoHeaderValue, InvalidHeaderValue, Writer}; /// 1. `%x21`, or /// 2. in the range `%x23` to `%x7E`, or /// 3. above `%x80` +fn entity_validate_char(c: u8) -> bool { + c == 0x21 || (0x23..=0x7e).contains(&c) || (c >= 0x80) +} + fn check_slice_validity(slice: &str) -> bool { - slice - .bytes() - .all(|c| c == b'\x21' || (c >= b'\x23' && c <= b'\x7e') | (c >= b'\x80')) + slice.bytes().all(entity_validate_char) } /// An entity tag, defined in [RFC7232](https://tools.ietf.org/html/rfc7232#section-2.3) diff --git a/actix-http/src/test.rs b/actix-http/src/test.rs index b79f5a73c..4512e72c2 100644 --- a/actix-http/src/test.rs +++ b/actix-http/src/test.rs @@ -1,9 +1,14 @@ -//! Test Various helpers for Actix applications to use during testing. -use std::convert::TryFrom; -use std::io::{self, Read, Write}; -use std::pin::Pin; -use std::str::FromStr; -use std::task::{Context, Poll}; +//! Various testing helpers for use in internal and app tests. + +use std::{ + cell::{Ref, RefCell}, + convert::TryFrom, + io::{self, Read, Write}, + pin::Pin, + rc::Rc, + str::FromStr, + task::{Context, Poll}, +}; use actix_codec::{AsyncRead, AsyncWrite}; use bytes::{Bytes, BytesMut}; @@ -183,7 +188,7 @@ fn parts(parts: &mut Option) -> &mut Inner { parts.as_mut().expect("cannot reuse test request builder") } -/// Async io buffer +/// Async I/O test buffer. pub struct TestBuffer { pub read_buf: BytesMut, pub write_buf: BytesMut, @@ -191,24 +196,24 @@ pub struct TestBuffer { } impl TestBuffer { - /// Create new TestBuffer instance - pub fn new(data: T) -> TestBuffer + /// Create new `TestBuffer` instance with initial read buffer. + pub fn new(data: T) -> Self where - BytesMut: From, + T: Into, { - TestBuffer { - read_buf: BytesMut::from(data), + Self { + read_buf: data.into(), write_buf: BytesMut::new(), err: None, } } - /// Create new empty TestBuffer instance - pub fn empty() -> TestBuffer { - TestBuffer::new("") + /// Create new empty `TestBuffer` instance. + pub fn empty() -> Self { + Self::new("") } - /// Add extra data to read buffer. + /// Add data to read buffer. pub fn extend_read_buf>(&mut self, data: T) { self.read_buf.extend_from_slice(data.as_ref()) } @@ -236,6 +241,7 @@ impl io::Write for TestBuffer { self.write_buf.extend(buf); Ok(buf.len()) } + fn flush(&mut self) -> io::Result<()> { Ok(()) } @@ -268,3 +274,113 @@ impl AsyncWrite for TestBuffer { Poll::Ready(Ok(())) } } + +/// Async I/O test buffer with ability to incrementally add to the read buffer. +#[derive(Clone)] +pub struct TestSeqBuffer(Rc>); + +impl TestSeqBuffer { + /// Create new `TestBuffer` instance with initial read buffer. + pub fn new(data: T) -> Self + where + T: Into, + { + Self(Rc::new(RefCell::new(TestSeqInner { + read_buf: data.into(), + write_buf: BytesMut::new(), + err: None, + }))) + } + + /// Create new empty `TestBuffer` instance. + pub fn empty() -> Self { + Self::new("") + } + + pub fn read_buf(&self) -> Ref<'_, BytesMut> { + Ref::map(self.0.borrow(), |inner| &inner.read_buf) + } + + pub fn write_buf(&self) -> Ref<'_, BytesMut> { + Ref::map(self.0.borrow(), |inner| &inner.write_buf) + } + + pub fn err(&self) -> Ref<'_, Option> { + Ref::map(self.0.borrow(), |inner| &inner.err) + } + + /// Add data to read buffer. + pub fn extend_read_buf>(&mut self, data: T) { + self.0 + .borrow_mut() + .read_buf + .extend_from_slice(data.as_ref()) + } +} + +pub struct TestSeqInner { + read_buf: BytesMut, + write_buf: BytesMut, + err: Option, +} + +impl io::Read for TestSeqBuffer { + fn read(&mut self, dst: &mut [u8]) -> Result { + if self.0.borrow().read_buf.is_empty() { + if self.0.borrow().err.is_some() { + Err(self.0.borrow_mut().err.take().unwrap()) + } else { + Err(io::Error::new(io::ErrorKind::WouldBlock, "")) + } + } else { + let size = std::cmp::min(self.0.borrow().read_buf.len(), dst.len()); + let b = self.0.borrow_mut().read_buf.split_to(size); + dst[..size].copy_from_slice(&b); + Ok(size) + } + } +} + +impl io::Write for TestSeqBuffer { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.borrow_mut().write_buf.extend(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsyncRead for TestSeqBuffer { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let r = self.get_mut().read(buf); + match r { + Ok(n) => Poll::Ready(Ok(n)), + Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + Err(err) => Poll::Ready(Err(err)), + } + } +} + +impl AsyncWrite for TestSeqBuffer { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Poll::Ready(self.get_mut().write(buf)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 6ffdecc35..cd212fb7e 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -197,13 +197,13 @@ mod tests { let req = TestRequest::default().method(Method::POST).finish(); assert_eq!( HandshakeError::GetMethodRequired, - verify_handshake(req.head()).err().unwrap() + verify_handshake(req.head()).unwrap_err(), ); let req = TestRequest::default().finish(); assert_eq!( HandshakeError::NoWebsocketUpgrade, - verify_handshake(req.head()).err().unwrap() + verify_handshake(req.head()).unwrap_err(), ); let req = TestRequest::default() @@ -211,7 +211,7 @@ mod tests { .finish(); assert_eq!( HandshakeError::NoWebsocketUpgrade, - verify_handshake(req.head()).err().unwrap() + verify_handshake(req.head()).unwrap_err(), ); let req = TestRequest::default() @@ -222,7 +222,7 @@ mod tests { .finish(); assert_eq!( HandshakeError::NoConnectionUpgrade, - verify_handshake(req.head()).err().unwrap() + verify_handshake(req.head()).unwrap_err(), ); let req = TestRequest::default() @@ -237,7 +237,7 @@ mod tests { .finish(); assert_eq!( HandshakeError::NoVersionHeader, - verify_handshake(req.head()).err().unwrap() + verify_handshake(req.head()).unwrap_err(), ); let req = TestRequest::default() @@ -256,7 +256,7 @@ mod tests { .finish(); assert_eq!( HandshakeError::UnsupportedVersion, - verify_handshake(req.head()).err().unwrap() + verify_handshake(req.head()).unwrap_err(), ); let req = TestRequest::default() @@ -275,7 +275,7 @@ mod tests { .finish(); assert_eq!( HandshakeError::BadWebsocketKey, - verify_handshake(req.head()).err().unwrap() + verify_handshake(req.head()).unwrap_err(), ); let req = TestRequest::default() diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs index b9ebf97cc..b476f1791 100644 --- a/actix-multipart/src/server.rs +++ b/actix-multipart/src/server.rs @@ -725,9 +725,7 @@ impl Drop for Safety { if Rc::strong_count(&self.payload) != self.level { self.clean.set(true); } - if let Some(task) = self.task.take() { - task.wake() - } + self.task.wake(); } } diff --git a/actix-web-codegen/README.md b/actix-web-codegen/README.md index 6eca847b8..283591e86 100644 --- a/actix-web-codegen/README.md +++ b/actix-web-codegen/README.md @@ -3,7 +3,7 @@ > Helper and convenience macros for Actix Web [![crates.io](https://meritbadge.herokuapp.com/actix-web-codegen)](https://crates.io/crates/actix-web-codegen) -[![Documentation](https://docs.rs/actix-web-codegen/badge.svg)](https://docs.rs/actix-web) +[![Documentation](https://docs.rs/actix-web-codegen/badge.svg)](https://docs.rs/actix-web-codegen/0.4.0/actix_web_codegen/) [![Version](https://img.shields.io/badge/rustc-1.42+-ab6000.svg)](https://blog.rust-lang.org/2020/03/12/Rust-1.42.html) [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) diff --git a/actix-web-codegen/src/lib.rs b/actix-web-codegen/src/lib.rs index af2bc7f18..50e5be712 100644 --- a/actix-web-codegen/src/lib.rs +++ b/actix-web-codegen/src/lib.rs @@ -8,7 +8,7 @@ //! are re-exported. //! //! # Runtime Setup -//! Used for setting up the actix async runtime. See [main] macro docs. +//! Used for setting up the actix async runtime. See [macro@main] macro docs. //! //! ```rust //! #[actix_web_codegen::main] // or `#[actix_web::main]` in Actix Web apps @@ -34,7 +34,7 @@ //! //! # Multiple Method Handlers //! Similar to the single method handler macro but takes one or more arguments for the HTTP methods -//! it should respond to. See [route] macro docs. +//! it should respond to. See [macro@route] macro docs. //! //! ```rust //! # use actix_web::HttpResponse; @@ -46,17 +46,15 @@ //! ``` //! //! [actix-web attributes docs]: https://docs.rs/actix-web/*/actix_web/#attributes -//! [main]: attr.main.html -//! [route]: attr.route.html -//! [GET]: attr.get.html -//! [POST]: attr.post.html -//! [PUT]: attr.put.html -//! [DELETE]: attr.delete.html -//! [HEAD]: attr.head.html -//! [CONNECT]: attr.connect.html -//! [OPTIONS]: attr.options.html -//! [TRACE]: attr.trace.html -//! [PATCH]: attr.patch.html +//! [GET]: macro@get +//! [POST]: macro@post +//! [PUT]: macro@put +//! [HEAD]: macro@head +//! [CONNECT]: macro@macro@connect +//! [OPTIONS]: macro@options +//! [TRACE]: macro@trace +//! [PATCH]: macro@patch +//! [DELETE]: macro@delete #![recursion_limit = "512"] diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 7ca415336..e4f801bbe 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -1,6 +1,8 @@ # Changes ## Unreleased - 2020-xx-xx +### Changed +* Bumped `rand` to `0.8` ## 2.0.3 - 2020-11-29 diff --git a/awc/Cargo.toml b/awc/Cargo.toml index 3c1963d6b..2e92526d2 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -50,7 +50,7 @@ futures-core = { version = "0.3.5", default-features = false } log =" 0.4" mime = "0.3" percent-encoding = "2.1" -rand = "0.7" +rand = "0.8" serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.7" diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 57e80bd46..dd43d08b3 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -1,6 +1,6 @@ //! Websockets client //! -//! Type definitions required to use [`awc::Client`](../struct.Client.html) as a WebSocket client. +//! Type definitions required to use [`awc::Client`](super::Client) as a WebSocket client. //! //! # Example //! @@ -70,9 +70,14 @@ impl WebsocketsRequest { >::Error: Into, { let mut err = None; - let mut head = RequestHead::default(); - head.method = Method::GET; - head.version = Version::HTTP_11; + + #[allow(clippy::field_reassign_with_default)] + let mut head = { + let mut head = RequestHead::default(); + head.method = Method::GET; + head.version = Version::HTTP_11; + head + }; match Uri::try_from(uri) { Ok(uri) => head.uri = uri, diff --git a/awc/tests/test_client.rs b/awc/tests/test_client.rs index a9552d0d5..0024c6652 100644 --- a/awc/tests/test_client.rs +++ b/awc/tests/test_client.rs @@ -480,6 +480,7 @@ async fn test_client_gzip_encoding_large_random() { let data = rand::thread_rng() .sample_iter(&rand::distributions::Alphanumeric) .take(100_000) + .map(char::from) .collect::(); let srv = test::start(|| { @@ -529,6 +530,7 @@ async fn test_client_brotli_encoding_large_random() { let data = rand::thread_rng() .sample_iter(&rand::distributions::Alphanumeric) .take(70_000) + .map(char::from) .collect::(); let srv = test::start(|| { diff --git a/codecov.yml b/codecov.yml index 90cdfab47..e6bc40203 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,4 +1,15 @@ -ignore: # ignore codecoverage on following paths +comment: false + +coverage: + status: + project: + default: + threshold: 10% # make CI green + patch: + default: + threshold: 10% # make CI green + +ignore: # ignore code coverage on following paths - "**/tests" - "test-server" - "**/benches" diff --git a/src/config.rs b/src/config.rs index 03ba82732..01959daa1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -141,7 +141,7 @@ impl AppConfig { /// Server host name. /// /// Host name is used by application router as a hostname for url generation. - /// Check [ConnectionInfo](./struct.ConnectionInfo.html#method.host) + /// Check [ConnectionInfo](super::dev::ConnectionInfo::host()) /// documentation for more information. /// /// By default host name is set to a "localhost" value. diff --git a/src/extract.rs b/src/extract.rs index df9c34cb3..5916b1bc5 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -4,7 +4,8 @@ use std::pin::Pin; use std::task::{Context, Poll}; use actix_http::error::Error; -use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready}; +use futures_util::future::{ready, Ready}; +use futures_util::ready; use crate::dev::Payload; use crate::request::HttpRequest; @@ -95,21 +96,41 @@ where T: FromRequest, T::Future: 'static, { - type Config = T::Config; type Error = Error; - type Future = LocalBoxFuture<'static, Result, Error>>; + type Future = FromRequestOptFuture; + type Config = T::Config; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - T::from_request(req, payload) - .then(|r| match r { - Ok(v) => ok(Some(v)), - Err(e) => { - log::debug!("Error for Option extractor: {}", e.into()); - ok(None) - } - }) - .boxed_local() + FromRequestOptFuture { + fut: T::from_request(req, payload), + } + } +} + +#[pin_project::pin_project] +pub struct FromRequestOptFuture { + #[pin] + fut: Fut, +} + +impl Future for FromRequestOptFuture +where + Fut: Future>, + E: Into, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let res = ready!(this.fut.poll(cx)); + match res { + Ok(t) => Poll::Ready(Ok(Some(t))), + Err(e) => { + log::debug!("Error for Option extractor: {}", e.into()); + Poll::Ready(Ok(None)) + } + } } } @@ -165,29 +186,45 @@ where T::Error: 'static, T::Future: 'static, { - type Config = T::Config; type Error = Error; - type Future = LocalBoxFuture<'static, Result, Error>>; + type Future = FromRequestResFuture; + type Config = T::Config; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - T::from_request(req, payload) - .then(|res| match res { - Ok(v) => ok(Ok(v)), - Err(e) => ok(Err(e)), - }) - .boxed_local() + FromRequestResFuture { + fut: T::from_request(req, payload), + } + } +} + +#[pin_project::pin_project] +pub struct FromRequestResFuture { + #[pin] + fut: Fut, +} + +impl Future for FromRequestResFuture +where + Fut: Future>, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let res = ready!(this.fut.poll(cx)); + Poll::Ready(Ok(res)) } } #[doc(hidden)] impl FromRequest for () { - type Config = (); type Error = Error; type Future = Ready>; + type Config = (); fn from_request(_: &HttpRequest, _: &mut Payload) -> Self::Future { - ok(()) + ready(Ok(())) } } diff --git a/src/handler.rs b/src/handler.rs index 669512ab3..3d0a2382e 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,4 +1,3 @@ -use std::convert::Infallible; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; @@ -6,7 +5,7 @@ use std::task::{Context, Poll}; use actix_http::{Error, Response}; use actix_service::{Service, ServiceFactory}; -use futures_util::future::{ok, Ready}; +use futures_util::future::{ready, Ready}; use futures_util::ready; use pin_project::pin_project; @@ -36,9 +35,11 @@ where } #[doc(hidden)] +/// Extract arguments from request, run factory function and make response. pub struct Handler where F: Factory, + T: FromRequest, R: Future, O: Responder, { @@ -49,6 +50,7 @@ where impl Handler where F: Factory, + T: FromRequest, R: Future, O: Responder, { @@ -63,6 +65,7 @@ where impl Clone for Handler where F: Factory, + T: FromRequest, R: Future, O: Responder, { @@ -74,188 +77,103 @@ where } } -impl Service for Handler +impl ServiceFactory for Handler where F: Factory, + T: FromRequest, R: Future, O: Responder, { - type Request = (T, HttpRequest); + type Request = ServiceRequest; type Response = ServiceResponse; - type Error = Infallible; - type Future = HandlerServiceResponse; + type Error = Error; + type Config = (); + type Service = Self; + type InitError = (); + type Future = Ready>; + + fn new_service(&self, _: ()) -> Self::Future { + ready(Ok(self.clone())) + } +} + +// Handler is both it's ServiceFactory and Service Type. +impl Service for Handler +where + F: Factory, + T: FromRequest, + R: Future, + O: Responder, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = HandlerServiceFuture; fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, (param, req): (T, HttpRequest)) -> Self::Future { - HandlerServiceResponse { - fut: self.hnd.call(param), - fut2: None, - req: Some(req), - } + fn call(&mut self, req: Self::Request) -> Self::Future { + let (req, mut payload) = req.into_parts(); + let fut = T::from_request(&req, &mut payload); + HandlerServiceFuture::Extract(fut, Some(req), self.hnd.clone()) } } #[doc(hidden)] -#[pin_project] -pub struct HandlerServiceResponse +#[pin_project(project = HandlerProj)] +pub enum HandlerServiceFuture where - T: Future, - R: Responder, + F: Factory, + T: FromRequest, + R: Future, + O: Responder, { - #[pin] - fut: T, - #[pin] - fut2: Option, - req: Option, + Extract(#[pin] T::Future, Option, F), + Handle(#[pin] R, Option), + Respond(#[pin] O::Future, Option), } -impl Future for HandlerServiceResponse +impl Future for HandlerServiceFuture where - T: Future, - R: Responder, + F: Factory, + T: FromRequest, + R: Future, + O: Responder, { - type Output = Result; + // Error type in this future is a placeholder type. + // all instances of error must be converted to ServiceResponse and return in Ok. + type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().project(); - - if let Some(fut) = this.fut2.as_pin_mut() { - return match fut.poll(cx) { - Poll::Ready(Ok(res)) => { - Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) + loop { + match self.as_mut().project() { + HandlerProj::Extract(fut, req, handle) => { + match ready!(fut.poll(cx)) { + Ok(item) => { + let fut = handle.call(item); + let state = HandlerServiceFuture::Handle(fut, req.take()); + self.as_mut().set(state); + } + Err(e) => { + let res: Response = e.into().into(); + let req = req.take().unwrap(); + return Poll::Ready(Ok(ServiceResponse::new(req, res))); + } + }; } - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - let res: Response = e.into().into(); - Poll::Ready(Ok(ServiceResponse::new(this.req.take().unwrap(), res))) + HandlerProj::Handle(fut, req) => { + let res = ready!(fut.poll(cx)); + let fut = res.respond_to(req.as_ref().unwrap()); + let state = HandlerServiceFuture::Respond(fut, req.take()); + self.as_mut().set(state); + } + HandlerProj::Respond(fut, req) => { + let res = ready!(fut.poll(cx)).unwrap_or_else(|e| e.into().into()); + let req = req.take().unwrap(); + return Poll::Ready(Ok(ServiceResponse::new(req, res))); } - }; - } - - match this.fut.poll(cx) { - Poll::Ready(res) => { - let fut = res.respond_to(this.req.as_ref().unwrap()); - self.as_mut().project().fut2.set(Some(fut)); - self.poll(cx) - } - Poll::Pending => Poll::Pending, - } - } -} - -/// Extract arguments from request -pub struct Extract { - service: S, - _t: PhantomData, -} - -impl Extract { - pub fn new(service: S) -> Self { - Extract { - service, - _t: PhantomData, - } - } -} - -impl ServiceFactory for Extract -where - S: Service< - Request = (T, HttpRequest), - Response = ServiceResponse, - Error = Infallible, - > + Clone, -{ - type Config = (); - type Request = ServiceRequest; - type Response = ServiceResponse; - type Error = (Error, ServiceRequest); - type InitError = (); - type Service = ExtractService; - type Future = Ready>; - - fn new_service(&self, _: ()) -> Self::Future { - ok(ExtractService { - _t: PhantomData, - service: self.service.clone(), - }) - } -} - -pub struct ExtractService { - service: S, - _t: PhantomData, -} - -impl Service for ExtractService -where - S: Service< - Request = (T, HttpRequest), - Response = ServiceResponse, - Error = Infallible, - > + Clone, -{ - type Request = ServiceRequest; - type Response = ServiceResponse; - type Error = (Error, ServiceRequest); - type Future = ExtractResponse; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: ServiceRequest) -> Self::Future { - let (req, mut payload) = req.into_parts(); - let fut = T::from_request(&req, &mut payload); - - ExtractResponse { - fut, - req, - fut_s: None, - service: self.service.clone(), - } - } -} - -#[pin_project] -pub struct ExtractResponse { - req: HttpRequest, - service: S, - #[pin] - fut: T::Future, - #[pin] - fut_s: Option, -} - -impl Future for ExtractResponse -where - S: Service< - Request = (T, HttpRequest), - Response = ServiceResponse, - Error = Infallible, - >, -{ - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().project(); - - if let Some(fut) = this.fut_s.as_pin_mut() { - return fut.poll(cx).map_err(|_| panic!()); - } - - match ready!(this.fut.poll(cx)) { - Err(e) => { - let req = ServiceRequest::new(this.req.clone()); - Poll::Ready(Err((e.into(), req))) - } - Ok(item) => { - let fut = Some(this.service.call((item, this.req.clone()))); - self.as_mut().project().fut_s.set(fut); - self.poll(cx) } } } diff --git a/src/info.rs b/src/info.rs index 1d9b402a7..975604041 100644 --- a/src/info.rs +++ b/src/info.rs @@ -174,7 +174,7 @@ impl ConnectionInfo { /// Do not use this function for security purposes, unless you can ensure the Forwarded and /// X-Forwarded-For headers cannot be spoofed by the client. If you want the client's socket /// address explicitly, use - /// [`HttpRequest::peer_addr()`](../web/struct.HttpRequest.html#method.peer_addr) instead. + /// [`HttpRequest::peer_addr()`](super::web::HttpRequest::peer_addr()) instead. #[inline] pub fn realip_remote_addr(&self) -> Option<&str> { if let Some(ref r) = self.realip_remote_addr { diff --git a/src/lib.rs b/src/lib.rs index a8fc50d83..b8346d966 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,16 +29,16 @@ //! //! To get started navigating the API docs, you may consider looking at the following pages first: //! -//! * [App](struct.App.html): This struct represents an Actix web application and is used to +//! * [App]: This struct represents an Actix web application and is used to //! configure routes and other common application settings. //! -//! * [HttpServer](struct.HttpServer.html): This struct represents an HTTP server instance and is +//! * [HttpServer]: This struct represents an HTTP server instance and is //! used to instantiate and configure servers. //! -//! * [web](web/index.html): This module provides essential types for route registration as well as +//! * [web]: This module provides essential types for route registration as well as //! common utilities for request handlers. //! -//! * [HttpRequest](struct.HttpRequest.html) and [HttpResponse](struct.HttpResponse.html): These +//! * [HttpRequest] and [HttpResponse]: These //! structs represent HTTP requests and responses and expose methods for creating, inspecting, //! and otherwise utilizing them. //! diff --git a/src/middleware/compress.rs b/src/middleware/compress.rs index fe3ba841c..7575d7455 100644 --- a/src/middleware/compress.rs +++ b/src/middleware/compress.rs @@ -192,10 +192,7 @@ impl AcceptEncoding { }; let quality = match parts.len() { 1 => encoding.quality(), - _ => match f64::from_str(parts[1]) { - Ok(q) => q, - Err(_) => 0.0, - }, + _ => f64::from_str(parts[1]).unwrap_or(0.0), }; Some(AcceptEncoding { encoding, quality }) } diff --git a/src/middleware/condition.rs b/src/middleware/condition.rs index ab1c69746..9061c7458 100644 --- a/src/middleware/condition.rs +++ b/src/middleware/condition.rs @@ -105,6 +105,7 @@ mod tests { use crate::test::{self, TestRequest}; use crate::HttpResponse; + #[allow(clippy::unnecessary_wraps)] fn render_500(mut res: ServiceResponse) -> Result> { res.response_mut() .headers_mut() diff --git a/src/middleware/defaultheaders.rs b/src/middleware/defaultheaders.rs index 6d43aba95..a6f1a4336 100644 --- a/src/middleware/defaultheaders.rs +++ b/src/middleware/defaultheaders.rs @@ -1,10 +1,14 @@ //! Middleware for setting default response headers use std::convert::TryFrom; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; use std::rc::Rc; use std::task::{Context, Poll}; use actix_service::{Service, Transform}; -use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready}; +use futures_util::future::{ready, Ready}; +use futures_util::ready; use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; use crate::http::{Error as HttpError, HeaderMap}; @@ -97,15 +101,15 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type InitError = (); type Transform = DefaultHeadersMiddleware; + type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - ok(DefaultHeadersMiddleware { + ready(Ok(DefaultHeadersMiddleware { service, inner: self.inner.clone(), - }) + })) } } @@ -122,36 +126,56 @@ where type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; - type Future = LocalBoxFuture<'static, Result>; + type Future = DefaultHeaderFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } - #[allow(clippy::borrow_interior_mutable_const)] fn call(&mut self, req: ServiceRequest) -> Self::Future { let inner = self.inner.clone(); let fut = self.service.call(req); - async move { - let mut res = fut.await?; - - // set response headers - for (key, value) in inner.headers.iter() { - if !res.headers().contains_key(key) { - res.headers_mut().insert(key.clone(), value.clone()); - } - } - // default content-type - if inner.ct && !res.headers().contains_key(&CONTENT_TYPE) { - res.headers_mut().insert( - CONTENT_TYPE, - HeaderValue::from_static("application/octet-stream"), - ); - } - Ok(res) + DefaultHeaderFuture { + fut, + inner, + _body: PhantomData, } - .boxed_local() + } +} + +#[pin_project::pin_project] +pub struct DefaultHeaderFuture { + #[pin] + fut: S::Future, + inner: Rc, + _body: PhantomData, +} + +impl Future for DefaultHeaderFuture +where + S: Service, Error = Error>, +{ + type Output = ::Output; + + #[allow(clippy::borrow_interior_mutable_const)] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let mut res = ready!(this.fut.poll(cx))?; + // set response headers + for (key, value) in this.inner.headers.iter() { + if !res.headers().contains_key(key) { + res.headers_mut().insert(key.clone(), value.clone()); + } + } + // default content-type + if this.inner.ct && !res.headers().contains_key(&CONTENT_TYPE) { + res.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + } + Poll::Ready(Ok(res)) } } diff --git a/src/middleware/errhandlers.rs b/src/middleware/errhandlers.rs index 93a5d3f22..c0cb9594e 100644 --- a/src/middleware/errhandlers.rs +++ b/src/middleware/errhandlers.rs @@ -154,6 +154,7 @@ mod tests { use crate::test::{self, TestRequest}; use crate::HttpResponse; + #[allow(clippy::unnecessary_wraps)] fn render_500(mut res: ServiceResponse) -> Result> { res.response_mut() .headers_mut() diff --git a/src/middleware/logger.rs b/src/middleware/logger.rs index b2e5c791f..563cb6c32 100644 --- a/src/middleware/logger.rs +++ b/src/middleware/logger.rs @@ -82,11 +82,10 @@ use crate::HttpResponse; /// /// # Security /// **\*** It is calculated using -/// [`ConnectionInfo::realip_remote_addr()`](../dev/struct.ConnectionInfo.html#method.realip_remote_addr) +/// [`ConnectionInfo::realip_remote_addr()`](crate::dev::ConnectionInfo::realip_remote_addr()) /// /// If you use this value ensure that all requests come from trusted hosts, since it is trivial /// for the remote client to simulate being another client. -/// pub struct Logger(Rc); struct Inner { diff --git a/src/middleware/normalize.rs b/src/middleware/normalize.rs index ac8ad71d5..ad9f51079 100644 --- a/src/middleware/normalize.rs +++ b/src/middleware/normalize.rs @@ -1,10 +1,11 @@ -//! `Middleware` to normalize request's URI +//! For middleware documentation, see [`NormalizePath`]. + use std::task::{Context, Poll}; use actix_http::http::{PathAndQuery, Uri}; use actix_service::{Service, Transform}; use bytes::Bytes; -use futures_util::future::{ok, Ready}; +use futures_util::future::{ready, Ready}; use regex::Regex; use crate::service::{ServiceRequest, ServiceResponse}; @@ -17,10 +18,12 @@ pub enum TrailingSlash { /// Always add a trailing slash to the end of the path. /// This will require all routes to end in a trailing slash for them to be accessible. Always, + /// Only merge any present multiple trailing slashes. /// - /// Note: This option provides the best compatibility with the v2 version of this middlware. + /// Note: This option provides the best compatibility with the v2 version of this middleware. MergeOnly, + /// Trim trailing slashes from the end of the path. Trim, } @@ -32,28 +35,53 @@ impl Default for TrailingSlash { } #[derive(Default, Clone, Copy)] -/// `Middleware` to normalize request's URI in place +/// Middleware to normalize a request's path so that routes can be matched less strictly. /// -/// Performs following: -/// -/// - Merges multiple slashes into one. +/// # Normalization Steps +/// - Merges multiple consecutive slashes into one. (For example, `/path//one` always +/// becomes `/path/one`.) /// - Appends a trailing slash if one is not present, removes one if present, or keeps trailing -/// slashes as-is, depending on the supplied `TrailingSlash` variant. +/// slashes as-is, depending on which [`TrailingSlash`] variant is supplied +/// to [`new`](NormalizePath::new()). /// +/// # Default Behavior +/// The default constructor chooses to strip trailing slashes from the end +/// ([`TrailingSlash::Trim`]), the effect is that route definitions should be defined without +/// trailing slashes or else they will be inaccessible. +/// +/// # Example /// ```rust -/// use actix_web::{web, http, middleware, App, HttpResponse}; +/// use actix_web::{web, middleware, App}; /// -/// # fn main() { +/// # #[actix_rt::test] +/// # async fn normalize() { /// let app = App::new() /// .wrap(middleware::NormalizePath::default()) -/// .service( -/// web::resource("/test") -/// .route(web::get().to(|| HttpResponse::Ok())) -/// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) -/// ); +/// .route("/test", web::get().to(|| async { "test" })) +/// .route("/unmatchable/", web::get().to(|| async { "unmatchable" })); +/// +/// use actix_web::http::StatusCode; +/// use actix_web::test::{call_service, init_service, TestRequest}; +/// +/// let mut app = init_service(app).await; +/// +/// let req = TestRequest::with_uri("/test").to_request(); +/// let res = call_service(&mut app, req).await; +/// assert_eq!(res.status(), StatusCode::OK); +/// +/// let req = TestRequest::with_uri("/test/").to_request(); +/// let res = call_service(&mut app, req).await; +/// assert_eq!(res.status(), StatusCode::OK); +/// +/// let req = TestRequest::with_uri("/unmatchable").to_request(); +/// let res = call_service(&mut app, req).await; +/// assert_eq!(res.status(), StatusCode::NOT_FOUND); +/// +/// let req = TestRequest::with_uri("/unmatchable/").to_request(); +/// let res = call_service(&mut app, req).await; +/// assert_eq!(res.status(), StatusCode::NOT_FOUND); /// # } /// ``` - pub struct NormalizePath(TrailingSlash); impl NormalizePath { @@ -76,11 +104,11 @@ where type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - ok(NormalizePathNormalization { + ready(Ok(NormalizePathNormalization { service, merge_slash: Regex::new("//+").unwrap(), trailing_slash_behavior: self.0, - }) + })) } } @@ -160,9 +188,11 @@ mod tests { use actix_service::IntoService; use super::*; - use crate::dev::ServiceRequest; - use crate::test::{call_service, init_service, TestRequest}; - use crate::{web, App, HttpResponse}; + use crate::{ + dev::ServiceRequest, + test::{call_service, init_service, TestRequest}, + web, App, HttpResponse, + }; #[actix_rt::test] async fn test_wrap() { @@ -244,7 +274,7 @@ mod tests { } #[actix_rt::test] - async fn keep_trailing_slash_unchange() { + async fn keep_trailing_slash_unchanged() { let mut app = init_service( App::new() .wrap(NormalizePath(TrailingSlash::MergeOnly)) @@ -279,7 +309,7 @@ mod tests { async fn test_in_place_normalization() { let srv = |req: ServiceRequest| { assert_eq!("/v1/something/", req.path()); - ok(req.into_response(HttpResponse::Ok().finish())) + ready(Ok(req.into_response(HttpResponse::Ok().finish()))) }; let mut normalize = NormalizePath::default() @@ -310,7 +340,7 @@ mod tests { let srv = |req: ServiceRequest| { assert_eq!(URI, req.path()); - ok(req.into_response(HttpResponse::Ok().finish())) + ready(Ok(req.into_response(HttpResponse::Ok().finish()))) }; let mut normalize = NormalizePath::default() @@ -324,12 +354,12 @@ mod tests { } #[actix_rt::test] - async fn should_normalize_notrail() { + async fn should_normalize_no_trail() { const URI: &str = "/v1/something"; let srv = |req: ServiceRequest| { assert_eq!(URI.to_string() + "/", req.path()); - ok(req.into_response(HttpResponse::Ok().finish())) + ready(Ok(req.into_response(HttpResponse::Ok().finish()))) }; let mut normalize = NormalizePath::default() diff --git a/src/route.rs b/src/route.rs index 8cc1edfc2..439ae6c4a 100644 --- a/src/route.rs +++ b/src/route.rs @@ -11,29 +11,29 @@ use futures_util::future::{ready, FutureExt, LocalBoxFuture}; use crate::extract::FromRequest; use crate::guard::{self, Guard}; -use crate::handler::{Extract, Factory, Handler}; +use crate::handler::{Factory, Handler}; use crate::responder::Responder; use crate::service::{ServiceRequest, ServiceResponse}; use crate::HttpResponse; -type BoxedRouteService = Box< +type BoxedRouteService = Box< dyn Service< - Request = Req, - Response = Res, + Request = ServiceRequest, + Response = ServiceResponse, Error = Error, - Future = LocalBoxFuture<'static, Result>, + Future = LocalBoxFuture<'static, Result>, >, >; -type BoxedRouteNewService = Box< +type BoxedRouteNewService = Box< dyn ServiceFactory< Config = (), - Request = Req, - Response = Res, + Request = ServiceRequest, + Response = ServiceResponse, Error = Error, InitError = (), - Service = BoxedRouteService, - Future = LocalBoxFuture<'static, Result, ()>>, + Service = BoxedRouteService, + Future = LocalBoxFuture<'static, Result>, >, >; @@ -42,7 +42,7 @@ type BoxedRouteNewService = Box< /// Route uses builder-like pattern for configuration. /// If handler is not explicitly set, default *404 Not Found* handler is used. pub struct Route { - service: BoxedRouteNewService, + service: BoxedRouteNewService, guards: Rc>>, } @@ -51,9 +51,9 @@ impl Route { #[allow(clippy::new_without_default)] pub fn new() -> Route { Route { - service: Box::new(RouteNewService::new(Extract::new(Handler::new(|| { + service: Box::new(RouteNewService::new(Handler::new(|| { ready(HttpResponse::NotFound()) - })))), + }))), guards: Rc::new(Vec::new()), } } @@ -80,15 +80,8 @@ impl ServiceFactory for Route { } } -type RouteFuture = LocalBoxFuture< - 'static, - Result, ()>, ->; - -#[pin_project::pin_project] pub struct CreateRouteService { - #[pin] - fut: RouteFuture, + fut: LocalBoxFuture<'static, Result>, guards: Rc>>, } @@ -96,9 +89,9 @@ impl Future for CreateRouteService { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); + let this = self.get_mut(); - match this.fut.poll(cx)? { + match this.fut.as_mut().poll(cx)? { Poll::Ready(service) => Poll::Ready(Ok(RouteService { service, guards: this.guards.clone(), @@ -109,7 +102,7 @@ impl Future for CreateRouteService { } pub struct RouteService { - service: BoxedRouteService, + service: BoxedRouteService, guards: Rc>>, } @@ -233,15 +226,14 @@ impl Route { R: Future + 'static, U: Responder + 'static, { - self.service = - Box::new(RouteNewService::new(Extract::new(Handler::new(handler)))); + self.service = Box::new(RouteNewService::new(Handler::new(handler))); self } } struct RouteNewService where - T: ServiceFactory, + T: ServiceFactory, { service: T, } @@ -252,7 +244,7 @@ where Config = (), Request = ServiceRequest, Response = ServiceResponse, - Error = (Error, ServiceRequest), + Error = Error, >, T::Future: 'static, T::Service: 'static, @@ -269,18 +261,18 @@ where Config = (), Request = ServiceRequest, Response = ServiceResponse, - Error = (Error, ServiceRequest), + Error = Error, >, T::Future: 'static, T::Service: 'static, ::Future: 'static, { - type Config = (); type Request = ServiceRequest; type Response = ServiceResponse; type Error = Error; + type Config = (); + type Service = BoxedRouteService; type InitError = (); - type Service = BoxedRouteService; type Future = LocalBoxFuture<'static, Result>; fn new_service(&self, _: ()) -> Self::Future { @@ -288,8 +280,7 @@ where .new_service(()) .map(|result| match result { Ok(service) => { - let service: BoxedRouteService<_, _> = - Box::new(RouteServiceWrapper { service }); + let service = Box::new(RouteServiceWrapper { service }) as _; Ok(service) } Err(_) => Err(()), @@ -305,11 +296,7 @@ struct RouteServiceWrapper { impl Service for RouteServiceWrapper where T::Future: 'static, - T: Service< - Request = ServiceRequest, - Response = ServiceResponse, - Error = (Error, ServiceRequest), - >, + T: Service, { type Request = ServiceRequest; type Response = ServiceResponse; @@ -317,27 +304,11 @@ where type Future = LocalBoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx).map_err(|(e, _)| e) + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { - // let mut fut = self.service.call(req); - self.service - .call(req) - .map(|res| match res { - Ok(res) => Ok(res), - Err((err, req)) => Ok(req.error_response(err)), - }) - .boxed_local() - - // match fut.poll() { - // Poll::Ready(Ok(res)) => Either::Left(ok(res)), - // Poll::Ready(Err((e, req))) => Either::Left(ok(req.error_response(e))), - // Poll::Pending => Either::Right(Box::new(fut.then(|res| match res { - // Ok(res) => Ok(res), - // Err((err, req)) => Ok(req.error_response(err)), - // }))), - // } + Box::pin(self.service.call(req)) } } diff --git a/src/server.rs b/src/server.rs index 3badb6e8d..be97e8a0d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -213,7 +213,7 @@ where /// Set server host name. /// /// Host name is used by application router as a hostname for url generation. - /// Check [ConnectionInfo](./dev/struct.ConnectionInfo.html#method.host) + /// Check [ConnectionInfo](super::dev::ConnectionInfo::host()) /// documentation for more information. /// /// By default host name is set to a "localhost" value. diff --git a/src/service.rs b/src/service.rs index a861ba38c..189ba5554 100644 --- a/src/service.rs +++ b/src/service.rs @@ -195,13 +195,13 @@ impl ServiceRequest { self.0.match_info() } - /// Counterpart to [`HttpRequest::match_name`](../struct.HttpRequest.html#method.match_name). + /// Counterpart to [`HttpRequest::match_name`](super::HttpRequest::match_name()). #[inline] pub fn match_name(&self) -> Option<&str> { self.0.match_name() } - /// Counterpart to [`HttpRequest::match_pattern`](../struct.HttpRequest.html#method.match_pattern). + /// Counterpart to [`HttpRequest::match_pattern`](super::HttpRequest::match_pattern()). #[inline] pub fn match_pattern(&self) -> Option { self.0.match_pattern() @@ -225,7 +225,7 @@ impl ServiceRequest { self.0.app_config() } - /// Counterpart to [`HttpRequest::app_data`](../struct.HttpRequest.html#method.app_data). + /// Counterpart to [`HttpRequest::app_data`](super::HttpRequest::app_data()). pub fn app_data(&self) -> Option<&T> { for container in (self.0).0.app_data.iter().rev() { if let Some(data) = container.get::() { diff --git a/src/types/form.rs b/src/types/form.rs index 2a7101287..82ea73216 100644 --- a/src/types/form.rs +++ b/src/types/form.rs @@ -35,7 +35,7 @@ use crate::{responder::Responder, web}; /// To extract typed information from request's body, the type `T` must /// implement the `Deserialize` trait from *serde*. /// -/// [**FormConfig**](struct.FormConfig.html) allows to configure extraction +/// [**FormConfig**](FormConfig) allows to configure extraction /// process. /// /// ### Example diff --git a/src/types/json.rs b/src/types/json.rs index 081a022e8..dc0870a6e 100644 --- a/src/types/json.rs +++ b/src/types/json.rs @@ -1,14 +1,16 @@ //! Json extractor/responder use std::future::Future; +use std::marker::PhantomData; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::{fmt, ops}; use bytes::BytesMut; -use futures_util::future::{err, ok, FutureExt, LocalBoxFuture, Ready}; -use futures_util::StreamExt; +use futures_util::future::{ready, Ready}; +use futures_util::ready; +use futures_util::stream::Stream; use serde::de::DeserializeOwned; use serde::Serialize; @@ -31,7 +33,7 @@ use crate::{responder::Responder, web}; /// To extract typed information from request's body, the type `T` must /// implement the `Deserialize` trait from *serde*. /// -/// [**JsonConfig**](struct.JsonConfig.html) allows to configure extraction +/// [**JsonConfig**](JsonConfig) allows to configure extraction /// process. /// /// ## Example @@ -127,12 +129,12 @@ impl Responder for Json { fn respond_to(self, _: &HttpRequest) -> Self::Future { let body = match serde_json::to_string(&self.0) { Ok(body) => body, - Err(e) => return err(e.into()), + Err(e) => return ready(Err(e.into())), }; - ok(Response::build(StatusCode::OK) + ready(Ok(Response::build(StatusCode::OK) .content_type("application/json") - .body(body)) + .body(body))) } } @@ -142,7 +144,7 @@ impl Responder for Json { /// To extract typed information from request's body, the type `T` must /// implement the `Deserialize` trait from *serde*. /// -/// [**JsonConfig**](struct.JsonConfig.html) allows to configure extraction +/// [**JsonConfig**](JsonConfig) allows to configure extraction /// process. /// /// ## Example @@ -173,37 +175,64 @@ where T: DeserializeOwned + 'static, { type Error = Error; - type Future = LocalBoxFuture<'static, Result>; + type Future = JsonExtractFut; type Config = JsonConfig; #[inline] fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - let req2 = req.clone(); let config = JsonConfig::from_req(req); let limit = config.limit; - let ctype = config.content_type.clone(); + let ctype = config.content_type.as_deref(); let err_handler = config.err_handler.clone(); - JsonBody::new(req, payload, ctype) - .limit(limit) - .map(move |res| match res { - Err(e) => { - log::debug!( - "Failed to deserialize Json from payload. \ - Request path: {}", - req2.path() - ); + JsonExtractFut { + req: Some(req.clone()), + fut: JsonBody::new(req, payload, ctype).limit(limit), + err_handler, + } + } +} - if let Some(err) = err_handler { - Err((*err)(e, &req2)) - } else { - Err(e.into()) - } +type JsonErrorHandler = + Option Error + Send + Sync>>; + +pub struct JsonExtractFut { + req: Option, + fut: JsonBody, + err_handler: JsonErrorHandler, +} + +impl Future for JsonExtractFut +where + T: DeserializeOwned + 'static, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + let res = ready!(Pin::new(&mut this.fut).poll(cx)); + + let res = match res { + Err(e) => { + let req = this.req.take().unwrap(); + log::debug!( + "Failed to deserialize Json from payload. \ + Request path: {}", + req.path() + ); + + if let Some(err) = this.err_handler.as_ref() { + Err((*err)(e, &req)) + } else { + Err(e.into()) } - Ok(data) => Ok(Json(data)), - }) - .boxed_local() + } + Ok(data) => Ok(Json(data)), + }; + + Poll::Ready(res) } } @@ -248,8 +277,7 @@ where #[derive(Clone)] pub struct JsonConfig { limit: usize, - err_handler: - Option Error + Send + Sync>>, + err_handler: JsonErrorHandler, content_type: Option bool + Send + Sync>>, } @@ -306,19 +334,24 @@ impl Default for JsonConfig { /// Returns error: /// /// * content type is not `application/json` -/// (unless specified in [`JsonConfig`](struct.JsonConfig.html)) +/// (unless specified in [`JsonConfig`]) /// * content length is greater than 256k -pub struct JsonBody { - limit: usize, - length: Option, - #[cfg(feature = "compress")] - stream: Option>, - #[cfg(not(feature = "compress"))] - stream: Option, - err: Option, - fut: Option>>, +pub enum JsonBody { + Error(Option), + Body { + limit: usize, + length: Option, + #[cfg(feature = "compress")] + payload: Decompress, + #[cfg(not(feature = "compress"))] + payload: Payload, + buf: BytesMut, + _res: PhantomData, + }, } +impl Unpin for JsonBody {} + impl JsonBody where U: DeserializeOwned + 'static, @@ -328,7 +361,7 @@ where pub fn new( req: &HttpRequest, payload: &mut Payload, - ctype: Option bool + Send + Sync>>, + ctype: Option<&(dyn Fn(mime::Mime) -> bool + Send + Sync)>, ) -> Self { // check content-type let json = if let Ok(Some(mime)) = req.mime_type() { @@ -340,39 +373,58 @@ where }; if !json { - return JsonBody { - limit: 262_144, - length: None, - stream: None, - fut: None, - err: Some(JsonPayloadError::ContentType), - }; + return JsonBody::Error(Some(JsonPayloadError::ContentType)); } - let len = req + let length = req .headers() .get(&CONTENT_LENGTH) .and_then(|l| l.to_str().ok()) .and_then(|s| s.parse::().ok()); + // Notice the content_length is not checked against limit of json config here. + // As the internal usage always call JsonBody::limit after JsonBody::new. + // And limit check to return an error variant of JsonBody happens there. + #[cfg(feature = "compress")] let payload = Decompress::from_headers(payload.take(), req.headers()); #[cfg(not(feature = "compress"))] let payload = payload.take(); - JsonBody { + JsonBody::Body { limit: 262_144, - length: len, - stream: Some(payload), - fut: None, - err: None, + length, + payload, + buf: BytesMut::with_capacity(8192), + _res: PhantomData, } } /// Change max size of payload. By default max size is 256Kb - pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; - self + pub fn limit(self, limit: usize) -> Self { + match self { + JsonBody::Body { + length, + payload, + buf, + .. + } => { + if let Some(len) = length { + if len > limit { + return JsonBody::Error(Some(JsonPayloadError::Overflow)); + } + } + + JsonBody::Body { + limit, + length, + payload, + buf, + _res: PhantomData, + } + } + JsonBody::Error(e) => JsonBody::Error(e), + } } } @@ -382,41 +434,34 @@ where { type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if let Some(ref mut fut) = self.fut { - return Pin::new(fut).poll(cx); - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); - if let Some(err) = self.err.take() { - return Poll::Ready(Err(err)); - } - - let limit = self.limit; - if let Some(len) = self.length.take() { - if len > limit { - return Poll::Ready(Err(JsonPayloadError::Overflow)); - } - } - let mut stream = self.stream.take().unwrap(); - - self.fut = Some( - async move { - let mut body = BytesMut::with_capacity(8192); - - while let Some(item) = stream.next().await { - let chunk = item?; - if (body.len() + chunk.len()) > limit { - return Err(JsonPayloadError::Overflow); - } else { - body.extend_from_slice(&chunk); + match this { + JsonBody::Body { + limit, + buf, + payload, + .. + } => loop { + let res = ready!(Pin::new(&mut *payload).poll_next(cx)); + match res { + Some(chunk) => { + let chunk = chunk?; + if (buf.len() + chunk.len()) > *limit { + return Poll::Ready(Err(JsonPayloadError::Overflow)); + } else { + buf.extend_from_slice(&chunk); + } + } + None => { + let json = serde_json::from_slice::(&buf)?; + return Poll::Ready(Ok(json)); } } - Ok(serde_json::from_slice::(&body)?) - } - .boxed_local(), - ); - - self.poll(cx) + }, + JsonBody::Error(e) => Poll::Ready(Err(e.take().unwrap())), + } } } diff --git a/src/types/path.rs b/src/types/path.rs index dbb5f3ee0..640ff4346 100644 --- a/src/types/path.rs +++ b/src/types/path.rs @@ -15,7 +15,7 @@ use crate::FromRequest; #[derive(PartialEq, Eq, PartialOrd, Ord)] /// Extract typed information from the request's path. /// -/// [**PathConfig**](struct.PathConfig.html) allows to configure extraction process. +/// [**PathConfig**](PathConfig) allows to configure extraction process. /// /// ## Example /// diff --git a/src/types/payload.rs b/src/types/payload.rs index 4ff5ef4b4..9228b37aa 100644 --- a/src/types/payload.rs +++ b/src/types/payload.rs @@ -7,10 +7,12 @@ use std::task::{Context, Poll}; use actix_http::error::{Error, ErrorBadRequest, PayloadError}; use actix_http::HttpMessage; use bytes::{Bytes, BytesMut}; -use encoding_rs::UTF_8; +use encoding_rs::{Encoding, UTF_8}; use futures_core::stream::Stream; -use futures_util::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready}; -use futures_util::StreamExt; +use futures_util::{ + future::{err, ok, Either, ErrInto, Ready, TryFutureExt as _}, + ready, +}; use mime::Mime; use crate::extract::FromRequest; @@ -111,7 +113,7 @@ impl FromRequest for Payload { /// /// Loads request's payload and construct Bytes instance. /// -/// [**PayloadConfig**](struct.PayloadConfig.html) allows to configure +/// [**PayloadConfig**](PayloadConfig) allows to configure /// extraction process. /// /// ## Example @@ -135,10 +137,7 @@ impl FromRequest for Payload { impl FromRequest for Bytes { type Config = PayloadConfig; type Error = Error; - type Future = Either< - LocalBoxFuture<'static, Result>, - Ready>, - >; + type Future = Either, Ready>>; #[inline] fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { @@ -151,7 +150,7 @@ impl FromRequest for Bytes { let limit = cfg.limit; let fut = HttpMessageBody::new(req, payload).limit(limit); - Either::Left(async move { Ok(fut.await?) }.boxed_local()) + Either::Left(fut.err_into()) } } @@ -159,7 +158,7 @@ impl FromRequest for Bytes { /// /// Text extractor automatically decode body according to the request's charset. /// -/// [**PayloadConfig**](struct.PayloadConfig.html) allows to configure +/// [**PayloadConfig**](PayloadConfig) allows to configure /// extraction process. /// /// ## Example @@ -185,10 +184,7 @@ impl FromRequest for Bytes { impl FromRequest for String { type Config = PayloadConfig; type Error = Error; - type Future = Either< - LocalBoxFuture<'static, Result>, - Ready>, - >; + type Future = Either>>; #[inline] fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { @@ -205,25 +201,40 @@ impl FromRequest for String { Err(e) => return Either::Right(err(e.into())), }; let limit = cfg.limit; - let fut = HttpMessageBody::new(req, payload).limit(limit); + let body_fut = HttpMessageBody::new(req, payload).limit(limit); - Either::Left( - async move { - let body = fut.await?; + Either::Left(StringExtractFut { body_fut, encoding }) + } +} - if encoding == UTF_8 { - Ok(str::from_utf8(body.as_ref()) - .map_err(|_| ErrorBadRequest("Can not decode body"))? - .to_owned()) - } else { - Ok(encoding - .decode_without_bom_handling_and_without_replacement(&body) - .map(|s| s.into_owned()) - .ok_or_else(|| ErrorBadRequest("Can not decode body"))?) - } - } - .boxed_local(), - ) +pub struct StringExtractFut { + body_fut: HttpMessageBody, + encoding: &'static Encoding, +} + +impl<'a> Future for StringExtractFut { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let encoding = self.encoding; + + Pin::new(&mut self.body_fut).poll(cx).map(|out| { + let body = out?; + bytes_to_string(body, encoding) + }) + } +} + +fn bytes_to_string(body: Bytes, encoding: &'static Encoding) -> Result { + if encoding == UTF_8 { + Ok(str::from_utf8(body.as_ref()) + .map_err(|_| ErrorBadRequest("Can not decode body"))? + .to_owned()) + } else { + Ok(encoding + .decode_without_bom_handling_and_without_replacement(&body) + .map(|s| s.into_owned()) + .ok_or_else(|| ErrorBadRequest("Can not decode body"))?) } } @@ -241,9 +252,10 @@ pub struct PayloadConfig { impl PayloadConfig { /// Create `PayloadConfig` instance and set max size of payload. pub fn new(limit: usize) -> Self { - let mut cfg = Self::default(); - cfg.limit = limit; - cfg + Self { + limit, + ..Default::default() + } } /// Change max size of payload. By default max size is 256Kb @@ -290,10 +302,12 @@ impl PayloadConfig { // Allow shared refs to default. const DEFAULT_CONFIG: PayloadConfig = PayloadConfig { - limit: 262_144, // 2^18 bytes (~256kB) + limit: DEFAULT_CONFIG_LIMIT, mimetype: None, }; +const DEFAULT_CONFIG_LIMIT: usize = 262_144; // 2^18 bytes (~256kB) + impl Default for PayloadConfig { fn default() -> Self { DEFAULT_CONFIG.clone() @@ -311,99 +325,83 @@ pub struct HttpMessageBody { limit: usize, length: Option, #[cfg(feature = "compress")] - stream: Option>, + stream: dev::Decompress, #[cfg(not(feature = "compress"))] - stream: Option, + stream: dev::Payload, + buf: BytesMut, err: Option, - fut: Option>>, } impl HttpMessageBody { /// Create `MessageBody` for request. #[allow(clippy::borrow_interior_mutable_const)] pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> HttpMessageBody { - let mut len = None; + let mut length = None; + let mut err = None; + if let Some(l) = req.headers().get(&header::CONTENT_LENGTH) { - if let Ok(s) = l.to_str() { - if let Ok(l) = s.parse::() { - len = Some(l) - } else { - return Self::err(PayloadError::UnknownLength); - } - } else { - return Self::err(PayloadError::UnknownLength); + match l.to_str() { + Ok(s) => match s.parse::() { + Ok(l) if l > DEFAULT_CONFIG_LIMIT => { + err = Some(PayloadError::Overflow) + } + Ok(l) => length = Some(l), + Err(_) => err = Some(PayloadError::UnknownLength), + }, + Err(_) => err = Some(PayloadError::UnknownLength), } } #[cfg(feature = "compress")] - let stream = Some(dev::Decompress::from_headers(payload.take(), req.headers())); + let stream = dev::Decompress::from_headers(payload.take(), req.headers()); #[cfg(not(feature = "compress"))] - let stream = Some(payload.take()); + let stream = payload.take(); HttpMessageBody { stream, - limit: 262_144, - length: len, - fut: None, - err: None, + limit: DEFAULT_CONFIG_LIMIT, + length, + buf: BytesMut::with_capacity(8192), + err, } } /// Change max size of payload. By default max size is 256Kb pub fn limit(mut self, limit: usize) -> Self { + if let Some(l) = self.length { + if l > limit { + self.err = Some(PayloadError::Overflow); + } + } self.limit = limit; self } - - fn err(e: PayloadError) -> Self { - HttpMessageBody { - stream: None, - limit: 262_144, - fut: None, - err: Some(e), - length: None, - } - } } impl Future for HttpMessageBody { type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if let Some(ref mut fut) = self.fut { - return Pin::new(fut).poll(cx); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + if let Some(e) = this.err.take() { + return Poll::Ready(Err(e)); } - if let Some(err) = self.err.take() { - return Poll::Ready(Err(err)); - } - - if let Some(len) = self.length.take() { - if len > self.limit { - return Poll::Ready(Err(PayloadError::Overflow)); - } - } - - // future - let limit = self.limit; - let mut stream = self.stream.take().unwrap(); - self.fut = Some( - async move { - let mut body = BytesMut::with_capacity(8192); - - while let Some(item) = stream.next().await { - let chunk = item?; - if body.len() + chunk.len() > limit { - return Err(PayloadError::Overflow); + loop { + let res = ready!(Pin::new(&mut this.stream).poll_next(cx)); + match res { + Some(chunk) => { + let chunk = chunk?; + if this.buf.len() + chunk.len() > this.limit { + return Poll::Ready(Err(PayloadError::Overflow)); } else { - body.extend_from_slice(&chunk); + this.buf.extend_from_slice(&chunk); } } - Ok(body.freeze()) + None => return Poll::Ready(Ok(this.buf.split().freeze())), } - .boxed_local(), - ); - self.poll(cx) + } } } diff --git a/src/types/query.rs b/src/types/query.rs index 7eded49c5..27df220fc 100644 --- a/src/types/query.rs +++ b/src/types/query.rs @@ -18,7 +18,7 @@ use crate::request::HttpRequest; /// be decoded into any type which depends upon data ordering e.g. tuples or tuple-structs. /// Attempts to do so will *fail at runtime*. /// -/// [**QueryConfig**](struct.QueryConfig.html) allows to configure extraction process. +/// [**QueryConfig**](QueryConfig) allows to configure extraction process. /// /// ## Example /// diff --git a/tests/test_server.rs b/tests/test_server.rs index f8a9ab86d..c6c316f0d 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -248,6 +248,7 @@ async fn test_body_gzip_large_random() { let data = rand::thread_rng() .sample_iter(&Alphanumeric) .take(70_000) + .map(char::from) .collect::(); let srv_data = data.clone(); @@ -529,6 +530,7 @@ async fn test_reading_gzip_encoding_large_random() { let data = rand::thread_rng() .sample_iter(&Alphanumeric) .take(60_000) + .map(char::from) .collect::(); let srv = test::start_with(test::config().h1(), || { @@ -614,6 +616,7 @@ async fn test_reading_deflate_encoding_large_random() { let data = rand::thread_rng() .sample_iter(&Alphanumeric) .take(160_000) + .map(char::from) .collect::(); let srv = test::start_with(test::config().h1(), || { @@ -672,6 +675,7 @@ async fn test_brotli_encoding_large() { let data = rand::thread_rng() .sample_iter(&Alphanumeric) .take(320_000) + .map(char::from) .collect::(); let srv = test::start_with(test::config().h1(), || { @@ -753,6 +757,7 @@ async fn test_reading_deflate_encoding_large_random_rustls() { let data = rand::thread_rng() .sample_iter(&Alphanumeric) .take(160_000) + .map(char::from) .collect::(); // load ssl keys