ANE/docs/ARCHITECTURE.md

15 KiB

ANE Training -- System Architecture

Training neural networks directly on Apple's Neural Engine via reverse-engineered private APIs (_ANEClient, _ANECompiler). No CoreML training APIs, no Metal, no GPU.

Project Structure

ANE/
+-- api_exploration.m          # ANE private API discovery
+-- inmem_basic.m              # In-memory MIL compilation proof-of-concept
+-- inmem_bench.m              # ANE dispatch latency across model sizes
+-- inmem_peak.m               # Peak TFLOPS via deep conv chains (self-contained)
+-- sram_bench.m               # SRAM capacity probing (performance cliff detection)
+-- sram_probe.m               # Fine-grained SRAM size exploration
+-- bridge/
|   +-- ane_bridge.h           # C-callable API for Python ctypes
|   +-- ane_bridge.m           # Bridge implementation
|   +-- Makefile               # Builds libane_bridge.dylib
|   +-- libane_bridge.dylib    # Pre-built shared library
+-- training/
|   +-- train_large.m          # Main: 12-layer training (CPU classifier)
|   +-- train_large_ane.m      # Variant: classifier + softmax on ANE
|   +-- stories_config.h       # Model constants, structs, alloc helpers
|   +-- stories_io.h           # IOSurface I/O, NEON fp16, compile/run
|   +-- stories_mil.h          # MIL generators for 6 fused ANE kernels
|   +-- stories_cpu_ops.h      # vDSP RMSNorm, cross-entropy, Adam, embedding
|   +-- ane_runtime.h          # Generalized ANE wrapper (multi-I/O)
|   +-- ane_mil_gen.h          # Composable MIL helpers (conv, matmul, fused QKV)
|   +-- ane_rmsnorm_bwd.h      # RMSNorm backward MIL (train_large_ane only)
|   +-- ane_classifier.h       # Classifier/softmax MIL (train_large_ane only)
|   +-- forward.h              # Gen1 forward pass (per-linear-kernel, all-CPU)
|   +-- backward.h             # Gen1 backward pass (all-CPU reference)
|   +-- model.h                # Gen1 Model struct, per-kernel compile
|   +-- dashboard.py           # TUI monitoring (loss, power, text generation)
|   +-- tokenize.py            # Extract pretokenized TinyStories data
|   +-- download_data.sh       # Download TinyStories from HuggingFace
|   +-- Makefile               # Build targets for training + tests
|   +-- test_*.m               # 12 unit test files
+-- docs/                      # This documentation
+-- scripts/                   # Automation scripts

Two Generations of Training Code

Gen1: model.h + forward.h + backward.h

The original correctness reference. One ANE kernel per linear projection (7 per layer + 1 classifier = 85 kernels total). Forward and backward are sequential all-CPU operations with optional ANE for the matmuls. No kernel fusion, no async overlap. Used for verifying Gen2's fused kernels produce correct results.

Gen2: train_large.m + stories_*.h (production)

The performance-optimized system. Uses 5 fused ANE kernels per layer (each performing multiple operations in a single dispatch). Weight gradients (dW) run asynchronously on CPU via GCD to overlap with ANE. All data is channel-first [C, S] fp16 on IOSurfaces.

The rest of this document describes Gen2.


Model Configuration

Stories110M -- a Llama2-architecture transformer:

Parameter Value Macro
Hidden dimension 768 DIM
FFN intermediate 2048 HIDDEN
Attention heads 12 HEADS
Head dimension 64 HD
Sequence length 256 SEQ
Layers 12 NLAYERS
Vocabulary 32000 VOCAB
Total parameters 109.53M TOTAL_PARAMS
Accumulation steps 10 ACCUM_STEPS
Max ANE compiles 100 MAX_COMPILES

ANE Kernel Fusion Map

Each training step dispatches 6 kernel types per layer. 5 are weight-bearing (recompiled each batch), 1 is weight-free (compiled once).

