mirror of
https://github.com/karpathy/minGPT
synced 2024-11-15 19:10:39 +01:00
332 lines
11 KiB
Plaintext
332 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"A cute little demo showing the simplest usage of minGPT. Configured to run fine on Macbook Air in like a minute."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch.utils.data import Dataset\n",
|
|
"from torch.utils.data.dataloader import DataLoader"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pickle\n",
|
|
"\n",
|
|
"class SortDataset(Dataset):\n",
|
|
" \"\"\" \n",
|
|
" Dataset for the Sort problem. E.g. for problem length 6:\n",
|
|
" Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2\n",
|
|
" Which will feed into the transformer concatenated as:\n",
|
|
" input: 0 0 2 1 0 1 0 0 0 1 1\n",
|
|
" output: I I I I I 0 0 0 1 1 2\n",
|
|
" where I is \"ignore\", as the transformer is reading the input sequence\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, split, length=6, num_digits=3):\n",
|
|
" assert split in {'train', 'test'}\n",
|
|
" self.split = split\n",
|
|
" self.length = length\n",
|
|
" self.num_digits = num_digits\n",
|
|
" \n",
|
|
" def __len__(self):\n",
|
|
" return 10000 # ...\n",
|
|
" \n",
|
|
" def get_vocab_size(self):\n",
|
|
" return self.num_digits\n",
|
|
" \n",
|
|
" def get_block_size(self):\n",
|
|
" # the length of the sequence that will feed into transformer, \n",
|
|
" # containing concatenated input and the output, but -1 because\n",
|
|
" # the transformer starts making predictions at the last input element\n",
|
|
" return self.length * 2 - 1\n",
|
|
"\n",
|
|
" def __getitem__(self, idx):\n",
|
|
" \n",
|
|
" # use rejection sampling to generate an input example from the desired split\n",
|
|
" while True:\n",
|
|
" # generate some random integers\n",
|
|
" inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)\n",
|
|
" # half of the time let's try to boost the number of examples that \n",
|
|
" # have a large number of repeats, as this is what the model seems to struggle\n",
|
|
" # with later in training, and they are kind of rate\n",
|
|
" if torch.rand(1).item() < 0.5:\n",
|
|
" if inp.unique().nelement() > self.length // 2:\n",
|
|
" # too many unqiue digits, re-sample\n",
|
|
" continue\n",
|
|
" # figure out if this generated example is train or test based on its hash\n",
|
|
" h = hash(pickle.dumps(inp.tolist()))\n",
|
|
" inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test\n",
|
|
" if inp_split == self.split:\n",
|
|
" break # ok\n",
|
|
" \n",
|
|
" # solve the task: i.e. sort\n",
|
|
" sol = torch.sort(inp)[0]\n",
|
|
"\n",
|
|
" # concatenate the problem specification and the solution\n",
|
|
" cat = torch.cat((inp, sol), dim=0)\n",
|
|
"\n",
|
|
" # the inputs to the transformer will be the offset sequence\n",
|
|
" x = cat[:-1].clone()\n",
|
|
" y = cat[1:].clone()\n",
|
|
" # we only want to predict at output locations, mask out the loss at the input locations\n",
|
|
" y[:self.length-1] = -1\n",
|
|
" return x, y\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2 -1\n",
|
|
"2 -1\n",
|
|
"0 -1\n",
|
|
"1 -1\n",
|
|
"0 -1\n",
|
|
"2 0\n",
|
|
"0 0\n",
|
|
"0 1\n",
|
|
"1 2\n",
|
|
"2 2\n",
|
|
"2 2\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# print an example instance of the dataset\n",
|
|
"train_dataset = SortDataset('train')\n",
|
|
"test_dataset = SortDataset('test')\n",
|
|
"x, y = train_dataset[0]\n",
|
|
"for a, b in zip(x,y):\n",
|
|
" print(int(a),int(b))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"number of parameters: 0.09M\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# create a GPT instance\n",
|
|
"from mingpt.model import GPT\n",
|
|
"\n",
|
|
"model_config = GPT.get_default_config()\n",
|
|
"model_config.model_type = 'gpt-nano'\n",
|
|
"model_config.vocab_size = train_dataset.get_vocab_size()\n",
|
|
"model_config.block_size = train_dataset.get_block_size()\n",
|
|
"model = GPT(model_config)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"running on device cpu\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# create a Trainer object\n",
|
|
"from mingpt.trainer import Trainer\n",
|
|
"\n",
|
|
"train_config = Trainer.get_default_config()\n",
|
|
"train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster\n",
|
|
"train_config.max_iters = 2000\n",
|
|
"train_config.num_workers = 0\n",
|
|
"trainer = Trainer(train_config, model, train_dataset)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"iter_dt 0.00ms; iter 0: train loss 1.09793\n",
|
|
"iter_dt 29.36ms; iter 100: train loss 0.14420\n",
|
|
"iter_dt 29.03ms; iter 200: train loss 0.04971\n",
|
|
"iter_dt 28.62ms; iter 300: train loss 0.03680\n",
|
|
"iter_dt 28.92ms; iter 400: train loss 0.01332\n",
|
|
"iter_dt 28.34ms; iter 500: train loss 0.01905\n",
|
|
"iter_dt 28.35ms; iter 600: train loss 0.02515\n",
|
|
"iter_dt 28.69ms; iter 700: train loss 0.02522\n",
|
|
"iter_dt 28.70ms; iter 800: train loss 0.02379\n",
|
|
"iter_dt 28.39ms; iter 900: train loss 0.00192\n",
|
|
"iter_dt 28.40ms; iter 1000: train loss 0.01416\n",
|
|
"iter_dt 28.47ms; iter 1100: train loss 0.00136\n",
|
|
"iter_dt 28.21ms; iter 1200: train loss 0.02124\n",
|
|
"iter_dt 28.21ms; iter 1300: train loss 0.05553\n",
|
|
"iter_dt 28.39ms; iter 1400: train loss 0.00930\n",
|
|
"iter_dt 28.00ms; iter 1500: train loss 0.00863\n",
|
|
"iter_dt 28.57ms; iter 1600: train loss 0.00624\n",
|
|
"iter_dt 28.39ms; iter 1700: train loss 0.00355\n",
|
|
"iter_dt 28.35ms; iter 1800: train loss 0.00235\n",
|
|
"iter_dt 28.98ms; iter 1900: train loss 0.00243\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def batch_end_callback(trainer):\n",
|
|
" if trainer.iter_num % 100 == 0:\n",
|
|
" print(f\"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}\")\n",
|
|
"trainer.set_callback('on_batch_end', batch_end_callback)\n",
|
|
"\n",
|
|
"trainer.run()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# now let's perform some evaluation\n",
|
|
"model.eval();"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"train final score: 5000/5000 = 100.00% correct\n",
|
|
"test final score: 5000/5000 = 100.00% correct\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from mingpt.utils import sample\n",
|
|
"\n",
|
|
"def eval_split(trainer, split, max_batches):\n",
|
|
" dataset = {'train':train_dataset, 'test':test_dataset}[split]\n",
|
|
" n = train_dataset.length # naugy direct access shrug\n",
|
|
" results = []\n",
|
|
" mistakes_printed_already = 0\n",
|
|
" loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)\n",
|
|
" for b, (x, y) in enumerate(loader):\n",
|
|
" x = x.to(trainer.device)\n",
|
|
" y = y.to(trainer.device)\n",
|
|
" # isolate the input pattern alone\n",
|
|
" inp = x[:, :n]\n",
|
|
" sol = y[:, -n:]\n",
|
|
" # let the model sample the rest of the sequence\n",
|
|
" cat = sample(model, inp, n, sample=False) # using greedy argmax, not sampling\n",
|
|
" sol_candidate = cat[:, n:] # isolate the filled in sequence\n",
|
|
" # compare the predicted sequence to the true sequence\n",
|
|
" correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha\n",
|
|
" for i in range(x.size(0)):\n",
|
|
" results.append(int(correct[i]))\n",
|
|
" if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense\n",
|
|
" mistakes_printed_already += 1\n",
|
|
" print(\"GPT claims that %s sorted is %s but gt is %s\" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))\n",
|
|
" if max_batches is not None and b+1 >= max_batches:\n",
|
|
" break\n",
|
|
" rt = torch.tensor(results, dtype=torch.float)\n",
|
|
" print(\"%s final score: %d/%d = %.2f%% correct\" % (split, rt.sum(), len(results), 100*rt.mean()))\n",
|
|
" return rt.sum()\n",
|
|
"\n",
|
|
"# run a lot of examples from both train and test through the model and verify the output correctness\n",
|
|
"with torch.no_grad():\n",
|
|
" train_score = eval_split(trainer, 'train', max_batches=50)\n",
|
|
" test_score = eval_split(trainer, 'test', max_batches=50)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"input sequence : [[0, 0, 2, 1, 0, 1]]\n",
|
|
"predicted sorted: [[0, 0, 0, 1, 1, 2]]\n",
|
|
"gt sort : [0, 0, 0, 1, 1, 2]\n",
|
|
"matches : True\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# let's run a random given sequence through the model as well\n",
|
|
"n = train_dataset.length # naugy direct access shrug\n",
|
|
"inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)\n",
|
|
"assert inp[0].nelement() == n\n",
|
|
"with torch.no_grad():\n",
|
|
" cat = sample(model, inp, n, sample=False)\n",
|
|
"sol = torch.sort(inp[0])[0]\n",
|
|
"sol_candidate = cat[:, n:]\n",
|
|
"print('input sequence :', inp.tolist())\n",
|
|
"print('predicted sorted:', sol_candidate.tolist())\n",
|
|
"print('gt sort :', sol.tolist())\n",
|
|
"print('matches :', bool((sol == sol_candidate).all()))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3.8.5 ('base')",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.5"
|
|
},
|
|
"orig_nbformat": 4,
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "afdab15bd6582f87e2d1e596bfa7241af51aedf8abc909e2cab3828057cb30c9"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|