532 lines
15 KiB
Rust
532 lines
15 KiB
Rust
//! Quantization Techniques for iOS/Browser WASM
|
|
//!
|
|
//! Memory-efficient vector compression for mobile devices.
|
|
//! - Scalar Quantization: 4x compression (f32 → u8)
|
|
//! - Binary Quantization: 32x compression (f32 → 1 bit)
|
|
//! - Product Quantization: 8-16x compression
|
|
|
|
use std::vec::Vec;
|
|
|
|
// ============================================
|
|
// Scalar Quantization (4x compression)
|
|
// ============================================
|
|
|
|
/// Scalar-quantized vector (f32 → u8)
|
|
#[derive(Clone, Debug)]
|
|
pub struct ScalarQuantized {
|
|
/// Quantized values
|
|
pub data: Vec<u8>,
|
|
/// Minimum value for reconstruction
|
|
pub min: f32,
|
|
/// Scale factor for reconstruction
|
|
pub scale: f32,
|
|
}
|
|
|
|
impl ScalarQuantized {
|
|
/// Quantize a float vector to u8
|
|
pub fn quantize(vector: &[f32]) -> Self {
|
|
if vector.is_empty() {
|
|
return Self {
|
|
data: vec![],
|
|
min: 0.0,
|
|
scale: 1.0,
|
|
};
|
|
}
|
|
|
|
let min = vector.iter().cloned().fold(f32::INFINITY, f32::min);
|
|
let max = vector.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
|
|
|
let scale = if (max - min).abs() < f32::EPSILON {
|
|
1.0
|
|
} else {
|
|
(max - min) / 255.0
|
|
};
|
|
|
|
let data = vector
|
|
.iter()
|
|
.map(|&v| ((v - min) / scale).round().clamp(0.0, 255.0) as u8)
|
|
.collect();
|
|
|
|
Self { data, min, scale }
|
|
}
|
|
|
|
/// Reconstruct approximate float vector
|
|
pub fn reconstruct(&self) -> Vec<f32> {
|
|
self.data
|
|
.iter()
|
|
.map(|&v| self.min + (v as f32) * self.scale)
|
|
.collect()
|
|
}
|
|
|
|
/// Fast distance calculation in quantized space
|
|
pub fn distance(&self, other: &Self) -> f32 {
|
|
let mut sum = 0i32;
|
|
for (&a, &b) in self.data.iter().zip(other.data.iter()) {
|
|
let diff = a as i32 - b as i32;
|
|
sum += diff * diff;
|
|
}
|
|
(sum as f32).sqrt() * self.scale.max(other.scale)
|
|
}
|
|
|
|
/// Asymmetric distance (query is float, database is quantized)
|
|
pub fn asymmetric_distance(&self, query: &[f32]) -> f32 {
|
|
let len = self.data.len().min(query.len());
|
|
let mut sum = 0.0f32;
|
|
|
|
for i in 0..len {
|
|
let reconstructed = self.min + (self.data[i] as f32) * self.scale;
|
|
let diff = reconstructed - query[i];
|
|
sum += diff * diff;
|
|
}
|
|
|
|
sum.sqrt()
|
|
}
|
|
|
|
/// Get memory size in bytes
|
|
pub fn memory_size(&self) -> usize {
|
|
self.data.len() + 8 // data + min + scale
|
|
}
|
|
|
|
/// Serialize to bytes
|
|
pub fn serialize(&self) -> Vec<u8> {
|
|
let mut bytes = Vec::with_capacity(8 + self.data.len());
|
|
bytes.extend_from_slice(&self.min.to_le_bytes());
|
|
bytes.extend_from_slice(&self.scale.to_le_bytes());
|
|
bytes.extend_from_slice(&self.data);
|
|
bytes
|
|
}
|
|
|
|
/// Deserialize from bytes
|
|
pub fn deserialize(bytes: &[u8]) -> Option<Self> {
|
|
if bytes.len() < 8 {
|
|
return None;
|
|
}
|
|
let min = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
|
|
let scale = f32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
|
|
let data = bytes[8..].to_vec();
|
|
Some(Self { data, min, scale })
|
|
}
|
|
|
|
/// Estimate serialized size
|
|
pub fn serialized_size(&self) -> usize {
|
|
8 + self.data.len()
|
|
}
|
|
}
|
|
|
|
// ============================================
|
|
// Binary Quantization (32x compression)
|
|
// ============================================
|
|
|
|
/// Binary-quantized vector (f32 → 1 bit)
|
|
#[derive(Clone, Debug)]
|
|
pub struct BinaryQuantized {
|
|
/// Packed bits (8 dimensions per byte)
|
|
pub bits: Vec<u8>,
|
|
/// Original dimension count
|
|
pub dimensions: usize,
|
|
}
|
|
|
|
impl BinaryQuantized {
|
|
/// Quantize float vector to binary (sign-based)
|
|
pub fn quantize(vector: &[f32]) -> Self {
|
|
let dimensions = vector.len();
|
|
let num_bytes = (dimensions + 7) / 8;
|
|
let mut bits = vec![0u8; num_bytes];
|
|
|
|
for (i, &v) in vector.iter().enumerate() {
|
|
if v > 0.0 {
|
|
let byte_idx = i / 8;
|
|
let bit_idx = i % 8;
|
|
bits[byte_idx] |= 1 << bit_idx;
|
|
}
|
|
}
|
|
|
|
Self { bits, dimensions }
|
|
}
|
|
|
|
/// Quantize with threshold (not just sign)
|
|
pub fn quantize_with_threshold(vector: &[f32], threshold: f32) -> Self {
|
|
let dimensions = vector.len();
|
|
let num_bytes = (dimensions + 7) / 8;
|
|
let mut bits = vec![0u8; num_bytes];
|
|
|
|
for (i, &v) in vector.iter().enumerate() {
|
|
if v > threshold {
|
|
let byte_idx = i / 8;
|
|
let bit_idx = i % 8;
|
|
bits[byte_idx] |= 1 << bit_idx;
|
|
}
|
|
}
|
|
|
|
Self { bits, dimensions }
|
|
}
|
|
|
|
/// Hamming distance between two binary vectors
|
|
pub fn distance(&self, other: &Self) -> u32 {
|
|
let mut distance = 0u32;
|
|
for (&a, &b) in self.bits.iter().zip(other.bits.iter()) {
|
|
distance += (a ^ b).count_ones();
|
|
}
|
|
distance
|
|
}
|
|
|
|
/// Asymmetric distance to float query
|
|
pub fn asymmetric_distance(&self, query: &[f32]) -> f32 {
|
|
let mut distance = 0u32;
|
|
for (i, &q) in query.iter().take(self.dimensions).enumerate() {
|
|
let byte_idx = i / 8;
|
|
let bit_idx = i % 8;
|
|
let bit = (self.bits.get(byte_idx).unwrap_or(&0) >> bit_idx) & 1;
|
|
|
|
let query_bit = if q > 0.0 { 1 } else { 0 };
|
|
if bit != query_bit {
|
|
distance += 1;
|
|
}
|
|
}
|
|
distance as f32
|
|
}
|
|
|
|
/// Reconstruct to +1/-1 vector
|
|
pub fn reconstruct(&self) -> Vec<f32> {
|
|
let mut result = Vec::with_capacity(self.dimensions);
|
|
for i in 0..self.dimensions {
|
|
let byte_idx = i / 8;
|
|
let bit_idx = i % 8;
|
|
let bit = (self.bits.get(byte_idx).unwrap_or(&0) >> bit_idx) & 1;
|
|
result.push(if bit == 1 { 1.0 } else { -1.0 });
|
|
}
|
|
result
|
|
}
|
|
|
|
/// Get memory size in bytes
|
|
pub fn memory_size(&self) -> usize {
|
|
self.bits.len() + 8 // bits + dimensions (as usize)
|
|
}
|
|
|
|
/// Serialize to bytes
|
|
pub fn serialize(&self) -> Vec<u8> {
|
|
let mut bytes = Vec::with_capacity(4 + self.bits.len());
|
|
bytes.extend_from_slice(&(self.dimensions as u32).to_le_bytes());
|
|
bytes.extend_from_slice(&self.bits);
|
|
bytes
|
|
}
|
|
|
|
/// Deserialize from bytes
|
|
pub fn deserialize(bytes: &[u8]) -> Option<Self> {
|
|
if bytes.len() < 4 {
|
|
return None;
|
|
}
|
|
let dimensions = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
|
|
let bits = bytes[4..].to_vec();
|
|
Some(Self { bits, dimensions })
|
|
}
|
|
|
|
/// Estimate serialized size
|
|
pub fn serialized_size(&self) -> usize {
|
|
4 + self.bits.len()
|
|
}
|
|
}
|
|
|
|
// ============================================
|
|
// Simple Product Quantization (8-16x compression)
|
|
// ============================================
|
|
|
|
/// Product-quantized vector
|
|
#[derive(Clone, Debug)]
|
|
pub struct ProductQuantized {
|
|
/// Quantized codes (one per subspace)
|
|
pub codes: Vec<u8>,
|
|
/// Number of subspaces
|
|
pub num_subspaces: usize,
|
|
}
|
|
|
|
/// Product quantization codebook
|
|
#[derive(Clone, Debug)]
|
|
pub struct PQCodebook {
|
|
/// Centroids for each subspace [subspace][centroid][dim]
|
|
pub centroids: Vec<Vec<Vec<f32>>>,
|
|
/// Number of subspaces
|
|
pub num_subspaces: usize,
|
|
/// Dimension per subspace
|
|
pub subspace_dim: usize,
|
|
/// Number of centroids (usually 256 for u8 codes)
|
|
pub num_centroids: usize,
|
|
}
|
|
|
|
impl PQCodebook {
|
|
/// Train a PQ codebook using k-means
|
|
pub fn train(
|
|
vectors: &[Vec<f32>],
|
|
num_subspaces: usize,
|
|
num_centroids: usize,
|
|
iterations: usize,
|
|
) -> Self {
|
|
if vectors.is_empty() {
|
|
return Self {
|
|
centroids: vec![],
|
|
num_subspaces,
|
|
subspace_dim: 0,
|
|
num_centroids,
|
|
};
|
|
}
|
|
|
|
let dim = vectors[0].len();
|
|
let subspace_dim = dim / num_subspaces;
|
|
let mut centroids = Vec::with_capacity(num_subspaces);
|
|
|
|
// Train each subspace independently
|
|
for s in 0..num_subspaces {
|
|
let start = s * subspace_dim;
|
|
let end = start + subspace_dim;
|
|
|
|
// Extract subvectors
|
|
let subvectors: Vec<Vec<f32>> = vectors
|
|
.iter()
|
|
.map(|v| v[start..end].to_vec())
|
|
.collect();
|
|
|
|
// Run k-means
|
|
let subspace_centroids = kmeans(&subvectors, num_centroids, iterations);
|
|
centroids.push(subspace_centroids);
|
|
}
|
|
|
|
Self {
|
|
centroids,
|
|
num_subspaces,
|
|
subspace_dim,
|
|
num_centroids,
|
|
}
|
|
}
|
|
|
|
/// Encode a vector using this codebook
|
|
pub fn encode(&self, vector: &[f32]) -> ProductQuantized {
|
|
let mut codes = Vec::with_capacity(self.num_subspaces);
|
|
|
|
for (s, subspace_centroids) in self.centroids.iter().enumerate() {
|
|
let start = s * self.subspace_dim;
|
|
let end = start + self.subspace_dim;
|
|
let subvector = &vector[start..end];
|
|
|
|
// Find nearest centroid
|
|
let code = subspace_centroids
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, c)| {
|
|
let dist = euclidean_squared(subvector, c);
|
|
(i, dist)
|
|
})
|
|
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
|
|
.map(|(i, _)| i as u8)
|
|
.unwrap_or(0);
|
|
|
|
codes.push(code);
|
|
}
|
|
|
|
ProductQuantized {
|
|
codes,
|
|
num_subspaces: self.num_subspaces,
|
|
}
|
|
}
|
|
|
|
/// Decode a PQ vector back to approximate floats
|
|
pub fn decode(&self, pq: &ProductQuantized) -> Vec<f32> {
|
|
let mut result = Vec::with_capacity(self.num_subspaces * self.subspace_dim);
|
|
|
|
for (s, &code) in pq.codes.iter().enumerate() {
|
|
if s < self.centroids.len() && (code as usize) < self.centroids[s].len() {
|
|
result.extend_from_slice(&self.centroids[s][code as usize]);
|
|
}
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
/// Compute distance using precomputed distance table (ADC)
|
|
pub fn asymmetric_distance(&self, pq: &ProductQuantized, query: &[f32]) -> f32 {
|
|
let mut dist = 0.0f32;
|
|
|
|
for (s, &code) in pq.codes.iter().enumerate() {
|
|
let start = s * self.subspace_dim;
|
|
let end = start + self.subspace_dim;
|
|
let query_sub = &query[start..end];
|
|
|
|
if s < self.centroids.len() && (code as usize) < self.centroids[s].len() {
|
|
let centroid = &self.centroids[s][code as usize];
|
|
dist += euclidean_squared(query_sub, centroid);
|
|
}
|
|
}
|
|
|
|
dist.sqrt()
|
|
}
|
|
}
|
|
|
|
// ============================================
|
|
// Helper Functions
|
|
// ============================================
|
|
|
|
fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
|
|
a.iter()
|
|
.zip(b.iter())
|
|
.map(|(&x, &y)| {
|
|
let d = x - y;
|
|
d * d
|
|
})
|
|
.sum()
|
|
}
|
|
|
|
fn kmeans(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
|
|
if vectors.is_empty() || k == 0 {
|
|
return vec![];
|
|
}
|
|
|
|
let dim = vectors[0].len();
|
|
|
|
// Initialize centroids (first k vectors or random subset)
|
|
let mut centroids: Vec<Vec<f32>> = vectors.iter().take(k).cloned().collect();
|
|
|
|
// Pad if not enough vectors
|
|
while centroids.len() < k {
|
|
centroids.push(vec![0.0; dim]);
|
|
}
|
|
|
|
for _ in 0..iterations {
|
|
// Assign vectors to clusters
|
|
let mut assignments: Vec<Vec<Vec<f32>>> = vec![vec![]; k];
|
|
|
|
for vector in vectors {
|
|
let nearest = centroids
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, c)| (i, euclidean_squared(vector, c)))
|
|
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
|
|
.map(|(i, _)| i)
|
|
.unwrap_or(0);
|
|
|
|
assignments[nearest].push(vector.clone());
|
|
}
|
|
|
|
// Update centroids
|
|
for (centroid, assigned) in centroids.iter_mut().zip(assignments.iter()) {
|
|
if !assigned.is_empty() {
|
|
for (i, c) in centroid.iter_mut().enumerate() {
|
|
*c = assigned.iter().map(|v| v[i]).sum::<f32>() / assigned.len() as f32;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
centroids
|
|
}
|
|
|
|
// ============================================
|
|
// WASM Exports
|
|
// ============================================
|
|
|
|
/// Scalar quantize a vector
|
|
#[no_mangle]
|
|
pub extern "C" fn scalar_quantize(
|
|
input_ptr: *const f32,
|
|
len: u32,
|
|
out_data: *mut u8,
|
|
out_min: *mut f32,
|
|
out_scale: *mut f32,
|
|
) {
|
|
unsafe {
|
|
let input = core::slice::from_raw_parts(input_ptr, len as usize);
|
|
let sq = ScalarQuantized::quantize(input);
|
|
|
|
let out = core::slice::from_raw_parts_mut(out_data, sq.data.len());
|
|
out.copy_from_slice(&sq.data);
|
|
|
|
*out_min = sq.min;
|
|
*out_scale = sq.scale;
|
|
}
|
|
}
|
|
|
|
/// Binary quantize a vector
|
|
#[no_mangle]
|
|
pub extern "C" fn binary_quantize(
|
|
input_ptr: *const f32,
|
|
len: u32,
|
|
out_bits: *mut u8,
|
|
) -> u32 {
|
|
unsafe {
|
|
let input = core::slice::from_raw_parts(input_ptr, len as usize);
|
|
let bq = BinaryQuantized::quantize(input);
|
|
|
|
let out = core::slice::from_raw_parts_mut(out_bits, bq.bits.len());
|
|
out.copy_from_slice(&bq.bits);
|
|
|
|
bq.bits.len() as u32
|
|
}
|
|
}
|
|
|
|
/// Hamming distance between two binary vectors
|
|
#[no_mangle]
|
|
pub extern "C" fn hamming_distance(
|
|
a_ptr: *const u8,
|
|
b_ptr: *const u8,
|
|
len: u32,
|
|
) -> u32 {
|
|
unsafe {
|
|
let a = core::slice::from_raw_parts(a_ptr, len as usize);
|
|
let b = core::slice::from_raw_parts(b_ptr, len as usize);
|
|
|
|
let mut distance = 0u32;
|
|
for (&x, &y) in a.iter().zip(b.iter()) {
|
|
distance += (x ^ y).count_ones();
|
|
}
|
|
distance
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_scalar_quantization() {
|
|
let v = vec![0.0, 0.5, 1.0, 0.25, 0.75];
|
|
let sq = ScalarQuantized::quantize(&v);
|
|
let reconstructed = sq.reconstruct();
|
|
|
|
for (orig, recon) in v.iter().zip(reconstructed.iter()) {
|
|
assert!((orig - recon).abs() < 0.01);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_binary_quantization() {
|
|
let v = vec![1.0, -1.0, 0.5, -0.5];
|
|
let bq = BinaryQuantized::quantize(&v);
|
|
|
|
assert_eq!(bq.dimensions, 4);
|
|
assert_eq!(bq.bits.len(), 1);
|
|
assert_eq!(bq.bits[0], 0b0101); // positions 0 and 2 are positive
|
|
}
|
|
|
|
#[test]
|
|
fn test_hamming_distance() {
|
|
let v1 = vec![1.0, 1.0, 1.0, 1.0];
|
|
let v2 = vec![1.0, -1.0, 1.0, -1.0];
|
|
|
|
let bq1 = BinaryQuantized::quantize(&v1);
|
|
let bq2 = BinaryQuantized::quantize(&v2);
|
|
|
|
assert_eq!(bq1.distance(&bq2), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_pq_encode_decode() {
|
|
let vectors: Vec<Vec<f32>> = (0..100)
|
|
.map(|i| vec![i as f32 / 100.0; 8])
|
|
.collect();
|
|
|
|
let codebook = PQCodebook::train(&vectors, 2, 16, 10);
|
|
let pq = codebook.encode(&vectors[50]);
|
|
let decoded = codebook.decode(&pq);
|
|
|
|
assert_eq!(decoded.len(), 8);
|
|
}
|
|
}
|