//! GNN Layer Implementation for HNSW Topology //! //! This module implements graph neural network layers that operate on HNSW graph structure, //! including attention mechanisms, normalization, and gated recurrent updates. use crate::error::GnnError; use ndarray::{Array1, Array2, ArrayView1}; use rand::Rng; use rand_distr::{Distribution, Normal}; use serde::{Deserialize, Serialize}; /// Linear transformation layer (weight matrix multiplication) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Linear { weights: Array2, bias: Array1, } impl Linear { /// Create a new linear layer with Xavier/Glorot initialization pub fn new(input_dim: usize, output_dim: usize) -> Self { let mut rng = rand::thread_rng(); // Xavier initialization: scale = sqrt(2.0 / (input_dim + output_dim)) let scale = (2.0 / (input_dim + output_dim) as f32).sqrt(); let normal = Normal::new(0.0, scale as f64).unwrap(); let weights = Array2::from_shape_fn((output_dim, input_dim), |_| normal.sample(&mut rng) as f32); let bias = Array1::zeros(output_dim); Self { weights, bias } } /// Forward pass: y = Wx + b pub fn forward(&self, input: &[f32]) -> Vec { let x = ArrayView1::from(input); let output = self.weights.dot(&x) + &self.bias; output.to_vec() } /// Get output dimension pub fn output_dim(&self) -> usize { self.weights.shape()[0] } } /// Layer normalization #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LayerNorm { gamma: Array1, beta: Array1, eps: f32, } impl LayerNorm { /// Create a new layer normalization layer pub fn new(dim: usize, eps: f32) -> Self { Self { gamma: Array1::ones(dim), beta: Array1::zeros(dim), eps, } } /// Forward pass: normalize and scale pub fn forward(&self, input: &[f32]) -> Vec { let x = ArrayView1::from(input); // Compute mean and variance let mean = x.mean().unwrap_or(0.0); let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::() / x.len() as f32; // Normalize let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt()); // Scale and shift let output = &self.gamma * &normalized + &self.beta; output.to_vec() } } /// Multi-head attention mechanism #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MultiHeadAttention { num_heads: usize, head_dim: usize, q_linear: Linear, k_linear: Linear, v_linear: Linear, out_linear: Linear, } impl MultiHeadAttention { /// Create a new multi-head attention layer /// /// # Errors /// Returns `GnnError::LayerConfig` if `embed_dim` is not divisible by `num_heads`. pub fn new(embed_dim: usize, num_heads: usize) -> Result { if embed_dim % num_heads != 0 { return Err(GnnError::layer_config(format!( "Embedding dimension ({}) must be divisible by number of heads ({})", embed_dim, num_heads ))); } let head_dim = embed_dim / num_heads; Ok(Self { num_heads, head_dim, q_linear: Linear::new(embed_dim, embed_dim), k_linear: Linear::new(embed_dim, embed_dim), v_linear: Linear::new(embed_dim, embed_dim), out_linear: Linear::new(embed_dim, embed_dim), }) } /// Forward pass: compute multi-head attention /// /// # Arguments /// * `query` - Query vector /// * `keys` - Key vectors from neighbors /// * `values` - Value vectors from neighbors /// /// # Returns /// Attention-weighted output vector pub fn forward(&self, query: &[f32], keys: &[Vec], values: &[Vec]) -> Vec { if keys.is_empty() || values.is_empty() { return query.to_vec(); } // Project query, keys, and values let q = self.q_linear.forward(query); let k: Vec> = keys.iter().map(|k| self.k_linear.forward(k)).collect(); let v: Vec> = values.iter().map(|v| self.v_linear.forward(v)).collect(); // Reshape for multi-head attention let q_heads = self.split_heads(&q); let k_heads: Vec>> = k.iter().map(|k_vec| self.split_heads(k_vec)).collect(); let v_heads: Vec>> = v.iter().map(|v_vec| self.split_heads(v_vec)).collect(); // Compute attention for each head let mut head_outputs = Vec::new(); for h in 0..self.num_heads { let q_h = &q_heads[h]; let k_h: Vec<&Vec> = k_heads.iter().map(|heads| &heads[h]).collect(); let v_h: Vec<&Vec> = v_heads.iter().map(|heads| &heads[h]).collect(); let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h); head_outputs.push(head_output); } // Concatenate heads let concat: Vec = head_outputs.into_iter().flatten().collect(); // Final linear projection self.out_linear.forward(&concat) } /// Split vector into multiple heads fn split_heads(&self, x: &[f32]) -> Vec> { let mut heads = Vec::new(); for h in 0..self.num_heads { let start = h * self.head_dim; let end = start + self.head_dim; heads.push(x[start..end].to_vec()); } heads } /// Scaled dot-product attention fn scaled_dot_product_attention( &self, query: &[f32], keys: &[&Vec], values: &[&Vec], ) -> Vec { if keys.is_empty() { return query.to_vec(); } let scale = (self.head_dim as f32).sqrt(); // Compute attention scores let scores: Vec = keys .iter() .map(|k| { let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum(); dot / scale }) .collect(); // Softmax with epsilon guard against division by zero let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max); let exp_scores: Vec = scores.iter().map(|&s| (s - max_score).exp()).collect(); let sum_exp: f32 = exp_scores.iter().sum::().max(1e-10); let attention_weights: Vec = exp_scores.iter().map(|&e| e / sum_exp).collect(); // Weighted sum of values let mut output = vec![0.0; self.head_dim]; for (weight, value) in attention_weights.iter().zip(values.iter()) { for (out, &val) in output.iter_mut().zip(value.iter()) { *out += weight * val; } } output } } /// Gated Recurrent Unit (GRU) cell for state updates #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GRUCell { // Update gate w_z: Linear, u_z: Linear, // Reset gate w_r: Linear, u_r: Linear, // Candidate hidden state w_h: Linear, u_h: Linear, } impl GRUCell { /// Create a new GRU cell pub fn new(input_dim: usize, hidden_dim: usize) -> Self { Self { // Update gate w_z: Linear::new(input_dim, hidden_dim), u_z: Linear::new(hidden_dim, hidden_dim), // Reset gate w_r: Linear::new(input_dim, hidden_dim), u_r: Linear::new(hidden_dim, hidden_dim), // Candidate hidden state w_h: Linear::new(input_dim, hidden_dim), u_h: Linear::new(hidden_dim, hidden_dim), } } /// Forward pass: update hidden state /// /// # Arguments /// * `input` - Current input /// * `hidden` - Previous hidden state /// /// # Returns /// Updated hidden state pub fn forward(&self, input: &[f32], hidden: &[f32]) -> Vec { // Update gate: z_t = sigmoid(W_z * x_t + U_z * h_{t-1}) let z = self.sigmoid_vec(&self.add_vecs(&self.w_z.forward(input), &self.u_z.forward(hidden))); // Reset gate: r_t = sigmoid(W_r * x_t + U_r * h_{t-1}) let r = self.sigmoid_vec(&self.add_vecs(&self.w_r.forward(input), &self.u_r.forward(hidden))); // Candidate hidden state: h_tilde = tanh(W_h * x_t + U_h * (r_t ⊙ h_{t-1})) let r_hidden = self.mul_vecs(&r, hidden); let h_tilde = self.tanh_vec(&self.add_vecs(&self.w_h.forward(input), &self.u_h.forward(&r_hidden))); // Final hidden state: h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h_tilde let one_minus_z: Vec = z.iter().map(|&zval| 1.0 - zval).collect(); let term1 = self.mul_vecs(&one_minus_z, hidden); let term2 = self.mul_vecs(&z, &h_tilde); self.add_vecs(&term1, &term2) } /// Sigmoid activation with numerical stability fn sigmoid(&self, x: f32) -> f32 { if x > 0.0 { 1.0 / (1.0 + (-x).exp()) } else { let ex = x.exp(); ex / (1.0 + ex) } } /// Sigmoid for vectors fn sigmoid_vec(&self, v: &[f32]) -> Vec { v.iter().map(|&x| self.sigmoid(x)).collect() } /// Tanh activation fn tanh(&self, x: f32) -> f32 { x.tanh() } /// Tanh for vectors fn tanh_vec(&self, v: &[f32]) -> Vec { v.iter().map(|&x| self.tanh(x)).collect() } /// Element-wise addition fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec { a.iter().zip(b.iter()).map(|(x, y)| x + y).collect() } /// Element-wise multiplication fn mul_vecs(&self, a: &[f32], b: &[f32]) -> Vec { a.iter().zip(b.iter()).map(|(x, y)| x * y).collect() } } /// Main GNN layer operating on HNSW topology #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RuvectorLayer { /// Message weight matrix w_msg: Linear, /// Aggregation weight matrix w_agg: Linear, /// GRU update cell w_update: GRUCell, /// Multi-head attention attention: MultiHeadAttention, /// Layer normalization norm: LayerNorm, /// Dropout rate dropout: f32, } impl RuvectorLayer { /// Create a new Ruvector GNN layer /// /// # Arguments /// * `input_dim` - Dimension of input node embeddings /// * `hidden_dim` - Dimension of hidden representations /// * `heads` - Number of attention heads /// * `dropout` - Dropout rate (0.0 to 1.0) /// /// # Errors /// Returns `GnnError::LayerConfig` if `dropout` is outside `[0.0, 1.0]` or /// if `hidden_dim` is not divisible by `heads`. pub fn new( input_dim: usize, hidden_dim: usize, heads: usize, dropout: f32, ) -> Result { if !(0.0..=1.0).contains(&dropout) { return Err(GnnError::layer_config(format!( "Dropout must be between 0.0 and 1.0, got {}", dropout ))); } Ok(Self { w_msg: Linear::new(input_dim, hidden_dim), w_agg: Linear::new(hidden_dim, hidden_dim), w_update: GRUCell::new(hidden_dim, hidden_dim), attention: MultiHeadAttention::new(hidden_dim, heads)?, norm: LayerNorm::new(hidden_dim, 1e-5), dropout, }) } /// Forward pass through the GNN layer /// /// # Arguments /// * `node_embedding` - Current node's embedding /// * `neighbor_embeddings` - Embeddings of neighbor nodes /// * `edge_weights` - Weights of edges to neighbors (e.g., distances) /// /// # Returns /// Updated node embedding pub fn forward( &self, node_embedding: &[f32], neighbor_embeddings: &[Vec], edge_weights: &[f32], ) -> Vec { if neighbor_embeddings.is_empty() { // No neighbors: return normalized projection let projected = self.w_msg.forward(node_embedding); return self.norm.forward(&projected); } // Step 1: Message passing - transform node and neighbor embeddings let node_msg = self.w_msg.forward(node_embedding); let neighbor_msgs: Vec> = neighbor_embeddings .iter() .map(|n| self.w_msg.forward(n)) .collect(); // Step 2: Attention-based aggregation let attention_output = self .attention .forward(&node_msg, &neighbor_msgs, &neighbor_msgs); // Step 3: Weighted aggregation using edge weights let weighted_msgs = self.aggregate_messages(&neighbor_msgs, edge_weights); // Step 4: Combine attention and weighted aggregation let combined = self.add_vecs(&attention_output, &weighted_msgs); let aggregated = self.w_agg.forward(&combined); // Step 5: GRU update let updated = self.w_update.forward(&aggregated, &node_msg); // Step 6: Apply dropout (simplified - always apply scaling) let dropped = self.apply_dropout(&updated); // Step 7: Layer normalization self.norm.forward(&dropped) } /// Aggregate neighbor messages with edge weights fn aggregate_messages(&self, messages: &[Vec], weights: &[f32]) -> Vec { if messages.is_empty() || weights.is_empty() { return vec![0.0; self.w_msg.output_dim()]; } // Normalize weights to sum to 1 let weight_sum: f32 = weights.iter().sum(); let normalized_weights: Vec = if weight_sum > 0.0 { weights.iter().map(|&w| w / weight_sum).collect() } else { vec![1.0 / weights.len() as f32; weights.len()] }; // Weighted sum let dim = messages[0].len(); let mut aggregated = vec![0.0; dim]; for (msg, &weight) in messages.iter().zip(normalized_weights.iter()) { for (agg, &m) in aggregated.iter_mut().zip(msg.iter()) { *agg += weight * m; } } aggregated } /// Apply dropout (simplified version - just scales by (1-dropout)) fn apply_dropout(&self, input: &[f32]) -> Vec { let scale = 1.0 - self.dropout; input.iter().map(|&x| x * scale).collect() } /// Element-wise vector addition fn add_vecs(&self, a: &[f32], b: &[f32]) -> Vec { a.iter().zip(b.iter()).map(|(x, y)| x + y).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_linear_layer() { let linear = Linear::new(4, 2); let input = vec![1.0, 2.0, 3.0, 4.0]; let output = linear.forward(&input); assert_eq!(output.len(), 2); } #[test] fn test_layer_norm() { let norm = LayerNorm::new(4, 1e-5); let input = vec![1.0, 2.0, 3.0, 4.0]; let output = norm.forward(&input); // Check that output has zero mean (approximately) let mean: f32 = output.iter().sum::() / output.len() as f32; assert!((mean).abs() < 1e-5); } #[test] fn test_multihead_attention() { let attention = MultiHeadAttention::new(8, 2).unwrap(); let query = vec![0.5; 8]; let keys = vec![vec![0.3; 8], vec![0.7; 8]]; let values = vec![vec![0.2; 8], vec![0.8; 8]]; let output = attention.forward(&query, &keys, &values); assert_eq!(output.len(), 8); } #[test] fn test_multihead_attention_invalid_dims() { let result = MultiHeadAttention::new(10, 3); assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!(err.contains("divisible")); } #[test] fn test_gru_cell() { let gru = GRUCell::new(4, 8); let input = vec![1.0; 4]; let hidden = vec![0.5; 8]; let new_hidden = gru.forward(&input, &hidden); assert_eq!(new_hidden.len(), 8); } #[test] fn test_ruvector_layer() { let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap(); let node = vec![1.0, 2.0, 3.0, 4.0]; let neighbors = vec![vec![0.5, 1.0, 1.5, 2.0], vec![2.0, 3.0, 4.0, 5.0]]; let weights = vec![0.3, 0.7]; let output = layer.forward(&node, &neighbors, &weights); assert_eq!(output.len(), 8); } #[test] fn test_ruvector_layer_no_neighbors() { let layer = RuvectorLayer::new(4, 8, 2, 0.1).unwrap(); let node = vec![1.0, 2.0, 3.0, 4.0]; let neighbors: Vec> = vec![]; let weights: Vec = vec![]; let output = layer.forward(&node, &neighbors, &weights); assert_eq!(output.len(), 8); } #[test] fn test_ruvector_layer_invalid_dropout() { let result = RuvectorLayer::new(4, 8, 2, 1.5); assert!(result.is_err()); } #[test] fn test_ruvector_layer_invalid_heads() { let result = RuvectorLayer::new(4, 7, 3, 0.1); assert!(result.is_err()); } }