diff --git a/actix-http/src/body/body_stream.rs b/actix-http/src/body/body_stream.rs index aaf0dd8f6..ba32e56e1 100644 --- a/actix-http/src/body/body_stream.rs +++ b/actix-http/src/body/body_stream.rs @@ -65,11 +65,16 @@ where #[cfg(test)] mod tests { - use std::convert::Infallible; + use std::{convert::Infallible, time::Duration}; - use actix_rt::pin; + use actix_rt::{ + pin, + time::{sleep, Sleep}, + }; use actix_utils::future::poll_fn; - use futures_util::stream; + use derive_more::{Display, Error}; + use futures_core::ready; + use futures_util::{stream, FutureExt}; use super::*; use crate::body::to_bytes; @@ -109,4 +114,60 @@ mod tests { assert_eq!(to_bytes(body).await.ok(), Some(Bytes::from("12"))); } + #[derive(Debug, Display, Error)] + #[display(fmt = "stream error")] + struct StreamErr; + + #[actix_rt::test] + async fn stream_immediate_error() { + let body = BodyStream::new(stream::once(async { Err(StreamErr) })); + assert!(matches!(to_bytes(body).await, Err(StreamErr))); + } + + #[actix_rt::test] + async fn stream_delayed_error() { + let body = + BodyStream::new(stream::iter(vec![Ok(Bytes::from("1")), Err(StreamErr)])); + assert!(matches!(to_bytes(body).await, Err(StreamErr))); + + #[pin_project::pin_project(project = TimeDelayStreamProj)] + #[derive(Debug)] + enum TimeDelayStream { + Start, + Sleep(#[pin] Sleep), + Done, + } + + impl Stream for TimeDelayStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.as_mut().project(); + + match this { + TimeDelayStreamProj::Start => { + let sleep = sleep(Duration::from_millis(1)); + self.set(TimeDelayStream::Sleep(sleep)); + cx.waker().wake_by_ref(); + Poll::Pending + } + + TimeDelayStreamProj::Sleep(mut delay) => { + ready!(delay.poll_unpin(cx)); + self.set(TimeDelayStream::Done); + cx.waker().wake_by_ref(); + Poll::Pending + } + + TimeDelayStreamProj::Done => Poll::Ready(Some(Err(StreamErr))), + } + } + } + + let body = BodyStream::new(TimeDelayStream::Start); + assert!(matches!(to_bytes(body).await, Err(StreamErr))); + } }