504 lines
17 KiB
Python
504 lines
17 KiB
Python
"""
|
|
Global error handling middleware for WiFi-DensePose API
|
|
"""
|
|
|
|
import logging
|
|
import traceback
|
|
import time
|
|
from typing import Dict, Any, Optional, Callable, Union
|
|
from datetime import datetime
|
|
|
|
from fastapi import Request, Response, HTTPException, status
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.exceptions import RequestValidationError
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
from pydantic import ValidationError
|
|
|
|
from src.config.settings import Settings
|
|
from src.logger import get_request_context
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ErrorResponse:
|
|
"""Standardized error response format."""
|
|
|
|
def __init__(
|
|
self,
|
|
error_code: str,
|
|
message: str,
|
|
details: Optional[Dict[str, Any]] = None,
|
|
status_code: int = 500,
|
|
request_id: Optional[str] = None,
|
|
):
|
|
self.error_code = error_code
|
|
self.message = message
|
|
self.details = details or {}
|
|
self.status_code = status_code
|
|
self.request_id = request_id
|
|
self.timestamp = datetime.utcnow().isoformat()
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for JSON response."""
|
|
response = {
|
|
"error": {
|
|
"code": self.error_code,
|
|
"message": self.message,
|
|
"timestamp": self.timestamp,
|
|
}
|
|
}
|
|
|
|
if self.details:
|
|
response["error"]["details"] = self.details
|
|
|
|
if self.request_id:
|
|
response["error"]["request_id"] = self.request_id
|
|
|
|
return response
|
|
|
|
def to_response(self) -> JSONResponse:
|
|
"""Convert to FastAPI JSONResponse."""
|
|
headers = {}
|
|
if self.request_id:
|
|
headers["X-Request-ID"] = self.request_id
|
|
|
|
return JSONResponse(
|
|
status_code=self.status_code,
|
|
content=self.to_dict(),
|
|
headers=headers
|
|
)
|
|
|
|
|
|
class ErrorHandler:
|
|
"""Central error handler for the application."""
|
|
|
|
def __init__(self, settings: Settings):
|
|
self.settings = settings
|
|
self.include_traceback = settings.debug and settings.is_development
|
|
self.log_errors = True
|
|
|
|
def handle_http_exception(self, request: Request, exc: HTTPException) -> ErrorResponse:
|
|
"""Handle HTTP exceptions."""
|
|
request_context = get_request_context()
|
|
request_id = request_context.get("request_id")
|
|
|
|
# Log the error
|
|
if self.log_errors:
|
|
logger.warning(
|
|
f"HTTP {exc.status_code}: {exc.detail} - "
|
|
f"{request.method} {request.url.path} - "
|
|
f"Request ID: {request_id}"
|
|
)
|
|
|
|
# Determine error code
|
|
error_code = self._get_error_code_for_status(exc.status_code)
|
|
|
|
# Build error details
|
|
details = {}
|
|
if hasattr(exc, "headers") and exc.headers:
|
|
details["headers"] = exc.headers
|
|
|
|
if self.include_traceback and hasattr(exc, "__traceback__"):
|
|
details["traceback"] = traceback.format_exception(
|
|
type(exc), exc, exc.__traceback__
|
|
)
|
|
|
|
return ErrorResponse(
|
|
error_code=error_code,
|
|
message=str(exc.detail),
|
|
details=details,
|
|
status_code=exc.status_code,
|
|
request_id=request_id
|
|
)
|
|
|
|
def handle_validation_error(self, request: Request, exc: RequestValidationError) -> ErrorResponse:
|
|
"""Handle request validation errors."""
|
|
request_context = get_request_context()
|
|
request_id = request_context.get("request_id")
|
|
|
|
# Log the error
|
|
if self.log_errors:
|
|
logger.warning(
|
|
f"Validation error: {exc.errors()} - "
|
|
f"{request.method} {request.url.path} - "
|
|
f"Request ID: {request_id}"
|
|
)
|
|
|
|
# Format validation errors
|
|
validation_details = []
|
|
for error in exc.errors():
|
|
validation_details.append({
|
|
"field": ".".join(str(loc) for loc in error["loc"]),
|
|
"message": error["msg"],
|
|
"type": error["type"],
|
|
"input": error.get("input"),
|
|
})
|
|
|
|
details = {
|
|
"validation_errors": validation_details,
|
|
"error_count": len(validation_details)
|
|
}
|
|
|
|
if self.include_traceback:
|
|
details["traceback"] = traceback.format_exception(
|
|
type(exc), exc, exc.__traceback__
|
|
)
|
|
|
|
return ErrorResponse(
|
|
error_code="VALIDATION_ERROR",
|
|
message="Request validation failed",
|
|
details=details,
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
request_id=request_id
|
|
)
|
|
|
|
def handle_pydantic_error(self, request: Request, exc: ValidationError) -> ErrorResponse:
|
|
"""Handle Pydantic validation errors."""
|
|
request_context = get_request_context()
|
|
request_id = request_context.get("request_id")
|
|
|
|
# Log the error
|
|
if self.log_errors:
|
|
logger.warning(
|
|
f"Pydantic validation error: {exc.errors()} - "
|
|
f"{request.method} {request.url.path} - "
|
|
f"Request ID: {request_id}"
|
|
)
|
|
|
|
# Format validation errors
|
|
validation_details = []
|
|
for error in exc.errors():
|
|
validation_details.append({
|
|
"field": ".".join(str(loc) for loc in error["loc"]),
|
|
"message": error["msg"],
|
|
"type": error["type"],
|
|
})
|
|
|
|
details = {
|
|
"validation_errors": validation_details,
|
|
"error_count": len(validation_details)
|
|
}
|
|
|
|
return ErrorResponse(
|
|
error_code="DATA_VALIDATION_ERROR",
|
|
message="Data validation failed",
|
|
details=details,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
request_id=request_id
|
|
)
|
|
|
|
def handle_generic_exception(self, request: Request, exc: Exception) -> ErrorResponse:
|
|
"""Handle generic exceptions."""
|
|
request_context = get_request_context()
|
|
request_id = request_context.get("request_id")
|
|
|
|
# Log the error
|
|
if self.log_errors:
|
|
logger.error(
|
|
f"Unhandled exception: {type(exc).__name__}: {exc} - "
|
|
f"{request.method} {request.url.path} - "
|
|
f"Request ID: {request_id}",
|
|
exc_info=True
|
|
)
|
|
|
|
# Determine error details
|
|
details = {}
|
|
|
|
if self.include_traceback:
|
|
details["exception_type"] = type(exc).__name__
|
|
details["traceback"] = traceback.format_exception(
|
|
type(exc), exc, exc.__traceback__
|
|
)
|
|
|
|
# Don't expose internal error details in production
|
|
if self.settings.is_production:
|
|
message = "An internal server error occurred"
|
|
else:
|
|
message = str(exc) or "An unexpected error occurred"
|
|
|
|
return ErrorResponse(
|
|
error_code="INTERNAL_SERVER_ERROR",
|
|
message=message,
|
|
details=details,
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
request_id=request_id
|
|
)
|
|
|
|
def handle_database_error(self, request: Request, exc: Exception) -> ErrorResponse:
|
|
"""Handle database-related errors."""
|
|
request_context = get_request_context()
|
|
request_id = request_context.get("request_id")
|
|
|
|
# Log the error
|
|
if self.log_errors:
|
|
logger.error(
|
|
f"Database error: {type(exc).__name__}: {exc} - "
|
|
f"{request.method} {request.url.path} - "
|
|
f"Request ID: {request_id}",
|
|
exc_info=True
|
|
)
|
|
|
|
details = {
|
|
"exception_type": type(exc).__name__,
|
|
"category": "database"
|
|
}
|
|
|
|
if self.include_traceback:
|
|
details["traceback"] = traceback.format_exception(
|
|
type(exc), exc, exc.__traceback__
|
|
)
|
|
|
|
return ErrorResponse(
|
|
error_code="DATABASE_ERROR",
|
|
message="Database operation failed" if self.settings.is_production else str(exc),
|
|
details=details,
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
request_id=request_id
|
|
)
|
|
|
|
def handle_external_service_error(self, request: Request, exc: Exception) -> ErrorResponse:
|
|
"""Handle external service errors."""
|
|
request_context = get_request_context()
|
|
request_id = request_context.get("request_id")
|
|
|
|
# Log the error
|
|
if self.log_errors:
|
|
logger.error(
|
|
f"External service error: {type(exc).__name__}: {exc} - "
|
|
f"{request.method} {request.url.path} - "
|
|
f"Request ID: {request_id}",
|
|
exc_info=True
|
|
)
|
|
|
|
details = {
|
|
"exception_type": type(exc).__name__,
|
|
"category": "external_service"
|
|
}
|
|
|
|
return ErrorResponse(
|
|
error_code="EXTERNAL_SERVICE_ERROR",
|
|
message="External service unavailable" if self.settings.is_production else str(exc),
|
|
details=details,
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
request_id=request_id
|
|
)
|
|
|
|
def _get_error_code_for_status(self, status_code: int) -> str:
|
|
"""Get error code for HTTP status code."""
|
|
error_codes = {
|
|
400: "BAD_REQUEST",
|
|
401: "UNAUTHORIZED",
|
|
403: "FORBIDDEN",
|
|
404: "NOT_FOUND",
|
|
405: "METHOD_NOT_ALLOWED",
|
|
409: "CONFLICT",
|
|
422: "UNPROCESSABLE_ENTITY",
|
|
429: "TOO_MANY_REQUESTS",
|
|
500: "INTERNAL_SERVER_ERROR",
|
|
502: "BAD_GATEWAY",
|
|
503: "SERVICE_UNAVAILABLE",
|
|
504: "GATEWAY_TIMEOUT",
|
|
}
|
|
|
|
return error_codes.get(status_code, "HTTP_ERROR")
|
|
|
|
|
|
class ErrorHandlingMiddleware:
|
|
"""Error handling middleware for FastAPI."""
|
|
|
|
def __init__(self, app, settings: Settings):
|
|
self.app = app
|
|
self.settings = settings
|
|
self.error_handler = ErrorHandler(settings)
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
"""Process request through error handling middleware."""
|
|
if scope["type"] != "http":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
await self.app(scope, receive, send)
|
|
except Exception as exc:
|
|
# Create a mock request for error handling
|
|
from starlette.requests import Request
|
|
request = Request(scope, receive)
|
|
|
|
# Handle different exception types
|
|
if isinstance(exc, HTTPException):
|
|
error_response = self.error_handler.handle_http_exception(request, exc)
|
|
elif isinstance(exc, RequestValidationError):
|
|
error_response = self.error_handler.handle_validation_error(request, exc)
|
|
elif isinstance(exc, ValidationError):
|
|
error_response = self.error_handler.handle_pydantic_error(request, exc)
|
|
else:
|
|
# Check for specific error types
|
|
if self._is_database_error(exc):
|
|
error_response = self.error_handler.handle_database_error(request, exc)
|
|
elif self._is_external_service_error(exc):
|
|
error_response = self.error_handler.handle_external_service_error(request, exc)
|
|
else:
|
|
error_response = self.error_handler.handle_generic_exception(request, exc)
|
|
|
|
# Send the error response
|
|
response = error_response.to_response()
|
|
await response(scope, receive, send)
|
|
|
|
finally:
|
|
# Log request processing time
|
|
processing_time = time.time() - start_time
|
|
logger.debug(f"Error handling middleware processing time: {processing_time:.3f}s")
|
|
|
|
def _is_database_error(self, exc: Exception) -> bool:
|
|
"""Check if exception is database-related."""
|
|
database_exceptions = [
|
|
"sqlalchemy",
|
|
"psycopg2",
|
|
"pymongo",
|
|
"redis",
|
|
"ConnectionError",
|
|
"OperationalError",
|
|
"IntegrityError",
|
|
]
|
|
|
|
exc_module = getattr(type(exc), "__module__", "")
|
|
exc_name = type(exc).__name__
|
|
|
|
return any(
|
|
db_exc in exc_module or db_exc in exc_name
|
|
for db_exc in database_exceptions
|
|
)
|
|
|
|
def _is_external_service_error(self, exc: Exception) -> bool:
|
|
"""Check if exception is external service-related."""
|
|
external_exceptions = [
|
|
"requests",
|
|
"httpx",
|
|
"aiohttp",
|
|
"urllib",
|
|
"ConnectionError",
|
|
"TimeoutError",
|
|
"ConnectTimeout",
|
|
"ReadTimeout",
|
|
]
|
|
|
|
exc_module = getattr(type(exc), "__module__", "")
|
|
exc_name = type(exc).__name__
|
|
|
|
return any(
|
|
ext_exc in exc_module or ext_exc in exc_name
|
|
for ext_exc in external_exceptions
|
|
)
|
|
|
|
|
|
def setup_error_handling(app, settings: Settings):
|
|
"""Setup error handling for the application."""
|
|
logger.info("Setting up error handling middleware")
|
|
|
|
error_handler = ErrorHandler(settings)
|
|
|
|
# Add exception handlers
|
|
@app.exception_handler(HTTPException)
|
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
error_response = error_handler.handle_http_exception(request, exc)
|
|
return error_response.to_response()
|
|
|
|
@app.exception_handler(StarletteHTTPException)
|
|
async def starlette_http_exception_handler(request: Request, exc: StarletteHTTPException):
|
|
# Convert Starlette HTTPException to FastAPI HTTPException
|
|
fastapi_exc = HTTPException(status_code=exc.status_code, detail=exc.detail)
|
|
error_response = error_handler.handle_http_exception(request, fastapi_exc)
|
|
return error_response.to_response()
|
|
|
|
@app.exception_handler(RequestValidationError)
|
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
|
error_response = error_handler.handle_validation_error(request, exc)
|
|
return error_response.to_response()
|
|
|
|
@app.exception_handler(ValidationError)
|
|
async def pydantic_exception_handler(request: Request, exc: ValidationError):
|
|
error_response = error_handler.handle_pydantic_error(request, exc)
|
|
return error_response.to_response()
|
|
|
|
@app.exception_handler(Exception)
|
|
async def generic_exception_handler(request: Request, exc: Exception):
|
|
error_response = error_handler.handle_generic_exception(request, exc)
|
|
return error_response.to_response()
|
|
|
|
# Add middleware for additional error handling
|
|
# Note: We use exception handlers instead of custom middleware to avoid ASGI conflicts
|
|
# The middleware approach is commented out but kept for reference
|
|
# middleware = ErrorHandlingMiddleware(app, settings)
|
|
# app.add_middleware(ErrorHandlingMiddleware, settings=settings)
|
|
|
|
logger.info("Error handling configured")
|
|
|
|
|
|
class CustomHTTPException(HTTPException):
|
|
"""Custom HTTP exception with additional context."""
|
|
|
|
def __init__(
|
|
self,
|
|
status_code: int,
|
|
detail: str,
|
|
error_code: Optional[str] = None,
|
|
context: Optional[Dict[str, Any]] = None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
):
|
|
super().__init__(status_code=status_code, detail=detail, headers=headers)
|
|
self.error_code = error_code
|
|
self.context = context or {}
|
|
|
|
|
|
class BusinessLogicError(CustomHTTPException):
|
|
"""Exception for business logic errors."""
|
|
|
|
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
|
|
super().__init__(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=message,
|
|
error_code="BUSINESS_LOGIC_ERROR",
|
|
context=context
|
|
)
|
|
|
|
|
|
class ResourceNotFoundError(CustomHTTPException):
|
|
"""Exception for resource not found errors."""
|
|
|
|
def __init__(self, resource: str, identifier: str):
|
|
super().__init__(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"{resource} not found",
|
|
error_code="RESOURCE_NOT_FOUND",
|
|
context={"resource": resource, "identifier": identifier}
|
|
)
|
|
|
|
|
|
class ConflictError(CustomHTTPException):
|
|
"""Exception for conflict errors."""
|
|
|
|
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
|
|
super().__init__(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail=message,
|
|
error_code="CONFLICT_ERROR",
|
|
context=context
|
|
)
|
|
|
|
|
|
class ServiceUnavailableError(CustomHTTPException):
|
|
"""Exception for service unavailable errors."""
|
|
|
|
def __init__(self, service: str, reason: Optional[str] = None):
|
|
detail = f"{service} service is unavailable"
|
|
if reason:
|
|
detail += f": {reason}"
|
|
|
|
super().__init__(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail=detail,
|
|
error_code="SERVICE_UNAVAILABLE",
|
|
context={"service": service, "reason": reason}
|
|
) |