mirror of https://github.com/maderix/ANE.git
Fix dashboard text generation: add KV cache for proper autoregressive attention
This commit is contained in:
parent
19da850fca
commit
06535fc5be
|
|
@ -162,12 +162,17 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
||||||
freq = 1.0 / (10000.0 ** (2.0 * i / HD))
|
freq = 1.0 / (10000.0 ** (2.0 * i / HD))
|
||||||
freqs[pos, i] = pos * freq
|
freqs[pos, i] = pos * freq
|
||||||
|
|
||||||
|
# KV cache: per-layer, per-head arrays
|
||||||
|
k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)]
|
||||||
|
v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)]
|
||||||
|
|
||||||
for step in range(max_tokens):
|
for step in range(max_tokens):
|
||||||
seq_len = len(tokens)
|
seq_len = len(tokens)
|
||||||
if seq_len > SEQ:
|
if seq_len > SEQ:
|
||||||
break
|
break
|
||||||
|
|
||||||
x = W['embed'][tokens[-1]].copy()
|
x = W['embed'][tokens[-1]].copy()
|
||||||
|
pos = seq_len - 1
|
||||||
|
|
||||||
for L in range(NLAYERS):
|
for L in range(NLAYERS):
|
||||||
# RMSNorm + QKV
|
# RMSNorm + QKV
|
||||||
|
|
@ -177,7 +182,6 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
||||||
v = W[f'Wv{L}'] @ xn
|
v = W[f'Wv{L}'] @ xn
|
||||||
|
|
||||||
# RoPE
|
# RoPE
|
||||||
pos = seq_len - 1
|
|
||||||
for h in range(HEADS):
|
for h in range(HEADS):
|
||||||
for i in range(HD // 2):
|
for i in range(HD // 2):
|
||||||
freq = freqs[pos, i]
|
freq = freqs[pos, i]
|
||||||
|
|
@ -189,14 +193,18 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
||||||
k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v
|
k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v
|
||||||
k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v
|
k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v
|
||||||
|
|
||||||
# Attention (single token)
|
# Append to KV cache and compute attention
|
||||||
o = np.zeros(DIM, dtype=np.float32)
|
o = np.zeros(DIM, dtype=np.float32)
|
||||||
for h in range(HEADS):
|
for h in range(HEADS):
|
||||||
qh = q[h * HD:(h + 1) * HD]
|
qh = q[h * HD:(h + 1) * HD]
|
||||||
kh = k[h * HD:(h + 1) * HD]
|
kh = k[h * HD:(h + 1) * HD].reshape(1, HD)
|
||||||
vh = v[h * HD:(h + 1) * HD]
|
vh = v[h * HD:(h + 1) * HD].reshape(1, HD)
|
||||||
score = np.dot(qh, kh) / math.sqrt(HD)
|
k_cache[L][h] = np.vstack([k_cache[L][h], kh])
|
||||||
o[h * HD:(h + 1) * HD] = vh
|
v_cache[L][h] = np.vstack([v_cache[L][h], vh])
|
||||||
|
# scores: (1, HD) @ (HD, seq_len) -> (seq_len,)
|
||||||
|
scores = k_cache[L][h] @ qh / math.sqrt(HD)
|
||||||
|
attn = softmax(scores)
|
||||||
|
o[h * HD:(h + 1) * HD] = attn @ v_cache[L][h]
|
||||||
|
|
||||||
# Residual + output projection
|
# Residual + output projection
|
||||||
x2 = x + W[f'Wo{L}'] @ o
|
x2 = x + W[f'Wo{L}'] @ o
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue