wifi-densepose/vendor/ruvector/examples/refrag-pipeline/src/expand.rs

450 lines
13 KiB
Rust

//! Expand Layer - Tensor Projection
//!
//! This module handles dimension adaptation when stored tensor dimensions
//! don't match the target LLM's expected input dimensions.
//!
//! For example, projecting 768-dim RoBERTa embeddings to 4096-dim LLaMA space.
use ndarray::{Array1, Array2};
use rand::Rng;
use std::collections::HashMap;
use std::time::Instant;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ProjectionError {
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Projector not found for model: {0}")]
ProjectorNotFound(String),
#[error("Invalid projection weights: {0}")]
InvalidWeights(String),
}
pub type Result<T> = std::result::Result<T, ProjectionError>;
/// Linear projector: y = Wx + b
///
/// Projects from source dimension to target dimension.
#[derive(Clone)]
pub struct Projector {
/// Weight matrix [target_dim, source_dim]
weights: Array2<f32>,
/// Bias vector [target_dim]
bias: Array1<f32>,
/// Source dimension
source_dim: usize,
/// Target dimension
target_dim: usize,
/// Model identifier
model_id: String,
}
impl Projector {
/// Create a new projector with random initialization
pub fn new(source_dim: usize, target_dim: usize, model_id: impl Into<String>) -> Self {
let mut rng = rand::thread_rng();
// Xavier initialization
let scale = (2.0 / (source_dim + target_dim) as f32).sqrt();
let weights_data: Vec<f32> = (0..target_dim * source_dim)
.map(|_| rng.gen_range(-scale..scale))
.collect();
Self {
weights: Array2::from_shape_vec((target_dim, source_dim), weights_data).unwrap(),
bias: Array1::zeros(target_dim),
source_dim,
target_dim,
model_id: model_id.into(),
}
}
/// Create identity projector (no transformation)
pub fn identity(dim: usize, model_id: impl Into<String>) -> Self {
let mut weights = Array2::zeros((dim, dim));
for i in 0..dim {
weights[[i, i]] = 1.0;
}
Self {
weights,
bias: Array1::zeros(dim),
source_dim: dim,
target_dim: dim,
model_id: model_id.into(),
}
}
/// Create with specific weights
pub fn with_weights(
weights: Array2<f32>,
bias: Array1<f32>,
model_id: impl Into<String>,
) -> Result<Self> {
let (target_dim, source_dim) = weights.dim();
if bias.len() != target_dim {
return Err(ProjectionError::InvalidWeights(format!(
"Bias length {} doesn't match target dim {}",
bias.len(),
target_dim
)));
}
Ok(Self {
weights,
bias,
source_dim,
target_dim,
model_id: model_id.into(),
})
}
/// Project a vector from source to target dimension
pub fn project(&self, input: &[f32]) -> Result<Vec<f32>> {
if input.len() != self.source_dim {
return Err(ProjectionError::DimensionMismatch {
expected: self.source_dim,
actual: input.len(),
});
}
let input_arr = Array1::from_vec(input.to_vec());
let output = self.weights.dot(&input_arr) + &self.bias;
Ok(output.to_vec())
}
/// Project with timing info
pub fn project_timed(&self, input: &[f32]) -> Result<(Vec<f32>, u64)> {
let start = Instant::now();
let result = self.project(input)?;
let latency_us = start.elapsed().as_micros() as u64;
Ok((result, latency_us))
}
/// Batch project multiple vectors
pub fn project_batch(&self, inputs: &[Vec<f32>]) -> Result<Vec<Vec<f32>>> {
inputs.iter().map(|v| self.project(v)).collect()
}
/// Get source dimension
pub fn source_dim(&self) -> usize {
self.source_dim
}
/// Get target dimension
pub fn target_dim(&self) -> usize {
self.target_dim
}
/// Get model identifier
pub fn model_id(&self) -> &str {
&self.model_id
}
/// Export weights to binary format
pub fn export_weights(&self) -> Vec<u8> {
let mut data = Vec::new();
// Header: source_dim, target_dim, model_id length
data.extend_from_slice(&(self.source_dim as u32).to_le_bytes());
data.extend_from_slice(&(self.target_dim as u32).to_le_bytes());
let model_id_bytes = self.model_id.as_bytes();
data.extend_from_slice(&(model_id_bytes.len() as u32).to_le_bytes());
data.extend_from_slice(model_id_bytes);
// Weights (row-major)
for &w in self.weights.iter() {
data.extend_from_slice(&w.to_le_bytes());
}
// Bias
for &b in self.bias.iter() {
data.extend_from_slice(&b.to_le_bytes());
}
data
}
/// Load weights from binary format
pub fn load_weights(data: &[u8]) -> Result<Self> {
if data.len() < 12 {
return Err(ProjectionError::InvalidWeights("Data too short".into()));
}
let source_dim = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
let target_dim = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
let model_id_len = u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize;
let model_id = String::from_utf8_lossy(&data[12..12 + model_id_len]).to_string();
let weights_start = 12 + model_id_len;
let weights_size = target_dim * source_dim * 4;
let bias_size = target_dim * 4;
if data.len() < weights_start + weights_size + bias_size {
return Err(ProjectionError::InvalidWeights(
"Data too short for weights".into(),
));
}
let mut weights_data = Vec::with_capacity(target_dim * source_dim);
for chunk in data[weights_start..weights_start + weights_size].chunks_exact(4) {
let bytes: [u8; 4] = chunk.try_into().unwrap();
weights_data.push(f32::from_le_bytes(bytes));
}
let mut bias_data = Vec::with_capacity(target_dim);
for chunk in data[weights_start + weights_size..].chunks_exact(4) {
let bytes: [u8; 4] = chunk.try_into().unwrap();
bias_data.push(f32::from_le_bytes(bytes));
}
Ok(Self {
weights: Array2::from_shape_vec((target_dim, source_dim), weights_data).unwrap(),
bias: Array1::from_vec(bias_data),
source_dim,
target_dim,
model_id,
})
}
}
/// Registry of projectors for different model alignments
pub struct ProjectorRegistry {
projectors: HashMap<String, Projector>,
}
impl ProjectorRegistry {
pub fn new() -> Self {
Self {
projectors: HashMap::new(),
}
}
/// Register a projector for a model
pub fn register(&mut self, projector: Projector) {
self.projectors
.insert(projector.model_id.clone(), projector);
}
/// Get projector for a model
pub fn get(&self, model_id: &str) -> Option<&Projector> {
self.projectors.get(model_id)
}
/// Project tensor to target LLM space
pub fn project(&self, tensor: &[f32], model_id: &str) -> Result<Vec<f32>> {
let projector = self
.projectors
.get(model_id)
.ok_or_else(|| ProjectionError::ProjectorNotFound(model_id.to_string()))?;
projector.project(tensor)
}
/// Check if projector exists for model
pub fn has_projector(&self, model_id: &str) -> bool {
self.projectors.contains_key(model_id)
}
/// List registered models
pub fn models(&self) -> Vec<&str> {
self.projectors.keys().map(|s| s.as_str()).collect()
}
/// Create with common LLM projectors
pub fn with_defaults(source_dim: usize) -> Self {
let mut registry = Self::new();
// Common LLM configurations
let models = [
("llama3-8b", 4096),
("llama3-70b", 8192),
("gpt-4", 8192),
("claude-3", 8192),
("mistral-7b", 4096),
("phi-3", 3072),
];
for (model_id, target_dim) in models {
if source_dim == target_dim {
registry.register(Projector::identity(source_dim, model_id));
} else {
registry.register(Projector::new(source_dim, target_dim, model_id));
}
}
registry
}
}
impl Default for ProjectorRegistry {
fn default() -> Self {
Self::new()
}
}
/// Expand layer for REFRAG pipeline
pub struct ExpandLayer {
registry: ProjectorRegistry,
/// Default target model
default_model: String,
/// Enable auto-projection
auto_project: bool,
}
impl ExpandLayer {
pub fn new(registry: ProjectorRegistry, default_model: impl Into<String>) -> Self {
Self {
registry,
default_model: default_model.into(),
auto_project: true,
}
}
/// Create with default projectors for 768-dim source
pub fn for_roberta() -> Self {
Self::new(ProjectorRegistry::with_defaults(768), "llama3-8b")
}
/// Create with default projectors for 1536-dim source (OpenAI ada-002)
pub fn for_openai() -> Self {
Self::new(ProjectorRegistry::with_defaults(1536), "gpt-4")
}
/// Set default target model
pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
self.default_model = model.into();
self
}
/// Enable/disable auto-projection
pub fn with_auto_project(mut self, enabled: bool) -> Self {
self.auto_project = enabled;
self
}
/// Expand tensor to target LLM space
pub fn expand(&self, tensor: &[f32], target_model: Option<&str>) -> Result<Vec<f32>> {
let model = target_model.unwrap_or(&self.default_model);
self.registry.project(tensor, model)
}
/// Expand with automatic model detection
pub fn expand_auto(&self, tensor: &[f32], alignment_model: Option<&str>) -> Result<Vec<f32>> {
if !self.auto_project {
return Ok(tensor.to_vec());
}
let model = alignment_model.unwrap_or(&self.default_model);
self.registry.project(tensor, model)
}
/// Check if expansion is needed
pub fn needs_expansion(&self, tensor_dim: usize, target_model: &str) -> bool {
if let Some(projector) = self.registry.get(target_model) {
projector.target_dim() != tensor_dim
} else {
false
}
}
/// Get registry for registration
pub fn registry_mut(&mut self) -> &mut ProjectorRegistry {
&mut self.registry
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_projector_dimensions() {
let projector = Projector::new(768, 4096, "test-model");
assert_eq!(projector.source_dim(), 768);
assert_eq!(projector.target_dim(), 4096);
assert_eq!(projector.model_id(), "test-model");
}
#[test]
fn test_identity_projector() {
let projector = Projector::identity(4, "identity");
let input = vec![1.0, 2.0, 3.0, 4.0];
let output = projector.project(&input).unwrap();
assert_eq!(input, output);
}
#[test]
fn test_projection() {
let projector = Projector::new(4, 8, "test");
let input = vec![1.0, 2.0, 3.0, 4.0];
let output = projector.project(&input).unwrap();
assert_eq!(output.len(), 8);
}
#[test]
fn test_dimension_mismatch() {
let projector = Projector::new(4, 8, "test");
let input = vec![1.0, 2.0, 3.0]; // Wrong size
let result = projector.project(&input);
assert!(matches!(
result,
Err(ProjectionError::DimensionMismatch { .. })
));
}
#[test]
fn test_projector_registry() {
let mut registry = ProjectorRegistry::new();
registry.register(Projector::new(768, 4096, "llama3-8b"));
registry.register(Projector::new(768, 8192, "gpt-4"));
assert!(registry.has_projector("llama3-8b"));
assert!(registry.has_projector("gpt-4"));
assert!(!registry.has_projector("unknown"));
let models = registry.models();
assert_eq!(models.len(), 2);
}
#[test]
fn test_expand_layer() {
let expand = ExpandLayer::for_roberta();
let tensor = vec![0.1f32; 768];
let expanded = expand.expand(&tensor, Some("llama3-8b")).unwrap();
assert_eq!(expanded.len(), 4096);
}
#[test]
fn test_weight_export_import() {
let projector = Projector::new(4, 8, "test-model");
let exported = projector.export_weights();
let imported = Projector::load_weights(&exported).unwrap();
assert_eq!(projector.source_dim(), imported.source_dim());
assert_eq!(projector.target_dim(), imported.target_dim());
assert_eq!(projector.model_id(), imported.model_id());
// Verify same projection behavior
let input = vec![1.0, 2.0, 3.0, 4.0];
let out1 = projector.project(&input).unwrap();
let out2 = imported.project(&input).unwrap();
for (a, b) in out1.iter().zip(out2.iter()) {
assert!((a - b).abs() < f32::EPSILON);
}
}
}