Kernel Generator Fused Operations Baked Weights Input Shape Output Shape
fwdAttn gen_sdpa_fwd_taps() RMSNorm1, Wq/Wk/Wv conv, reshape, transpose, Q at K^T matmul, scale, causal mask, softmax, scores at V matmul, Wo conv rms_att, Wq, Wk, Wv, Wo, mask [1,DIM,1,SEQ] [1,6*DIM,1,SEQ]
fwdFFN gen_ffn_fwd_taps() RMSNorm2, W1/W3 conv, sigmoid, SiLU gating, W2 conv rms_ffn, W1, W3, W2 [1,DIM,1,SEQ] [1,2D+3H,1,SEQ]
ffnBwd gen_ffn_bwd() W2^T conv, SiLU derivative, W1^T/W3^T conv, add W2^T, W1^T, W3^T [1,D+2H,1,SEQ] [1,D+2H,1,SEQ]
sdpaBwd1 gen_sdpa_bwd1() Wo^T conv, reshape, Q at K^T recompute, softmax, dV matmul, dP matmul Wo^T, mask [1,4*DIM,1,SEQ] [1,D+2*SC,1,SEQ]
sdpaBwd2 gen_sdpa_bwd2() softmax Jacobian, scale, dQ=dS at K matmul, dK=dS^T at Q matmul (none) [1,2SC+2D,1,SEQ] [1,2*DIM,1,SEQ]
qkvBwd gen_qkvb() Wq^T/Wk^T/Wv^T conv, sum Wq^T, Wk^T, Wv^T [1,3*DIM,1,SEQ] [1,DIM,1,SEQ]

Where D=DIM=768, H=HIDDEN=2048, SC=SCORE_CH=HEADS*SEQ=3072.

"Taps" in forward kernels: intermediate values (Q, K, V, attention output, norms) are concatenated onto the output via concat(axis=1) so backward kernels can read them without CPU recomputation.


CPU vs ANE Operation Split

Operation Location Reason
Embedding lookup/backward CPU Scatter/gather by token index
RMSNorm forward ANE Fused into fwdAttn/fwdFFN kernels
QKV projections ANE 1x1 conv = matmul
Multi-head attention (SDPA) ANE Decomposed Q at K^T + mask + softmax + scores at V
FFN (SwiGLU) ANE W1,W3 conv + sigmoid + gate + W2 conv
Residual connections CPU Simple vDSP_vadd
Final RMSNorm CPU (or ANE in _ane variant) Standalone, not fused with other ops
Classifier matmul CPU cblas (or ANE in _ane variant) [VOCAB,DIM] x [DIM,SEQ]
Cross-entropy + softmax CPU (partially ANE in _ane) Target indexing requires CPU
dW weight gradients CPU (async cblas) Outer products, independent of backward data flow
RMSNorm backward CPU (or ANE in _ane variant) vDSP vectorized
Adam optimizer CPU In-place weight mutation

Training Step Swim-Lane Diagram

One complete training step showing CPU, ANE, and async GCD operations interleaved:

sequenceDiagram
    participant CPU
    participant ANE
    participant GCD as GCD Async Queue

    Note over CPU: FORWARD PASS (per layer L=0..11)

    CPU->>CPU: embed_lookup(tokens to x_cur)

    loop Layer L = 0..11
        CPU->>CPU: wait for prior async dW
        CPU->>CPU: save layer_in, write fp16 to IOSurface
        CPU->>ANE: run fwdAttn kernel
        ANE-->>CPU: concat(o_out, Q, K, V, attn_out, xnorm)
        CPU->>CPU: read fp16 taps, residual add to x2

        CPU->>CPU: write fp16 x2 to IOSurface
        CPU->>ANE: run fwdFFN kernel
        ANE-->>CPU: concat(ffn_out, h1, h3, silu_out, x2norm)
        CPU->>CPU: read fp16 taps, residual add to x_cur
    end

    Note over CPU: CLASSIFIER + LOSS
    CPU->>CPU: rmsnorm(x_cur to x_final)
    CPU->>CPU: cblas_sgemm(embed x x_final to logits)
    CPU->>CPU: cross_entropy_loss(logits to loss, dlogits)

    Note over CPU: BACKWARD PASS
    CPU->>CPU: cblas_sgemm(embed^T x dlogits to dy)
    CPU->>GCD: async dEmbed += dlogits x x_final^T
    CPU->>CPU: rmsnorm_bwd(dy to dx)

    loop Layer L = 11..0
        Note over CPU,GCD: FFN Backward
        CPU->>CPU: write dffn + copy h1,h3 from fwd taps
        CPU->>ANE: run ffnBwd kernel
        ANE-->>CPU: concat(dx_ffn, dh1, dh3)
        CPU->>GCD: async dW2, dW1, dW3 accumulation

        Note over CPU,GCD: RMSNorm2 Backward + Residual
        CPU->>CPU: rmsnorm_bwd, add residual gradient

        Note over CPU,GCD: SDPA Backward
        CPU->>GCD: async dWo accumulation
        CPU->>CPU: copy Q,K,V from fwd taps, write dx2
        CPU->>ANE: run sdpaBwd1 kernel
        ANE-->>CPU: concat(dV, probs, dP)

        CPU->>CPU: copy probs,dP,Q,K
        CPU->>ANE: run sdpaBwd2 kernel
        ANE-->>CPU: concat(dQ, dK)

        CPU->>GCD: async dWq, dWk, dWv accumulation

        Note over CPU,GCD: QKV Backward
        CPU->>CPU: copy dQ,dK,dV
        CPU->>ANE: run qkvBwd kernel
        ANE-->>CPU: dx_attn

        Note over CPU,GCD: RMSNorm1 Backward + Residual
        CPU->>CPU: rmsnorm_bwd, add both skip gradients
    end

    CPU->>CPU: dispatch_group_wait(all async dW)
    CPU->>CPU: embed_backward(dy to d_embed)

