275 lines
8.2 KiB
Rust
275 lines
8.2 KiB
Rust
//! Temporal Neural Solver - Optimized WASM Implementation
|
|
//! Ultra-fast neural network inference for JavaScript/TypeScript
|
|
|
|
use wasm_bindgen::prelude::*;
|
|
use serde::{Serialize, Deserialize};
|
|
use std::sync::Arc;
|
|
|
|
// Enable console.log for debugging
|
|
#[wasm_bindgen]
|
|
extern "C" {
|
|
#[wasm_bindgen(js_namespace = console)]
|
|
fn log(s: &str);
|
|
}
|
|
|
|
macro_rules! console_log {
|
|
($($t:tt)*) => (log(&format_args!($($t)*).to_string()))
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
pub struct PredictionResult {
|
|
pub output: Vec<f32>,
|
|
pub latency_ns: u64,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
pub struct BatchResult {
|
|
pub predictions: Vec<Vec<f32>>,
|
|
pub total_latency_ms: f64,
|
|
pub avg_latency_us: f64,
|
|
pub throughput_ops_sec: f64,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
pub struct TemporalNeuralSolver {
|
|
// Optimized weight matrices (flattened for cache efficiency)
|
|
weights1_flat: Vec<f32>, // 128x32 = 4096 elements
|
|
weights2_flat: Vec<f32>, // 32x4 = 128 elements
|
|
bias1: Vec<f32>, // 32 elements
|
|
bias2: Vec<f32>, // 4 elements
|
|
|
|
// Temporal state for Kalman filtering
|
|
state: Vec<f32>,
|
|
covariance: Vec<f32>,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl TemporalNeuralSolver {
|
|
/// Create a new solver instance
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new() -> Self {
|
|
console_error_panic_hook::set_once();
|
|
|
|
// Initialize weights with optimized patterns
|
|
let mut weights1_flat = vec![0.0f32; 128 * 32];
|
|
let mut weights2_flat = vec![0.0f32; 32 * 4];
|
|
|
|
// Xavier initialization
|
|
let scale1 = (2.0 / 128.0_f32).sqrt();
|
|
let scale2 = (2.0 / 32.0_f32).sqrt();
|
|
|
|
for i in 0..weights1_flat.len() {
|
|
weights1_flat[i] = ((i as f32 * 0.1337).sin() * scale1).tanh();
|
|
}
|
|
|
|
for i in 0..weights2_flat.len() {
|
|
weights2_flat[i] = ((i as f32 * 0.2718).cos() * scale2).tanh();
|
|
}
|
|
|
|
Self {
|
|
weights1_flat,
|
|
weights2_flat,
|
|
bias1: vec![0.01; 32],
|
|
bias2: vec![0.01; 4],
|
|
state: vec![0.0; 4],
|
|
covariance: vec![1.0; 4],
|
|
}
|
|
}
|
|
|
|
/// Single prediction with sub-microsecond target latency
|
|
#[wasm_bindgen]
|
|
pub fn predict(&mut self, input: &[f32]) -> Result<JsValue, JsValue> {
|
|
if input.len() != 128 {
|
|
return Err(JsValue::from_str("Input must be exactly 128 elements"));
|
|
}
|
|
|
|
let start = web_time::Instant::now();
|
|
|
|
// Optimized forward pass with loop unrolling
|
|
let mut hidden = [0.0f32; 32];
|
|
|
|
// Layer 1: Matrix multiply with 4x unrolling
|
|
for i in 0..32 {
|
|
let offset = i * 128;
|
|
let mut sum = self.bias1[i];
|
|
|
|
// Process 4 elements at a time
|
|
let mut j = 0;
|
|
while j < 128 {
|
|
sum += input[j] * self.weights1_flat[offset + j];
|
|
sum += input[j + 1] * self.weights1_flat[offset + j + 1];
|
|
sum += input[j + 2] * self.weights1_flat[offset + j + 2];
|
|
sum += input[j + 3] * self.weights1_flat[offset + j + 3];
|
|
j += 4;
|
|
}
|
|
|
|
// ReLU activation
|
|
hidden[i] = sum.max(0.0);
|
|
}
|
|
|
|
// Layer 2: Output layer
|
|
let mut output = [0.0f32; 4];
|
|
for i in 0..4 {
|
|
let offset = i * 32;
|
|
let mut sum = self.bias2[i];
|
|
|
|
// Unrolled by 4
|
|
let mut j = 0;
|
|
while j < 32 {
|
|
sum += hidden[j] * self.weights2_flat[offset + j];
|
|
sum += hidden[j + 1] * self.weights2_flat[offset + j + 1];
|
|
sum += hidden[j + 2] * self.weights2_flat[offset + j + 2];
|
|
sum += hidden[j + 3] * self.weights2_flat[offset + j + 3];
|
|
j += 4;
|
|
}
|
|
|
|
output[i] = sum;
|
|
}
|
|
|
|
// Apply temporal smoothing (simplified Kalman filter)
|
|
for i in 0..4 {
|
|
let innovation = output[i] - self.state[i];
|
|
let gain = self.covariance[i] / (self.covariance[i] + 0.1);
|
|
self.state[i] += gain * innovation;
|
|
self.covariance[i] *= 1.0 - gain;
|
|
output[i] = self.state[i];
|
|
}
|
|
|
|
let elapsed_nanos = start.elapsed().as_nanos() as u64;
|
|
|
|
let result = PredictionResult {
|
|
output: output.to_vec(),
|
|
latency_ns: elapsed_nanos,
|
|
};
|
|
|
|
Ok(serde_wasm_bindgen::to_value(&result)?)
|
|
}
|
|
|
|
/// Batch prediction for high throughput
|
|
#[wasm_bindgen]
|
|
pub fn predict_batch(&mut self, inputs_flat: &[f32]) -> Result<JsValue, JsValue> {
|
|
if inputs_flat.len() % 128 != 0 {
|
|
return Err(JsValue::from_str("Input length must be multiple of 128"));
|
|
}
|
|
|
|
let batch_size = inputs_flat.len() / 128;
|
|
let start = web_time::Instant::now();
|
|
let mut all_outputs = Vec::with_capacity(batch_size);
|
|
|
|
for batch_idx in 0..batch_size {
|
|
let input_offset = batch_idx * 128;
|
|
let input = &inputs_flat[input_offset..input_offset + 128];
|
|
|
|
// Inline the forward pass for maximum performance
|
|
let mut hidden = [0.0f32; 32];
|
|
|
|
// Layer 1
|
|
for i in 0..32 {
|
|
let weight_offset = i * 128;
|
|
let mut sum = self.bias1[i];
|
|
for j in 0..128 {
|
|
sum += input[j] * self.weights1_flat[weight_offset + j];
|
|
}
|
|
hidden[i] = sum.max(0.0);
|
|
}
|
|
|
|
// Layer 2
|
|
let mut output = [0.0f32; 4];
|
|
for i in 0..4 {
|
|
let weight_offset = i * 32;
|
|
let mut sum = self.bias2[i];
|
|
for j in 0..32 {
|
|
sum += hidden[j] * self.weights2_flat[weight_offset + j];
|
|
}
|
|
output[i] = sum;
|
|
}
|
|
|
|
all_outputs.push(output.to_vec());
|
|
}
|
|
|
|
let total_elapsed = start.elapsed();
|
|
let avg_latency = total_elapsed.as_secs_f64() / batch_size as f64;
|
|
|
|
let result = BatchResult {
|
|
predictions: all_outputs,
|
|
total_latency_ms: total_elapsed.as_secs_f64() * 1000.0,
|
|
avg_latency_us: avg_latency * 1_000_000.0,
|
|
throughput_ops_sec: 1.0 / avg_latency,
|
|
};
|
|
|
|
Ok(serde_wasm_bindgen::to_value(&result)?)
|
|
}
|
|
|
|
/// Reset temporal state
|
|
#[wasm_bindgen]
|
|
pub fn reset_state(&mut self) {
|
|
self.state = vec![0.0; 4];
|
|
self.covariance = vec![1.0; 4];
|
|
}
|
|
|
|
/// Get solver metadata
|
|
#[wasm_bindgen]
|
|
pub fn info(&self) -> JsValue {
|
|
let info = serde_json::json!({
|
|
"name": "Temporal Neural Solver",
|
|
"version": env!("CARGO_PKG_VERSION"),
|
|
"platform": "WebAssembly",
|
|
"optimization": "Loop-unrolled WASM",
|
|
"features": {
|
|
"temporal_filtering": true,
|
|
"kalman_smoothing": true,
|
|
"loop_unrolling": true,
|
|
"cache_optimized": true,
|
|
},
|
|
"dimensions": {
|
|
"input": 128,
|
|
"hidden": 32,
|
|
"output": 4,
|
|
},
|
|
"performance_targets": {
|
|
"latency_us": 1.0,
|
|
"throughput_ops_sec": 1_000_000,
|
|
}
|
|
});
|
|
|
|
serde_wasm_bindgen::to_value(&info).unwrap()
|
|
}
|
|
}
|
|
|
|
/// Benchmark function for performance testing
|
|
#[wasm_bindgen]
|
|
pub fn benchmark(iterations: u32) -> JsValue {
|
|
let mut solver = TemporalNeuralSolver::new();
|
|
let test_input = vec![0.5f32; 128];
|
|
|
|
let start = web_time::Instant::now();
|
|
|
|
for _ in 0..iterations {
|
|
let _ = solver.predict(&test_input);
|
|
}
|
|
|
|
let elapsed = start.elapsed();
|
|
let avg_latency = elapsed.as_secs_f64() / iterations as f64;
|
|
|
|
let result = serde_json::json!({
|
|
"iterations": iterations,
|
|
"total_time_ms": elapsed.as_secs_f64() * 1000.0,
|
|
"avg_latency_us": avg_latency * 1_000_000.0,
|
|
"throughput_ops_sec": 1.0 / avg_latency,
|
|
});
|
|
|
|
serde_wasm_bindgen::to_value(&result).unwrap()
|
|
}
|
|
|
|
/// Get version
|
|
#[wasm_bindgen]
|
|
pub fn version() -> String {
|
|
env!("CARGO_PKG_VERSION").to_string()
|
|
}
|
|
|
|
/// Initialize module
|
|
#[wasm_bindgen(start)]
|
|
pub fn main() {
|
|
console_error_panic_hook::set_once();
|
|
console_log!("⚡ Temporal Neural Solver WASM v{} initialized", env!("CARGO_PKG_VERSION"));
|
|
} |