mirror of https://github.com/maderix/ANE.git
Fix dashboard text generation: add KV cache for proper autoregressive attention
This commit is contained in:
parent
19da850fca
commit
06535fc5be
|
|
@ -162,12 +162,17 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
|||
freq = 1.0 / (10000.0 ** (2.0 * i / HD))
|
||||
freqs[pos, i] = pos * freq
|
||||
|
||||
# KV cache: per-layer, per-head arrays
|
||||
k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)]
|
||||
v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)]
|
||||
|
||||
for step in range(max_tokens):
|
||||
seq_len = len(tokens)
|
||||
if seq_len > SEQ:
|
||||
break
|
||||
|
||||
x = W['embed'][tokens[-1]].copy()
|
||||
pos = seq_len - 1
|
||||
|
||||
for L in range(NLAYERS):
|
||||
# RMSNorm + QKV
|
||||
|
|
@ -177,7 +182,6 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
|||
v = W[f'Wv{L}'] @ xn
|
||||
|
||||
# RoPE
|
||||
pos = seq_len - 1
|
||||
for h in range(HEADS):
|
||||
for i in range(HD // 2):
|
||||
freq = freqs[pos, i]
|
||||
|
|
@ -189,14 +193,18 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
|||
k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v
|
||||
k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v
|
||||
|
||||
# Attention (single token)
|
||||
# Append to KV cache and compute attention
|
||||
o = np.zeros(DIM, dtype=np.float32)
|
||||
for h in range(HEADS):
|
||||
qh = q[h * HD:(h + 1) * HD]
|
||||
kh = k[h * HD:(h + 1) * HD]
|
||||
vh = v[h * HD:(h + 1) * HD]
|
||||
score = np.dot(qh, kh) / math.sqrt(HD)
|
||||
o[h * HD:(h + 1) * HD] = vh
|
||||
kh = k[h * HD:(h + 1) * HD].reshape(1, HD)
|
||||
vh = v[h * HD:(h + 1) * HD].reshape(1, HD)
|
||||
k_cache[L][h] = np.vstack([k_cache[L][h], kh])
|
||||
v_cache[L][h] = np.vstack([v_cache[L][h], vh])
|
||||
# scores: (1, HD) @ (HD, seq_len) -> (seq_len,)
|
||||
scores = k_cache[L][h] @ qh / math.sqrt(HD)
|
||||
attn = softmax(scores)
|
||||
o[h * HD:(h + 1) * HD] = attn @ v_cache[L][h]
|
||||
|
||||
# Residual + output projection
|
||||
x2 = x + W[f'Wo{L}'] @ o
|
||||
|
|
|
|||
Loading…
Reference in New Issue