wifi-densepose/aether-arena/calibration/model.py

108 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""WiFi-CSI pose model + LoRA adapter for the RuView calibration service.
Architecture matches the published flagship checkpoint
[`ruvnet/wifi-densepose-mmfi-pose`](https://huggingface.co/ruvnet/wifi-densepose-mmfi-pose)
(`pose_mmfi_best.pt`): transformer encoder + temporal attention pooling + skeleton-graph head.
The calibration service freezes this base and fits a tiny per-room **LoRA adapter** (rank 8 on the
input projection + pose head ≈ 11 KB) from ~100200 labeled in-room samples. Empirically that lifts
cross-subject 64→72% and cross-environment 11→73% (ADR-150 §3.33.6).
"""
import numpy as np
import torch
import torch.nn as nn
# COCO-17 skeleton edges for the graph-refinement head.
EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (5, 6), (5, 7), (7, 9), (6, 8), (8, 10),
(5, 11), (6, 12), (11, 12), (11, 13), (13, 15), (12, 14), (14, 16)]
_A = np.eye(17, dtype=np.float32)
for _i, _j in EDGES:
_A[_i, _j] = _A[_j, _i] = 1.0
_A = _A / _A.sum(1, keepdims=True)
class LoRA(nn.Module):
"""Low-rank adapter wrapping a frozen Linear: y = W·x + (x·A·B)·(alpha/r)."""
def __init__(self, base: nn.Linear, r: int = 8, alpha: int = 16):
super().__init__()
self.base = base
for p in self.base.parameters():
p.requires_grad = False
self.A = nn.Parameter(torch.zeros(base.in_features, r))
self.B = nn.Parameter(torch.zeros(r, base.out_features))
nn.init.normal_(self.A, std=0.02)
self.scale = alpha / r
def forward(self, x):
return self.base(x) + (x @ self.A @ self.B) * self.scale
class GR(nn.Module):
"""Skeleton-graph refinement: nudges joints toward anatomically consistent positions."""
def __init__(self, d=256, h=96):
super().__init__()
self.je = nn.Parameter(torch.randn(17, 32) * 0.02)
self.inp = nn.Linear(d + 34, h)
self.g1 = nn.Linear(h, h)
self.g2 = nn.Linear(h, h)
self.out = nn.Linear(h, 2)
self.register_buffer("A", torch.tensor(_A))
def forward(self, z, kp0):
B = z.shape[0]
f = torch.relu(self.inp(torch.cat(
[z.unsqueeze(1).expand(-1, 17, -1), self.je.unsqueeze(0).expand(B, -1, -1), kp0], -1)))
f = torch.relu(self.g1(torch.einsum('ij,bjh->bih', self.A, f)))
f = torch.relu(self.g2(torch.einsum('ij,bjh->bih', self.A, f)))
return kp0 + 0.3 * torch.tanh(self.out(f))
class PoseNet(nn.Module):
"""Flagship pose model. Input [B,3,114,10] CSI amplitude (per-sample standardized) -> [B,34]."""
def __init__(self, na=3, nsc=114, nt=10, d=256, L=4, H=8):
super().__init__()
self.proj = nn.Linear(na * nsc, d)
self.pos = nn.Parameter(torch.randn(1, nt, d) * 0.02)
enc = nn.TransformerEncoderLayer(d, H, d * 2, dropout=0.2, batch_first=True, activation='gelu')
self.tf = nn.TransformerEncoder(enc, L)
self.att = nn.Linear(d, 1)
self.head = nn.Sequential(nn.Linear(d, 256), nn.GELU(), nn.Dropout(0.3), nn.Linear(256, 34))
self.gr = GR(d)
self.na, self.nsc, self.nt = na, nsc, nt
def forward(self, x):
B = x.shape[0]
t = x.permute(0, 3, 1, 2).reshape(B, self.nt, self.na * self.nsc)
h = self.tf(self.proj(t) + self.pos)
w = torch.softmax(self.att(h), 1)
z = (h * w).sum(1)
kp0 = torch.sigmoid(self.head(z)).reshape(B, 17, 2)
return self.gr(z, kp0).reshape(B, 34)
def add_lora(self, r=8, alpha=16):
"""Wrap the input projection + pose head with LoRA adapters (the ~11 KB calibration set)."""
self.proj = LoRA(self.proj, r, alpha)
self.head[0] = LoRA(self.head[0], r, alpha)
self.head[3] = LoRA(self.head[3], r, alpha)
return self
def lora_state(self) -> dict:
"""Extract just the LoRA A/B tensors (the per-room adapter to save)."""
return {k: v.detach().cpu().numpy() for k, v in self.state_dict().items()
if k.endswith(".A") or k.endswith(".B")}
def load_lora(self, adapter: dict):
sd = self.state_dict()
for k, v in adapter.items():
sd[k] = torch.tensor(v)
self.load_state_dict(sd)
return self
def standardize(x: torch.Tensor) -> torch.Tensor:
"""Per-sample standardization used in training/inference."""
return (x - x.mean((1, 2, 3), keepdim=True)) / (x.std((1, 2, 3), keepdim=True) + 1e-6)