mirror of
https://github.com/karpathy/minGPT
synced 2024-11-15 19:10:39 +01:00
414 lines
21 KiB
Plaintext
414 lines
21 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Train GPT on addition\n",
|
|
"\n",
|
|
"Train a GPT model on a dedicated addition dataset to see if a Transformer can learn to add."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# set up logging\n",
|
|
"import logging\n",
|
|
"logging.basicConfig(\n",
|
|
" format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
|
|
" datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
|
|
" level=logging.INFO,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# make deterministic\n",
|
|
"from mingpt.utils import set_seed\n",
|
|
"set_seed(42)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"from torch.nn import functional as F"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from torch.utils.data import Dataset\n",
|
|
"\n",
|
|
"class AdditionDataset(Dataset):\n",
|
|
" \"\"\"\n",
|
|
" Returns addition problems of up to some number of digits in the inputs. Recall\n",
|
|
" that all GPT cares about are sequences of integers, and completing them according to\n",
|
|
" patterns in the data. Therefore, we have to somehow encode addition problems\n",
|
|
" as a sequence of integers.\n",
|
|
" \n",
|
|
" The sum of two n-digit numbers gives a third up to (n+1)-digit number. So our\n",
|
|
" encoding will simply be the n-digit first number, n-digit second number, \n",
|
|
" and (n+1)-digit result, all simply concatenated together. Because each addition\n",
|
|
" problem is so structured, there is no need to bother the model with encoding\n",
|
|
" +, =, or other tokens. Each possible sequence has the same length, and simply\n",
|
|
" contains the raw digits of the addition problem.\n",
|
|
" \n",
|
|
" As a few examples, the 2-digit problems:\n",
|
|
" - 85 + 50 = 135 becomes the sequence [8, 5, 5, 0, 1, 3, 5]\n",
|
|
" - 6 + 39 = 45 becomes the sequence [0, 6, 3, 9, 0, 4, 5]\n",
|
|
" etc.\n",
|
|
" \n",
|
|
" We will also only train GPT on the final (n+1)-digits because the first\n",
|
|
" two n-digits are always assumed to be given. So when we give GPT an exam later,\n",
|
|
" we will e.g. feed it the sequence [0, 6, 3, 9], which encodes that we'd like\n",
|
|
" to add 6 + 39, and hope that the model completes the integer sequence with [0, 4, 5]\n",
|
|
" in 3 sequential steps.\n",
|
|
" \n",
|
|
" fun exercise: does it help if the result is asked to be produced in reverse order?\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, ndigit, split):\n",
|
|
" self.split = split # train/test\n",
|
|
" self.ndigit = ndigit\n",
|
|
" self.vocab_size = 10 # 10 possible digits 0..9\n",
|
|
" # +1 due to potential carry overflow, but then -1 because very last digit doesn't plug back\n",
|
|
" self.block_size = ndigit + ndigit + ndigit + 1 - 1\n",
|
|
" \n",
|
|
" # split up all addition problems into either training data or test data\n",
|
|
" num = (10**self.ndigit)**2 # total number of possible combinations\n",
|
|
" r = np.random.RandomState(1337) # make deterministic\n",
|
|
" perm = r.permutation(num)\n",
|
|
" num_test = min(int(num*0.2), 1000) # 20% of the whole dataset, or only up to 1000\n",
|
|
" self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]\n",
|
|
"\n",
|
|
" def __len__(self):\n",
|
|
" return self.ixes.size\n",
|
|
"\n",
|
|
" def __getitem__(self, idx):\n",
|
|
" # given a problem index idx, first recover the associated a + b\n",
|
|
" idx = self.ixes[idx]\n",
|
|
" nd = 10**self.ndigit\n",
|
|
" a = idx // nd\n",
|
|
" b = idx % nd\n",
|
|
" c = a + b\n",
|
|
" render = f'%0{self.ndigit}d%0{self.ndigit}d%0{self.ndigit+1}d' % (a,b,c) # e.g. 03+25=28 becomes \"0325028\" \n",
|
|
" dix = [int(s) for s in render] # convert each character to its token index\n",
|
|
" # x will be input to GPT and y will be the associated expected outputs\n",
|
|
" x = torch.tensor(dix[:-1], dtype=torch.long)\n",
|
|
" y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence\n",
|
|
" y[:self.ndigit*2-1] = -100 # we will only train in the output locations. -100 will mask loss to zero\n",
|
|
" return x, y\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# create a dataset for e.g. 2-digit addition\n",
|
|
"ndigit = 2\n",
|
|
"train_dataset = AdditionDataset(ndigit=ndigit, split='train')\n",
|
|
"test_dataset = AdditionDataset(ndigit=ndigit, split='test')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor([4, 7, 1, 7, 0, 6]), tensor([-100, -100, -100, 0, 6, 4]))"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"train_dataset[0] # sample a training instance just to see what one raw example looks like"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"08/16/2020 23:47:41 - INFO - mingpt.model - number of parameters: 4.001280e+05\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from mingpt.model import GPT, GPTConfig, GPT1Config\n",
|
|
"\n",
|
|
"# initialize a baby GPT model\n",
|
|
"mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, \n",
|
|
" n_layer=2, n_head=4, n_embd=128)\n",
|
|
"model = GPT(mconf)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 0%| | 0/18 [00:00<?, ?it/s]/apcv/shared/conda-envs/apcv-6244e1d-566/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
|
"epoch 1 iter 17: train loss 1.74049. lr 5.994512e-04: 100%|██████████| 18/18 [00:30<00:00, 1.70s/it]\n",
|
|
"08/16/2020 23:48:16 - INFO - mingpt.trainer - test loss: 1.693525\n",
|
|
"epoch 2 iter 17: train loss 1.50974. lr 5.977197e-04: 100%|██████████| 18/18 [00:01<00:00, 11.61it/s]\n",
|
|
"08/16/2020 23:48:18 - INFO - mingpt.trainer - test loss: 1.466473\n",
|
|
"epoch 3 iter 17: train loss 1.31133. lr 5.948114e-04: 100%|██████████| 18/18 [00:01<00:00, 11.45it/s]\n",
|
|
"08/16/2020 23:48:20 - INFO - mingpt.trainer - test loss: 1.256615\n",
|
|
"epoch 4 iter 17: train loss 1.22379. lr 5.907379e-04: 100%|██████████| 18/18 [00:01<00:00, 11.50it/s]\n",
|
|
"08/16/2020 23:48:21 - INFO - mingpt.trainer - test loss: 1.160792\n",
|
|
"epoch 5 iter 17: train loss 1.14308. lr 5.855153e-04: 100%|██████████| 18/18 [00:01<00:00, 11.63it/s]\n",
|
|
"08/16/2020 23:48:23 - INFO - mingpt.trainer - test loss: 1.091487\n",
|
|
"epoch 6 iter 17: train loss 1.09970. lr 5.791641e-04: 100%|██████████| 18/18 [00:01<00:00, 11.56it/s]\n",
|
|
"08/16/2020 23:48:25 - INFO - mingpt.trainer - test loss: 1.050111\n",
|
|
"epoch 7 iter 17: train loss 1.08481. lr 5.717095e-04: 100%|██████████| 18/18 [00:01<00:00, 11.53it/s]\n",
|
|
"08/16/2020 23:48:26 - INFO - mingpt.trainer - test loss: 1.037456\n",
|
|
"epoch 8 iter 17: train loss 1.03496. lr 5.631810e-04: 100%|██████████| 18/18 [00:01<00:00, 11.59it/s]\n",
|
|
"08/16/2020 23:48:28 - INFO - mingpt.trainer - test loss: 0.997156\n",
|
|
"epoch 9 iter 17: train loss 0.98606. lr 5.536122e-04: 100%|██████████| 18/18 [00:01<00:00, 11.67it/s]\n",
|
|
"08/16/2020 23:48:30 - INFO - mingpt.trainer - test loss: 0.836543\n",
|
|
"epoch 10 iter 17: train loss 0.59589. lr 5.430411e-04: 100%|██████████| 18/18 [00:01<00:00, 12.80it/s]\n",
|
|
"08/16/2020 23:48:31 - INFO - mingpt.trainer - test loss: 0.438013\n",
|
|
"epoch 11 iter 17: train loss 0.50257. lr 5.315093e-04: 100%|██████████| 18/18 [00:01<00:00, 12.99it/s]\n",
|
|
"08/16/2020 23:48:33 - INFO - mingpt.trainer - test loss: 0.343370\n",
|
|
"epoch 12 iter 17: train loss 0.44096. lr 5.190624e-04: 100%|██████████| 18/18 [00:01<00:00, 12.08it/s]\n",
|
|
"08/16/2020 23:48:34 - INFO - mingpt.trainer - test loss: 0.277625\n",
|
|
"epoch 13 iter 17: train loss 0.37445. lr 5.057497e-04: 100%|██████████| 18/18 [00:01<00:00, 12.84it/s]\n",
|
|
"08/16/2020 23:48:36 - INFO - mingpt.trainer - test loss: 0.236511\n",
|
|
"epoch 14 iter 17: train loss 0.31269. lr 4.916238e-04: 100%|██████████| 18/18 [00:01<00:00, 12.90it/s]\n",
|
|
"08/16/2020 23:48:37 - INFO - mingpt.trainer - test loss: 0.207689\n",
|
|
"epoch 15 iter 17: train loss 0.34095. lr 4.767405e-04: 100%|██████████| 18/18 [00:01<00:00, 12.74it/s]\n",
|
|
"08/16/2020 23:48:39 - INFO - mingpt.trainer - test loss: 0.165566\n",
|
|
"epoch 16 iter 17: train loss 0.25957. lr 4.611586e-04: 100%|██████████| 18/18 [00:01<00:00, 12.69it/s]\n",
|
|
"08/16/2020 23:48:40 - INFO - mingpt.trainer - test loss: 0.123080\n",
|
|
"epoch 17 iter 17: train loss 0.23488. lr 4.449397e-04: 100%|██████████| 18/18 [00:01<00:00, 12.85it/s]\n",
|
|
"08/16/2020 23:48:42 - INFO - mingpt.trainer - test loss: 0.091252\n",
|
|
"epoch 18 iter 17: train loss 0.20269. lr 4.281479e-04: 100%|██████████| 18/18 [00:01<00:00, 12.72it/s]\n",
|
|
"08/16/2020 23:48:43 - INFO - mingpt.trainer - test loss: 0.078601\n",
|
|
"epoch 19 iter 17: train loss 0.19535. lr 4.108497e-04: 100%|██████████| 18/18 [00:01<00:00, 12.78it/s]\n",
|
|
"08/16/2020 23:48:45 - INFO - mingpt.trainer - test loss: 0.055412\n",
|
|
"epoch 20 iter 17: train loss 0.16152. lr 3.931133e-04: 100%|██████████| 18/18 [00:01<00:00, 12.66it/s]\n",
|
|
"08/16/2020 23:48:46 - INFO - mingpt.trainer - test loss: 0.051874\n",
|
|
"epoch 21 iter 17: train loss 0.14061. lr 3.750088e-04: 100%|██████████| 18/18 [00:01<00:00, 12.84it/s]\n",
|
|
"08/16/2020 23:48:48 - INFO - mingpt.trainer - test loss: 0.044502\n",
|
|
"epoch 22 iter 17: train loss 0.16309. lr 3.566079e-04: 100%|██████████| 18/18 [00:01<00:00, 12.67it/s]\n",
|
|
"08/16/2020 23:48:49 - INFO - mingpt.trainer - test loss: 0.036376\n",
|
|
"epoch 23 iter 17: train loss 0.14411. lr 3.379832e-04: 100%|██████████| 18/18 [00:01<00:00, 13.21it/s]\n",
|
|
"08/16/2020 23:48:51 - INFO - mingpt.trainer - test loss: 0.029843\n",
|
|
"epoch 24 iter 17: train loss 0.12110. lr 3.192084e-04: 100%|██████████| 18/18 [00:01<00:00, 12.74it/s]\n",
|
|
"08/16/2020 23:48:52 - INFO - mingpt.trainer - test loss: 0.025040\n",
|
|
"epoch 25 iter 17: train loss 0.11360. lr 3.003577e-04: 100%|██████████| 18/18 [00:01<00:00, 12.77it/s]\n",
|
|
"08/16/2020 23:48:54 - INFO - mingpt.trainer - test loss: 0.023500\n",
|
|
"epoch 26 iter 17: train loss 0.13910. lr 2.815056e-04: 100%|██████████| 18/18 [00:01<00:00, 12.78it/s]\n",
|
|
"08/16/2020 23:48:55 - INFO - mingpt.trainer - test loss: 0.022606\n",
|
|
"epoch 27 iter 17: train loss 0.07931. lr 2.627266e-04: 100%|██████████| 18/18 [00:01<00:00, 12.74it/s]\n",
|
|
"08/16/2020 23:48:57 - INFO - mingpt.trainer - test loss: 0.015403\n",
|
|
"epoch 28 iter 17: train loss 0.09684. lr 2.440948e-04: 100%|██████████| 18/18 [00:01<00:00, 11.92it/s]\n",
|
|
"08/16/2020 23:48:58 - INFO - mingpt.trainer - test loss: 0.015245\n",
|
|
"epoch 29 iter 17: train loss 0.09055. lr 2.256841e-04: 100%|██████████| 18/18 [00:01<00:00, 12.77it/s]\n",
|
|
"08/16/2020 23:49:00 - INFO - mingpt.trainer - test loss: 0.012647\n",
|
|
"epoch 30 iter 17: train loss 0.08837. lr 2.075671e-04: 100%|██████████| 18/18 [00:01<00:00, 12.59it/s]\n",
|
|
"08/16/2020 23:49:01 - INFO - mingpt.trainer - test loss: 0.011611\n",
|
|
"epoch 31 iter 17: train loss 0.08425. lr 1.898155e-04: 100%|██████████| 18/18 [00:01<00:00, 12.43it/s]\n",
|
|
"08/16/2020 23:49:03 - INFO - mingpt.trainer - test loss: 0.009952\n",
|
|
"epoch 32 iter 17: train loss 0.10772. lr 1.724993e-04: 100%|██████████| 18/18 [00:01<00:00, 12.40it/s]\n",
|
|
"08/16/2020 23:49:05 - INFO - mingpt.trainer - test loss: 0.008648\n",
|
|
"epoch 33 iter 17: train loss 0.07272. lr 1.556871e-04: 100%|██████████| 18/18 [00:01<00:00, 12.57it/s]\n",
|
|
"08/16/2020 23:49:06 - INFO - mingpt.trainer - test loss: 0.010154\n",
|
|
"epoch 34 iter 17: train loss 0.05550. lr 1.394453e-04: 100%|██████████| 18/18 [00:01<00:00, 12.47it/s]\n",
|
|
"08/16/2020 23:49:08 - INFO - mingpt.trainer - test loss: 0.007668\n",
|
|
"epoch 35 iter 17: train loss 0.05451. lr 1.238381e-04: 100%|██████████| 18/18 [00:01<00:00, 12.59it/s]\n",
|
|
"08/16/2020 23:49:09 - INFO - mingpt.trainer - test loss: 0.008095\n",
|
|
"epoch 36 iter 17: train loss 0.09133. lr 1.089272e-04: 100%|██████████| 18/18 [00:01<00:00, 12.39it/s]\n",
|
|
"08/16/2020 23:49:11 - INFO - mingpt.trainer - test loss: 0.006615\n",
|
|
"epoch 37 iter 17: train loss 0.06825. lr 9.477150e-05: 100%|██████████| 18/18 [00:01<00:00, 12.27it/s]\n",
|
|
"08/16/2020 23:49:12 - INFO - mingpt.trainer - test loss: 0.005874\n",
|
|
"epoch 38 iter 17: train loss 0.05798. lr 8.142699e-05: 100%|██████████| 18/18 [00:01<00:00, 12.49it/s]\n",
|
|
"08/16/2020 23:49:14 - INFO - mingpt.trainer - test loss: 0.005701\n",
|
|
"epoch 39 iter 17: train loss 0.06975. lr 6.894639e-05: 100%|██████████| 18/18 [00:01<00:00, 12.88it/s]\n",
|
|
"08/16/2020 23:49:15 - INFO - mingpt.trainer - test loss: 0.005469\n",
|
|
"epoch 40 iter 17: train loss 0.06070. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.80it/s]\n",
|
|
"08/16/2020 23:49:17 - INFO - mingpt.trainer - test loss: 0.005307\n",
|
|
"epoch 41 iter 17: train loss 0.06378. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.60it/s]\n",
|
|
"08/16/2020 23:49:18 - INFO - mingpt.trainer - test loss: 0.005681\n",
|
|
"epoch 42 iter 17: train loss 0.04885. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.81it/s]\n",
|
|
"08/16/2020 23:49:20 - INFO - mingpt.trainer - test loss: 0.005456\n",
|
|
"epoch 43 iter 17: train loss 0.06409. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.81it/s]\n",
|
|
"08/16/2020 23:49:21 - INFO - mingpt.trainer - test loss: 0.004907\n",
|
|
"epoch 44 iter 17: train loss 0.07563. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.69it/s]\n",
|
|
"08/16/2020 23:49:23 - INFO - mingpt.trainer - test loss: 0.004650\n",
|
|
"epoch 45 iter 17: train loss 0.03149. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.79it/s]\n",
|
|
"08/16/2020 23:49:24 - INFO - mingpt.trainer - test loss: 0.004626\n",
|
|
"epoch 46 iter 17: train loss 0.07037. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.86it/s]\n",
|
|
"08/16/2020 23:49:26 - INFO - mingpt.trainer - test loss: 0.004147\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 47 iter 17: train loss 0.07650. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.82it/s]\n",
|
|
"08/16/2020 23:49:27 - INFO - mingpt.trainer - test loss: 0.004611\n",
|
|
"epoch 48 iter 17: train loss 0.06342. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.63it/s]\n",
|
|
"08/16/2020 23:49:29 - INFO - mingpt.trainer - test loss: 0.004083\n",
|
|
"epoch 49 iter 17: train loss 0.12429. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.69it/s]\n",
|
|
"08/16/2020 23:49:30 - INFO - mingpt.trainer - test loss: 0.004081\n",
|
|
"epoch 50 iter 17: train loss 0.04616. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.19it/s]\n",
|
|
"08/16/2020 23:49:32 - INFO - mingpt.trainer - test loss: 0.003922\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from mingpt.trainer import Trainer, TrainerConfig\n",
|
|
"\n",
|
|
"# initialize a trainer instance and kick off training\n",
|
|
"tconf = TrainerConfig(max_epochs=50, batch_size=512, learning_rate=6e-4,\n",
|
|
" lr_decay=True, warmup_tokens=1024, final_tokens=50*len(train_dataset)*(ndigit+1),\n",
|
|
" num_workers=4)\n",
|
|
"trainer = Trainer(model, train_dataset, test_dataset, tconf)\n",
|
|
"trainer.train()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# now let's give the trained model an addition exam\n",
|
|
"from torch.utils.data.dataloader import DataLoader\n",
|
|
"from mingpt.utils import sample\n",
|
|
"\n",
|
|
"def give_exam(dataset, batch_size=32, max_batches=-1):\n",
|
|
" \n",
|
|
" results = []\n",
|
|
" loader = DataLoader(dataset, batch_size=batch_size)\n",
|
|
" for b, (x, y) in enumerate(loader):\n",
|
|
" x = x.to(trainer.device)\n",
|
|
" d1d2 = x[:, :ndigit*2]\n",
|
|
" d1d2d3 = sample(model, d1d2, ndigit+1)\n",
|
|
" d3 = d1d2d3[:, -(ndigit+1):]\n",
|
|
" factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device)\n",
|
|
" # decode the integers from individual digits\n",
|
|
" d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)\n",
|
|
" d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)\n",
|
|
" d3i_pred = (d3 * factors).sum(1)\n",
|
|
" d3i_gt = d1i + d2i\n",
|
|
" correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line, lol\n",
|
|
" for i in range(x.size(0)):\n",
|
|
" results.append(int(correct[i]))\n",
|
|
" judge = 'YEP!!!' if correct[i] else 'NOPE'\n",
|
|
" if not correct[i]:\n",
|
|
" print(\"GPT claims that %03d + %03d = %03d (gt is %03d; %s)\" \n",
|
|
" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i], judge))\n",
|
|
" \n",
|
|
" if max_batches >= 0 and b+1 >= max_batches:\n",
|
|
" break\n",
|
|
"\n",
|
|
" print(\"final score: %d/%d = %.2f%% correct\" % (np.sum(results), len(results), 100*np.mean(results)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"final score: 9000/9000 = 100.00% correct\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# training set: how well did we memorize?\n",
|
|
"give_exam(train_dataset, batch_size=1024, max_batches=10)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"GPT claims that 055 + 045 = 090 (gt is 100; NOPE)\n",
|
|
"final score: 999/1000 = 99.90% correct\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# test set: how well did we generalize?\n",
|
|
"give_exam(test_dataset, batch_size=1024, max_batches=-1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# well that's amusing... our model learned everything except 55 + 45"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|