diff --git a/actix-web/src/middleware/err_handlers.rs b/actix-web/src/middleware/err_handlers.rs index c44b5bcb1..03103ff5a 100644 --- a/actix-web/src/middleware/err_handlers.rs +++ b/actix-web/src/middleware/err_handlers.rs @@ -119,7 +119,6 @@ impl ErrorHandlers { self } - /// Register a default error handler. /// /// Any request with a status code that hasn't been given a specific other handler (by calling @@ -141,7 +140,10 @@ impl ErrorHandlers { where F: Fn(ServiceResponse) -> Result> + 'static, { - Self { default_client: Some(Rc::new(handler)), ..self } + Self { + default_client: Some(Rc::new(handler)), + ..self + } } /// Register a handler on which to fall back for server error status codes (500-599). @@ -149,21 +151,26 @@ impl ErrorHandlers { where F: Fn(ServiceResponse) -> Result> + 'static, { - Self { default_server: Some(Rc::new(handler)), ..self } + Self { + default_server: Some(Rc::new(handler)), + ..self + } } - /// Selects the most appropriate handler for the given status code. /// /// If the `handlers` map has an entry for that status code, that handler is returned. /// Otherwise, fall back on the appropriate default handler. fn get_handler<'a>( status: &StatusCode, - default_client: Option<&'a dyn Fn(ServiceResponse) -> Result>>, - default_server: Option<&'a dyn Fn(ServiceResponse) -> Result>>, + default_client: Option< + &'a dyn Fn(ServiceResponse) -> Result>, + >, + default_server: Option< + &'a dyn Fn(ServiceResponse) -> Result>, + >, handlers: &'a Handlers, - ) -> Option<&'a dyn Fn(ServiceResponse) -> Result>> - { + ) -> Option<&'a dyn Fn(ServiceResponse) -> Result>> { handlers .get(status) .map(|h| h.as_ref()) @@ -188,12 +195,14 @@ where let handlers = self.handlers.clone(); let default_client = self.default_client.clone(); let default_server = self.default_server.clone(); - Box::pin(async move { Ok(ErrorHandlersMiddleware { - service, - default_client, - default_server, - handlers - }) }) + Box::pin(async move { + Ok(ErrorHandlersMiddleware { + service, + default_client, + default_server, + handlers, + }) + }) } } @@ -222,7 +231,12 @@ where let default_client = self.default_client.clone(); let default_server = self.default_server.clone(); let fut = self.service.call(req); - ErrorHandlersFuture::ServiceFuture { fut, default_client, default_server, handlers } + ErrorHandlersFuture::ServiceFuture { + fut, + default_client, + default_server, + handlers, + } } } @@ -253,7 +267,12 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.as_mut().project() { - ErrorHandlersProj::ServiceFuture { fut, default_client, default_server, handlers } => { + ErrorHandlersProj::ServiceFuture { + fut, + default_client, + default_server, + handlers, + } => { let res = ready!(fut.poll(cx))?; let status = res.status(); @@ -261,7 +280,7 @@ where &status, default_client.as_mut().map(|f| Rc::as_ref(f)), default_server.as_mut().map(|f| Rc::as_ref(f)), - handlers + handlers, ); match handler { Some(handler) => match handler(res)? { @@ -272,7 +291,7 @@ where self.poll(cx) } - } + }, None => Poll::Ready(Ok(res.map_into_left_body())), } } @@ -425,17 +444,21 @@ mod tests { let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await; let mw_client = make_mw(StatusCode::BAD_REQUEST).await; - let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await; + let resp = + test::call_service(&mw_client, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); - let resp = test::call_service(&mw_server, TestRequest::default().to_srv_request()).await; + let resp = + test::call_service(&mw_server, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } #[actix_rt::test] async fn default_handlers_separate_client_server() { #[allow(clippy::unnecessary_wraps)] - fn error_handler_client(mut res: ServiceResponse) -> Result> { + fn error_handler_client( + mut res: ServiceResponse, + ) -> Result> { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); @@ -443,7 +466,9 @@ mod tests { } #[allow(clippy::unnecessary_wraps)] - fn error_handler_server(mut res: ServiceResponse) -> Result> { + fn error_handler_server( + mut res: ServiceResponse, + ) -> Result> { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0002")); @@ -461,17 +486,21 @@ mod tests { let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await; let mw_client = make_mw(StatusCode::BAD_REQUEST).await; - let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await; + let resp = + test::call_service(&mw_client, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); - let resp = test::call_service(&mw_server, TestRequest::default().to_srv_request()).await; + let resp = + test::call_service(&mw_server, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002"); } #[actix_rt::test] async fn default_handlers_specialization() { #[allow(clippy::unnecessary_wraps)] - fn error_handler_client(mut res: ServiceResponse) -> Result> { + fn error_handler_client( + mut res: ServiceResponse, + ) -> Result> { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); @@ -479,7 +508,9 @@ mod tests { } #[allow(clippy::unnecessary_wraps)] - fn error_handler_specific(mut res: ServiceResponse) -> Result> { + fn error_handler_specific( + mut res: ServiceResponse, + ) -> Result> { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0003")); @@ -497,10 +528,12 @@ mod tests { let mw_client = make_mw(StatusCode::BAD_REQUEST).await; let mw_specific = make_mw(StatusCode::UNPROCESSABLE_ENTITY).await; - let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await; + let resp = + test::call_service(&mw_client, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); - let resp = test::call_service(&mw_specific, TestRequest::default().to_srv_request()).await; + let resp = + test::call_service(&mw_specific, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0003"); } }