diff --git a/firmware/esp32-csi-node/provision.py b/firmware/esp32-csi-node/provision.py index d450de99..76d57058 100644 --- a/firmware/esp32-csi-node/provision.py +++ b/firmware/esp32-csi-node/provision.py @@ -14,15 +14,35 @@ Requirements: pip install 'esptool>=5.0' nvs-partition-gen (or use the nvs_partition_gen.py bundled with ESP-IDF) -WARNING -- FULL-REPLACE SEMANTICS (issue #391): - Every invocation REPLACES the entire `csi_cfg` NVS namespace on the device. - Any key you don't pass on the CLI is erased. Always include WiFi credentials - (--ssid, --password, --target-ip) unless you pass --force-partial. +ADDITIVE-BY-DEFAULT (issue #391, #574 phase 1): + Earlier versions of this script REPLACED the entire `csi_cfg` NVS namespace + on the device every invocation, wiping any key you didn't pass on the CLI. + That cost customers hours of unnecessary friction. + + The script now MERGES new CLI flags with the per-port state previously + written from this machine (stored under your user config dir; see + `--state-dir` to override or `--state` to inspect). On every invocation: + + 1. Read the prior per-port state file (or treat as empty if absent). + 2. Overlay the new CLI flags on top. + 3. Generate + flash NVS from the merged state. + 4. Write the merged state back to the state file. + + Net effect: partial reconfigure works the way users expect. Pass `--reset` + to wipe both the state file AND the device NVS for first-time provisioning + of a recycled board. + + Caveat: state lives on the controlling machine. Provisioning the same + device from a second machine starts from an empty state — pass the keys + you want to keep on that invocation, or pre-seed the state file. A future + follow-up will add USB-CDC NVS dump for true device-authoritative merging + (tracked in #574). """ import argparse import csv import io +import json import os import struct import subprocess @@ -70,6 +90,90 @@ def has_config_value(args): ) +# --------------------------------------------------------------------------- +# Per-port state file (additive-by-default merging, #391 / #574) +# --------------------------------------------------------------------------- +# +# The state file is JSON keyed by `args` attribute name. It captures every +# config value previously written to a given serial port from this machine. +# On the next invocation, missing CLI flags fall back to the stored value. + +# argparse attribute names that participate in the merge. Order doesn't +# matter; this is just the surface area to round-trip. +MERGEABLE_ATTRS = [ + "ssid", "password", "target_ip", "target_port", "node_id", + "tdm_slot", "tdm_total", + "edge_tier", "pres_thresh", "fall_thresh", + "vital_win", "vital_int", "subk_count", + "channel", "filter_mac", + "hop_channels", "hop_dwell", + "seed_url", "seed_token", "zone", "swarm_hb", "swarm_ingest", +] + + +def _default_state_dir() -> str: + """Per-user config dir for provision-state JSON files.""" + env = os.environ + if sys.platform == "win32": + base = env.get("APPDATA") or os.path.expanduser("~") + else: + base = env.get("XDG_CONFIG_HOME") or os.path.join( + os.path.expanduser("~"), ".config" + ) + return os.path.join(base, "wifi-densepose", "esp32-provision-state") + + +def _state_path_for(port: str, state_dir: str) -> str: + """File path for a given serial port. Sanitize the port for filesystem use.""" + safe = port.replace("/", "_").replace(":", "_").replace("\\", "_") + return os.path.join(state_dir, f"{safe}.json") + + +def load_state(port: str, state_dir: str) -> dict: + """Return the merged-state dict for `port`, or `{}` if absent / unreadable.""" + path = _state_path_for(port, state_dir) + if not os.path.isfile(path): + return {} + try: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + if isinstance(data, dict): + return data + except (OSError, json.JSONDecodeError) as exc: + print(f"WARNING: could not read state file {path}: {exc}", file=sys.stderr) + return {} + + +def save_state(port: str, state_dir: str, state: dict) -> str: + """Write `state` to the per-port file, creating dirs as needed. Returns path.""" + os.makedirs(state_dir, exist_ok=True) + path = _state_path_for(port, state_dir) + # Sort keys for deterministic on-disk content (easier to diff). + tmp = path + ".tmp" + with open(tmp, "w", encoding="utf-8") as f: + json.dump(state, f, indent=2, sort_keys=True) + f.write("\n") + os.replace(tmp, path) + return path + + +def merge_state_into_args(args, prior: dict) -> dict: + """Overlay `args` onto `prior` for every MERGEABLE_ATTRS attribute. + + CLI values win whenever they were explicitly set (i.e. not `None`). + Returns the merged dict (for state persistence) and mutates `args` + in place so downstream `build_nvs_csv` sees the merged values. + """ + merged = dict(prior) + for name in MERGEABLE_ATTRS: + cli_val = getattr(args, name, None) + if cli_val is not None: + merged[name] = cli_val + elif name in merged: + setattr(args, name, merged[name]) + return merged + + def build_nvs_csv(args): """Build an NVS CSV string for the csi_cfg namespace.""" buf = io.StringIO() @@ -250,19 +354,45 @@ def main(): parser.add_argument("--swarm-ingest", type=int, help="Swarm vector ingest interval in seconds (default 5)") parser.add_argument("--dry-run", action="store_true", help="Generate NVS binary but don't flash") parser.add_argument("--force-partial", action="store_true", - help="Allow partial config without WiFi credentials. " - "WARNING: flashing REPLACES the entire csi_cfg NVS namespace - " - "any key not passed on the CLI will be erased (issue #391).") + help="[deprecated since #391/#574] Suppress the missing-WiFi-trio " + "error when no prior state file exists. The script now merges " + "with prior state by default, so this flag is rarely needed.") + parser.add_argument("--reset", action="store_true", + help="Wipe this machine's per-port state file before merging. " + "Use for first-time provisioning of a recycled board where " + "previously-staged keys should NOT be re-applied.") + parser.add_argument("--state-dir", default=_default_state_dir(), + help="Override the per-user state directory (default: per-OS user config dir).") + parser.add_argument("--state", action="store_true", + help="Print the merged state that WOULD be flashed for this port and exit. " + "Useful for debugging which keys are about to land on the device.") args = parser.parse_args() - if not has_config_value(args): - parser.error("At least one config value must be specified") + # --- Per-port state load + merge (additive-by-default, #391 / #574) --- + if args.reset: + path = _state_path_for(args.port, args.state_dir) + if os.path.isfile(path): + os.unlink(path) + print(f"--reset: removed state file {path}", file=sys.stderr) + prior = {} + else: + prior = load_state(args.port, args.state_dir) + merged = merge_state_into_args(args, prior) - # Bug 2 (#391): Prevent silent wipe of WiFi credentials on partial invocations. - # Flashing the generated NVS binary to offset 0x9000 REPLACES the entire - # csi_cfg namespace — there is no merge with existing NVS. Require the full - # WiFi trio unless the user explicitly opts in with --force-partial. + if args.state: + print(json.dumps(merged, indent=2, sort_keys=True)) + return + + if not has_config_value(args): + parser.error( + "At least one config value must be specified (after merging prior state). " + "If you intended to start fresh, pass --reset and the keys you want." + ) + + # WiFi-trio sanity check. After the merge, the trio should be present + # unless the user is intentionally provisioning a brand-new board with + # partial state. Keep --force-partial as the escape hatch for that case. wifi_trio_missing = [ name for name, val in [ ("--ssid", args.ssid), @@ -272,20 +402,19 @@ def main(): ] if wifi_trio_missing and not args.force_partial: parser.error( - f"Missing required WiFi credentials: {', '.join(wifi_trio_missing)}.\n" + f"Missing required WiFi credentials after merging prior state: " + f"{', '.join(wifi_trio_missing)}.\n" f"\n" - f" provision.py REPLACES the entire csi_cfg NVS namespace on each run.\n" - f" Any key not passed on the CLI will be erased -- including WiFi creds.\n" - f"\n" - f" Either pass all of --ssid, --password, --target-ip,\n" - f" or add --force-partial to acknowledge that other NVS keys will be wiped." + f" No per-port state file at {_state_path_for(args.port, args.state_dir)}\n" + f" and the CLI didn't include them. Either pass --ssid + --password + --target-ip\n" + f" on this run, or add --force-partial to flash without WiFi.\n" ) if args.force_partial and wifi_trio_missing: - print("WARNING: --force-partial is set. The following NVS keys will be WIPED " - "(not present in this invocation):", file=sys.stderr) - for k in wifi_trio_missing: - print(f" - {k.lstrip('-')}", file=sys.stderr) - print(" Plus any other csi_cfg keys not passed on the CLI.\n", file=sys.stderr) + print( + "WARNING: --force-partial is set and WiFi credentials are missing. " + "The device will not connect to WiFi after flashing.", + file=sys.stderr, + ) # Validate TDM: if one is given, both should be if (args.tdm_slot is not None) != (args.tdm_total is not None): @@ -371,9 +500,18 @@ def main(): print(f"NVS binary saved to {out} ({len(nvs_bin)} bytes)") print(f"Flash manually: python -m esptool --chip {args.chip} --port {args.port} " f"write-flash 0x9000 {out}") + # Persist merged state even on dry-run so a subsequent real flash from + # this machine sees the same staged config. + path = save_state(args.port, args.state_dir, merged) + print(f"State persisted to {path}") return flash_nvs(args.port, args.baud, nvs_bin, args.chip) + # Persist merged state after a successful flash so future partial + # invocations from this machine merge on top of what's actually on the + # device. This is the heart of the additive-by-default fix (#391/#574). + path = save_state(args.port, args.state_dir, merged) + print(f"State persisted to {path}") if __name__ == "__main__": diff --git a/firmware/esp32-csi-node/tests/test_provision_state.py b/firmware/esp32-csi-node/tests/test_provision_state.py new file mode 100644 index 00000000..e55270e9 --- /dev/null +++ b/firmware/esp32-csi-node/tests/test_provision_state.py @@ -0,0 +1,129 @@ +"""Tests for provision.py's additive-by-default merge behaviour (#391, #574).""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import tempfile +import unittest + +# Allow `python -m unittest` from anywhere in the repo. +HERE = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.dirname(HERE)) + +import provision # noqa: E402 — sibling import after sys.path tweak + + +def _mk_args(**overrides) -> argparse.Namespace: + """Build a Namespace with every mergeable attr set to None unless overridden.""" + base = {name: None for name in provision.MERGEABLE_ATTRS} + base.update(overrides) + return argparse.Namespace(**base) + + +class TestStateFile(unittest.TestCase): + def setUp(self): + self.dir = tempfile.mkdtemp(prefix="provision-state-") + + def tearDown(self): + import shutil + shutil.rmtree(self.dir, ignore_errors=True) + + def test_load_state_empty_when_missing(self): + self.assertEqual(provision.load_state("COM7", self.dir), {}) + + def test_save_then_load_roundtrip(self): + provision.save_state("COM7", self.dir, {"ssid": "x", "password": "y"}) + self.assertEqual( + provision.load_state("COM7", self.dir), + {"ssid": "x", "password": "y"}, + ) + + def test_save_creates_per_port_files(self): + provision.save_state("COM7", self.dir, {"ssid": "a"}) + provision.save_state("/dev/ttyUSB0", self.dir, {"ssid": "b"}) + self.assertEqual(provision.load_state("COM7", self.dir), {"ssid": "a"}) + self.assertEqual(provision.load_state("/dev/ttyUSB0", self.dir), {"ssid": "b"}) + + def test_load_state_handles_corrupt_json(self): + path = provision._state_path_for("COM7", self.dir) + os.makedirs(self.dir, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + f.write("{not valid json") + # Should warn but not raise. + self.assertEqual(provision.load_state("COM7", self.dir), {}) + + +class TestMerge(unittest.TestCase): + def test_cli_wins_over_prior(self): + args = _mk_args(ssid="new-ssid") + prior = {"ssid": "old-ssid", "password": "abc"} + merged = provision.merge_state_into_args(args, prior) + self.assertEqual(args.ssid, "new-ssid") # CLI value preserved + self.assertEqual(args.password, "abc") # filled from prior + self.assertEqual(merged["ssid"], "new-ssid") + self.assertEqual(merged["password"], "abc") + + def test_prior_fills_missing_cli(self): + args = _mk_args() # all None + prior = { + "ssid": "MyWiFi", + "password": "secret", + "target_ip": "192.168.1.20", + "node_id": 3, + } + merged = provision.merge_state_into_args(args, prior) + self.assertEqual(args.ssid, "MyWiFi") + self.assertEqual(args.password, "secret") + self.assertEqual(args.target_ip, "192.168.1.20") + self.assertEqual(args.node_id, 3) + for key, val in prior.items(): + self.assertEqual(merged[key], val) + + def test_partial_invocation_does_not_drop_unrelated_keys(self): + # The exact #391 scenario: user previously provisioned WiFi, now adds + # only --seed-url. Old behaviour wiped SSID. New behaviour keeps it. + args = _mk_args(seed_url="http://10.1.10.236") + prior = { + "ssid": "ruv.net", + "password": "", + "target_ip": "192.168.1.20", + } + merged = provision.merge_state_into_args(args, prior) + self.assertEqual(args.ssid, "ruv.net") + self.assertEqual(args.password, "") + self.assertEqual(args.target_ip, "192.168.1.20") + self.assertEqual(args.seed_url, "http://10.1.10.236") + # And the on-disk merged dict carries all four keys. + self.assertEqual(set(merged.keys()), + {"ssid", "password", "target_ip", "seed_url"}) + + def test_empty_prior_is_noop(self): + args = _mk_args(ssid="x") + merged = provision.merge_state_into_args(args, {}) + self.assertEqual(merged, {"ssid": "x"}) + + def test_falsy_but_not_none_cli_value_overrides_prior(self): + # node_id=0 is a legal value; must NOT be replaced by prior["node_id"]=5. + args = _mk_args(node_id=0) + prior = {"node_id": 5} + merged = provision.merge_state_into_args(args, prior) + self.assertEqual(args.node_id, 0) + self.assertEqual(merged["node_id"], 0) + + +class TestStatePathSanitization(unittest.TestCase): + def test_slashes_in_port_are_safe(self): + path = provision._state_path_for("/dev/ttyUSB0", "/tmp/x") + # Must not contain a raw slash in the basename + self.assertNotIn("/", os.path.basename(path)) + + def test_windows_com_port_is_safe(self): + path = provision._state_path_for("COM7", "/tmp/x") + self.assertTrue(path.endswith("COM7.json")) + + +if __name__ == "__main__": + unittest.main()