""" Authentication middleware for WiFi-DensePose API """ import logging import time from typing import Optional, Dict, Any, Callable from datetime import datetime, timedelta from fastapi import Request, Response, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from starlette.middleware.base import BaseHTTPMiddleware from jose import JWTError, jwt from passlib.context import CryptContext from src.config.settings import Settings from src.logger import set_request_context logger = logging.getLogger(__name__) # Password hashing pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # JWT token handler security = HTTPBearer(auto_error=False) class AuthenticationError(Exception): """Authentication error.""" pass class AuthorizationError(Exception): """Authorization error.""" pass class TokenManager: """JWT token management.""" def __init__(self, settings: Settings): self.settings = settings self.secret_key = settings.secret_key self.algorithm = settings.jwt_algorithm self.expire_hours = settings.jwt_expire_hours def create_access_token(self, data: Dict[str, Any]) -> str: """Create JWT access token.""" to_encode = data.copy() expire = datetime.utcnow() + timedelta(hours=self.expire_hours) to_encode.update({"exp": expire, "iat": datetime.utcnow()}) encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) return encoded_jwt def verify_token(self, token: str) -> Dict[str, Any]: """Verify and decode JWT token.""" try: payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) # Check token blacklist (logout invalidation) from src.api.middleware.auth import token_blacklist if token_blacklist.is_blacklisted(token): raise AuthenticationError("Token has been revoked") return payload except JWTError as e: logger.warning(f"JWT verification failed: {e}") raise AuthenticationError("Invalid token") def decode_token_claims(self, token: str) -> Optional[Dict[str, Any]]: """Decode and verify token, returning its claims. Unlike the previous implementation, this method always verifies the token signature. Use verify_token() for full validation including expiry checks; this helper is provided only for inspecting claims from an already-verified token. """ try: return jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) except JWTError: return None class UserManager: """User management for authentication.""" def __init__(self): # In a real application, this would connect to a database. # No default users are created -- users must be provisioned # through the create_user() method or an external identity provider. self._users: Dict[str, Dict[str, Any]] = {} @staticmethod def hash_password(password: str) -> str: """Hash a password.""" return pwd_context.hash(password) @staticmethod def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash.""" return pwd_context.verify(plain_password, hashed_password) def get_user(self, username: str) -> Optional[Dict[str, Any]]: """Get user by username.""" return self._users.get(username) def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]: """Authenticate user with username and password.""" user = self.get_user(username) if not user: return None if not self.verify_password(password, user["hashed_password"]): return None if not user.get("is_active", False): return None return user def create_user(self, username: str, email: str, password: str, roles: list = None) -> Dict[str, Any]: """Create a new user.""" if username in self._users: raise ValueError("User already exists") user = { "username": username, "email": email, "hashed_password": self.hash_password(password), "roles": roles or ["user"], "is_active": True, "created_at": datetime.utcnow(), } self._users[username] = user return user def update_user(self, username: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Update user information.""" user = self._users.get(username) if not user: return None # Don't allow updating certain fields protected_fields = {"username", "created_at", "hashed_password"} updates = {k: v for k, v in updates.items() if k not in protected_fields} user.update(updates) return user def deactivate_user(self, username: str) -> bool: """Deactivate a user.""" user = self._users.get(username) if user: user["is_active"] = False return True return False class AuthenticationMiddleware(BaseHTTPMiddleware): """Authentication middleware for FastAPI.""" def __init__(self, app, settings: Settings): super().__init__(app) self.settings = settings self.token_manager = TokenManager(settings) self.user_manager = UserManager() self.enabled = settings.enable_authentication async def dispatch(self, request: Request, call_next: Callable) -> Response: """Process request through authentication middleware.""" start_time = time.time() try: # Skip authentication for certain paths if self._should_skip_auth(request): response = await call_next(request) return response # Skip if authentication is disabled if not self.enabled: response = await call_next(request) return response # Extract and verify token user_info = await self._authenticate_request(request) # Set user context if user_info: request.state.user = user_info set_request_context(user_id=user_info.get("username")) # Process request response = await call_next(request) # Add authentication headers self._add_auth_headers(response, user_info) return response except AuthenticationError as e: logger.warning(f"Authentication failed: {e}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e), headers={"WWW-Authenticate": "Bearer"}, ) except AuthorizationError as e: logger.warning(f"Authorization failed: {e}") raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=str(e), ) except Exception as e: logger.error(f"Authentication middleware error: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Authentication service error", ) finally: # Log request processing time processing_time = time.time() - start_time logger.debug(f"Auth middleware processing time: {processing_time:.3f}s") def _should_skip_auth(self, request: Request) -> bool: """Check if authentication should be skipped for this request.""" path = request.url.path # Skip authentication for these paths skip_paths = [ "/health", "/metrics", "/docs", "/redoc", "/openapi.json", "/auth/login", "/auth/register", "/static", ] return any(path.startswith(skip_path) for skip_path in skip_paths) async def _authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]: """Authenticate the request and return user info.""" # Try to get token from Authorization header authorization = request.headers.get("Authorization") if not authorization: if self._requires_auth(request): raise AuthenticationError("Missing authorization header") return None # Extract token try: scheme, token = authorization.split() if scheme.lower() != "bearer": raise AuthenticationError("Invalid authentication scheme") except ValueError: raise AuthenticationError("Invalid authorization header format") # Verify token try: payload = self.token_manager.verify_token(token) username = payload.get("sub") if not username: raise AuthenticationError("Invalid token payload") # Get user info user = self.user_manager.get_user(username) if not user: raise AuthenticationError("User not found") if not user.get("is_active", False): raise AuthenticationError("User account is disabled") # Return user info without sensitive data return { "username": user["username"], "email": user["email"], "roles": user["roles"], "is_active": user["is_active"], } except AuthenticationError: raise except Exception as e: logger.error(f"Token verification error: {e}") raise AuthenticationError("Token verification failed") def _requires_auth(self, request: Request) -> bool: """Check if the request requires authentication.""" # All API endpoints require authentication by default path = request.url.path return path.startswith("/api/") or path.startswith("/ws/") def _add_auth_headers(self, response: Response, user_info: Optional[Dict[str, Any]]): """Add authentication-related headers to response.""" if user_info: response.headers["X-User"] = user_info["username"] response.headers["X-User-Roles"] = ",".join(user_info["roles"]) async def login(self, username: str, password: str) -> Dict[str, Any]: """Authenticate user and return token.""" user = self.user_manager.authenticate_user(username, password) if not user: raise AuthenticationError("Invalid username or password") # Create token token_data = { "sub": user["username"], "email": user["email"], "roles": user["roles"], } access_token = self.token_manager.create_access_token(token_data) return { "access_token": access_token, "token_type": "bearer", "expires_in": self.settings.jwt_expire_hours * 3600, "user": { "username": user["username"], "email": user["email"], "roles": user["roles"], } } async def register(self, username: str, email: str, password: str) -> Dict[str, Any]: """Register a new user.""" try: user = self.user_manager.create_user(username, email, password) # Create token for new user token_data = { "sub": user["username"], "email": user["email"], "roles": user["roles"], } access_token = self.token_manager.create_access_token(token_data) return { "access_token": access_token, "token_type": "bearer", "expires_in": self.settings.jwt_expire_hours * 3600, "user": { "username": user["username"], "email": user["email"], "roles": user["roles"], } } except ValueError as e: raise AuthenticationError(str(e)) async def refresh_token(self, token: str) -> Dict[str, Any]: """Refresh an access token.""" try: payload = self.token_manager.verify_token(token) username = payload.get("sub") user = self.user_manager.get_user(username) if not user or not user.get("is_active", False): raise AuthenticationError("User not found or inactive") # Create new token token_data = { "sub": user["username"], "email": user["email"], "roles": user["roles"], } new_token = self.token_manager.create_access_token(token_data) return { "access_token": new_token, "token_type": "bearer", "expires_in": self.settings.jwt_expire_hours * 3600, } except Exception as e: raise AuthenticationError("Token refresh failed") def check_permission(self, user_info: Dict[str, Any], required_role: str) -> bool: """Check if user has required role/permission.""" user_roles = user_info.get("roles", []) # Admin role has all permissions if "admin" in user_roles: return True # Check specific role return required_role in user_roles def require_role(self, required_role: str): """Decorator to require specific role.""" def decorator(func): import functools @functools.wraps(func) async def wrapper(request: Request, *args, **kwargs): user_info = getattr(request.state, "user", None) if not user_info: raise AuthorizationError("Authentication required") if not self.check_permission(user_info, required_role): raise AuthorizationError(f"Role '{required_role}' required") return await func(request, *args, **kwargs) return wrapper return decorator # Global authentication middleware instance _auth_middleware: Optional[AuthenticationMiddleware] = None def get_auth_middleware(settings: Settings) -> AuthenticationMiddleware: """Get authentication middleware instance.""" global _auth_middleware if _auth_middleware is None: _auth_middleware = AuthenticationMiddleware(settings) return _auth_middleware def get_current_user(request: Request) -> Optional[Dict[str, Any]]: """Get current authenticated user from request.""" return getattr(request.state, "user", None) def require_authentication(request: Request) -> Dict[str, Any]: """Require authentication and return user info.""" user = get_current_user(request) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required", headers={"WWW-Authenticate": "Bearer"}, ) return user def require_role(role: str): """Dependency to require specific role.""" def dependency(request: Request) -> Dict[str, Any]: user = require_authentication(request) auth_middleware = get_auth_middleware(request.app.state.settings) if not auth_middleware.check_permission(user, role): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Role '{role}' required", ) return user return dependency