wifi-densepose/scripts/ota-deploy.sh

276 lines
9.5 KiB
Bash
Executable File

#!/usr/bin/env python3
"""
scripts/ota-deploy.sh — push esp32-csi-node.bin to one or more sensor nodes
over WiFi. Talks to the on-device /ota endpoint (ADR-045, port 8032,
handler in firmware/esp32-csi-node/main/ota_update.c).
Usage:
scripts/ota-deploy.sh # auto-discover via ARP, deploy to all
scripts/ota-deploy.sh 192.168.0.100 # one node
scripts/ota-deploy.sh 192.168.0.100 192.168.0.101
scripts/ota-deploy.sh --build # idf.py build first, then deploy
scripts/ota-deploy.sh --no-verify ... # skip post-reboot /ota/status check
Auth: set env OTA_PSK=<token> to send "Authorization: Bearer <token>"
(matches the on-device check in ota_update.c::ota_check_auth).
Exit codes:
0 — all targeted nodes confirmed running_partition flipped
1 — one or more nodes failed verification or were unreachable
2 — build or argument error
"""
from __future__ import annotations
import argparse
import concurrent.futures as cf
import json
import os
import re
import shutil
import subprocess
import sys
import time
import urllib.error
import urllib.request
from pathlib import Path
from typing import Iterable
REPO_ROOT = Path(__file__).resolve().parent.parent
FW_DIR = REPO_ROOT / "firmware" / "esp32-csi-node"
BIN_PATH = FW_DIR / "build" / "esp32-csi-node.bin"
PORT = 8032
UPLOAD_TIMEOUT_S = 120
REBOOT_WAIT_S = 10
VERIFY_RETRIES = 6
VERIFY_DELAY_S = 3
# ---- ANSI logging helpers ----------------------------------------------------
def _c(code: str, msg: str) -> str:
if not sys.stdout.isatty():
return msg
return f"\033[{code}m{msg}\033[0m"
def log(msg: str) -> None: print(_c("36", "[ota-deploy] ") + msg, flush=True)
def warn(msg: str) -> None: print(_c("33", "[ota-deploy] ") + msg, file=sys.stderr, flush=True)
def err(msg: str) -> None: print(_c("31", "[ota-deploy] ") + msg, file=sys.stderr, flush=True)
# ---- helpers -----------------------------------------------------------------
def http_get(url: str, timeout: float = 4.0) -> str | None:
try:
with urllib.request.urlopen(url, timeout=timeout) as r:
return r.read().decode("utf-8", errors="replace")
except (urllib.error.URLError, urllib.error.HTTPError, TimeoutError, OSError):
return None
def get_ota_status(ip: str) -> dict | None:
body = http_get(f"http://{ip}:{PORT}/ota/status")
if not body:
return None
try:
return json.loads(body)
except json.JSONDecodeError:
return None
def local_subnet_prefix() -> str | None:
"""Return e.g. '192.168.0' from en0 (macOS) or first non-loopback IP."""
try:
out = subprocess.check_output(
["ipconfig", "getifaddr", "en0"], stderr=subprocess.DEVNULL, text=True
).strip()
if out:
return out.rsplit(".", 1)[0]
except (subprocess.CalledProcessError, FileNotFoundError):
pass
# Linux fallback
try:
out = subprocess.check_output(["hostname", "-I"], text=True).strip()
if out:
return out.split()[0].rsplit(".", 1)[0]
except (subprocess.CalledProcessError, FileNotFoundError):
pass
return None
def discover_nodes() -> list[str]:
"""ARP-prefilter + parallel /ota/status probe to find live sensor nodes."""
prefix = local_subnet_prefix()
if not prefix:
err("could not determine local /24 — pass node IPs explicitly")
return []
log(f"scanning {prefix}.0/24 for /ota/status responders ...")
candidates: list[str] = []
try:
arp_out = subprocess.check_output(
["arp", "-a", "-n"], text=True, stderr=subprocess.DEVNULL
)
for line in arp_out.splitlines():
m = re.search(rf"\(({re.escape(prefix)}\.\d+)\)", line)
if m and "incomplete" not in line:
ip = m.group(1)
if not ip.endswith(".1"): # skip gateway
candidates.append(ip)
except (subprocess.CalledProcessError, FileNotFoundError):
pass
if not candidates:
warn(f"no ARP hits — falling back to {prefix}.100-110 ping sweep")
candidates = [f"{prefix}.{i}" for i in range(100, 111)]
candidates = sorted(set(candidates))
found: list[str] = []
with cf.ThreadPoolExecutor(max_workers=32) as pool:
futs = {pool.submit(get_ota_status, ip): ip for ip in candidates}
for fut in cf.as_completed(futs):
ip = futs[fut]
try:
if fut.result():
found.append(ip)
except Exception:
pass
return sorted(found, key=lambda x: tuple(int(o) for o in x.split(".")))
def upload_one(ip: str, payload: bytes, psk: str | None) -> tuple[bool, float, str]:
"""POST the firmware to one node. Returns (success, elapsed_s, body_snippet)."""
req = urllib.request.Request(
f"http://{ip}:{PORT}/ota",
data=payload,
headers={"Content-Type": "application/octet-stream"},
method="POST",
)
if psk:
req.add_header("Authorization", f"Bearer {psk}")
t0 = time.monotonic()
try:
with urllib.request.urlopen(req, timeout=UPLOAD_TIMEOUT_S) as r:
body = r.read().decode("utf-8", errors="replace")[:200]
return True, time.monotonic() - t0, body
except (urllib.error.HTTPError, urllib.error.URLError,
TimeoutError, ConnectionResetError, OSError) as e:
# ConnectionReset is *expected* when the chip restarts before flushing
# the response. We treat it as a soft pass and verify via /ota/status.
return (isinstance(e, ConnectionResetError),
time.monotonic() - t0,
f"{type(e).__name__}: {e}")
def build_firmware() -> int:
log("building firmware via idf.py ...")
if "IDF_PATH" not in os.environ:
export = Path.home() / "esp" / "esp-idf-v5.2" / "export.sh"
if not export.is_file():
err("IDF_PATH not set and ~/esp/esp-idf-v5.2/export.sh not found")
return 2
# source the env in a child shell
rc = subprocess.call(
["bash", "-lc", f". '{export}' >/dev/null 2>&1 && cd '{FW_DIR}' && idf.py build"]
)
else:
rc = subprocess.call(["idf.py", "build"], cwd=str(FW_DIR))
if rc != 0:
err("build failed")
return 2
return 0
# ---- main --------------------------------------------------------------------
def main(argv: list[str]) -> int:
ap = argparse.ArgumentParser(
prog="ota-deploy.sh",
description="Push esp32-csi-node.bin to one or more sensor nodes over WiFi.",
)
ap.add_argument("targets", nargs="*",
help="node IPs; auto-discover if omitted")
ap.add_argument("--build", action="store_true",
help="idf.py build before deploying")
ap.add_argument("--no-verify", action="store_true",
help="skip post-reboot /ota/status confirmation")
args = ap.parse_args(argv)
if args.build:
rc = build_firmware()
if rc != 0:
return rc
if not BIN_PATH.is_file():
err(f"firmware binary not found: {BIN_PATH} — pass --build first")
return 2
payload = BIN_PATH.read_bytes()
log(f"firmware: {BIN_PATH} ({len(payload)} bytes)")
targets = args.targets or discover_nodes()
if not targets:
err("no nodes given and none discovered")
return 1
log(f"targets: {' '.join(targets)}")
# snapshot before
before: dict[str, str] = {}
for ip in targets:
st = get_ota_status(ip)
if not st:
warn(f"{ip}: not reachable before upload")
before[ip] = "UNREACHABLE"
continue
before[ip] = st.get("running_partition", "UNKNOWN")
log(f"{ip} before: running_partition={before[ip]} time={st.get('time')}")
psk = os.environ.get("OTA_PSK") or None
if psk:
log("OTA_PSK set — sending Bearer token")
# upload in parallel
log("uploading in parallel ...")
results: dict[str, tuple[bool, float, str]] = {}
with cf.ThreadPoolExecutor(max_workers=max(2, len(targets))) as pool:
futs = {pool.submit(upload_one, ip, payload, psk): ip for ip in targets}
for fut in cf.as_completed(futs):
ip = futs[fut]
ok, dt, body = fut.result()
results[ip] = (ok, dt, body)
tag = _c("32", "ok") if ok else _c("31", "ERR")
log(f"{ip} upload {tag} in {dt:.1f}s body={body[:120]}")
if args.no_verify:
log("--no-verify — done")
return 0 if all(v[0] for v in results.values()) else 1
# verify
log(f"waiting {REBOOT_WAIT_S}s for reboot ...")
time.sleep(REBOOT_WAIT_S)
fail = False
for ip in targets:
new_st: dict | None = None
for _ in range(VERIFY_RETRIES):
new_st = get_ota_status(ip)
if new_st:
break
time.sleep(VERIFY_DELAY_S)
if not new_st:
err(f"{ip}: not reachable after reboot — DEAD or panic loop")
fail = True
continue
new_part = new_st.get("running_partition", "?")
new_time = new_st.get("time", "?")
if new_part == before.get(ip):
err(f"{ip}: running_partition still {new_part} — OTA did NOT take "
"(likely panic on first boot from new slot)")
fail = True
else:
log(f"{ip}: {before[ip]} → {_c('32', new_part)} (time={new_time}) ✓")
return 1 if fail else 0
if __name__ == "__main__":
try:
sys.exit(main(sys.argv[1:]))
except KeyboardInterrupt:
err("interrupted")
sys.exit(130)