wifi-densepose/v2/crates/wifi-densepose-temporal/src/config.rs

71 lines
2.5 KiB
Rust

use crate::TemporalError;
/// Backend choice per ADR-096 §4.4.
///
/// * `Dense` — back-compat path against `ruvector-attention`. Reserved;
/// not yet implemented in this crate (returns a typed error so callers
/// can fail loudly during config validation rather than at forward()).
/// * `SparseGqa` — `ruvllm_sparse_attention` `forward_gqa` for prefill,
/// `decode_step` against `KvCache` for streaming inference.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum TemporalBackendKind {
Dense,
SparseGqa,
}
#[derive(Clone, Debug)]
pub struct TemporalHeadConfig {
pub backend: TemporalBackendKind,
/// Number of query heads. For pure MHA, equals `kv_heads`.
pub q_heads: usize,
/// Number of key/value heads. Must divide `q_heads`. GQA group size
/// is `q_heads / kv_heads`.
pub kv_heads: usize,
/// Per-head feature dimension.
pub head_dim: usize,
/// Local attention window radius (sparse pattern primitive #1, ADR-096 §3).
pub window: usize,
/// Landmark block size (sparse pattern primitive #3).
pub block_size: usize,
/// Whether the attention is causal. AETHER temporal aggregation is
/// causal (cannot peek at future CSI frames during streaming re-ID).
pub causal: bool,
}
impl TemporalHeadConfig {
/// Default config sized for the AETHER training default
/// (`window_frames = 100`) but with the sparse machinery wired up
/// so the long-window roadmap (10 s / 1000 frames) only requires
/// changing `window` at the call site, not re-architecting.
pub fn default_aether() -> Self {
Self {
backend: TemporalBackendKind::SparseGqa,
q_heads: 4,
kv_heads: 1, // MQA — collapses to one shared K/V across query heads
head_dim: 32,
window: 32,
block_size: 16,
causal: true,
}
}
pub fn validate(&self) -> Result<(), TemporalError> {
if self.q_heads == 0 || self.kv_heads == 0 || self.head_dim == 0 {
return Err(TemporalError::InvalidConfig(
"q_heads, kv_heads, head_dim must all be > 0",
));
}
if self.q_heads % self.kv_heads != 0 {
return Err(TemporalError::InvalidConfig(
"q_heads must be divisible by kv_heads (GQA constraint)",
));
}
if self.block_size == 0 {
return Err(TemporalError::InvalidConfig("block_size must be > 0"));
}
Ok(())
}
}