feat(temporal): weight-blob wire format (ADR-095 Phase 1, #513)

The training/firmware boundary needs a stable serialization for the
temporal head's weights, distinct from the kernel scaffold and the
firmware ABI. This commit defines that format on the host side. The
firmware-side mirrored loader lands when the toolchain unblocks.

Format:
  - Header (24 B): magic 'RVNE' / version 1 / dtype flag
    (FP32 / FP16) / input_dim / n_q_heads / n_kv_heads / head_dim /
    n_layers / n_classes / weights_len.
  - Body: weights_len bytes of flat per-layer weights.
  - Footer (4 B): CRC32 IEEE 802.3 over everything before, same
    polynomial used by temporal_task.c so a blob produced here parses
    on the firmware unchanged.

Layout decisions:
  - Little-endian throughout (Xtensa native).
  - Weights kept as Vec<u8> rather than Vec<f32>/Vec<f16> so the no_std
    firmware loader (which may not have the `half` crate) can mmap and
    read either dtype directly.
  - Versioning is hard-break: bumping `version` means firmware refuses
    to load. Optional fields go behind reserved flag bits, never by
    field reorder. Documented inline.

Validation surface:
  - `WeightBlobHeader::validate()` catches zero dims, invalid GQA
    ratios (n_q_heads % n_kv_heads != 0), n_layers=0, n_classes<2.
    Same checks fire from `WeightBlob::parse()` so the firmware can't
    accidentally accept a blob the host should have rejected.
  - `WeightBlob::parse()` enforces magic / version / size / CRC
    before exposing weights to the caller.

Tests (8/8 passing, alongside 5/5 sparse smoke = 13/13 total):
  - roundtrip_fp32, roundtrip_fp16
  - parse_rejects_bad_magic, _wrong_version, _size_mismatch,
    _crc_corruption, _invalid_gqa_ratio_in_header
  - header_constants_match_wire_layout (anchor)

What's deliberately NOT in this commit:
  - The firmware-side mirrored loader (deferred to the iteration that
    unblocks the esp Rust toolchain — no point shipping a parser that
    can't be compiled).
  - Per-layer weight ordering. The blob is a flat byte-buffer; the
    interpretation of per-layer offsets is the kernel's contract,
    documented in the eventual model module (ADR-095 §3.2 follow-up).

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv 2026-05-08 11:43:49 -04:00
parent 7994af8221
commit 237325a117
3 changed files with 376 additions and 0 deletions

View File

@ -12,10 +12,15 @@
pub mod config;
pub mod error;
pub mod sparse;
pub mod weights;
pub use config::{TemporalBackendKind, TemporalHeadConfig};
pub use error::TemporalError;
pub use sparse::SparseGqaHead;
pub use weights::{
WeightBlob, WeightBlobHeader, WeightDtype, WEIGHT_BLOB_HEADER_LEN, WEIGHT_BLOB_MAGIC,
WEIGHT_BLOB_VERSION,
};
// Re-export the upstream Tensor3 so callers don't need a direct
// `ruvllm_sparse_attention` dep.

View File

@ -0,0 +1,231 @@
// Wire format for the temporal-head weights blob.
//
// One blob describes one model. Both ends speak it:
// - Host-side (this crate): training emits a blob via `WeightBlob::serialize`.
// - Firmware-side (`firmware/esp32-csi-node/components/ruv_temporal`):
// loads it via a mirrored parser. The blob is the *only* thing
// that crosses the host/firmware boundary at deploy time, so the
// format must be stable, self-describing, and version-gated.
//
// Layout (little-endian throughout):
//
// header 16 B
// [0x00..0x04) magic u32 = 0x52564E45 ("RVNE" — RuVector Neural Edge)
// [0x04..0x06) version u16 = 1
// [0x06..0x07) flags u8 bit 0 = 0:fp32 / 1:fp16 weights
// [0x07..0x08) reserved u8
// [0x08..0x0A) input_dim u16 per-frame feature dim
// [0x0A..0x0C) n_q_heads u16 query head count
// [0x0C..0x0E) n_kv_heads u16 key/value head count (≤ n_q_heads, divides it)
// [0x0E..0x10) head_dim u16 per-head feature dim
//
// body variable
// [0x10..0x12) n_layers u16
// [0x12..0x14) n_classes u16
// [0x14..0x18) weights_len u32 bytes of weights payload (after this header)
// [0x18..end-4) weights weights_len bytes — flat per-layer arrays
// in the order the kernel reads them
// footer 4 B
// [end-4..end) crc32 u32 IEEE 802.3, covers everything before
//
// Total size = 16 (header) + 2+2+4 (body header) + weights_len + 4 (crc) = 28 + weights_len
//
// Versioning: bumping `version` is a hard break — firmware refuses to
// load a blob whose version it doesn't know. Adding a *new* field is
// done by reserving a new flag bit and treating the field as
// post-weights when the bit is set; never reorder existing fields.
use crate::error::TemporalError;
pub const WEIGHT_BLOB_MAGIC: u32 = 0x5256_4E45; // "RVNE"
pub const WEIGHT_BLOB_VERSION: u16 = 1;
pub const WEIGHT_BLOB_HEADER_LEN: usize = 16 + 2 + 2 + 4; // 24
pub const WEIGHT_BLOB_FOOTER_LEN: usize = 4;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum WeightDtype {
F32,
F16,
}
#[derive(Clone, Debug)]
pub struct WeightBlobHeader {
pub dtype: WeightDtype,
pub input_dim: u16,
pub n_q_heads: u16,
pub n_kv_heads: u16,
pub head_dim: u16,
pub n_layers: u16,
pub n_classes: u16,
}
impl WeightBlobHeader {
/// Element size in bytes for the configured dtype.
pub fn elem_bytes(&self) -> usize {
match self.dtype {
WeightDtype::F32 => 4,
WeightDtype::F16 => 2,
}
}
/// Validate that the structural numbers make sense — caught here
/// rather than at first kernel call so the host-side training
/// tool can't accidentally emit a blob the firmware will reject
/// at boot.
pub fn validate(&self) -> Result<(), TemporalError> {
if self.input_dim == 0
|| self.n_q_heads == 0
|| self.n_kv_heads == 0
|| self.head_dim == 0
{
return Err(TemporalError::InvalidConfig(
"header: zero-valued dimension(s)",
));
}
if self.n_q_heads % self.n_kv_heads != 0 {
return Err(TemporalError::InvalidConfig(
"header: n_q_heads must be divisible by n_kv_heads (GQA)",
));
}
if self.n_layers == 0 || self.n_classes < 2 {
return Err(TemporalError::InvalidConfig(
"header: n_layers must be ≥ 1 and n_classes ≥ 2",
));
}
Ok(())
}
}
/// A complete weight blob: header + raw weights bytes.
///
/// Weights are kept as `Vec<u8>` rather than `Vec<f32>` / `Vec<f16>` so
/// the firmware loader (which is no_std and may not have the `half`
/// crate) can `mmap` the body and read either dtype directly.
pub struct WeightBlob {
pub header: WeightBlobHeader,
pub weights: Vec<u8>,
}
impl WeightBlob {
pub fn new(header: WeightBlobHeader, weights: Vec<u8>) -> Result<Self, TemporalError> {
header.validate()?;
let elem = header.elem_bytes();
if weights.len() % elem != 0 {
return Err(TemporalError::InvalidConfig(
"weights length is not a multiple of dtype element size",
));
}
Ok(Self { header, weights })
}
/// Serialize to the wire format. Stable across rebuilds — this is
/// the contract the firmware loader speaks.
pub fn serialize(&self) -> Vec<u8> {
let total = WEIGHT_BLOB_HEADER_LEN + self.weights.len() + WEIGHT_BLOB_FOOTER_LEN;
let mut out = Vec::with_capacity(total);
// header
out.extend_from_slice(&WEIGHT_BLOB_MAGIC.to_le_bytes());
out.extend_from_slice(&WEIGHT_BLOB_VERSION.to_le_bytes());
let flags: u8 = match self.header.dtype {
WeightDtype::F32 => 0,
WeightDtype::F16 => 1,
};
out.push(flags);
out.push(0); // reserved
out.extend_from_slice(&self.header.input_dim.to_le_bytes());
out.extend_from_slice(&self.header.n_q_heads.to_le_bytes());
out.extend_from_slice(&self.header.n_kv_heads.to_le_bytes());
out.extend_from_slice(&self.header.head_dim.to_le_bytes());
// body header
out.extend_from_slice(&self.header.n_layers.to_le_bytes());
out.extend_from_slice(&self.header.n_classes.to_le_bytes());
out.extend_from_slice(&(self.weights.len() as u32).to_le_bytes());
// weights
out.extend_from_slice(&self.weights);
// footer: crc32 over everything written so far
let crc = crc32_ieee(&out);
out.extend_from_slice(&crc.to_le_bytes());
out
}
/// Parse a blob, validating magic / version / size / CRC.
pub fn parse(buf: &[u8]) -> Result<Self, TemporalError> {
if buf.len() < WEIGHT_BLOB_HEADER_LEN + WEIGHT_BLOB_FOOTER_LEN {
return Err(TemporalError::InvalidConfig("blob too short"));
}
let magic = u32::from_le_bytes(buf[0..4].try_into().unwrap());
if magic != WEIGHT_BLOB_MAGIC {
return Err(TemporalError::InvalidConfig("bad magic"));
}
let version = u16::from_le_bytes(buf[4..6].try_into().unwrap());
if version != WEIGHT_BLOB_VERSION {
return Err(TemporalError::InvalidConfig("unsupported blob version"));
}
let flags = buf[6];
let dtype = match flags & 0x01 {
0 => WeightDtype::F32,
_ => WeightDtype::F16,
};
let input_dim = u16::from_le_bytes(buf[8..10].try_into().unwrap());
let n_q_heads = u16::from_le_bytes(buf[10..12].try_into().unwrap());
let n_kv_heads = u16::from_le_bytes(buf[12..14].try_into().unwrap());
let head_dim = u16::from_le_bytes(buf[14..16].try_into().unwrap());
let n_layers = u16::from_le_bytes(buf[16..18].try_into().unwrap());
let n_classes = u16::from_le_bytes(buf[18..20].try_into().unwrap());
let weights_len = u32::from_le_bytes(buf[20..24].try_into().unwrap()) as usize;
// sanity-check size before slicing
let expected = WEIGHT_BLOB_HEADER_LEN + weights_len + WEIGHT_BLOB_FOOTER_LEN;
if buf.len() != expected {
return Err(TemporalError::InvalidConfig(
"blob length doesn't match weights_len in header",
));
}
// CRC check: cover everything before the trailing 4-byte CRC
let stored_crc = u32::from_le_bytes(buf[buf.len() - 4..].try_into().unwrap());
let computed = crc32_ieee(&buf[..buf.len() - 4]);
if stored_crc != computed {
return Err(TemporalError::InvalidConfig("blob CRC mismatch"));
}
let header = WeightBlobHeader {
dtype,
input_dim,
n_q_heads,
n_kv_heads,
head_dim,
n_layers,
n_classes,
};
header.validate()?;
let weights_start = WEIGHT_BLOB_HEADER_LEN;
let weights_end = weights_start + weights_len;
let weights = buf[weights_start..weights_end].to_vec();
Ok(Self { header, weights })
}
}
/// IEEE 802.3 CRC32 (poly 0xEDB88320), table-free. Same polynomial
/// the firmware-side loader uses (`temporal_task.c::crc32_ieee`) so a
/// blob produced here parses there.
pub(crate) fn crc32_ieee(data: &[u8]) -> u32 {
let mut crc = 0xFFFF_FFFFu32;
for &b in data {
crc ^= b as u32;
for _ in 0..8 {
let mask = 0u32.wrapping_sub(crc & 1);
crc = (crc >> 1) ^ (0xEDB8_8320 & mask);
}
}
!crc
}

View File

@ -0,0 +1,140 @@
//! Roundtrip + corruption-detection tests for the temporal head's
//! weight-blob wire format. The format is the contract between
//! host-side training and firmware-side inference — when this test
//! file changes, both ends update in lockstep.
use wifi_densepose_temporal::{
WeightBlob, WeightBlobHeader, WeightDtype, WEIGHT_BLOB_HEADER_LEN, WEIGHT_BLOB_MAGIC,
WEIGHT_BLOB_VERSION,
};
fn header_default() -> WeightBlobHeader {
WeightBlobHeader {
dtype: WeightDtype::F32,
input_dim: 16,
n_q_heads: 4,
n_kv_heads: 1,
head_dim: 32,
n_layers: 2,
n_classes: 4,
}
}
#[test]
fn roundtrip_fp32() {
let header = header_default();
let weights: Vec<u8> = (0..1024).map(|i| (i & 0xFF) as u8).collect();
let blob = WeightBlob::new(header, weights).expect("ok");
let serialized = blob.serialize();
let parsed = WeightBlob::parse(&serialized).expect("parse");
assert_eq!(parsed.header.input_dim, 16);
assert_eq!(parsed.header.n_q_heads, 4);
assert_eq!(parsed.header.n_kv_heads, 1);
assert_eq!(parsed.header.head_dim, 32);
assert_eq!(parsed.header.n_layers, 2);
assert_eq!(parsed.header.n_classes, 4);
assert_eq!(parsed.header.dtype, WeightDtype::F32);
assert_eq!(parsed.weights.len(), 1024);
}
#[test]
fn roundtrip_fp16() {
let header = WeightBlobHeader {
dtype: WeightDtype::F16,
..header_default()
};
// FP16 means 2 bytes per element — 512 bytes = 256 elements.
let weights: Vec<u8> = (0..512).map(|i| (i & 0xFF) as u8).collect();
let blob = WeightBlob::new(header, weights).expect("ok");
let serialized = blob.serialize();
let parsed = WeightBlob::parse(&serialized).expect("parse");
assert_eq!(parsed.header.dtype, WeightDtype::F16);
assert_eq!(parsed.weights.len(), 512);
}
#[test]
fn parse_rejects_bad_magic() {
let header = header_default();
let blob = WeightBlob::new(header, vec![0u8; 16]).expect("ok");
let mut bytes = blob.serialize();
bytes[0] = 0xFF; // corrupt magic
let err = WeightBlob::parse(&bytes).err().expect("rejected");
assert!(format!("{err}").contains("magic"));
}
#[test]
fn parse_rejects_wrong_version() {
let header = header_default();
let blob = WeightBlob::new(header, vec![0u8; 16]).expect("ok");
let mut bytes = blob.serialize();
bytes[4] = 99; // bump version
bytes[5] = 0;
let err = WeightBlob::parse(&bytes).err().expect("rejected");
assert!(format!("{err}").contains("version"));
}
#[test]
fn parse_rejects_size_mismatch() {
let header = header_default();
let blob = WeightBlob::new(header, vec![0u8; 64]).expect("ok");
let mut bytes = blob.serialize();
// truncate the weights region by 4 bytes — total length now
// doesn't match the weights_len field.
bytes.drain(WEIGHT_BLOB_HEADER_LEN..WEIGHT_BLOB_HEADER_LEN + 4);
let err = WeightBlob::parse(&bytes).err().expect("rejected");
assert!(format!("{err}").contains("length") || format!("{err}").contains("CRC"));
}
#[test]
fn parse_rejects_crc_corruption() {
let header = header_default();
let blob = WeightBlob::new(header, vec![0xAAu8; 32]).expect("ok");
let mut bytes = blob.serialize();
// flip a bit in the middle of the weights region
let mid = WEIGHT_BLOB_HEADER_LEN + 5;
bytes[mid] ^= 0x01;
let err = WeightBlob::parse(&bytes).err().expect("rejected");
assert!(format!("{err}").contains("CRC"));
}
#[test]
fn parse_rejects_invalid_gqa_ratio_in_header() {
// Manually craft bytes where n_q_heads % n_kv_heads != 0 to ensure
// header.validate() fires from inside parse(). Easiest: build a
// valid blob then patch the n_kv_heads field.
let header = header_default();
let blob = WeightBlob::new(header, vec![0u8; 16]).expect("ok");
let mut bytes = blob.serialize();
// n_kv_heads is at offset 12..14; set it to 3 so 4 % 3 != 0.
bytes[12] = 3;
bytes[13] = 0;
// Re-CRC so we can be sure the validator (not the CRC) is what
// rejects this case.
let new_crc = crc32_ieee(&bytes[..bytes.len() - 4]);
let crc_off = bytes.len() - 4;
bytes[crc_off..].copy_from_slice(&new_crc.to_le_bytes());
let err = WeightBlob::parse(&bytes).err().expect("rejected");
assert!(format!("{err}").to_lowercase().contains("gqa"));
}
#[test]
fn header_constants_match_wire_layout() {
// Anchor the public constants so they can't drift silently.
assert_eq!(WEIGHT_BLOB_MAGIC, 0x5256_4E45);
assert_eq!(WEIGHT_BLOB_VERSION, 1);
assert_eq!(WEIGHT_BLOB_HEADER_LEN, 24);
}
// Mirror of the production CRC32 so the size-mismatch / GQA tests can
// re-CRC after their patch. Kept out of the public API.
fn crc32_ieee(data: &[u8]) -> u32 {
let mut crc = 0xFFFF_FFFFu32;
for &b in data {
crc ^= b as u32;
for _ in 0..8 {
let mask = 0u32.wrapping_sub(crc & 1);
crc = (crc >> 1) ^ (0xEDB8_8320 & mask);
}
}
!crc
}