wifi-densepose/vendor/sublinear-time-solver/crates/neural-network-implementation/src/inference/mod.rs

722 lines
24 KiB
Rust

//! High-performance inference engine for temporal neural networks
//!
//! This module provides optimized inference capabilities with sub-millisecond
//! latency guarantees and comprehensive performance monitoring.
use crate::{
config::{Config, InferenceConfig},
error::{Result, TemporalNeuralError},
models::{ModelTrait, SystemA, SystemB},
solvers::Certificate,
};
use nalgebra::{DMatrix, DVector};
use serde::{Deserialize, Serialize};
use std::time::Instant;
pub mod quantization;
pub mod simd_ops;
pub mod memory_pool;
pub use quantization::{QuantizedInference, Int8Quantizer};
pub use simd_ops::{SimdAccelerator, VectorOps};
pub use memory_pool::{MemoryPool, PreallocatedBuffer};
/// High-performance predictor with latency guarantees
pub struct Predictor {
/// Model being used for prediction
model: PredictorModel,
/// Inference configuration
config: InferenceConfig,
/// Performance monitor
monitor: PerformanceMonitor,
/// Memory pool for zero-allocation inference
memory_pool: MemoryPool,
/// SIMD accelerator
simd_accelerator: SimdAccelerator,
/// Quantization engine (if enabled)
quantizer: Option<Int8Quantizer>,
/// Inference statistics
stats: InferenceStatistics,
}
/// Model wrapper for unified inference interface
enum PredictorModel {
/// System A (traditional)
SystemA(SystemA),
/// System B (temporal solver)
SystemB(SystemB),
}
/// Prediction result with comprehensive metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Prediction {
/// Predicted values
pub values: DVector<f64>,
/// Confidence score (0.0 to 1.0)
pub confidence: f64,
/// Prediction latency in microseconds
pub latency_us: f64,
/// Certificate (for System B)
pub certificate: Option<Certificate>,
/// Prediction metadata
pub metadata: PredictionMetadata,
}
/// Metadata about the prediction
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictionMetadata {
/// Model type used
pub model_type: String,
/// Prediction timestamp
pub timestamp: chrono::DateTime<chrono::Utc>,
/// Input data quality score
pub input_quality: f64,
/// Whether quantization was used
pub quantized: bool,
/// Whether SIMD was used
pub simd_used: bool,
/// Memory usage for this prediction
pub memory_used_bytes: usize,
/// Detailed timing breakdown
pub timing: TimingBreakdown,
}
/// Detailed timing breakdown for performance analysis
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimingBreakdown {
/// Input preprocessing time (microseconds)
pub preprocessing_us: f64,
/// Core model inference time (microseconds)
pub inference_us: f64,
/// Post-processing time (microseconds)
pub postprocessing_us: f64,
/// Solver verification time (microseconds, System B only)
pub verification_us: Option<f64>,
/// Memory allocation time (microseconds)
pub allocation_us: f64,
/// Total time (microseconds)
pub total_us: f64,
}
/// Performance monitoring and latency tracking
#[derive(Debug)]
struct PerformanceMonitor {
/// Recent latency measurements
recent_latencies: Vec<f64>,
/// Maximum number of recent measurements to keep
max_recent: usize,
/// Target latency threshold
target_latency_us: f64,
/// Latency violations counter
violations: u64,
/// Total predictions made
total_predictions: u64,
}
/// Comprehensive inference statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceStatistics {
/// Total predictions made
pub total_predictions: u64,
/// Average latency in microseconds
pub avg_latency_us: f64,
/// P50 latency in microseconds
pub p50_latency_us: f64,
/// P99 latency in microseconds
pub p99_latency_us: f64,
/// P99.9 latency in microseconds
pub p99_9_latency_us: f64,
/// Maximum latency observed
pub max_latency_us: f64,
/// Minimum latency observed
pub min_latency_us: f64,
/// Latency target violations
pub latency_violations: u64,
/// Latency violation rate
pub violation_rate: f64,
/// Average throughput (predictions per second)
pub throughput_pred_per_sec: f64,
/// Memory usage statistics
pub memory_stats: MemoryStatistics,
/// System B specific statistics
pub system_b_stats: Option<SystemBInferenceStats>,
}
/// Memory usage statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryStatistics {
/// Current memory usage in bytes
pub current_usage_bytes: usize,
/// Peak memory usage in bytes
pub peak_usage_bytes: usize,
/// Average memory per prediction
pub avg_memory_per_prediction: f64,
/// Memory pool utilization
pub pool_utilization: f64,
}
/// System B specific inference statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemBInferenceStats {
/// Gate pass rate
pub gate_pass_rate: f64,
/// Average certificate error
pub avg_certificate_error: f64,
/// Fallback usage rate
pub fallback_rate: f64,
/// Average solver work performed
pub avg_solver_work: f64,
}
impl Predictor {
/// Create a new predictor from a trained model
pub fn new_system_a(model: SystemA, config: InferenceConfig) -> Result<Self> {
let monitor = PerformanceMonitor::new(config.target_latency_ms * 1000.0);
let memory_pool = MemoryPool::new(1024 * 1024)?; // 1MB pool
let simd_accelerator = SimdAccelerator::new(config.enable_simd);
let quantizer = if config.enable_simd {
Some(Int8Quantizer::new()?)
} else {
None
};
Ok(Self {
model: PredictorModel::SystemA(model),
config,
monitor,
memory_pool,
simd_accelerator,
quantizer,
stats: InferenceStatistics::new(),
})
}
/// Create a new predictor for System B
pub fn new_system_b(mut model: SystemB, config: InferenceConfig) -> Result<Self> {
// Prepare model for inference
model.prepare_for_inference()?;
let monitor = PerformanceMonitor::new(config.target_latency_ms * 1000.0);
let memory_pool = MemoryPool::new(2 * 1024 * 1024)?; // 2MB pool for System B
let simd_accelerator = SimdAccelerator::new(config.enable_simd);
let quantizer = if config.enable_simd {
Some(Int8Quantizer::new()?)
} else {
None
};
Ok(Self {
model: PredictorModel::SystemB(model),
config,
monitor,
memory_pool,
simd_accelerator,
quantizer,
stats: InferenceStatistics::new(),
})
}
/// Perform prediction with comprehensive monitoring
pub fn predict(&mut self, input: &DMatrix<f64>) -> Result<Prediction> {
let start_time = Instant::now();
// Pre-allocate memory from pool
let _buffer = self.memory_pool.acquire()?;
let allocation_time = start_time.elapsed().as_micros() as f64;
// Validate input
self.validate_input(input)?;
// Preprocessing
let preprocessing_start = Instant::now();
let processed_input = self.preprocess_input(input)?;
let preprocessing_time = preprocessing_start.elapsed().as_micros() as f64;
// Core inference
let inference_start = Instant::now();
let (prediction_values, certificate) = match &mut self.model {
PredictorModel::SystemA(model) => {
let pred = model.forward(&processed_input)?;
(pred, None)
}
PredictorModel::SystemB(model) => {
let pred_result = model.predict_with_solver(&processed_input)?;
// Create certificate from gate result
let certificate = Certificate {
error_bound: pred_result.gate_result.certificate_error,
confidence: pred_result.gate_result.confidence,
work_performed: pred_result.gate_result.work_performed,
algorithm: "temporal_solver".to_string(),
is_valid: pred_result.gate_result.passed,
metadata: crate::solvers::CertificateMetadata {
condition_number: None,
diagonally_dominant: false,
iterations: 0,
residual_norm: pred_result.gate_result.certificate_error,
computation_time_us: pred_result.gate_result.verification_time_us,
},
};
(pred_result.prediction, Some(certificate))
}
};
let inference_time = inference_start.elapsed().as_micros() as f64;
// Post-processing
let postprocessing_start = Instant::now();
let final_prediction = self.postprocess_prediction(&prediction_values)?;
let postprocessing_time = postprocessing_start.elapsed().as_micros() as f64;
let total_time = start_time.elapsed().as_micros() as f64;
// Update performance monitoring
self.monitor.record_latency(total_time);
// Compute confidence score
let confidence = self.compute_confidence(&final_prediction, certificate.as_ref());
// Create timing breakdown
let timing = TimingBreakdown {
preprocessing_us: preprocessing_time,
inference_us: inference_time,
postprocessing_us: postprocessing_time,
verification_us: certificate.as_ref().map(|c| c.metadata.computation_time_us),
allocation_us: allocation_time,
total_us: total_time,
};
// Create metadata
let metadata = PredictionMetadata {
model_type: match &self.model {
PredictorModel::SystemA(_) => "SystemA".to_string(),
PredictorModel::SystemB(_) => "SystemB".to_string(),
},
timestamp: chrono::Utc::now(),
input_quality: self.assess_input_quality(input),
quantized: self.quantizer.is_some(),
simd_used: self.config.enable_simd,
memory_used_bytes: self.memory_pool.current_usage(),
timing,
};
// Update statistics
self.update_statistics(total_time, &metadata, certificate.as_ref());
// Check latency constraints
if total_time > self.config.target_latency_ms * 1000.0 {
log::warn!(
"Latency constraint violated: {:.2}μs > {:.2}μs",
total_time, self.config.target_latency_ms * 1000.0
);
}
Ok(Prediction {
values: final_prediction,
confidence,
latency_us: total_time,
certificate,
metadata,
})
}
/// Batch prediction for higher throughput
pub fn predict_batch(&mut self, inputs: &[DMatrix<f64>]) -> Result<Vec<Prediction>> {
let mut predictions = Vec::with_capacity(inputs.len());
for input in inputs {
let prediction = self.predict(input)?;
predictions.push(prediction);
}
Ok(predictions)
}
/// Validate input data
fn validate_input(&self, input: &DMatrix<f64>) -> Result<()> {
let expected_shape = match &self.model {
PredictorModel::SystemA(model) => model.input_shape(),
PredictorModel::SystemB(model) => model.input_shape(),
};
let actual_shape = (input.nrows(), input.ncols());
if actual_shape != expected_shape {
return Err(TemporalNeuralError::InferenceError {
message: format!(
"Input shape mismatch: expected {:?}, got {:?}",
expected_shape, actual_shape
),
input_shape: Some(vec![actual_shape.0, actual_shape.1]),
latency_exceeded: false,
});
}
// Check for invalid values
for &val in input.iter() {
if !val.is_finite() {
return Err(TemporalNeuralError::InferenceError {
message: "Input contains invalid values (NaN or Inf)".to_string(),
input_shape: Some(vec![actual_shape.0, actual_shape.1]),
latency_exceeded: false,
});
}
}
Ok(())
}
/// Preprocess input for inference
fn preprocess_input(&self, input: &DMatrix<f64>) -> Result<DMatrix<f64>> {
// Apply SIMD optimizations if available
if self.config.enable_simd {
self.simd_accelerator.optimize_matrix(input)
} else {
Ok(input.clone())
}
}
/// Post-process prediction results
fn postprocess_prediction(&self, prediction: &DVector<f64>) -> Result<DVector<f64>> {
// Apply any final transformations
Ok(prediction.clone())
}
/// Compute confidence score for prediction
fn compute_confidence(&self, prediction: &DVector<f64>, certificate: Option<&Certificate>) -> f64 {
match certificate {
Some(cert) => cert.confidence,
None => {
// For System A, use prediction magnitude as rough confidence measure
let mag = prediction.norm();
if mag < 10.0 { 0.9 } else { 0.7 } // Simple heuristic
}
}
}
/// Assess input data quality
fn assess_input_quality(&self, input: &DMatrix<f64>) -> f64 {
// Check for reasonable variance and no outliers
let mut quality: f64 = 1.0;
for i in 0..input.nrows() {
let row_data: Vec<f64> = input.row(i).iter().cloned().collect();
let mean = row_data.iter().sum::<f64>() / row_data.len() as f64;
let variance = row_data.iter()
.map(|x| (x - mean).powi(2))
.sum::<f64>() / row_data.len() as f64;
// Penalize very low variance (flat signals)
if variance < 1e-6 {
quality *= 0.5;
}
// Penalize extreme values
for &val in &row_data {
if val.abs() > 100.0 {
quality *= 0.8;
}
}
}
quality.clamp(0.0, 1.0)
}
/// Update inference statistics
fn update_statistics(&mut self, latency_us: f64, metadata: &PredictionMetadata, certificate: Option<&Certificate>) {
self.stats.total_predictions += 1;
// Update latency statistics
let prev_avg = self.stats.avg_latency_us;
let n = self.stats.total_predictions as f64;
self.stats.avg_latency_us = (prev_avg * (n - 1.0) + latency_us) / n;
// Update min/max
if latency_us > self.stats.max_latency_us {
self.stats.max_latency_us = latency_us;
}
if latency_us < self.stats.min_latency_us || self.stats.min_latency_us == 0.0 {
self.stats.min_latency_us = latency_us;
}
// Update memory statistics
self.stats.memory_stats.current_usage_bytes = metadata.memory_used_bytes;
if metadata.memory_used_bytes > self.stats.memory_stats.peak_usage_bytes {
self.stats.memory_stats.peak_usage_bytes = metadata.memory_used_bytes;
}
// Update System B specific stats
if let Some(cert) = certificate {
if self.stats.system_b_stats.is_none() {
self.stats.system_b_stats = Some(SystemBInferenceStats {
gate_pass_rate: 0.0,
avg_certificate_error: 0.0,
fallback_rate: 0.0,
avg_solver_work: 0.0,
});
}
if let Some(ref mut b_stats) = self.stats.system_b_stats {
let prev_avg_error = b_stats.avg_certificate_error;
b_stats.avg_certificate_error = (prev_avg_error * (n - 1.0) + cert.error_bound) / n;
let prev_avg_work = b_stats.avg_solver_work;
b_stats.avg_solver_work = (prev_avg_work * (n - 1.0) + cert.work_performed as f64) / n;
}
}
// Update percentile statistics periodically
if self.stats.total_predictions % 100 == 0 {
self.update_percentile_statistics();
}
}
/// Update percentile statistics (P50, P99, P99.9)
fn update_percentile_statistics(&mut self) {
let mut latencies = self.monitor.recent_latencies.clone();
if latencies.is_empty() {
return;
}
latencies.sort_by(|a, b| a.partial_cmp(b).unwrap());
let len = latencies.len();
self.stats.p50_latency_us = latencies[len / 2];
self.stats.p99_latency_us = latencies[(len as f64 * 0.99) as usize];
self.stats.p99_9_latency_us = latencies[(len as f64 * 0.999) as usize];
// Update violation statistics
self.stats.latency_violations = self.monitor.violations;
self.stats.violation_rate = self.monitor.violations as f64 / self.stats.total_predictions as f64;
// Estimate throughput
if let (Some(&first), Some(&last)) = (latencies.first(), latencies.last()) {
let time_span = last - first;
if time_span > 0.0 {
self.stats.throughput_pred_per_sec = (len as f64) / (time_span / 1_000_000.0);
}
}
}
/// Get current inference statistics
pub fn get_statistics(&self) -> &InferenceStatistics {
&self.stats
}
/// Check if performance targets are being met
pub fn meets_performance_targets(&self) -> bool {
self.stats.p99_9_latency_us <= self.config.target_latency_ms * 1000.0 &&
self.stats.violation_rate <= 0.001 // Less than 0.1% violations
}
/// Reset statistics
pub fn reset_statistics(&mut self) {
self.stats = InferenceStatistics::new();
self.monitor.reset();
}
/// Warm up the predictor (important for latency-critical applications)
pub fn warmup(&mut self, warmup_iterations: usize) -> Result<()> {
log::info!("Warming up predictor with {} iterations", warmup_iterations);
// Create dummy input of the correct shape
let input_shape = match &self.model {
PredictorModel::SystemA(model) => model.input_shape(),
PredictorModel::SystemB(model) => model.input_shape(),
};
let dummy_input = DMatrix::zeros(input_shape.0, input_shape.1);
for _ in 0..warmup_iterations {
let _ = self.predict(&dummy_input)?;
}
// Reset statistics after warmup
self.reset_statistics();
log::info!("Warmup completed");
Ok(())
}
}
impl PerformanceMonitor {
fn new(target_latency_us: f64) -> Self {
Self {
recent_latencies: Vec::with_capacity(1000),
max_recent: 1000,
target_latency_us,
violations: 0,
total_predictions: 0,
}
}
fn record_latency(&mut self, latency_us: f64) {
self.recent_latencies.push(latency_us);
if self.recent_latencies.len() > self.max_recent {
self.recent_latencies.remove(0);
}
if latency_us > self.target_latency_us {
self.violations += 1;
}
self.total_predictions += 1;
}
fn reset(&mut self) {
self.recent_latencies.clear();
self.violations = 0;
self.total_predictions = 0;
}
}
impl InferenceStatistics {
fn new() -> Self {
Self {
total_predictions: 0,
avg_latency_us: 0.0,
p50_latency_us: 0.0,
p99_latency_us: 0.0,
p99_9_latency_us: 0.0,
max_latency_us: 0.0,
min_latency_us: 0.0,
latency_violations: 0,
violation_rate: 0.0,
throughput_pred_per_sec: 0.0,
memory_stats: MemoryStatistics {
current_usage_bytes: 0,
peak_usage_bytes: 0,
avg_memory_per_prediction: 0.0,
pool_utilization: 0.0,
},
system_b_stats: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
config::{ModelConfig, TemporalSolverConfig},
models::{SystemA, SystemB},
};
fn create_test_inference_config() -> InferenceConfig {
InferenceConfig {
target_latency_ms: 0.9,
enable_simd: false, // Disable for tests
num_threads: 1,
pin_memory: false,
cpu_affinity: None,
batch_size: 1,
}
}
fn create_test_model_config() -> ModelConfig {
ModelConfig {
model_type: "micro_gru".to_string(),
hidden_size: 8,
num_layers: 1,
dropout: 0.0,
residual: false,
activation: "tanh".to_string(),
layer_norm: false,
}
}
#[test]
fn test_system_a_predictor() {
let model_config = create_test_model_config();
let inference_config = create_test_inference_config();
let model = SystemA::new(&model_config).unwrap();
let mut predictor = Predictor::new_system_a(model, inference_config).unwrap();
let input = DMatrix::from_element(4, 256, 1.0);
let prediction = predictor.predict(&input).unwrap();
assert_eq!(prediction.values.len(), 2);
assert!(prediction.latency_us > 0.0);
assert!(prediction.confidence >= 0.0 && prediction.confidence <= 1.0);
assert!(prediction.certificate.is_none());
}
#[test]
fn test_system_b_predictor() {
let model_config = create_test_model_config();
let solver_config = TemporalSolverConfig::default();
let inference_config = create_test_inference_config();
let model = SystemB::new(&model_config, &solver_config).unwrap();
let mut predictor = Predictor::new_system_b(model, inference_config).unwrap();
let input = DMatrix::from_element(4, 256, 1.0);
let prediction = predictor.predict(&input).unwrap();
assert_eq!(prediction.values.len(), 2);
assert!(prediction.latency_us > 0.0);
assert!(prediction.confidence >= 0.0 && prediction.confidence <= 1.0);
assert!(prediction.certificate.is_some());
}
#[test]
fn test_batch_prediction() {
let model_config = create_test_model_config();
let inference_config = create_test_inference_config();
let model = SystemA::new(&model_config).unwrap();
let mut predictor = Predictor::new_system_a(model, inference_config).unwrap();
let inputs = vec![
DMatrix::from_element(4, 256, 1.0),
DMatrix::from_element(4, 256, 2.0),
DMatrix::from_element(4, 256, 3.0),
];
let predictions = predictor.predict_batch(&inputs).unwrap();
assert_eq!(predictions.len(), 3);
for prediction in predictions {
assert_eq!(prediction.values.len(), 2);
assert!(prediction.latency_us > 0.0);
}
}
#[test]
fn test_statistics_tracking() {
let model_config = create_test_model_config();
let inference_config = create_test_inference_config();
let model = SystemA::new(&model_config).unwrap();
let mut predictor = Predictor::new_system_a(model, inference_config).unwrap();
let input = DMatrix::from_element(4, 256, 1.0);
// Make several predictions
for _ in 0..10 {
let _ = predictor.predict(&input).unwrap();
}
let stats = predictor.get_statistics();
assert_eq!(stats.total_predictions, 10);
assert!(stats.avg_latency_us > 0.0);
}
#[test]
fn test_input_validation() {
let model_config = create_test_model_config();
let inference_config = create_test_inference_config();
let model = SystemA::new(&model_config).unwrap();
let mut predictor = Predictor::new_system_a(model, inference_config).unwrap();
// Wrong shape
let wrong_input = DMatrix::from_element(3, 100, 1.0);
assert!(predictor.predict(&wrong_input).is_err());
// Invalid values
let mut invalid_input = DMatrix::from_element(4, 256, 1.0);
invalid_input[(0, 0)] = f64::NAN;
assert!(predictor.predict(&invalid_input).is_err());
}
}