wifi-densepose/firmware/esp32-csi-node/components/ruv_temporal/src/weights.rs

195 lines
6.7 KiB
Rust

// Firmware-side mirror of `wifi-densepose-temporal::weights`. Same wire
// format, same magic, same CRC polynomial — a blob produced by the
// host's `WeightBlob::serialize()` parses here byte-for-byte.
//
// no_std + alloc. The host side keeps weights as `Vec<u8>` because it
// owns the buffer; the firmware loader takes a borrowed `&[u8]` slice
// (the blob lives in flash via EMBED_FILES, or a heap mmap from NVS,
// neither of which the loader should re-allocate).
//
// Stays *byte-exact* in lockstep with `v2/crates/wifi-densepose-temporal/src/weights.rs`.
// When the host format changes, this file changes in the same commit
// and bumps `BLOB_VERSION`; mismatched versions refuse to load.
use core::convert::TryInto;
use core::fmt;
pub const BLOB_MAGIC: u32 = 0x5256_4E45; // "RVNE"
pub const BLOB_VERSION: u16 = 1;
pub const BLOB_HEADER_LEN: usize = 24;
pub const BLOB_FOOTER_LEN: usize = 4;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum WeightDtype {
F32,
F16,
}
#[derive(Clone, Copy, Debug)]
pub struct WeightBlobHeader {
pub dtype: WeightDtype,
pub input_dim: u16,
pub n_q_heads: u16,
pub n_kv_heads: u16,
pub head_dim: u16,
pub n_layers: u16,
pub n_classes: u16,
}
impl WeightBlobHeader {
pub fn elem_bytes(&self) -> usize {
match self.dtype {
WeightDtype::F32 => 4,
WeightDtype::F16 => 2,
}
}
fn validate(&self) -> Result<(), WeightLoadError> {
if self.input_dim == 0
|| self.n_q_heads == 0
|| self.n_kv_heads == 0
|| self.head_dim == 0
{
return Err(WeightLoadError::ZeroDim);
}
if self.n_q_heads % self.n_kv_heads != 0 {
return Err(WeightLoadError::InvalidGqaRatio);
}
if self.n_layers == 0 || self.n_classes < 2 {
return Err(WeightLoadError::DegenerateShape);
}
Ok(())
}
}
/// A parsed view into a weights blob. Holds borrowed slices into the
/// caller-owned buffer — no allocation, no copy. The firmware's
/// kernel reads weights directly from this view.
#[derive(Clone, Copy)]
pub struct WeightBlobView<'a> {
pub header: WeightBlobHeader,
pub weights: &'a [u8],
}
impl<'a> WeightBlobView<'a> {
/// Parse a blob, validating magic / version / size / CRC. Returns
/// a borrowed view; the input `buf` must outlive the view.
pub fn parse(buf: &'a [u8]) -> Result<Self, WeightLoadError> {
if buf.len() < BLOB_HEADER_LEN + BLOB_FOOTER_LEN {
return Err(WeightLoadError::TooShort);
}
let magic = u32::from_le_bytes(buf[0..4].try_into().unwrap());
if magic != BLOB_MAGIC {
return Err(WeightLoadError::BadMagic);
}
let version = u16::from_le_bytes(buf[4..6].try_into().unwrap());
if version != BLOB_VERSION {
return Err(WeightLoadError::WrongVersion(version));
}
let flags = buf[6];
let dtype = match flags & 0x01 {
0 => WeightDtype::F32,
_ => WeightDtype::F16,
};
let input_dim = u16::from_le_bytes(buf[8..10].try_into().unwrap());
let n_q_heads = u16::from_le_bytes(buf[10..12].try_into().unwrap());
let n_kv_heads = u16::from_le_bytes(buf[12..14].try_into().unwrap());
let head_dim = u16::from_le_bytes(buf[14..16].try_into().unwrap());
let n_layers = u16::from_le_bytes(buf[16..18].try_into().unwrap());
let n_classes = u16::from_le_bytes(buf[18..20].try_into().unwrap());
let weights_len = u32::from_le_bytes(buf[20..24].try_into().unwrap()) as usize;
let expected = BLOB_HEADER_LEN + weights_len + BLOB_FOOTER_LEN;
if buf.len() != expected {
return Err(WeightLoadError::SizeMismatch);
}
let stored_crc = u32::from_le_bytes(buf[buf.len() - 4..].try_into().unwrap());
let computed = crc32_ieee(&buf[..buf.len() - 4]);
if stored_crc != computed {
return Err(WeightLoadError::CrcMismatch);
}
let header = WeightBlobHeader {
dtype,
input_dim,
n_q_heads,
n_kv_heads,
head_dim,
n_layers,
n_classes,
};
header.validate()?;
let weights_start = BLOB_HEADER_LEN;
let weights_end = weights_start + weights_len;
Ok(Self {
header,
weights: &buf[weights_start..weights_end],
})
}
}
/// Loader-side error. Distinct from the host-side `TemporalError` so
/// the firmware can map specific cases to specific `esp_err_t` codes.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum WeightLoadError {
TooShort,
BadMagic,
WrongVersion(u16),
SizeMismatch,
CrcMismatch,
ZeroDim,
InvalidGqaRatio,
DegenerateShape,
}
impl fmt::Display for WeightLoadError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooShort => write!(f, "weight blob too short"),
Self::BadMagic => write!(f, "weight blob: bad magic"),
Self::WrongVersion(v) => write!(f, "weight blob: unsupported version {}", v),
Self::SizeMismatch => write!(f, "weight blob: declared length doesn't match buffer"),
Self::CrcMismatch => write!(f, "weight blob: CRC32 mismatch"),
Self::ZeroDim => write!(f, "weight blob: zero-valued dimension(s)"),
Self::InvalidGqaRatio => write!(f, "weight blob: n_q_heads not divisible by n_kv_heads"),
Self::DegenerateShape => write!(f, "weight blob: n_layers=0 or n_classes<2"),
}
}
}
/// Map loader errors to esp_err_t-style codes for the C ABI. Defined
/// here rather than in lib.rs so the mapping stays adjacent to the
/// error type and can't drift.
pub const fn weight_load_err_to_esp(err: &WeightLoadError) -> i32 {
match err {
WeightLoadError::TooShort
| WeightLoadError::BadMagic
| WeightLoadError::WrongVersion(_)
| WeightLoadError::SizeMismatch => 0x102, // ESP_ERR_INVALID_ARG
WeightLoadError::CrcMismatch => 0x10C, // ESP_ERR_INVALID_CRC
WeightLoadError::ZeroDim
| WeightLoadError::InvalidGqaRatio
| WeightLoadError::DegenerateShape => 0x103, // ESP_ERR_INVALID_SIZE
}
}
/// Same polynomial as `temporal_task.c::crc32_ieee` and the host-side
/// `wifi_densepose_temporal::weights::crc32_ieee`. The whole point of
/// keeping it bit-for-bit identical across all three sites is so a
/// blob round-trips without re-computing.
fn crc32_ieee(data: &[u8]) -> u32 {
let mut crc = 0xFFFF_FFFFu32;
for &b in data {
crc ^= b as u32;
for _ in 0..8 {
let mask = 0u32.wrapping_sub(crc & 1);
crc = (crc >> 1) ^ (0xEDB8_8320 & mask);
}
}
!crc
}