wifi-densepose/v2/crates/ruview-swarm/src/allocation/fnn.rs

98 lines
2.7 KiB
Rust

//! Lightweight 3-layer FNN bid scorer — pure Rust, no ONNX required.
/// 3-layer FNN: 5 inputs → 16 hidden (ReLU) → 8 hidden (ReLU) → 1 output (sigmoid).
pub struct FnnScorer {
pub w1: [[f32; 5]; 16],
pub b1: [f32; 16],
pub w2: [[f32; 16]; 8],
pub b2: [f32; 8],
pub w3: [f32; 8],
pub b3: f32,
}
fn relu(x: f32) -> f32 {
x.max(0.0)
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
impl FnnScorer {
/// Score a feature vector. Returns sigmoid(output) ∈ [0, 1].
/// Features: [dist_norm, battery_norm, link_quality, csi_confidence, workload_norm]
pub fn score(&self, features: [f32; 5]) -> f32 {
// Layer 1: 5 → 16 (ReLU)
let mut h1 = [0.0f32; 16];
for (i, row) in self.w1.iter().enumerate() {
let z: f32 = row.iter().zip(features.iter()).map(|(w, x)| w * x).sum();
h1[i] = relu(z + self.b1[i]);
}
// Layer 2: 16 → 8 (ReLU)
let mut h2 = [0.0f32; 8];
for (i, row) in self.w2.iter().enumerate() {
let z: f32 = row.iter().zip(h1.iter()).map(|(w, x)| w * x).sum();
h2[i] = relu(z + self.b2[i]);
}
// Layer 3: 8 → 1 (sigmoid)
let z3: f32 = self.w3.iter().zip(h2.iter()).map(|(w, x)| w * x).sum::<f32>() + self.b3;
sigmoid(z3)
}
/// Default weights initialised to a simple identity-like setup.
pub fn default_weights() -> Self {
// Simple: w1 diagonalish, others small constant
// Index needed: diagonal/strided init uses i for both row and column.
let mut w1 = [[0.0f32; 5]; 16];
#[allow(clippy::needless_range_loop)]
for i in 0..5 {
w1[i][i] = 1.0;
}
for row in w1.iter_mut().take(16).skip(5) {
row[0] = 0.1;
}
let mut w2 = [[0.0f32; 16]; 8];
#[allow(clippy::needless_range_loop)]
for i in 0..8 {
w2[i][i * 2] = 1.0;
}
let w3 = [0.125f32; 8];
Self {
w1,
b1: [0.0; 16],
w2,
b2: [0.0; 8],
w3,
b3: 0.0,
}
}
}
impl Default for FnnScorer {
fn default() -> Self {
Self::default_weights()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_score_in_unit_interval() {
let scorer = FnnScorer::default_weights();
let features = [0.3f32, 0.8, 0.9, 0.75, 0.2];
let s = scorer.score(features);
assert!(s >= 0.0 && s <= 1.0, "score {s} out of [0,1]");
}
#[test]
fn test_score_deterministic() {
let scorer = FnnScorer::default_weights();
let f = [0.5f32; 5];
assert_eq!(scorer.score(f), scorer.score(f));
}
}