feat(temporal): weight-blob wire format (ADR-095 Phase 1, #513)
The training/firmware boundary needs a stable serialization for the
temporal head's weights, distinct from the kernel scaffold and the
firmware ABI. This commit defines that format on the host side. The
firmware-side mirrored loader lands when the toolchain unblocks.
Format:
- Header (24 B): magic 'RVNE' / version 1 / dtype flag
(FP32 / FP16) / input_dim / n_q_heads / n_kv_heads / head_dim /
n_layers / n_classes / weights_len.
- Body: weights_len bytes of flat per-layer weights.
- Footer (4 B): CRC32 IEEE 802.3 over everything before, same
polynomial used by temporal_task.c so a blob produced here parses
on the firmware unchanged.
Layout decisions:
- Little-endian throughout (Xtensa native).
- Weights kept as Vec<u8> rather than Vec<f32>/Vec<f16> so the no_std
firmware loader (which may not have the `half` crate) can mmap and
read either dtype directly.
- Versioning is hard-break: bumping `version` means firmware refuses
to load. Optional fields go behind reserved flag bits, never by
field reorder. Documented inline.
Validation surface:
- `WeightBlobHeader::validate()` catches zero dims, invalid GQA
ratios (n_q_heads % n_kv_heads != 0), n_layers=0, n_classes<2.
Same checks fire from `WeightBlob::parse()` so the firmware can't
accidentally accept a blob the host should have rejected.
- `WeightBlob::parse()` enforces magic / version / size / CRC
before exposing weights to the caller.
Tests (8/8 passing, alongside 5/5 sparse smoke = 13/13 total):
- roundtrip_fp32, roundtrip_fp16
- parse_rejects_bad_magic, _wrong_version, _size_mismatch,
_crc_corruption, _invalid_gqa_ratio_in_header
- header_constants_match_wire_layout (anchor)
What's deliberately NOT in this commit:
- The firmware-side mirrored loader (deferred to the iteration that
unblocks the esp Rust toolchain — no point shipping a parser that
can't be compiled).
- Per-layer weight ordering. The blob is a flat byte-buffer; the
interpretation of per-layer offsets is the kernel's contract,
documented in the eventual model module (ADR-095 §3.2 follow-up).
Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
7994af8221
commit
237325a117
|
|
@ -12,10 +12,15 @@
|
|||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod sparse;
|
||||
pub mod weights;
|
||||
|
||||
pub use config::{TemporalBackendKind, TemporalHeadConfig};
|
||||
pub use error::TemporalError;
|
||||
pub use sparse::SparseGqaHead;
|
||||
pub use weights::{
|
||||
WeightBlob, WeightBlobHeader, WeightDtype, WEIGHT_BLOB_HEADER_LEN, WEIGHT_BLOB_MAGIC,
|
||||
WEIGHT_BLOB_VERSION,
|
||||
};
|
||||
|
||||
// Re-export the upstream Tensor3 so callers don't need a direct
|
||||
// `ruvllm_sparse_attention` dep.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,231 @@
|
|||
// Wire format for the temporal-head weights blob.
|
||||
//
|
||||
// One blob describes one model. Both ends speak it:
|
||||
// - Host-side (this crate): training emits a blob via `WeightBlob::serialize`.
|
||||
// - Firmware-side (`firmware/esp32-csi-node/components/ruv_temporal`):
|
||||
// loads it via a mirrored parser. The blob is the *only* thing
|
||||
// that crosses the host/firmware boundary at deploy time, so the
|
||||
// format must be stable, self-describing, and version-gated.
|
||||
//
|
||||
// Layout (little-endian throughout):
|
||||
//
|
||||
// header 16 B
|
||||
// [0x00..0x04) magic u32 = 0x52564E45 ("RVNE" — RuVector Neural Edge)
|
||||
// [0x04..0x06) version u16 = 1
|
||||
// [0x06..0x07) flags u8 bit 0 = 0:fp32 / 1:fp16 weights
|
||||
// [0x07..0x08) reserved u8
|
||||
// [0x08..0x0A) input_dim u16 per-frame feature dim
|
||||
// [0x0A..0x0C) n_q_heads u16 query head count
|
||||
// [0x0C..0x0E) n_kv_heads u16 key/value head count (≤ n_q_heads, divides it)
|
||||
// [0x0E..0x10) head_dim u16 per-head feature dim
|
||||
//
|
||||
// body variable
|
||||
// [0x10..0x12) n_layers u16
|
||||
// [0x12..0x14) n_classes u16
|
||||
// [0x14..0x18) weights_len u32 bytes of weights payload (after this header)
|
||||
// [0x18..end-4) weights weights_len bytes — flat per-layer arrays
|
||||
// in the order the kernel reads them
|
||||
// footer 4 B
|
||||
// [end-4..end) crc32 u32 IEEE 802.3, covers everything before
|
||||
//
|
||||
// Total size = 16 (header) + 2+2+4 (body header) + weights_len + 4 (crc) = 28 + weights_len
|
||||
//
|
||||
// Versioning: bumping `version` is a hard break — firmware refuses to
|
||||
// load a blob whose version it doesn't know. Adding a *new* field is
|
||||
// done by reserving a new flag bit and treating the field as
|
||||
// post-weights when the bit is set; never reorder existing fields.
|
||||
|
||||
use crate::error::TemporalError;
|
||||
|
||||
pub const WEIGHT_BLOB_MAGIC: u32 = 0x5256_4E45; // "RVNE"
|
||||
pub const WEIGHT_BLOB_VERSION: u16 = 1;
|
||||
pub const WEIGHT_BLOB_HEADER_LEN: usize = 16 + 2 + 2 + 4; // 24
|
||||
pub const WEIGHT_BLOB_FOOTER_LEN: usize = 4;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub enum WeightDtype {
|
||||
F32,
|
||||
F16,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WeightBlobHeader {
|
||||
pub dtype: WeightDtype,
|
||||
pub input_dim: u16,
|
||||
pub n_q_heads: u16,
|
||||
pub n_kv_heads: u16,
|
||||
pub head_dim: u16,
|
||||
pub n_layers: u16,
|
||||
pub n_classes: u16,
|
||||
}
|
||||
|
||||
impl WeightBlobHeader {
|
||||
/// Element size in bytes for the configured dtype.
|
||||
pub fn elem_bytes(&self) -> usize {
|
||||
match self.dtype {
|
||||
WeightDtype::F32 => 4,
|
||||
WeightDtype::F16 => 2,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate that the structural numbers make sense — caught here
|
||||
/// rather than at first kernel call so the host-side training
|
||||
/// tool can't accidentally emit a blob the firmware will reject
|
||||
/// at boot.
|
||||
pub fn validate(&self) -> Result<(), TemporalError> {
|
||||
if self.input_dim == 0
|
||||
|| self.n_q_heads == 0
|
||||
|| self.n_kv_heads == 0
|
||||
|| self.head_dim == 0
|
||||
{
|
||||
return Err(TemporalError::InvalidConfig(
|
||||
"header: zero-valued dimension(s)",
|
||||
));
|
||||
}
|
||||
if self.n_q_heads % self.n_kv_heads != 0 {
|
||||
return Err(TemporalError::InvalidConfig(
|
||||
"header: n_q_heads must be divisible by n_kv_heads (GQA)",
|
||||
));
|
||||
}
|
||||
if self.n_layers == 0 || self.n_classes < 2 {
|
||||
return Err(TemporalError::InvalidConfig(
|
||||
"header: n_layers must be ≥ 1 and n_classes ≥ 2",
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A complete weight blob: header + raw weights bytes.
|
||||
///
|
||||
/// Weights are kept as `Vec<u8>` rather than `Vec<f32>` / `Vec<f16>` so
|
||||
/// the firmware loader (which is no_std and may not have the `half`
|
||||
/// crate) can `mmap` the body and read either dtype directly.
|
||||
pub struct WeightBlob {
|
||||
pub header: WeightBlobHeader,
|
||||
pub weights: Vec<u8>,
|
||||
}
|
||||
|
||||
impl WeightBlob {
|
||||
pub fn new(header: WeightBlobHeader, weights: Vec<u8>) -> Result<Self, TemporalError> {
|
||||
header.validate()?;
|
||||
let elem = header.elem_bytes();
|
||||
if weights.len() % elem != 0 {
|
||||
return Err(TemporalError::InvalidConfig(
|
||||
"weights length is not a multiple of dtype element size",
|
||||
));
|
||||
}
|
||||
Ok(Self { header, weights })
|
||||
}
|
||||
|
||||
/// Serialize to the wire format. Stable across rebuilds — this is
|
||||
/// the contract the firmware loader speaks.
|
||||
pub fn serialize(&self) -> Vec<u8> {
|
||||
let total = WEIGHT_BLOB_HEADER_LEN + self.weights.len() + WEIGHT_BLOB_FOOTER_LEN;
|
||||
let mut out = Vec::with_capacity(total);
|
||||
|
||||
// header
|
||||
out.extend_from_slice(&WEIGHT_BLOB_MAGIC.to_le_bytes());
|
||||
out.extend_from_slice(&WEIGHT_BLOB_VERSION.to_le_bytes());
|
||||
let flags: u8 = match self.header.dtype {
|
||||
WeightDtype::F32 => 0,
|
||||
WeightDtype::F16 => 1,
|
||||
};
|
||||
out.push(flags);
|
||||
out.push(0); // reserved
|
||||
out.extend_from_slice(&self.header.input_dim.to_le_bytes());
|
||||
out.extend_from_slice(&self.header.n_q_heads.to_le_bytes());
|
||||
out.extend_from_slice(&self.header.n_kv_heads.to_le_bytes());
|
||||
out.extend_from_slice(&self.header.head_dim.to_le_bytes());
|
||||
|
||||
// body header
|
||||
out.extend_from_slice(&self.header.n_layers.to_le_bytes());
|
||||
out.extend_from_slice(&self.header.n_classes.to_le_bytes());
|
||||
out.extend_from_slice(&(self.weights.len() as u32).to_le_bytes());
|
||||
|
||||
// weights
|
||||
out.extend_from_slice(&self.weights);
|
||||
|
||||
// footer: crc32 over everything written so far
|
||||
let crc = crc32_ieee(&out);
|
||||
out.extend_from_slice(&crc.to_le_bytes());
|
||||
out
|
||||
}
|
||||
|
||||
/// Parse a blob, validating magic / version / size / CRC.
|
||||
pub fn parse(buf: &[u8]) -> Result<Self, TemporalError> {
|
||||
if buf.len() < WEIGHT_BLOB_HEADER_LEN + WEIGHT_BLOB_FOOTER_LEN {
|
||||
return Err(TemporalError::InvalidConfig("blob too short"));
|
||||
}
|
||||
|
||||
let magic = u32::from_le_bytes(buf[0..4].try_into().unwrap());
|
||||
if magic != WEIGHT_BLOB_MAGIC {
|
||||
return Err(TemporalError::InvalidConfig("bad magic"));
|
||||
}
|
||||
let version = u16::from_le_bytes(buf[4..6].try_into().unwrap());
|
||||
if version != WEIGHT_BLOB_VERSION {
|
||||
return Err(TemporalError::InvalidConfig("unsupported blob version"));
|
||||
}
|
||||
let flags = buf[6];
|
||||
let dtype = match flags & 0x01 {
|
||||
0 => WeightDtype::F32,
|
||||
_ => WeightDtype::F16,
|
||||
};
|
||||
|
||||
let input_dim = u16::from_le_bytes(buf[8..10].try_into().unwrap());
|
||||
let n_q_heads = u16::from_le_bytes(buf[10..12].try_into().unwrap());
|
||||
let n_kv_heads = u16::from_le_bytes(buf[12..14].try_into().unwrap());
|
||||
let head_dim = u16::from_le_bytes(buf[14..16].try_into().unwrap());
|
||||
|
||||
let n_layers = u16::from_le_bytes(buf[16..18].try_into().unwrap());
|
||||
let n_classes = u16::from_le_bytes(buf[18..20].try_into().unwrap());
|
||||
let weights_len = u32::from_le_bytes(buf[20..24].try_into().unwrap()) as usize;
|
||||
|
||||
// sanity-check size before slicing
|
||||
let expected = WEIGHT_BLOB_HEADER_LEN + weights_len + WEIGHT_BLOB_FOOTER_LEN;
|
||||
if buf.len() != expected {
|
||||
return Err(TemporalError::InvalidConfig(
|
||||
"blob length doesn't match weights_len in header",
|
||||
));
|
||||
}
|
||||
|
||||
// CRC check: cover everything before the trailing 4-byte CRC
|
||||
let stored_crc = u32::from_le_bytes(buf[buf.len() - 4..].try_into().unwrap());
|
||||
let computed = crc32_ieee(&buf[..buf.len() - 4]);
|
||||
if stored_crc != computed {
|
||||
return Err(TemporalError::InvalidConfig("blob CRC mismatch"));
|
||||
}
|
||||
|
||||
let header = WeightBlobHeader {
|
||||
dtype,
|
||||
input_dim,
|
||||
n_q_heads,
|
||||
n_kv_heads,
|
||||
head_dim,
|
||||
n_layers,
|
||||
n_classes,
|
||||
};
|
||||
header.validate()?;
|
||||
|
||||
let weights_start = WEIGHT_BLOB_HEADER_LEN;
|
||||
let weights_end = weights_start + weights_len;
|
||||
let weights = buf[weights_start..weights_end].to_vec();
|
||||
|
||||
Ok(Self { header, weights })
|
||||
}
|
||||
}
|
||||
|
||||
/// IEEE 802.3 CRC32 (poly 0xEDB88320), table-free. Same polynomial
|
||||
/// the firmware-side loader uses (`temporal_task.c::crc32_ieee`) so a
|
||||
/// blob produced here parses there.
|
||||
pub(crate) fn crc32_ieee(data: &[u8]) -> u32 {
|
||||
let mut crc = 0xFFFF_FFFFu32;
|
||||
for &b in data {
|
||||
crc ^= b as u32;
|
||||
for _ in 0..8 {
|
||||
let mask = 0u32.wrapping_sub(crc & 1);
|
||||
crc = (crc >> 1) ^ (0xEDB8_8320 & mask);
|
||||
}
|
||||
}
|
||||
!crc
|
||||
}
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
//! Roundtrip + corruption-detection tests for the temporal head's
|
||||
//! weight-blob wire format. The format is the contract between
|
||||
//! host-side training and firmware-side inference — when this test
|
||||
//! file changes, both ends update in lockstep.
|
||||
|
||||
use wifi_densepose_temporal::{
|
||||
WeightBlob, WeightBlobHeader, WeightDtype, WEIGHT_BLOB_HEADER_LEN, WEIGHT_BLOB_MAGIC,
|
||||
WEIGHT_BLOB_VERSION,
|
||||
};
|
||||
|
||||
fn header_default() -> WeightBlobHeader {
|
||||
WeightBlobHeader {
|
||||
dtype: WeightDtype::F32,
|
||||
input_dim: 16,
|
||||
n_q_heads: 4,
|
||||
n_kv_heads: 1,
|
||||
head_dim: 32,
|
||||
n_layers: 2,
|
||||
n_classes: 4,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_fp32() {
|
||||
let header = header_default();
|
||||
let weights: Vec<u8> = (0..1024).map(|i| (i & 0xFF) as u8).collect();
|
||||
let blob = WeightBlob::new(header, weights).expect("ok");
|
||||
let serialized = blob.serialize();
|
||||
let parsed = WeightBlob::parse(&serialized).expect("parse");
|
||||
assert_eq!(parsed.header.input_dim, 16);
|
||||
assert_eq!(parsed.header.n_q_heads, 4);
|
||||
assert_eq!(parsed.header.n_kv_heads, 1);
|
||||
assert_eq!(parsed.header.head_dim, 32);
|
||||
assert_eq!(parsed.header.n_layers, 2);
|
||||
assert_eq!(parsed.header.n_classes, 4);
|
||||
assert_eq!(parsed.header.dtype, WeightDtype::F32);
|
||||
assert_eq!(parsed.weights.len(), 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_fp16() {
|
||||
let header = WeightBlobHeader {
|
||||
dtype: WeightDtype::F16,
|
||||
..header_default()
|
||||
};
|
||||
// FP16 means 2 bytes per element — 512 bytes = 256 elements.
|
||||
let weights: Vec<u8> = (0..512).map(|i| (i & 0xFF) as u8).collect();
|
||||
let blob = WeightBlob::new(header, weights).expect("ok");
|
||||
let serialized = blob.serialize();
|
||||
let parsed = WeightBlob::parse(&serialized).expect("parse");
|
||||
assert_eq!(parsed.header.dtype, WeightDtype::F16);
|
||||
assert_eq!(parsed.weights.len(), 512);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_rejects_bad_magic() {
|
||||
let header = header_default();
|
||||
let blob = WeightBlob::new(header, vec![0u8; 16]).expect("ok");
|
||||
let mut bytes = blob.serialize();
|
||||
bytes[0] = 0xFF; // corrupt magic
|
||||
let err = WeightBlob::parse(&bytes).err().expect("rejected");
|
||||
assert!(format!("{err}").contains("magic"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_rejects_wrong_version() {
|
||||
let header = header_default();
|
||||
let blob = WeightBlob::new(header, vec![0u8; 16]).expect("ok");
|
||||
let mut bytes = blob.serialize();
|
||||
bytes[4] = 99; // bump version
|
||||
bytes[5] = 0;
|
||||
let err = WeightBlob::parse(&bytes).err().expect("rejected");
|
||||
assert!(format!("{err}").contains("version"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_rejects_size_mismatch() {
|
||||
let header = header_default();
|
||||
let blob = WeightBlob::new(header, vec![0u8; 64]).expect("ok");
|
||||
let mut bytes = blob.serialize();
|
||||
// truncate the weights region by 4 bytes — total length now
|
||||
// doesn't match the weights_len field.
|
||||
bytes.drain(WEIGHT_BLOB_HEADER_LEN..WEIGHT_BLOB_HEADER_LEN + 4);
|
||||
let err = WeightBlob::parse(&bytes).err().expect("rejected");
|
||||
assert!(format!("{err}").contains("length") || format!("{err}").contains("CRC"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_rejects_crc_corruption() {
|
||||
let header = header_default();
|
||||
let blob = WeightBlob::new(header, vec![0xAAu8; 32]).expect("ok");
|
||||
let mut bytes = blob.serialize();
|
||||
// flip a bit in the middle of the weights region
|
||||
let mid = WEIGHT_BLOB_HEADER_LEN + 5;
|
||||
bytes[mid] ^= 0x01;
|
||||
let err = WeightBlob::parse(&bytes).err().expect("rejected");
|
||||
assert!(format!("{err}").contains("CRC"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_rejects_invalid_gqa_ratio_in_header() {
|
||||
// Manually craft bytes where n_q_heads % n_kv_heads != 0 to ensure
|
||||
// header.validate() fires from inside parse(). Easiest: build a
|
||||
// valid blob then patch the n_kv_heads field.
|
||||
let header = header_default();
|
||||
let blob = WeightBlob::new(header, vec![0u8; 16]).expect("ok");
|
||||
let mut bytes = blob.serialize();
|
||||
// n_kv_heads is at offset 12..14; set it to 3 so 4 % 3 != 0.
|
||||
bytes[12] = 3;
|
||||
bytes[13] = 0;
|
||||
// Re-CRC so we can be sure the validator (not the CRC) is what
|
||||
// rejects this case.
|
||||
let new_crc = crc32_ieee(&bytes[..bytes.len() - 4]);
|
||||
let crc_off = bytes.len() - 4;
|
||||
bytes[crc_off..].copy_from_slice(&new_crc.to_le_bytes());
|
||||
let err = WeightBlob::parse(&bytes).err().expect("rejected");
|
||||
assert!(format!("{err}").to_lowercase().contains("gqa"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_constants_match_wire_layout() {
|
||||
// Anchor the public constants so they can't drift silently.
|
||||
assert_eq!(WEIGHT_BLOB_MAGIC, 0x5256_4E45);
|
||||
assert_eq!(WEIGHT_BLOB_VERSION, 1);
|
||||
assert_eq!(WEIGHT_BLOB_HEADER_LEN, 24);
|
||||
}
|
||||
|
||||
// Mirror of the production CRC32 so the size-mismatch / GQA tests can
|
||||
// re-CRC after their patch. Kept out of the public API.
|
||||
fn crc32_ieee(data: &[u8]) -> u32 {
|
||||
let mut crc = 0xFFFF_FFFFu32;
|
||||
for &b in data {
|
||||
crc ^= b as u32;
|
||||
for _ in 0..8 {
|
||||
let mask = 0u32.wrapping_sub(crc & 1);
|
||||
crc = (crc >> 1) ^ (0xEDB8_8320 & mask);
|
||||
}
|
||||
}
|
||||
!crc
|
||||
}
|
||||
Loading…
Reference in New Issue