Async CPU/ANE Overlap Strategy

The key insight: dW gradients (weight gradients) are independent of the backward data flow. They are outer products dW += dy x x^T that only accumulate into gradient buffers. The data-path gradients (dx) flow backward through the network on ANE.

Timeline for one backward layer:
  ANE:  [ffnBwd]    [sdpaBwd1]   [sdpaBwd2]   [qkvBwd]
  CPU:       [dW_FFN (3x sgemm)]    [dWo]    [dWqkv (3x sgemm)]

GCD serial dispatch queue "dw_cblas" ensures dW operations don't overlap each other (they share scratch buffers). The dispatch_group_wait at the start of each forward layer ensures async dW from the previous step's backward has finished before IOSurfaces are reused.


Compile/Restart Lifecycle

The ANE runtime leaks resources internally, limiting compiles to ~119 per process. The system manages this with checkpoint-and-restart:

flowchart TD
    Start["Process starts (fresh or --resume)"] --> LoadCkpt{"--resume flag?"}
    LoadCkpt -->|Yes| Resume["Load checkpoint: weights, Adam state, step counter"]
    LoadCkpt -->|No| Init["Xavier init weights, zero Adam state"]
    Resume --> CompileCheck
    Init --> CompileCheck

    CompileCheck{"g_compile_count + 60 > MAX_COMPILES?"} -->|Yes| SaveCheckpoint["Save checkpoint to ane_stories110M_ckpt.bin"]
    SaveCheckpoint --> FreeAll["Free all ANE kernels"]
    FreeAll --> RestartProcess["Re-launch process with --resume flag"]
    RestartProcess --> Start

    CompileCheck -->|No| Compile["Compile 60 weight-bearing kernels (5 per layer x 12)"]
    Compile --> ZeroGrads["Zero gradient accumulators"]
    ZeroGrads --> AccumLoop

    subgraph AccumLoop ["Gradient Accumulation (10 steps)"]
        SingleStep["Forward + Backward + async dW"] --> MoreSteps{"More accum steps?"}
        MoreSteps -->|Yes| SingleStep
    end

    MoreSteps -->|No| WaitDW["dispatch_group_wait (all async dW)"]
    WaitDW --> ScaleGrad["Scale gradients by 1/ACCUM_STEPS"]
    ScaleGrad --> AdamUpdate["Adam update (mutates weights in-place)"]
    AdamUpdate --> FreeKernels["Free all weight-bearing kernels"]
    FreeKernels --> CompileCheck

With MAX_COMPILES=100 and 60 weight-bearing kernels per batch, only 1 batch (10 accumulation steps) fits per process lifetime. The checkpoint preserves:

  • Training step and total_steps
  • All weights and Adam (m, v) state per layer
  • Cumulative timing statistics
  • Adam timestep counter

Data Flow Through One Layer

Tensor shapes as they flow through forward and backward passes:

