diff --git a/src/application.rs b/src/application.rs index 1e4d8273..8cf5db26 100644 --- a/src/application.rs +++ b/src/application.rs @@ -43,8 +43,8 @@ impl<S: 'static> PipelineHandler<S> for Inner<S> { path.split_at(prefix.len()).1.starts_with('/')) }; if m { - let path: &'static str = unsafe{ - mem::transmute(&req.path()[self.prefix+prefix.len()..])}; + let path: &'static str = unsafe { + mem::transmute(&req.path()[self.prefix+prefix.len()..]) }; if path.is_empty() { req.match_info_mut().add("tail", ""); } else { @@ -321,9 +321,7 @@ impl<S> Application<S> where S: 'static { } /// Register a middleware - pub fn middleware<T>(mut self, mw: T) -> Application<S> - where T: Middleware<S> + 'static - { + pub fn middleware<M: Middleware<S>>(mut self, mw: M) -> Application<S> { self.parts.as_mut().expect("Use after finish") .middlewares.push(Box::new(mw)); self diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index b9798c97..70f5712e 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -46,7 +46,7 @@ pub enum Finished { /// Middleware definition #[allow(unused_variables)] -pub trait Middleware<S> { +pub trait Middleware<S>: 'static { /// Method is called when request is ready. It may return /// future, which should resolve before next middleware get called. diff --git a/src/pipeline.rs b/src/pipeline.rs index 44c50310..9873958b 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -74,7 +74,7 @@ impl<S> PipelineInfo<S> { } } -impl<S, H: PipelineHandler<S>> Pipeline<S, H> { +impl<S: 'static, H: PipelineHandler<S>> Pipeline<S, H> { pub fn new(req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>, @@ -101,7 +101,7 @@ impl Pipeline<(), Inner<()>> { } } -impl<S, H> Pipeline<S, H> { +impl<S: 'static, H> Pipeline<S, H> { fn is_done(&self) -> bool { match self.1 { @@ -114,7 +114,7 @@ impl<S, H> Pipeline<S, H> { } } -impl<S, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> { +impl<S: 'static, H: PipelineHandler<S>> HttpHandlerTask for Pipeline<S, H> { fn disconnected(&mut self) { self.0.disconnected = Some(true); @@ -277,7 +277,7 @@ struct StartMiddlewares<S, H> { _s: PhantomData<S>, } -impl<S, H: PipelineHandler<S>> StartMiddlewares<S, H> { +impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> { fn init(info: &mut PipelineInfo<S>, handler: Rc<RefCell<H>>) -> PipelineState<S, H> { @@ -364,7 +364,7 @@ struct WaitingResponse<S, H> { _h: PhantomData<H>, } -impl<S, H> WaitingResponse<S, H> { +impl<S: 'static, H> WaitingResponse<S, H> { #[inline] fn init(info: &mut PipelineInfo<S>, reply: Reply) -> PipelineState<S, H> @@ -399,7 +399,7 @@ struct RunMiddlewares<S, H> { _h: PhantomData<H>, } -impl<S, H> RunMiddlewares<S, H> { +impl<S: 'static, H> RunMiddlewares<S, H> { fn init(info: &mut PipelineInfo<S>, mut resp: HttpResponse) -> PipelineState<S, H> { @@ -510,7 +510,7 @@ enum IOState { Done, } -impl<S, H> ProcessResponse<S, H> { +impl<S: 'static, H> ProcessResponse<S, H> { #[inline] fn init(resp: HttpResponse) -> PipelineState<S, H> @@ -550,19 +550,6 @@ impl<S, H> ProcessResponse<S, H> { result }, IOState::Payload(mut body) => { - // always poll context - if self.running == RunningState::Running { - match info.poll_context() { - Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => - self.running = RunningState::Done, - Err(err) => { - info.error = Some(err); - return Ok(FinishingMiddlewares::init(info, self.resp)) - } - } - } - match body.poll() { Ok(Async::Ready(None)) => { self.iostate = IOState::Done; @@ -706,7 +693,7 @@ struct FinishingMiddlewares<S, H> { _h: PhantomData<H>, } -impl<S, H> FinishingMiddlewares<S, H> { +impl<S: 'static, H> FinishingMiddlewares<S, H> { fn init(info: &mut PipelineInfo<S>, resp: HttpResponse) -> PipelineState<S, H> { if info.count == 0 { diff --git a/src/resource.rs b/src/resource.rs index ee6d682e..c9e1251c 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -1,3 +1,4 @@ +use std::rc::Rc; use std::marker::PhantomData; use http::{Method, StatusCode}; @@ -6,6 +7,7 @@ use pred; use body::Body; use route::Route; use handler::{Reply, Handler, Responder}; +use middleware::Middleware; use httprequest::HttpRequest; use httpresponse::HttpResponse; @@ -33,6 +35,7 @@ pub struct Resource<S=()> { name: String, state: PhantomData<S>, routes: Vec<Route<S>>, + middlewares: Rc<Vec<Box<Middleware<S>>>>, } impl<S> Default for Resource<S> { @@ -40,7 +43,8 @@ impl<S> Default for Resource<S> { Resource { name: String::new(), state: PhantomData, - routes: Vec::new() } + routes: Vec::new(), + middlewares: Rc::new(Vec::new()) } } } @@ -50,7 +54,8 @@ impl<S> Resource<S> { Resource { name: String::new(), state: PhantomData, - routes: Vec::new() } + routes: Vec::new(), + middlewares: Rc::new(Vec::new()) } } /// Set resource name @@ -126,12 +131,25 @@ impl<S: 'static> Resource<S> { self.routes.last_mut().unwrap().f(handler) } - pub(crate) fn handle(&mut self, mut req: HttpRequest<S>, default: Option<&mut Resource<S>>) - -> Reply + /// Register a middleware + /// + /// This is similar to `Application's` middlewares, but + /// middlewares get invoked on resource level. + pub fn middleware<M: Middleware<S>>(&mut self, mw: M) { + Rc::get_mut(&mut self.middlewares).unwrap().push(Box::new(mw)); + } + + pub(crate) fn handle(&mut self, + mut req: HttpRequest<S>, + default: Option<&mut Resource<S>>) -> Reply { for route in &mut self.routes { if route.check(&mut req) { - return route.handle(req) + return if self.middlewares.is_empty() { + route.handle(req) + } else { + route.compose(req, Rc::clone(&self.middlewares)) + }; } } if let Some(resource) = default { diff --git a/src/route.rs b/src/route.rs index 64b60603..acef0fd4 100644 --- a/src/route.rs +++ b/src/route.rs @@ -1,10 +1,16 @@ -use futures::Future; +use std::mem; +use std::rc::Rc; +use std::marker::PhantomData; +use futures::{Async, Future, Poll}; use error::Error; use pred::Predicate; -use handler::{Reply, Handler, Responder, RouteHandler, AsyncHandler, WrapHandler}; +use handler::{Reply, ReplyItem, Handler, + Responder, RouteHandler, AsyncHandler, WrapHandler}; +use middleware::{Middleware, Response as MiddlewareResponse, Started as MiddlewareStarted}; use httpcodes::HTTPNotFound; use httprequest::HttpRequest; +use httpresponse::HttpResponse; /// Resource route definition /// @@ -12,7 +18,7 @@ use httprequest::HttpRequest; /// If handler is not explicitly set, default *404 Not Found* handler is used. pub struct Route<S> { preds: Vec<Box<Predicate<S>>>, - handler: Box<RouteHandler<S>>, + handler: InnerHandler<S>, } impl<S: 'static> Default for Route<S> { @@ -20,13 +26,14 @@ impl<S: 'static> Default for Route<S> { fn default() -> Route<S> { Route { preds: Vec::new(), - handler: Box::new(WrapHandler::new(|_| HTTPNotFound)), + handler: InnerHandler::new(|_| HTTPNotFound), } } } impl<S: 'static> Route<S> { + #[inline] pub(crate) fn check(&self, req: &mut HttpRequest<S>) -> bool { for pred in &self.preds { if !pred.check(req) { @@ -36,10 +43,18 @@ impl<S: 'static> Route<S> { true } + #[inline] pub(crate) fn handle(&mut self, req: HttpRequest<S>) -> Reply { self.handler.handle(req) } + #[inline] + pub(crate) fn compose(&mut self, + req: HttpRequest<S>, + mws: Rc<Vec<Box<Middleware<S>>>>) -> Reply { + Reply::async(Compose::new(req, mws, self.handler.clone())) + } + /// Add match predicate to route. /// /// ```rust @@ -65,7 +80,7 @@ impl<S: 'static> Route<S> { /// Set handler object. Usually call to this method is last call /// during route configuration, because it does not return reference to self. pub fn h<H: Handler<S>>(&mut self, handler: H) { - self.handler = Box::new(WrapHandler::new(handler)); + self.handler = InnerHandler::new(handler); } /// Set handler function. Usually call to this method is last call @@ -74,7 +89,7 @@ impl<S: 'static> Route<S> { where F: Fn(HttpRequest<S>) -> R + 'static, R: Responder + 'static, { - self.handler = Box::new(WrapHandler::new(handler)); + self.handler = InnerHandler::new(handler); } /// Set async handler function. @@ -84,6 +99,315 @@ impl<S: 'static> Route<S> { R: Responder + 'static, E: Into<Error> + 'static { - self.handler = Box::new(AsyncHandler::new(handler)); + self.handler = InnerHandler::async(handler); + } +} + +/// RouteHandler wrapper. This struct is required because it needs to be shared +/// for resource level middlewares. +struct InnerHandler<S>(Rc<Box<RouteHandler<S>>>); + +impl<S: 'static> InnerHandler<S> { + + #[inline] + fn new<H: Handler<S>>(h: H) -> Self { + InnerHandler(Rc::new(Box::new(WrapHandler::new(h)))) + } + + #[inline] + fn async<H, R, F, E>(h: H) -> Self + where H: Fn(HttpRequest<S>) -> F + 'static, + F: Future<Item=R, Error=E> + 'static, + R: Responder + 'static, + E: Into<Error> + 'static + { + InnerHandler(Rc::new(Box::new(AsyncHandler::new(h)))) + } + + #[inline] + pub fn handle(&self, req: HttpRequest<S>) -> Reply { + // reason: handler is unique per thread, + // handler get called from async code, and handler doesnt have side effects + #[allow(mutable_transmutes)] + #[cfg_attr(feature = "cargo-clippy", allow(borrowed_box))] + let h: &mut Box<RouteHandler<S>> = unsafe { mem::transmute(self.0.as_ref()) }; + h.handle(req) + } +} + +impl<S> Clone for InnerHandler<S> { + #[inline] + fn clone(&self) -> Self { + InnerHandler(Rc::clone(&self.0)) + } +} + + +/// Compose resource level middlewares with route handler. +struct Compose<S: 'static> { + info: ComposeInfo<S>, + state: ComposeState<S>, +} + +struct ComposeInfo<S: 'static> { + count: usize, + req: HttpRequest<S>, + mws: Rc<Vec<Box<Middleware<S>>>>, + handler: InnerHandler<S>, +} + +enum ComposeState<S: 'static> { + Starting(StartMiddlewares<S>), + Handler(WaitingResponse<S>), + RunMiddlewares(RunMiddlewares<S>), + Response(Response<S>), +} + +impl<S: 'static> ComposeState<S> { + fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> { + match *self { + ComposeState::Starting(ref mut state) => state.poll(info), + ComposeState::Handler(ref mut state) => state.poll(info), + ComposeState::RunMiddlewares(ref mut state) => state.poll(info), + ComposeState::Response(_) => None, + } + } +} + +impl<S: 'static> Compose<S> { + fn new(req: HttpRequest<S>, + mws: Rc<Vec<Box<Middleware<S>>>>, + handler: InnerHandler<S>) -> Self + { + let mut info = ComposeInfo { + count: 0, + req: req, + mws: mws, + handler: handler }; + let state = StartMiddlewares::init(&mut info); + + Compose {state: state, info: info} + } +} + +impl<S> Future for Compose<S> { + type Item = HttpResponse; + type Error = Error; + + fn poll(&mut self) -> Poll<Self::Item, Self::Error> { + loop { + if let ComposeState::Response(ref mut resp) = self.state { + let resp = resp.resp.take().unwrap(); + return Ok(Async::Ready(resp)) + } + if let Some(state) = self.state.poll(&mut self.info) { + self.state = state; + } else { + return Ok(Async::NotReady) + } + } + } +} + +/// Middlewares start executor +struct StartMiddlewares<S> { + fut: Option<Fut>, + _s: PhantomData<S>, +} + +type Fut = Box<Future<Item=Option<HttpResponse>, Error=Error>>; + +impl<S: 'static> StartMiddlewares<S> { + + fn init(info: &mut ComposeInfo<S>) -> ComposeState<S> { + let len = info.mws.len(); + loop { + if info.count == len { + let reply = info.handler.handle(info.req.clone()); + return WaitingResponse::init(info, reply) + } else { + match info.mws[info.count].start(&mut info.req) { + MiddlewareStarted::Done => + info.count += 1, + MiddlewareStarted::Response(resp) => + return RunMiddlewares::init(info, resp), + MiddlewareStarted::Future(mut fut) => + match fut.poll() { + Ok(Async::NotReady) => + return ComposeState::Starting(StartMiddlewares { + fut: Some(fut), + _s: PhantomData}), + Ok(Async::Ready(resp)) => { + if let Some(resp) = resp { + return RunMiddlewares::init(info, resp); + } + info.count += 1; + } + Err(err) => + return Response::init(err.into()), + }, + MiddlewareStarted::Err(err) => + return Response::init(err.into()), + } + } + } + } + + fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> + { + let len = info.mws.len(); + 'outer: loop { + match self.fut.as_mut().unwrap().poll() { + Ok(Async::NotReady) => + return None, + Ok(Async::Ready(resp)) => { + info.count += 1; + if let Some(resp) = resp { + return Some(RunMiddlewares::init(info, resp)); + } + if info.count == len { + let reply = info.handler.handle(info.req.clone()); + return Some(WaitingResponse::init(info, reply)); + } else { + loop { + match info.mws[info.count].start(&mut info.req) { + MiddlewareStarted::Done => + info.count += 1, + MiddlewareStarted::Response(resp) => { + return Some(RunMiddlewares::init(info, resp)); + }, + MiddlewareStarted::Future(fut) => { + self.fut = Some(fut); + continue 'outer + }, + MiddlewareStarted::Err(err) => + return Some(Response::init(err.into())) + } + } + } + } + Err(err) => + return Some(Response::init(err.into())) + } + } + } +} + +// waiting for response +struct WaitingResponse<S> { + fut: Box<Future<Item=HttpResponse, Error=Error>>, + _s: PhantomData<S>, +} + +impl<S: 'static> WaitingResponse<S> { + + #[inline] + fn init(info: &mut ComposeInfo<S>, reply: Reply) -> ComposeState<S> { + match reply.into() { + ReplyItem::Message(resp) => + RunMiddlewares::init(info, resp), + ReplyItem::Future(fut) => + ComposeState::Handler( + WaitingResponse { fut: fut, _s: PhantomData }), + } + } + + fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> { + match self.fut.poll() { + Ok(Async::NotReady) => None, + Ok(Async::Ready(response)) => + Some(RunMiddlewares::init(info, response)), + Err(err) => + Some(Response::init(err.into())), + } + } +} + + +/// Middlewares response executor +struct RunMiddlewares<S> { + curr: usize, + fut: Option<Box<Future<Item=HttpResponse, Error=Error>>>, + _s: PhantomData<S>, +} + +impl<S: 'static> RunMiddlewares<S> { + + fn init(info: &mut ComposeInfo<S>, mut resp: HttpResponse) -> ComposeState<S> { + let mut curr = 0; + let len = info.mws.len(); + + loop { + resp = match info.mws[curr].response(&mut info.req, resp) { + MiddlewareResponse::Err(err) => { + info.count = curr + 1; + return Response::init(err.into()) + }, + MiddlewareResponse::Done(r) => { + curr += 1; + if curr == len { + return Response::init(r) + } else { + r + } + }, + MiddlewareResponse::Future(fut) => { + return ComposeState::RunMiddlewares( + RunMiddlewares { curr: curr, fut: Some(fut), _s: PhantomData }) + }, + }; + } + } + + fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> + { + let len = info.mws.len(); + + loop { + // poll latest fut + let mut resp = match self.fut.as_mut().unwrap().poll() { + Ok(Async::NotReady) => { + return None + } + Ok(Async::Ready(resp)) => { + self.curr += 1; + resp + } + Err(err) => + return Some(Response::init(err.into())), + }; + + loop { + if self.curr == len { + return Some(Response::init(resp)); + } else { + match info.mws[self.curr].response(&mut info.req, resp) { + MiddlewareResponse::Err(err) => + return Some(Response::init(err.into())), + MiddlewareResponse::Done(r) => { + self.curr += 1; + resp = r + }, + MiddlewareResponse::Future(fut) => { + self.fut = Some(fut); + break + }, + } + } + } + } + } +} + +struct Response<S> { + resp: Option<HttpResponse>, + _s: PhantomData<S>, +} + +impl<S: 'static> Response<S> { + + fn init(resp: HttpResponse) -> ComposeState<S> { + ComposeState::Response( + Response{resp: Some(resp), _s: PhantomData}) } } diff --git a/src/test.rs b/src/test.rs index 4f2433a9..22b09b29 100644 --- a/src/test.rs +++ b/src/test.rs @@ -192,6 +192,16 @@ impl<S: 'static> TestApp<S> { self.app = Some(self.app.take().unwrap().resource("/", |r| r.h(handler))); } + /// Register handler for "/" with resource middleware + pub fn handler2<H, M>(&mut self, handler: H, mw: M) + where H: Handler<S>, M: Middleware<S> + { + self.app = Some(self.app.take().unwrap() + .resource("/", |r| { + r.middleware(mw); + r.h(handler)})); + } + /// Register middleware pub fn middleware<T>(&mut self, mw: T) -> &mut TestApp<S> where T: Middleware<S> + 'static diff --git a/tests/test_server.rs b/tests/test_server.rs index 51919cd5..1399879b 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -135,3 +135,28 @@ fn test_middlewares() { assert_eq!(num2.load(Ordering::Relaxed), 1); assert_eq!(num3.load(Ordering::Relaxed), 1); } + + +#[test] +fn test_resource_middlewares() { + let num1 = Arc::new(AtomicUsize::new(0)); + let num2 = Arc::new(AtomicUsize::new(0)); + let num3 = Arc::new(AtomicUsize::new(0)); + + let act_num1 = Arc::clone(&num1); + let act_num2 = Arc::clone(&num2); + let act_num3 = Arc::clone(&num3); + + let srv = test::TestServer::new( + move |app| app.handler2( + httpcodes::HTTPOk, + MiddlewareTest{start: Arc::clone(&act_num1), + response: Arc::clone(&act_num2), + finish: Arc::clone(&act_num3)}) + ); + + assert!(reqwest::get(&srv.url("/")).unwrap().status().is_success()); + assert_eq!(num1.load(Ordering::Relaxed), 1); + assert_eq!(num2.load(Ordering::Relaxed), 1); + // assert_eq!(num3.load(Ordering::Relaxed), 1); +}