use crate::quantization::{QuantizedMlp, QuantizedWeights}; use rand::Rng; /// Quantized MLP wrapper for temporal-compare /// Provides 4x model size reduction with minimal accuracy loss pub struct QuantizedMlpBackend { quantized: Option, // Training happens in FP32, quantize after weights1: Vec, bias1: Vec, weights2: Vec, bias2: Vec, input_dim: usize, hidden_dim: usize, output_dim: usize, // Track compression stats original_size: usize, quantized_size: usize, } impl QuantizedMlpBackend { pub fn new(input_dim: usize, hidden_dim: usize, output_dim: usize) -> Self { let mut rng = rand::thread_rng(); // Xavier initialization let scale1 = (2.0 / input_dim as f32).sqrt(); let scale2 = (2.0 / hidden_dim as f32).sqrt(); let weights1: Vec = (0..hidden_dim * input_dim) .map(|_| rng.gen_range(-scale1..scale1)) .collect(); let weights2: Vec = (0..output_dim * hidden_dim) .map(|_| rng.gen_range(-scale2..scale2)) .collect(); let original_size = (weights1.len() + weights2.len() + hidden_dim + output_dim) * 4; Self { quantized: None, weights1, bias1: vec![0.0; hidden_dim], weights2, bias2: vec![0.0; output_dim], input_dim, hidden_dim, output_dim, original_size, quantized_size: 0, } } /// Train in FP32 for best accuracy pub fn train(&mut self, x: &[Vec], y: &[f32], epochs: usize, lr: f32) { for epoch in 0..epochs { let mut total_loss = 0.0; for (xi, &yi) in x.iter().zip(y.iter()) { // Forward pass (FP32) let mut hidden = vec![0.0f32; self.hidden_dim]; // Layer 1 for i in 0..self.hidden_dim { let mut sum = self.bias1[i]; for j in 0..self.input_dim { sum += self.weights1[i * self.input_dim + j] * xi[j]; } hidden[i] = sum.max(0.0); // ReLU } // Layer 2 let mut output = self.bias2[0]; for i in 0..self.hidden_dim { output += self.weights2[i] * hidden[i]; } // Loss (MSE) let error = output - yi; total_loss += error * error; // Backward pass // Output layer gradients for i in 0..self.hidden_dim { self.weights2[i] -= lr * error * hidden[i]; } self.bias2[0] -= lr * error; // Hidden layer gradients for i in 0..self.hidden_dim { if hidden[i] > 0.0 { let grad = error * self.weights2[i]; for j in 0..self.input_dim { self.weights1[i * self.input_dim + j] -= lr * grad * xi[j]; } self.bias1[i] -= lr * grad; } } } if epoch % 10 == 0 { println!("Epoch {}: Loss = {:.6}", epoch, total_loss / x.len() as f32); } } // Quantize after training self.quantize(); } /// Quantize the trained FP32 model to INT8 pub fn quantize(&mut self) { let qmlp = QuantizedMlp::from_float_mlp( &self.weights1, &self.bias1, &self.weights2, &self.bias2, self.input_dim, self.hidden_dim, self.output_dim ); self.quantized_size = qmlp.model_size(); self.quantized = Some(qmlp); } /// Predict using quantized weights (fast) pub fn predict(&self, x: &[Vec]) -> Vec { match &self.quantized { Some(qmlp) => { x.iter().map(|xi| { let mut output = vec![0.0f32; self.output_dim]; #[cfg(target_arch = "x86_64")] { qmlp.forward_avx2(xi, &mut output); } #[cfg(not(target_arch = "x86_64"))] { qmlp.forward(xi, &mut output); } output[0] }).collect() } None => { // Fallback to FP32 if not quantized self.predict_fp32(x) } } } /// Predict using FP32 weights (for comparison) pub fn predict_fp32(&self, x: &[Vec]) -> Vec { x.iter().map(|xi| { let mut hidden = vec![0.0f32; self.hidden_dim]; // Layer 1 for i in 0..self.hidden_dim { let mut sum = self.bias1[i]; for j in 0..self.input_dim { sum += self.weights1[i * self.input_dim + j] * xi[j]; } hidden[i] = sum.max(0.0); } // Layer 2 let mut output = self.bias2[0]; for i in 0..self.hidden_dim { output += self.weights2[i] * hidden[i]; } output }).collect() } /// Classification prediction pub fn predict_class(&self, x: &[Vec]) -> Vec { self.predict(x).iter().map(|&y| { if y < -0.25 { 0 } else if y > 0.25 { 2 } else { 1 } }).collect() } /// Get compression statistics pub fn get_compression_stats(&self) -> (usize, usize, f32) { let ratio = if self.quantized_size > 0 { self.original_size as f32 / self.quantized_size as f32 } else { 1.0 }; (self.original_size, self.quantized_size, ratio) } /// Compare FP32 vs INT8 performance pub fn benchmark_inference(&self, x: &[Vec], iterations: usize) { use std::time::Instant; // Warm up let _ = self.predict_fp32(&x[..1.min(x.len())]); let _ = self.predict(&x[..1.min(x.len())]); // Benchmark FP32 let start = Instant::now(); for _ in 0..iterations { let _ = self.predict_fp32(x); } let fp32_time = start.elapsed(); // Benchmark INT8 let start = Instant::now(); for _ in 0..iterations { let _ = self.predict(x); } let int8_time = start.elapsed(); let speedup = fp32_time.as_secs_f32() / int8_time.as_secs_f32(); println!("\n=== Quantization Benchmark ==="); println!("FP32 time: {:.3}s", fp32_time.as_secs_f32()); println!("INT8 time: {:.3}s", int8_time.as_secs_f32()); println!("Speedup: {:.2}x", speedup); let (orig, quant, ratio) = self.get_compression_stats(); println!("\n=== Model Size ==="); println!("Original: {} bytes", orig); println!("Quantized: {} bytes", quant); println!("Compression: {:.2}x", ratio); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_quantized_mlp() { let mut model = QuantizedMlpBackend::new(32, 64, 1); // Create dummy data let x: Vec> = (0..100) .map(|_| (0..32).map(|_| rand::random()).collect()) .collect(); let y: Vec = (0..100).map(|_| rand::random()).collect(); // Train and quantize model.train(&x, &y, 10, 0.01); // Check predictions work let pred_fp32 = model.predict_fp32(&x[..10]); let pred_int8 = model.predict(&x[..10]); // Should be similar but not identical for (p32, p8) in pred_fp32.iter().zip(&pred_int8) { let diff = (p32 - p8).abs(); assert!(diff < 0.1, "Quantization error too large: {}", diff); } // Check compression let (_, _, ratio) = model.get_compression_stats(); assert!(ratio > 3.0, "Compression ratio too low: {}", ratio); } }