feat: scalable WiFlow model with 4 size presets (#362)

Add --scale flag with 4 presets for dataset-appropriate sizing:

  lite:   ~190K params, 2 TCN blocks k=3  (trains in seconds)
  small:  ~200K params, 4 TCN blocks k=5  (trains in minutes)
  medium: ~800K params, 4 TCN blocks k=7  (trains in ~15 min)
  full:   ~7.7M params, 4 TCN blocks k=7  (trains in hours)

Refactored model to use dynamic TCN block count, kernel size,
channel widths, hidden dim, and SPSA perturbation count — all
driven by the scale preset. Default is 'lite' for fast iteration.

Validated: lite model completes 30 epochs on 265 samples in ~2 min
on Windows CPU (vs stuck at epoch 1 with full model).

Scale up with: --scale small|medium|full as dataset grows.

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruv 2026-04-06 14:55:35 -04:00
parent d09baa6a09
commit 327d0d13f6
1 changed files with 57 additions and 32 deletions

View File

@ -73,6 +73,7 @@ const { values: args } = parseArgs({
lr: { type: 'string', default: '0.0001' },
'skip-contrastive': { type: 'boolean', default: false },
'eval-split': { type: 'string', default: '0.2' },
scale: { type: 'string', short: 's', default: 'lite' },
verbose: { type: 'boolean', short: 'v', default: false },
},
strict: true,
@ -123,6 +124,24 @@ const CONFIG = {
temporalWeight: 0.1,
};
// ---------------------------------------------------------------------------
// Model scale presets: lite → small → medium → full
// lite: ~45K params, trains in seconds (good for <1K samples)
// small: ~200K params, trains in minutes (good for 1K-10K samples)
// medium: ~800K params, trains in ~15 min (good for 10K-50K samples)
// full: ~7.7M params, trains in hours (good for 50K+ samples)
// ---------------------------------------------------------------------------
const SCALE_PRESETS = {
lite: { tcnChannels: [32, 32, 32, 32], hiddenDim: 256, tcnBlocks: 2, kernel: 3, spsaK: 1 },
small: { tcnChannels: [64, 64, 48, 32], hiddenDim: 512, tcnBlocks: 4, kernel: 5, spsaK: 2 },
medium: { tcnChannels: [128, 128, 96, 64], hiddenDim: 1024, tcnBlocks: 4, kernel: 7, spsaK: 3 },
full: { tcnChannels: [256, 256, 192, 128], hiddenDim: 2048, tcnBlocks: 4, kernel: 7, spsaK: 3 },
};
const scaleKey = args.scale || 'lite';
const SCALE = SCALE_PRESETS[scaleKey] || SCALE_PRESETS.lite;
console.log(`Model scale: ${scaleKey} (${JSON.stringify(SCALE)})`);
// Compute phase epochs
const totalForPhases = CONFIG.skipContrastive
? CONFIG.totalEpochs
@ -853,33 +872,40 @@ class Linear {
* Sigmoid to [0, 1]
*/
class WiFlowSupervisedModel {
constructor(inputDim, timeSteps, numKeypoints, seed) {
constructor(inputDim, timeSteps, numKeypoints, seed, scale) {
this.inputDim = inputDim;
this.timeSteps = timeSteps;
this.numKeypoints = numKeypoints || 17;
this.outDim = this.numKeypoints * 2;
this.scale = scale || SCALE;
const rng = createRng(seed || 42);
const ch = this.scale.tcnChannels;
const k = this.scale.kernel;
// TCN blocks: inputDim -> 256 -> 256 -> 192 -> 128
this.tcn1 = new TCNBlock(inputDim, 256, 7, 1, rng);
this.tcn2 = new TCNBlock(256, 256, 7, 2, rng);
this.tcn3 = new TCNBlock(256, 192, 7, 4, rng);
this.tcn4 = new TCNBlock(192, 128, 7, 8, rng);
// TCN blocks: inputDim -> ch[0] -> ch[1] -> ch[2] -> ch[3]
this.tcnBlocks = [];
let prevCh = inputDim;
const dilations = [1, 2, 4, 8];
const nBlocks = Math.min(this.scale.tcnBlocks, ch.length);
for (let i = 0; i < nBlocks; i++) {
this.tcnBlocks.push(new TCNBlock(prevCh, ch[i], k, dilations[i], rng));
prevCh = ch[i];
}
// Flatten: 128 * timeSteps -> linear -> 34
const flatDim = 128 * timeSteps;
this.fc1 = new Linear(flatDim, 2048, rng);
this.fc2 = new Linear(2048, this.outDim, rng);
// Flatten: lastCh * timeSteps -> hidden -> 34
const flatDim = prevCh * timeSteps;
const hiddenDim = this.scale.hiddenDim;
this.fc1 = new Linear(flatDim, hiddenDim, rng);
this.fc2 = new Linear(hiddenDim, this.outDim, rng);
this._totalParams = null;
}
totalParams() {
if (this._totalParams === null) {
this._totalParams = this.tcn1.numParams() + this.tcn2.numParams() +
this.tcn3.numParams() + this.tcn4.numParams() +
this.fc1.numParams() + this.fc2.numParams();
this._totalParams = this.fc1.numParams() + this.fc2.numParams();
for (const b of this.tcnBlocks) this._totalParams += b.numParams();
}
return this._totalParams;
}
@ -892,14 +918,11 @@ class WiFlowSupervisedModel {
forward(csi) {
const T = this.timeSteps;
// TCN stages
let x = this.tcn1.forward(csi, T);
x = this.tcn2.forward(x, T);
x = this.tcn3.forward(x, T);
x = this.tcn4.forward(x, T);
// Flatten: [128, T] -> [128*T]
// x is already flat as [128 * T]
// TCN stages (dynamic block count based on scale)
let x = csi;
for (const block of this.tcnBlocks) {
x = block.forward(x, T);
}
// FC layers with ReLU
let h = this.fc1.forward(x);
@ -920,10 +943,10 @@ class WiFlowSupervisedModel {
*/
encode(csi) {
const T = this.timeSteps;
let x = this.tcn1.forward(csi, T);
x = this.tcn2.forward(x, T);
x = this.tcn3.forward(x, T);
x = this.tcn4.forward(x, T);
let x = csi;
for (const block of this.tcnBlocks) {
x = block.forward(x, T);
}
let h = this.fc1.forward(x);
relu(h);
@ -963,10 +986,9 @@ class WiFlowSupervisedModel {
params.push({ weight: linear.bias, mom: linear.biasMom, name: `${prefix}.bias` });
};
addTCN(this.tcn1, 'tcn1');
addTCN(this.tcn2, 'tcn2');
addTCN(this.tcn3, 'tcn3');
addTCN(this.tcn4, 'tcn4');
for (let i = 0; i < this.tcnBlocks.length; i++) {
addTCN(this.tcnBlocks[i], `tcn${i}`);
}
addLinear(this.fc1, 'fc1');
addLinear(this.fc2, 'fc2');
@ -1259,9 +1281,12 @@ async function main() {
// Step 2: Initialize model
// -----------------------------------------------------------------------
console.log('[2/6] Initializing WiFlow supervised model...');
const model = new WiFlowSupervisedModel(inputDim, T, CONFIG.numKeypoints, 42);
const model = new WiFlowSupervisedModel(inputDim, T, CONFIG.numKeypoints, 42, SCALE);
const ch = SCALE.tcnChannels.slice(0, SCALE.tcnBlocks);
const lastCh = ch[ch.length - 1];
console.log(` Scale: ${scaleKey}`);
console.log(` Parameters: ${model.totalParams().toLocaleString()}`);
console.log(` Architecture: TCN(${inputDim}->256->256->192->128, k=7, d=[1,2,4,8]) -> FC(${128 * T}->2048->34)`);
console.log(` Architecture: TCN(${inputDim}->${ch.join('->')}, k=${SCALE.kernel}, d=[1,2,4,8]) -> FC(${lastCh * T}->${SCALE.hiddenDim}->34)`);
console.log('');
const trainingLog = {
@ -1330,7 +1355,7 @@ async function main() {
};
const batch = shuffledTrain.slice(b, batchEnd);
const grad = multiSpsaGrad(model, batch, lossFn, p, rng, 3);
const grad = multiSpsaGrad(model, batch, lossFn, p, rng, SCALE.spsaK);
sgdStep(p, grad, lr, CONFIG.momentum);
}