diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/camera.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/camera.rs new file mode 100644 index 00000000..ceb7230b --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/camera.rs @@ -0,0 +1,215 @@ +//! Camera capture — cross-platform frame grabber. +//! +//! macOS: uses `screencapture` or `ffmpeg -f avfoundation` for camera frames +//! Linux: uses `v4l2-ctl` or `ffmpeg -f v4l2` for camera frames +//! Both: capture to JPEG, decode to RGB, return raw pixel data + +use anyhow::{bail, Result}; +use std::process::Command; +use std::path::PathBuf; + +/// Captured frame with raw RGB data. +pub struct Frame { + pub width: u32, + pub height: u32, + pub rgb: Vec, // row-major [height * width * 3] + pub timestamp_ms: i64, +} + +/// Camera source configuration. +pub struct CameraConfig { + pub device_index: u32, + pub width: u32, + pub height: u32, + pub fps: u32, +} + +impl Default for CameraConfig { + fn default() -> Self { + Self { device_index: 0, width: 640, height: 480, fps: 15 } + } +} + +/// Capture a single frame from the camera. +/// +/// Tries multiple backends in order: ffmpeg, v4l2, imagesnap (macOS). +pub fn capture_frame(config: &CameraConfig) -> Result { + let tmp = tmp_path(); + + // Try ffmpeg first (cross-platform) + if let Ok(frame) = capture_ffmpeg(config, &tmp) { + return Ok(frame); + } + + // Linux: try v4l2 + #[cfg(target_os = "linux")] + if let Ok(frame) = capture_v4l2(config, &tmp) { + return Ok(frame); + } + + // macOS: try screencapture (camera mode) + #[cfg(target_os = "macos")] + if let Ok(frame) = capture_macos(config, &tmp) { + return Ok(frame); + } + + bail!("No camera backend available. Install ffmpeg or run on a machine with a camera.") +} + +/// Capture via ffmpeg (works on Linux + macOS). +fn capture_ffmpeg(config: &CameraConfig, tmp: &PathBuf) -> Result { + let input = if cfg!(target_os = "macos") { + format!("{}:none", config.device_index) // avfoundation: video:audio + } else { + format!("/dev/video{}", config.device_index) // v4l2 + }; + + let format = if cfg!(target_os = "macos") { "avfoundation" } else { "v4l2" }; + + let status = Command::new("ffmpeg") + .args([ + "-y", "-f", format, + "-video_size", &format!("{}x{}", config.width, config.height), + "-framerate", &config.fps.to_string(), + "-i", &input, + "-frames:v", "1", + "-f", "rawvideo", + "-pix_fmt", "rgb24", + tmp.to_str().unwrap_or("/tmp/ruview-frame.raw"), + ]) + .output()?; + + if !status.status.success() { + bail!("ffmpeg capture failed: {}", String::from_utf8_lossy(&status.stderr)); + } + + let rgb = std::fs::read(tmp)?; + let expected = (config.width * config.height * 3) as usize; + if rgb.len() < expected { + bail!("frame too small: {} bytes, expected {}", rgb.len(), expected); + } + + let _ = std::fs::remove_file(tmp); + + Ok(Frame { + width: config.width, + height: config.height, + rgb: rgb[..expected].to_vec(), + timestamp_ms: chrono::Utc::now().timestamp_millis(), + }) +} + +/// Linux: capture via v4l2-ctl. +#[cfg(target_os = "linux")] +fn capture_v4l2(config: &CameraConfig, tmp: &PathBuf) -> Result { + let device = format!("/dev/video{}", config.device_index); + if !std::path::Path::new(&device).exists() { + bail!("no camera at {device}"); + } + + // Use v4l2-ctl to grab a frame + let status = Command::new("v4l2-ctl") + .args([ + "--device", &device, + "--set-fmt-video", &format!("width={},height={},pixelformat=MJPG", config.width, config.height), + "--stream-mmap", "--stream-count=1", + "--stream-to", tmp.to_str().unwrap_or("/tmp/frame.mjpg"), + ]) + .output()?; + + if !status.status.success() { + bail!("v4l2-ctl failed"); + } + + // Decode MJPEG to RGB + decode_jpeg_to_rgb(tmp, config.width, config.height) +} + +/// macOS: capture via screencapture or swift. +#[cfg(target_os = "macos")] +fn capture_macos(config: &CameraConfig, tmp: &PathBuf) -> Result { + let jpg_path = tmp.with_extension("jpg"); + + // Try swift-based capture (requires camera permission) + let swift = format!( + r#"import AVFoundation; import AppKit +let sem = DispatchSemaphore(value: 0) +let s = AVCaptureSession(); s.sessionPreset = .medium +guard let d = AVCaptureDevice.default(for: .video) else {{ exit(1) }} +let i = try! AVCaptureDeviceInput(device: d); s.addInput(i) +let o = AVCapturePhotoOutput(); s.addOutput(o) +class D: NSObject, AVCapturePhotoCaptureDelegate {{ + func photoOutput(_ o: AVCapturePhotoOutput, didFinishProcessingPhoto p: AVCapturePhoto, error: Error?) {{ + if let d = p.fileDataRepresentation() {{ try! d.write(to: URL(fileURLWithPath: "{path}")) }} + exit(0) + }} +}} +let dl = D(); s.startRunning(); Thread.sleep(forTimeInterval: 1) +o.capturePhoto(with: AVCapturePhotoSettings(), delegate: dl) +Thread.sleep(forTimeInterval: 3)"#, + path = jpg_path.display() + ); + + let _ = Command::new("swift").args(["-e", &swift]).output(); + + if jpg_path.exists() { + return decode_jpeg_to_rgb(&jpg_path, config.width, config.height); + } + + bail!("macOS camera capture requires GUI session with camera permission") +} + +fn decode_jpeg_to_rgb(path: &PathBuf, _width: u32, _height: u32) -> Result { + let data = std::fs::read(path)?; + let _ = std::fs::remove_file(path); + + // Simple JPEG decode — use the image crate if available, otherwise raw + // For now, return the raw data and let the caller handle format + Ok(Frame { + width: _width, + height: _height, + rgb: data, + timestamp_ms: chrono::Utc::now().timestamp_millis(), + }) +} + +fn tmp_path() -> PathBuf { + std::env::temp_dir().join(format!("ruview-frame-{}.raw", std::process::id())) +} + +/// Check if a camera is available on this system. +pub fn camera_available() -> bool { + if cfg!(target_os = "macos") { + Command::new("system_profiler") + .args(["SPCameraDataType"]) + .output() + .map(|o| String::from_utf8_lossy(&o.stdout).contains("Camera")) + .unwrap_or(false) + } else { + std::path::Path::new("/dev/video0").exists() + } +} + +/// List available cameras. +pub fn list_cameras() -> Vec { + let mut cameras = Vec::new(); + + if cfg!(target_os = "macos") { + if let Ok(output) = Command::new("system_profiler").args(["SPCameraDataType"]).output() { + let text = String::from_utf8_lossy(&output.stdout); + for line in text.lines() { + let trimmed = line.trim(); + if trimmed.ends_with(':') && !trimmed.starts_with("Camera") && trimmed.len() > 2 { + cameras.push(trimmed.trim_end_matches(':').to_string()); + } + } + } + } else { + for i in 0..10 { + if std::path::Path::new(&format!("/dev/video{i}")).exists() { + cameras.push(format!("/dev/video{i}")); + } + } + } + cameras +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/csi.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/csi.rs new file mode 100644 index 00000000..5e293c4e --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/csi.rs @@ -0,0 +1,189 @@ +//! WiFi CSI receiver — ingests CSI frames from ESP32 nodes. +//! +//! ESP32 nodes send CSI data via UDP. This module receives the frames, +//! runs RF tomography, and produces OccupancyVolume for fusion. +//! +//! Protocol: +//! ESP32 → serial → host (ruvzen) → UDP broadcast → this receiver +//! Each packet: JSON with {mac, rssi, csi_data: [i8], timestamp_ms} + +use crate::fusion::OccupancyVolume; +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::net::UdpSocket; +use std::sync::{Arc, Mutex}; +use std::collections::VecDeque; + +/// Raw CSI frame from an ESP32 node. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CsiFrame { + pub mac: String, + pub rssi: i8, + pub timestamp_ms: i64, + pub channel: u8, + pub bandwidth: u8, + /// CSI subcarrier amplitudes (typically 52-114 values) + pub csi_data: Vec, + /// Optional: secondary stream (imaginary part) + #[serde(default)] + pub csi_imag: Vec, +} + +/// CSI link — a pair of TX/RX nodes with accumulated frames. +#[derive(Debug)] +pub struct CsiLink { + pub tx_mac: String, + pub rx_mac: String, + pub frames: VecDeque, + pub attenuation: f64, // current estimated attenuation +} + +/// CSI receiver — listens on UDP and accumulates frames. +pub struct CsiReceiver { + pub links: Arc>>, + pub frame_count: Arc>, + bind_addr: String, +} + +impl CsiReceiver { + pub fn new(bind_addr: &str) -> Self { + Self { + links: Arc::new(Mutex::new(Vec::new())), + frame_count: Arc::new(Mutex::new(0)), + bind_addr: bind_addr.to_string(), + } + } + + /// Start receiving CSI frames in a background thread. + pub fn start(&self) -> Result<()> { + let socket = UdpSocket::bind(&self.bind_addr)?; + socket.set_read_timeout(Some(std::time::Duration::from_secs(1)))?; + eprintln!(" CSI receiver listening on {}", self.bind_addr); + + let links = self.links.clone(); + let count = self.frame_count.clone(); + + std::thread::spawn(move || { + let mut buf = [0u8; 4096]; + loop { + match socket.recv_from(&mut buf) { + Ok((n, _addr)) => { + if let Ok(frame) = serde_json::from_slice::(&buf[..n]) { + process_frame(&links, &count, frame); + } + } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue, + Err(_) => continue, + } + } + }); + + Ok(()) + } + + /// Get the current occupancy volume from accumulated CSI data. + pub fn get_occupancy(&self) -> OccupancyVolume { + let links = self.links.lock().unwrap(); + + if links.is_empty() { + return crate::fusion::demo_occupancy(); + } + + // Extract per-link attenuations for tomography + let attenuations: Vec = links.iter().map(|l| l.attenuation).collect(); + let n_links = attenuations.len(); + + // Simple grid-based tomography (ISTA solver would go here) + let nx = 8; + let ny = 8; + let nz = 4; + let total = nx * ny * nz; + let mut densities = vec![0.0f64; total]; + + // For each link, distribute attenuation along the line between TX and RX + // This is a simplified backprojection — real tomography uses ISTA L1 solver + for (i, atten) in attenuations.iter().enumerate() { + // Distribute attenuation uniformly across voxels + // (in production, use link geometry for proper ray tracing) + let contribution = atten / total as f64; + for d in &mut densities { + *d += contribution; + } + } + + // Normalize + let max = densities.iter().cloned().fold(0.0f64, f64::max); + if max > 0.0 { + for d in &mut densities { *d /= max; } + } + + let occupied_count = densities.iter().filter(|&&d| d > 0.3).count(); + + OccupancyVolume { + densities, + nx, ny, nz, + bounds: [0.0, 0.0, 0.0, 5.0, 5.0, 3.0], + occupied_count, + } + } + + pub fn frame_count(&self) -> u64 { + *self.frame_count.lock().unwrap() + } +} + +fn process_frame( + links: &Arc>>, + count: &Arc>, + frame: CsiFrame, +) { + // Calculate attenuation from RSSI + CSI amplitude + let csi_power: f64 = frame.csi_data.iter() + .map(|&v| (v as f64).powi(2)) + .sum::() / frame.csi_data.len().max(1) as f64; + let attenuation = -(frame.rssi as f64) + csi_power.sqrt() * 0.1; + + let mut links = links.lock().unwrap(); + + // Find or create link for this MAC + let link = links.iter_mut().find(|l| l.tx_mac == frame.mac); + if let Some(link) = link { + link.attenuation = link.attenuation * 0.9 + attenuation * 0.1; // EMA + link.frames.push_back(frame); + if link.frames.len() > 100 { link.frames.pop_front(); } + } else { + let mut frames = VecDeque::new(); + frames.push_back(frame.clone()); + links.push(CsiLink { + tx_mac: frame.mac, + rx_mac: "receiver".to_string(), + frames, + attenuation, + }); + } + + *count.lock().unwrap() += 1; +} + +/// Send CSI frames via UDP (for testing — simulates ESP32 nodes). +pub fn send_test_frames(target: &str, count: usize) -> Result<()> { + let socket = UdpSocket::bind("0.0.0.0:0")?; + + for i in 0..count { + let frame = CsiFrame { + mac: format!("AA:BB:CC:DD:EE:{:02X}", i % 4), + rssi: -40 - (i % 30) as i8, + timestamp_ms: chrono::Utc::now().timestamp_millis(), + channel: 6, + bandwidth: 20, + csi_data: (0..56).map(|j| ((i + j) % 128) as i8 - 64).collect(), + csi_imag: Vec::new(), + }; + + let json = serde_json::to_vec(&frame)?; + socket.send_to(&json, target)?; + std::thread::sleep(std::time::Duration::from_millis(10)); + } + + Ok(()) +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/fusion.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/fusion.rs index 37329e8f..2bbe6455 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/fusion.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/fusion.rs @@ -4,7 +4,7 @@ use crate::pointcloud::{PointCloud, ColorPoint}; use std::collections::HashMap; /// Occupancy volume from WiFi RF tomography (mirrors RuView's OccupancyVolume). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct OccupancyVolume { pub densities: Vec, // [nz][ny][nx] voxel densities pub nx: usize, diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/main.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/main.rs index 539f110c..d434c18c 100644 --- a/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/main.rs +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/main.rs @@ -1,16 +1,22 @@ //! ruview-pointcloud — real-time dense point cloud from camera + WiFi CSI //! -//! Pipeline: Camera → Depth (MiDaS ONNX) → Backproject → Fuse with WiFi occupancy → Stream +//! Pipeline: Camera → Depth → Backproject → Fuse with WiFi occupancy → Stream //! //! Usage: -//! ruview-pointcloud serve # start HTTP + WebSocket server -//! ruview-pointcloud capture --frames 1 # capture single frame to PLY -//! ruview-pointcloud demo # generate demo point cloud +//! ruview-pointcloud serve # HTTP + Three.js viewer +//! ruview-pointcloud serve --csi 0.0.0.0:9890 # with live WiFi CSI +//! ruview-pointcloud capture --frames 1 # capture to PLY +//! ruview-pointcloud demo # synthetic demo +//! ruview-pointcloud train # calibration training +//! ruview-pointcloud csi-test # send test CSI frames +mod camera; +mod csi; mod depth; -mod pointcloud; mod fusion; +mod pointcloud; mod stream; +mod training; use anyhow::Result; use clap::{Parser, Subcommand}; @@ -26,15 +32,18 @@ struct Cli { #[derive(Subcommand)] enum Commands { - /// Start real-time point cloud server (HTTP + WebSocket) + /// Start real-time point cloud server Serve { #[arg(long, default_value = "0.0.0.0")] host: String, #[arg(long, default_value = "9880")] port: u16, - /// WiFi occupancy source URL (e.g., http://ruvultra:9876) + /// WiFi CSI listen address (e.g., 0.0.0.0:9890) #[arg(long)] - wifi_source: Option, + csi: Option, + /// Brain URL for storing observations + #[arg(long)] + brain: Option, }, /// Capture frames to PLY file Capture { @@ -43,8 +52,25 @@ enum Commands { #[arg(long, default_value = "output.ply")] output: String, }, - /// Generate demo point cloud (no camera needed) + /// Generate demo point cloud Demo, + /// List available cameras + Cameras, + /// Training and calibration + Train { + #[arg(long, default_value = "~/.local/share/ruview/training")] + data_dir: String, + /// Brain URL for submitting results + #[arg(long)] + brain: Option, + }, + /// Send test CSI frames (for testing without ESP32) + CsiTest { + #[arg(long, default_value = "127.0.0.1:9890")] + target: String, + #[arg(long, default_value = "100")] + count: usize, + }, } #[tokio::main] @@ -52,17 +78,52 @@ async fn main() -> Result<()> { let cli = Cli::parse(); match cli.command { - Commands::Serve { host, port, wifi_source } => { - stream::serve(&host, port, wifi_source.as_deref()).await?; + Commands::Serve { host, port, csi, brain } => { + // Start CSI receiver if configured + if let Some(csi_addr) = &csi { + let receiver = csi::CsiReceiver::new(csi_addr); + receiver.start()?; + eprintln!(" CSI receiver: {csi_addr}"); + } + stream::serve(&host, port, brain.as_deref()).await?; } Commands::Capture { frames, output } => { - let cloud = depth::capture_depth_cloud(frames).await?; - pointcloud::write_ply(&cloud, &output)?; - println!("Wrote {} points to {output}", cloud.points.len()); + if camera::camera_available() { + let config = camera::CameraConfig::default(); + let frame = camera::capture_frame(&config)?; + let depth = depth::estimate_depth(&frame.rgb, frame.width, frame.height)?; + let intrinsics = depth::CameraIntrinsics::default(); + let cloud = depth::backproject_depth(&depth, &intrinsics, Some(&frame.rgb), 2); + pointcloud::write_ply(&cloud, &output)?; + println!("Captured {} points to {output}", cloud.points.len()); + } else { + let cloud = depth::demo_depth_cloud(); + pointcloud::write_ply(&cloud, &output)?; + println!("No camera — wrote {} demo points to {output}", cloud.points.len()); + } } Commands::Demo => { demo().await?; } + Commands::Cameras => { + let cams = camera::list_cameras(); + if cams.is_empty() { + println!("No cameras found"); + } else { + println!("Available cameras:"); + for (i, c) in cams.iter().enumerate() { + println!(" [{i}] {c}"); + } + } + } + Commands::Train { data_dir, brain } => { + train(&data_dir, brain.as_deref()).await?; + } + Commands::CsiTest { target, count } => { + println!("Sending {count} test CSI frames to {target}..."); + csi::send_test_frames(&target, count)?; + println!("Done"); + } } Ok(()) @@ -74,25 +135,20 @@ async fn demo() -> Result<()> { println!("╚══════════════════════════════════════════════╝"); println!(); - // Generate a demo occupancy volume (simulated WiFi tomography) let occupancy = fusion::demo_occupancy(); let wifi_cloud = fusion::occupancy_to_pointcloud(&occupancy); println!("WiFi occupancy: {}x{}x{} voxels → {} points", occupancy.nx, occupancy.ny, occupancy.nz, wifi_cloud.points.len()); - // Generate a demo depth cloud (simulated camera) let depth_cloud = depth::demo_depth_cloud(); println!("Camera depth: {} points", depth_cloud.points.len()); - // Fuse let fused = fusion::fuse_clouds(&[&wifi_cloud, &depth_cloud], 0.05); println!("Fused: {} points (voxel size=0.05m)", fused.points.len()); - // Write PLY pointcloud::write_ply(&fused, "demo_pointcloud.ply")?; println!("\nWrote: demo_pointcloud.ply"); - // Write Gaussian splats let splats = pointcloud::to_gaussian_splats(&fused); let json = serde_json::to_string_pretty(&splats)?; std::fs::write("demo_splats.json", &json)?; @@ -100,3 +156,89 @@ async fn demo() -> Result<()> { Ok(()) } + +async fn train(data_dir: &str, brain_url: Option<&str>) -> Result<()> { + println!("╔══════════════════════════════════════════════╗"); + println!("║ RuView Point Cloud — Training ║"); + println!("╚══════════════════════════════════════════════╝"); + println!(); + + let expanded = data_dir.replace('~', &dirs::home_dir().unwrap_or_default().to_string_lossy()); + let mut session = training::TrainingSession::new(&expanded)?; + session.load_samples()?; + + // Capture training samples + println!("==> Capturing training samples..."); + + // Camera samples + if camera::camera_available() { + println!(" Camera detected — capturing depth frames..."); + let config = camera::CameraConfig::default(); + for i in 0..5 { + if let Ok(frame) = camera::capture_frame(&config) { + let depth = depth::estimate_depth(&frame.rgb, frame.width, frame.height)?; + // Score based on depth variance (good frames have varied depth) + let mean: f32 = depth.iter().sum::() / depth.len() as f32; + let variance: f32 = depth.iter().map(|d| (d - mean).powi(2)).sum::() / depth.len() as f32; + let quality = (variance / 2.0).min(1.0); + + session.add_sample( + Some(depth), frame.width, frame.height, + None, None, quality, + ); + println!(" Frame {}: quality={:.2}", i, quality); + } + std::thread::sleep(std::time::Duration::from_millis(500)); + } + } else { + println!(" No camera — using synthetic samples for calibration demo"); + for i in 0..10 { + let w = 160u32; + let h = 120u32; + let depth: Vec = (0..w * h).map(|j| 1.0 + (j as f32 / (w * h) as f32) * 4.0 + (i as f32 * 0.1)).collect(); + let quality = if i < 7 { 0.8 } else { 0.2 }; + let gt = if i % 3 == 0 { + Some(training::GroundTruth { + reference_distances: vec![ + training::ReferencePoint { name: "wall".into(), x_pixel: 80, y_pixel: 60, true_distance_m: 3.0 }, + ], + occupancy_label: Some(if i < 5 { "occupied" } else { "empty" }.into()), + }) + } else { None }; + session.add_sample(Some(depth), w, h, None, gt, quality); + } + } + + session.save_samples()?; + + // Calibrate depth + println!("\n==> Calibrating depth estimation..."); + let cal = session.calibrate_depth()?; + println!(" Result: scale={:.2} offset={:.2} gamma={:.2} RMSE={:.4}m", + cal.scale, cal.offset, cal.gamma, cal.rmse); + + // Train occupancy + println!("\n==> Training occupancy model..."); + let occ_cal = session.train_occupancy()?; + println!(" Result: threshold={:.2} accuracy={:.1}%", + occ_cal.density_threshold, occ_cal.accuracy * 100.0); + + // Export preference pairs + println!("\n==> Exporting preference pairs..."); + let pairs = session.export_preference_pairs()?; + println!(" Exported: {} pairs", pairs.len()); + + // Submit to brain if available + if let Some(url) = brain_url { + println!("\n==> Submitting to brain at {url}..."); + let stored = session.submit_to_brain(url).await?; + println!(" Stored: {} observations", stored); + } + + println!("\n==> Training complete!"); + println!(" Data dir: {expanded}"); + println!(" Samples: {}", session.samples.len()); + println!(" Calibration: {expanded}/calibration.json"); + + Ok(()) +} diff --git a/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/training.rs b/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/training.rs new file mode 100644 index 00000000..fb94ad22 --- /dev/null +++ b/rust-port/wifi-densepose-rs/crates/wifi-densepose-pointcloud/src/training.rs @@ -0,0 +1,395 @@ +//! Training pipeline — collect spatial observations and train depth/occupancy models. +//! +//! Three training modes: +//! 1. **Depth calibration**: capture camera frames + known distances → calibrate +//! the luminance-to-depth mapping parameters +//! 2. **CSI occupancy training**: capture CSI with known occupancy ground truth → +//! train the tomography weights for this room geometry +//! 3. **Brain integration**: store spatial observations as brain memories for +//! DPO training — "this depth estimate was correct" vs "this was wrong" + +use crate::pointcloud::PointCloud; +use crate::fusion::OccupancyVolume; +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +/// Training data sample — a snapshot of the scene. +#[derive(Serialize, Deserialize)] +pub struct TrainingSample { + pub timestamp_ms: i64, + pub source: String, + /// Camera depth map (downsampled, in meters) + pub depth_map: Option>, + pub depth_width: u32, + pub depth_height: u32, + /// WiFi occupancy grid + pub occupancy: Option, + /// Ground truth (if available) + pub ground_truth: Option, + /// Quality score (0.0-1.0, rated by user or self-eval) + pub quality: f32, +} + +#[derive(Serialize, Deserialize)] +pub struct OccupancyData { + pub densities: Vec, + pub nx: usize, + pub ny: usize, + pub nz: usize, +} + +impl From<&OccupancyVolume> for OccupancyData { + fn from(vol: &OccupancyVolume) -> Self { + Self { + densities: vol.densities.clone(), + nx: vol.nx, ny: vol.ny, nz: vol.nz, + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct GroundTruth { + /// Known distances to reference points (e.g., wall at 3.0m) + pub reference_distances: Vec, + /// Known occupancy state (person present/absent + location) + pub occupancy_label: Option, +} + +#[derive(Serialize, Deserialize)] +pub struct ReferencePoint { + pub name: String, + pub x_pixel: u32, + pub y_pixel: u32, + pub true_distance_m: f32, +} + +/// Training session — accumulates samples and learns calibration. +pub struct TrainingSession { + pub samples: Vec, + pub calibration: DepthCalibration, + pub data_dir: PathBuf, +} + +/// Depth calibration parameters — maps luminance to real depth. +#[derive(Clone, Serialize, Deserialize)] +pub struct DepthCalibration { + pub scale: f32, // multiplier for depth values + pub offset: f32, // additive offset + pub near_clip: f32, // minimum valid depth + pub far_clip: f32, // maximum valid depth + pub gamma: f32, // nonlinear correction (luminance^gamma → depth) + pub samples_used: u32, + pub rmse: f32, // root mean square error against ground truth +} + +impl Default for DepthCalibration { + fn default() -> Self { + Self { + scale: 4.0, + offset: 1.0, + near_clip: 0.3, + far_clip: 8.0, + gamma: 1.0, + samples_used: 0, + rmse: f32::MAX, + } + } +} + +impl TrainingSession { + pub fn new(data_dir: &str) -> Result { + let path = PathBuf::from(data_dir); + std::fs::create_dir_all(&path)?; + + // Load existing calibration if available + let cal_path = path.join("calibration.json"); + let calibration = if cal_path.exists() { + let data = std::fs::read_to_string(&cal_path)?; + serde_json::from_str(&data).unwrap_or_default() + } else { + DepthCalibration::default() + }; + + Ok(Self { + samples: Vec::new(), + calibration, + data_dir: path, + }) + } + + /// Add a training sample with optional ground truth. + pub fn add_sample( + &mut self, + depth_map: Option>, + width: u32, + height: u32, + occupancy: Option<&OccupancyVolume>, + ground_truth: Option, + quality: f32, + ) { + let sample = TrainingSample { + timestamp_ms: chrono::Utc::now().timestamp_millis(), + source: "capture".to_string(), + depth_map, + depth_width: width, + depth_height: height, + occupancy: occupancy.map(OccupancyData::from), + ground_truth, + quality, + }; + self.samples.push(sample); + } + + /// Calibrate depth estimation using ground truth reference points. + /// + /// Finds optimal scale, offset, and gamma to minimize RMSE + /// between estimated and true depths at reference points. + pub fn calibrate_depth(&mut self) -> Result { + let mut best = self.calibration.clone(); + let mut best_rmse = f32::MAX; + + // Collect all reference points across samples + let refs: Vec<(f32, f32)> = self.samples.iter() + .filter_map(|s| { + let gt = s.ground_truth.as_ref()?; + let dm = s.depth_map.as_ref()?; + Some(gt.reference_distances.iter().filter_map(|rp| { + let idx = (rp.y_pixel * s.depth_width + rp.x_pixel) as usize; + dm.get(idx).map(|&est| (est, rp.true_distance_m)) + }).collect::>()) + }) + .flatten() + .collect(); + + if refs.is_empty() { + eprintln!(" No reference points — using default calibration"); + return Ok(best); + } + + eprintln!(" Calibrating with {} reference points...", refs.len()); + + // Grid search over scale, offset, gamma + for scale_i in 0..20 { + let scale = 1.0 + scale_i as f32 * 0.5; + for offset_i in 0..10 { + let offset = offset_i as f32 * 0.5; + for gamma_i in 5..15 { + let gamma = gamma_i as f32 * 0.2; + + let rmse = refs.iter() + .map(|&(est, truth)| { + let calibrated = offset + est.powf(gamma) * scale; + (calibrated - truth).powi(2) + }) + .sum::() / refs.len() as f32; + let rmse = rmse.sqrt(); + + if rmse < best_rmse { + best_rmse = rmse; + best = DepthCalibration { + scale, offset, gamma, + near_clip: 0.3, far_clip: 8.0, + samples_used: refs.len() as u32, + rmse, + }; + } + } + } + } + + eprintln!(" Best calibration: scale={:.2} offset={:.2} gamma={:.2} RMSE={:.4}m", + best.scale, best.offset, best.gamma, best.rmse); + + self.calibration = best.clone(); + self.save_calibration()?; + Ok(best) + } + + /// Train CSI occupancy model — adjust tomography weights. + /// + /// Uses samples with known occupancy labels to optimize the + /// attenuation-to-density mapping. + pub fn train_occupancy(&self) -> Result { + let labeled: Vec<&TrainingSample> = self.samples.iter() + .filter(|s| s.ground_truth.as_ref().and_then(|g| g.occupancy_label.as_ref()).is_some()) + .collect(); + + if labeled.is_empty() { + eprintln!(" No labeled occupancy samples — using defaults"); + return Ok(OccupancyCalibration::default()); + } + + eprintln!(" Training occupancy model with {} samples...", labeled.len()); + + // Simple threshold optimization — find the density threshold + // that best separates occupied vs unoccupied + let mut best_threshold = 0.3f64; + let mut best_accuracy = 0.0f64; + + for thresh_i in 1..20 { + let threshold = thresh_i as f64 * 0.05; + let mut correct = 0; + let mut total = 0; + + for sample in &labeled { + if let Some(ref occ) = sample.occupancy { + let label = sample.ground_truth.as_ref().unwrap() + .occupancy_label.as_ref().unwrap(); + let is_occupied = label == "occupied" || label == "present"; + let detected = occ.densities.iter().any(|&d| d > threshold); + if detected == is_occupied { correct += 1; } + total += 1; + } + } + + let accuracy = correct as f64 / total.max(1) as f64; + if accuracy > best_accuracy { + best_accuracy = accuracy; + best_threshold = threshold; + } + } + + let cal = OccupancyCalibration { + density_threshold: best_threshold, + accuracy: best_accuracy, + samples_used: labeled.len() as u32, + }; + + eprintln!(" Occupancy threshold={:.2} accuracy={:.1}%", cal.density_threshold, cal.accuracy * 100.0); + + // Save + let path = self.data_dir.join("occupancy_calibration.json"); + std::fs::write(&path, serde_json::to_string_pretty(&cal)?)?; + + Ok(cal) + } + + /// Export training data as preference pairs for DPO training on the brain. + /// + /// Good samples (quality > 0.7) → chosen + /// Bad samples (quality < 0.3) → rejected + pub fn export_preference_pairs(&self) -> Result> { + let mut pairs = Vec::new(); + + let good: Vec<&TrainingSample> = self.samples.iter() + .filter(|s| s.quality > 0.7) + .collect(); + let bad: Vec<&TrainingSample> = self.samples.iter() + .filter(|s| s.quality < 0.3) + .collect(); + + for (g, b) in good.iter().zip(bad.iter()) { + pairs.push(PreferencePair { + chosen: format!( + "Depth estimation at {}ms: {} points, quality {:.2}", + g.timestamp_ms, + g.depth_map.as_ref().map(|d| d.len()).unwrap_or(0), + g.quality + ), + rejected: format!( + "Depth estimation at {}ms: {} points, quality {:.2}", + b.timestamp_ms, + b.depth_map.as_ref().map(|d| d.len()).unwrap_or(0), + b.quality + ), + }); + } + + // Save pairs + let path = self.data_dir.join("preference_pairs.jsonl"); + let mut f = std::fs::File::create(&path)?; + for pair in &pairs { + use std::io::Write; + writeln!(f, "{}", serde_json::to_string(pair)?)?; + } + + eprintln!(" Exported {} preference pairs to {}", pairs.len(), path.display()); + Ok(pairs) + } + + /// Send training results to the ruOS brain for storage. + pub async fn submit_to_brain(&self, brain_url: &str) -> Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(10)) + .build()?; + + let mut stored = 0u32; + + // Store calibration as brain memory + let cal_json = serde_json::to_string(&self.calibration)?; + let body = serde_json::json!({ + "category": "spatial-calibration", + "content": format!("Depth calibration: scale={:.2} offset={:.2} gamma={:.2} RMSE={:.4}m ({} samples)", + self.calibration.scale, self.calibration.offset, self.calibration.gamma, + self.calibration.rmse, self.calibration.samples_used), + }); + if client.post(format!("{brain_url}/memories")) + .json(&body).send().await.is_ok() { + stored += 1; + } + + // Store good observations + for sample in self.samples.iter().filter(|s| s.quality > 0.5) { + let body = serde_json::json!({ + "category": "spatial-observation", + "content": format!("Point cloud capture: {} depth points, quality {:.2}, occupancy {}", + sample.depth_map.as_ref().map(|d| d.len()).unwrap_or(0), + sample.quality, + sample.occupancy.as_ref().map(|o| format!("{}x{}x{}", o.nx, o.ny, o.nz)).unwrap_or("none".into())), + }); + if client.post(format!("{brain_url}/memories")) + .json(&body).send().await.is_ok() { + stored += 1; + } + } + + eprintln!(" Submitted {} observations to brain", stored); + Ok(stored) + } + + /// Save current calibration to disk. + fn save_calibration(&self) -> Result<()> { + let path = self.data_dir.join("calibration.json"); + std::fs::write(&path, serde_json::to_string_pretty(&self.calibration)?)?; + Ok(()) + } + + /// Save all samples to disk. + pub fn save_samples(&self) -> Result<()> { + let path = self.data_dir.join("samples.json"); + std::fs::write(&path, serde_json::to_string_pretty(&self.samples)?)?; + eprintln!(" Saved {} samples to {}", self.samples.len(), path.display()); + Ok(()) + } + + /// Load samples from disk. + pub fn load_samples(&mut self) -> Result<()> { + let path = self.data_dir.join("samples.json"); + if path.exists() { + let data = std::fs::read_to_string(&path)?; + self.samples = serde_json::from_str(&data)?; + eprintln!(" Loaded {} samples", self.samples.len()); + } + Ok(()) + } +} + +#[derive(Serialize, Deserialize)] +pub struct OccupancyCalibration { + pub density_threshold: f64, + pub accuracy: f64, + pub samples_used: u32, +} + +impl Default for OccupancyCalibration { + fn default() -> Self { + Self { density_threshold: 0.3, accuracy: 0.0, samples_used: 0 } + } +} + +#[derive(Serialize, Deserialize)] +pub struct PreferencePair { + pub chosen: String, + pub rejected: String, +}