implemented default handlers for error middleware and associated unit tests

This commit is contained in:
erhodes 2022-06-15 13:09:21 -06:00
parent 062127a210
commit 32b3df3f26
1 changed files with 170 additions and 6 deletions

View File

@ -30,6 +30,8 @@ pub enum ErrorHandlerResponse<B> {
type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>; type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>;
type DefaultHandler<B> = Option<Rc<ErrorHandler<B>>>;
/// Middleware for registering custom status code based error handlers. /// Middleware for registering custom status code based error handlers.
/// ///
/// Register handlers with the `ErrorHandlers::handler()` method to register a custom error handler /// Register handlers with the `ErrorHandlers::handler()` method to register a custom error handler
@ -54,6 +56,8 @@ type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse
/// .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError))); /// .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError)));
/// ``` /// ```
pub struct ErrorHandlers<B> { pub struct ErrorHandlers<B> {
default_client: DefaultHandler<B>,
default_server: DefaultHandler<B>,
handlers: Handlers<B>, handlers: Handlers<B>,
} }
@ -62,6 +66,8 @@ type Handlers<B> = Rc<AHashMap<StatusCode, Box<ErrorHandler<B>>>>;
impl<B> Default for ErrorHandlers<B> { impl<B> Default for ErrorHandlers<B> {
fn default() -> Self { fn default() -> Self {
ErrorHandlers { ErrorHandlers {
default_client: Default::default(),
default_server: Default::default(),
handlers: Default::default(), handlers: Default::default(),
} }
} }
@ -83,6 +89,46 @@ impl<B> ErrorHandlers<B> {
.insert(status, Box::new(handler)); .insert(status, Box::new(handler));
self self
} }
pub fn default_handler<F>(self, handler: F) -> Self
where
F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
{
let handler = Rc::new(handler);
Self {
default_server: Some(handler.clone()),
default_client: Some(handler),
..self
}
}
pub fn default_handler_server<F>(self, handler: F) -> Self
where
F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
{
Self { default_server: Some(Rc::new(handler)), ..self }
}
pub fn default_handler_client<F>(self, handler: F) -> Self
where
F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
{
Self { default_client: Some(Rc::new(handler)), ..self }
}
fn get_handler<'a>(
status: &StatusCode,
default_client: Option<&'a dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>>,
default_server: Option<&'a dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>>,
handlers: &'a Handlers<B>,
) -> Option<&'a dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>>
{
handlers
.get(status)
.map(|h| h.as_ref())
.or_else(|| status.is_client_error().then(|| default_client).flatten())
.or_else(|| status.is_server_error().then(|| default_server).flatten())
}
} }
impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B> impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B>
@ -99,13 +145,22 @@ where
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
let handlers = self.handlers.clone(); let handlers = self.handlers.clone();
Box::pin(async move { Ok(ErrorHandlersMiddleware { service, handlers }) }) let default_client = self.default_client.clone();
let default_server = self.default_server.clone();
Box::pin(async move { Ok(ErrorHandlersMiddleware {
service,
default_client,
default_server,
handlers
}) })
} }
} }
#[doc(hidden)] #[doc(hidden)]
pub struct ErrorHandlersMiddleware<S, B> { pub struct ErrorHandlersMiddleware<S, B> {
service: S, service: S,
default_client: DefaultHandler<B>,
default_server: DefaultHandler<B>,
handlers: Handlers<B>, handlers: Handlers<B>,
} }
@ -123,8 +178,10 @@ where
fn call(&self, req: ServiceRequest) -> Self::Future { fn call(&self, req: ServiceRequest) -> Self::Future {
let handlers = self.handlers.clone(); let handlers = self.handlers.clone();
let default_client = self.default_client.clone();
let default_server = self.default_server.clone();
let fut = self.service.call(req); let fut = self.service.call(req);
ErrorHandlersFuture::ServiceFuture { fut, handlers } ErrorHandlersFuture::ServiceFuture { fut, default_client, default_server, handlers }
} }
} }
@ -137,6 +194,8 @@ pin_project! {
ServiceFuture { ServiceFuture {
#[pin] #[pin]
fut: Fut, fut: Fut,
default_client: DefaultHandler<B>,
default_server: DefaultHandler<B>,
handlers: Handlers<B>, handlers: Handlers<B>,
}, },
ErrorHandlerFuture { ErrorHandlerFuture {
@ -153,10 +212,17 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project() { match self.as_mut().project() {
ErrorHandlersProj::ServiceFuture { fut, handlers } => { ErrorHandlersProj::ServiceFuture { fut, default_client, default_server, handlers } => {
let res = ready!(fut.poll(cx))?; let res = ready!(fut.poll(cx))?;
let status = res.status();
match handlers.get(&res.status()) { let handler = ErrorHandlers::get_handler(
&status,
default_client.as_mut().map(|f| Rc::as_ref(f)),
default_server.as_mut().map(|f| Rc::as_ref(f)),
handlers
);
match handler {
Some(handler) => match handler(res)? { Some(handler) => match handler(res)? {
ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)), ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)),
ErrorHandlerResponse::Future(fut) => { ErrorHandlerResponse::Future(fut) => {
@ -165,8 +231,7 @@ where
self.poll(cx) self.poll(cx)
} }
}, }
None => Poll::Ready(Ok(res.map_into_left_body())), None => Poll::Ready(Ok(res.map_into_left_body())),
} }
} }
@ -298,4 +363,103 @@ mod tests {
"error in error handler" "error in error handler"
); );
} }
#[actix_rt::test]
async fn default_error_handler() {
#[allow(clippy::unnecessary_wraps)]
fn error_handler<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}
let make_mw = |status| async move {
ErrorHandlers::new()
.default_handler(error_handler)
.new_transform(test::status_service(status).into_service())
.await
.unwrap()
};
let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await;
let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
let resp = test::call_service(&mw_server, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
#[actix_rt::test]
async fn default_handlers_separate_client_server() {
#[allow(clippy::unnecessary_wraps)]
fn error_handler_client<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}
#[allow(clippy::unnecessary_wraps)]
fn error_handler_server<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0002"));
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}
let make_mw = |status| async move {
ErrorHandlers::new()
.default_handler_server(error_handler_server)
.default_handler_client(error_handler_client)
.new_transform(test::status_service(status).into_service())
.await
.unwrap()
};
let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await;
let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
let resp = test::call_service(&mw_server, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002");
}
#[actix_rt::test]
async fn default_handlers_specialization() {
#[allow(clippy::unnecessary_wraps)]
fn error_handler_client<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}
#[allow(clippy::unnecessary_wraps)]
fn error_handler_specific<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0003"));
Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
}
let make_mw = |status| async move {
ErrorHandlers::new()
.default_handler_client(error_handler_client)
.handler(StatusCode::UNPROCESSABLE_ENTITY, error_handler_specific)
.new_transform(test::status_service(status).into_service())
.await
.unwrap()
};
let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
let mw_specific = make_mw(StatusCode::UNPROCESSABLE_ENTITY).await;
let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
let resp = test::call_service(&mw_specific, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0003");
}
} }