wifi-densepose/scripts/train-ruvllm.js

1072 lines
37 KiB
JavaScript

#!/usr/bin/env node
/**
* WiFi-DensePose CSI Training Pipeline using ruvllm
*
* Complete training, refinement, and quantization pipeline for CSI sensing models.
* Uses ruvllm's ContrastiveTrainer, TrainingPipeline, LoRA, EWC, and SafeTensors export.
*
* Usage:
* node scripts/train-ruvllm.js --data data/recordings/pretrain-*.csi.jsonl
* node scripts/train-ruvllm.js --data data/recordings/pretrain-1775182186.csi.jsonl --benchmark
* node scripts/train-ruvllm.js --data data/recordings/*.csi.jsonl --output models/csi-v1
*
* ADR: docs/adr/ADR-071-ruvllm-training-pipeline.md
*/
'use strict';
const fs = require('fs');
const path = require('path');
const { parseArgs } = require('util');
// ---------------------------------------------------------------------------
// Resolve ruvllm from vendor tree — use compiled JS output
// ---------------------------------------------------------------------------
const RUVLLM_PATH = path.resolve(__dirname, '..', 'vendor', 'ruvector', 'npm', 'packages', 'ruvllm', 'src');
const {
ContrastiveTrainer,
cosineSimilarity,
tripletLoss,
infoNCELoss,
} = require(path.join(RUVLLM_PATH, 'contrastive.js'));
const {
TrainingPipeline,
} = require(path.join(RUVLLM_PATH, 'training.js'));
const {
LoraAdapter,
LoraManager,
} = require(path.join(RUVLLM_PATH, 'lora.js'));
const {
EwcManager,
ReasoningBank,
SonaCoordinator,
} = require(path.join(RUVLLM_PATH, 'sona.js'));
const {
SafeTensorsWriter,
ModelExporter,
DatasetExporter,
} = require(path.join(RUVLLM_PATH, 'export.js'));
// ---------------------------------------------------------------------------
// CLI argument parsing
// ---------------------------------------------------------------------------
const { values: args } = parseArgs({
options: {
data: { type: 'string', short: 'd' },
output: { type: 'string', short: 'o', default: 'models/csi-ruvllm' },
benchmark: { type: 'boolean', short: 'b', default: false },
epochs: { type: 'string', short: 'e', default: '20' },
'batch-size': { type: 'string', default: '32' },
'lora-rank': { type: 'string', default: '4' },
'quantize-bits': { type: 'string', default: '4' },
verbose: { type: 'boolean', short: 'v', default: false },
},
strict: true,
});
if (!args.data) {
console.error('Usage: node scripts/train-ruvllm.js --data <path-to-csi-jsonl> [--output dir] [--benchmark]');
process.exit(1);
}
const CONFIG = {
dataGlob: args.data,
outputDir: args.output,
benchmark: args.benchmark,
epochs: parseInt(args.epochs, 10),
batchSize: parseInt(args['batch-size'], 10),
loraRank: parseInt(args['lora-rank'], 10),
quantizeBits: parseInt(args['quantize-bits'], 10),
verbose: args.verbose,
// Contrastive training hyperparameters
margin: 0.3,
temperature: 0.07,
hardNegativeRatio: 0.7,
learningRate: 0.001,
// Temporal window thresholds (seconds)
positiveWindowSec: 1.0,
negativeWindowSec: 30.0,
// Feature dimensions
inputDim: 8, // 8-dim CSI feature vector
hiddenDim: 64, // intermediate
embeddingDim: 128, // output embedding
};
// ---------------------------------------------------------------------------
// Data loading
// ---------------------------------------------------------------------------
/**
* Parse CSI JSONL file into typed frames.
* Returns arrays of feature frames, vitals frames, and raw CSI frames.
*/
function loadCsiData(filePath) {
const features = [];
const vitals = [];
const rawCsi = [];
const content = fs.readFileSync(filePath, 'utf-8');
const lines = content.split('\n').filter(l => l.trim());
for (const line of lines) {
try {
const frame = JSON.parse(line);
switch (frame.type) {
case 'feature':
features.push({
timestamp: frame.timestamp,
nodeId: frame.node_id,
features: frame.features, // 8-dim float array
rssi: frame.rssi,
seq: frame.seq,
});
break;
case 'vitals':
vitals.push({
timestamp: frame.timestamp,
nodeId: frame.node_id,
breathingBpm: frame.breathing_bpm,
heartrateBpm: frame.heartrate_bpm,
nPersons: frame.n_persons,
motionEnergy: frame.motion_energy,
presenceScore: frame.presence_score,
rssi: frame.rssi,
});
break;
case 'raw_csi':
rawCsi.push({
timestamp: frame.timestamp,
nodeId: frame.node_id,
subcarriers: frame.subcarriers,
iqHex: frame.iq_hex,
rssi: frame.rssi,
});
break;
}
} catch (e) {
// Skip malformed lines
}
}
return { features, vitals, rawCsi };
}
/**
* Resolve glob pattern to file list. Handles simple * patterns on both
* Unix and Windows without requiring a glob library.
*/
function resolveGlob(pattern) {
if (!pattern.includes('*')) {
return fs.existsSync(pattern) ? [pattern] : [];
}
const dir = path.dirname(pattern);
const base = path.basename(pattern);
const regex = new RegExp('^' + base.replace(/\*/g, '.*') + '$');
if (!fs.existsSync(dir)) return [];
return fs.readdirSync(dir)
.filter(f => regex.test(f))
.map(f => path.join(dir, f));
}
// ---------------------------------------------------------------------------
// Embedding encoder (simulates 8 -> 64 -> 128 FC network)
// ---------------------------------------------------------------------------
/**
* Simple two-layer FC encoder: 8 -> 64 -> 128
* Uses deterministic seeded weights for reproducibility.
*/
class CsiEncoder {
constructor(inputDim, hiddenDim, outputDim, seed = 42) {
this.inputDim = inputDim;
this.hiddenDim = hiddenDim;
this.outputDim = outputDim;
// Initialize weights with seeded pseudo-random values (Kaiming)
const rng = this._createRng(seed);
this.w1 = this._initMatrix(inputDim, hiddenDim, rng, inputDim);
this.b1 = new Float64Array(hiddenDim);
this.w2 = this._initMatrix(hiddenDim, outputDim, rng, hiddenDim);
this.b2 = new Float64Array(outputDim);
}
/**
* Forward pass: input (8-dim) -> embedding (128-dim)
*/
encode(input) {
// Layer 1: input @ w1 + b1, then ReLU
const hidden = new Float64Array(this.hiddenDim);
for (let j = 0; j < this.hiddenDim; j++) {
let sum = this.b1[j];
for (let i = 0; i < this.inputDim; i++) {
sum += (input[i] || 0) * this.w1[i * this.hiddenDim + j];
}
hidden[j] = Math.max(0, sum); // ReLU
}
// Layer 2: hidden @ w2 + b2
const output = new Float64Array(this.outputDim);
for (let j = 0; j < this.outputDim; j++) {
let sum = this.b2[j];
for (let i = 0; i < this.hiddenDim; i++) {
sum += hidden[i] * this.w2[i * this.outputDim + j];
}
output[j] = sum;
}
// L2 normalize
let norm = 0;
for (let i = 0; i < output.length; i++) norm += output[i] * output[i];
norm = Math.sqrt(norm) || 1;
const result = new Array(this.outputDim);
for (let i = 0; i < this.outputDim; i++) result[i] = output[i] / norm;
return result;
}
/**
* Encode a batch of inputs.
*/
encodeBatch(inputs) {
return inputs.map(input => this.encode(input));
}
_createRng(seed) {
// Simple xorshift32 PRNG
let s = seed;
return () => {
s ^= s << 13;
s ^= s >> 17;
s ^= s << 5;
return ((s >>> 0) / 4294967296) - 0.5;
};
}
_initMatrix(rows, cols, rng, fanIn) {
const scale = Math.sqrt(2.0 / fanIn);
const arr = new Float64Array(rows * cols);
for (let i = 0; i < arr.length; i++) {
arr[i] = rng() * scale;
}
return arr;
}
}
// ---------------------------------------------------------------------------
// Triplet generation
// ---------------------------------------------------------------------------
/**
* Generate contrastive triplets from feature frames.
*
* Strategies:
* 1. Temporal positive: frames within 1s = similar environment state
* 2. Temporal negative: frames >30s apart = different state
* 3. Cross-node positive: same timestamp from node 1 and node 2 = same person
* 4. Cross-node negative: different timestamp, different node = different state
* 5. Hard negatives: frames near transition boundaries
*/
function generateTriplets(features, vitals, config) {
const triplets = [];
// Index features by node
const byNode = {};
for (const f of features) {
if (!byNode[f.nodeId]) byNode[f.nodeId] = [];
byNode[f.nodeId].push(f);
}
const nodeIds = Object.keys(byNode).map(Number);
// Sort each node's features by timestamp
for (const nid of nodeIds) {
byNode[nid].sort((a, b) => a.timestamp - b.timestamp);
}
// Build a timestamp -> vitals map for labeling
const vitalsMap = new Map();
for (const v of vitals) {
const key = `${v.nodeId}-${Math.round(v.timestamp * 10)}`;
vitalsMap.set(key, v);
}
function findNearestVitals(nodeId, timestamp) {
// Simple nearest-neighbor lookup in vitals
let best = null;
let bestDist = Infinity;
for (const v of vitals) {
if (v.nodeId !== nodeId) continue;
const dist = Math.abs(v.timestamp - timestamp);
if (dist < bestDist) {
bestDist = dist;
best = v;
}
}
return best;
}
// Strategy 1 + 2: Temporal positive/negative within same node
for (const nid of nodeIds) {
const frames = byNode[nid];
for (let i = 0; i < frames.length; i++) {
const anchor = frames[i];
// Find temporal positive (within 1 second)
for (let j = i + 1; j < frames.length && j < i + 20; j++) {
const candidate = frames[j];
const timeDiff = Math.abs(candidate.timestamp - anchor.timestamp);
if (timeDiff <= config.positiveWindowSec) {
// Find a temporal negative (>30 seconds away)
for (let k = 0; k < frames.length; k++) {
const neg = frames[k];
const negTimeDiff = Math.abs(neg.timestamp - anchor.timestamp);
if (negTimeDiff >= config.negativeWindowSec) {
const isHard = negTimeDiff < config.negativeWindowSec * 2;
triplets.push({
anchor: anchor.features,
positive: candidate.features,
negative: neg.features,
isHard,
type: 'temporal',
anchorLabel: `node${nid}-t${anchor.timestamp.toFixed(2)}`,
posLabel: `node${nid}-t${candidate.timestamp.toFixed(2)}`,
negLabel: `node${nid}-t${neg.timestamp.toFixed(2)}`,
});
break; // One negative per positive
}
}
}
}
}
}
// Strategy 3: Cross-node positive (same timestamp, different nodes)
if (nodeIds.length >= 2) {
const node1Frames = byNode[nodeIds[0]] || [];
const node2Frames = byNode[nodeIds[1]] || [];
for (const f1 of node1Frames) {
// Find node2 frame closest in time
let bestMatch = null;
let bestDist = Infinity;
for (const f2 of node2Frames) {
const dist = Math.abs(f2.timestamp - f1.timestamp);
if (dist < bestDist) {
bestDist = dist;
bestMatch = f2;
}
}
if (bestMatch && bestDist < config.positiveWindowSec) {
// Find a cross-node negative (different time from different node)
for (const f2neg of node2Frames) {
const negDist = Math.abs(f2neg.timestamp - f1.timestamp);
if (negDist >= config.negativeWindowSec) {
triplets.push({
anchor: f1.features,
positive: bestMatch.features,
negative: f2neg.features,
isHard: false,
type: 'cross-node',
anchorLabel: `node${f1.nodeId}-t${f1.timestamp.toFixed(2)}`,
posLabel: `node${bestMatch.nodeId}-t${bestMatch.timestamp.toFixed(2)}`,
negLabel: `node${f2neg.nodeId}-t${f2neg.timestamp.toFixed(2)}`,
});
break;
}
}
}
}
}
// Strategy 5: Hard negatives near scenario transitions
// Detect transitions via motion_energy spikes in vitals
const sortedVitals = [...vitals].sort((a, b) => a.timestamp - b.timestamp);
const transitionTimes = [];
for (let i = 1; i < sortedVitals.length; i++) {
const prev = sortedVitals[i - 1];
const curr = sortedVitals[i];
const energyDelta = Math.abs(curr.motionEnergy - prev.motionEnergy);
if (energyDelta > 2.0) {
transitionTimes.push(curr.timestamp);
}
}
// Add hard negatives from transition boundaries
for (const transTime of transitionTimes.slice(0, 50)) {
for (const nid of nodeIds) {
const frames = byNode[nid];
// Find frame just before and just after transition
let before = null, after = null;
for (const f of frames) {
if (f.timestamp < transTime) before = f;
if (f.timestamp > transTime && !after) after = f;
}
if (before && after) {
// The first non-transition frame as anchor
const anchorIdx = Math.max(0, frames.indexOf(before) - 5);
const anchor = frames[anchorIdx];
if (anchor) {
triplets.push({
anchor: anchor.features,
positive: before.features,
negative: after.features,
isHard: true,
type: 'transition-hard',
anchorLabel: `node${nid}-pre-transition`,
posLabel: `node${nid}-before`,
negLabel: `node${nid}-after`,
});
}
}
}
}
return triplets;
}
// ---------------------------------------------------------------------------
// Quantization (TurboQuant simulation)
// ---------------------------------------------------------------------------
/**
* Quantize Float32Array to N-bit fixed point.
* Returns { quantized: Uint8Array, scale: number, zeroPoint: number }.
* Compression ratio: 32 / bits.
*/
function quantizeWeights(weights, bits) {
const maxVal = 2 ** (bits - 1) - 1;
const minVal = -(2 ** (bits - 1));
let wMin = Infinity, wMax = -Infinity;
for (let i = 0; i < weights.length; i++) {
if (weights[i] < wMin) wMin = weights[i];
if (weights[i] > wMax) wMax = weights[i];
}
const scale = (wMax - wMin) / (maxVal - minVal) || 1e-10;
const zeroPoint = Math.round(-wMin / scale + minVal);
// Pack into bytes (simplified — store one value per byte for 4-bit/8-bit)
const bytesPerWeight = bits <= 8 ? 1 : 2;
const quantized = new Uint8Array(weights.length * bytesPerWeight);
for (let i = 0; i < weights.length; i++) {
let q = Math.round(weights[i] / scale) + zeroPoint;
q = Math.max(minVal, Math.min(maxVal, q));
quantized[i] = (q - minVal) & 0xFF;
}
return {
quantized,
scale,
zeroPoint,
bits,
originalSize: weights.length * 4, // fp32 bytes
quantizedSize: quantized.length,
compressionRatio: (weights.length * 4) / quantized.length,
};
}
/**
* Dequantize back to float for quality assessment.
*/
function dequantizeWeights(quantized, scale, zeroPoint, bits) {
const minVal = -(2 ** (bits - 1));
const result = new Float32Array(quantized.length);
for (let i = 0; i < quantized.length; i++) {
const q = (quantized[i] + minVal) - zeroPoint;
result[i] = q * scale;
}
return result;
}
/**
* Compute quantization quality loss (RMSE between original and dequantized).
*/
function quantizationQuality(original, dequantized) {
let sumSqErr = 0;
const n = Math.min(original.length, dequantized.length);
for (let i = 0; i < n; i++) {
const diff = original[i] - dequantized[i];
sumSqErr += diff * diff;
}
return Math.sqrt(sumSqErr / n);
}
// ---------------------------------------------------------------------------
// Training labels from vitals data
// ---------------------------------------------------------------------------
/**
* Create task-head labels from vitals data for each feature frame.
* Returns { presence: number, activity: number[], vitalsTarget: number[] }
*/
function createLabels(featureFrame, vitals) {
// Find nearest vitals for this frame
let nearest = null;
let bestDist = Infinity;
for (const v of vitals) {
if (v.nodeId !== featureFrame.nodeId) continue;
const dist = Math.abs(v.timestamp - featureFrame.timestamp);
if (dist < bestDist) {
bestDist = dist;
nearest = v;
}
}
if (!nearest || bestDist > 2.0) {
return null; // No matching vitals within 2 seconds
}
// Presence: binary (threshold at 0.3)
const presence = nearest.presenceScore > 0.3 ? 1.0 : 0.0;
// Activity: [still, moving, empty] as one-hot
let activity;
if (nearest.presenceScore <= 0.1) {
activity = [0, 0, 1]; // empty
} else if (nearest.motionEnergy > 2.0) {
activity = [0, 1, 0]; // moving
} else {
activity = [1, 0, 0]; // still
}
// Vitals: [breathing BPM normalized, heartrate BPM normalized]
const vitalsTarget = [
nearest.breathingBpm / 30.0, // normalize to ~0-1 range
nearest.heartrateBpm / 120.0, // normalize to ~0-1 range
];
return { presence, activity, vitalsTarget };
}
// ---------------------------------------------------------------------------
// Main pipeline
// ---------------------------------------------------------------------------
async function main() {
const startTime = Date.now();
console.log('=== WiFi-DensePose CSI Training Pipeline (ruvllm) ===');
console.log(`Config: epochs=${CONFIG.epochs} batch=${CONFIG.batchSize} lora_rank=${CONFIG.loraRank} quant=${CONFIG.quantizeBits}bit`);
console.log('');
// -----------------------------------------------------------------------
// Step 1: Load CSI data
// -----------------------------------------------------------------------
console.log('[1/9] Loading CSI data...');
const files = resolveGlob(CONFIG.dataGlob);
if (files.length === 0) {
console.error(`No files found matching: ${CONFIG.dataGlob}`);
process.exit(1);
}
let allFeatures = [];
let allVitals = [];
let allRawCsi = [];
for (const file of files) {
console.log(` Loading: ${path.basename(file)}`);
const { features, vitals, rawCsi } = loadCsiData(file);
allFeatures = allFeatures.concat(features);
allVitals = allVitals.concat(vitals);
allRawCsi = allRawCsi.concat(rawCsi);
}
console.log(` Loaded: ${allFeatures.length} features, ${allVitals.length} vitals, ${allRawCsi.length} raw CSI frames`);
console.log(` Nodes: ${[...new Set(allFeatures.map(f => f.nodeId))].join(', ')}`);
if (allFeatures.length === 0) {
console.error('No feature frames found in data. Ensure data contains type="feature" frames.');
process.exit(1);
}
// -----------------------------------------------------------------------
// Step 2: Generate contrastive triplets
// -----------------------------------------------------------------------
console.log('\n[2/9] Generating contrastive triplets...');
const triplets = generateTriplets(allFeatures, allVitals, CONFIG);
const temporalCount = triplets.filter(t => t.type === 'temporal').length;
const crossNodeCount = triplets.filter(t => t.type === 'cross-node').length;
const hardCount = triplets.filter(t => t.isHard).length;
console.log(` Total triplets: ${triplets.length}`);
console.log(` Temporal: ${temporalCount}, Cross-node: ${crossNodeCount}, Hard: ${hardCount}`);
console.log(` Hard negative ratio: ${(hardCount / triplets.length * 100).toFixed(1)}%`);
if (triplets.length === 0) {
console.error('No triplets generated. Data may lack temporal diversity (need >30s span).');
process.exit(1);
}
// -----------------------------------------------------------------------
// Step 3: Build encoder and encode features
// -----------------------------------------------------------------------
console.log('\n[3/9] Building CSI encoder (8 -> 64 -> 128)...');
const encoder = new CsiEncoder(CONFIG.inputDim, CONFIG.hiddenDim, CONFIG.embeddingDim);
// Pre-encode all features
console.log(' Encoding feature vectors...');
const encodingStart = Date.now();
const encodedFeatures = allFeatures.map(f => ({
...f,
embedding: encoder.encode(f.features),
}));
console.log(` Encoded ${encodedFeatures.length} frames in ${Date.now() - encodingStart}ms`);
// -----------------------------------------------------------------------
// Phase 1: Contrastive pretraining
// -----------------------------------------------------------------------
console.log('\n[4/9] Phase 1: Contrastive pretraining...');
const contrastiveTrainer = new ContrastiveTrainer({
epochs: CONFIG.epochs,
batchSize: CONFIG.batchSize,
margin: CONFIG.margin,
temperature: CONFIG.temperature,
hardNegativeRatio: CONFIG.hardNegativeRatio,
learningRate: CONFIG.learningRate,
outputPath: path.join(CONFIG.outputDir, 'contrastive'),
});
// Add triplets with encoded embeddings
for (const triplet of triplets) {
const anchorEmb = encoder.encode(triplet.anchor);
const posEmb = encoder.encode(triplet.positive);
const negEmb = encoder.encode(triplet.negative);
contrastiveTrainer.addTriplet(
triplet.anchorLabel,
anchorEmb,
triplet.posLabel,
posEmb,
triplet.negLabel,
negEmb,
triplet.isHard
);
}
console.log(` Triplets loaded: ${contrastiveTrainer.getTripletCount()}`);
const contrastiveResult = contrastiveTrainer.train();
console.log(` Epochs: ${contrastiveResult.history.length}`);
console.log(` Initial loss: ${contrastiveResult.initialLoss.toFixed(6)}`);
console.log(` Final loss: ${contrastiveResult.finalLoss.toFixed(6)}`);
console.log(` Improvement: ${contrastiveResult.improvement.toFixed(1)}%`);
console.log(` Duration: ${contrastiveResult.durationMs}ms`);
// Export contrastive training data
const contrastiveOutDir = contrastiveTrainer.exportTrainingData();
console.log(` Training data exported to: ${contrastiveOutDir}`);
// -----------------------------------------------------------------------
// Phase 2: Task head training via TrainingPipeline
// -----------------------------------------------------------------------
console.log('\n[5/9] Phase 2: Task head training...');
// Create LoRA adapter for the task heads: 128-dim input, 128-dim output
const taskAdapter = new LoraAdapter(
{ rank: CONFIG.loraRank * 2, alpha: CONFIG.loraRank * 4, dropout: 0.05, targetModules: ['encoder', 'task_heads'] },
CONFIG.embeddingDim,
CONFIG.embeddingDim
);
const taskPipeline = new TrainingPipeline({
learningRate: CONFIG.learningRate,
batchSize: CONFIG.batchSize,
epochs: Math.max(5, Math.floor(CONFIG.epochs / 2)),
scheduler: 'cosine',
warmupSteps: 50,
earlyStoppingPatience: 5,
checkpointInterval: 2,
ewcLambda: 2000,
validationSplit: 0.1,
}, taskAdapter);
// Build training data: input = encoded feature, target = task labels
let labeledCount = 0;
const taskTrainingData = [];
for (const ef of encodedFeatures) {
const labels = createLabels(ef, allVitals);
if (!labels) continue;
// Construct target vector: [presence(1), activity(3), vitals(2), padding(122)]
// Total: 128-dim to match adapter output dim
const target = new Array(CONFIG.embeddingDim).fill(0);
target[0] = labels.presence;
target[1] = labels.activity[0]; // still
target[2] = labels.activity[1]; // moving
target[3] = labels.activity[2]; // empty
target[4] = labels.vitalsTarget[0]; // breathing normalized
target[5] = labels.vitalsTarget[1]; // heartrate normalized
taskTrainingData.push({
input: ef.embedding,
target,
quality: 1.0,
});
labeledCount++;
}
console.log(` Labeled samples: ${labeledCount} / ${encodedFeatures.length} (${(labeledCount / encodedFeatures.length * 100).toFixed(1)}%)`);
if (taskTrainingData.length > 0) {
taskPipeline.addData(taskTrainingData);
const taskResult = taskPipeline.train();
console.log(` Epochs completed: ${taskResult.epochs}`);
console.log(` Final loss: ${taskResult.finalLoss.toFixed(6)}`);
console.log(` Best val loss: ${taskResult.bestValLoss.toFixed(6)}`);
console.log(` Early stopped: ${taskResult.earlyStopped}`);
console.log(` Duration: ${taskResult.durationMs}ms`);
} else {
console.log(' WARN: No labeled data available, skipping task head training.');
}
// -----------------------------------------------------------------------
// Phase 3: LoRA refinement (per-node room adaptation)
// -----------------------------------------------------------------------
console.log('\n[6/9] Phase 3: LoRA refinement (per-node adaptation)...');
const loraManager = new LoraManager({
rank: CONFIG.loraRank,
alpha: CONFIG.loraRank * 2,
dropout: 0.1,
targetModules: ['room_adapt'],
});
const nodeIds = [...new Set(allFeatures.map(f => f.nodeId))];
for (const nodeId of nodeIds) {
console.log(` Training LoRA adapter for node ${nodeId}...`);
const nodeAdapter = loraManager.create(
`node-${nodeId}`,
{ rank: CONFIG.loraRank, alpha: CONFIG.loraRank * 2, dropout: 0.1 },
CONFIG.embeddingDim,
CONFIG.embeddingDim
);
// Train on node-specific data
const nodeFeatures = encodedFeatures.filter(f => f.nodeId === nodeId);
const nodePipeline = new TrainingPipeline({
learningRate: CONFIG.learningRate * 0.5,
batchSize: Math.min(CONFIG.batchSize, nodeFeatures.length),
epochs: 5,
scheduler: 'cosine',
ewcLambda: 3000,
}, nodeAdapter);
const nodeData = [];
for (const nf of nodeFeatures) {
const labels = createLabels(nf, allVitals);
if (!labels) continue;
const target = new Array(CONFIG.embeddingDim).fill(0);
target[0] = labels.presence;
target[1] = labels.activity[0];
target[2] = labels.activity[1];
target[3] = labels.activity[2];
target[4] = labels.vitalsTarget[0];
target[5] = labels.vitalsTarget[1];
nodeData.push({ input: nf.embedding, target, quality: 1.0 });
}
if (nodeData.length > 0) {
nodePipeline.addData(nodeData);
const nodeResult = nodePipeline.train();
console.log(` Node ${nodeId}: ${nodeData.length} samples, loss=${nodeResult.finalLoss.toFixed(6)}, ${nodeResult.durationMs}ms`);
}
}
console.log(` LoRA adapters: ${loraManager.list().join(', ')}`);
console.log(` Total LoRA parameters: ${loraManager.stats().totalParameters}`);
// -----------------------------------------------------------------------
// Phase 4: Quantization (TurboQuant)
// -----------------------------------------------------------------------
console.log('\n[7/9] Phase 4: Quantization (TurboQuant)...');
const mergedWeights = taskAdapter.merge();
const flatWeights = new Float32Array(mergedWeights.flat());
const quantResults = {};
for (const bits of [2, 4, 8]) {
const qr = quantizeWeights(flatWeights, bits);
const deq = dequantizeWeights(qr.quantized, qr.scale, qr.zeroPoint, bits);
const rmse = quantizationQuality(flatWeights, deq);
quantResults[bits] = { ...qr, rmse };
console.log(` ${bits}-bit: compression=${qr.compressionRatio.toFixed(1)}x, RMSE=${rmse.toFixed(6)}, size=${(qr.quantizedSize / 1024).toFixed(1)}KB`);
}
// -----------------------------------------------------------------------
// Phase 5: EWC consolidation
// -----------------------------------------------------------------------
console.log('\n[8/9] Phase 5: EWC consolidation...');
const ewcManager = taskPipeline.getEwcManager();
const ewcWeights = taskAdapter.merge().flat();
ewcManager.registerTask('csi-pretraining-v1', ewcWeights);
// Register per-node tasks for EWC protection
for (const nodeId of nodeIds) {
const nodeAdapter = loraManager.get(`node-${nodeId}`);
if (nodeAdapter) {
const nodeWeights = nodeAdapter.merge().flat();
ewcManager.registerTask(`node-${nodeId}-adaptation`, nodeWeights);
}
}
const ewcStats = ewcManager.stats();
console.log(` Tasks learned: ${ewcStats.tasksLearned}`);
console.log(` Fisher computed: ${ewcStats.fisherComputed}`);
console.log(` Protection strength: ${ewcStats.protectionStrength}`);
console.log(` Forgetting rate: ${ewcStats.forgettingRate.toFixed(4)}`);
// -----------------------------------------------------------------------
// Step 9: Export
// -----------------------------------------------------------------------
console.log('\n[9/9] Exporting models...');
// Ensure output directory exists
fs.mkdirSync(CONFIG.outputDir, { recursive: true });
// 9a: SafeTensors export via ModelExporter
const exporter = new ModelExporter();
const exportModel = {
metadata: {
name: 'wifi-densepose-csi-embedding',
version: '1.0.0',
architecture: 'csi-encoder-8-64-128',
training: {
steps: contrastiveResult.history.length * contrastiveTrainer.getTripletCount(),
loss: contrastiveResult.finalLoss,
learningRate: CONFIG.learningRate,
},
custom: {
inputDim: CONFIG.inputDim,
hiddenDim: CONFIG.hiddenDim,
embeddingDim: CONFIG.embeddingDim,
totalFrames: allFeatures.length,
totalTriplets: triplets.length,
nodes: nodeIds,
quantizationBits: CONFIG.quantizeBits,
},
},
loraWeights: taskAdapter.getWeights(),
loraConfig: taskAdapter.getConfig(),
ewcStats: ewcStats,
tensors: new Map(),
};
// Add encoder weights as tensors
exportModel.tensors.set('encoder.w1', new Float32Array(encoder.w1));
exportModel.tensors.set('encoder.b1', new Float32Array(encoder.b1));
exportModel.tensors.set('encoder.w2', new Float32Array(encoder.w2));
exportModel.tensors.set('encoder.b2', new Float32Array(encoder.b2));
// SafeTensors
const safetensorsBuffer = exporter.toSafeTensors(exportModel);
fs.writeFileSync(path.join(CONFIG.outputDir, 'model.safetensors'), safetensorsBuffer);
console.log(` SafeTensors: ${path.join(CONFIG.outputDir, 'model.safetensors')} (${(safetensorsBuffer.length / 1024).toFixed(1)} KB)`);
// HuggingFace export
const hfExport = exporter.toHuggingFace(exportModel);
fs.writeFileSync(path.join(CONFIG.outputDir, 'config.json'), hfExport.config);
console.log(` HF config: ${path.join(CONFIG.outputDir, 'config.json')}`);
// JSON export
const jsonExport = exporter.toJSON(exportModel);
fs.writeFileSync(path.join(CONFIG.outputDir, 'model.json'), jsonExport);
// 9b: Quantized models
const quantDir = path.join(CONFIG.outputDir, 'quantized');
fs.mkdirSync(quantDir, { recursive: true });
for (const [bits, qr] of Object.entries(quantResults)) {
const qPath = path.join(quantDir, `model-q${bits}.bin`);
fs.writeFileSync(qPath, Buffer.from(qr.quantized));
console.log(` Quantized ${bits}-bit: ${qPath} (${(qr.quantizedSize / 1024).toFixed(1)} KB)`);
}
// 9c: Per-node LoRA adapters
const loraDir = path.join(CONFIG.outputDir, 'lora');
fs.mkdirSync(loraDir, { recursive: true });
for (const adapterId of loraManager.list()) {
const adapter = loraManager.get(adapterId);
const loraPath = path.join(loraDir, `${adapterId}.json`);
fs.writeFileSync(loraPath, adapter.toJSON());
console.log(` LoRA adapter: ${loraPath}`);
}
// 9d: RVF (RuVector Format) — JSONL for Cognitum Seed ingest
const rvfPath = path.join(CONFIG.outputDir, 'model.rvf.jsonl');
const rvfLines = [
JSON.stringify({ type: 'metadata', ...exportModel.metadata }),
JSON.stringify({ type: 'encoder', w1_shape: [CONFIG.inputDim, CONFIG.hiddenDim], w2_shape: [CONFIG.hiddenDim, CONFIG.embeddingDim] }),
JSON.stringify({ type: 'lora', config: taskAdapter.getConfig(), parameters: taskAdapter.numParameters() }),
JSON.stringify({ type: 'ewc', stats: ewcStats }),
JSON.stringify({ type: 'quantization', default_bits: CONFIG.quantizeBits, variants: Object.keys(quantResults).map(Number) }),
];
fs.writeFileSync(rvfPath, rvfLines.join('\n'));
console.log(` RVF manifest: ${rvfPath}`);
// 9e: Training metrics
const metricsPath = path.join(CONFIG.outputDir, 'training-metrics.json');
const metrics = {
timestamp: new Date().toISOString(),
totalDurationMs: Date.now() - startTime,
data: {
files: files.map(f => path.basename(f)),
totalFeatures: allFeatures.length,
totalVitals: allVitals.length,
totalRawCsi: allRawCsi.length,
nodes: nodeIds,
},
contrastive: {
triplets: triplets.length,
temporal: temporalCount,
crossNode: crossNodeCount,
hardNegatives: hardCount,
initialLoss: contrastiveResult.initialLoss,
finalLoss: contrastiveResult.finalLoss,
improvement: contrastiveResult.improvement,
durationMs: contrastiveResult.durationMs,
lossHistory: contrastiveResult.history,
},
taskHeads: taskTrainingData.length > 0 ? {
samples: labeledCount,
finalLoss: taskPipeline.getMetrics().trainLoss,
} : null,
lora: {
adapters: loraManager.list(),
totalParameters: loraManager.stats().totalParameters,
},
quantization: Object.fromEntries(
Object.entries(quantResults).map(([bits, qr]) => [
`q${bits}`,
{ compressionRatio: qr.compressionRatio, rmse: qr.rmse, sizeKB: qr.quantizedSize / 1024 },
])
),
ewc: ewcStats,
config: CONFIG,
};
fs.writeFileSync(metricsPath, JSON.stringify(metrics, null, 2));
console.log(` Metrics: ${metricsPath}`);
// -----------------------------------------------------------------------
// Summary
// -----------------------------------------------------------------------
const totalDuration = Date.now() - startTime;
console.log('\n=== Training Complete ===');
console.log(` Total duration: ${(totalDuration / 1000).toFixed(1)}s`);
console.log(` Output directory: ${path.resolve(CONFIG.outputDir)}`);
console.log(` Model size (fp32): ${(safetensorsBuffer.length / 1024).toFixed(1)} KB`);
console.log(` Model size (q${CONFIG.quantizeBits}): ${(quantResults[CONFIG.quantizeBits]?.quantizedSize / 1024 || 0).toFixed(1)} KB`);
console.log(` LoRA adapters: ${loraManager.count()}`);
console.log(` EWC tasks protected: ${ewcStats.tasksLearned}`);
// -----------------------------------------------------------------------
// Optional benchmark
// -----------------------------------------------------------------------
if (CONFIG.benchmark) {
console.log('\n=== Benchmark Mode ===');
runBenchmark(encoder, taskAdapter, allFeatures, allVitals, quantResults);
}
}
// ---------------------------------------------------------------------------
// Benchmark
// ---------------------------------------------------------------------------
function runBenchmark(encoder, adapter, features, vitals, quantResults) {
const N = Math.min(1000, features.length);
const testFeatures = features.slice(0, N);
// Inference latency
console.log(`\nInference latency (${N} samples):`);
const latencies = [];
for (const f of testFeatures) {
const start = process.hrtime.bigint();
const emb = encoder.encode(f.features);
adapter.forward(emb);
const elapsed = Number(process.hrtime.bigint() - start) / 1e6;
latencies.push(elapsed);
}
latencies.sort((a, b) => a - b);
const mean = latencies.reduce((a, b) => a + b, 0) / latencies.length;
const p95 = latencies[Math.floor(latencies.length * 0.95)];
const p99 = latencies[Math.floor(latencies.length * 0.99)];
console.log(` Mean: ${mean.toFixed(3)}ms`);
console.log(` P95: ${p95.toFixed(3)}ms`);
console.log(` P99: ${p99.toFixed(3)}ms`);
console.log(` Throughput: ${(1000 / mean).toFixed(0)} embeddings/sec`);
// Embedding quality: cosine similarity for temporal pairs
console.log('\nEmbedding quality (temporal pairs):');
let posSimilarities = [];
let negSimilarities = [];
for (let i = 0; i < Math.min(features.length - 1, 200); i++) {
const f1 = features[i];
const f2 = features[i + 1];
const timeDiff = Math.abs(f2.timestamp - f1.timestamp);
const emb1 = encoder.encode(f1.features);
const emb2 = encoder.encode(f2.features);
const sim = cosineSimilarity(emb1, emb2);
if (timeDiff <= 1.0) {
posSimilarities.push(sim);
} else if (timeDiff >= 30.0) {
negSimilarities.push(sim);
}
}
if (posSimilarities.length > 0) {
const avgPos = posSimilarities.reduce((a, b) => a + b, 0) / posSimilarities.length;
console.log(` Positive pair avg similarity: ${avgPos.toFixed(4)} (n=${posSimilarities.length})`);
}
if (negSimilarities.length > 0) {
const avgNeg = negSimilarities.reduce((a, b) => a + b, 0) / negSimilarities.length;
console.log(` Negative pair avg similarity: ${avgNeg.toFixed(4)} (n=${negSimilarities.length})`);
}
// Presence detection accuracy
console.log('\nPresence detection accuracy:');
let correct = 0, total = 0;
for (const f of testFeatures) {
const labels = createLabels(f, vitals);
if (!labels) continue;
const emb = encoder.encode(f.features);
const out = adapter.forward(emb);
const predicted = out[0] > 0.5 ? 1 : 0;
if (predicted === labels.presence) correct++;
total++;
}
if (total > 0) {
console.log(` Accuracy: ${(correct / total * 100).toFixed(1)}% (${correct}/${total})`);
}
// Memory usage per quantization level
console.log('\nMemory usage per quantization level:');
console.log(' Bits | Size (KB) | Compression | RMSE');
console.log(' -----|-----------|-------------|------');
for (const [bits, qr] of Object.entries(quantResults)) {
console.log(` ${bits.padStart(4)} | ${(qr.quantizedSize / 1024).toFixed(1).padStart(9)} | ${qr.compressionRatio.toFixed(1).padStart(11)}x | ${qr.rmse.toFixed(6)}`);
}
console.log(` fp32 | ${(quantResults[Object.keys(quantResults)[0]].originalSize / 1024).toFixed(1).padStart(9)} | ${' '.padStart(10)}1x | 0.000000`);
}
// Run
main().catch(err => {
console.error('Training pipeline failed:', err);
process.exit(1);
});