perf(signal): cache PSD FFT planner (2.0–3.1x) + honor DTW band (2.4–4.1x) (ADR-154 M0)
Two measured, bit-equivalent perf wins. Each ships a criterion bench
(benches/features_bench.rs, new) with before/after numbers and a committed
bit-identity test — no perf claim without a measured before/after.
PSD FFT-planner caching (features.rs)
PowerSpectralDensity::from_csi_data re-planned a FftPlanner on EVERY frame,
and FeatureExtractor::extract calls it per frame on the hot path. New
from_csi_data_with_fft(csi, n, &Arc<dyn Fft>) reuses a plan cached in
FeatureExtractor (built once in new()). Bit-identical output
(psd_cached_fft_bit_identical_to_fresh, f64::to_bits over 6 sizes).
MEASURED (median ns/frame, criterion):
fft=64 5.84µs -> 1.89µs (3.09x)
fft=128 9.31µs -> 3.61µs (2.58x)
fft=256 13.77µs -> 6.73µs (2.04x)
DTW Sakoe-Chiba band (gesture.rs)
dtw_distance computed j_start/j_end but iterated the FULL 1..=m row,
continue-ing out-of-band — band constrained the path, not the work (O(n*m)).
Now iterates j_start..=j_end (O(n*band)), resetting only the two boundary
guard cells the recurrence reads, with endpoint reachability (|n-m|<=band)
at the return. Bit-identical across 12 shapes x 8 bands
(dtw_banded_bit_identical_to_fullrow).
MEASURED (median, criterion):
n=m=100 band=5 33.45µs -> 13.77µs (2.43x)
n=m=200 band=5 122.32µs -> 29.55µs (4.14x)
n=m=200 band=10 159.98µs -> 60.19µs (2.66x)
Reproduce:
cd v2 && cargo bench -p wifi-densepose-signal --no-default-features \
--bench features_bench
Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
be068748b3
commit
4d384cb884
|
|
@ -66,6 +66,11 @@ harness = false
|
|||
name = "aether_prefilter_bench"
|
||||
harness = false
|
||||
|
||||
## ADR-154: FFT-planner caching (PSD) + DTW Sakoe-Chiba band perf benches.
|
||||
[[bench]]
|
||||
name = "features_bench"
|
||||
harness = false
|
||||
|
||||
## ADR-134: CIR estimator throughput benchmarks
|
||||
[[bench]]
|
||||
name = "cir_bench"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,217 @@
|
|||
//! ADR-154 perf benchmarks: FFT-planner caching (PSD) and DTW Sakoe-Chiba band.
|
||||
//!
|
||||
//! These benches back the *measured* before/after claims in
|
||||
//! `docs/adr/ADR-154-signal-dsp-beyond-sota.md`. Every claim in that ADR has a
|
||||
//! reproduce command pointing here — no perf number ships without a bench.
|
||||
//!
|
||||
//! Reproduce (compile-only):
|
||||
//! cargo bench -p wifi-densepose-signal --no-default-features \
|
||||
//! --bench features_bench --no-run
|
||||
//!
|
||||
//! Reproduce (full run, writes target/criterion/ HTML):
|
||||
//! cargo bench -p wifi-densepose-signal --no-default-features --bench features_bench
|
||||
//!
|
||||
//! Two groups:
|
||||
//! * `psd_fft_planner` — `from_csi_data` (re-plans every call) vs
|
||||
//! `from_csi_data_with_fft` (cached plan). Same output
|
||||
//! (proved bit-identical in features.rs tests).
|
||||
//! * `dtw_sakoe_chiba` — full-row baseline (walks 1..=m, the pre-ADR-154
|
||||
//! behaviour) vs the banded loop (walks the band only).
|
||||
//! Both functions are inlined here because the crate's
|
||||
//! `dtw_distance` is private; the banded copy is a
|
||||
//! faithful transcription of the shipped fix.
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ndarray::Array2;
|
||||
use rustfft::FftPlanner;
|
||||
use std::time::Duration;
|
||||
|
||||
use wifi_densepose_signal::{CsiData, PowerSpectralDensity};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// PSD: fresh-planner vs cached-planner
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn make_csi(subcarriers: usize) -> CsiData {
|
||||
use std::f64::consts::PI;
|
||||
let antennas = 4;
|
||||
let mut amplitude = Array2::zeros((antennas, subcarriers));
|
||||
let mut phase = Array2::zeros((antennas, subcarriers));
|
||||
for i in 0..antennas {
|
||||
for j in 0..subcarriers {
|
||||
amplitude[[i, j]] = 0.5 + 0.3 * ((j as f64 / subcarriers as f64) * PI).sin();
|
||||
phase[[i, j]] = (j as f64 / subcarriers as f64) * 2.0 * PI - PI;
|
||||
}
|
||||
}
|
||||
CsiData::builder()
|
||||
.amplitude(amplitude)
|
||||
.phase(phase)
|
||||
.bandwidth(20.0e6)
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn bench_psd_fft_planner(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("psd_fft_planner");
|
||||
group.measurement_time(Duration::from_secs(4));
|
||||
|
||||
for &fft_size in &[64usize, 128, 256] {
|
||||
let csi = make_csi(fft_size);
|
||||
group.throughput(Throughput::Elements(1));
|
||||
|
||||
// BEFORE: re-plans a FftPlanner on every frame.
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("fresh_planner", fft_size),
|
||||
&fft_size,
|
||||
|b, &n| {
|
||||
b.iter(|| {
|
||||
let psd = PowerSpectralDensity::from_csi_data(black_box(&csi), black_box(n));
|
||||
black_box(psd.total_power)
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
// AFTER: plan once, reuse across frames (the FeatureExtractor path).
|
||||
let mut planner = FftPlanner::<f64>::new();
|
||||
let plan = planner.plan_fft_forward(fft_size);
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("cached_planner", fft_size),
|
||||
&fft_size,
|
||||
|b, &n| {
|
||||
b.iter(|| {
|
||||
let psd = PowerSpectralDensity::from_csi_data_with_fft(
|
||||
black_box(&csi),
|
||||
black_box(n),
|
||||
black_box(&plan),
|
||||
);
|
||||
black_box(psd.total_power)
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DTW: full-row baseline vs Sakoe-Chiba band
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[inline]
|
||||
fn euclidean(a: &[f64], b: &[f64]) -> f64 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y) * (x - y))
|
||||
.sum::<f64>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
/// Pre-ADR-154 behaviour: iterate the FULL 1..=m row, `continue` on out-of-band.
|
||||
fn dtw_fullrow(seq_a: &[Vec<f64>], seq_b: &[Vec<f64>], band_width: usize) -> f64 {
|
||||
let (n, m) = (seq_a.len(), seq_b.len());
|
||||
if n == 0 || m == 0 {
|
||||
return f64::INFINITY;
|
||||
}
|
||||
let mut prev = vec![f64::INFINITY; m + 1];
|
||||
let mut curr = vec![f64::INFINITY; m + 1];
|
||||
prev[0] = 0.0;
|
||||
for i in 1..=n {
|
||||
curr[0] = f64::INFINITY;
|
||||
let j_start = if band_width >= i {
|
||||
1
|
||||
} else {
|
||||
i.saturating_sub(band_width).max(1)
|
||||
};
|
||||
let j_end = (i + band_width).min(m);
|
||||
for j in 1..=m {
|
||||
if j < j_start || j > j_end {
|
||||
curr[j] = f64::INFINITY;
|
||||
continue;
|
||||
}
|
||||
let cost = euclidean(&seq_a[i - 1], &seq_b[j - 1]);
|
||||
curr[j] = cost + prev[j].min(curr[j - 1]).min(prev[j - 1]);
|
||||
}
|
||||
std::mem::swap(&mut prev, &mut curr);
|
||||
}
|
||||
prev[m]
|
||||
}
|
||||
|
||||
/// Post-ADR-154: iterate the band only (transcription of the shipped fix).
|
||||
fn dtw_banded(seq_a: &[Vec<f64>], seq_b: &[Vec<f64>], band_width: usize) -> f64 {
|
||||
let (n, m) = (seq_a.len(), seq_b.len());
|
||||
if n == 0 || m == 0 {
|
||||
return f64::INFINITY;
|
||||
}
|
||||
let mut prev = vec![f64::INFINITY; m + 1];
|
||||
let mut curr = vec![f64::INFINITY; m + 1];
|
||||
prev[0] = 0.0;
|
||||
for i in 1..=n {
|
||||
curr[0] = f64::INFINITY;
|
||||
let j_start = if band_width >= i {
|
||||
1
|
||||
} else {
|
||||
i.saturating_sub(band_width).max(1)
|
||||
};
|
||||
let j_end = (i + band_width).min(m);
|
||||
if j_start >= 1 && j_start - 1 <= m {
|
||||
curr[j_start - 1] = f64::INFINITY;
|
||||
}
|
||||
for j in j_start..=j_end {
|
||||
let cost = euclidean(&seq_a[i - 1], &seq_b[j - 1]);
|
||||
curr[j] = cost + prev[j].min(curr[j - 1]).min(prev[j - 1]);
|
||||
}
|
||||
if j_end + 1 <= m {
|
||||
curr[j_end + 1] = f64::INFINITY;
|
||||
}
|
||||
std::mem::swap(&mut prev, &mut curr);
|
||||
}
|
||||
let lo = n.saturating_sub(band_width).max(1);
|
||||
let hi = (n + band_width).min(m);
|
||||
if m >= lo && m <= hi {
|
||||
prev[m]
|
||||
} else {
|
||||
f64::INFINITY
|
||||
}
|
||||
}
|
||||
|
||||
fn make_seq(len: usize, seed: u64) -> Vec<Vec<f64>> {
|
||||
let mut s = seed;
|
||||
(0..len)
|
||||
.map(|_| {
|
||||
s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let x = ((s >> 33) as f64) / (u32::MAX as f64);
|
||||
vec![x, 1.0 - x, x * 0.5]
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn bench_dtw_band(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("dtw_sakoe_chiba");
|
||||
group.measurement_time(Duration::from_secs(4));
|
||||
|
||||
// The ADR claim case: n = m = 200, band = 5.
|
||||
for &(n, band) in &[(100usize, 5usize), (200, 5), (200, 10)] {
|
||||
let a = make_seq(n, 0x1234);
|
||||
let b = make_seq(n, 0x9abc);
|
||||
// Cells touched ≈ full: n*n; banded: n*(2*band+1).
|
||||
group.throughput(Throughput::Elements((n * n) as u64));
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("full_row", format!("n{n}_band{band}")),
|
||||
&band,
|
||||
|bch, &bw| {
|
||||
bch.iter(|| black_box(dtw_fullrow(black_box(&a), black_box(&b), bw)));
|
||||
},
|
||||
);
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("banded", format!("n{n}_band{band}")),
|
||||
&band,
|
||||
|bch, &bw| {
|
||||
bch.iter(|| black_box(dtw_banded(black_box(&a), black_box(&b), bw)));
|
||||
},
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, bench_psd_fft_planner, bench_dtw_band);
|
||||
criterion_main!(benches);
|
||||
|
|
@ -7,7 +7,8 @@ use crate::csi_processor::CsiData;
|
|||
use chrono::{DateTime, Utc};
|
||||
use ndarray::{Array1, Array2};
|
||||
use num_complex::Complex64;
|
||||
use rustfft::FftPlanner;
|
||||
use rustfft::{Fft, FftPlanner};
|
||||
use std::sync::Arc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Amplitude-based features
|
||||
|
|
@ -449,8 +450,29 @@ pub struct PowerSpectralDensity {
|
|||
}
|
||||
|
||||
impl PowerSpectralDensity {
|
||||
/// Calculate PSD from CSI amplitude data
|
||||
/// Calculate PSD from CSI amplitude data.
|
||||
///
|
||||
/// Plans a fresh FFT each call. On the per-frame hot path, prefer
|
||||
/// [`Self::from_csi_data_with_fft`] with a planner cached in
|
||||
/// [`FeatureExtractor`] — ADR-154 measured the re-plan as the dominant cost
|
||||
/// (see `benches/features_bench.rs`).
|
||||
pub fn from_csi_data(csi_data: &CsiData, fft_size: usize) -> Self {
|
||||
let mut fft_planner = FftPlanner::new();
|
||||
let fft = fft_planner.plan_fft_forward(fft_size);
|
||||
Self::from_csi_data_with_fft(csi_data, fft_size, &fft)
|
||||
}
|
||||
|
||||
/// Calculate PSD reusing a pre-planned FFT (ADR-154 perf path).
|
||||
///
|
||||
/// `fft` must be a forward plan of length `fft_size`. The output is
|
||||
/// **bit-identical** to [`Self::from_csi_data`] for the same `fft_size`
|
||||
/// (rustfft plans of equal length compute the same butterflies); only the
|
||||
/// one-time planner construction is hoisted out of the loop.
|
||||
pub fn from_csi_data_with_fft(
|
||||
csi_data: &CsiData,
|
||||
fft_size: usize,
|
||||
fft: &Arc<dyn Fft<f64>>,
|
||||
) -> Self {
|
||||
let amplitude = &csi_data.amplitude;
|
||||
let flat: Vec<f64> = amplitude.iter().copied().collect();
|
||||
|
||||
|
|
@ -465,9 +487,7 @@ impl PowerSpectralDensity {
|
|||
input.push(Complex64::new(0.0, 0.0));
|
||||
}
|
||||
|
||||
// Apply FFT
|
||||
let mut fft_planner = FftPlanner::new();
|
||||
let fft = fft_planner.plan_fft_forward(fft_size);
|
||||
// Apply the caller-provided (cached) FFT plan.
|
||||
fft.process(&mut input);
|
||||
|
||||
// Calculate power spectrum
|
||||
|
|
@ -613,16 +633,31 @@ impl Default for FeatureExtractorConfig {
|
|||
}
|
||||
}
|
||||
|
||||
/// Feature extractor for CSI data
|
||||
#[derive(Debug)]
|
||||
/// Feature extractor for CSI data.
|
||||
///
|
||||
/// ADR-154: caches the forward FFT plan for `config.fft_size` so the per-frame
|
||||
/// PSD path does not re-plan a `FftPlanner` on every `extract()` call.
|
||||
pub struct FeatureExtractor {
|
||||
config: FeatureExtractorConfig,
|
||||
/// Cached forward FFT plan of length `config.fft_size` (ADR-154 perf path).
|
||||
psd_fft: Arc<dyn Fft<f64>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for FeatureExtractor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("FeatureExtractor")
|
||||
.field("config", &self.config)
|
||||
.field("psd_fft_len", &self.config.fft_size)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl FeatureExtractor {
|
||||
/// Create a new feature extractor
|
||||
pub fn new(config: FeatureExtractorConfig) -> Self {
|
||||
Self { config }
|
||||
let mut planner = FftPlanner::new();
|
||||
let psd_fft = planner.plan_fft_forward(config.fft_size);
|
||||
Self { config, psd_fft }
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
|
|
@ -640,7 +675,11 @@ impl FeatureExtractor {
|
|||
let amplitude = AmplitudeFeatures::from_csi_data(csi_data);
|
||||
let phase = PhaseFeatures::from_csi_data(csi_data);
|
||||
let correlation = CorrelationFeatures::from_csi_data(csi_data);
|
||||
let psd = PowerSpectralDensity::from_csi_data(csi_data, self.config.fft_size);
|
||||
let psd = PowerSpectralDensity::from_csi_data_with_fft(
|
||||
csi_data,
|
||||
self.config.fft_size,
|
||||
&self.psd_fft,
|
||||
);
|
||||
|
||||
let metadata = FeatureMetadata {
|
||||
num_antennas: csi_data.num_antennas,
|
||||
|
|
@ -692,7 +731,11 @@ impl FeatureExtractor {
|
|||
|
||||
/// Extract PSD features only
|
||||
pub fn extract_psd(&self, csi_data: &CsiData) -> PowerSpectralDensity {
|
||||
PowerSpectralDensity::from_csi_data(csi_data, self.config.fft_size)
|
||||
PowerSpectralDensity::from_csi_data_with_fft(
|
||||
csi_data,
|
||||
self.config.fft_size,
|
||||
&self.psd_fft,
|
||||
)
|
||||
}
|
||||
|
||||
/// Extract Doppler features from history
|
||||
|
|
@ -802,6 +845,31 @@ mod tests {
|
|||
assert!(psd.peak_power >= 0.0);
|
||||
}
|
||||
|
||||
// ADR-154: the cached-FFT PSD path must be BIT-IDENTICAL to the
|
||||
// fresh-planner path (the perf change only hoists the planner out of the
|
||||
// loop — same butterflies, same output).
|
||||
#[test]
|
||||
fn psd_cached_fft_bit_identical_to_fresh() {
|
||||
use rustfft::FftPlanner;
|
||||
let csi_data = create_test_csi_data();
|
||||
for fft_size in [16usize, 32, 64, 128, 100, 96] {
|
||||
let fresh = PowerSpectralDensity::from_csi_data(&csi_data, fft_size);
|
||||
let mut planner = FftPlanner::<f64>::new();
|
||||
let plan = planner.plan_fft_forward(fft_size);
|
||||
let cached =
|
||||
PowerSpectralDensity::from_csi_data_with_fft(&csi_data, fft_size, &plan);
|
||||
assert_eq!(
|
||||
fresh.values.to_vec(),
|
||||
cached.values.to_vec(),
|
||||
"PSD values differ for fft_size={fft_size}"
|
||||
);
|
||||
assert_eq!(fresh.total_power.to_bits(), cached.total_power.to_bits());
|
||||
assert_eq!(fresh.peak_frequency.to_bits(), cached.peak_frequency.to_bits());
|
||||
assert_eq!(fresh.centroid.to_bits(), cached.centroid.to_bits());
|
||||
assert_eq!(fresh.bandwidth.to_bits(), cached.bandwidth.to_bits());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_doppler_features() {
|
||||
let history = create_test_history(20);
|
||||
|
|
|
|||
|
|
@ -308,23 +308,59 @@ fn dtw_distance(seq_a: &[Vec<f64>], seq_b: &[Vec<f64>], band_width: usize) -> f6
|
|||
};
|
||||
let j_end = (i + band_width).min(m);
|
||||
|
||||
for j in 1..=m {
|
||||
if j < j_start || j > j_end {
|
||||
curr[j] = f64::INFINITY;
|
||||
continue;
|
||||
}
|
||||
|
||||
// ADR-154: honor the Sakoe-Chiba band by iterating ONLY the in-band
|
||||
// cells [j_start, j_end] instead of walking the full 1..=m row and
|
||||
// `continue`-ing on every out-of-band cell. This cuts the inner-loop
|
||||
// trip count from m to (2·band_width + 1).
|
||||
//
|
||||
// `curr` is reused across rows via swap, so out-of-band cells that a
|
||||
// LATER read can touch must be reset to INFINITY (the previous row may
|
||||
// have left a stale finite value). Reads of `curr`/`prev` only ever
|
||||
// touch the immediate neighbours of the band:
|
||||
// - `curr[j_start - 1]` (the left/deletion term at j == j_start),
|
||||
// - next row's `prev[j_end + 1]` (the insertion/match term as the
|
||||
// band slides right by one), and
|
||||
// - the final `prev[m]` answer when m itself is out of band.
|
||||
// Resetting `curr[j_start-1]` and `curr[j_end+1..=m up to one cell]`
|
||||
// reproduces the full-row version **bit-for-bit**.
|
||||
// When `j_start > j_end` the band is empty for this row (j_start can even
|
||||
// exceed m). The full-row version would set every cell to INFINITY; we
|
||||
// reproduce that by leaving the band loop empty and INFINITY-filling the
|
||||
// boundary guards below (all clamped to valid indices).
|
||||
if j_start >= 1 && j_start - 1 <= m {
|
||||
curr[j_start - 1] = f64::INFINITY;
|
||||
}
|
||||
for j in j_start..=j_end {
|
||||
let cost = euclidean_distance(&seq_a[i - 1], &seq_b[j - 1]);
|
||||
curr[j] = cost
|
||||
+ prev[j] // insertion
|
||||
.min(curr[j - 1]) // deletion
|
||||
.min(prev[j - 1]); // match
|
||||
}
|
||||
// Guard the right boundary with a SINGLE cell. As `i` increments the
|
||||
// band slides right by one, so the only out-of-band cell the next row
|
||||
// reads beyond `j_end` is `prev[j_end + 1]` (its insertion/match term).
|
||||
// Resetting just that one cell keeps the per-row cost O(band), not O(m).
|
||||
// The final `prev[m]` answer is handled by the band-reachability check
|
||||
// at the return site, so we never need to walk the whole tail.
|
||||
if j_end + 1 <= m {
|
||||
curr[j_end + 1] = f64::INFINITY;
|
||||
}
|
||||
|
||||
std::mem::swap(&mut prev, &mut curr);
|
||||
}
|
||||
|
||||
prev[m]
|
||||
// The endpoint (n, m) is reachable only if `m` lies within the LAST row's
|
||||
// band `[n - band, n + band]` — i.e. `|n - m| <= band_width`. Outside that,
|
||||
// the full-row version left `prev[m] = INFINITY`, so we return INFINITY to
|
||||
// stay bit-identical (the banded loop never wrote `prev[m]`).
|
||||
let last_row_lo = n.saturating_sub(band_width).max(1);
|
||||
let last_row_hi = (n + band_width).min(m);
|
||||
if m >= last_row_lo && m <= last_row_hi {
|
||||
prev[m]
|
||||
} else {
|
||||
f64::INFINITY
|
||||
}
|
||||
}
|
||||
|
||||
/// Euclidean distance between two feature vectors.
|
||||
|
|
@ -344,6 +380,82 @@ fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
|
|||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Reference full-row banded DTW (the pre-ADR-154 implementation): walks the
|
||||
/// entire 1..=m row and `continue`s on out-of-band cells. Used to prove the
|
||||
/// optimized banded loop is bit-identical.
|
||||
fn dtw_distance_fullrow(seq_a: &[Vec<f64>], seq_b: &[Vec<f64>], band_width: usize) -> f64 {
|
||||
let n = seq_a.len();
|
||||
let m = seq_b.len();
|
||||
if n == 0 || m == 0 {
|
||||
return f64::INFINITY;
|
||||
}
|
||||
let mut prev = vec![f64::INFINITY; m + 1];
|
||||
let mut curr = vec![f64::INFINITY; m + 1];
|
||||
prev[0] = 0.0;
|
||||
for i in 1..=n {
|
||||
curr[0] = f64::INFINITY;
|
||||
let j_start = if band_width >= i {
|
||||
1
|
||||
} else {
|
||||
i.saturating_sub(band_width).max(1)
|
||||
};
|
||||
let j_end = (i + band_width).min(m);
|
||||
for j in 1..=m {
|
||||
if j < j_start || j > j_end {
|
||||
curr[j] = f64::INFINITY;
|
||||
continue;
|
||||
}
|
||||
let cost = euclidean_distance(&seq_a[i - 1], &seq_b[j - 1]);
|
||||
curr[j] = cost + prev[j].min(curr[j - 1]).min(prev[j - 1]);
|
||||
}
|
||||
std::mem::swap(&mut prev, &mut curr);
|
||||
}
|
||||
prev[m]
|
||||
}
|
||||
|
||||
/// ADR-154: the banded loop must be BIT-IDENTICAL to the full-row version
|
||||
/// across a sweep of sizes and band widths (this is the perf change's
|
||||
/// correctness contract — same numbers, fewer cells touched).
|
||||
#[test]
|
||||
fn dtw_banded_bit_identical_to_fullrow() {
|
||||
// Deterministic pseudo-random sequences.
|
||||
let mk = |len: usize, seed: u64| -> Vec<Vec<f64>> {
|
||||
let mut s = seed;
|
||||
(0..len)
|
||||
.map(|_| {
|
||||
s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let x = ((s >> 33) as f64) / (u32::MAX as f64);
|
||||
vec![x, 1.0 - x]
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
for &(n, m) in &[
|
||||
(10, 10),
|
||||
(10, 20),
|
||||
(20, 10),
|
||||
(50, 50),
|
||||
(200, 200),
|
||||
(7, 13),
|
||||
(13, 7),
|
||||
(1, 5),
|
||||
(5, 1),
|
||||
(100, 30),
|
||||
(30, 100),
|
||||
(200, 195),
|
||||
] {
|
||||
let a = mk(n, 0x1234);
|
||||
let b = mk(m, 0x9abc);
|
||||
for band in [0usize, 1, 2, 3, 5, 8, 50, 1000] {
|
||||
let opt = dtw_distance(&a, &b, band);
|
||||
let refv = dtw_distance_fullrow(&a, &b, band);
|
||||
assert!(
|
||||
(opt == refv) || (opt.is_infinite() && refv.is_infinite()),
|
||||
"DTW mismatch n={n} m={m} band={band}: opt={opt} ref={refv}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn make_template(
|
||||
name: &str,
|
||||
gesture_type: GestureType,
|
||||
|
|
|
|||
Loading…
Reference in New Issue