From ea379d532a1b3fab8765ca7bc0c5df7299d9fa1f Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Wed, 21 Apr 2021 17:32:08 +0100 Subject: [PATCH] introduce encoder error type --- actix-http/src/encoding/encoder.rs | 74 +++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 12 deletions(-) diff --git a/actix-http/src/encoding/encoder.rs b/actix-http/src/encoding/encoder.rs index c41f79242..addd6b7a8 100644 --- a/actix-http/src/encoding/encoder.rs +++ b/actix-http/src/encoding/encoder.rs @@ -1,6 +1,7 @@ //! Stream encoders. use std::{ + error::Error as StdError, future::Future, io::{self, Write as _}, pin::Pin, @@ -10,6 +11,7 @@ use std::{ use actix_rt::task::{spawn_blocking, JoinHandle}; use brotli2::write::BrotliEncoder; use bytes::Bytes; +use derive_more::Display; use flate2::write::{GzEncoder, ZlibEncoder}; use futures_core::ready; use pin_project::pin_project; @@ -95,8 +97,12 @@ enum EncoderBody { BoxedStream(BoxAnyBody), } -impl> MessageBody for EncoderBody { - type Error = Error; +impl MessageBody for EncoderBody +where + B: MessageBody, + B::Error: Into, +{ + type Error = EncoderError; fn size(&self) -> BodySize { match self { @@ -109,7 +115,7 @@ impl> MessageBody for EncoderBody { fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { match self.project() { EncoderBodyProj::Bytes(b) => { if b.is_empty() { @@ -118,10 +124,17 @@ impl> MessageBody for EncoderBody { Poll::Ready(Some(Ok(std::mem::take(b)))) } } - EncoderBodyProj::Stream(b) => b.poll_next(cx), + // TODO: MSRV 1.51: poll_map_err + EncoderBodyProj::Stream(b) => match ready!(b.poll_next(cx)) { + Some(Err(err)) => Poll::Ready(Some(Err(EncoderError::Body(err)))), + Some(Ok(val)) => Poll::Ready(Some(Ok(val))), + None => Poll::Ready(None), + }, EncoderBodyProj::BoxedStream(ref mut b) => { match ready!(b.as_mut().poll_next(cx)) { - Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), + Some(Err(err)) => { + Poll::Ready(Some(Err(EncoderError::Boxed(err.into())))) + } Some(Ok(val)) => Poll::Ready(Some(Ok(val))), None => Poll::Ready(None), } @@ -130,8 +143,12 @@ impl> MessageBody for EncoderBody { } } -impl> MessageBody for Encoder { - type Error = Error; +impl MessageBody for Encoder +where + B: MessageBody, + B::Error: Into, +{ + type Error = EncoderError; fn size(&self) -> BodySize { if self.encoder.is_none() { @@ -144,7 +161,7 @@ impl> MessageBody for Encoder { fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { let mut this = self.project(); loop { if *this.eof { @@ -152,8 +169,9 @@ impl> MessageBody for Encoder { } if let Some(ref mut fut) = this.fut { - let mut encoder = - ready!(Pin::new(fut).poll(cx)).map_err(|_| BlockingError)??; + let mut encoder = ready!(Pin::new(fut).poll(cx)) + .map_err(|_| EncoderError::Blocking(BlockingError))? + .map_err(EncoderError::Io)?; let chunk = encoder.take(); *this.encoder = Some(encoder); @@ -172,7 +190,7 @@ impl> MessageBody for Encoder { Some(Ok(chunk)) => { if let Some(mut encoder) = this.encoder.take() { if chunk.len() < MAX_CHUNK_SIZE_ENCODE_IN_PLACE { - encoder.write(&chunk)?; + encoder.write(&chunk).map_err(EncoderError::Io)?; let chunk = encoder.take(); *this.encoder = Some(encoder); @@ -192,7 +210,7 @@ impl> MessageBody for Encoder { None => { if let Some(encoder) = this.encoder.take() { - let chunk = encoder.finish()?; + let chunk = encoder.finish().map_err(EncoderError::Io)?; if chunk.is_empty() { return Poll::Ready(None); } else { @@ -291,3 +309,35 @@ impl ContentEncoder { } } } + +#[derive(Debug, Display)] +pub enum EncoderError { + #[display(fmt = "body")] + Body(E), + + #[display(fmt = "boxed")] + Boxed(Error), + + #[display(fmt = "blocking")] + Blocking(BlockingError), + + #[display(fmt = "io")] + Io(io::Error), +} + +impl StdError for EncoderError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + None + } +} + +impl> From> for Error { + fn from(err: EncoderError) -> Self { + match err { + EncoderError::Body(err) => err.into(), + EncoderError::Boxed(err) => err, + EncoderError::Blocking(err) => err.into(), + EncoderError::Io(err) => err.into(), + } + } +}