108 lines
4.3 KiB
Python
108 lines
4.3 KiB
Python
"""WiFi-CSI pose model + LoRA adapter for the RuView calibration service.
|
||
|
||
Architecture matches the published flagship checkpoint
|
||
[`ruvnet/wifi-densepose-mmfi-pose`](https://huggingface.co/ruvnet/wifi-densepose-mmfi-pose)
|
||
(`pose_mmfi_best.pt`): transformer encoder + temporal attention pooling + skeleton-graph head.
|
||
|
||
The calibration service freezes this base and fits a tiny per-room **LoRA adapter** (rank 8 on the
|
||
input projection + pose head ≈ 11 KB) from ~100–200 labeled in-room samples. Empirically that lifts
|
||
cross-subject 64→72% and cross-environment 11→73% (ADR-150 §3.3–3.6).
|
||
"""
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
# COCO-17 skeleton edges for the graph-refinement head.
|
||
EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (5, 6), (5, 7), (7, 9), (6, 8), (8, 10),
|
||
(5, 11), (6, 12), (11, 12), (11, 13), (13, 15), (12, 14), (14, 16)]
|
||
_A = np.eye(17, dtype=np.float32)
|
||
for _i, _j in EDGES:
|
||
_A[_i, _j] = _A[_j, _i] = 1.0
|
||
_A = _A / _A.sum(1, keepdims=True)
|
||
|
||
|
||
class LoRA(nn.Module):
|
||
"""Low-rank adapter wrapping a frozen Linear: y = W·x + (x·A·B)·(alpha/r)."""
|
||
|
||
def __init__(self, base: nn.Linear, r: int = 8, alpha: int = 16):
|
||
super().__init__()
|
||
self.base = base
|
||
for p in self.base.parameters():
|
||
p.requires_grad = False
|
||
self.A = nn.Parameter(torch.zeros(base.in_features, r))
|
||
self.B = nn.Parameter(torch.zeros(r, base.out_features))
|
||
nn.init.normal_(self.A, std=0.02)
|
||
self.scale = alpha / r
|
||
|
||
def forward(self, x):
|
||
return self.base(x) + (x @ self.A @ self.B) * self.scale
|
||
|
||
|
||
class GR(nn.Module):
|
||
"""Skeleton-graph refinement: nudges joints toward anatomically consistent positions."""
|
||
|
||
def __init__(self, d=256, h=96):
|
||
super().__init__()
|
||
self.je = nn.Parameter(torch.randn(17, 32) * 0.02)
|
||
self.inp = nn.Linear(d + 34, h)
|
||
self.g1 = nn.Linear(h, h)
|
||
self.g2 = nn.Linear(h, h)
|
||
self.out = nn.Linear(h, 2)
|
||
self.register_buffer("A", torch.tensor(_A))
|
||
|
||
def forward(self, z, kp0):
|
||
B = z.shape[0]
|
||
f = torch.relu(self.inp(torch.cat(
|
||
[z.unsqueeze(1).expand(-1, 17, -1), self.je.unsqueeze(0).expand(B, -1, -1), kp0], -1)))
|
||
f = torch.relu(self.g1(torch.einsum('ij,bjh->bih', self.A, f)))
|
||
f = torch.relu(self.g2(torch.einsum('ij,bjh->bih', self.A, f)))
|
||
return kp0 + 0.3 * torch.tanh(self.out(f))
|
||
|
||
|
||
class PoseNet(nn.Module):
|
||
"""Flagship pose model. Input [B,3,114,10] CSI amplitude (per-sample standardized) -> [B,34]."""
|
||
|
||
def __init__(self, na=3, nsc=114, nt=10, d=256, L=4, H=8):
|
||
super().__init__()
|
||
self.proj = nn.Linear(na * nsc, d)
|
||
self.pos = nn.Parameter(torch.randn(1, nt, d) * 0.02)
|
||
enc = nn.TransformerEncoderLayer(d, H, d * 2, dropout=0.2, batch_first=True, activation='gelu')
|
||
self.tf = nn.TransformerEncoder(enc, L)
|
||
self.att = nn.Linear(d, 1)
|
||
self.head = nn.Sequential(nn.Linear(d, 256), nn.GELU(), nn.Dropout(0.3), nn.Linear(256, 34))
|
||
self.gr = GR(d)
|
||
self.na, self.nsc, self.nt = na, nsc, nt
|
||
|
||
def forward(self, x):
|
||
B = x.shape[0]
|
||
t = x.permute(0, 3, 1, 2).reshape(B, self.nt, self.na * self.nsc)
|
||
h = self.tf(self.proj(t) + self.pos)
|
||
w = torch.softmax(self.att(h), 1)
|
||
z = (h * w).sum(1)
|
||
kp0 = torch.sigmoid(self.head(z)).reshape(B, 17, 2)
|
||
return self.gr(z, kp0).reshape(B, 34)
|
||
|
||
def add_lora(self, r=8, alpha=16):
|
||
"""Wrap the input projection + pose head with LoRA adapters (the ~11 KB calibration set)."""
|
||
self.proj = LoRA(self.proj, r, alpha)
|
||
self.head[0] = LoRA(self.head[0], r, alpha)
|
||
self.head[3] = LoRA(self.head[3], r, alpha)
|
||
return self
|
||
|
||
def lora_state(self) -> dict:
|
||
"""Extract just the LoRA A/B tensors (the per-room adapter to save)."""
|
||
return {k: v.detach().cpu().numpy() for k, v in self.state_dict().items()
|
||
if k.endswith(".A") or k.endswith(".B")}
|
||
|
||
def load_lora(self, adapter: dict):
|
||
sd = self.state_dict()
|
||
for k, v in adapter.items():
|
||
sd[k] = torch.tensor(v)
|
||
self.load_state_dict(sd)
|
||
return self
|
||
|
||
|
||
def standardize(x: torch.Tensor) -> torch.Tensor:
|
||
"""Per-sample standardization used in training/inference."""
|
||
return (x - x.mean((1, 2, 3), keepdim=True)) / (x.std((1, 2, 3), keepdim=True) + 1e-6)
|