Commit Graph

3 Commits

Author SHA1 Message Date
maderix c3c5094865 Fixed the dynamic pipeline logit generation 2026-03-06 04:51:32 -08:00
maderix 926f977b40 Fix backward pass: global loss scaling, weight transpose, AdamW, activation clipping
Three bugs prevented loss from converging below 5.5 (unigram plateau):

1. FP16 underflow in ANE backward matmuls: gradient (~8e-5) × weight (~0.036)
   products flushed to zero in fp16. Fixed with global loss scaling (256×)
   applied once to dlogits, divided out before Adam update.

2. Backward weight staging used raw weights instead of transposed — all 4
   backward kernels (wotBwd, qkvBwd, ffnBwdW2t, ffnBwdW13t) now use
   pre-transposed buffers (Wot_buf, Wqt_buf, etc.).

3. Added AdamW (decoupled weight decay, wd=0.1 for weights, 0.0 for norms),
   activation clipping (act_clip=20), gradient clipping, cosine LR schedule,
   per-layer IOSurface weight pre-staging, and vocab compaction.

Loss now drops 9.14 → 5.74 in 500 steps from random init (87ms/step).
2026-03-05 07:23:08 -08:00
maderix cb474e1537 Add dynamic weight training pipeline — 110ms/step without recompilation
Dynamic weight pipeline that eliminates the ~3.7s recompile-every-10-steps
bottleneck. Weights are passed via IOSurface spatial dimension instead of
baked as constants, so kernels compile once at startup (345ms) and run
indefinitely without exec() restart.

Key components:
- training_dynamic/ — full pipeline (config, IO, MIL generators, train loop)
  - 9 dynamic kernels shared across all 12 layers
  - Vocab compaction 32K→9.2K for faster classifier
  - Vectorized cross-entropy with vDSP/NEON
  - Adam optimizer with gradient clipping + cosine LR schedule
  - Checkpoint save/resume

- test_dynamic_matmul.m — validates dynamic weight matmul vs cblas
- test_weight_patch.m — tests weight update via IOSurface

- dashboard.py — updated with --dynamic flag for v2 pipeline support,
  improved step regex parsing, --scratch/--lr/--accum CLI args

Performance: 110ms/step steady-state (no recompile overhead)
  ane_fwd=21 ane_bwd=28 io_fwd=12 io_bwd=15 silu=10 cls=13 rms=5 ms
2026-03-03 04:34:55 -08:00