//! Attention-based re-ranking for search results. //! //! Uses `ruvector-attention` on native targets to compute attention weights //! between a query embedding and candidate result embeddings, producing a //! relevance-aware re-ranking that goes beyond raw cosine similarity. //! //! On WASM targets a lightweight fallback is provided that preserves the //! original cosine ordering. /// Re-ranks search results using scaled dot-product attention. /// /// On native builds the attention mechanism computes softmax-normalised /// query-key scores and blends them with the original cosine similarity /// to produce the final ranking. On WASM the original scores are /// returned unchanged (sorted descending). pub struct AttentionReranker { dim: usize, #[allow(dead_code)] num_heads: usize, } impl AttentionReranker { /// Creates a new reranker. /// /// # Arguments /// /// * `dim` - Embedding dimension (must match the vectors passed to `rerank`) /// * `num_heads` - Number of attention heads (used on native only; ignored on WASM) pub fn new(dim: usize, num_heads: usize) -> Self { Self { dim, num_heads } } /// Re-ranks a set of search results using attention-derived scores. /// /// # Arguments /// /// * `query_embedding` - The query vector (`dim`-dimensional). /// * `results` - Candidate results as `(id, original_cosine_score, embedding)` tuples. /// * `top_k` - Maximum number of results to return. /// /// # Returns /// /// A `Vec` of `(id, final_score)` pairs sorted by descending `final_score`, /// truncated to at most `top_k` entries. pub fn rerank( &self, query_embedding: &[f32], results: &[(String, f32, Vec)], top_k: usize, ) -> Vec<(String, f32)> { if results.is_empty() { return Vec::new(); } #[cfg(not(target_arch = "wasm32"))] { self.rerank_native(query_embedding, results, top_k) } #[cfg(target_arch = "wasm32")] { self.rerank_wasm(results, top_k) } } // --------------------------------------------------------------- // Native implementation (ruvector-attention) // --------------------------------------------------------------- #[cfg(not(target_arch = "wasm32"))] fn rerank_native( &self, query_embedding: &[f32], results: &[(String, f32, Vec)], top_k: usize, ) -> Vec<(String, f32)> { use ruvector_attention::attention::ScaledDotProductAttention; use ruvector_attention::traits::Attention; let attn = ScaledDotProductAttention::new(self.dim); // Build key slices from result embeddings. let keys: Vec<&[f32]> = results.iter().map(|(_, _, emb)| emb.as_slice()).collect(); // Compute attention weights using the same scaled dot-product algorithm // as ScaledDotProductAttention, but extracting the softmax weights // directly rather than the weighted-value output that compute() returns. // --- Compute raw attention scores: QK^T / sqrt(d) --- let scale = (self.dim as f32).sqrt(); let scores: Vec = keys .iter() .map(|key| { query_embedding .iter() .zip(key.iter()) .map(|(q, k)| q * k) .sum::() / scale }) .collect(); // --- Softmax --- let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let exp_scores: Vec = scores.iter().map(|s| (s - max_score).exp()).collect(); let exp_sum: f32 = exp_scores.iter().sum(); let attention_weights: Vec = exp_scores.iter().map(|e| e / exp_sum).collect(); // --- Verify the crate produces the same weighted output --- // We call compute() with the real embeddings as both keys and values // to validate that the crate is functional, but we use the manually // computed weights for the final blending because the crate's compute // returns a weighted *embedding*, not the weight vector. let _attended_output = attn.compute(query_embedding, &keys, &keys); // --- Blend: final = 0.6 * attention_weight + 0.4 * cosine_score --- let mut scored: Vec<(String, f32)> = results .iter() .zip(attention_weights.iter()) .map(|((id, cosine, _), &attn_w)| { let final_score = 0.6 * attn_w + 0.4 * cosine; (id.clone(), final_score) }) .collect(); scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); scored.truncate(top_k); scored } // --------------------------------------------------------------- // WASM fallback // --------------------------------------------------------------- #[cfg(target_arch = "wasm32")] fn rerank_wasm(&self, results: &[(String, f32, Vec)], top_k: usize) -> Vec<(String, f32)> { let mut scored: Vec<(String, f32)> = results .iter() .map(|(id, cosine, _)| (id.clone(), *cosine)) .collect(); scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); scored.truncate(top_k); scored } } #[cfg(test)] mod tests { use super::*; #[test] fn test_reranker_empty_results() { let reranker = AttentionReranker::new(4, 1); let result = reranker.rerank(&[1.0, 0.0, 0.0, 0.0], &[], 5); assert!(result.is_empty()); } #[test] fn test_reranker_single_result() { let reranker = AttentionReranker::new(4, 1); let results = vec![("a".to_string(), 0.9, vec![1.0, 0.0, 0.0, 0.0])]; let ranked = reranker.rerank(&[1.0, 0.0, 0.0, 0.0], &results, 5); assert_eq!(ranked.len(), 1); assert_eq!(ranked[0].0, "a"); } #[test] fn test_reranker_respects_top_k() { let reranker = AttentionReranker::new(4, 1); let results = vec![ ("a".to_string(), 0.9, vec![1.0, 0.0, 0.0, 0.0]), ("b".to_string(), 0.8, vec![0.0, 1.0, 0.0, 0.0]), ("c".to_string(), 0.7, vec![0.0, 0.0, 1.0, 0.0]), ]; let ranked = reranker.rerank(&[1.0, 0.0, 0.0, 0.0], &results, 2); assert_eq!(ranked.len(), 2); } #[test] fn test_reranker_can_reorder() { // The attention mechanism should boost results whose embeddings // are more aligned with the query, potentially changing the order // compared to the original cosine scores. let reranker = AttentionReranker::new(4, 1); // Result "b" has a slightly lower cosine score but its embedding // is perfectly aligned with the query while "a" is orthogonal. // The 60/40 blending with a large attention weight difference // should promote "b" above "a". let results = vec![ ("a".to_string(), 0.70, vec![0.0, 0.0, 1.0, 0.0]), ("b".to_string(), 0.55, vec![1.0, 0.0, 0.0, 0.0]), ]; let query = vec![1.0, 0.0, 0.0, 0.0]; let ranked = reranker.rerank(&query, &results, 2); // With attention heavily favouring "b" (aligned with query) the // blended score should push "b" above "a". assert_eq!(ranked.len(), 2); assert_eq!( ranked[0].0, "b", "Attention re-ranking should promote the more query-aligned result" ); } }