""" WebSocket connection manager for WiFi-DensePose API """ import asyncio import json import logging import uuid from typing import Dict, List, Optional, Any, Set from datetime import datetime, timedelta from collections import defaultdict from fastapi import WebSocket, WebSocketDisconnect logger = logging.getLogger(__name__) class WebSocketConnection: """Represents a WebSocket connection with metadata.""" def __init__( self, websocket: WebSocket, client_id: str, stream_type: str, zone_ids: Optional[List[str]] = None, **config ): self.websocket = websocket self.client_id = client_id self.stream_type = stream_type self.zone_ids = zone_ids or [] self.config = config self.connected_at = datetime.utcnow() self.last_ping = datetime.utcnow() self.message_count = 0 self.is_active = True async def send_json(self, data: Dict[str, Any]): """Send JSON data to client.""" try: await self.websocket.send_json(data) self.message_count += 1 except Exception as e: logger.error(f"Error sending to client {self.client_id}: {e}") self.is_active = False raise async def send_text(self, message: str): """Send text message to client.""" try: await self.websocket.send_text(message) self.message_count += 1 except Exception as e: logger.error(f"Error sending text to client {self.client_id}: {e}") self.is_active = False raise def update_config(self, config: Dict[str, Any]): """Update connection configuration.""" self.config.update(config) # Update zone IDs if provided if "zone_ids" in config: self.zone_ids = config["zone_ids"] or [] def matches_filter( self, stream_type: Optional[str] = None, zone_ids: Optional[List[str]] = None, **filters ) -> bool: """Check if connection matches given filters.""" # Check stream type if stream_type and self.stream_type != stream_type: return False # Check zone IDs if zone_ids: if not self.zone_ids: # Connection listens to all zones return True # Check if any requested zone is in connection's zones if not any(zone in self.zone_ids for zone in zone_ids): return False # Check additional filters for key, value in filters.items(): if key in self.config and self.config[key] != value: return False return True def get_info(self) -> Dict[str, Any]: """Get connection information.""" return { "client_id": self.client_id, "stream_type": self.stream_type, "zone_ids": self.zone_ids, "config": self.config, "connected_at": self.connected_at.isoformat(), "last_ping": self.last_ping.isoformat(), "message_count": self.message_count, "is_active": self.is_active, "uptime_seconds": (datetime.utcnow() - self.connected_at).total_seconds() } class ConnectionManager: """Manages WebSocket connections for real-time streaming.""" def __init__(self): self.connections: Dict[str, WebSocketConnection] = {} self.connections_by_type: Dict[str, Set[str]] = defaultdict(set) self.connections_by_zone: Dict[str, Set[str]] = defaultdict(set) self.metrics = { "total_connections": 0, "active_connections": 0, "messages_sent": 0, "errors": 0, "start_time": datetime.utcnow() } self._cleanup_task = None self._started = False async def connect( self, websocket: WebSocket, stream_type: str, zone_ids: Optional[List[str]] = None, **config ) -> str: """Register a new WebSocket connection.""" client_id = str(uuid.uuid4()) try: # Create connection object connection = WebSocketConnection( websocket=websocket, client_id=client_id, stream_type=stream_type, zone_ids=zone_ids, **config ) # Store connection self.connections[client_id] = connection self.connections_by_type[stream_type].add(client_id) # Index by zones if zone_ids: for zone_id in zone_ids: self.connections_by_zone[zone_id].add(client_id) # Update metrics self.metrics["total_connections"] += 1 self.metrics["active_connections"] = len(self.connections) logger.info(f"WebSocket client {client_id} connected for {stream_type}") return client_id except Exception as e: logger.error(f"Error connecting WebSocket client: {e}") raise async def disconnect(self, client_id: str) -> bool: """Disconnect a WebSocket client.""" if client_id not in self.connections: return False try: connection = self.connections[client_id] # Remove from indexes self.connections_by_type[connection.stream_type].discard(client_id) for zone_id in connection.zone_ids: self.connections_by_zone[zone_id].discard(client_id) # Close WebSocket if still active if connection.is_active: try: await connection.websocket.close() except Exception: pass # Connection might already be closed # Remove connection del self.connections[client_id] # Update metrics self.metrics["active_connections"] = len(self.connections) logger.info(f"WebSocket client {client_id} disconnected") return True except Exception as e: logger.error(f"Error disconnecting client {client_id}: {e}") return False async def disconnect_all(self): """Disconnect all WebSocket clients.""" client_ids = list(self.connections.keys()) for client_id in client_ids: await self.disconnect(client_id) logger.info("All WebSocket clients disconnected") async def send_to_client(self, client_id: str, data: Dict[str, Any]) -> bool: """Send data to a specific client.""" if client_id not in self.connections: return False connection = self.connections[client_id] try: await connection.send_json(data) self.metrics["messages_sent"] += 1 return True except Exception as e: logger.error(f"Error sending to client {client_id}: {e}") self.metrics["errors"] += 1 # Mark connection as inactive and schedule for cleanup connection.is_active = False return False async def broadcast( self, data: Dict[str, Any], stream_type: Optional[str] = None, zone_ids: Optional[List[str]] = None, **filters ) -> int: """Broadcast data to matching clients.""" sent_count = 0 failed_clients = [] # Get matching connections matching_clients = self._get_matching_clients( stream_type=stream_type, zone_ids=zone_ids, **filters ) # Send to all matching clients for client_id in matching_clients: try: success = await self.send_to_client(client_id, data) if success: sent_count += 1 else: failed_clients.append(client_id) except Exception as e: logger.error(f"Error broadcasting to client {client_id}: {e}") failed_clients.append(client_id) # Clean up failed connections for client_id in failed_clients: await self.disconnect(client_id) return sent_count async def update_client_config(self, client_id: str, config: Dict[str, Any]) -> bool: """Update client configuration.""" if client_id not in self.connections: return False connection = self.connections[client_id] old_zones = set(connection.zone_ids) # Update configuration connection.update_config(config) # Update zone indexes if zones changed new_zones = set(connection.zone_ids) # Remove from old zones for zone_id in old_zones - new_zones: self.connections_by_zone[zone_id].discard(client_id) # Add to new zones for zone_id in new_zones - old_zones: self.connections_by_zone[zone_id].add(client_id) return True async def get_client_status(self, client_id: str) -> Optional[Dict[str, Any]]: """Get status of a specific client.""" if client_id not in self.connections: return None return self.connections[client_id].get_info() async def get_connected_clients(self) -> List[Dict[str, Any]]: """Get list of all connected clients.""" return [conn.get_info() for conn in self.connections.values()] async def get_connection_stats(self) -> Dict[str, Any]: """Get connection statistics.""" stats = { "total_clients": len(self.connections), "clients_by_type": { stream_type: len(clients) for stream_type, clients in self.connections_by_type.items() }, "clients_by_zone": { zone_id: len(clients) for zone_id, clients in self.connections_by_zone.items() if clients # Only include zones with active clients }, "active_clients": sum(1 for conn in self.connections.values() if conn.is_active), "inactive_clients": sum(1 for conn in self.connections.values() if not conn.is_active) } return stats async def get_metrics(self) -> Dict[str, Any]: """Get detailed metrics.""" uptime = (datetime.utcnow() - self.metrics["start_time"]).total_seconds() return { **self.metrics, "active_connections": len(self.connections), "uptime_seconds": uptime, "messages_per_second": self.metrics["messages_sent"] / max(uptime, 1), "error_rate": self.metrics["errors"] / max(self.metrics["messages_sent"], 1) } def _get_matching_clients( self, stream_type: Optional[str] = None, zone_ids: Optional[List[str]] = None, **filters ) -> List[str]: """Get client IDs that match the given filters.""" candidates = set(self.connections.keys()) # Filter by stream type if stream_type: type_clients = self.connections_by_type.get(stream_type, set()) candidates &= type_clients # Filter by zones if zone_ids: zone_clients = set() for zone_id in zone_ids: zone_clients.update(self.connections_by_zone.get(zone_id, set())) # Also include clients listening to all zones (empty zone list) all_zone_clients = { client_id for client_id, conn in self.connections.items() if not conn.zone_ids } zone_clients.update(all_zone_clients) candidates &= zone_clients # Apply additional filters matching_clients = [] for client_id in candidates: connection = self.connections[client_id] if connection.is_active and connection.matches_filter(**filters): matching_clients.append(client_id) return matching_clients async def ping_clients(self): """Send ping to all connected clients.""" ping_data = { "type": "ping", "timestamp": datetime.utcnow().isoformat() } failed_clients = [] for client_id, connection in self.connections.items(): try: await connection.send_json(ping_data) connection.last_ping = datetime.utcnow() except Exception as e: logger.warning(f"Ping failed for client {client_id}: {e}") failed_clients.append(client_id) # Clean up failed connections for client_id in failed_clients: await self.disconnect(client_id) async def cleanup_inactive_connections(self): """Clean up inactive or stale connections.""" now = datetime.utcnow() stale_threshold = timedelta(minutes=5) # 5 minutes without ping stale_clients = [] for client_id, connection in self.connections.items(): # Check if connection is inactive if not connection.is_active: stale_clients.append(client_id) continue # Check if connection is stale (no ping response) if now - connection.last_ping > stale_threshold: logger.warning(f"Client {client_id} appears stale, disconnecting") stale_clients.append(client_id) # Clean up stale connections for client_id in stale_clients: await self.disconnect(client_id) if stale_clients: logger.info(f"Cleaned up {len(stale_clients)} stale connections") async def start(self): """Start the connection manager.""" if not self._started: self._start_cleanup_task() self._started = True logger.info("Connection manager started") def _start_cleanup_task(self): """Start background cleanup task.""" async def cleanup_loop(): while True: try: await asyncio.sleep(60) # Run every minute await self.cleanup_inactive_connections() # Send periodic ping every 2 minutes if datetime.utcnow().minute % 2 == 0: await self.ping_clients() except Exception as e: logger.error(f"Error in cleanup task: {e}") try: self._cleanup_task = asyncio.create_task(cleanup_loop()) except RuntimeError: # No event loop running, will start later logger.debug("No event loop running, cleanup task will start later") async def shutdown(self): """Shutdown connection manager.""" # Cancel cleanup task if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass # Disconnect all clients await self.disconnect_all() logger.info("Connection manager shutdown complete") # Global connection manager instance connection_manager = ConnectionManager()