diff --git a/CHANGES.md b/CHANGES.md index f450c70d..aa4d7c45 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,8 +1,8 @@ # Changes -## 0.6.6 (2018-05-xx) +## 0.6.6 (2018-05-16) -.. +* Panic during middleware execution #226 ## 0.6.5 (2018-05-15) diff --git a/src/pipeline.rs b/src/pipeline.rs index 4d5d405c..82ec45a7 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -284,12 +284,12 @@ impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> { if let Some(resp) = resp { return Some(RunMiddlewares::init(info, resp)); } - if info.count == len { - let reply = unsafe { &mut *self.hnd.get() } - .handle(info.req().clone(), self.htype); - return Some(WaitingResponse::init(info, reply)); - } else { - loop { + loop { + if info.count == len { + let reply = unsafe { &mut *self.hnd.get() } + .handle(info.req().clone(), self.htype); + return Some(WaitingResponse::init(info, reply)); + } else { match info.mws[info.count as usize].start(info.req_mut()) { Ok(Started::Done) => info.count += 1, Ok(Started::Response(resp)) => { diff --git a/src/route.rs b/src/route.rs index 1322d108..b109fd60 100644 --- a/src/route.rs +++ b/src/route.rs @@ -5,13 +5,17 @@ use std::rc::Rc; use futures::{Async, Future, Poll}; use error::Error; -use handler::{AsyncHandler, AsyncResult, AsyncResultItem, FromRequest, Handler, - Responder, RouteHandler, WrapHandler}; +use handler::{ + AsyncHandler, AsyncResult, AsyncResultItem, FromRequest, Handler, Responder, + RouteHandler, WrapHandler, +}; use http::StatusCode; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use middleware::{Finished as MiddlewareFinished, Middleware, - Response as MiddlewareResponse, Started as MiddlewareStarted}; +use middleware::{ + Finished as MiddlewareFinished, Middleware, Response as MiddlewareResponse, + Started as MiddlewareStarted, +}; use pred::Predicate; use with::{ExtractorConfig, With, With2, With3, WithAsync}; @@ -51,7 +55,9 @@ impl<S: 'static> Route<S> { #[inline] pub(crate) fn compose( - &mut self, req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>, + &mut self, + req: HttpRequest<S>, + mws: Rc<Vec<Box<Middleware<S>>>>, ) -> AsyncResult<HttpResponse> { AsyncResult::async(Box::new(Compose::new(req, mws, self.handler.clone()))) } @@ -242,7 +248,8 @@ impl<S: 'static> Route<S> { /// } /// ``` pub fn with2<T1, T2, F, R>( - &mut self, handler: F, + &mut self, + handler: F, ) -> (ExtractorConfig<S, T1>, ExtractorConfig<S, T2>) where F: Fn(T1, T2) -> R + 'static, @@ -263,7 +270,8 @@ impl<S: 'static> Route<S> { #[doc(hidden)] /// Set handler function, use request extractor for all parameters. pub fn with3<T1, T2, T3, F, R>( - &mut self, handler: F, + &mut self, + handler: F, ) -> ( ExtractorConfig<S, T1>, ExtractorConfig<S, T2>, @@ -296,9 +304,7 @@ struct InnerHandler<S>(Rc<UnsafeCell<Box<RouteHandler<S>>>>); impl<S: 'static> InnerHandler<S> { #[inline] fn new<H: Handler<S>>(h: H) -> Self { - InnerHandler(Rc::new(UnsafeCell::new(Box::new(WrapHandler::new( - h, - ))))) + InnerHandler(Rc::new(UnsafeCell::new(Box::new(WrapHandler::new(h))))) } #[inline] @@ -309,9 +315,7 @@ impl<S: 'static> InnerHandler<S> { R: Responder + 'static, E: Into<Error> + 'static, { - InnerHandler(Rc::new(UnsafeCell::new(Box::new(AsyncHandler::new( - h, - ))))) + InnerHandler(Rc::new(UnsafeCell::new(Box::new(AsyncHandler::new(h))))) } #[inline] @@ -364,7 +368,9 @@ impl<S: 'static> ComposeState<S> { impl<S: 'static> Compose<S> { fn new( - req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>, handler: InnerHandler<S>, + req: HttpRequest<S>, + mws: Rc<Vec<Box<Middleware<S>>>>, + handler: InnerHandler<S>, ) -> Self { let mut info = ComposeInfo { count: 0, @@ -440,11 +446,11 @@ impl<S: 'static> StartMiddlewares<S> { if let Some(resp) = resp { return Some(RunMiddlewares::init(info, resp)); } - if info.count == len { - let reply = info.handler.handle(info.req.clone()); - return Some(WaitingResponse::init(info, reply)); - } else { - loop { + loop { + if info.count == len { + let reply = info.handler.handle(info.req.clone()); + return Some(WaitingResponse::init(info, reply)); + } else { match info.mws[info.count].start(&mut info.req) { Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Response(resp)) => { @@ -479,7 +485,8 @@ struct WaitingResponse<S> { impl<S: 'static> WaitingResponse<S> { #[inline] fn init( - info: &mut ComposeInfo<S>, reply: AsyncResult<HttpResponse>, + info: &mut ComposeInfo<S>, + reply: AsyncResult<HttpResponse>, ) -> ComposeState<S> { match reply.into() { AsyncResultItem::Err(err) => RunMiddlewares::init(info, err.into()), diff --git a/src/scope.rs b/src/scope.rs index aecfd6bf..00bcadad 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -10,8 +10,10 @@ use handler::{AsyncResult, AsyncResultItem, FromRequest, Responder, RouteHandler use http::Method; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use middleware::{Finished as MiddlewareFinished, Middleware, - Response as MiddlewareResponse, Started as MiddlewareStarted}; +use middleware::{ + Finished as MiddlewareFinished, Middleware, Response as MiddlewareResponse, + Started as MiddlewareStarted, +}; use pred::Predicate; use resource::ResourceHandler; use router::Resource; @@ -400,8 +402,7 @@ struct Wrapper<S: 'static> { impl<S: 'static, S2: 'static> RouteHandler<S2> for Wrapper<S> { fn handle(&mut self, req: HttpRequest<S2>) -> AsyncResult<HttpResponse> { - self.scope - .handle(req.change_state(Rc::clone(&self.state))) + self.scope.handle(req.change_state(Rc::clone(&self.state))) } } @@ -458,7 +459,8 @@ impl<S: 'static> ComposeState<S> { impl<S: 'static> Compose<S> { fn new( - req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>, + req: HttpRequest<S>, + mws: Rc<Vec<Box<Middleware<S>>>>, resource: Rc<UnsafeCell<ResourceHandler<S>>>, default: Option<Rc<UnsafeCell<ResourceHandler<S>>>>, ) -> Self { @@ -543,17 +545,17 @@ impl<S: 'static> StartMiddlewares<S> { if let Some(resp) = resp { return Some(RunMiddlewares::init(info, resp)); } - if info.count == len { - let resource = unsafe { &mut *info.resource.get() }; - let reply = if let Some(ref default) = info.default { - let d = unsafe { &mut *default.as_ref().get() }; - resource.handle(info.req.clone(), Some(d)) + loop { + if info.count == len { + let resource = unsafe { &mut *info.resource.get() }; + let reply = if let Some(ref default) = info.default { + let d = unsafe { &mut *default.as_ref().get() }; + resource.handle(info.req.clone(), Some(d)) + } else { + resource.handle(info.req.clone(), None) + }; + return Some(WaitingResponse::init(info, reply)); } else { - resource.handle(info.req.clone(), None) - }; - return Some(WaitingResponse::init(info, reply)); - } else { - loop { match info.mws[info.count].start(&mut info.req) { Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Response(resp)) => { @@ -583,7 +585,8 @@ struct WaitingResponse<S> { impl<S: 'static> WaitingResponse<S> { #[inline] fn init( - info: &mut ComposeInfo<S>, reply: AsyncResult<HttpResponse>, + info: &mut ComposeInfo<S>, + reply: AsyncResult<HttpResponse>, ) -> ComposeState<S> { match reply.into() { AsyncResultItem::Ok(resp) => RunMiddlewares::init(info, resp), diff --git a/tests/test_middleware.rs b/tests/test_middleware.rs index 99151afd..2c9160b6 100644 --- a/tests/test_middleware.rs +++ b/tests/test_middleware.rs @@ -21,28 +21,24 @@ struct MiddlewareTest { impl<S> middleware::Middleware<S> for MiddlewareTest { fn start(&self, _: &mut HttpRequest<S>) -> Result<middleware::Started> { - self.start.store( - self.start.load(Ordering::Relaxed) + 1, - Ordering::Relaxed, - ); + self.start + .store(self.start.load(Ordering::Relaxed) + 1, Ordering::Relaxed); Ok(middleware::Started::Done) } fn response( - &self, _: &mut HttpRequest<S>, resp: HttpResponse, + &self, + _: &mut HttpRequest<S>, + resp: HttpResponse, ) -> Result<middleware::Response> { - self.response.store( - self.response.load(Ordering::Relaxed) + 1, - Ordering::Relaxed, - ); + self.response + .store(self.response.load(Ordering::Relaxed) + 1, Ordering::Relaxed); Ok(middleware::Response::Done(resp)) } fn finish(&self, _: &mut HttpRequest<S>, _: &HttpResponse) -> middleware::Finished { - self.finish.store( - self.finish.load(Ordering::Relaxed) + 1, - Ordering::Relaxed, - ); + self.finish + .store(self.finish.load(Ordering::Relaxed) + 1, Ordering::Relaxed); middleware::Finished::Done } } @@ -187,10 +183,7 @@ fn test_scope_middleware() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -226,10 +219,7 @@ fn test_scope_middleware_multiple() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -337,10 +327,7 @@ fn test_scope_middleware_async_handler() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -402,10 +389,7 @@ fn test_scope_middleware_async_error() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); @@ -466,7 +450,9 @@ impl<S> middleware::Middleware<S> for MiddlewareAsyncTest { } fn response( - &self, _: &mut HttpRequest<S>, resp: HttpResponse, + &self, + _: &mut HttpRequest<S>, + resp: HttpResponse, ) -> Result<middleware::Response> { let to = Timeout::new(Duration::from_millis(10), &Arbiter::handle()).unwrap(); @@ -555,6 +541,42 @@ fn test_async_middleware_multiple() { assert_eq!(num3.load(Ordering::Relaxed), 2); } +#[test] +fn test_async_sync_middleware_multiple() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + App::new() + .middleware(MiddlewareAsyncTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }) + .middleware(MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }) + .resource("/test", |r| r.f(|_| HttpResponse::Ok())) + }); + + let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + assert_eq!(num1.load(Ordering::Relaxed), 2); + assert_eq!(num2.load(Ordering::Relaxed), 2); + + thread::sleep(Duration::from_millis(50)); + assert_eq!(num3.load(Ordering::Relaxed), 2); +} + #[test] fn test_async_scope_middleware() { let num1 = Arc::new(AtomicUsize::new(0)); @@ -577,10 +599,7 @@ fn test_async_scope_middleware() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -618,10 +637,45 @@ fn test_async_scope_middleware_multiple() { }) }); - let request = srv.get() - .uri(srv.url("/scope/test")) - .finish() - .unwrap(); + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + assert_eq!(num1.load(Ordering::Relaxed), 2); + assert_eq!(num2.load(Ordering::Relaxed), 2); + + thread::sleep(Duration::from_millis(20)); + assert_eq!(num3.load(Ordering::Relaxed), 2); +} + +#[test] +fn test_async_async_scope_middleware_multiple() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + App::new().scope("/scope", |scope| { + scope + .middleware(MiddlewareAsyncTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }) + .middleware(MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }) + .resource("/test", |r| r.f(|_| HttpResponse::Ok())) + }) + }); + + let request = srv.get().uri(srv.url("/scope/test")).finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -703,3 +757,42 @@ fn test_async_resource_middleware_multiple() { thread::sleep(Duration::from_millis(40)); assert_eq!(num3.load(Ordering::Relaxed), 2); } + +#[test] +fn test_async_sync_resource_middleware_multiple() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let mut srv = test::TestServer::with_factory(move || { + let mw1 = MiddlewareAsyncTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }; + let mw2 = MiddlewareTest { + start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3), + }; + App::new().resource("/test", move |r| { + r.middleware(mw1); + r.middleware(mw2); + r.h(|_| HttpResponse::Ok()); + }) + }); + + let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + assert_eq!(num1.load(Ordering::Relaxed), 2); + assert_eq!(num2.load(Ordering::Relaxed), 2); + + thread::sleep(Duration::from_millis(40)); + assert_eq!(num3.load(Ordering::Relaxed), 2); +}