mirror of
https://github.com/karpathy/minGPT
synced 2024-11-15 19:10:39 +01:00
early work, refactoring the adder first
This commit is contained in:
parent
3ed14b2cec
commit
8425759c24
@ -14,25 +14,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mingpt.utils import CfgNode as CN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GPTConfig:
|
||||
""" base GPT config, params common to all GPT versions """
|
||||
embd_pdrop = 0.1
|
||||
resid_pdrop = 0.1
|
||||
attn_pdrop = 0.1
|
||||
|
||||
def __init__(self, vocab_size, block_size, **kwargs):
|
||||
self.vocab_size = vocab_size
|
||||
self.block_size = block_size
|
||||
for k,v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
class GPT1Config(GPTConfig):
|
||||
""" GPT-1 like network roughly 125M params """
|
||||
n_layer = 12
|
||||
n_head = 12
|
||||
n_embd = 768
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
"""
|
||||
@ -101,6 +86,23 @@ class Block(nn.Module):
|
||||
class GPT(nn.Module):
|
||||
""" the full GPT language model, with a context size of block_size """
|
||||
|
||||
@classmethod
|
||||
def get_default_config(self, type):
|
||||
C = CN(**{
|
||||
'GPT-1': dict(n_layer=12, n_head=12, n_embd=768),
|
||||
'Gopher-44M': dict(n_layer=8, n_head=16, n_embd=512),
|
||||
'GPT-Micro': dict(n_layer=4, n_head=4, n_embd=64), # I made this one up...
|
||||
}[type]
|
||||
)
|
||||
# these options must be filled in externally
|
||||
C.vocab_size = None
|
||||
C.block_size = None
|
||||
# dropout hyperparameters
|
||||
C.embd_pdrop = 0.1
|
||||
C.resid_pdrop = 0.1
|
||||
C.attn_pdrop = 0.1
|
||||
return C
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
@ -117,6 +119,7 @@ class GPT(nn.Module):
|
||||
self.block_size = config.block_size
|
||||
self.apply(self._init_weights)
|
||||
|
||||
# TODO: only report the number of non-embedding parameters, as is standard practice
|
||||
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def get_block_size(self):
|
||||
@ -181,7 +184,7 @@ class GPT(nn.Module):
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
b, t = idx.size()
|
||||
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
||||
assert t <= self.block_size, f"Cannot forward, model block size is exhausted on {t} <= {self.block_size}"
|
||||
|
||||
# forward the GPT model
|
||||
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
||||
@ -194,6 +197,6 @@ class GPT(nn.Module):
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
|
||||
return logits, loss
|
||||
|
@ -3,139 +3,84 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
|
||||
so nothing in this file really has anything to do with GPT specifically.
|
||||
"""
|
||||
|
||||
import math
|
||||
import logging
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainerConfig:
|
||||
# optimization parameters
|
||||
max_epochs = 10
|
||||
batch_size = 64
|
||||
learning_rate = 3e-4
|
||||
betas = (0.9, 0.95)
|
||||
grad_norm_clip = 1.0
|
||||
weight_decay = 0.1 # only applied on matmul weights
|
||||
# learning rate decay params: linear warmup followed by cosine decay to 10% of original
|
||||
lr_decay = False
|
||||
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
|
||||
final_tokens = 260e9 # (at what point we reach 10% of original LR)
|
||||
# checkpoint settings
|
||||
ckpt_path = None
|
||||
num_workers = 0 # for DataLoader
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k,v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
from mingpt.utils import CfgNode as CN
|
||||
|
||||
class Trainer:
|
||||
|
||||
def __init__(self, model, train_dataset, test_dataset, config):
|
||||
@classmethod
|
||||
def get_default_config(self):
|
||||
C = CN()
|
||||
# dataloder parameters
|
||||
C.num_workers = 4
|
||||
# optimizer parameters
|
||||
C.batch_size = 64
|
||||
C.learning_rate = 3e-4
|
||||
C.betas = (0.9, 0.95)
|
||||
C.weight_decay = 0.1 # only applied on matmul weights
|
||||
C.grad_norm_clip = 1.0
|
||||
return C
|
||||
|
||||
def __init__(self, config, model, train_dataset):
|
||||
self.model = model
|
||||
self.train_dataset = train_dataset
|
||||
self.test_dataset = test_dataset
|
||||
self.config = config
|
||||
self.callbacks = defaultdict(list)
|
||||
|
||||
# take over whatever gpus are on the system
|
||||
self.device = 'cpu'
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.cuda.current_device()
|
||||
self.model = torch.nn.DataParallel(self.model).to(self.device)
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
def save_checkpoint(self):
|
||||
# DataParallel wrappers keep raw model object in .module attribute
|
||||
raw_model = self.model.module if hasattr(self.model, "module") else self.model
|
||||
logger.info("saving %s", self.config.ckpt_path)
|
||||
torch.save(raw_model.state_dict(), self.config.ckpt_path)
|
||||
def register_callback(self, onevent: str, callback):
|
||||
self.callbacks[onevent].append(callback)
|
||||
|
||||
def train(self):
|
||||
def trigger_callbacks(self, onevent: str):
|
||||
for callback in self.callbacks.get(onevent, []):
|
||||
callback(self)
|
||||
|
||||
def run(self):
|
||||
model, config = self.model, self.config
|
||||
raw_model = model.module if hasattr(self.model, "module") else model
|
||||
optimizer = raw_model.configure_optimizers(config)
|
||||
|
||||
def run_epoch(loader, is_train):
|
||||
model.train(is_train)
|
||||
|
||||
losses = []
|
||||
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
|
||||
for it, (x, y) in pbar:
|
||||
|
||||
# place data on the correct device
|
||||
x = x.to(self.device)
|
||||
y = y.to(self.device)
|
||||
|
||||
# forward the model
|
||||
with torch.set_grad_enabled(is_train):
|
||||
logits, loss = model(x, y)
|
||||
loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
|
||||
losses.append(loss.item())
|
||||
|
||||
if is_train:
|
||||
|
||||
# backprop and update the parameters
|
||||
model.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
|
||||
optimizer.step()
|
||||
|
||||
# decay the learning rate based on our progress
|
||||
if config.lr_decay:
|
||||
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
|
||||
if self.tokens < config.warmup_tokens:
|
||||
# linear warmup
|
||||
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
|
||||
else:
|
||||
# cosine learning rate decay
|
||||
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
|
||||
lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||
lr = config.learning_rate * lr_mult
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
else:
|
||||
lr = config.learning_rate
|
||||
|
||||
# report progress
|
||||
pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")
|
||||
|
||||
if not is_train:
|
||||
test_loss = float(np.mean(losses))
|
||||
logger.info("test loss: %f", test_loss)
|
||||
return test_loss
|
||||
|
||||
best_loss = float('inf')
|
||||
self.tokens = 0 # counter used for learning rate decay
|
||||
# setup the optimizer
|
||||
optimizer = model.configure_optimizers(config)
|
||||
|
||||
# setup the dataloader
|
||||
train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
shuffle=True,
|
||||
sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)),
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
batch_size=config.batch_size,
|
||||
num_workers=config.num_workers
|
||||
num_workers=config.num_workers,
|
||||
)
|
||||
if self.test_dataset is not None:
|
||||
test_loader = DataLoader(
|
||||
self.test_dataset,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
batch_size=config.batch_size,
|
||||
num_workers=config.num_workers
|
||||
)
|
||||
|
||||
for epoch in range(config.max_epochs):
|
||||
run_epoch(train_loader, is_train=True)
|
||||
if self.test_dataset is not None:
|
||||
test_loss = run_epoch(test_loader, is_train=False)
|
||||
self.iter_num = 0
|
||||
data_iter = iter(train_loader)
|
||||
while True:
|
||||
|
||||
# supports early stopping based on the test loss, or just save always if no test set is provided
|
||||
good_model = self.test_dataset is None or test_loss < best_loss
|
||||
if self.config.ckpt_path is not None and good_model:
|
||||
best_loss = test_loss
|
||||
self.save_checkpoint()
|
||||
# fetch the next batch (x, y) and re-init iterator if needed
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
data_iter = iter(train_loader)
|
||||
batch = next(data_iter)
|
||||
batch = [t.to(self.device) for t in batch]
|
||||
x, y = batch
|
||||
|
||||
# forward the model
|
||||
model.train()
|
||||
logits, self.loss = model(x, y)
|
||||
|
||||
# backprop and update the parameters
|
||||
model.zero_grad(set_to_none=True)
|
||||
self.loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
|
||||
optimizer.step()
|
||||
|
||||
self.trigger_callbacks('on_batch_end')
|
||||
self.iter_num += 1
|
||||
|
@ -45,3 +45,9 @@ def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
|
||||
x = torch.cat((x, ix), dim=1)
|
||||
|
||||
return x
|
||||
|
||||
class CfgNode:
|
||||
""" a lightweight configuration class inspired by yacs """
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
186
projects/adder/adder.py
Normal file
186
projects/adder/adder.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""
|
||||
Trains a GPT to add n-digit numbers.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from mingpt.model import GPT
|
||||
from mingpt.trainer import Trainer
|
||||
from mingpt.utils import set_seed, sample, CfgNode as CN
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def get_config():
|
||||
|
||||
C = CN()
|
||||
|
||||
# system
|
||||
C.system = CN()
|
||||
C.system.seed = 1337
|
||||
C.system.work_dir = './out/adder'
|
||||
|
||||
# data
|
||||
C.data = AdditionDataset.get_default_config()
|
||||
|
||||
# model
|
||||
C.model = GPT.get_default_config('GPT-Micro')
|
||||
C.model.vocab_size = 10 # the digits 0..9
|
||||
# a,b,a+b, and +1 due to potential carry overflow,
|
||||
# but then also -1 because very last digit doesn't ever plug back
|
||||
# as there is no explicit <EOS> token to predict, it is implied
|
||||
C.model.block_size = C.data.ndigit + C.data.ndigit + C.data.ndigit + 1 - 1
|
||||
|
||||
# trainer
|
||||
C.trainer = Trainer.get_default_config()
|
||||
C.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
|
||||
|
||||
return C
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class AdditionDataset(Dataset):
|
||||
"""
|
||||
Creates n-digit addition problems. For example, if n=2, then an example
|
||||
addition problem would be to add 85 + 50 = 135. This problem would be
|
||||
represented as the following string for the GPT:
|
||||
|
||||
"8550531"
|
||||
|
||||
This is because:
|
||||
- we are discarding the + and =, which are not necessary. We just encode the digits
|
||||
of the input numbers concatenated together.
|
||||
- the result 135 is encoded backwards to make the addition easier to learn for the
|
||||
GPT model, because of how the addition algorithm works.
|
||||
|
||||
As one more example, the problem 6 + 39 = 45 would be encoded as:
|
||||
|
||||
"0639054"
|
||||
|
||||
where you will notice that we are padding with zeros to make sure that we always
|
||||
produce strings of the exact same size: n + n + (n + 1). When n=2, this is 7.
|
||||
At test time, we will feed in an addition problem by giving the first 2n digits,
|
||||
and hoping that the GPT model completes the sequence with the next (n+1) digits
|
||||
correctly.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_default_config(self):
|
||||
C = CN()
|
||||
C.ndigit = 2
|
||||
return C
|
||||
|
||||
def __init__(self, config, split):
|
||||
self.config = config
|
||||
self.split = split # train/test
|
||||
self.vocab_size = 10 # 10 possible digits 0..9
|
||||
|
||||
# split up all addition problems into either training data or test data
|
||||
ndigit = self.config.ndigit
|
||||
assert ndigit <= 3, "the lines below would be very memory inefficient, in future maybe refactor to support"
|
||||
num = (10**ndigit)**2 # total number of possible addition problems with ndigit numbers
|
||||
rng = torch.Generator()
|
||||
rng.manual_seed(1337)
|
||||
perm = torch.randperm(num, generator=rng)
|
||||
num_test = min(int(num*0.2), 500) # 20% of the whole dataset, or only up to 500
|
||||
self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]
|
||||
|
||||
def __len__(self):
|
||||
return self.ixes.nelement()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
ndigit = self.config.ndigit
|
||||
# given a problem index idx, first recover the associated a + b
|
||||
idx = self.ixes[idx].item()
|
||||
nd = 10**ndigit
|
||||
a = idx // nd
|
||||
b = idx % nd
|
||||
# calculate the "label" of the addition problem a + b
|
||||
c = a + b
|
||||
# encode the digits of a, b, c into strings
|
||||
astr = f'%0{ndigit}d' % a
|
||||
bstr = f'%0{ndigit}d' % b
|
||||
cstr = (f'%0{ndigit+1}d' % c)[::-1] # reverse c to make addition easier
|
||||
render = astr + bstr + cstr
|
||||
dix = [int(s) for s in render] # convert each character to its token index
|
||||
# x will be input to GPT and y will be the associated expected outputs
|
||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||
y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence
|
||||
y[:ndigit*2-1] = -1 # we will only train in the output locations. -1 will mask loss to zero
|
||||
return x, y
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# get default config and overrides from the command line, if any
|
||||
config = get_config()
|
||||
#config.merge_from_list(sys.argv[1:])
|
||||
|
||||
# inits and logging
|
||||
set_seed(config.system.seed)
|
||||
os.makedirs(config.system.work_dir, exist_ok=True)
|
||||
|
||||
# construct train and test datasets
|
||||
train_dataset = AdditionDataset(config.data, split='train')
|
||||
test_dataset = AdditionDataset(config.data, split='test')
|
||||
|
||||
# construct the model
|
||||
model = GPT(config.model)
|
||||
|
||||
# construct the trainer object
|
||||
trainer = Trainer(config.trainer, model, train_dataset)
|
||||
|
||||
# helper function for the evaluation of a model
|
||||
def eval_split(trainer, split, max_batches=-1):
|
||||
dataset = {'train':train_dataset, 'test':test_dataset}[split]
|
||||
ndigit = config.data.ndigit
|
||||
results = []
|
||||
mistakes_printed_already = 0
|
||||
loader = DataLoader(dataset, batch_size=50, num_workers=0, drop_last=False)
|
||||
for b, (x, y) in enumerate(loader):
|
||||
x = x.to(trainer.device)
|
||||
d1d2 = x[:, :ndigit*2]
|
||||
d1d2d3 = sample(model, d1d2, ndigit+1)
|
||||
d3 = d1d2d3[:, -(ndigit+1):]
|
||||
d3 = d3.flip(1) # reverse the digits to their "normal" order
|
||||
factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device)
|
||||
# decode the integers from individual digits
|
||||
d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)
|
||||
d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)
|
||||
d3i_pred = (d3 * factors).sum(1)
|
||||
d3i_gt = d1i + d2i
|
||||
correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
|
||||
for i in range(x.size(0)):
|
||||
results.append(int(correct[i]))
|
||||
if not correct[i] and mistakes_printed_already < 5: # only print up to 5 mistakes to get a sense
|
||||
mistakes_printed_already += 1
|
||||
print("GPT claims that %03d + %03d = %03d but gt is %03d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i]))
|
||||
if max_batches >= 0 and b+1 >= max_batches:
|
||||
break
|
||||
rt = torch.tensor(results, dtype=torch.float)
|
||||
print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean()))
|
||||
|
||||
# iteration callback
|
||||
def batch_end_callback(trainer):
|
||||
|
||||
if trainer.iter_num % 100 == 0:
|
||||
print(f"iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
|
||||
|
||||
if trainer.iter_num % 500 == 0:
|
||||
model.eval()
|
||||
eval_split(trainer, 'train', max_batches=-1)
|
||||
eval_split(trainer, 'test', max_batches=-1)
|
||||
|
||||
# todo: save good models but only if it's the best model so far
|
||||
# ckpt_path = os.path.join(config.system.work_dir, "model.pt")
|
||||
# torch.save(model.state_dict(), ckpt_path)
|
||||
|
||||
trainer.register_callback('on_batch_end', batch_end_callback)
|
||||
|
||||
# run the optimization
|
||||
trainer.run()
|
4
projects/adder/readme.md
Normal file
4
projects/adder/readme.md
Normal file
@ -0,0 +1,4 @@
|
||||
|
||||
### adder
|
||||
|
||||
Train a GPT model to add n-digit numbers
|
4
projects/readme.md
Normal file
4
projects/readme.md
Normal file
@ -0,0 +1,4 @@
|
||||
|
||||
### minGPT projects
|
||||
|
||||
Various projects that use the minGPT library to achieve great things.
|
Loading…
Reference in New Issue
Block a user