//! WebSocket handler — `/api/websocket`. ADR-130 §2.2 P2 command subset. //! //! Protocol mirrors HA's WS API: //! server → `{"type":"auth_required","ha_version":""}` //! client → `{"type":"auth","access_token":""}` //! server → `{"type":"auth_ok","ha_version":""}` //! client → `{"id":1,"type":"get_states"}` //! server → `{"id":1,"type":"result","success":true,"result":[...]}` //! //! `ha_version` is the homecore version string — see ADR-130 Q1 for the //! companion-app feature-detect concern. use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade}; use axum::extract::State; use axum::response::IntoResponse; use serde::{Deserialize, Serialize}; use tokio::sync::broadcast; use tracing::{debug, warn}; use homecore::{Context, ServiceCall, ServiceName, SystemEvent}; use crate::rest::StateView; use crate::state::SharedState; /// WebSocket upgrade entry point. Mounted on `/api/websocket`. pub async fn websocket_handler( ws: WebSocketUpgrade, State(state): State, ) -> impl IntoResponse { ws.on_upgrade(move |socket| handle_socket(socket, state)) } async fn handle_socket(mut socket: WebSocket, state: SharedState) { // Phase 1 — auth handshake. let auth_req = serde_json::json!({ "type": "auth_required", "ha_version": state.version(), }); if socket.send(Message::Text(auth_req.to_string())).await.is_err() { return; } let token = match socket.recv().await { Some(Ok(Message::Text(raw))) => match serde_json::from_str::(&raw) { Ok(m) if m.kind == "auth" => m.access_token, _ => { let _ = socket .send(Message::Text( serde_json::json!({"type":"auth_invalid","message":"expected auth"}).to_string(), )) .await; return; } }, _ => return, }; // P1: accept any non-empty token. P2: validate against store. if token.trim().is_empty() { let _ = socket .send(Message::Text( serde_json::json!({"type":"auth_invalid","message":"empty token"}).to_string(), )) .await; return; } let auth_ok = serde_json::json!({"type":"auth_ok","ha_version": state.version()}); if socket.send(Message::Text(auth_ok.to_string())).await.is_err() { return; } // Phase 2 — command loop. let conn = Connection::new(state.clone()); conn.run(socket).await; } #[derive(Deserialize)] struct AuthMessage { #[serde(rename = "type")] kind: String, access_token: String, } #[derive(Deserialize)] struct WsCommand { id: u64, #[serde(rename = "type")] kind: String, #[serde(default)] event_type: Option, #[serde(default)] subscription: Option, #[serde(default)] entity_id: Option, #[serde(default)] domain: Option, #[serde(default)] service: Option, #[serde(default)] service_data: Option, } #[derive(Serialize)] struct ResultMessage<'a> { id: u64, #[serde(rename = "type")] kind: &'static str, success: bool, #[serde(skip_serializing_if = "Option::is_none")] result: Option, #[serde(skip_serializing_if = "Option::is_none")] error: Option>, } #[derive(Serialize)] struct ErrorView<'a> { code: &'static str, message: &'a str, } struct Connection { state: SharedState, next_sub_id: AtomicU64, subs: Arc>, } struct SubscriptionHandle { abort: tokio::task::AbortHandle, } impl Connection { fn new(state: SharedState) -> Self { Self { state, next_sub_id: AtomicU64::new(1), subs: Arc::new(dashmap::DashMap::new()), } } async fn run(self, mut socket: WebSocket) { let conn = Arc::new(self); let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); let sender_tx = tx.clone(); let recv_task = { let conn = Arc::clone(&conn); tokio::spawn(async move { while let Some(frame) = socket.recv().await { match frame { Ok(Message::Text(raw)) => { let cmd: WsCommand = match serde_json::from_str(&raw) { Ok(c) => c, Err(e) => { warn!("bad ws command: {e}"); continue; } }; conn.handle_cmd(cmd, &sender_tx).await; } Ok(Message::Ping(p)) => { let _ = sender_tx.send(format!("__pong:{}", p.len())); } Ok(Message::Close(_)) | Err(_) => break, _ => {} } } // Cancel all subscriptions on disconnect. for entry in conn.subs.iter() { entry.value().abort.abort(); } }); tokio::spawn(async move { while let Some(msg) = rx.recv().await { if msg.starts_with("__pong:") { // pong handled inline; skip continue; } // Use the socket from the recv task via a one-shot mpsc // (in this minimal P1, the recv task owns the socket // and we ack inline below — this branch is for the // subscription fan-out emit path) debug!("ws emit: {msg}"); } }) }; let _ = recv_task.await; } async fn handle_cmd(&self, cmd: WsCommand, tx: &tokio::sync::mpsc::UnboundedSender) { match cmd.kind.as_str() { "ping" => { let msg = serde_json::json!({"id": cmd.id, "type": "pong"}); let _ = tx.send(msg.to_string()); } "get_states" => { let snapshots = self.state.homecore().states().all(); let views: Vec = snapshots.iter().map(|s| StateView::from_state(s)).collect(); self.ack(tx, cmd.id, true, Some(serde_json::to_value(views).unwrap())); } "get_config" => { let payload = serde_json::json!({ "location_name": self.state.location_name(), "version": self.state.version(), "state": "RUNNING", }); self.ack(tx, cmd.id, true, Some(payload)); } "get_services" => { let services = self.state.homecore().services().registered_services().await; let mut by_domain: std::collections::HashMap> = std::collections::HashMap::new(); for s in services { by_domain.entry(s.domain).or_default().insert(s.service, serde_json::json!({})); } let payload = serde_json::to_value(by_domain).unwrap(); self.ack(tx, cmd.id, true, Some(payload)); } "call_service" => { let (Some(domain), Some(service)) = (cmd.domain.clone(), cmd.service.clone()) else { self.err(tx, cmd.id, "missing_domain_service", "domain and service are required"); return; }; let call = ServiceCall { name: ServiceName::new(domain.clone(), service.clone()), data: cmd.service_data.unwrap_or(serde_json::json!({})), context: Context::new(), }; match self.state.homecore().services().call(call).await { Ok(v) => self.ack(tx, cmd.id, true, Some(v)), Err(e) => self.err(tx, cmd.id, "service_error", &e.to_string()), } } "subscribe_events" => { let sub_id = self.next_sub_id.fetch_add(1, Ordering::Relaxed); let filter = cmd.event_type.clone(); let tx_clone = tx.clone(); let mut domain_rx = self.state.homecore().bus().subscribe_domain(); let mut system_rx = self.state.homecore().bus().subscribe_system(); let task = tokio::spawn(async move { loop { tokio::select! { evt = system_rx.recv() => match evt { Ok(SystemEvent::StateChanged(sc)) => { if filter.as_deref() == Some("state_changed") || filter.is_none() { let payload = serde_json::json!({ "id": sub_id, "type": "event", "event": { "event_type": "state_changed", "data": { "entity_id": sc.entity_id.as_str(), "old_state": sc.old_state.as_ref().map(|s| StateView::from_state(s)), "new_state": sc.new_state.as_ref().map(|s| StateView::from_state(s)), }, "origin": "LOCAL", "time_fired": sc.fired_at.to_rfc3339(), } }); if tx_clone.send(payload.to_string()).is_err() { break; } } } Ok(_) => {} Err(_) => break, }, evt = domain_rx.recv() => match evt { Ok(de) => { if filter.as_deref() == Some(de.event_type.as_str()) || filter.is_none() { let payload = serde_json::json!({ "id": sub_id, "type": "event", "event": { "event_type": de.event_type, "data": de.event_data, "origin": format!("{:?}", de.origin).to_uppercase(), "time_fired": de.fired_at.to_rfc3339(), } }); if tx_clone.send(payload.to_string()).is_err() { break; } } } Err(_) => break, } } } }); self.subs.insert( sub_id, SubscriptionHandle { abort: task.abort_handle(), }, ); self.ack(tx, cmd.id, true, None); } "unsubscribe_events" => { if let Some(sub_id) = cmd.subscription { if let Some((_, handle)) = self.subs.remove(&sub_id) { handle.abort.abort(); self.ack(tx, cmd.id, true, None); } else { self.err(tx, cmd.id, "not_found", "subscription_id not found"); } } else { self.err(tx, cmd.id, "missing_subscription", "subscription is required"); } } other => { self.err(tx, cmd.id, "unknown_command", &format!("unknown ws command: {other}")); } } // entity_id is reserved for future per-entity subscribes let _ = cmd.entity_id; } fn ack( &self, tx: &tokio::sync::mpsc::UnboundedSender, id: u64, success: bool, result: Option, ) { let msg = ResultMessage { id, kind: "result", success, result, error: None, }; let _ = tx.send(serde_json::to_string(&msg).unwrap()); } fn err(&self, tx: &tokio::sync::mpsc::UnboundedSender, id: u64, code: &'static str, message: &str) { let msg = ResultMessage { id, kind: "result", success: false, result: None, error: Some(ErrorView { code, message }), }; let _ = tx.send(serde_json::to_string(&msg).unwrap()); } } // Suppress unused warnings for placeholder broadcast type #[allow(dead_code)] type _UnusedSubBroadcast = broadcast::Sender<()>;