wifi-densepose/benchmarks/wiflow-std/export_to_safetensors.py

175 lines
6.9 KiB
Python

"""ADR-152 §2.2: export the retrained WiFlow-STD PyTorch checkpoint to
safetensors with tch-rs (VarStore) variable names, plus a numerical-parity
fixture for the Rust port.
Outputs (all under results/, gitignored):
retrained_wiflow_std.safetensors -- 248 f32 tensors named exactly as the
Rust WiFlowStdModel VarStore expects
(see wiflow_std/model.rs
`dump_variable_names` for the
authoritative name dump)
parity_fixture.npz -- deterministic input (seed 42,
shape (2, 540, 20), uniform [0,1]) and
the Python model's eval-mode output
parity_fixture.json -- same data as flattened f32 lists, for
the dependency-free Rust test
(tests/test_wiflow_std_parity.rs)
PyTorch -> tch key mapping (derived from the VarStore dump, not guessed):
tcn.network.{i}.conv1_group.weight -> tcn{i}.conv1_group.weight
tcn.network.{i}.bn*_{group,pw}.<leaf> -> tcn{i}.bn*_{group,pw}.<leaf>
tcn.network.{i}.downsample.0.weight -> tcn{i}.ds_conv.weight
tcn.network.{i}.downsample.1.<leaf> -> tcn{i}.ds_bn.<leaf>
up.block.{0,1,4,5,8,9}.<leaf> -> conv_in.{conv1,bn1,conv2,bn2,conv3,bn3}.<leaf>
up.downsample.{0,1}.<leaf> -> conv_in.{ds_conv,ds_bn}.<leaf>
residual_blocks.{i}.block.{...}.<leaf> -> conv{i}.{conv1..bn3}.<leaf>
residual_blocks.{i}.downsample.{0,1} -> conv{i}.{ds_conv,ds_bn}
attention.{width,height}_axis.qkv_transform.weight
-> attention.{width,height}.qkv.weight
attention.{width,height}_axis.bn_* -> attention.{width,height}.bn_*
decoder.{0,1,3,4}.<leaf> -> {dec_conv1,dec_bn1,dec_conv2,dec_bn2}.<leaf>
*.num_batches_tracked -> dropped (tch BatchNorm has no such buffer)
Legacy upstream names (att. -> attention., final_conv. -> decoder.) are
remapped first, exactly as eval_repro.py does for the released checkpoint.
Usage:
.venv/Scripts/python.exe export_to_safetensors.py
"""
import json
import os
import re
import numpy as np
import torch
from safetensors.torch import save_file
from _bench_common import RESULTS, import_upstream, remap_legacy_keys
import_upstream() # sys.path + models stub
from models.pose_model import WiFlowPoseModel # noqa: E402
CHECKPOINT = os.path.join(RESULTS, "retrained_best_pose_model.pth")
# Sequential index -> tch sub-name inside one ConvBlock1/AsymmetricConvBlock:
# [Conv2d(0), BN(1), SiLU(2), Dropout2d(3), Conv2d(4), BN(5), SiLU(6),
# Dropout2d(7), Conv2d(8), BN(9)]
_BLOCK_IDX = {"0": "conv1", "1": "bn1", "4": "conv2", "5": "bn2",
"8": "conv3", "9": "bn3"}
_DS_IDX = {"0": "ds_conv", "1": "ds_bn"}
_DECODER_IDX = {"0": "dec_conv1", "1": "dec_bn1", "3": "dec_conv2",
"4": "dec_bn2"}
def _conv_block(new_prefix: str, rest: str) -> str:
m = re.fullmatch(r"block\.(\d+)\.(.+)", rest)
if m:
return f"{new_prefix}.{_BLOCK_IDX[m.group(1)]}.{m.group(2)}"
m = re.fullmatch(r"downsample\.(\d+)\.(.+)", rest)
if m:
return f"{new_prefix}.{_DS_IDX[m.group(1)]}.{m.group(2)}"
raise KeyError(f"unmapped conv-block key: {new_prefix} / {rest}")
def map_key(key: str) -> str:
"""Map one PyTorch state_dict key to the tch VarStore name."""
m = re.fullmatch(r"tcn\.network\.(\d+)\.(.+)", key)
if m:
i, rest = m.groups()
rest = (rest.replace("downsample.0.", "ds_conv.")
.replace("downsample.1.", "ds_bn."))
return f"tcn{i}.{rest}"
m = re.fullmatch(r"up\.(.+)", key)
if m:
return _conv_block("conv_in", m.group(1))
m = re.fullmatch(r"residual_blocks\.(\d+)\.(.+)", key)
if m:
return _conv_block(f"conv{m.group(1)}", m.group(2))
m = re.fullmatch(r"attention\.(width|height)_axis\.(.+)", key)
if m:
axis, rest = m.groups()
rest = rest.replace("qkv_transform.", "qkv.")
return f"attention.{axis}.{rest}"
m = re.fullmatch(r"decoder\.(\d+)\.(.+)", key)
if m:
return f"{_DECODER_IDX[m.group(1)]}.{m.group(2)}"
raise KeyError(f"unmapped checkpoint key: {key}")
def main():
state = torch.load(CHECKPOINT, map_location="cpu", weights_only=True)
if not isinstance(state, dict) or "tcn.network.0.conv1_group.weight" not in {
k for k in state
} | {k.replace("att.", "attention.") for k in state}:
# tolerate trainer wrappers like {"model_state_dict": ...}
for wrapper in ("model_state_dict", "state_dict", "model"):
if isinstance(state, dict) and wrapper in state:
state = state[wrapper]
break
# Legacy upstream names predate the published code (_bench_common).
state = remap_legacy_keys(state)
mapped = {}
dropped = 0
for k, v in state.items():
if k.endswith("num_batches_tracked"):
dropped += 1
continue
tch_key = map_key(k)
if tch_key in mapped:
raise KeyError(f"duplicate mapped key: {k} -> {tch_key}")
mapped[tch_key] = v.detach().to(torch.float32).contiguous()
n_params = sum(v.numel() for k, v in mapped.items()
if "running_" not in k)
print(f"checkpoint tensors: {len(state)} "
f"(dropped {dropped} num_batches_tracked)")
print(f"mapped tensors: {len(mapped)}, "
f"non-buffer params: {n_params/1e6:.6f}M")
assert len(mapped) == 248, f"expected 248 tch variables, got {len(mapped)}"
assert n_params == 2_225_042, f"param count mismatch: {n_params}"
st_path = os.path.join(RESULTS, "retrained_wiflow_std.safetensors")
save_file(mapped, st_path)
print(f"wrote {st_path}")
# ---- parity fixture --------------------------------------------------
model = WiFlowPoseModel(dropout=0.5)
model.load_state_dict(state, strict=True)
model.eval()
gen = torch.Generator().manual_seed(42)
x = torch.rand(2, 540, 20, generator=gen, dtype=torch.float32)
with torch.no_grad():
y = model(x)
print(f"fixture input {tuple(x.shape)} -> output {tuple(y.shape)}, "
f"output range [{y.min().item():.6f}, {y.max().item():.6f}]")
np.savez(os.path.join(RESULTS, "parity_fixture.npz"),
input=x.numpy(), output=y.numpy())
fixture = {
"seed": 42,
"input_shape": list(x.shape),
"input": x.flatten().tolist(),
"output_shape": list(y.shape),
"output": y.flatten().tolist(),
}
json_path = os.path.join(RESULTS, "parity_fixture.json")
with open(json_path, "w") as f:
json.dump(fixture, f)
print(f"wrote {os.path.join(RESULTS, 'parity_fixture.npz')}")
print(f"wrote {json_path}")
if __name__ == "__main__":
main()