fix(provision): recognize swarm/hopping flags as config values (#617)

This commit is contained in:
NgoQuocViet2001 2026-05-19 21:03:58 +07:00 committed by GitHub
parent c00f45e296
commit 3439fb1402
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 97 additions and 11 deletions

View File

@ -37,6 +37,39 @@ NVS_PARTITION_OFFSET = 0x9000
NVS_PARTITION_SIZE = 0x6000 # 24 KiB
CONFIG_VALUE_CHECKS = [
("ssid", bool),
("password", lambda value: value is not None),
("target_ip", bool),
("target_port", lambda value: value is not None),
("node_id", lambda value: value is not None),
("tdm_slot", lambda value: value is not None),
("tdm_total", lambda value: value is not None),
("edge_tier", lambda value: value is not None),
("pres_thresh", lambda value: value is not None),
("fall_thresh", lambda value: value is not None),
("vital_win", lambda value: value is not None),
("vital_int", lambda value: value is not None),
("subk_count", lambda value: value is not None),
("channel", lambda value: value is not None),
("filter_mac", lambda value: value is not None),
("hop_channels", lambda value: value is not None),
("seed_url", lambda value: value is not None),
("seed_token", lambda value: value is not None),
("zone", lambda value: value is not None),
("swarm_hb", lambda value: value is not None),
("swarm_ingest", lambda value: value is not None),
]
def has_config_value(args):
"""Return True when args include at least one NVS-writing config value."""
return any(
check(getattr(args, name, None))
for name, check in CONFIG_VALUE_CHECKS
)
def build_nvs_csv(args):
"""Build an NVS CSV string for the csi_cfg namespace."""
buf = io.StringIO()
@ -223,17 +256,7 @@ def main():
args = parser.parse_args()
has_value = any([
args.ssid, args.password is not None, args.target_ip,
args.target_port, args.node_id is not None,
args.tdm_slot is not None, args.tdm_total is not None,
args.edge_tier is not None, args.pres_thresh is not None,
args.fall_thresh is not None, args.vital_win is not None,
args.vital_int is not None, args.subk_count is not None,
args.channel is not None, args.filter_mac is not None,
args.seed_url is not None, args.zone is not None,
])
if not has_value:
if not has_config_value(args):
parser.error("At least one config value must be specified")
# Bug 2 (#391): Prevent silent wipe of WiFi credentials on partial invocations.

View File

@ -0,0 +1,63 @@
import csv
import importlib.util
import io
import types
import unittest
from pathlib import Path
PROVISION_PATH = Path(__file__).resolve().parents[1] / "provision.py"
SPEC = importlib.util.spec_from_file_location("provision", PROVISION_PATH)
provision = importlib.util.module_from_spec(SPEC)
SPEC.loader.exec_module(provision)
def make_args(**overrides):
values = {name: None for name, _ in provision.CONFIG_VALUE_CHECKS}
values["hop_dwell"] = 200
values.update(overrides)
return types.SimpleNamespace(**values)
def csv_rows(content):
return list(csv.DictReader(io.StringIO(content)))
class ProvisionConfigValueTests(unittest.TestCase):
def test_swarm_and_hopping_flags_count_as_config_values(self):
cases = [
{"hop_channels": "1,6,11"},
{"seed_token": "token-123"},
{"swarm_hb": 15},
{"swarm_ingest": 3},
]
for values in cases:
with self.subTest(values=values):
self.assertTrue(provision.has_config_value(make_args(**values)))
def test_operational_flags_alone_do_not_count_as_config_values(self):
self.assertFalse(provision.has_config_value(make_args()))
def test_swarm_and_hopping_values_are_written_to_csv(self):
args = make_args(
hop_channels="1,6,11",
hop_dwell=250,
seed_token="token-123",
swarm_hb=15,
swarm_ingest=3,
)
rows = csv_rows(provision.build_nvs_csv(args))
values_by_key = {row["key"]: row["value"] for row in rows}
self.assertEqual(values_by_key["hop_count"], "3")
self.assertEqual(values_by_key["chan_list"], "01060b")
self.assertEqual(values_by_key["dwell_ms"], "250")
self.assertEqual(values_by_key["seed_token"], "token-123")
self.assertEqual(values_by_key["swarm_hb"], "15")
self.assertEqual(values_by_key["swarm_ingest"], "3")
if __name__ == "__main__":
unittest.main()