wifi-densepose/vendor/sublinear-time-solver/crates/temporal-neural-solver-wasm/src/lib.rs

275 lines
8.2 KiB
Rust

//! Temporal Neural Solver - Optimized WASM Implementation
//! Ultra-fast neural network inference for JavaScript/TypeScript
use wasm_bindgen::prelude::*;
use serde::{Serialize, Deserialize};
use std::sync::Arc;
// Enable console.log for debugging
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = console)]
fn log(s: &str);
}
macro_rules! console_log {
($($t:tt)*) => (log(&format_args!($($t)*).to_string()))
}
#[derive(Serialize, Deserialize)]
pub struct PredictionResult {
pub output: Vec<f32>,
pub latency_ns: u64,
}
#[derive(Serialize, Deserialize)]
pub struct BatchResult {
pub predictions: Vec<Vec<f32>>,
pub total_latency_ms: f64,
pub avg_latency_us: f64,
pub throughput_ops_sec: f64,
}
#[wasm_bindgen]
pub struct TemporalNeuralSolver {
// Optimized weight matrices (flattened for cache efficiency)
weights1_flat: Vec<f32>, // 128x32 = 4096 elements
weights2_flat: Vec<f32>, // 32x4 = 128 elements
bias1: Vec<f32>, // 32 elements
bias2: Vec<f32>, // 4 elements
// Temporal state for Kalman filtering
state: Vec<f32>,
covariance: Vec<f32>,
}
#[wasm_bindgen]
impl TemporalNeuralSolver {
/// Create a new solver instance
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
console_error_panic_hook::set_once();
// Initialize weights with optimized patterns
let mut weights1_flat = vec![0.0f32; 128 * 32];
let mut weights2_flat = vec![0.0f32; 32 * 4];
// Xavier initialization
let scale1 = (2.0 / 128.0_f32).sqrt();
let scale2 = (2.0 / 32.0_f32).sqrt();
for i in 0..weights1_flat.len() {
weights1_flat[i] = ((i as f32 * 0.1337).sin() * scale1).tanh();
}
for i in 0..weights2_flat.len() {
weights2_flat[i] = ((i as f32 * 0.2718).cos() * scale2).tanh();
}
Self {
weights1_flat,
weights2_flat,
bias1: vec![0.01; 32],
bias2: vec![0.01; 4],
state: vec![0.0; 4],
covariance: vec![1.0; 4],
}
}
/// Single prediction with sub-microsecond target latency
#[wasm_bindgen]
pub fn predict(&mut self, input: &[f32]) -> Result<JsValue, JsValue> {
if input.len() != 128 {
return Err(JsValue::from_str("Input must be exactly 128 elements"));
}
let start = web_time::Instant::now();
// Optimized forward pass with loop unrolling
let mut hidden = [0.0f32; 32];
// Layer 1: Matrix multiply with 4x unrolling
for i in 0..32 {
let offset = i * 128;
let mut sum = self.bias1[i];
// Process 4 elements at a time
let mut j = 0;
while j < 128 {
sum += input[j] * self.weights1_flat[offset + j];
sum += input[j + 1] * self.weights1_flat[offset + j + 1];
sum += input[j + 2] * self.weights1_flat[offset + j + 2];
sum += input[j + 3] * self.weights1_flat[offset + j + 3];
j += 4;
}
// ReLU activation
hidden[i] = sum.max(0.0);
}
// Layer 2: Output layer
let mut output = [0.0f32; 4];
for i in 0..4 {
let offset = i * 32;
let mut sum = self.bias2[i];
// Unrolled by 4
let mut j = 0;
while j < 32 {
sum += hidden[j] * self.weights2_flat[offset + j];
sum += hidden[j + 1] * self.weights2_flat[offset + j + 1];
sum += hidden[j + 2] * self.weights2_flat[offset + j + 2];
sum += hidden[j + 3] * self.weights2_flat[offset + j + 3];
j += 4;
}
output[i] = sum;
}
// Apply temporal smoothing (simplified Kalman filter)
for i in 0..4 {
let innovation = output[i] - self.state[i];
let gain = self.covariance[i] / (self.covariance[i] + 0.1);
self.state[i] += gain * innovation;
self.covariance[i] *= 1.0 - gain;
output[i] = self.state[i];
}
let elapsed_nanos = start.elapsed().as_nanos() as u64;
let result = PredictionResult {
output: output.to_vec(),
latency_ns: elapsed_nanos,
};
Ok(serde_wasm_bindgen::to_value(&result)?)
}
/// Batch prediction for high throughput
#[wasm_bindgen]
pub fn predict_batch(&mut self, inputs_flat: &[f32]) -> Result<JsValue, JsValue> {
if inputs_flat.len() % 128 != 0 {
return Err(JsValue::from_str("Input length must be multiple of 128"));
}
let batch_size = inputs_flat.len() / 128;
let start = web_time::Instant::now();
let mut all_outputs = Vec::with_capacity(batch_size);
for batch_idx in 0..batch_size {
let input_offset = batch_idx * 128;
let input = &inputs_flat[input_offset..input_offset + 128];
// Inline the forward pass for maximum performance
let mut hidden = [0.0f32; 32];
// Layer 1
for i in 0..32 {
let weight_offset = i * 128;
let mut sum = self.bias1[i];
for j in 0..128 {
sum += input[j] * self.weights1_flat[weight_offset + j];
}
hidden[i] = sum.max(0.0);
}
// Layer 2
let mut output = [0.0f32; 4];
for i in 0..4 {
let weight_offset = i * 32;
let mut sum = self.bias2[i];
for j in 0..32 {
sum += hidden[j] * self.weights2_flat[weight_offset + j];
}
output[i] = sum;
}
all_outputs.push(output.to_vec());
}
let total_elapsed = start.elapsed();
let avg_latency = total_elapsed.as_secs_f64() / batch_size as f64;
let result = BatchResult {
predictions: all_outputs,
total_latency_ms: total_elapsed.as_secs_f64() * 1000.0,
avg_latency_us: avg_latency * 1_000_000.0,
throughput_ops_sec: 1.0 / avg_latency,
};
Ok(serde_wasm_bindgen::to_value(&result)?)
}
/// Reset temporal state
#[wasm_bindgen]
pub fn reset_state(&mut self) {
self.state = vec![0.0; 4];
self.covariance = vec![1.0; 4];
}
/// Get solver metadata
#[wasm_bindgen]
pub fn info(&self) -> JsValue {
let info = serde_json::json!({
"name": "Temporal Neural Solver",
"version": env!("CARGO_PKG_VERSION"),
"platform": "WebAssembly",
"optimization": "Loop-unrolled WASM",
"features": {
"temporal_filtering": true,
"kalman_smoothing": true,
"loop_unrolling": true,
"cache_optimized": true,
},
"dimensions": {
"input": 128,
"hidden": 32,
"output": 4,
},
"performance_targets": {
"latency_us": 1.0,
"throughput_ops_sec": 1_000_000,
}
});
serde_wasm_bindgen::to_value(&info).unwrap()
}
}
/// Benchmark function for performance testing
#[wasm_bindgen]
pub fn benchmark(iterations: u32) -> JsValue {
let mut solver = TemporalNeuralSolver::new();
let test_input = vec![0.5f32; 128];
let start = web_time::Instant::now();
for _ in 0..iterations {
let _ = solver.predict(&test_input);
}
let elapsed = start.elapsed();
let avg_latency = elapsed.as_secs_f64() / iterations as f64;
let result = serde_json::json!({
"iterations": iterations,
"total_time_ms": elapsed.as_secs_f64() * 1000.0,
"avg_latency_us": avg_latency * 1_000_000.0,
"throughput_ops_sec": 1.0 / avg_latency,
});
serde_wasm_bindgen::to_value(&result).unwrap()
}
/// Get version
#[wasm_bindgen]
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
/// Initialize module
#[wasm_bindgen(start)]
pub fn main() {
console_error_panic_hook::set_once();
console_log!("⚡ Temporal Neural Solver WASM v{} initialized", env!("CARGO_PKG_VERSION"));
}