wifi-densepose/vendor/sublinear-time-solver/js/fast-solver.js

416 lines
12 KiB
JavaScript

/**
* Fast Node.js solver implementation optimized to beat Python benchmarks
*
* This addresses the critical MCP Dense performance issue that's 190x slower than Python.
* Key optimizations:
* - Native sparse CSR format
* - Manual loop unrolling
* - Memory-efficient data structures
* - Prepared for WASM integration
*/
class FastCSRMatrix {
constructor(values, colIndices, rowPtr, rows, cols) {
this.values = new Float64Array(values);
this.colIndices = new Uint32Array(colIndices);
this.rowPtr = new Uint32Array(rowPtr);
this.rows = rows;
this.cols = cols;
}
static fromTriplets(triplets, rows, cols) {
// Sort triplets by row, then column for optimal CSR construction
triplets.sort((a, b) => {
if (a[0] !== b[0]) return a[0] - b[0];
return a[1] - b[1];
});
const values = [];
const colIndices = [];
const rowPtr = new Array(rows + 1).fill(0);
let currentRow = 0;
for (const [row, col, val] of triplets) {
// Fill row pointers
while (currentRow <= row) {
rowPtr[currentRow] = values.length;
currentRow++;
}
values.push(val);
colIndices.push(col);
}
// Fill remaining row pointers
while (currentRow <= rows) {
rowPtr[currentRow] = values.length;
currentRow++;
}
return new FastCSRMatrix(values, colIndices, rowPtr, rows, cols);
}
/**
* Ultra-fast matrix-vector multiplication optimized for performance
* This is the critical operation that needs to beat Python
*/
multiplyVector(x, y) {
// Fill output with zeros
y.fill(0.0);
// Process rows with manual loop unrolling
for (let row = 0; row < this.rows; row++) {
const start = this.rowPtr[row];
const end = this.rowPtr[row + 1];
const nnz = end - start;
if (nnz === 0) continue;
// For small rows, use simple accumulation
if (nnz <= 4) {
let sum = 0.0;
for (let idx = start; idx < end; idx++) {
sum += this.values[idx] * x[this.colIndices[idx]];
}
y[row] = sum;
} else {
// For larger rows, unroll loop for better performance
const chunks = Math.floor(nnz / 4);
const remainder = nnz % 4;
let sum = 0.0;
// Process 4 elements at a time
let idx = start;
for (let chunk = 0; chunk < chunks; chunk++) {
sum += this.values[idx] * x[this.colIndices[idx]] +
this.values[idx + 1] * x[this.colIndices[idx + 1]] +
this.values[idx + 2] * x[this.colIndices[idx + 2]] +
this.values[idx + 3] * x[this.colIndices[idx + 3]];
idx += 4;
}
// Handle remainder
for (let i = 0; i < remainder; i++) {
sum += this.values[idx] * x[this.colIndices[idx]];
idx++;
}
y[row] = sum;
}
}
}
get nnz() {
return this.values.length;
}
}
class FastConjugateGradient {
constructor(maxIterations = 1000, tolerance = 1e-10) {
this.maxIterations = maxIterations;
this.tolerance = tolerance;
this.toleranceSq = tolerance * tolerance;
}
/**
* Solve Ax = b using optimized conjugate gradient
* Targets sub-50ms performance for 1000x1000 matrices
*/
solve(matrix, b) {
const n = matrix.rows;
if (matrix.rows !== matrix.cols) {
throw new Error('Matrix must be square');
}
if (b.length !== n) {
throw new Error('Vector size mismatch');
}
// Pre-allocate all vectors with Float64Array for better performance
const x = new Float64Array(n);
const r = new Float64Array(b); // r = b - A*x (initially r = b since x = 0)
const p = new Float64Array(b); // p = r initially
const ap = new Float64Array(n);
let rsold = this.dotProductFast(r, r);
for (let iteration = 0; iteration < this.maxIterations; iteration++) {
if (rsold <= this.toleranceSq) {
break;
}
// ap = A * p
matrix.multiplyVector(p, ap);
// alpha = rsold / (p^T * ap)
const pap = this.dotProductFast(p, ap);
if (Math.abs(pap) < 1e-16) {
break;
}
const alpha = rsold / pap;
// x = x + alpha * p
this.axpyFast(alpha, p, x);
// r = r - alpha * ap
this.axpyFast(-alpha, ap, r);
const rsnew = this.dotProductFast(r, r);
const beta = rsnew / rsold;
// p = r + beta * p
for (let i = 0; i < n; i++) {
p[i] = r[i] + beta * p[i];
}
rsold = rsnew;
}
return Array.from(x);
}
/**
* Fast dot product with manual unrolling
*/
dotProductFast(x, y) {
const n = x.length;
const chunks = Math.floor(n / 4);
const remainder = n % 4;
let sum = 0.0;
// Process 4 elements at a time
let i = 0;
for (let chunk = 0; chunk < chunks; chunk++) {
sum += x[i] * y[i] +
x[i + 1] * y[i + 1] +
x[i + 2] * y[i + 2] +
x[i + 3] * y[i + 3];
i += 4;
}
// Handle remainder
for (let j = 0; j < remainder; j++) {
sum += x[i] * y[i];
i++;
}
return sum;
}
/**
* Fast AXPY operation: y = alpha * x + y
*/
axpyFast(alpha, x, y) {
const n = x.length;
const chunks = Math.floor(n / 4);
const remainder = n % 4;
// Process 4 elements at a time
let i = 0;
for (let chunk = 0; chunk < chunks; chunk++) {
y[i] += alpha * x[i];
y[i + 1] += alpha * x[i + 1];
y[i + 2] += alpha * x[i + 2];
y[i + 3] += alpha * x[i + 3];
i += 4;
}
// Handle remainder
for (let j = 0; j < remainder; j++) {
y[i] += alpha * x[i];
i++;
}
}
}
/**
* Memory-efficient buffer pool for vector reuse
*/
class VectorPool {
constructor(size, capacity = 8) {
this.size = size;
this.buffers = [];
// Pre-allocate buffers
for (let i = 0; i < capacity; i++) {
this.buffers.push(new Float64Array(size));
}
}
getBuffer() {
return this.buffers.pop() || new Float64Array(this.size);
}
returnBuffer(buffer) {
if (buffer.length === this.size && this.buffers.length < 8) {
buffer.fill(0.0);
this.buffers.push(buffer);
}
}
}
/**
* WASM-ready solver interface
* Prepares for WASM integration when the module becomes available
*/
class FastSolver {
constructor(config = {}) {
this.maxIterations = config.maxIterations || 1000;
this.tolerance = config.tolerance || 1e-10;
this.useWasm = config.useWasm && this.isWasmAvailable();
this.vectorPool = null;
}
isWasmAvailable() {
// Check if WASM module is loaded
try {
return typeof WebAssembly !== 'undefined' &&
global.wasmSolver !== undefined;
} catch (e) {
return false;
}
}
/**
* Create optimized sparse matrix from triplets
*/
createMatrix(triplets, rows, cols) {
if (this.useWasm) {
// Use WASM implementation when available
return this.createWasmMatrix(triplets, rows, cols);
} else {
// Use fast JavaScript implementation
return FastCSRMatrix.fromTriplets(triplets, rows, cols);
}
}
createWasmMatrix(triplets, rows, cols) {
// Placeholder for WASM integration
// This would call the actual WASM module
console.log('WASM matrix creation not yet implemented, falling back to JS');
return FastCSRMatrix.fromTriplets(triplets, rows, cols);
}
/**
* Solve linear system with optimal method selection
*/
solve(matrix, b, options = {}) {
const startTime = process.hrtime.bigint();
// Initialize vector pool if needed
if (!this.vectorPool) {
this.vectorPool = new VectorPool(matrix.rows);
}
let result;
if (this.useWasm) {
result = this.solveWasm(matrix, b, options);
} else {
const solver = new FastConjugateGradient(this.maxIterations, this.tolerance);
result = solver.solve(matrix, b);
}
const endTime = process.hrtime.bigint();
const executionTime = Number(endTime - startTime) / 1e6; // Convert to milliseconds
return {
solution: result,
executionTime,
iterations: this.lastIterations || 0,
method: this.useWasm ? 'wasm' : 'javascript'
};
}
solveWasm(matrix, b, options) {
// Placeholder for WASM solver integration
console.log('WASM solver not yet implemented, falling back to JS');
const solver = new FastConjugateGradient(this.maxIterations, this.tolerance);
return solver.solve(matrix, b);
}
/**
* Generate test matrices for benchmarking
*/
generateTestMatrix(size, sparsity = 0.01) {
const triplets = [];
// Generate diagonally dominant sparse matrix
for (let i = 0; i < size; i++) {
// Diagonal element (make it dominant)
const diagVal = 5.0 + i * 0.01;
triplets.push([i, i, diagVal]);
// Off-diagonal elements
const nnzPerRow = Math.max(1, Math.floor(size * sparsity));
for (let k = 0; k < Math.min(nnzPerRow, 5); k++) {
const j = Math.floor(Math.random() * size);
if (i !== j) {
const val = Math.random() * 0.5; // Keep small for diagonal dominance
triplets.push([i, j, val]);
}
}
}
const matrix = this.createMatrix(triplets, size, size);
const b = new Array(size).fill(1.0); // Simple right-hand side
return { matrix, b };
}
/**
* Benchmark against Python baseline
* Target: beat 40ms for 1000x1000 matrices (Python baseline)
*/
benchmark(sizes = [100, 1000]) {
console.log('🚀 Fast Solver Benchmark - Targeting Python performance');
console.log('=' * 60);
const results = [];
for (const size of sizes) {
console.log(`\n📊 Testing ${size}x${size} matrix...`);
const { matrix, b } = this.generateTestMatrix(size, 0.001);
// Warm up
this.solve(matrix, b);
// Benchmark
const startTime = process.hrtime.bigint();
const result = this.solve(matrix, b);
const endTime = process.hrtime.bigint();
const timeMs = Number(endTime - startTime) / 1e6;
// Python baseline times (from performance analysis)
const pythonBaseline = size === 100 ? 2 : (size === 1000 ? 40 : 200);
const speedup = pythonBaseline / timeMs;
const status = speedup > 1 ? '✅ FASTER' : '❌ SLOWER';
console.log(` Time: ${timeMs.toFixed(2)}ms`);
console.log(` Python baseline: ${pythonBaseline}ms`);
console.log(` Speedup: ${speedup.toFixed(2)}x ${status}`);
console.log(` NNZ: ${matrix.nnz}`);
console.log(` Method: ${result.method}`);
results.push({
size,
timeMs,
pythonBaseline,
speedup,
nnz: matrix.nnz,
method: result.method
});
}
return results;
}
}
export {
FastCSRMatrix,
FastConjugateGradient,
VectorPool,
FastSolver
};