//! Temporal Neural Solver - Optimized WASM Implementation //! Ultra-fast neural network inference for JavaScript/TypeScript use wasm_bindgen::prelude::*; use serde::{Serialize, Deserialize}; use std::sync::Arc; // Enable console.log for debugging #[wasm_bindgen] extern "C" { #[wasm_bindgen(js_namespace = console)] fn log(s: &str); } macro_rules! console_log { ($($t:tt)*) => (log(&format_args!($($t)*).to_string())) } #[derive(Serialize, Deserialize)] pub struct PredictionResult { pub output: Vec, pub latency_ns: u64, } #[derive(Serialize, Deserialize)] pub struct BatchResult { pub predictions: Vec>, pub total_latency_ms: f64, pub avg_latency_us: f64, pub throughput_ops_sec: f64, } #[wasm_bindgen] pub struct TemporalNeuralSolver { // Optimized weight matrices (flattened for cache efficiency) weights1_flat: Vec, // 128x32 = 4096 elements weights2_flat: Vec, // 32x4 = 128 elements bias1: Vec, // 32 elements bias2: Vec, // 4 elements // Temporal state for Kalman filtering state: Vec, covariance: Vec, } #[wasm_bindgen] impl TemporalNeuralSolver { /// Create a new solver instance #[wasm_bindgen(constructor)] pub fn new() -> Self { console_error_panic_hook::set_once(); // Initialize weights with optimized patterns let mut weights1_flat = vec![0.0f32; 128 * 32]; let mut weights2_flat = vec![0.0f32; 32 * 4]; // Xavier initialization let scale1 = (2.0 / 128.0_f32).sqrt(); let scale2 = (2.0 / 32.0_f32).sqrt(); for i in 0..weights1_flat.len() { weights1_flat[i] = ((i as f32 * 0.1337).sin() * scale1).tanh(); } for i in 0..weights2_flat.len() { weights2_flat[i] = ((i as f32 * 0.2718).cos() * scale2).tanh(); } Self { weights1_flat, weights2_flat, bias1: vec![0.01; 32], bias2: vec![0.01; 4], state: vec![0.0; 4], covariance: vec![1.0; 4], } } /// Single prediction with sub-microsecond target latency #[wasm_bindgen] pub fn predict(&mut self, input: &[f32]) -> Result { if input.len() != 128 { return Err(JsValue::from_str("Input must be exactly 128 elements")); } let start = web_time::Instant::now(); // Optimized forward pass with loop unrolling let mut hidden = [0.0f32; 32]; // Layer 1: Matrix multiply with 4x unrolling for i in 0..32 { let offset = i * 128; let mut sum = self.bias1[i]; // Process 4 elements at a time let mut j = 0; while j < 128 { sum += input[j] * self.weights1_flat[offset + j]; sum += input[j + 1] * self.weights1_flat[offset + j + 1]; sum += input[j + 2] * self.weights1_flat[offset + j + 2]; sum += input[j + 3] * self.weights1_flat[offset + j + 3]; j += 4; } // ReLU activation hidden[i] = sum.max(0.0); } // Layer 2: Output layer let mut output = [0.0f32; 4]; for i in 0..4 { let offset = i * 32; let mut sum = self.bias2[i]; // Unrolled by 4 let mut j = 0; while j < 32 { sum += hidden[j] * self.weights2_flat[offset + j]; sum += hidden[j + 1] * self.weights2_flat[offset + j + 1]; sum += hidden[j + 2] * self.weights2_flat[offset + j + 2]; sum += hidden[j + 3] * self.weights2_flat[offset + j + 3]; j += 4; } output[i] = sum; } // Apply temporal smoothing (simplified Kalman filter) for i in 0..4 { let innovation = output[i] - self.state[i]; let gain = self.covariance[i] / (self.covariance[i] + 0.1); self.state[i] += gain * innovation; self.covariance[i] *= 1.0 - gain; output[i] = self.state[i]; } let elapsed_nanos = start.elapsed().as_nanos() as u64; let result = PredictionResult { output: output.to_vec(), latency_ns: elapsed_nanos, }; Ok(serde_wasm_bindgen::to_value(&result)?) } /// Batch prediction for high throughput #[wasm_bindgen] pub fn predict_batch(&mut self, inputs_flat: &[f32]) -> Result { if inputs_flat.len() % 128 != 0 { return Err(JsValue::from_str("Input length must be multiple of 128")); } let batch_size = inputs_flat.len() / 128; let start = web_time::Instant::now(); let mut all_outputs = Vec::with_capacity(batch_size); for batch_idx in 0..batch_size { let input_offset = batch_idx * 128; let input = &inputs_flat[input_offset..input_offset + 128]; // Inline the forward pass for maximum performance let mut hidden = [0.0f32; 32]; // Layer 1 for i in 0..32 { let weight_offset = i * 128; let mut sum = self.bias1[i]; for j in 0..128 { sum += input[j] * self.weights1_flat[weight_offset + j]; } hidden[i] = sum.max(0.0); } // Layer 2 let mut output = [0.0f32; 4]; for i in 0..4 { let weight_offset = i * 32; let mut sum = self.bias2[i]; for j in 0..32 { sum += hidden[j] * self.weights2_flat[weight_offset + j]; } output[i] = sum; } all_outputs.push(output.to_vec()); } let total_elapsed = start.elapsed(); let avg_latency = total_elapsed.as_secs_f64() / batch_size as f64; let result = BatchResult { predictions: all_outputs, total_latency_ms: total_elapsed.as_secs_f64() * 1000.0, avg_latency_us: avg_latency * 1_000_000.0, throughput_ops_sec: 1.0 / avg_latency, }; Ok(serde_wasm_bindgen::to_value(&result)?) } /// Reset temporal state #[wasm_bindgen] pub fn reset_state(&mut self) { self.state = vec![0.0; 4]; self.covariance = vec![1.0; 4]; } /// Get solver metadata #[wasm_bindgen] pub fn info(&self) -> JsValue { let info = serde_json::json!({ "name": "Temporal Neural Solver", "version": env!("CARGO_PKG_VERSION"), "platform": "WebAssembly", "optimization": "Loop-unrolled WASM", "features": { "temporal_filtering": true, "kalman_smoothing": true, "loop_unrolling": true, "cache_optimized": true, }, "dimensions": { "input": 128, "hidden": 32, "output": 4, }, "performance_targets": { "latency_us": 1.0, "throughput_ops_sec": 1_000_000, } }); serde_wasm_bindgen::to_value(&info).unwrap() } } /// Benchmark function for performance testing #[wasm_bindgen] pub fn benchmark(iterations: u32) -> JsValue { let mut solver = TemporalNeuralSolver::new(); let test_input = vec![0.5f32; 128]; let start = web_time::Instant::now(); for _ in 0..iterations { let _ = solver.predict(&test_input); } let elapsed = start.elapsed(); let avg_latency = elapsed.as_secs_f64() / iterations as f64; let result = serde_json::json!({ "iterations": iterations, "total_time_ms": elapsed.as_secs_f64() * 1000.0, "avg_latency_us": avg_latency * 1_000_000.0, "throughput_ops_sec": 1.0 / avg_latency, }); serde_wasm_bindgen::to_value(&result).unwrap() } /// Get version #[wasm_bindgen] pub fn version() -> String { env!("CARGO_PKG_VERSION").to_string() } /// Initialize module #[wasm_bindgen(start)] pub fn main() { console_error_panic_hook::set_once(); console_log!("⚡ Temporal Neural Solver WASM v{} initialized", env!("CARGO_PKG_VERSION")); }