mirror of https://github.com/maderix/ANE.git
72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
import os
|
|
import json
|
|
from collections import Counter
|
|
|
|
# Minimal BPE trainer for TinyStories
|
|
RAW_TEXT_PATH = "/Users/andy.huang/lab/research/ANE/training/tinystories_raw.txt"
|
|
VOCAB_PATH = "/Users/andy.huang/lab/research/ANE/training/vocab.json"
|
|
VOCAB_SIZE = 5000 # Reduced for speed of verification
|
|
SUBSET_SIZE = 200000 # 200KB limit for speed
|
|
|
|
def get_stats(ids):
|
|
counts = Counter()
|
|
for pair in zip(ids, ids[1:]):
|
|
counts[pair] += 1
|
|
return counts
|
|
|
|
def merge(ids, pair, idx):
|
|
new_ids = []
|
|
i = 0
|
|
while i < len(ids):
|
|
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
|
|
new_ids.append(idx)
|
|
i += 2
|
|
else:
|
|
new_ids.append(ids[i])
|
|
i += 1
|
|
return new_ids
|
|
|
|
def train():
|
|
print(f"Loading raw text (subset {SUBSET_SIZE} bytes) from {RAW_TEXT_PATH}...")
|
|
with open(RAW_TEXT_PATH, "r", encoding="utf-8") as f:
|
|
text = f.read(SUBSET_SIZE)
|
|
|
|
print("Initial byte-encoding...")
|
|
# Start with raw bytes (0-255)
|
|
ids = list(text.encode("utf-8"))
|
|
|
|
merges = {}
|
|
vocab = {i: bytes([i]) for i in range(256)}
|
|
|
|
num_merges = VOCAB_SIZE - 256
|
|
print(f"Training BPE for {num_merges} merges...")
|
|
|
|
for i in range(num_merges):
|
|
stats = get_stats(ids)
|
|
if not stats:
|
|
break
|
|
pair = max(stats, key=stats.get)
|
|
idx = 256 + i
|
|
ids = merge(ids, pair, idx)
|
|
merges[pair] = idx
|
|
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
|
if (i+1) % 100 == 0:
|
|
print(f"Merge {i+1}/{num_merges}: {pair} -> {idx} (count {stats[pair]})")
|
|
|
|
# Save merges and vocab
|
|
# We need to convert tuple keys to strings for JSON
|
|
serializable_merges = {f"{p[0]},{p[1]}": idx for p, idx in merges.items()}
|
|
# Convert vocab bytes to list of ints for JSON
|
|
serializable_vocab = {idx: list(b) for idx, b in vocab.items()}
|
|
|
|
with open(VOCAB_PATH, "w") as f:
|
|
json.dump({
|
|
"merges": serializable_merges,
|
|
"vocab": serializable_vocab
|
|
}, f)
|
|
|
|
print(f"Vocab saved to {VOCAB_PATH}")
|
|
|
|
if __name__ == "__main__":
|
|
train()
|