Closes #391 (full-replace footgun). Phase 1 of #574 (esp32-csi-node provisioning UX). The mDNS discovery + USB-CDC pairing work in #574 remains future work; this PR handles only the provision.py-side fix. Background: provision.py flashed a fresh NVS partition at 0x9000 every invocation. The previous behaviour built that partition only from the CLI flags passed on the current run — every key you didn't pass was silently erased. We hit it ourselves earlier today: --force-partial only suppressed the safety check but still wiped the SSID. This PR replaces the full-replace semantic with a per-port state file that captures every config value previously flashed from this machine. On each invocation: 1. Read ~/.config/wifi-densepose/esp32-provision-state/<port>.json (or %APPDATA%/... on Windows). 2. Overlay the new CLI flags on top — CLI wins where set. 3. Generate + flash NVS from the merged dict. 4. Persist the merged dict back to the state file. Net effect: the exact scenario from #391 + today's incident now passes (test_partial_invocation_does_not_drop_unrelated_keys): python provision.py --port COM7 --ssid Net --password p --target-ip 10.0.0.5 # later: python provision.py --port COM7 --seed-url http://10.0.0.99:8080 # WiFi creds preserved, seed_url added. New flags: --reset Wipe per-port state before merging (recycled-board path). --state-dir Override per-user state dir (XDG / %APPDATA% by default). --state Print the merged state and exit (debug / inspection). --force-partial preserved as a deprecation-flagged escape hatch. State file caveats (in the module docstring): per-machine, atomic write via .tmp + os.replace, future follow-up to add USB-CDC NVS dump for device-authoritative merging is tracked in #574. Tests: tests/test_provision_state.py — 11 tests covering load/save round-trip, corrupt-JSON resilience, CLI-wins-over-prior, the exact #391 case, falsy-but-not-None CLI override (node_id=0 must survive), and serial-port path sanitization for /dev/ttyUSB0. 11/11 pass. Live-tested end-to-end with --dry-run + --state inspection: first run: ssid + password + target_ip persisted second run: --seed-url added — WiFi creds intact in final state.
This commit is contained in:
parent
4b1a835107
commit
dc7f6cd096
|
|
@ -14,15 +14,35 @@ Requirements:
|
|||
pip install 'esptool>=5.0' nvs-partition-gen
|
||||
(or use the nvs_partition_gen.py bundled with ESP-IDF)
|
||||
|
||||
WARNING -- FULL-REPLACE SEMANTICS (issue #391):
|
||||
Every invocation REPLACES the entire `csi_cfg` NVS namespace on the device.
|
||||
Any key you don't pass on the CLI is erased. Always include WiFi credentials
|
||||
(--ssid, --password, --target-ip) unless you pass --force-partial.
|
||||
ADDITIVE-BY-DEFAULT (issue #391, #574 phase 1):
|
||||
Earlier versions of this script REPLACED the entire `csi_cfg` NVS namespace
|
||||
on the device every invocation, wiping any key you didn't pass on the CLI.
|
||||
That cost customers hours of unnecessary friction.
|
||||
|
||||
The script now MERGES new CLI flags with the per-port state previously
|
||||
written from this machine (stored under your user config dir; see
|
||||
`--state-dir` to override or `--state` to inspect). On every invocation:
|
||||
|
||||
1. Read the prior per-port state file (or treat as empty if absent).
|
||||
2. Overlay the new CLI flags on top.
|
||||
3. Generate + flash NVS from the merged state.
|
||||
4. Write the merged state back to the state file.
|
||||
|
||||
Net effect: partial reconfigure works the way users expect. Pass `--reset`
|
||||
to wipe both the state file AND the device NVS for first-time provisioning
|
||||
of a recycled board.
|
||||
|
||||
Caveat: state lives on the controlling machine. Provisioning the same
|
||||
device from a second machine starts from an empty state — pass the keys
|
||||
you want to keep on that invocation, or pre-seed the state file. A future
|
||||
follow-up will add USB-CDC NVS dump for true device-authoritative merging
|
||||
(tracked in #574).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import subprocess
|
||||
|
|
@ -70,6 +90,90 @@ def has_config_value(args):
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-port state file (additive-by-default merging, #391 / #574)
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# The state file is JSON keyed by `args` attribute name. It captures every
|
||||
# config value previously written to a given serial port from this machine.
|
||||
# On the next invocation, missing CLI flags fall back to the stored value.
|
||||
|
||||
# argparse attribute names that participate in the merge. Order doesn't
|
||||
# matter; this is just the surface area to round-trip.
|
||||
MERGEABLE_ATTRS = [
|
||||
"ssid", "password", "target_ip", "target_port", "node_id",
|
||||
"tdm_slot", "tdm_total",
|
||||
"edge_tier", "pres_thresh", "fall_thresh",
|
||||
"vital_win", "vital_int", "subk_count",
|
||||
"channel", "filter_mac",
|
||||
"hop_channels", "hop_dwell",
|
||||
"seed_url", "seed_token", "zone", "swarm_hb", "swarm_ingest",
|
||||
]
|
||||
|
||||
|
||||
def _default_state_dir() -> str:
|
||||
"""Per-user config dir for provision-state JSON files."""
|
||||
env = os.environ
|
||||
if sys.platform == "win32":
|
||||
base = env.get("APPDATA") or os.path.expanduser("~")
|
||||
else:
|
||||
base = env.get("XDG_CONFIG_HOME") or os.path.join(
|
||||
os.path.expanduser("~"), ".config"
|
||||
)
|
||||
return os.path.join(base, "wifi-densepose", "esp32-provision-state")
|
||||
|
||||
|
||||
def _state_path_for(port: str, state_dir: str) -> str:
|
||||
"""File path for a given serial port. Sanitize the port for filesystem use."""
|
||||
safe = port.replace("/", "_").replace(":", "_").replace("\\", "_")
|
||||
return os.path.join(state_dir, f"{safe}.json")
|
||||
|
||||
|
||||
def load_state(port: str, state_dir: str) -> dict:
|
||||
"""Return the merged-state dict for `port`, or `{}` if absent / unreadable."""
|
||||
path = _state_path_for(port, state_dir)
|
||||
if not os.path.isfile(path):
|
||||
return {}
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
print(f"WARNING: could not read state file {path}: {exc}", file=sys.stderr)
|
||||
return {}
|
||||
|
||||
|
||||
def save_state(port: str, state_dir: str, state: dict) -> str:
|
||||
"""Write `state` to the per-port file, creating dirs as needed. Returns path."""
|
||||
os.makedirs(state_dir, exist_ok=True)
|
||||
path = _state_path_for(port, state_dir)
|
||||
# Sort keys for deterministic on-disk content (easier to diff).
|
||||
tmp = path + ".tmp"
|
||||
with open(tmp, "w", encoding="utf-8") as f:
|
||||
json.dump(state, f, indent=2, sort_keys=True)
|
||||
f.write("\n")
|
||||
os.replace(tmp, path)
|
||||
return path
|
||||
|
||||
|
||||
def merge_state_into_args(args, prior: dict) -> dict:
|
||||
"""Overlay `args` onto `prior` for every MERGEABLE_ATTRS attribute.
|
||||
|
||||
CLI values win whenever they were explicitly set (i.e. not `None`).
|
||||
Returns the merged dict (for state persistence) and mutates `args`
|
||||
in place so downstream `build_nvs_csv` sees the merged values.
|
||||
"""
|
||||
merged = dict(prior)
|
||||
for name in MERGEABLE_ATTRS:
|
||||
cli_val = getattr(args, name, None)
|
||||
if cli_val is not None:
|
||||
merged[name] = cli_val
|
||||
elif name in merged:
|
||||
setattr(args, name, merged[name])
|
||||
return merged
|
||||
|
||||
|
||||
def build_nvs_csv(args):
|
||||
"""Build an NVS CSV string for the csi_cfg namespace."""
|
||||
buf = io.StringIO()
|
||||
|
|
@ -250,19 +354,45 @@ def main():
|
|||
parser.add_argument("--swarm-ingest", type=int, help="Swarm vector ingest interval in seconds (default 5)")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Generate NVS binary but don't flash")
|
||||
parser.add_argument("--force-partial", action="store_true",
|
||||
help="Allow partial config without WiFi credentials. "
|
||||
"WARNING: flashing REPLACES the entire csi_cfg NVS namespace - "
|
||||
"any key not passed on the CLI will be erased (issue #391).")
|
||||
help="[deprecated since #391/#574] Suppress the missing-WiFi-trio "
|
||||
"error when no prior state file exists. The script now merges "
|
||||
"with prior state by default, so this flag is rarely needed.")
|
||||
parser.add_argument("--reset", action="store_true",
|
||||
help="Wipe this machine's per-port state file before merging. "
|
||||
"Use for first-time provisioning of a recycled board where "
|
||||
"previously-staged keys should NOT be re-applied.")
|
||||
parser.add_argument("--state-dir", default=_default_state_dir(),
|
||||
help="Override the per-user state directory (default: per-OS user config dir).")
|
||||
parser.add_argument("--state", action="store_true",
|
||||
help="Print the merged state that WOULD be flashed for this port and exit. "
|
||||
"Useful for debugging which keys are about to land on the device.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not has_config_value(args):
|
||||
parser.error("At least one config value must be specified")
|
||||
# --- Per-port state load + merge (additive-by-default, #391 / #574) ---
|
||||
if args.reset:
|
||||
path = _state_path_for(args.port, args.state_dir)
|
||||
if os.path.isfile(path):
|
||||
os.unlink(path)
|
||||
print(f"--reset: removed state file {path}", file=sys.stderr)
|
||||
prior = {}
|
||||
else:
|
||||
prior = load_state(args.port, args.state_dir)
|
||||
merged = merge_state_into_args(args, prior)
|
||||
|
||||
# Bug 2 (#391): Prevent silent wipe of WiFi credentials on partial invocations.
|
||||
# Flashing the generated NVS binary to offset 0x9000 REPLACES the entire
|
||||
# csi_cfg namespace — there is no merge with existing NVS. Require the full
|
||||
# WiFi trio unless the user explicitly opts in with --force-partial.
|
||||
if args.state:
|
||||
print(json.dumps(merged, indent=2, sort_keys=True))
|
||||
return
|
||||
|
||||
if not has_config_value(args):
|
||||
parser.error(
|
||||
"At least one config value must be specified (after merging prior state). "
|
||||
"If you intended to start fresh, pass --reset and the keys you want."
|
||||
)
|
||||
|
||||
# WiFi-trio sanity check. After the merge, the trio should be present
|
||||
# unless the user is intentionally provisioning a brand-new board with
|
||||
# partial state. Keep --force-partial as the escape hatch for that case.
|
||||
wifi_trio_missing = [
|
||||
name for name, val in [
|
||||
("--ssid", args.ssid),
|
||||
|
|
@ -272,20 +402,19 @@ def main():
|
|||
]
|
||||
if wifi_trio_missing and not args.force_partial:
|
||||
parser.error(
|
||||
f"Missing required WiFi credentials: {', '.join(wifi_trio_missing)}.\n"
|
||||
f"Missing required WiFi credentials after merging prior state: "
|
||||
f"{', '.join(wifi_trio_missing)}.\n"
|
||||
f"\n"
|
||||
f" provision.py REPLACES the entire csi_cfg NVS namespace on each run.\n"
|
||||
f" Any key not passed on the CLI will be erased -- including WiFi creds.\n"
|
||||
f"\n"
|
||||
f" Either pass all of --ssid, --password, --target-ip,\n"
|
||||
f" or add --force-partial to acknowledge that other NVS keys will be wiped."
|
||||
f" No per-port state file at {_state_path_for(args.port, args.state_dir)}\n"
|
||||
f" and the CLI didn't include them. Either pass --ssid + --password + --target-ip\n"
|
||||
f" on this run, or add --force-partial to flash without WiFi.\n"
|
||||
)
|
||||
if args.force_partial and wifi_trio_missing:
|
||||
print("WARNING: --force-partial is set. The following NVS keys will be WIPED "
|
||||
"(not present in this invocation):", file=sys.stderr)
|
||||
for k in wifi_trio_missing:
|
||||
print(f" - {k.lstrip('-')}", file=sys.stderr)
|
||||
print(" Plus any other csi_cfg keys not passed on the CLI.\n", file=sys.stderr)
|
||||
print(
|
||||
"WARNING: --force-partial is set and WiFi credentials are missing. "
|
||||
"The device will not connect to WiFi after flashing.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Validate TDM: if one is given, both should be
|
||||
if (args.tdm_slot is not None) != (args.tdm_total is not None):
|
||||
|
|
@ -371,9 +500,18 @@ def main():
|
|||
print(f"NVS binary saved to {out} ({len(nvs_bin)} bytes)")
|
||||
print(f"Flash manually: python -m esptool --chip {args.chip} --port {args.port} "
|
||||
f"write-flash 0x9000 {out}")
|
||||
# Persist merged state even on dry-run so a subsequent real flash from
|
||||
# this machine sees the same staged config.
|
||||
path = save_state(args.port, args.state_dir, merged)
|
||||
print(f"State persisted to {path}")
|
||||
return
|
||||
|
||||
flash_nvs(args.port, args.baud, nvs_bin, args.chip)
|
||||
# Persist merged state after a successful flash so future partial
|
||||
# invocations from this machine merge on top of what's actually on the
|
||||
# device. This is the heart of the additive-by-default fix (#391/#574).
|
||||
path = save_state(args.port, args.state_dir, merged)
|
||||
print(f"State persisted to {path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,129 @@
|
|||
"""Tests for provision.py's additive-by-default merge behaviour (#391, #574)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
# Allow `python -m unittest` from anywhere in the repo.
|
||||
HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.dirname(HERE))
|
||||
|
||||
import provision # noqa: E402 — sibling import after sys.path tweak
|
||||
|
||||
|
||||
def _mk_args(**overrides) -> argparse.Namespace:
|
||||
"""Build a Namespace with every mergeable attr set to None unless overridden."""
|
||||
base = {name: None for name in provision.MERGEABLE_ATTRS}
|
||||
base.update(overrides)
|
||||
return argparse.Namespace(**base)
|
||||
|
||||
|
||||
class TestStateFile(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dir = tempfile.mkdtemp(prefix="provision-state-")
|
||||
|
||||
def tearDown(self):
|
||||
import shutil
|
||||
shutil.rmtree(self.dir, ignore_errors=True)
|
||||
|
||||
def test_load_state_empty_when_missing(self):
|
||||
self.assertEqual(provision.load_state("COM7", self.dir), {})
|
||||
|
||||
def test_save_then_load_roundtrip(self):
|
||||
provision.save_state("COM7", self.dir, {"ssid": "x", "password": "y"})
|
||||
self.assertEqual(
|
||||
provision.load_state("COM7", self.dir),
|
||||
{"ssid": "x", "password": "y"},
|
||||
)
|
||||
|
||||
def test_save_creates_per_port_files(self):
|
||||
provision.save_state("COM7", self.dir, {"ssid": "a"})
|
||||
provision.save_state("/dev/ttyUSB0", self.dir, {"ssid": "b"})
|
||||
self.assertEqual(provision.load_state("COM7", self.dir), {"ssid": "a"})
|
||||
self.assertEqual(provision.load_state("/dev/ttyUSB0", self.dir), {"ssid": "b"})
|
||||
|
||||
def test_load_state_handles_corrupt_json(self):
|
||||
path = provision._state_path_for("COM7", self.dir)
|
||||
os.makedirs(self.dir, exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write("{not valid json")
|
||||
# Should warn but not raise.
|
||||
self.assertEqual(provision.load_state("COM7", self.dir), {})
|
||||
|
||||
|
||||
class TestMerge(unittest.TestCase):
|
||||
def test_cli_wins_over_prior(self):
|
||||
args = _mk_args(ssid="new-ssid")
|
||||
prior = {"ssid": "old-ssid", "password": "abc"}
|
||||
merged = provision.merge_state_into_args(args, prior)
|
||||
self.assertEqual(args.ssid, "new-ssid") # CLI value preserved
|
||||
self.assertEqual(args.password, "abc") # filled from prior
|
||||
self.assertEqual(merged["ssid"], "new-ssid")
|
||||
self.assertEqual(merged["password"], "abc")
|
||||
|
||||
def test_prior_fills_missing_cli(self):
|
||||
args = _mk_args() # all None
|
||||
prior = {
|
||||
"ssid": "MyWiFi",
|
||||
"password": "secret",
|
||||
"target_ip": "192.168.1.20",
|
||||
"node_id": 3,
|
||||
}
|
||||
merged = provision.merge_state_into_args(args, prior)
|
||||
self.assertEqual(args.ssid, "MyWiFi")
|
||||
self.assertEqual(args.password, "secret")
|
||||
self.assertEqual(args.target_ip, "192.168.1.20")
|
||||
self.assertEqual(args.node_id, 3)
|
||||
for key, val in prior.items():
|
||||
self.assertEqual(merged[key], val)
|
||||
|
||||
def test_partial_invocation_does_not_drop_unrelated_keys(self):
|
||||
# The exact #391 scenario: user previously provisioned WiFi, now adds
|
||||
# only --seed-url. Old behaviour wiped SSID. New behaviour keeps it.
|
||||
args = _mk_args(seed_url="http://10.1.10.236")
|
||||
prior = {
|
||||
"ssid": "ruv.net",
|
||||
"password": "<secret>",
|
||||
"target_ip": "192.168.1.20",
|
||||
}
|
||||
merged = provision.merge_state_into_args(args, prior)
|
||||
self.assertEqual(args.ssid, "ruv.net")
|
||||
self.assertEqual(args.password, "<secret>")
|
||||
self.assertEqual(args.target_ip, "192.168.1.20")
|
||||
self.assertEqual(args.seed_url, "http://10.1.10.236")
|
||||
# And the on-disk merged dict carries all four keys.
|
||||
self.assertEqual(set(merged.keys()),
|
||||
{"ssid", "password", "target_ip", "seed_url"})
|
||||
|
||||
def test_empty_prior_is_noop(self):
|
||||
args = _mk_args(ssid="x")
|
||||
merged = provision.merge_state_into_args(args, {})
|
||||
self.assertEqual(merged, {"ssid": "x"})
|
||||
|
||||
def test_falsy_but_not_none_cli_value_overrides_prior(self):
|
||||
# node_id=0 is a legal value; must NOT be replaced by prior["node_id"]=5.
|
||||
args = _mk_args(node_id=0)
|
||||
prior = {"node_id": 5}
|
||||
merged = provision.merge_state_into_args(args, prior)
|
||||
self.assertEqual(args.node_id, 0)
|
||||
self.assertEqual(merged["node_id"], 0)
|
||||
|
||||
|
||||
class TestStatePathSanitization(unittest.TestCase):
|
||||
def test_slashes_in_port_are_safe(self):
|
||||
path = provision._state_path_for("/dev/ttyUSB0", "/tmp/x")
|
||||
# Must not contain a raw slash in the basename
|
||||
self.assertNotIn("/", os.path.basename(path))
|
||||
|
||||
def test_windows_com_port_is_safe(self):
|
||||
path = provision._state_path_for("COM7", "/tmp/x")
|
||||
self.assertTrue(path.endswith("COM7.json"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue