1
0
Fork 0
mirror of https://github.com/karpathy/minGPT synced 2024-03-29 10:19:59 +01:00

refactor sequence generation into the model and match the huggingface/transformers api. touches everything but this makes a lot more sense to me aesthetically

This commit is contained in:
Andrej 2022-07-11 18:50:53 +00:00 committed by GitHub
parent 5af9e5c5d7
commit acaadacd59
7 changed files with 83 additions and 101 deletions

View File

@ -15,7 +15,9 @@
"source": [
"import torch\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data.dataloader import DataLoader"
"from torch.utils.data.dataloader import DataLoader\n",
"from mingpt.utils import set_seed\n",
"set_seed(3407)"
]
},
{
@ -96,17 +98,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"2 -1\n",
"2 -1\n",
"1 -1\n",
"0 -1\n",
"1 -1\n",
"0 -1\n",
"2 0\n",
"0 -1\n",
"0 0\n",
"0 0\n",
"0 0\n",
"0 0\n",
"0 1\n",
"1 2\n",
"2 2\n",
"2 2\n"
"1 1\n"
]
}
],
@ -152,7 +154,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"running on device cpu\n"
"running on device cuda\n"
]
}
],
@ -176,26 +178,26 @@
"name": "stdout",
"output_type": "stream",
"text": [
"iter_dt 0.00ms; iter 0: train loss 1.09793\n",
"iter_dt 29.36ms; iter 100: train loss 0.14420\n",
"iter_dt 29.03ms; iter 200: train loss 0.04971\n",
"iter_dt 28.62ms; iter 300: train loss 0.03680\n",
"iter_dt 28.92ms; iter 400: train loss 0.01332\n",
"iter_dt 28.34ms; iter 500: train loss 0.01905\n",
"iter_dt 28.35ms; iter 600: train loss 0.02515\n",
"iter_dt 28.69ms; iter 700: train loss 0.02522\n",
"iter_dt 28.70ms; iter 800: train loss 0.02379\n",
"iter_dt 28.39ms; iter 900: train loss 0.00192\n",
"iter_dt 28.40ms; iter 1000: train loss 0.01416\n",
"iter_dt 28.47ms; iter 1100: train loss 0.00136\n",
"iter_dt 28.21ms; iter 1200: train loss 0.02124\n",
"iter_dt 28.21ms; iter 1300: train loss 0.05553\n",
"iter_dt 28.39ms; iter 1400: train loss 0.00930\n",
"iter_dt 28.00ms; iter 1500: train loss 0.00863\n",
"iter_dt 28.57ms; iter 1600: train loss 0.00624\n",
"iter_dt 28.39ms; iter 1700: train loss 0.00355\n",
"iter_dt 28.35ms; iter 1800: train loss 0.00235\n",
"iter_dt 28.98ms; iter 1900: train loss 0.00243\n"
"iter_dt 0.00ms; iter 0: train loss 1.06407\n",
"iter_dt 18.17ms; iter 100: train loss 0.14712\n",
"iter_dt 18.70ms; iter 200: train loss 0.05315\n",
"iter_dt 19.65ms; iter 300: train loss 0.04404\n",
"iter_dt 31.64ms; iter 400: train loss 0.04724\n",
"iter_dt 18.43ms; iter 500: train loss 0.02521\n",
"iter_dt 19.83ms; iter 600: train loss 0.03352\n",
"iter_dt 19.58ms; iter 700: train loss 0.00539\n",
"iter_dt 18.72ms; iter 800: train loss 0.02057\n",
"iter_dt 18.26ms; iter 900: train loss 0.00360\n",
"iter_dt 18.50ms; iter 1000: train loss 0.00788\n",
"iter_dt 20.64ms; iter 1100: train loss 0.01162\n",
"iter_dt 18.63ms; iter 1200: train loss 0.00963\n",
"iter_dt 18.32ms; iter 1300: train loss 0.02066\n",
"iter_dt 18.40ms; iter 1400: train loss 0.01739\n",
"iter_dt 18.37ms; iter 1500: train loss 0.00376\n",
"iter_dt 18.67ms; iter 1600: train loss 0.00133\n",
"iter_dt 18.38ms; iter 1700: train loss 0.00179\n",
"iter_dt 18.66ms; iter 1800: train loss 0.00079\n",
"iter_dt 18.48ms; iter 1900: train loss 0.00042\n"
]
}
],
@ -233,8 +235,6 @@
}
],
"source": [
"from mingpt.utils import sample\n",
"\n",
"def eval_split(trainer, split, max_batches):\n",
" dataset = {'train':train_dataset, 'test':test_dataset}[split]\n",
" n = train_dataset.length # naugy direct access shrug\n",
@ -248,7 +248,7 @@
" inp = x[:, :n]\n",
" sol = y[:, -n:]\n",
" # let the model sample the rest of the sequence\n",
" cat = sample(model, inp, n, sample=False) # using greedy argmax, not sampling\n",
" cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling\n",
" sol_candidate = cat[:, n:] # isolate the filled in sequence\n",
" # compare the predicted sequence to the true sequence\n",
" correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha\n",
@ -291,7 +291,7 @@
"inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)\n",
"assert inp[0].nelement() == n\n",
"with torch.no_grad():\n",
" cat = sample(model, inp, n, sample=False)\n",
" cat = model.generate(inp, n, do_sample=False)\n",
"sol = torch.sort(inp[0])[0]\n",
"sol_candidate = cat[:, n:]\n",
"print('input sequence :', inp.tolist())\n",
@ -303,7 +303,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.5 ('base')",
"display_name": "Python 3.10.4 64-bit",
"language": "python",
"name": "python3"
},
@ -317,12 +317,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "afdab15bd6582f87e2d1e596bfa7241af51aedf8abc909e2cab3828057cb30c9"
"hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858"
}
}
},

View File

@ -15,7 +15,6 @@
"source": [
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
"from mingpt.model import GPT\n",
"from mingpt.utils import sample\n",
"from mingpt.utils import set_seed\n",
"set_seed(3407)"
]
@ -74,10 +73,7 @@
" x = x.expand(num_samples, -1)\n",
"\n",
" # forward the model `steps` times to get samples, in a batch\n",
" if use_mingpt:\n",
" y = sample(model=model, x=x, steps=steps, sample=do_sample, top_k=40)\n",
" else:\n",
" y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n",
" y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n",
" \n",
" for i in range(num_samples):\n",
" out = tokenizer.decode(y[i].cpu().squeeze())\n",
@ -95,8 +91,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2022-07-08 23:51:10.949993: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
"2022-07-08 23:51:10.950042: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
"2022-07-11 18:42:21.744061: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
"2022-07-11 18:42:21.744099: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
]
},
{

View File

@ -1,10 +1,5 @@
"""
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
- all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
Full definition of a GPT Language Model, all of it in this single file.
References:
1) the official GPT-2 TensorFlow implementation released by OpenAI:
@ -161,13 +156,10 @@ class GPT(nn.Module):
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
# report number of parameters
# report number of parameters (note we don't count the decoder parameters in lm_head)
n_params = sum(p.numel() for p in self.transformer.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
@ -286,3 +278,33 @@ class GPT(nn.Module):
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# either sample from the distribution or take the most likely element
if do_sample:
idx_next = torch.multinomial(probs, num_samples=1)
else:
_, idx_next = torch.topk(probs, k=1, dim=-1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
return idx

View File

@ -7,8 +7,8 @@ from ast import literal_eval
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
# -----------------------------------------------------------------------------
def set_seed(seed):
random.seed(seed)
@ -28,42 +28,6 @@ def setup_logging(config):
with open(os.path.join(work_dir, 'config.json'), 'w') as f:
f.write(json.dumps(config.to_dict(), indent=4))
def top_k_logits(logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[:, [-1]]] = -float('Inf')
return out
@torch.no_grad()
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
"""
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
the sequence, feeding the predictions back into the model each time. Clearly the sampling
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
of block_size, unlike an RNN that has an infinite context window.
"""
block_size = model.get_block_size()
model.eval()
for k in range(steps):
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
logits, _ = model(x_cond)
# pluck the logits at the final step and scale by temperature
logits = logits[:, -1, :] / temperature
# optionally crop probabilities to only the top k options
if top_k is not None:
logits = top_k_logits(logits, top_k)
# apply softmax to convert to probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution or take the most likely
if sample:
ix = torch.multinomial(probs, num_samples=1)
else:
_, ix = torch.topk(probs, k=1, dim=-1)
# append to the sequence and continue
x = torch.cat((x, ix), dim=1)
return x
class CfgNode:
""" a lightweight configuration class inspired by yacs """
# TODO: convert to subclass from a dict like in yacs?

View File

@ -12,7 +12,7 @@ from torch.utils.data.dataloader import DataLoader
from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN
from mingpt.utils import set_seed, setup_logging, CfgNode as CN
# -----------------------------------------------------------------------------
@ -154,7 +154,7 @@ if __name__ == '__main__':
# isolate the first two digits of the input sequence alone
d1d2 = x[:, :ndigit*2]
# let the model sample the rest of the sequence
d1d2d3 = sample(model, d1d2, ndigit+1, sample=False) # using greedy argmax, not sampling
d1d2d3 = model.generate(d1d2, ndigit+1, do_sample=False) # using greedy argmax, not sampling
# isolate the last digit of the sampled sequence
d3 = d1d2d3[:, -(ndigit+1):]
d3 = d3.flip(1) # reverse the digits to their "normal" order

View File

@ -11,7 +11,7 @@ from torch.utils.data.dataloader import DataLoader
from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN
from mingpt.utils import set_seed, setup_logging, CfgNode as CN
# -----------------------------------------------------------------------------
@ -117,7 +117,7 @@ if __name__ == '__main__':
# sample from the model...
context = "O God, O God!"
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
y = sample(model, x, 500, temperature=1.0, sample=True, top_k=10)[0]
y = model.generate(x, 500, temperature=1.0, do_sample=True, top_k=10)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)
# save the latest model

View File

@ -6,7 +6,6 @@ import unittest
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from mingpt.model import GPT
from mingpt.utils import sample
# -----------------------------------------------------------------------------
@ -41,14 +40,15 @@ class TestHuggingFaceImport(unittest.TestCase):
logits2 = model_hf(x).logits
self.assertTrue(torch.allclose(logits1, logits2))
# now draw the argmax samples from each and compare them
y1 = sample(model=model, x=x, steps=20, sample=False)[0]
out1 = tokenizer.decode(y1.cpu().squeeze())
# now draw the argmax samples from each
y1 = model.generate(x, max_new_tokens=20, do_sample=False)[0]
y2 = model_hf.generate(x, max_new_tokens=20, do_sample=False)[0]
out2 = tokenizer.decode(y2.cpu().squeeze())
self.assertTrue(torch.equal(y1, y2))
self.assertTrue(out1 == out2) # compare the output strings too, exactly
self.assertTrue(torch.equal(y1, y2)) # compare the raw sampled indices
# convert indices to strings
out1 = tokenizer.decode(y1.cpu().squeeze())
out2 = tokenizer.decode(y2.cpu().squeeze())
self.assertTrue(out1 == out2) # compare the exact output strings too
if __name__ == '__main__':
unittest.main()