1
0
mirror of https://github.com/karpathy/minGPT synced 2024-11-15 19:10:39 +01:00

early work, refactoring the adder first

This commit is contained in:
Andrej Karpathy 2022-05-27 10:04:52 -07:00
parent 3ed14b2cec
commit 8425759c24
6 changed files with 278 additions and 130 deletions

@ -14,25 +14,10 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from mingpt.utils import CfgNode as CN
logger = logging.getLogger(__name__)
class GPTConfig:
""" base GPT config, params common to all GPT versions """
embd_pdrop = 0.1
resid_pdrop = 0.1
attn_pdrop = 0.1
def __init__(self, vocab_size, block_size, **kwargs):
self.vocab_size = vocab_size
self.block_size = block_size
for k,v in kwargs.items():
setattr(self, k, v)
class GPT1Config(GPTConfig):
""" GPT-1 like network roughly 125M params """
n_layer = 12
n_head = 12
n_embd = 768
# -----------------------------------------------------------------------------
class CausalSelfAttention(nn.Module):
"""
@ -101,6 +86,23 @@ class Block(nn.Module):
class GPT(nn.Module):
""" the full GPT language model, with a context size of block_size """
@classmethod
def get_default_config(self, type):
C = CN(**{
'GPT-1': dict(n_layer=12, n_head=12, n_embd=768),
'Gopher-44M': dict(n_layer=8, n_head=16, n_embd=512),
'GPT-Micro': dict(n_layer=4, n_head=4, n_embd=64), # I made this one up...
}[type]
)
# these options must be filled in externally
C.vocab_size = None
C.block_size = None
# dropout hyperparameters
C.embd_pdrop = 0.1
C.resid_pdrop = 0.1
C.attn_pdrop = 0.1
return C
def __init__(self, config):
super().__init__()
@ -117,6 +119,7 @@ class GPT(nn.Module):
self.block_size = config.block_size
self.apply(self._init_weights)
# TODO: only report the number of non-embedding parameters, as is standard practice
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_block_size(self):
@ -181,7 +184,7 @@ class GPT(nn.Module):
def forward(self, idx, targets=None):
b, t = idx.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
assert t <= self.block_size, f"Cannot forward, model block size is exhausted on {t} <= {self.block_size}"
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
@ -194,6 +197,6 @@ class GPT(nn.Module):
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss

@ -3,139 +3,84 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
so nothing in this file really has anything to do with GPT specifically.
"""
import math
import logging
from tqdm import tqdm
import numpy as np
from collections import defaultdict
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader
logger = logging.getLogger(__name__)
class TrainerConfig:
# optimization parameters
max_epochs = 10
batch_size = 64
learning_rate = 3e-4
betas = (0.9, 0.95)
grad_norm_clip = 1.0
weight_decay = 0.1 # only applied on matmul weights
# learning rate decay params: linear warmup followed by cosine decay to 10% of original
lr_decay = False
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
final_tokens = 260e9 # (at what point we reach 10% of original LR)
# checkpoint settings
ckpt_path = None
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
for k,v in kwargs.items():
setattr(self, k, v)
from mingpt.utils import CfgNode as CN
class Trainer:
def __init__(self, model, train_dataset, test_dataset, config):
@classmethod
def get_default_config(self):
C = CN()
# dataloder parameters
C.num_workers = 4
# optimizer parameters
C.batch_size = 64
C.learning_rate = 3e-4
C.betas = (0.9, 0.95)
C.weight_decay = 0.1 # only applied on matmul weights
C.grad_norm_clip = 1.0
return C
def __init__(self, config, model, train_dataset):
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
self.callbacks = defaultdict(list)
# take over whatever gpus are on the system
self.device = 'cpu'
if torch.cuda.is_available():
self.device = torch.cuda.current_device()
self.model = torch.nn.DataParallel(self.model).to(self.device)
self.model = self.model.to(self.device)
def save_checkpoint(self):
# DataParallel wrappers keep raw model object in .module attribute
raw_model = self.model.module if hasattr(self.model, "module") else self.model
logger.info("saving %s", self.config.ckpt_path)
torch.save(raw_model.state_dict(), self.config.ckpt_path)
def register_callback(self, onevent: str, callback):
self.callbacks[onevent].append(callback)
def train(self):
def trigger_callbacks(self, onevent: str):
for callback in self.callbacks.get(onevent, []):
callback(self)
def run(self):
model, config = self.model, self.config
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)
def run_epoch(loader, is_train):
model.train(is_train)
losses = []
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
for it, (x, y) in pbar:
# place data on the correct device
x = x.to(self.device)
y = y.to(self.device)
# forward the model
with torch.set_grad_enabled(is_train):
logits, loss = model(x, y)
loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
losses.append(loss.item())
if is_train:
# backprop and update the parameters
model.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
optimizer.step()
# decay the learning rate based on our progress
if config.lr_decay:
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
else:
# cosine learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
lr = config.learning_rate
# report progress
pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")
if not is_train:
test_loss = float(np.mean(losses))
logger.info("test loss: %f", test_loss)
return test_loss
best_loss = float('inf')
self.tokens = 0 # counter used for learning rate decay
# setup the optimizer
optimizer = model.configure_optimizers(config)
# setup the dataloader
train_loader = DataLoader(
self.train_dataset,
shuffle=True,
sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)),
shuffle=False,
pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers
num_workers=config.num_workers,
)
if self.test_dataset is not None:
test_loader = DataLoader(
self.test_dataset,
shuffle=True,
pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers
)
for epoch in range(config.max_epochs):
run_epoch(train_loader, is_train=True)
if self.test_dataset is not None:
test_loss = run_epoch(test_loader, is_train=False)
self.iter_num = 0
data_iter = iter(train_loader)
while True:
# supports early stopping based on the test loss, or just save always if no test set is provided
good_model = self.test_dataset is None or test_loss < best_loss
if self.config.ckpt_path is not None and good_model:
best_loss = test_loss
self.save_checkpoint()
# fetch the next batch (x, y) and re-init iterator if needed
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
batch = next(data_iter)
batch = [t.to(self.device) for t in batch]
x, y = batch
# forward the model
model.train()
logits, self.loss = model(x, y)
# backprop and update the parameters
model.zero_grad(set_to_none=True)
self.loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
optimizer.step()
self.trigger_callbacks('on_batch_end')
self.iter_num += 1

@ -45,3 +45,9 @@ def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
x = torch.cat((x, ix), dim=1)
return x
class CfgNode:
""" a lightweight configuration class inspired by yacs """
def __init__(self, **kwargs):
self.__dict__.update(kwargs)

186
projects/adder/adder.py Normal file

@ -0,0 +1,186 @@
"""
Trains a GPT to add n-digit numbers.
"""
import os
import sys
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, sample, CfgNode as CN
# -----------------------------------------------------------------------------
def get_config():
C = CN()
# system
C.system = CN()
C.system.seed = 1337
C.system.work_dir = './out/adder'
# data
C.data = AdditionDataset.get_default_config()
# model
C.model = GPT.get_default_config('GPT-Micro')
C.model.vocab_size = 10 # the digits 0..9
# a,b,a+b, and +1 due to potential carry overflow,
# but then also -1 because very last digit doesn't ever plug back
# as there is no explicit <EOS> token to predict, it is implied
C.model.block_size = C.data.ndigit + C.data.ndigit + C.data.ndigit + 1 - 1
# trainer
C.trainer = Trainer.get_default_config()
C.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
return C
# -----------------------------------------------------------------------------
class AdditionDataset(Dataset):
"""
Creates n-digit addition problems. For example, if n=2, then an example
addition problem would be to add 85 + 50 = 135. This problem would be
represented as the following string for the GPT:
"8550531"
This is because:
- we are discarding the + and =, which are not necessary. We just encode the digits
of the input numbers concatenated together.
- the result 135 is encoded backwards to make the addition easier to learn for the
GPT model, because of how the addition algorithm works.
As one more example, the problem 6 + 39 = 45 would be encoded as:
"0639054"
where you will notice that we are padding with zeros to make sure that we always
produce strings of the exact same size: n + n + (n + 1). When n=2, this is 7.
At test time, we will feed in an addition problem by giving the first 2n digits,
and hoping that the GPT model completes the sequence with the next (n+1) digits
correctly.
"""
@classmethod
def get_default_config(self):
C = CN()
C.ndigit = 2
return C
def __init__(self, config, split):
self.config = config
self.split = split # train/test
self.vocab_size = 10 # 10 possible digits 0..9
# split up all addition problems into either training data or test data
ndigit = self.config.ndigit
assert ndigit <= 3, "the lines below would be very memory inefficient, in future maybe refactor to support"
num = (10**ndigit)**2 # total number of possible addition problems with ndigit numbers
rng = torch.Generator()
rng.manual_seed(1337)
perm = torch.randperm(num, generator=rng)
num_test = min(int(num*0.2), 500) # 20% of the whole dataset, or only up to 500
self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]
def __len__(self):
return self.ixes.nelement()
def __getitem__(self, idx):
ndigit = self.config.ndigit
# given a problem index idx, first recover the associated a + b
idx = self.ixes[idx].item()
nd = 10**ndigit
a = idx // nd
b = idx % nd
# calculate the "label" of the addition problem a + b
c = a + b
# encode the digits of a, b, c into strings
astr = f'%0{ndigit}d' % a
bstr = f'%0{ndigit}d' % b
cstr = (f'%0{ndigit+1}d' % c)[::-1] # reverse c to make addition easier
render = astr + bstr + cstr
dix = [int(s) for s in render] # convert each character to its token index
# x will be input to GPT and y will be the associated expected outputs
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence
y[:ndigit*2-1] = -1 # we will only train in the output locations. -1 will mask loss to zero
return x, y
# -----------------------------------------------------------------------------
if __name__ == '__main__':
# get default config and overrides from the command line, if any
config = get_config()
#config.merge_from_list(sys.argv[1:])
# inits and logging
set_seed(config.system.seed)
os.makedirs(config.system.work_dir, exist_ok=True)
# construct train and test datasets
train_dataset = AdditionDataset(config.data, split='train')
test_dataset = AdditionDataset(config.data, split='test')
# construct the model
model = GPT(config.model)
# construct the trainer object
trainer = Trainer(config.trainer, model, train_dataset)
# helper function for the evaluation of a model
def eval_split(trainer, split, max_batches=-1):
dataset = {'train':train_dataset, 'test':test_dataset}[split]
ndigit = config.data.ndigit
results = []
mistakes_printed_already = 0
loader = DataLoader(dataset, batch_size=50, num_workers=0, drop_last=False)
for b, (x, y) in enumerate(loader):
x = x.to(trainer.device)
d1d2 = x[:, :ndigit*2]
d1d2d3 = sample(model, d1d2, ndigit+1)
d3 = d1d2d3[:, -(ndigit+1):]
d3 = d3.flip(1) # reverse the digits to their "normal" order
factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device)
# decode the integers from individual digits
d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)
d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)
d3i_pred = (d3 * factors).sum(1)
d3i_gt = d1i + d2i
correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
for i in range(x.size(0)):
results.append(int(correct[i]))
if not correct[i] and mistakes_printed_already < 5: # only print up to 5 mistakes to get a sense
mistakes_printed_already += 1
print("GPT claims that %03d + %03d = %03d but gt is %03d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i]))
if max_batches >= 0 and b+1 >= max_batches:
break
rt = torch.tensor(results, dtype=torch.float)
print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean()))
# iteration callback
def batch_end_callback(trainer):
if trainer.iter_num % 100 == 0:
print(f"iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
if trainer.iter_num % 500 == 0:
model.eval()
eval_split(trainer, 'train', max_batches=-1)
eval_split(trainer, 'test', max_batches=-1)
# todo: save good models but only if it's the best model so far
# ckpt_path = os.path.join(config.system.work_dir, "model.pt")
# torch.save(model.state_dict(), ckpt_path)
trainer.register_callback('on_batch_end', batch_end_callback)
# run the optimization
trainer.run()

4
projects/adder/readme.md Normal file

@ -0,0 +1,4 @@
### adder
Train a GPT model to add n-digit numbers

4
projects/readme.md Normal file

@ -0,0 +1,4 @@
### minGPT projects
Various projects that use the minGPT library to achieve great things.