flowchart LR
    subgraph fwdAttnKernel ["fwdAttn Kernel (ANE)"]
        xIn["x_in\n[1,768,1,256]"] --> RMS1["RMSNorm1"]
        RMS1 --> QKVConv["Wq,Wk,Wv conv\n[768,768,1,1]"]
        QKVConv --> ReshapeHeads["reshape\n[1,12,64,256]"]
        ReshapeHeads --> TransposeHeads["transpose\n[1,12,256,64]"]
        TransposeHeads --> QKT["Q x K^T\n[1,12,256,256]"]
        QKT --> ScaleMask["scale + mask\n+ softmax"]
        ScaleMask --> AV["scores x V\n[1,12,256,64]"]
        AV --> ReshapeBackFlat["reshape\n[1,768,1,256]"]
        ReshapeBackFlat --> WoConv["Wo conv\n[768,768,1,1]"]
    end

    subgraph taps1 ["Taps via concat"]
        WoConv --> T1["o_out [768]"]
        QKVConv --> T2["Q,K,V [768 each]"]
        AV --> T3["attn_out [768]"]
        RMS1 --> T4["xnorm [768]"]
    end

    subgraph cpuResid1 ["CPU"]
        T1 --> ResAdd1["x + o_out = x2"]
    end

    subgraph fwdFFNKernel ["fwdFFN Kernel (ANE)"]
        ResAdd1 --> RMS2["RMSNorm2"]
        RMS2 --> W1W3["W1,W3 conv\n[2048,768,1,1]"]
        W1W3 --> SiLUGate["sigmoid + SiLU\n+ gating"]
        SiLUGate --> W2Conv["W2 conv\n[768,2048,1,1]"]
    end

    subgraph taps2 ["Taps via concat"]
        W2Conv --> T5["ffn_out [768]"]
        W1W3 --> T6["h1,h3 [2048 each]"]
        SiLUGate --> T7["silu_out [2048]"]
        RMS2 --> T8["x2norm [768]"]
    end

    subgraph cpuResid2 ["CPU"]
        T5 --> ResAdd2["x2 + ffn_out = x_next"]
    end

IOSurface Memory Layout

All tensors use channel-first [1, C, 1, S] fp16 layout on IOSurfaces, matching ANE's native format:

IOSurface memory (contiguous fp16):
  channel_0:   [pos_0, pos_1, ..., pos_255]   (256 values)
  channel_1:   [pos_0, pos_1, ..., pos_255]
  ...
  channel_767: [pos_0, pos_1, ..., pos_255]

Fused kernel outputs use concat(axis=1) to pack multiple tensors into a single IOSurface:

fwdAttn output [1, 6*768, 1, 256]:
  channels    0-767:  o_out (Wo projection output)
  channels  768-1535: Q (query projection)
  channels 1536-2303: K (key projection)
  channels 2304-3071: V (value projection)
  channels 3072-3839: attn_out (pre-Wo attention output)
  channels 3840-4607: xnorm (RMSNorm1 output)

CPU reads specific taps via io_read_fp16(surface, data, ch_offset, n_channels, spatial).


Weight Blob Format

ANE weight blobs follow a binary format with a 128-byte header:

Offset  Size   Content
------  -----  -------
0       1      0x01 (format marker)
4       1      0x02 (format marker)
5-63    59     zeros (padding)
64      4      0xDEADBEEF (chunk magic, little-endian)
68      1      0x01 (chunk marker)
72      4      uint32 data_size (fp16 weight bytes)
80      4      uint32 data_offset (always 128)
84-127  44     zeros (padding)
128+    N      fp16 weight data, row-major [out_ch, in_ch]

Multi-weight blobs (fused QKV, FFN up) concatenate chunks: [64B global header] [64B chunk0 header] [chunk0 data] [64B chunk1 header] [chunk1 data] ...

MIL programs reference weights via BLOBFILE(path="@model_path/weights/name.bin", offset=uint64(64)) where offset 64 points to the chunk header within the file.


Key Constraints

Constraint Impact Workaround
~119 compile limit per process ANE compiler leaks resources checkpoint + re-launch with --resume
Weights baked at compile time Cannot hot-swap weights; must recompile Gradient accumulation amortizes compile cost
SDPA ignores attn_mask Causal attention cannot use native SDPA mask Decompose into Q at K^T + explicit mask + softmax + scores at V
ANE SRAM capacity ~32 MB Large weight matrices spill to DRAM Performance cliff above ~3072 channels
32000 input channels rejected ANE refuses conv with VOCAB input channels Classifier backward uses matmul op with reshape instead of conv
fp16 compute only Precision limited on ANE fp32 on CPU for loss, Adam; fp16 for ANE forward/backward

train_large.m vs train_large_ane.m

train_large_ane.m moves additional operations from CPU to ANE:

Operation train_large.m train_large_ane.m
Final RMSNorm CPU (rmsnorm() via vDSP) ANE (gen_final_rmsnorm())
Classifier forward CPU (cblas_sgemm) ANE (gen_classifier_fwd(), 32000-ch conv)
Softmax CPU (inside cross_entropy_loss()) ANE (gen_softmax_vocab())
Per-layer RMSNorm backward CPU (rmsnorm_bwd() via vDSP) ANE (gen_rmsnorm_bwd())

This increases compile budget pressure: 86 weight-bearing kernels per batch (vs 60), leaving less headroom within MAX_COMPILES=100.