mirror of
https://github.com/karpathy/minGPT
synced 2024-03-29 10:19:59 +01:00
refactor sequence generation into the model and match the huggingface/transformers api. touches everything but this makes a lot more sense to me aesthetically
This commit is contained in:
parent
5af9e5c5d7
commit
acaadacd59
70
demo.ipynb
70
demo.ipynb
|
@ -15,7 +15,9 @@
|
|||
"source": [
|
||||
"import torch\n",
|
||||
"from torch.utils.data import Dataset\n",
|
||||
"from torch.utils.data.dataloader import DataLoader"
|
||||
"from torch.utils.data.dataloader import DataLoader\n",
|
||||
"from mingpt.utils import set_seed\n",
|
||||
"set_seed(3407)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -96,17 +98,17 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2 -1\n",
|
||||
"2 -1\n",
|
||||
"1 -1\n",
|
||||
"0 -1\n",
|
||||
"1 -1\n",
|
||||
"0 -1\n",
|
||||
"2 0\n",
|
||||
"0 -1\n",
|
||||
"0 0\n",
|
||||
"0 0\n",
|
||||
"0 0\n",
|
||||
"0 0\n",
|
||||
"0 1\n",
|
||||
"1 2\n",
|
||||
"2 2\n",
|
||||
"2 2\n"
|
||||
"1 1\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -152,7 +154,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"running on device cpu\n"
|
||||
"running on device cuda\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -176,26 +178,26 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"iter_dt 0.00ms; iter 0: train loss 1.09793\n",
|
||||
"iter_dt 29.36ms; iter 100: train loss 0.14420\n",
|
||||
"iter_dt 29.03ms; iter 200: train loss 0.04971\n",
|
||||
"iter_dt 28.62ms; iter 300: train loss 0.03680\n",
|
||||
"iter_dt 28.92ms; iter 400: train loss 0.01332\n",
|
||||
"iter_dt 28.34ms; iter 500: train loss 0.01905\n",
|
||||
"iter_dt 28.35ms; iter 600: train loss 0.02515\n",
|
||||
"iter_dt 28.69ms; iter 700: train loss 0.02522\n",
|
||||
"iter_dt 28.70ms; iter 800: train loss 0.02379\n",
|
||||
"iter_dt 28.39ms; iter 900: train loss 0.00192\n",
|
||||
"iter_dt 28.40ms; iter 1000: train loss 0.01416\n",
|
||||
"iter_dt 28.47ms; iter 1100: train loss 0.00136\n",
|
||||
"iter_dt 28.21ms; iter 1200: train loss 0.02124\n",
|
||||
"iter_dt 28.21ms; iter 1300: train loss 0.05553\n",
|
||||
"iter_dt 28.39ms; iter 1400: train loss 0.00930\n",
|
||||
"iter_dt 28.00ms; iter 1500: train loss 0.00863\n",
|
||||
"iter_dt 28.57ms; iter 1600: train loss 0.00624\n",
|
||||
"iter_dt 28.39ms; iter 1700: train loss 0.00355\n",
|
||||
"iter_dt 28.35ms; iter 1800: train loss 0.00235\n",
|
||||
"iter_dt 28.98ms; iter 1900: train loss 0.00243\n"
|
||||
"iter_dt 0.00ms; iter 0: train loss 1.06407\n",
|
||||
"iter_dt 18.17ms; iter 100: train loss 0.14712\n",
|
||||
"iter_dt 18.70ms; iter 200: train loss 0.05315\n",
|
||||
"iter_dt 19.65ms; iter 300: train loss 0.04404\n",
|
||||
"iter_dt 31.64ms; iter 400: train loss 0.04724\n",
|
||||
"iter_dt 18.43ms; iter 500: train loss 0.02521\n",
|
||||
"iter_dt 19.83ms; iter 600: train loss 0.03352\n",
|
||||
"iter_dt 19.58ms; iter 700: train loss 0.00539\n",
|
||||
"iter_dt 18.72ms; iter 800: train loss 0.02057\n",
|
||||
"iter_dt 18.26ms; iter 900: train loss 0.00360\n",
|
||||
"iter_dt 18.50ms; iter 1000: train loss 0.00788\n",
|
||||
"iter_dt 20.64ms; iter 1100: train loss 0.01162\n",
|
||||
"iter_dt 18.63ms; iter 1200: train loss 0.00963\n",
|
||||
"iter_dt 18.32ms; iter 1300: train loss 0.02066\n",
|
||||
"iter_dt 18.40ms; iter 1400: train loss 0.01739\n",
|
||||
"iter_dt 18.37ms; iter 1500: train loss 0.00376\n",
|
||||
"iter_dt 18.67ms; iter 1600: train loss 0.00133\n",
|
||||
"iter_dt 18.38ms; iter 1700: train loss 0.00179\n",
|
||||
"iter_dt 18.66ms; iter 1800: train loss 0.00079\n",
|
||||
"iter_dt 18.48ms; iter 1900: train loss 0.00042\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -233,8 +235,6 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from mingpt.utils import sample\n",
|
||||
"\n",
|
||||
"def eval_split(trainer, split, max_batches):\n",
|
||||
" dataset = {'train':train_dataset, 'test':test_dataset}[split]\n",
|
||||
" n = train_dataset.length # naugy direct access shrug\n",
|
||||
|
@ -248,7 +248,7 @@
|
|||
" inp = x[:, :n]\n",
|
||||
" sol = y[:, -n:]\n",
|
||||
" # let the model sample the rest of the sequence\n",
|
||||
" cat = sample(model, inp, n, sample=False) # using greedy argmax, not sampling\n",
|
||||
" cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling\n",
|
||||
" sol_candidate = cat[:, n:] # isolate the filled in sequence\n",
|
||||
" # compare the predicted sequence to the true sequence\n",
|
||||
" correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha\n",
|
||||
|
@ -291,7 +291,7 @@
|
|||
"inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)\n",
|
||||
"assert inp[0].nelement() == n\n",
|
||||
"with torch.no_grad():\n",
|
||||
" cat = sample(model, inp, n, sample=False)\n",
|
||||
" cat = model.generate(inp, n, do_sample=False)\n",
|
||||
"sol = torch.sort(inp[0])[0]\n",
|
||||
"sol_candidate = cat[:, n:]\n",
|
||||
"print('input sequence :', inp.tolist())\n",
|
||||
|
@ -303,7 +303,7 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.8.5 ('base')",
|
||||
"display_name": "Python 3.10.4 64-bit",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -317,12 +317,12 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
"version": "3.10.4"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "afdab15bd6582f87e2d1e596bfa7241af51aedf8abc909e2cab3828057cb30c9"
|
||||
"hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
"source": [
|
||||
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
|
||||
"from mingpt.model import GPT\n",
|
||||
"from mingpt.utils import sample\n",
|
||||
"from mingpt.utils import set_seed\n",
|
||||
"set_seed(3407)"
|
||||
]
|
||||
|
@ -74,10 +73,7 @@
|
|||
" x = x.expand(num_samples, -1)\n",
|
||||
"\n",
|
||||
" # forward the model `steps` times to get samples, in a batch\n",
|
||||
" if use_mingpt:\n",
|
||||
" y = sample(model=model, x=x, steps=steps, sample=do_sample, top_k=40)\n",
|
||||
" else:\n",
|
||||
" y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n",
|
||||
" y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n",
|
||||
" \n",
|
||||
" for i in range(num_samples):\n",
|
||||
" out = tokenizer.decode(y[i].cpu().squeeze())\n",
|
||||
|
@ -95,8 +91,8 @@
|
|||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2022-07-08 23:51:10.949993: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
|
||||
"2022-07-08 23:51:10.950042: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
|
||||
"2022-07-11 18:42:21.744061: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
|
||||
"2022-07-11 18:42:21.744099: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -1,10 +1,5 @@
|
|||
"""
|
||||
GPT model:
|
||||
- the initial stem consists of a combination of token encoding and a positional encoding
|
||||
- the meat of it is a uniform sequence of Transformer blocks
|
||||
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
|
||||
- all blocks feed into a central residual pathway similar to resnets
|
||||
- the final decoder is a linear projection into a vanilla Softmax classifier
|
||||
Full definition of a GPT Language Model, all of it in this single file.
|
||||
|
||||
References:
|
||||
1) the official GPT-2 TensorFlow implementation released by OpenAI:
|
||||
|
@ -161,13 +156,10 @@ class GPT(nn.Module):
|
|||
if pn.endswith('c_proj.weight'):
|
||||
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
||||
|
||||
# report number of parameters
|
||||
# report number of parameters (note we don't count the decoder parameters in lm_head)
|
||||
n_params = sum(p.numel() for p in self.transformer.parameters())
|
||||
print("number of parameters: %.2fM" % (n_params/1e6,))
|
||||
|
||||
def get_block_size(self):
|
||||
return self.block_size
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
@ -286,3 +278,33 @@ class GPT(nn.Module):
|
|||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
|
||||
return logits, loss
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
|
||||
"""
|
||||
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
||||
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
||||
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
||||
"""
|
||||
for _ in range(max_new_tokens):
|
||||
# if the sequence context is growing too long we must crop it at block_size
|
||||
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
|
||||
# forward the model to get the logits for the index in the sequence
|
||||
logits, _ = self(idx_cond)
|
||||
# pluck the logits at the final step and scale by desired temperature
|
||||
logits = logits[:, -1, :] / temperature
|
||||
# optionally crop the logits to only the top k options
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
# apply softmax to convert logits to (normalized) probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# either sample from the distribution or take the most likely element
|
||||
if do_sample:
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
_, idx_next = torch.topk(probs, k=1, dim=-1)
|
||||
# append sampled index to the running sequence and continue
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
|
||||
return idx
|
||||
|
|
|
@ -7,8 +7,8 @@ from ast import literal_eval
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
|
@ -28,42 +28,6 @@ def setup_logging(config):
|
|||
with open(os.path.join(work_dir, 'config.json'), 'w') as f:
|
||||
f.write(json.dumps(config.to_dict(), indent=4))
|
||||
|
||||
def top_k_logits(logits, k):
|
||||
v, ix = torch.topk(logits, k)
|
||||
out = logits.clone()
|
||||
out[out < v[:, [-1]]] = -float('Inf')
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
|
||||
"""
|
||||
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
|
||||
the sequence, feeding the predictions back into the model each time. Clearly the sampling
|
||||
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
|
||||
of block_size, unlike an RNN that has an infinite context window.
|
||||
"""
|
||||
block_size = model.get_block_size()
|
||||
model.eval()
|
||||
for k in range(steps):
|
||||
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
||||
logits, _ = model(x_cond)
|
||||
# pluck the logits at the final step and scale by temperature
|
||||
logits = logits[:, -1, :] / temperature
|
||||
# optionally crop probabilities to only the top k options
|
||||
if top_k is not None:
|
||||
logits = top_k_logits(logits, top_k)
|
||||
# apply softmax to convert to probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# sample from the distribution or take the most likely
|
||||
if sample:
|
||||
ix = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
_, ix = torch.topk(probs, k=1, dim=-1)
|
||||
# append to the sequence and continue
|
||||
x = torch.cat((x, ix), dim=1)
|
||||
|
||||
return x
|
||||
|
||||
class CfgNode:
|
||||
""" a lightweight configuration class inspired by yacs """
|
||||
# TODO: convert to subclass from a dict like in yacs?
|
||||
|
|
|
@ -12,7 +12,7 @@ from torch.utils.data.dataloader import DataLoader
|
|||
|
||||
from mingpt.model import GPT
|
||||
from mingpt.trainer import Trainer
|
||||
from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN
|
||||
from mingpt.utils import set_seed, setup_logging, CfgNode as CN
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
@ -154,7 +154,7 @@ if __name__ == '__main__':
|
|||
# isolate the first two digits of the input sequence alone
|
||||
d1d2 = x[:, :ndigit*2]
|
||||
# let the model sample the rest of the sequence
|
||||
d1d2d3 = sample(model, d1d2, ndigit+1, sample=False) # using greedy argmax, not sampling
|
||||
d1d2d3 = model.generate(d1d2, ndigit+1, do_sample=False) # using greedy argmax, not sampling
|
||||
# isolate the last digit of the sampled sequence
|
||||
d3 = d1d2d3[:, -(ndigit+1):]
|
||||
d3 = d3.flip(1) # reverse the digits to their "normal" order
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch.utils.data.dataloader import DataLoader
|
|||
|
||||
from mingpt.model import GPT
|
||||
from mingpt.trainer import Trainer
|
||||
from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN
|
||||
from mingpt.utils import set_seed, setup_logging, CfgNode as CN
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
@ -117,7 +117,7 @@ if __name__ == '__main__':
|
|||
# sample from the model...
|
||||
context = "O God, O God!"
|
||||
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
|
||||
y = sample(model, x, 500, temperature=1.0, sample=True, top_k=10)[0]
|
||||
y = model.generate(x, 500, temperature=1.0, do_sample=True, top_k=10)[0]
|
||||
completion = ''.join([train_dataset.itos[int(i)] for i in y])
|
||||
print(completion)
|
||||
# save the latest model
|
||||
|
|
|
@ -6,7 +6,6 @@ import unittest
|
|||
import torch
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
from mingpt.model import GPT
|
||||
from mingpt.utils import sample
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
@ -41,14 +40,15 @@ class TestHuggingFaceImport(unittest.TestCase):
|
|||
logits2 = model_hf(x).logits
|
||||
self.assertTrue(torch.allclose(logits1, logits2))
|
||||
|
||||
# now draw the argmax samples from each and compare them
|
||||
y1 = sample(model=model, x=x, steps=20, sample=False)[0]
|
||||
out1 = tokenizer.decode(y1.cpu().squeeze())
|
||||
# now draw the argmax samples from each
|
||||
y1 = model.generate(x, max_new_tokens=20, do_sample=False)[0]
|
||||
y2 = model_hf.generate(x, max_new_tokens=20, do_sample=False)[0]
|
||||
out2 = tokenizer.decode(y2.cpu().squeeze())
|
||||
self.assertTrue(torch.equal(y1, y2))
|
||||
self.assertTrue(out1 == out2) # compare the output strings too, exactly
|
||||
self.assertTrue(torch.equal(y1, y2)) # compare the raw sampled indices
|
||||
|
||||
# convert indices to strings
|
||||
out1 = tokenizer.decode(y1.cpu().squeeze())
|
||||
out2 = tokenizer.decode(y2.cpu().squeeze())
|
||||
self.assertTrue(out1 == out2) # compare the exact output strings too
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue