wifi-densepose/vendor/ruvector/crates/ruvllm/src/adapter_manager.rs

457 lines
13 KiB
Rust

//! LoRA Adapter Manager
//!
//! Manages loading, caching, and hot-swapping of LoRA adapters for
//! efficient model customization at runtime.
//!
//! ## Features
//!
//! - **Hot-swapping**: Switch adapters without model reload
//! - **Memory pooling**: Shared memory pool with KV cache
//! - **Versioning**: Track adapter versions for updates
//! - **Caching**: LRU cache for frequently used adapters
use crate::error::{Result, RuvLLMError};
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
/// LoRA adapter configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdapterConfig {
/// Adapter name/identifier
pub name: String,
/// LoRA rank (typically 4, 8, 16, 32)
pub rank: usize,
/// Alpha scaling factor
pub alpha: f32,
/// Dropout rate (0.0 = no dropout)
pub dropout: f32,
/// Target modules (e.g., ["q_proj", "v_proj"])
pub target_modules: Vec<String>,
/// Whether to merge adapter into base weights
pub merge_weights: bool,
}
impl Default for AdapterConfig {
fn default() -> Self {
Self {
name: "default".to_string(),
rank: 8,
alpha: 16.0,
dropout: 0.0,
target_modules: vec!["q_proj".to_string(), "v_proj".to_string()],
merge_weights: false,
}
}
}
/// LoRA adapter weights for a single layer
#[derive(Debug, Clone)]
pub struct LoraLayerWeights {
/// A matrix (in_features x rank)
pub lora_a: Vec<f32>,
/// B matrix (rank x out_features)
pub lora_b: Vec<f32>,
/// Input dimension
pub in_features: usize,
/// Output dimension
pub out_features: usize,
/// LoRA rank
pub rank: usize,
}
impl LoraLayerWeights {
/// Create new LoRA layer weights (initialized to zero for A, random for B typically)
pub fn new(in_features: usize, out_features: usize, rank: usize) -> Self {
Self {
lora_a: vec![0.0; in_features * rank],
lora_b: vec![0.0; rank * out_features],
in_features,
out_features,
rank,
}
}
/// Apply LoRA to input: output = input @ (A @ B * scale)
pub fn apply(&self, input: &[f32], alpha: f32) -> Vec<f32> {
let scale = alpha / self.rank as f32;
// input @ A: (batch, in_features) @ (in_features, rank) -> (batch, rank)
let batch_size = input.len() / self.in_features;
let mut intermediate = vec![0.0; batch_size * self.rank];
for b in 0..batch_size {
for r in 0..self.rank {
let mut sum = 0.0;
for i in 0..self.in_features {
sum += input[b * self.in_features + i] * self.lora_a[i * self.rank + r];
}
intermediate[b * self.rank + r] = sum;
}
}
// intermediate @ B: (batch, rank) @ (rank, out_features) -> (batch, out_features)
let mut output = vec![0.0; batch_size * self.out_features];
for b in 0..batch_size {
for o in 0..self.out_features {
let mut sum = 0.0;
for r in 0..self.rank {
sum += intermediate[b * self.rank + r] * self.lora_b[r * self.out_features + o];
}
output[b * self.out_features + o] = sum * scale;
}
}
output
}
/// Get memory usage in bytes
pub fn memory_bytes(&self) -> usize {
(self.lora_a.len() + self.lora_b.len()) * std::mem::size_of::<f32>()
}
}
/// Complete LoRA adapter with all layer weights
#[derive(Debug, Clone)]
pub struct LoraAdapter {
/// Unique adapter ID
pub id: Uuid,
/// Configuration
pub config: AdapterConfig,
/// Layer weights by module name
pub layers: HashMap<String, LoraLayerWeights>,
/// Version number
pub version: u64,
/// Creation timestamp
pub created_at: chrono::DateTime<chrono::Utc>,
/// Reference count
ref_count: Arc<std::sync::atomic::AtomicUsize>,
}
impl LoraAdapter {
/// Create a new LoRA adapter
pub fn new(config: AdapterConfig) -> Self {
Self {
id: Uuid::new_v4(),
config,
layers: HashMap::new(),
version: 1,
created_at: chrono::Utc::now(),
ref_count: Arc::new(std::sync::atomic::AtomicUsize::new(1)),
}
}
/// Add a layer to the adapter
pub fn add_layer(&mut self, module_name: String, weights: LoraLayerWeights) {
self.layers.insert(module_name, weights);
}
/// Get total memory usage
pub fn memory_bytes(&self) -> usize {
self.layers.values().map(|l| l.memory_bytes()).sum()
}
/// Apply adapter to a specific module's output
pub fn apply(&self, module_name: &str, input: &[f32], base_output: &mut [f32]) -> Result<()> {
if let Some(layer) = self.layers.get(module_name) {
let delta = layer.apply(input, self.config.alpha);
if delta.len() != base_output.len() {
return Err(RuvLLMError::Adapter(format!(
"Output size mismatch: expected {}, got {}",
base_output.len(),
delta.len()
)));
}
for (out, d) in base_output.iter_mut().zip(delta.iter()) {
*out += d;
}
}
Ok(())
}
/// Increment reference count
pub fn inc_ref(&self) {
self.ref_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
/// Decrement reference count, returns true if count reached zero
pub fn dec_ref(&self) -> bool {
self.ref_count
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
== 1
}
/// Get current reference count
pub fn ref_count(&self) -> usize {
self.ref_count.load(std::sync::atomic::Ordering::SeqCst)
}
}
/// Adapter cache entry
struct CacheEntry {
adapter: Arc<LoraAdapter>,
last_accessed: chrono::DateTime<chrono::Utc>,
}
/// LoRA adapter manager
pub struct AdapterManager {
/// Loaded adapters by ID
adapters: DashMap<Uuid, Arc<LoraAdapter>>,
/// Name to ID mapping
name_to_id: DashMap<String, Uuid>,
/// LRU cache for eviction
cache: RwLock<Vec<CacheEntry>>,
/// Maximum number of adapters to keep loaded
max_loaded: usize,
/// Maximum total memory for adapters
max_memory_bytes: usize,
/// Current memory usage
current_memory: std::sync::atomic::AtomicUsize,
}
impl AdapterManager {
/// Create a new adapter manager
pub fn new() -> Self {
Self {
adapters: DashMap::new(),
name_to_id: DashMap::new(),
cache: RwLock::new(Vec::new()),
max_loaded: 16,
max_memory_bytes: 512 * 1024 * 1024, // 512MB
current_memory: std::sync::atomic::AtomicUsize::new(0),
}
}
/// Create with custom limits
pub fn with_limits(max_loaded: usize, max_memory_bytes: usize) -> Self {
Self {
adapters: DashMap::new(),
name_to_id: DashMap::new(),
cache: RwLock::new(Vec::new()),
max_loaded,
max_memory_bytes,
current_memory: std::sync::atomic::AtomicUsize::new(0),
}
}
/// Load an adapter
pub fn load(&self, adapter: LoraAdapter) -> Result<Uuid> {
let memory_needed = adapter.memory_bytes();
// Check memory limits
self.ensure_memory(memory_needed)?;
let id = adapter.id;
let name = adapter.config.name.clone();
let adapter = Arc::new(adapter);
self.adapters.insert(id, adapter.clone());
self.name_to_id.insert(name, id);
// Add to cache
let mut cache = self.cache.write();
cache.push(CacheEntry {
adapter,
last_accessed: chrono::Utc::now(),
});
self.current_memory
.fetch_add(memory_needed, std::sync::atomic::Ordering::SeqCst);
Ok(id)
}
/// Ensure there's enough memory for a new adapter
fn ensure_memory(&self, needed: usize) -> Result<()> {
let current = self
.current_memory
.load(std::sync::atomic::Ordering::SeqCst);
if current + needed <= self.max_memory_bytes {
return Ok(());
}
// Need to evict some adapters
let mut cache = self.cache.write();
// Sort by last accessed (oldest first)
cache.sort_by(|a, b| a.last_accessed.cmp(&b.last_accessed));
let mut freed = 0;
while freed < needed && !cache.is_empty() {
if let Some(entry) = cache.first() {
if entry.adapter.ref_count() <= 1 {
let id = entry.adapter.id;
let size = entry.adapter.memory_bytes();
// Remove from maps
self.adapters.remove(&id);
self.name_to_id.remove(&entry.adapter.config.name);
cache.remove(0);
freed += size;
self.current_memory
.fetch_sub(size, std::sync::atomic::Ordering::SeqCst);
} else {
// Adapter is in use, move to end
let entry = cache.remove(0);
cache.push(entry);
}
}
}
if freed < needed {
return Err(RuvLLMError::OutOfMemory(
"Cannot free enough memory for new adapter".to_string(),
));
}
Ok(())
}
/// Get adapter by ID
pub fn get(&self, id: &Uuid) -> Option<Arc<LoraAdapter>> {
if let Some(adapter) = self.adapters.get(id) {
// Update last accessed
let mut cache = self.cache.write();
if let Some(entry) = cache.iter_mut().find(|e| e.adapter.id == *id) {
entry.last_accessed = chrono::Utc::now();
}
Some(adapter.clone())
} else {
None
}
}
/// Get adapter by name
pub fn get_by_name(&self, name: &str) -> Option<Arc<LoraAdapter>> {
self.name_to_id.get(name).and_then(|id| self.get(&id))
}
/// Unload an adapter
pub fn unload(&self, id: &Uuid) -> Result<()> {
if let Some((_, adapter)) = self.adapters.remove(id) {
self.name_to_id.remove(&adapter.config.name);
let mut cache = self.cache.write();
cache.retain(|e| e.adapter.id != *id);
self.current_memory
.fetch_sub(adapter.memory_bytes(), std::sync::atomic::Ordering::SeqCst);
}
Ok(())
}
/// List all loaded adapters
pub fn list(&self) -> Vec<AdapterInfo> {
self.adapters
.iter()
.map(|entry| {
let adapter = entry.value();
AdapterInfo {
id: adapter.id,
name: adapter.config.name.clone(),
rank: adapter.config.rank,
version: adapter.version,
memory_bytes: adapter.memory_bytes(),
ref_count: adapter.ref_count(),
}
})
.collect()
}
/// Get memory statistics
pub fn memory_stats(&self) -> AdapterMemoryStats {
AdapterMemoryStats {
total_budget: self.max_memory_bytes,
used_bytes: self
.current_memory
.load(std::sync::atomic::Ordering::SeqCst),
adapter_count: self.adapters.len(),
max_adapters: self.max_loaded,
}
}
}
impl Default for AdapterManager {
fn default() -> Self {
Self::new()
}
}
/// Adapter information summary
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdapterInfo {
/// Adapter ID
pub id: Uuid,
/// Adapter name
pub name: String,
/// LoRA rank
pub rank: usize,
/// Version number
pub version: u64,
/// Memory usage
pub memory_bytes: usize,
/// Current reference count
pub ref_count: usize,
}
/// Adapter memory statistics
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AdapterMemoryStats {
/// Total memory budget
pub total_budget: usize,
/// Currently used bytes
pub used_bytes: usize,
/// Number of loaded adapters
pub adapter_count: usize,
/// Maximum number of adapters
pub max_adapters: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lora_layer_weights() {
let weights = LoraLayerWeights::new(4, 4, 2);
assert_eq!(weights.lora_a.len(), 8); // 4 * 2
assert_eq!(weights.lora_b.len(), 8); // 2 * 4
}
#[test]
fn test_lora_adapter() {
let config = AdapterConfig {
name: "test".to_string(),
rank: 4,
..Default::default()
};
let mut adapter = LoraAdapter::new(config);
adapter.add_layer("q_proj".to_string(), LoraLayerWeights::new(64, 64, 4));
assert_eq!(adapter.layers.len(), 1);
assert!(adapter.memory_bytes() > 0);
}
#[test]
fn test_adapter_manager() {
let manager = AdapterManager::new();
let adapter = LoraAdapter::new(AdapterConfig::default());
let id = manager.load(adapter).unwrap();
assert!(manager.get(&id).is_some());
assert!(manager.get_by_name("default").is_some());
manager.unload(&id).unwrap();
assert!(manager.get(&id).is_none());
}
}