335 lines
10 KiB
Rust
335 lines
10 KiB
Rust
//! Winner-Take-All (WTA) WASM bindings
|
|
//!
|
|
//! Instant decisions via neural competition:
|
|
//! - Single winner: <1us for 1000 neurons
|
|
//! - K-WTA: <10us for k=50
|
|
|
|
use wasm_bindgen::prelude::*;
|
|
|
|
/// Winner-Take-All competition layer
|
|
///
|
|
/// Implements neural competition where the highest-activation neuron
|
|
/// wins and suppresses others through lateral inhibition.
|
|
///
|
|
/// # Performance
|
|
/// - <1us winner selection for 1000 neurons
|
|
#[wasm_bindgen]
|
|
pub struct WTALayer {
|
|
membranes: Vec<f32>,
|
|
threshold: f32,
|
|
inhibition_strength: f32,
|
|
refractory_period: u32,
|
|
refractory_counters: Vec<u32>,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl WTALayer {
|
|
/// Create a new WTA layer
|
|
///
|
|
/// # Arguments
|
|
/// * `size` - Number of competing neurons
|
|
/// * `threshold` - Activation threshold for firing
|
|
/// * `inhibition` - Lateral inhibition strength (0.0-1.0)
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(size: usize, threshold: f32, inhibition: f32) -> Result<WTALayer, JsValue> {
|
|
if size == 0 {
|
|
return Err(JsValue::from_str("Size must be > 0"));
|
|
}
|
|
|
|
Ok(Self {
|
|
membranes: vec![0.0; size],
|
|
threshold,
|
|
inhibition_strength: inhibition.clamp(0.0, 1.0),
|
|
refractory_period: 10,
|
|
refractory_counters: vec![0; size],
|
|
})
|
|
}
|
|
|
|
/// Run winner-take-all competition
|
|
///
|
|
/// Returns the index of the winning neuron, or -1 if no neuron exceeds threshold.
|
|
#[wasm_bindgen]
|
|
pub fn compete(&mut self, inputs: &[f32]) -> Result<i32, JsValue> {
|
|
if inputs.len() != self.membranes.len() {
|
|
return Err(JsValue::from_str(&format!(
|
|
"Input size mismatch: expected {}, got {}",
|
|
self.membranes.len(),
|
|
inputs.len()
|
|
)));
|
|
}
|
|
|
|
// Single-pass: update membrane potentials and find max
|
|
let mut best_idx: Option<usize> = None;
|
|
let mut best_val = f32::NEG_INFINITY;
|
|
|
|
for (i, &input) in inputs.iter().enumerate() {
|
|
if self.refractory_counters[i] == 0 {
|
|
self.membranes[i] = input;
|
|
if input > best_val {
|
|
best_val = input;
|
|
best_idx = Some(i);
|
|
}
|
|
} else {
|
|
self.refractory_counters[i] = self.refractory_counters[i].saturating_sub(1);
|
|
}
|
|
}
|
|
|
|
let winner_idx = match best_idx {
|
|
Some(idx) => idx,
|
|
None => return Ok(-1),
|
|
};
|
|
|
|
// Check if winner exceeds threshold
|
|
if best_val < self.threshold {
|
|
return Ok(-1);
|
|
}
|
|
|
|
// Apply lateral inhibition
|
|
for (i, membrane) in self.membranes.iter_mut().enumerate() {
|
|
if i != winner_idx {
|
|
*membrane *= 1.0 - self.inhibition_strength;
|
|
}
|
|
}
|
|
|
|
// Set refractory period for winner
|
|
self.refractory_counters[winner_idx] = self.refractory_period;
|
|
|
|
Ok(winner_idx as i32)
|
|
}
|
|
|
|
/// Soft competition with normalized activations
|
|
///
|
|
/// Returns activation levels for all neurons after softmax-like normalization.
|
|
#[wasm_bindgen]
|
|
pub fn compete_soft(&mut self, inputs: &[f32]) -> Result<js_sys::Float32Array, JsValue> {
|
|
if inputs.len() != self.membranes.len() {
|
|
return Err(JsValue::from_str(&format!(
|
|
"Input size mismatch: expected {}, got {}",
|
|
self.membranes.len(),
|
|
inputs.len()
|
|
)));
|
|
}
|
|
|
|
// Update membrane potentials
|
|
self.membranes.copy_from_slice(inputs);
|
|
|
|
// Find max for numerical stability
|
|
let max_val = self
|
|
.membranes
|
|
.iter()
|
|
.copied()
|
|
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
|
.unwrap_or(0.0);
|
|
|
|
// Softmax with temperature
|
|
let temperature = 1.0 / (1.0 + self.inhibition_strength);
|
|
let mut activations: Vec<f32> = self
|
|
.membranes
|
|
.iter()
|
|
.map(|&x| ((x - max_val) / temperature).exp())
|
|
.collect();
|
|
|
|
// Normalize
|
|
let sum: f32 = activations.iter().sum();
|
|
if sum > 0.0 {
|
|
for a in &mut activations {
|
|
*a /= sum;
|
|
}
|
|
}
|
|
|
|
Ok(js_sys::Float32Array::from(activations.as_slice()))
|
|
}
|
|
|
|
/// Reset layer state
|
|
#[wasm_bindgen]
|
|
pub fn reset(&mut self) {
|
|
self.membranes.fill(0.0);
|
|
self.refractory_counters.fill(0);
|
|
}
|
|
|
|
/// Get current membrane potentials
|
|
#[wasm_bindgen]
|
|
pub fn get_membranes(&self) -> js_sys::Float32Array {
|
|
js_sys::Float32Array::from(self.membranes.as_slice())
|
|
}
|
|
|
|
/// Set refractory period
|
|
#[wasm_bindgen]
|
|
pub fn set_refractory_period(&mut self, period: u32) {
|
|
self.refractory_period = period;
|
|
}
|
|
|
|
/// Get layer size
|
|
#[wasm_bindgen(getter)]
|
|
pub fn size(&self) -> usize {
|
|
self.membranes.len()
|
|
}
|
|
}
|
|
|
|
/// K-Winner-Take-All layer for sparse distributed coding
|
|
///
|
|
/// Selects top-k neurons with highest activations.
|
|
///
|
|
/// # Performance
|
|
/// - O(n + k log k) using partial sorting
|
|
/// - <10us for 1000 neurons, k=50
|
|
#[wasm_bindgen]
|
|
pub struct KWTALayer {
|
|
size: usize,
|
|
k: usize,
|
|
threshold: Option<f32>,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl KWTALayer {
|
|
/// Create a new K-WTA layer
|
|
///
|
|
/// # Arguments
|
|
/// * `size` - Total number of neurons
|
|
/// * `k` - Number of winners to select
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn new(size: usize, k: usize) -> Result<KWTALayer, JsValue> {
|
|
if k == 0 {
|
|
return Err(JsValue::from_str("k must be > 0"));
|
|
}
|
|
if k > size {
|
|
return Err(JsValue::from_str("k cannot exceed layer size"));
|
|
}
|
|
|
|
Ok(Self {
|
|
size,
|
|
k,
|
|
threshold: None,
|
|
})
|
|
}
|
|
|
|
/// Set activation threshold
|
|
#[wasm_bindgen]
|
|
pub fn with_threshold(&mut self, threshold: f32) {
|
|
self.threshold = Some(threshold);
|
|
}
|
|
|
|
/// Select top-k neurons
|
|
///
|
|
/// Returns indices of k neurons with highest activations, sorted descending.
|
|
#[wasm_bindgen]
|
|
pub fn select(&self, inputs: &[f32]) -> Result<js_sys::Uint32Array, JsValue> {
|
|
if inputs.len() != self.size {
|
|
return Err(JsValue::from_str(&format!(
|
|
"Input size mismatch: expected {}, got {}",
|
|
self.size,
|
|
inputs.len()
|
|
)));
|
|
}
|
|
|
|
// Create (index, value) pairs
|
|
let mut indexed: Vec<(usize, f32)> =
|
|
inputs.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
|
|
|
// Filter by threshold if set
|
|
if let Some(threshold) = self.threshold {
|
|
indexed.retain(|(_, v)| *v >= threshold);
|
|
}
|
|
|
|
if indexed.is_empty() {
|
|
return Ok(js_sys::Uint32Array::new_with_length(0));
|
|
}
|
|
|
|
// Partial sort to get top-k
|
|
let k_actual = self.k.min(indexed.len());
|
|
indexed.select_nth_unstable_by(k_actual - 1, |a, b| {
|
|
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
|
|
});
|
|
|
|
// Take top k and sort descending
|
|
let mut winners: Vec<(usize, f32)> = indexed[..k_actual].to_vec();
|
|
winners.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
// Return only indices as u32
|
|
let indices: Vec<u32> = winners.into_iter().map(|(i, _)| i as u32).collect();
|
|
Ok(js_sys::Uint32Array::from(indices.as_slice()))
|
|
}
|
|
|
|
/// Select top-k neurons with their activation values
|
|
///
|
|
/// Returns array of [index, value] pairs.
|
|
#[wasm_bindgen]
|
|
pub fn select_with_values(&self, inputs: &[f32]) -> Result<JsValue, JsValue> {
|
|
if inputs.len() != self.size {
|
|
return Err(JsValue::from_str(&format!(
|
|
"Input size mismatch: expected {}, got {}",
|
|
self.size,
|
|
inputs.len()
|
|
)));
|
|
}
|
|
|
|
let mut indexed: Vec<(usize, f32)> =
|
|
inputs.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
|
|
|
if let Some(threshold) = self.threshold {
|
|
indexed.retain(|(_, v)| *v >= threshold);
|
|
}
|
|
|
|
if indexed.is_empty() {
|
|
return serde_wasm_bindgen::to_value(&Vec::<(usize, f32)>::new())
|
|
.map_err(|e| JsValue::from_str(&e.to_string()));
|
|
}
|
|
|
|
let k_actual = self.k.min(indexed.len());
|
|
indexed.select_nth_unstable_by(k_actual - 1, |a, b| {
|
|
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
|
|
});
|
|
|
|
let mut winners: Vec<(usize, f32)> = indexed[..k_actual].to_vec();
|
|
winners.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
serde_wasm_bindgen::to_value(&winners).map_err(|e| JsValue::from_str(&e.to_string()))
|
|
}
|
|
|
|
/// Create sparse activation vector (only top-k preserved)
|
|
#[wasm_bindgen]
|
|
pub fn sparse_activations(&self, inputs: &[f32]) -> Result<js_sys::Float32Array, JsValue> {
|
|
if inputs.len() != self.size {
|
|
return Err(JsValue::from_str(&format!(
|
|
"Input size mismatch: expected {}, got {}",
|
|
self.size,
|
|
inputs.len()
|
|
)));
|
|
}
|
|
|
|
let mut indexed: Vec<(usize, f32)> =
|
|
inputs.iter().enumerate().map(|(i, &v)| (i, v)).collect();
|
|
|
|
if let Some(threshold) = self.threshold {
|
|
indexed.retain(|(_, v)| *v >= threshold);
|
|
}
|
|
|
|
let mut sparse = vec![0.0; self.size];
|
|
|
|
if !indexed.is_empty() {
|
|
let k_actual = self.k.min(indexed.len());
|
|
indexed.select_nth_unstable_by(k_actual - 1, |a, b| {
|
|
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
|
|
});
|
|
|
|
for (idx, value) in &indexed[..k_actual] {
|
|
sparse[*idx] = *value;
|
|
}
|
|
}
|
|
|
|
Ok(js_sys::Float32Array::from(sparse.as_slice()))
|
|
}
|
|
|
|
/// Get number of winners
|
|
#[wasm_bindgen(getter)]
|
|
pub fn k(&self) -> usize {
|
|
self.k
|
|
}
|
|
|
|
/// Get layer size
|
|
#[wasm_bindgen(getter)]
|
|
pub fn size(&self) -> usize {
|
|
self.size
|
|
}
|
|
}
|