wifi-densepose/vendor/ruvector/examples/ruvLLM/src/inference.rs

348 lines
10 KiB
Rust

//! LFM2 inference pool for model management
//!
//! Supports both mock inference (for testing/benchmarking orchestration) and
//! real SIMD-optimized CPU inference.
use crate::config::InferenceConfig;
use crate::error::{Error, InferenceError, Result};
use crate::simd_inference::{SimdGenerationConfig, SimdInferenceEngine};
use crate::types::ModelSize;
use dashmap::DashMap;
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::Instant;
/// Generation configuration
#[derive(Debug, Clone)]
pub struct GenerationConfig {
/// Maximum tokens to generate
pub max_tokens: usize,
/// Temperature
pub temperature: f32,
/// Top-p (nucleus sampling)
pub top_p: f32,
/// Top-k sampling
pub top_k: usize,
/// Repeat penalty
pub repeat_penalty: f32,
}
impl Default for GenerationConfig {
fn default() -> Self {
Self {
max_tokens: 256,
temperature: 0.7,
top_p: 0.9,
top_k: 40,
repeat_penalty: 1.1,
}
}
}
impl From<&GenerationConfig> for SimdGenerationConfig {
fn from(config: &GenerationConfig) -> Self {
SimdGenerationConfig {
max_tokens: config.max_tokens,
temperature: config.temperature,
top_p: config.top_p,
top_k: config.top_k,
repeat_penalty: config.repeat_penalty,
}
}
}
/// Result of generation
#[derive(Debug, Clone)]
pub struct GenerationResult {
/// Generated text
pub text: String,
/// Tokens generated
pub tokens_generated: usize,
/// Model used
pub model_used: ModelSize,
/// Whether KV cache was hit
pub cache_hit: bool,
/// Inference time in milliseconds
pub inference_time_ms: f64,
/// Tokens per second
pub tokens_per_second: f64,
}
/// Inference mode
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InferenceMode {
/// Mock inference (fast, for orchestration benchmarks)
Mock,
/// Real SIMD-optimized CPU inference
RealSimd,
}
/// Pool of LFM2 models with lazy loading
pub struct InferencePool {
/// Loaded mock models (for orchestration benchmarks)
models: DashMap<ModelSize, Arc<MockModel>>,
/// LRU tracking
lru: RwLock<Vec<(ModelSize, Instant)>>,
/// Configuration
config: InferenceConfig,
/// Real SIMD inference engine
simd_engine: Option<Arc<SimdInferenceEngine>>,
/// Current inference mode
mode: InferenceMode,
}
/// Mock model for testing (measures orchestration overhead only)
struct MockModel {
size: ModelSize,
}
impl InferencePool {
/// Create a new inference pool with mock inference (fast orchestration benchmarks)
pub async fn new(config: &InferenceConfig) -> Result<Self> {
Ok(Self {
models: DashMap::new(),
lru: RwLock::new(Vec::new()),
config: config.clone(),
simd_engine: None,
mode: InferenceMode::Mock,
})
}
/// Create a new inference pool with real SIMD-optimized inference
pub async fn new_with_real_inference(config: &InferenceConfig) -> Result<Self> {
let engine = SimdInferenceEngine::new_demo();
Ok(Self {
models: DashMap::new(),
lru: RwLock::new(Vec::new()),
config: config.clone(),
simd_engine: Some(Arc::new(engine)),
mode: InferenceMode::RealSimd,
})
}
/// Set inference mode
pub fn set_mode(&mut self, mode: InferenceMode) {
if mode == InferenceMode::RealSimd && self.simd_engine.is_none() {
self.simd_engine = Some(Arc::new(SimdInferenceEngine::new_demo()));
}
self.mode = mode;
}
/// Get current inference mode
pub fn mode(&self) -> InferenceMode {
self.mode
}
/// Generate response from a model
pub async fn generate(
&self,
model_size: ModelSize,
prompt: &str,
config: GenerationConfig,
session_key: Option<&str>,
) -> Result<GenerationResult> {
let start = Instant::now();
match self.mode {
InferenceMode::Mock => {
// Get or load mock model
let _model = self.get_or_load(model_size).await?;
// Mock generation (measures orchestration overhead only)
let response = self.mock_generate(prompt, &config, model_size);
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
Ok(GenerationResult {
text: response,
tokens_generated: config.max_tokens / 2,
model_used: model_size,
cache_hit: false,
inference_time_ms: elapsed,
tokens_per_second: (config.max_tokens as f64 / 2.0) / (elapsed / 1000.0),
})
}
InferenceMode::RealSimd => {
// Use real SIMD-optimized inference
let engine = self.simd_engine.as_ref().ok_or_else(|| {
Error::Inference(InferenceError::InitFailed(
"SIMD engine not initialized".to_string(),
))
})?;
let simd_config: SimdGenerationConfig = (&config).into();
let (text, tokens_generated, inference_time_ms) =
engine.generate(prompt, &simd_config, session_key);
let tokens_per_second = if inference_time_ms > 0.0 {
(tokens_generated as f64 / inference_time_ms) * 1000.0
} else {
0.0
};
Ok(GenerationResult {
text,
tokens_generated,
model_used: model_size,
cache_hit: session_key.is_some(),
inference_time_ms,
tokens_per_second,
})
}
}
}
/// Health check
pub async fn health_check(&self) -> Result<HealthInfo> {
let (simd_vocab, simd_layers) = if let Some(engine) = &self.simd_engine {
engine.model_info()
} else {
(0, 0)
};
Ok(HealthInfo {
latency: 0.0,
loaded_models: self.models.len(),
available_memory: 0,
inference_mode: format!("{:?}", self.mode),
simd_vocab_size: simd_vocab,
simd_num_layers: simd_layers,
})
}
async fn get_or_load(&self, size: ModelSize) -> Result<Arc<MockModel>> {
// Check if already loaded
if let Some(model) = self.models.get(&size) {
self.update_lru(size);
return Ok(model.clone());
}
// Evict if needed
while self.models.len() >= self.config.max_loaded_models {
if let Some((evict_size, _)) = self.get_lru_oldest() {
self.models.remove(&evict_size);
}
}
// Load model
let model = Arc::new(MockModel { size });
self.models.insert(size, model.clone());
self.update_lru(size);
Ok(model)
}
fn update_lru(&self, size: ModelSize) {
let mut lru = self.lru.write();
lru.retain(|(s, _)| *s != size);
lru.push((size, Instant::now()));
}
fn get_lru_oldest(&self) -> Option<(ModelSize, Instant)> {
let lru = self.lru.read();
lru.first().cloned()
}
fn mock_generate(
&self,
prompt: &str,
config: &GenerationConfig,
model_size: ModelSize,
) -> String {
// Simple mock response based on prompt
let model_name = match model_size {
ModelSize::M350 => "350M",
ModelSize::M700 => "700M",
ModelSize::B1_2 => "1.2B",
ModelSize::B2_6 => "2.6B",
};
// Extract question from prompt
let question = if let Some(q_start) = prompt.find("Question:") {
let q = &prompt[q_start + 9..];
if let Some(end) = q.find('\n') {
q[..end].trim()
} else {
q.trim()
}
} else {
"your question"
};
format!(
"Based on the provided context, I can answer {}. \
[This is a mock response from {} model with temperature {:.1}]",
question, model_name, config.temperature
)
}
}
/// Health information
#[derive(Debug, Clone)]
pub struct HealthInfo {
/// Check latency in ms
pub latency: f32,
/// Number of loaded models
pub loaded_models: usize,
/// Available memory in bytes
pub available_memory: usize,
/// Current inference mode
pub inference_mode: String,
/// SIMD engine vocabulary size
pub simd_vocab_size: usize,
/// SIMD engine number of layers
pub simd_num_layers: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_inference_pool_creation() {
let config = InferenceConfig::default();
let pool = InferencePool::new(&config).await.unwrap();
assert_eq!(pool.models.len(), 0);
}
#[tokio::test]
async fn test_generate() {
let config = InferenceConfig::default();
let pool = InferencePool::new(&config).await.unwrap();
let result = pool
.generate(
ModelSize::M700,
"Question: What is Rust?\n\nAnswer:",
GenerationConfig::default(),
None,
)
.await
.unwrap();
assert!(!result.text.is_empty());
assert_eq!(result.model_used, ModelSize::M700);
}
#[tokio::test]
async fn test_model_eviction() {
let mut config = InferenceConfig::default();
config.max_loaded_models = 2;
let pool = InferencePool::new(&config).await.unwrap();
// Load 3 models
pool.generate(ModelSize::M350, "test", GenerationConfig::default(), None)
.await
.unwrap();
pool.generate(ModelSize::M700, "test", GenerationConfig::default(), None)
.await
.unwrap();
pool.generate(ModelSize::B1_2, "test", GenerationConfig::default(), None)
.await
.unwrap();
// Should only have 2 models loaded
assert!(pool.models.len() <= 2);
}
}