227 lines
7.3 KiB
JavaScript
227 lines
7.3 KiB
JavaScript
/**
|
||
* CNN Embedder — Lightweight MobileNet-V3-style feature extractor.
|
||
*
|
||
* Architecture mirrors ruvector-cnn: Conv2D → BatchNorm → ReLU → Pool → Project → L2 Normalize
|
||
* Uses pre-seeded random weights (deterministic). When ruvector-cnn-wasm is available,
|
||
* transparently delegates to the WASM implementation.
|
||
*
|
||
* Two instances are created: one for video frames, one for CSI pseudo-images.
|
||
*/
|
||
|
||
// Seeded PRNG for deterministic weight initialization
|
||
function mulberry32(seed) {
|
||
return function() {
|
||
let t = (seed += 0x6D2B79F5);
|
||
t = Math.imul(t ^ (t >>> 15), t | 1);
|
||
t ^= t + Math.imul(t ^ (t >>> 7), t | 61);
|
||
return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
|
||
};
|
||
}
|
||
|
||
export class CnnEmbedder {
|
||
/**
|
||
* @param {object} opts
|
||
* @param {number} opts.inputSize - Square input dimension (default 56 for speed)
|
||
* @param {number} opts.embeddingDim - Output embedding dimension (default 128)
|
||
* @param {boolean} opts.normalize - L2 normalize output
|
||
* @param {number} opts.seed - PRNG seed for weight init
|
||
*/
|
||
constructor(opts = {}) {
|
||
this.inputSize = opts.inputSize || 56;
|
||
this.embeddingDim = opts.embeddingDim || 128;
|
||
this.normalize = opts.normalize !== false;
|
||
this.wasmEmbedder = null;
|
||
|
||
// Initialize weights with deterministic PRNG
|
||
const rng = mulberry32(opts.seed || 42);
|
||
const randRange = (lo, hi) => lo + rng() * (hi - lo);
|
||
|
||
// Conv 3x3: 3 input channels → 16 output channels
|
||
this.convWeights = new Float32Array(3 * 3 * 3 * 16);
|
||
for (let i = 0; i < this.convWeights.length; i++) {
|
||
this.convWeights[i] = randRange(-0.15, 0.15);
|
||
}
|
||
|
||
// BatchNorm params (16 channels)
|
||
this.bnGamma = new Float32Array(16).fill(1.0);
|
||
this.bnBeta = new Float32Array(16).fill(0.0);
|
||
this.bnMean = new Float32Array(16).fill(0.0);
|
||
this.bnVar = new Float32Array(16).fill(1.0);
|
||
|
||
// Projection: 16 → embeddingDim
|
||
this.projWeights = new Float32Array(16 * this.embeddingDim);
|
||
for (let i = 0; i < this.projWeights.length; i++) {
|
||
this.projWeights[i] = randRange(-0.1, 0.1);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Try to load WASM embedder from ruvector-cnn-wasm package
|
||
* @param {string} wasmPath - Path to the WASM package directory
|
||
*/
|
||
async tryLoadWasm(wasmPath) {
|
||
try {
|
||
const mod = await import(`${wasmPath}/ruvector_cnn_wasm.js`);
|
||
await mod.default();
|
||
const config = new mod.EmbedderConfig();
|
||
config.input_size = this.inputSize;
|
||
config.embedding_dim = this.embeddingDim;
|
||
config.normalize = this.normalize;
|
||
this.wasmEmbedder = new mod.WasmCnnEmbedder(config);
|
||
console.log('[CNN] WASM embedder loaded successfully');
|
||
return true;
|
||
} catch (e) {
|
||
console.log('[CNN] WASM not available, using JS fallback:', e.message);
|
||
return false;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Extract embedding from RGB image data
|
||
* @param {Uint8Array} rgbData - RGB pixel data (H*W*3)
|
||
* @param {number} width
|
||
* @param {number} height
|
||
* @returns {Float32Array} embedding vector
|
||
*/
|
||
extract(rgbData, width, height) {
|
||
if (this.wasmEmbedder) {
|
||
try {
|
||
const result = this.wasmEmbedder.extract(rgbData, width, height);
|
||
return new Float32Array(result);
|
||
} catch (_) { /* fallback to JS */ }
|
||
}
|
||
return this._extractJS(rgbData, width, height);
|
||
}
|
||
|
||
_extractJS(rgbData, width, height) {
|
||
// 1. Resize to inputSize × inputSize if needed
|
||
const sz = this.inputSize;
|
||
let input;
|
||
if (width === sz && height === sz) {
|
||
input = new Float32Array(rgbData.length);
|
||
for (let i = 0; i < rgbData.length; i++) input[i] = rgbData[i] / 255.0;
|
||
} else {
|
||
input = this._resize(rgbData, width, height, sz, sz);
|
||
}
|
||
|
||
// 2. ImageNet normalization
|
||
const mean = [0.485, 0.456, 0.406];
|
||
const std = [0.229, 0.224, 0.225];
|
||
const pixels = sz * sz;
|
||
for (let i = 0; i < pixels; i++) {
|
||
input[i * 3] = (input[i * 3] - mean[0]) / std[0];
|
||
input[i * 3 + 1] = (input[i * 3 + 1] - mean[1]) / std[1];
|
||
input[i * 3 + 2] = (input[i * 3 + 2] - mean[2]) / std[2];
|
||
}
|
||
|
||
// 3. Conv2D 3x3 (3 → 16 channels)
|
||
const convOut = this._conv2d3x3(input, sz, sz, 3, 16);
|
||
|
||
// 4. BatchNorm
|
||
this._batchNorm(convOut, 16);
|
||
|
||
// 5. ReLU
|
||
for (let i = 0; i < convOut.length; i++) {
|
||
if (convOut[i] < 0) convOut[i] = 0;
|
||
}
|
||
|
||
// 6. Global average pooling → 16-dim
|
||
const outH = sz - 2, outW = sz - 2;
|
||
const pooled = new Float32Array(16);
|
||
const spatial = outH * outW;
|
||
for (let i = 0; i < spatial; i++) {
|
||
for (let c = 0; c < 16; c++) {
|
||
pooled[c] += convOut[i * 16 + c];
|
||
}
|
||
}
|
||
for (let c = 0; c < 16; c++) pooled[c] /= spatial;
|
||
|
||
// 7. Linear projection → embeddingDim
|
||
const emb = new Float32Array(this.embeddingDim);
|
||
for (let o = 0; o < this.embeddingDim; o++) {
|
||
let sum = 0;
|
||
for (let i = 0; i < 16; i++) {
|
||
sum += pooled[i] * this.projWeights[i * this.embeddingDim + o];
|
||
}
|
||
emb[o] = sum;
|
||
}
|
||
|
||
// 8. L2 normalize
|
||
if (this.normalize) {
|
||
let norm = 0;
|
||
for (let i = 0; i < emb.length; i++) norm += emb[i] * emb[i];
|
||
norm = Math.sqrt(norm);
|
||
if (norm > 1e-8) {
|
||
for (let i = 0; i < emb.length; i++) emb[i] /= norm;
|
||
}
|
||
}
|
||
|
||
return emb;
|
||
}
|
||
|
||
_conv2d3x3(input, H, W, Cin, Cout) {
|
||
const outH = H - 2, outW = W - 2;
|
||
const output = new Float32Array(outH * outW * Cout);
|
||
for (let y = 0; y < outH; y++) {
|
||
for (let x = 0; x < outW; x++) {
|
||
for (let co = 0; co < Cout; co++) {
|
||
let sum = 0;
|
||
for (let ky = 0; ky < 3; ky++) {
|
||
for (let kx = 0; kx < 3; kx++) {
|
||
for (let ci = 0; ci < Cin; ci++) {
|
||
const px = ((y + ky) * W + (x + kx)) * Cin + ci;
|
||
const wt = (((ky * 3 + kx) * Cin) + ci) * Cout + co;
|
||
sum += input[px] * this.convWeights[wt];
|
||
}
|
||
}
|
||
}
|
||
output[(y * outW + x) * Cout + co] = sum;
|
||
}
|
||
}
|
||
}
|
||
return output;
|
||
}
|
||
|
||
_batchNorm(data, channels) {
|
||
const spatial = data.length / channels;
|
||
for (let i = 0; i < spatial; i++) {
|
||
for (let c = 0; c < channels; c++) {
|
||
const idx = i * channels + c;
|
||
data[idx] = this.bnGamma[c] * (data[idx] - this.bnMean[c]) / Math.sqrt(this.bnVar[c] + 1e-5) + this.bnBeta[c];
|
||
}
|
||
}
|
||
}
|
||
|
||
_resize(rgbData, srcW, srcH, dstW, dstH) {
|
||
const output = new Float32Array(dstW * dstH * 3);
|
||
const xRatio = srcW / dstW;
|
||
const yRatio = srcH / dstH;
|
||
for (let y = 0; y < dstH; y++) {
|
||
for (let x = 0; x < dstW; x++) {
|
||
const sx = Math.min(Math.floor(x * xRatio), srcW - 1);
|
||
const sy = Math.min(Math.floor(y * yRatio), srcH - 1);
|
||
const srcIdx = (sy * srcW + sx) * 3;
|
||
const dstIdx = (y * dstW + x) * 3;
|
||
output[dstIdx] = rgbData[srcIdx] / 255.0;
|
||
output[dstIdx + 1] = rgbData[srcIdx + 1] / 255.0;
|
||
output[dstIdx + 2] = rgbData[srcIdx + 2] / 255.0;
|
||
}
|
||
}
|
||
return output;
|
||
}
|
||
|
||
/** Cosine similarity between two embeddings */
|
||
static cosineSimilarity(a, b) {
|
||
let dot = 0, normA = 0, normB = 0;
|
||
for (let i = 0; i < a.length; i++) {
|
||
dot += a[i] * b[i];
|
||
normA += a[i] * a[i];
|
||
normB += b[i] * b[i];
|
||
}
|
||
normA = Math.sqrt(normA);
|
||
normB = Math.sqrt(normB);
|
||
if (normA < 1e-8 || normB < 1e-8) return 0;
|
||
return dot / (normA * normB);
|
||
}
|
||
}
|