ANE/training/train_bpe.py

73 lines
2.1 KiB
Python

import os
import json
from collections import Counter
# Minimal BPE trainer for TinyStories
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
RAW_TEXT_PATH = os.path.join(BASE_DIR, "tinystories_raw.txt")
VOCAB_PATH = os.path.join(BASE_DIR, "vocab.json")
VOCAB_SIZE = 5000 # Reduced for speed of verification
SUBSET_SIZE = 200000 # 200KB limit for speed
def get_stats(ids):
counts = Counter()
for pair in zip(ids, ids[1:]):
counts[pair] += 1
return counts
def merge(ids, pair, idx):
new_ids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
new_ids.append(idx)
i += 2
else:
new_ids.append(ids[i])
i += 1
return new_ids
def train():
print(f"Loading raw text (subset {SUBSET_SIZE} bytes) from {RAW_TEXT_PATH}...")
with open(RAW_TEXT_PATH, "r", encoding="utf-8") as f:
text = f.read(SUBSET_SIZE)
print("Initial byte-encoding...")
# Start with raw bytes (0-255)
ids = list(text.encode("utf-8"))
merges = {}
vocab = {i: bytes([i]) for i in range(256)}
num_merges = VOCAB_SIZE - 256
print(f"Training BPE for {num_merges} merges...")
for i in range(num_merges):
stats = get_stats(ids)
if not stats:
break
pair = max(stats, key=stats.get)
idx = 256 + i
ids = merge(ids, pair, idx)
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
if (i+1) % 100 == 0:
print(f"Merge {i+1}/{num_merges}: {pair} -> {idx} (count {stats[pair]})")
# Save merges and vocab
# We need to convert tuple keys to strings for JSON
serializable_merges = {f"{p[0]},{p[1]}": idx for p, idx in merges.items()}
# Convert vocab bytes to list of ints for JSON
serializable_vocab = {idx: list(b) for idx, b in vocab.items()}
with open(VOCAB_PATH, "w") as f:
json.dump({
"merges": serializable_merges,
"vocab": serializable_vocab
}, f)
print(f"Vocab saved to {VOCAB_PATH}")
if __name__ == "__main__":
train()