mirror of
https://github.com/karpathy/minGPT
synced 2024-11-15 19:10:39 +01:00
208 lines
8.1 KiB
Python
208 lines
8.1 KiB
Python
"""
|
|
Trains a GPT to add n-digit numbers.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from torch.utils.data.dataloader import DataLoader
|
|
|
|
from mingpt.model import GPT
|
|
from mingpt.trainer import Trainer
|
|
from mingpt.utils import set_seed, setup_logging, CfgNode as CN
|
|
|
|
# -----------------------------------------------------------------------------
|
|
|
|
def get_config():
|
|
|
|
C = CN()
|
|
|
|
# system
|
|
C.system = CN()
|
|
C.system.seed = 3407
|
|
C.system.work_dir = './out/adder'
|
|
|
|
# data
|
|
C.data = AdditionDataset.get_default_config()
|
|
|
|
# model
|
|
C.model = GPT.get_default_config()
|
|
C.model.model_type = 'gpt-nano'
|
|
|
|
# trainer
|
|
C.trainer = Trainer.get_default_config()
|
|
C.trainer.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
|
|
|
|
return C
|
|
|
|
# -----------------------------------------------------------------------------
|
|
|
|
class AdditionDataset(Dataset):
|
|
"""
|
|
Creates n-digit addition problems. For example, if n=2, then an example
|
|
addition problem would be to add 85 + 50 = 135. This problem would be
|
|
represented as the following string for the GPT:
|
|
|
|
"8550531"
|
|
|
|
This is because:
|
|
- we are discarding the + and =, which are not necessary. We just encode the digits
|
|
of the input numbers concatenated together.
|
|
- the result 135 is encoded backwards to make the addition easier to learn for the
|
|
GPT model, because of how the addition algorithm works.
|
|
|
|
As one more example, the problem 6 + 39 = 45 would be encoded as:
|
|
|
|
"0639054"
|
|
|
|
where you will notice that we are padding with zeros to make sure that we always
|
|
produce strings of the exact same size: n + n + (n + 1). When n=2, this is 7.
|
|
At test time, we will feed in an addition problem by giving the first 2n digits,
|
|
and hoping that the GPT model completes the sequence with the next (n+1) digits
|
|
correctly.
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_default_config():
|
|
C = CN()
|
|
C.ndigit = 2
|
|
return C
|
|
|
|
def __init__(self, config, split):
|
|
self.config = config
|
|
self.split = split # train/test
|
|
|
|
# split up all addition problems into either training data or test data
|
|
ndigit = self.config.ndigit
|
|
assert ndigit <= 3, "the lines below would be very memory inefficient, in future maybe refactor to support"
|
|
num = (10**ndigit)**2 # total number of possible addition problems with ndigit numbers
|
|
rng = torch.Generator()
|
|
rng.manual_seed(1337)
|
|
perm = torch.randperm(num, generator=rng)
|
|
num_test = min(int(num*0.2), 500) # 20% of the whole dataset, or only up to 500
|
|
self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]
|
|
|
|
def get_vocab_size(self):
|
|
return 10 # digits 0..9
|
|
|
|
def get_block_size(self):
|
|
# a,b,a+b, and +1 due to potential carry overflow,
|
|
# but then also -1 because very last digit doesn't ever plug back
|
|
# as there is no explicit <EOS> token to predict, it is implied
|
|
return 3*self.config.ndigit + 1 - 1
|
|
|
|
def __len__(self):
|
|
return self.ixes.nelement()
|
|
|
|
def __getitem__(self, idx):
|
|
ndigit = self.config.ndigit
|
|
# given a problem index idx, first recover the associated a + b
|
|
idx = self.ixes[idx].item()
|
|
nd = 10**ndigit
|
|
a = idx // nd
|
|
b = idx % nd
|
|
# calculate the "label" of the addition problem a + b
|
|
c = a + b
|
|
# encode the digits of a, b, c into strings
|
|
astr = f'%0{ndigit}d' % a
|
|
bstr = f'%0{ndigit}d' % b
|
|
cstr = (f'%0{ndigit+1}d' % c)[::-1] # reverse c to make addition easier
|
|
render = astr + bstr + cstr
|
|
dix = [int(s) for s in render] # convert each character to its token index
|
|
# x will be input to GPT and y will be the associated expected outputs
|
|
x = torch.tensor(dix[:-1], dtype=torch.long)
|
|
y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence
|
|
y[:ndigit*2-1] = -1 # we will only train in the output locations. -1 will mask loss to zero
|
|
return x, y
|
|
|
|
# -----------------------------------------------------------------------------
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# get default config and overrides from the command line, if any
|
|
config = get_config()
|
|
config.merge_from_args(sys.argv[1:])
|
|
print(config)
|
|
setup_logging(config)
|
|
set_seed(config.system.seed)
|
|
|
|
# construct train and test datasets
|
|
train_dataset = AdditionDataset(config.data, split='train')
|
|
test_dataset = AdditionDataset(config.data, split='test')
|
|
|
|
# construct the model
|
|
config.model.vocab_size = train_dataset.get_vocab_size()
|
|
config.model.block_size = train_dataset.get_block_size()
|
|
model = GPT(config.model)
|
|
|
|
# construct the trainer object
|
|
trainer = Trainer(config.trainer, model, train_dataset)
|
|
|
|
# helper function for the evaluation of a model
|
|
def eval_split(trainer, split, max_batches=None):
|
|
dataset = {'train':train_dataset, 'test':test_dataset}[split]
|
|
ndigit = config.data.ndigit
|
|
results = []
|
|
mistakes_printed_already = 0
|
|
factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device)
|
|
loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)
|
|
for b, (x, y) in enumerate(loader):
|
|
x = x.to(trainer.device)
|
|
# isolate the first two digits of the input sequence alone
|
|
d1d2 = x[:, :ndigit*2]
|
|
# let the model sample the rest of the sequence
|
|
d1d2d3 = model.generate(d1d2, ndigit+1, do_sample=False) # using greedy argmax, not sampling
|
|
# isolate the last digit of the sampled sequence
|
|
d3 = d1d2d3[:, -(ndigit+1):]
|
|
d3 = d3.flip(1) # reverse the digits to their "normal" order
|
|
# decode the integers from individual digits
|
|
d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)
|
|
d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)
|
|
d3i_pred = (d3 * factors).sum(1)
|
|
d3i_gt = d1i + d2i # manually calculate the ground truth
|
|
# evaluate the correctness of the results in this batch
|
|
correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
|
|
for i in range(x.size(0)):
|
|
results.append(int(correct[i]))
|
|
if not correct[i] and mistakes_printed_already < 5: # only print up to 5 mistakes to get a sense
|
|
mistakes_printed_already += 1
|
|
print("GPT claims that %d + %d = %d but gt is %d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i]))
|
|
if max_batches is not None and b+1 >= max_batches:
|
|
break
|
|
rt = torch.tensor(results, dtype=torch.float)
|
|
print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean()))
|
|
return rt.sum()
|
|
|
|
# iteration callback
|
|
top_score = 0
|
|
def batch_end_callback(trainer):
|
|
global top_score
|
|
|
|
if trainer.iter_num % 10 == 0:
|
|
print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
|
|
|
|
if trainer.iter_num % 500 == 0:
|
|
# evaluate both the train and test score
|
|
train_max_batches = {1: None, 2: None, 3: 5}[config.data.ndigit] # if ndigit=2 we can afford the whole train set, ow no
|
|
model.eval()
|
|
with torch.no_grad():
|
|
train_score = eval_split(trainer, 'train', max_batches=train_max_batches)
|
|
test_score = eval_split(trainer, 'test', max_batches=None)
|
|
score = train_score + test_score
|
|
# save the model if this is the best score we've seen so far
|
|
if score > top_score:
|
|
top_score = score
|
|
print(f"saving model with new top score of {score}")
|
|
ckpt_path = os.path.join(config.system.work_dir, "model.pt")
|
|
torch.save(model.state_dict(), ckpt_path)
|
|
# revert model to training mode
|
|
model.train()
|
|
|
|
trainer.set_callback('on_batch_end', batch_end_callback)
|
|
|
|
# run the optimization
|
|
trainer.run()
|