1
0
Fork 0
mirror of https://github.com/karpathy/minGPT synced 2024-04-26 20:05:49 +02:00

first commit, able to multigpu train fp32 GPTs on math and character-level data, but have done barely any tuning.

This commit is contained in:
Andrej Karpathy 2020-08-17 00:39:02 -07:00
commit 0d9d098cd2
9 changed files with 1334 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
.ipynb_checkpoints/
__pycache__/

99
README.md Normal file
View File

@ -0,0 +1,99 @@
# minGPT
![mingpt](mingpt.jpg)
A PyTorch re-implementation of [GPT](https://github.com/openai/gpt-3) training. minGPT tries to be small, clean, interpretable and educational, as most of the currently available ones are a bit sprawling. GPT is not a complicated model and this implementation is appropriately about 300 lines of code, including boilerplate and a totally unnecessary custom causal self-attention module. Anyway, all that's going on is that a sequence of indices goes into a sequence of transformer blocks, and a probability distribution of the next index comes out. The rest of the complexity is just being clever with batching (both across examples and over sequence length) so that training is efficient.
The core minGPT "library" (hah) is two files: `mingpt/model.py` contains the actual Transformer model definition and `mingpt/trainer.py` is (GPT-independent) PyTorch boilerplate that trains the model. The attached Jupyter notebooks then show how the "library" (hah) can be used to train sequence models:
- `play_math.ipynb` trains a GPT focused on addition (inspired by the addition section in the GPT-3 paper)
- `play_char.ipynb` trains a GPT to be a character-level language model on arbitrary text, similar to my older char-rnn but with a transformer instead of an RNN
- `play_words.ipynb` a BPE version that does not yet exist
With a bpe encoder, distributed training and maybe fp16 this implementation may be able to reproduce GPT-1/GPT-2 results, though I haven't tried $$$. GPT-3 is likely out of reach as my understanding is that it does not fit into GPU memory and requires a more careful model-parallel treatment.
### Example usage
This code is simple enough to just hack inline, not "used", but current API looks something like:
```python
# you're on your own to define a class that returns individual examples as PyTorch LongTensors
from torch.utils.data import Dataset
train_dataset = MyDataset(...)
test_dataset = MyDataset(...)
# construct a GPT model
from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(vocab_size, block_size, n_layer=12, n_head=12, n_embd=768) # a GPT-1
model = GPT(mconf)
# construct a trainer
from mingpt.trainer import Trainer, TrainerConfig
tconf = TrainerConfig(max_epochs=10, batch_size=256)
trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()
# (... enjoy the show for a while... )
# sample from the model (the [None, ...] and [0] are to push/pop a needed dummy batch dimension)
from mingpt.utils import sample
x = torch.tensor([1, 2, 3], dtype=torch.long)[None, ...] # context conditioning
y = sample(model, x, steps=30, temperature=1.0, sample=True, top_k=5)[0]
print(y) # our model filled in the integer sequence with 30 additional likely integers
```
### References
Code:
- [openai/gpt-2](https://github.com/openai/gpt-2) has the model but not the training code, and in TensorFlow
- [openai/image-gpt](https://github.com/openai/image-gpt) has some more modern gpt-3 like modification in its code, good reference as well
- huggingface/transformers has a [language-modeling example](https://github.com/huggingface/transformers/tree/master/examples/language-modeling). It is full-featured but as a result also somewhat challenging to trace. E.g. some large functions have as much as 90% unused code behind various branching statments that is unsued in the default setting of simple language modeling.
Papers + some implementation notes:
#### Improving Language Understanding by Generative Pre-Training (GPT-1)
- Our model largely follows the original transformer work
- We trained a 12-layer decoder-only transformer with masked self-attention heads (768 dimensional states and 12 attention heads). For the position-wise feed-forward networks, we used 3072 dimensional inner states.
- Adam max learning rate of 2.5e-4. (later GPT-3 for this model size uses 6e-4)
- LR decay: increased linearly from zero over the first 2000 updates and annealed to 0 using a cosine schedule
- We train for 100 epochs on minibatches of 64 randomly sampled, contiguous sequences of 512 tokens.
- Since layernorm is used extensively throughout the model, a simple weight initialization of N(0, 0.02) was sufficient
- bytepair encoding (BPE) vocabulary with 40,000 merges
- residual, embedding, and attention dropouts with a rate of 0.1 for regularization.
- modified version of L2 regularization proposed in (37), with w = 0.01 on all non bias or gain weights
- For the activation function, we used the Gaussian Error Linear Unit (GELU).
- We used learned position embeddings instead of the sinusoidal version proposed in the original work
- For finetuning: We add dropout to the classifier with a rate of 0.1. learning rate of 6.25e-5 and a batchsize of 32. 3 epochs. We use a linear learning rate decay schedule with warmup over 0.2% of training. λ was set to 0.5.
- GPT-1 model is 12 layers and d_model 768, ~117M params
#### Language Models are Unsupervised Multitask Learners (GPT-2)
- LayerNorm was moved to the input of each sub-block, similar to a pre-activation residual network
- an additional layer normalization was added after the final self-attention block.
- modified initialization which accounts for the accumulation on the residual path with model depth is used. We scale the weights of residual layers at initialization by a factor of 1/√N where N is the number of residual layers. (weird because in their released code i can only find a simple use of the old 0.02... in their release of image-gpt I found it used for c_proj, and even then only for attn, not for mlp. huh. https://github.com/openai/image-gpt/blob/master/src/model.py)
- the vocabulary is expanded to 50,257
- increase the context size from 512 to 1024 tokens
- larger batchsize of 512 is used
- GPT-2 used 48 layers and d_model 1600 (vs. original 12 layers and d_model 768). ~1.542B params
#### Language Models are Few-Shot Learners (GPT-3)
- GPT-3: 96 layers, 96 heads, with d_model of 12,288 (175B parameters).
- GPT-1-like: 12 layers, 12 heads, d_model 768 (125M)
- We use the same model and architecture as GPT-2, including the modified initialization, pre-normalization, and reversible tokenization described therein
- we use alternating dense and locally banded sparse attention patterns in the layers of the transformer, similar to the Sparse Transformer
- we always have the feedforward layer four times the size of the bottleneck layer, dff = 4 dmodel
- all models use a context window of nctx = 2048 tokens.
- Adam with β1 = 0.9, β2 = 0.95, and eps = 108
- All models use weight decay of 0.1 to provide a small amount of regularization. (NOTE: GPT-1 used 0.01 I believe, see above)
- clip the global norm of the gradient at 1.0
- Linear LR warmup over the first 375 million tokens. Then use cosine decay for learning rate down to 10% of its value, over 260 billion tokens.
- gradually increase the batch size linearly from a small value (32k tokens) to the full value over the first 4-12 billion tokens of training, depending on the model size.
- full 2048-sized time context window is always used, with a special END OF DOCUMENT token delimiter
### License
MIT

BIN
mingpt.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 116 KiB

0
mingpt/__init__.py Normal file
View File

151
mingpt/model.py Normal file
View File

@ -0,0 +1,151 @@
"""
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
- all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
"""
import math
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
class GPTConfig:
""" base GPT config, params common to all GPT versions """
embd_pdrop = 0.1
resid_pdrop = 0.1
attn_pdrop = 0.1
def __init__(self, vocab_size, block_size, **kwargs):
self.vocab_size = vocab_size
self.block_size = block_size
for k,v in kwargs.items():
setattr(self, k, v)
class GPT1Config(GPTConfig):
""" GPT-1 like network roughly 125M params """
n_layer = 12
n_head = 12
n_embd = 768
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
I believe I could have just used torch.nn.MultiheadAttention but their documentation
is all but absent and code ugly so I don't trust it, rolling my own here.
"""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads
self.key = nn.Linear(config.n_embd, config.n_embd)
self.query = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
# regularization
self.attn_drop = nn.Dropout(config.attn_pdrop)
self.resid_drop = nn.Dropout(config.resid_pdrop)
# output projection
self.proj = nn.Linear(config.n_embd, config.n_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
def forward(self, x, layer_past=None):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, -1e10) # todo: just use float('-inf') instead?
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
return y
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
""" the full GPT language model, with a context size of block_size """
def __init__(self, config):
super().__init__()
# input embedding stem
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.embd_pdrop)
# transformer
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
# decoder head
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.block_size = config.block_size
self.apply(self._init_weights)
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def get_block_size(self):
return self.block_size
def forward(self, idx, targets=None):
b, t = idx.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss

129
mingpt/trainer.py Normal file
View File

@ -0,0 +1,129 @@
"""
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
so nothing in this file really has anything to do with GPT specifically.
"""
import math
import logging
from tqdm import tqdm
import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader
logger = logging.getLogger(__name__)
class TrainerConfig:
# optimization parameters
max_epochs = 10
batch_size = 64
learning_rate = 3e-4
betas = (0.9, 0.95)
grad_norm_clip = 1.0
weight_decay = 0.1 # only applied on matmul weights
# learning rate decay params: linear warmup followed by cosine decay to 10% of original
lr_decay = False
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
final_tokens = 260e9 # (at what point we reach 10% of original LR)
# checkpoint settings
ckpt_path = None
num_workers = 0 # for DataLoader
def __init__(self, **kwargs):
for k,v in kwargs.items():
setattr(self, k, v)
class Trainer:
def __init__(self, model, train_dataset, test_dataset, config):
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
# take over whatever gpus are on the system
self.device = 'cpu'
if torch.cuda.is_available():
self.device = torch.cuda.current_device()
self.model = torch.nn.DataParallel(self.model).to(self.device)
def save_checkpoint(self):
if self.config.ckpt_path is not None:
ckpt_model = self.model.module if hasattr(self.model, "module") else self.model
logger.info("saving %s", self.config.ckpt_path)
torch.save(ckpt_model.state_dict(), self.config.ckpt_path)
def train(self):
model, config = self.model, self.config
# create the optimizer
no_decay = ["bias", "LayerNorm.weight"]
params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
optim_groups = [
{"params": params_decay, "weight_decay": config.weight_decay},
{"params": params_nodecay, "weight_decay": 0.0},
]
optimizer = optim.AdamW(optim_groups, lr=config.learning_rate, betas=config.betas)
def run_epoch(split):
is_train = split == 'train'
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
loader = DataLoader(data, batch_size=config.batch_size, num_workers=config.num_workers)
losses = []
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
for it, (x, y) in pbar:
# place data on the correct device
x = x.to(self.device)
y = y.to(self.device)
# forward the model
with torch.set_grad_enabled(is_train):
logits, loss = model(x, y)
loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
losses.append(loss.item())
if is_train:
# backprop and update the parameters
model.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
optimizer.step()
# decay the learning rate based on our progress
if config.lr_decay:
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
else:
# cosine learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
lr = config.learning_rate
# report progress
pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")
if not is_train:
logger.info("test loss: %f", np.mean(losses))
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
run_epoch('train')
if self.test_dataset is not None:
run_epoch('test')
self.save_checkpoint()

47
mingpt/utils.py Normal file
View File

@ -0,0 +1,47 @@
import random
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def top_k_logits(logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[:, [-1]]] = 1e-10
return out
@torch.no_grad()
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
"""
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
the sequence, feeding the predictions back into the model each time. Clearly the sampling
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
of block_size, unlike an RNN that has an infinite context window.
"""
block_size = model.get_block_size()
model.eval()
for k in range(steps):
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
logits, _ = model(x_cond)
# pluck the logits at the final step and scale by temperature
logits = logits[:, -1, :] / temperature
# optionally crop probabilities to only the top k options
if top_k is not None:
logits = top_k_logits(logits, top_k)
# apply softmax to convert to probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution or take the most likely
if sample:
ix = torch.multinomial(probs, num_samples=1)
else:
_, ix = torch.topk(probs, k=1, dim=-1)
# append to the sequence and continue
x = torch.cat((x, ix), dim=1)
return x

493
play_char.ipynb Normal file
View File

@ -0,0 +1,493 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a character-level GPT on some text data\n",
"\n",
"The inputs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue as well. In this example we will feed it some shakespear, which we'll get it to predict character-level."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# set up logging\n",
"import logging\n",
"logging.basicConfig(\n",
" format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
" datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
" level=logging.INFO,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# make deterministic\n",
"from mingpt.utils import set_seed\n",
"set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn import functional as F"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"from torch.utils.data import Dataset\n",
"\n",
"class CharDataset(Dataset):\n",
"\n",
" def __init__(self, data, block_size):\n",
" chars = list(set(data))\n",
" data_size, vocab_size = len(data), len(chars)\n",
" print('data has %d characters, %d unique.' % (data_size, vocab_size))\n",
" \n",
" self.stoi = { ch:i for i,ch in enumerate(chars) }\n",
" self.itos = { i:ch for i,ch in enumerate(chars) }\n",
" self.block_size = block_size\n",
" self.vocab_size = vocab_size\n",
" self.data = data\n",
" \n",
" def __len__(self):\n",
" return math.ceil(len(self.data) / (self.block_size + 1))\n",
"\n",
" def __getitem__(self, idx):\n",
" # we're actually going to \"cheat\" and pick a spot in the dataset at random\n",
" i = np.random.randint(0, len(self.data) - (self.block_size + 1))\n",
" chunk = self.data[i:i+self.block_size+1]\n",
" dix = [self.stoi[s] for s in chunk]\n",
" x = torch.tensor(dix[:-1], dtype=torch.long)\n",
" y = torch.tensor(dix[1:], dtype=torch.long)\n",
" return x, y\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"block_size = 128 # spatial extent of the model for its context"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data has 1115394 characters, 65 unique.\n"
]
}
],
"source": [
"# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt\n",
"text = open('input.txt', 'r').read() # don't worry we won't run out of file handles\n",
"train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"08/17/2020 00:11:58 - INFO - mingpt.model - number of parameters: 2.535219e+07\n"
]
}
],
"source": [
"from mingpt.model import GPT, GPTConfig\n",
"mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,\n",
" n_layer=8, n_head=8, n_embd=512)\n",
"model = GPT(mconf)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/17 [00:00<?, ?it/s]/apcv/shared/conda-envs/apcv-6244e1d-566/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"epoch 1 iter 16: train loss 3.31022. lr 5.999637e-04: 100%|██████████| 17/17 [00:36<00:00, 2.18s/it]\n",
"epoch 2 iter 16: train loss 2.89320. lr 5.998533e-04: 100%|██████████| 17/17 [00:04<00:00, 3.78it/s]\n",
"epoch 3 iter 16: train loss 2.63845. lr 5.996690e-04: 100%|██████████| 17/17 [00:04<00:00, 3.74it/s]\n",
"epoch 4 iter 16: train loss 2.54588. lr 5.994107e-04: 100%|██████████| 17/17 [00:04<00:00, 3.87it/s]\n",
"epoch 5 iter 16: train loss 2.49512. lr 5.990785e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 6 iter 16: train loss 2.46732. lr 5.986726e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 7 iter 16: train loss 2.44716. lr 5.981929e-04: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 8 iter 16: train loss 2.37363. lr 5.976397e-04: 100%|██████████| 17/17 [00:04<00:00, 3.93it/s]\n",
"epoch 9 iter 16: train loss 2.34669. lr 5.970130e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 10 iter 16: train loss 2.28792. lr 5.963130e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 11 iter 16: train loss 2.21925. lr 5.955399e-04: 100%|██████████| 17/17 [00:04<00:00, 4.04it/s]\n",
"epoch 12 iter 16: train loss 2.16131. lr 5.946939e-04: 100%|██████████| 17/17 [00:04<00:00, 4.04it/s]\n",
"epoch 13 iter 16: train loss 2.12197. lr 5.937751e-04: 100%|██████████| 17/17 [00:04<00:00, 4.10it/s]\n",
"epoch 14 iter 16: train loss 2.06564. lr 5.927839e-04: 100%|██████████| 17/17 [00:04<00:00, 4.01it/s]\n",
"epoch 15 iter 16: train loss 2.00401. lr 5.917204e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 16 iter 16: train loss 1.96109. lr 5.905849e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 17 iter 16: train loss 1.89554. lr 5.893777e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 18 iter 16: train loss 1.85840. lr 5.880992e-04: 100%|██████████| 17/17 [00:04<00:00, 4.04it/s]\n",
"epoch 19 iter 16: train loss 1.80772. lr 5.867495e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 20 iter 16: train loss 1.76782. lr 5.853291e-04: 100%|██████████| 17/17 [00:04<00:00, 4.06it/s]\n",
"epoch 21 iter 16: train loss 1.73638. lr 5.838382e-04: 100%|██████████| 17/17 [00:04<00:00, 4.05it/s]\n",
"epoch 22 iter 16: train loss 1.71822. lr 5.822774e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 23 iter 16: train loss 1.65840. lr 5.806468e-04: 100%|██████████| 17/17 [00:04<00:00, 4.07it/s]\n",
"epoch 24 iter 16: train loss 1.62377. lr 5.789471e-04: 100%|██████████| 17/17 [00:04<00:00, 4.07it/s]\n",
"epoch 25 iter 16: train loss 1.58972. lr 5.771785e-04: 100%|██████████| 17/17 [00:04<00:00, 4.11it/s]\n",
"epoch 26 iter 16: train loss 1.56834. lr 5.753415e-04: 100%|██████████| 17/17 [00:04<00:00, 4.08it/s]\n",
"epoch 27 iter 16: train loss 1.53979. lr 5.734365e-04: 100%|██████████| 17/17 [00:04<00:00, 4.03it/s]\n",
"epoch 28 iter 16: train loss 1.50191. lr 5.714641e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 29 iter 16: train loss 1.48912. lr 5.694247e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 30 iter 16: train loss 1.46471. lr 5.673188e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 31 iter 16: train loss 1.42849. lr 5.651469e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 32 iter 16: train loss 1.43040. lr 5.629096e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 33 iter 16: train loss 1.41800. lr 5.606075e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 34 iter 16: train loss 1.38389. lr 5.582410e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 35 iter 16: train loss 1.39671. lr 5.558108e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 36 iter 16: train loss 1.37993. lr 5.533175e-04: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 37 iter 16: train loss 1.34003. lr 5.507617e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 38 iter 16: train loss 1.33921. lr 5.481440e-04: 100%|██████████| 17/17 [00:04<00:00, 3.93it/s]\n",
"epoch 39 iter 16: train loss 1.33100. lr 5.454651e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 40 iter 16: train loss 1.31935. lr 5.427256e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 41 iter 16: train loss 1.30351. lr 5.399262e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 42 iter 16: train loss 1.29172. lr 5.370676e-04: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 43 iter 16: train loss 1.28822. lr 5.341505e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 44 iter 16: train loss 1.26141. lr 5.311756e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 45 iter 16: train loss 1.26616. lr 5.281437e-04: 100%|██████████| 17/17 [00:04<00:00, 4.00it/s]\n",
"epoch 46 iter 16: train loss 1.25256. lr 5.250555e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 47 iter 16: train loss 1.22032. lr 5.219118e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 48 iter 16: train loss 1.23235. lr 5.187133e-04: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 49 iter 16: train loss 1.23283. lr 5.154608e-04: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 50 iter 16: train loss 1.20031. lr 5.121552e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 51 iter 16: train loss 1.18663. lr 5.087972e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 52 iter 16: train loss 1.19119. lr 5.053876e-04: 100%|██████████| 17/17 [00:04<00:00, 4.02it/s]\n",
"epoch 53 iter 16: train loss 1.19220. lr 5.019275e-04: 100%|██████████| 17/17 [00:04<00:00, 4.03it/s]\n",
"epoch 54 iter 16: train loss 1.16783. lr 4.984174e-04: 100%|██████████| 17/17 [00:04<00:00, 4.00it/s]\n",
"epoch 55 iter 16: train loss 1.15194. lr 4.948584e-04: 100%|██████████| 17/17 [00:04<00:00, 4.03it/s]\n",
"epoch 56 iter 16: train loss 1.14275. lr 4.912514e-04: 100%|██████████| 17/17 [00:04<00:00, 4.04it/s]\n",
"epoch 57 iter 16: train loss 1.14208. lr 4.875971e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 58 iter 16: train loss 1.13699. lr 4.838966e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 59 iter 16: train loss 1.12777. lr 4.801506e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 60 iter 16: train loss 1.11677. lr 4.763603e-04: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 61 iter 16: train loss 1.11154. lr 4.725264e-04: 100%|██████████| 17/17 [00:04<00:00, 4.01it/s]\n",
"epoch 62 iter 16: train loss 1.09913. lr 4.686499e-04: 100%|██████████| 17/17 [00:04<00:00, 4.03it/s]\n",
"epoch 63 iter 16: train loss 1.08472. lr 4.647318e-04: 100%|██████████| 17/17 [00:04<00:00, 4.12it/s]\n",
"epoch 64 iter 16: train loss 1.07229. lr 4.607731e-04: 100%|██████████| 17/17 [00:04<00:00, 4.08it/s]\n",
"epoch 65 iter 16: train loss 1.05527. lr 4.567747e-04: 100%|██████████| 17/17 [00:04<00:00, 4.03it/s]\n",
"epoch 66 iter 16: train loss 1.05399. lr 4.527376e-04: 100%|██████████| 17/17 [00:04<00:00, 4.09it/s]\n",
"epoch 67 iter 16: train loss 1.04091. lr 4.486628e-04: 100%|██████████| 17/17 [00:04<00:00, 4.12it/s]\n",
"epoch 68 iter 16: train loss 1.02974. lr 4.445513e-04: 100%|██████████| 17/17 [00:04<00:00, 4.16it/s]\n",
"epoch 69 iter 16: train loss 1.02136. lr 4.404042e-04: 100%|██████████| 17/17 [00:04<00:00, 4.12it/s]\n",
"epoch 70 iter 16: train loss 0.99327. lr 4.362224e-04: 100%|██████████| 17/17 [00:04<00:00, 4.04it/s]\n",
"epoch 71 iter 16: train loss 1.00298. lr 4.320070e-04: 100%|██████████| 17/17 [00:04<00:00, 4.08it/s]\n",
"epoch 72 iter 16: train loss 0.98914. lr 4.277590e-04: 100%|██████████| 17/17 [00:04<00:00, 4.05it/s]\n",
"epoch 73 iter 16: train loss 0.97867. lr 4.234795e-04: 100%|██████████| 17/17 [00:04<00:00, 4.07it/s]\n",
"epoch 74 iter 16: train loss 0.96466. lr 4.191696e-04: 100%|██████████| 17/17 [00:04<00:00, 4.05it/s]\n",
"epoch 75 iter 16: train loss 0.95394. lr 4.148302e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 76 iter 16: train loss 0.94508. lr 4.104625e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 77 iter 16: train loss 0.92286. lr 4.060675e-04: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 78 iter 16: train loss 0.93483. lr 4.016464e-04: 100%|██████████| 17/17 [00:04<00:00, 3.92it/s]\n",
"epoch 79 iter 16: train loss 0.91154. lr 3.972002e-04: 100%|██████████| 17/17 [00:04<00:00, 3.93it/s]\n",
"epoch 80 iter 16: train loss 0.90103. lr 3.927300e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 81 iter 16: train loss 0.88080. lr 3.882369e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 82 iter 16: train loss 0.87298. lr 3.837220e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 83 iter 16: train loss 0.87236. lr 3.791865e-04: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 84 iter 16: train loss 0.84863. lr 3.746315e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 85 iter 16: train loss 0.84596. lr 3.700580e-04: 100%|██████████| 17/17 [00:04<00:00, 3.92it/s]\n",
"epoch 86 iter 16: train loss 0.83594. lr 3.654672e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 87 iter 16: train loss 0.80802. lr 3.608603e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 88 iter 16: train loss 0.81852. lr 3.562384e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 89 iter 16: train loss 0.79750. lr 3.516026e-04: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 90 iter 16: train loss 0.79546. lr 3.469540e-04: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 91 iter 16: train loss 0.78435. lr 3.422939e-04: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 92 iter 16: train loss 0.77504. lr 3.376233e-04: 100%|██████████| 17/17 [00:04<00:00, 3.93it/s]\n",
"epoch 93 iter 16: train loss 0.75585. lr 3.329435e-04: 100%|██████████| 17/17 [00:04<00:00, 4.03it/s]\n",
"epoch 94 iter 16: train loss 0.74559. lr 3.282555e-04: 100%|██████████| 17/17 [00:04<00:00, 4.05it/s]\n",
"epoch 95 iter 16: train loss 0.73854. lr 3.235605e-04: 100%|██████████| 17/17 [00:04<00:00, 4.04it/s]\n",
"epoch 96 iter 16: train loss 0.73887. lr 3.188598e-04: 100%|██████████| 17/17 [00:04<00:00, 4.02it/s]\n",
"epoch 97 iter 16: train loss 0.72368. lr 3.141544e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 98 iter 16: train loss 0.70582. lr 3.094455e-04: 100%|██████████| 17/17 [00:04<00:00, 4.02it/s]\n",
"epoch 99 iter 16: train loss 0.70190. lr 3.047342e-04: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 100 iter 16: train loss 0.68987. lr 3.000218e-04: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 101 iter 16: train loss 0.67758. lr 2.953094e-04: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 102 iter 16: train loss 0.66017. lr 2.905981e-04: 100%|██████████| 17/17 [00:04<00:00, 4.00it/s]\n",
"epoch 103 iter 16: train loss 0.65877. lr 2.858892e-04: 100%|██████████| 17/17 [00:04<00:00, 4.00it/s]\n",
"epoch 104 iter 16: train loss 0.64497. lr 2.811837e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 105 iter 16: train loss 0.64179. lr 2.764829e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 106 iter 16: train loss 0.63559. lr 2.717879e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 107 iter 16: train loss 0.62855. lr 2.670999e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 108 iter 16: train loss 0.61758. lr 2.624199e-04: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 109 iter 16: train loss 0.60657. lr 2.577493e-04: 100%|██████████| 17/17 [00:04<00:00, 4.03it/s]\n",
"epoch 110 iter 16: train loss 0.59996. lr 2.530890e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 111 iter 16: train loss 0.58367. lr 2.484404e-04: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 112 iter 16: train loss 0.58197. lr 2.438044e-04: 100%|██████████| 17/17 [00:04<00:00, 3.92it/s]\n",
"epoch 113 iter 16: train loss 0.57091. lr 2.391824e-04: 100%|██████████| 17/17 [00:04<00:00, 3.91it/s]\n",
"epoch 114 iter 16: train loss 0.57271. lr 2.345753e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 115 iter 16: train loss 0.55243. lr 2.299844e-04: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 116 iter 16: train loss 0.54761. lr 2.254108e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 117 iter 16: train loss 0.54044. lr 2.208555e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 118 iter 16: train loss 0.53038. lr 2.163198e-04: 100%|██████████| 17/17 [00:04<00:00, 4.02it/s]\n",
"epoch 119 iter 16: train loss 0.53014. lr 2.118048e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 120 iter 16: train loss 0.52053. lr 2.073115e-04: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 121 iter 16: train loss 0.51295. lr 2.028411e-04: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 122 iter 16: train loss 0.51009. lr 1.983946e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 123 iter 16: train loss 0.51131. lr 1.939732e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 124 iter 16: train loss 0.49467. lr 1.895780e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 125 iter 16: train loss 0.48849. lr 1.852101e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 126 iter 16: train loss 0.47873. lr 1.808704e-04: 100%|██████████| 17/17 [00:04<00:00, 3.92it/s]\n",
"epoch 127 iter 16: train loss 0.47302. lr 1.765602e-04: 100%|██████████| 17/17 [00:04<00:00, 3.93it/s]\n",
"epoch 128 iter 16: train loss 0.47510. lr 1.722804e-04: 100%|██████████| 17/17 [00:04<00:00, 4.00it/s]\n",
"epoch 129 iter 16: train loss 0.46640. lr 1.680321e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 130 iter 16: train loss 0.46535. lr 1.638164e-04: 100%|██████████| 17/17 [00:04<00:00, 4.01it/s]\n",
"epoch 131 iter 16: train loss 0.45578. lr 1.596343e-04: 100%|██████████| 17/17 [00:04<00:00, 3.91it/s]\n",
"epoch 132 iter 16: train loss 0.45213. lr 1.554869e-04: 100%|██████████| 17/17 [00:04<00:00, 3.91it/s]\n",
"epoch 133 iter 16: train loss 0.44695. lr 1.513751e-04: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 134 iter 16: train loss 0.44066. lr 1.473000e-04: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 135 iter 16: train loss 0.42818. lr 1.432625e-04: 100%|██████████| 17/17 [00:04<00:00, 3.93it/s]\n",
"epoch 136 iter 16: train loss 0.43211. lr 1.392637e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 137 iter 16: train loss 0.42977. lr 1.353046e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 138 iter 16: train loss 0.41763. lr 1.313862e-04: 100%|██████████| 17/17 [00:04<00:00, 3.93it/s]\n",
"epoch 139 iter 16: train loss 0.41813. lr 1.275093e-04: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 140 iter 16: train loss 0.41200. lr 1.236750e-04: 100%|██████████| 17/17 [00:04<00:00, 3.91it/s]\n",
"epoch 141 iter 16: train loss 0.40947. lr 1.198842e-04: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 142 iter 16: train loss 0.41038. lr 1.161379e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 143 iter 16: train loss 0.40603. lr 1.124369e-04: 100%|██████████| 17/17 [00:04<00:00, 3.92it/s]\n",
"epoch 144 iter 16: train loss 0.40523. lr 1.087822e-04: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 145 iter 16: train loss 0.39535. lr 1.051747e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 146 iter 16: train loss 0.39754. lr 1.016153e-04: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 147 iter 16: train loss 0.38898. lr 9.810479e-05: 100%|██████████| 17/17 [00:04<00:00, 3.91it/s]\n",
"epoch 148 iter 16: train loss 0.38628. lr 9.464413e-05: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 149 iter 16: train loss 0.38674. lr 9.123415e-05: 100%|██████████| 17/17 [00:04<00:00, 3.91it/s]\n",
"epoch 150 iter 16: train loss 0.38966. lr 8.787567e-05: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 151 iter 16: train loss 0.38590. lr 8.456954e-05: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 152 iter 16: train loss 0.37149. lr 8.131657e-05: 100%|██████████| 17/17 [00:04<00:00, 3.92it/s]\n",
"epoch 153 iter 16: train loss 0.37366. lr 7.811757e-05: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 154 iter 16: train loss 0.36515. lr 7.497331e-05: 100%|██████████| 17/17 [00:04<00:00, 3.91it/s]\n",
"epoch 155 iter 16: train loss 0.36510. lr 7.188458e-05: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 156 iter 16: train loss 0.36846. lr 6.885214e-05: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 157 iter 16: train loss 0.35783. lr 6.587674e-05: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 158 iter 16: train loss 0.36345. lr 6.295911e-05: 100%|██████████| 17/17 [00:04<00:00, 4.01it/s]\n",
"epoch 159 iter 16: train loss 0.35740. lr 6.009997e-05: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 160 iter 16: train loss 0.36017. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.00it/s]\n",
"epoch 161 iter 16: train loss 0.35203. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 162 iter 16: train loss 0.34658. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 163 iter 16: train loss 0.35008. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.93it/s]\n",
"epoch 164 iter 16: train loss 0.34701. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.01it/s]\n",
"epoch 165 iter 16: train loss 0.34820. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.05it/s]\n",
"epoch 166 iter 16: train loss 0.34178. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 167 iter 16: train loss 0.34653. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.03it/s]\n",
"epoch 168 iter 16: train loss 0.34171. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.92it/s]\n",
"epoch 169 iter 16: train loss 0.34197. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 170 iter 16: train loss 0.34081. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 171 iter 16: train loss 0.33828. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.07it/s]\n",
"epoch 172 iter 16: train loss 0.33962. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.02it/s]\n",
"epoch 173 iter 16: train loss 0.33878. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 174 iter 16: train loss 0.34056. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.98it/s]\n",
"epoch 175 iter 16: train loss 0.33221. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 176 iter 16: train loss 0.33398. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.97it/s]\n",
"epoch 177 iter 16: train loss 0.32910. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.92it/s]\n",
"epoch 178 iter 16: train loss 0.33327. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.99it/s]\n",
"epoch 179 iter 16: train loss 0.32880. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.00it/s]\n",
"epoch 180 iter 16: train loss 0.32845. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.01it/s]\n",
"epoch 181 iter 16: train loss 0.32857. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.08it/s]\n",
"epoch 182 iter 16: train loss 0.33096. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.05it/s]\n",
"epoch 183 iter 16: train loss 0.32637. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.06it/s]\n",
"epoch 184 iter 16: train loss 0.32473. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.04it/s]\n",
"epoch 185 iter 16: train loss 0.33256. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.07it/s]\n",
"epoch 186 iter 16: train loss 0.32295. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.04it/s]\n",
"epoch 187 iter 16: train loss 0.32289. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.05it/s]\n",
"epoch 188 iter 16: train loss 0.31948. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.02it/s]\n",
"epoch 189 iter 16: train loss 0.32211. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.93it/s]\n",
"epoch 190 iter 16: train loss 0.32283. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.94it/s]\n",
"epoch 191 iter 16: train loss 0.31314. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.89it/s]\n",
"epoch 192 iter 16: train loss 0.31746. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.95it/s]\n",
"epoch 193 iter 16: train loss 0.31409. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 3.96it/s]\n",
"epoch 194 iter 16: train loss 0.31736. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.00it/s]\n",
"epoch 195 iter 16: train loss 0.31542. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.08it/s]\n",
"epoch 196 iter 16: train loss 0.30956. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.05it/s]\n",
"epoch 197 iter 16: train loss 0.31757. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.08it/s]\n",
"epoch 198 iter 16: train loss 0.31694. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.07it/s]\n",
"epoch 199 iter 16: train loss 0.30833. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.07it/s]\n",
"epoch 200 iter 16: train loss 0.30588. lr 6.000000e-05: 100%|██████████| 17/17 [00:04<00:00, 4.05it/s]\n"
]
}
],
"source": [
"from mingpt.trainer import Trainer, TrainerConfig\n",
"\n",
"# initialize a trainer instance and kick off training\n",
"tconf = TrainerConfig(max_epochs=200, batch_size=512, learning_rate=6e-4,\n",
" lr_decay=True, warmup_tokens=512*20, final_tokens=200*len(train_dataset)*block_size,\n",
" num_workers=4)\n",
"trainer = Trainer(model, train_dataset, None, tconf)\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"O God, O God! which is the business so harm!\n",
"Well, lords, and save yourselves; and no oath to be angry\n",
"That in their embraces: and, to brave the life\n",
"We have forgot and bandy as that time\n",
"Have told me and he bids me for this excellent,\n",
"Now I would say he looks on the banks\n",
"And give more strength than a wild and provide\n",
"A salt that with some friendly vow,\n",
"That from the reaches of the gain and stop the sleeves\n",
"Do scope that which He should hide for his guard\n",
"As miser made thee first way from his holy exercise.\n",
"\n",
"BUCKINGHAM:\n",
"Go, rating to London, with all these woful chances\n",
"Misthink the king and not be satisfied!\n",
"\n",
"Son:\n",
"Was ever son so rued a father's death?\n",
"\n",
"Father:\n",
"The warn's idle buy and blows: and then to make a\n",
"fire, sir, I will keep my capss with stars out\n",
"And safely point of good content.\n",
"Signior Lucentio, let us hence; good gods rest ourselves:\n",
"We shall we show her own heaven and the king\n",
"In me resolved: I have seen a lady's nose\n",
"That has been blue, but not her eyebrows.\n",
"\n",
"First Lady:\n",
"Hark ye;\n",
"The queen your mother rounds apace: we shall\n",
"Present our services to a fine new prince\n",
"One of these days; and then you'ld wanton with us,\n",
"If we would have you.\n",
"\n",
"Second Lady:\n",
"She is spread of late\n",
"Into a goodly bulk: good time encounter her!\n",
"\n",
"HENRY BOLINGBROKE:\n",
"I swear.\n",
"\n",
"THOMAS MOWBRAY:\n",
"And I, to keep all this.\n",
"\n",
"HENRY BOLINGBROKE:\n",
"Sweet peace conduct his sweet soul to the bosom\n",
"Of good old Abraham! Lords appellants,\n",
"Your differences shall all rest under gage\n",
"Till we assign you to your days of trial.\n",
"\n",
"DUKE OF YORK:\n",
"Sweet York, what wilt thou do?\n",
"Wilt thou not hide the trespass of thine own?\n",
"Have we more sons? or are we like to have?\n",
"Is not my teeming date drunk up with time?\n",
"And wilt thou pluck my fair son from mine age,\n",
"And rob me of a happy mother's name?\n",
"Is he not like thee? is he not thine own?\n",
"\n",
"DORCAS:\n",
"Whither?\n",
"\n",
"MOPSA:\n",
"It becomes thy oath full well,\n",
"Thou to me thy secrets tell.\n",
"\n",
"DORCAS:\n",
"Me too, let me go thither.\n",
"\n",
"MOPSA:\n",
"Or thou goest to the orange or mill.\n",
"\n",
"DORCAS:\n",
"If to either, tho\n"
]
}
],
"source": [
"# alright, let's sample some character-level shakespear\n",
"from mingpt.utils import sample\n",
"\n",
"context = \"O God, O God!\"\n",
"x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)\n",
"y = sample(model, x, 2000, temperature=0.9, sample=True, top_k=5)[0]\n",
"completion = ''.join([train_dataset.itos[int(i)] for i in y])\n",
"print(completion)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# well that was fun"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

413
play_math.ipynb Normal file
View File

@ -0,0 +1,413 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train GPT on addition\n",
"\n",
"Train a GPT model on a dedicated addition dataset to see if a Transformer can learn to add."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# set up logging\n",
"import logging\n",
"logging.basicConfig(\n",
" format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
" datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
" level=logging.INFO,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# make deterministic\n",
"from mingpt.utils import set_seed\n",
"set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn import functional as F"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import Dataset\n",
"\n",
"class AdditionDataset(Dataset):\n",
" \"\"\"\n",
" Returns addition problems of up to some number of digits in the inputs. Recall\n",
" that all GPT cares about are sequences of integers, and completing them according to\n",
" patterns in the data. Therefore, we have to somehow encode addition problems\n",
" as a sequence of integers.\n",
" \n",
" The sum of two n-digit numbers gives a third up to (n+1)-digit number. So our\n",
" encoding will simply be the n-digit first number, n-digit second number, \n",
" and (n+1)-digit result, all simply concatenated together. Because each addition\n",
" problem is so structured, there is no need to bother the model with encoding\n",
" +, =, or other tokens. Each possible sequence has the same length, and simply\n",
" contains the raw digits of the addition problem.\n",
" \n",
" As a few examples, the 2-digit problems:\n",
" - 85 + 50 = 135 becomes the sequence [8, 5, 5, 0, 1, 3, 5]\n",
" - 6 + 39 = 45 becomes the sequence [0, 6, 3, 9, 0, 4, 5]\n",
" etc.\n",
" \n",
" We will also only train GPT on the final (n+1)-digits because the first\n",
" two n-digits are always assumed to be given. So when we give GPT an exam later,\n",
" we will e.g. feed it the sequence [0, 6, 3, 9], which encodes that we'd like\n",
" to add 6 + 39, and hope that the model completes the integer sequence with [0, 4, 5]\n",
" in 3 sequential steps.\n",
" \n",
" fun exercise: does it help if the result is asked to be produced in reverse order?\n",
" \"\"\"\n",
"\n",
" def __init__(self, ndigit, split):\n",
" self.split = split # train/test\n",
" self.ndigit = ndigit\n",
" self.vocab_size = 10 # 10 possible digits 0..9\n",
" # +1 due to potential carry overflow, but then -1 because very last digit doesn't plug back\n",
" self.block_size = ndigit + ndigit + ndigit + 1 - 1\n",
" \n",
" # split up all addition problems into either training data or test data\n",
" num = (10**self.ndigit)**2 # total number of possible combinations\n",
" r = np.random.RandomState(1337) # make deterministic\n",
" perm = r.permutation(num)\n",
" num_test = min(int(num*0.2), 1000) # 20% of the whole dataset, or only up to 1000\n",
" self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]\n",
"\n",
" def __len__(self):\n",
" return self.ixes.size\n",
"\n",
" def __getitem__(self, idx):\n",
" # given a problem index idx, first recover the associated a + b\n",
" idx = self.ixes[idx]\n",
" nd = 10**self.ndigit\n",
" a = idx // nd\n",
" b = idx % nd\n",
" c = a + b\n",
" render = f'%0{self.ndigit}d%0{self.ndigit}d%0{self.ndigit+1}d' % (a,b,c) # e.g. 03+25=28 becomes \"0325028\" \n",
" dix = [int(s) for s in render] # convert each character to its token index\n",
" # x will be input to GPT and y will be the associated expected outputs\n",
" x = torch.tensor(dix[:-1], dtype=torch.long)\n",
" y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence\n",
" y[:self.ndigit*2-1] = -100 # we will only train in the output locations. -100 will mask loss to zero\n",
" return x, y\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# create a dataset for e.g. 2-digit addition\n",
"ndigit = 2\n",
"train_dataset = AdditionDataset(ndigit=ndigit, split='train')\n",
"test_dataset = AdditionDataset(ndigit=ndigit, split='test')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([4, 7, 1, 7, 0, 6]), tensor([-100, -100, -100, 0, 6, 4]))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_dataset[0] # sample a training instance just to see what one raw example looks like"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"08/16/2020 23:47:41 - INFO - mingpt.model - number of parameters: 4.001280e+05\n"
]
}
],
"source": [
"from mingpt.model import GPT, GPTConfig, GPT1Config\n",
"\n",
"# initialize a baby GPT model\n",
"mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, \n",
" n_layer=2, n_head=4, n_embd=128)\n",
"model = GPT(mconf)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/18 [00:00<?, ?it/s]/apcv/shared/conda-envs/apcv-6244e1d-566/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"epoch 1 iter 17: train loss 1.74049. lr 5.994512e-04: 100%|██████████| 18/18 [00:30<00:00, 1.70s/it]\n",
"08/16/2020 23:48:16 - INFO - mingpt.trainer - test loss: 1.693525\n",
"epoch 2 iter 17: train loss 1.50974. lr 5.977197e-04: 100%|██████████| 18/18 [00:01<00:00, 11.61it/s]\n",
"08/16/2020 23:48:18 - INFO - mingpt.trainer - test loss: 1.466473\n",
"epoch 3 iter 17: train loss 1.31133. lr 5.948114e-04: 100%|██████████| 18/18 [00:01<00:00, 11.45it/s]\n",
"08/16/2020 23:48:20 - INFO - mingpt.trainer - test loss: 1.256615\n",
"epoch 4 iter 17: train loss 1.22379. lr 5.907379e-04: 100%|██████████| 18/18 [00:01<00:00, 11.50it/s]\n",
"08/16/2020 23:48:21 - INFO - mingpt.trainer - test loss: 1.160792\n",
"epoch 5 iter 17: train loss 1.14308. lr 5.855153e-04: 100%|██████████| 18/18 [00:01<00:00, 11.63it/s]\n",
"08/16/2020 23:48:23 - INFO - mingpt.trainer - test loss: 1.091487\n",
"epoch 6 iter 17: train loss 1.09970. lr 5.791641e-04: 100%|██████████| 18/18 [00:01<00:00, 11.56it/s]\n",
"08/16/2020 23:48:25 - INFO - mingpt.trainer - test loss: 1.050111\n",
"epoch 7 iter 17: train loss 1.08481. lr 5.717095e-04: 100%|██████████| 18/18 [00:01<00:00, 11.53it/s]\n",
"08/16/2020 23:48:26 - INFO - mingpt.trainer - test loss: 1.037456\n",
"epoch 8 iter 17: train loss 1.03496. lr 5.631810e-04: 100%|██████████| 18/18 [00:01<00:00, 11.59it/s]\n",
"08/16/2020 23:48:28 - INFO - mingpt.trainer - test loss: 0.997156\n",
"epoch 9 iter 17: train loss 0.98606. lr 5.536122e-04: 100%|██████████| 18/18 [00:01<00:00, 11.67it/s]\n",
"08/16/2020 23:48:30 - INFO - mingpt.trainer - test loss: 0.836543\n",
"epoch 10 iter 17: train loss 0.59589. lr 5.430411e-04: 100%|██████████| 18/18 [00:01<00:00, 12.80it/s]\n",
"08/16/2020 23:48:31 - INFO - mingpt.trainer - test loss: 0.438013\n",
"epoch 11 iter 17: train loss 0.50257. lr 5.315093e-04: 100%|██████████| 18/18 [00:01<00:00, 12.99it/s]\n",
"08/16/2020 23:48:33 - INFO - mingpt.trainer - test loss: 0.343370\n",
"epoch 12 iter 17: train loss 0.44096. lr 5.190624e-04: 100%|██████████| 18/18 [00:01<00:00, 12.08it/s]\n",
"08/16/2020 23:48:34 - INFO - mingpt.trainer - test loss: 0.277625\n",
"epoch 13 iter 17: train loss 0.37445. lr 5.057497e-04: 100%|██████████| 18/18 [00:01<00:00, 12.84it/s]\n",
"08/16/2020 23:48:36 - INFO - mingpt.trainer - test loss: 0.236511\n",
"epoch 14 iter 17: train loss 0.31269. lr 4.916238e-04: 100%|██████████| 18/18 [00:01<00:00, 12.90it/s]\n",
"08/16/2020 23:48:37 - INFO - mingpt.trainer - test loss: 0.207689\n",
"epoch 15 iter 17: train loss 0.34095. lr 4.767405e-04: 100%|██████████| 18/18 [00:01<00:00, 12.74it/s]\n",
"08/16/2020 23:48:39 - INFO - mingpt.trainer - test loss: 0.165566\n",
"epoch 16 iter 17: train loss 0.25957. lr 4.611586e-04: 100%|██████████| 18/18 [00:01<00:00, 12.69it/s]\n",
"08/16/2020 23:48:40 - INFO - mingpt.trainer - test loss: 0.123080\n",
"epoch 17 iter 17: train loss 0.23488. lr 4.449397e-04: 100%|██████████| 18/18 [00:01<00:00, 12.85it/s]\n",
"08/16/2020 23:48:42 - INFO - mingpt.trainer - test loss: 0.091252\n",
"epoch 18 iter 17: train loss 0.20269. lr 4.281479e-04: 100%|██████████| 18/18 [00:01<00:00, 12.72it/s]\n",
"08/16/2020 23:48:43 - INFO - mingpt.trainer - test loss: 0.078601\n",
"epoch 19 iter 17: train loss 0.19535. lr 4.108497e-04: 100%|██████████| 18/18 [00:01<00:00, 12.78it/s]\n",
"08/16/2020 23:48:45 - INFO - mingpt.trainer - test loss: 0.055412\n",
"epoch 20 iter 17: train loss 0.16152. lr 3.931133e-04: 100%|██████████| 18/18 [00:01<00:00, 12.66it/s]\n",
"08/16/2020 23:48:46 - INFO - mingpt.trainer - test loss: 0.051874\n",
"epoch 21 iter 17: train loss 0.14061. lr 3.750088e-04: 100%|██████████| 18/18 [00:01<00:00, 12.84it/s]\n",
"08/16/2020 23:48:48 - INFO - mingpt.trainer - test loss: 0.044502\n",
"epoch 22 iter 17: train loss 0.16309. lr 3.566079e-04: 100%|██████████| 18/18 [00:01<00:00, 12.67it/s]\n",
"08/16/2020 23:48:49 - INFO - mingpt.trainer - test loss: 0.036376\n",
"epoch 23 iter 17: train loss 0.14411. lr 3.379832e-04: 100%|██████████| 18/18 [00:01<00:00, 13.21it/s]\n",
"08/16/2020 23:48:51 - INFO - mingpt.trainer - test loss: 0.029843\n",
"epoch 24 iter 17: train loss 0.12110. lr 3.192084e-04: 100%|██████████| 18/18 [00:01<00:00, 12.74it/s]\n",
"08/16/2020 23:48:52 - INFO - mingpt.trainer - test loss: 0.025040\n",
"epoch 25 iter 17: train loss 0.11360. lr 3.003577e-04: 100%|██████████| 18/18 [00:01<00:00, 12.77it/s]\n",
"08/16/2020 23:48:54 - INFO - mingpt.trainer - test loss: 0.023500\n",
"epoch 26 iter 17: train loss 0.13910. lr 2.815056e-04: 100%|██████████| 18/18 [00:01<00:00, 12.78it/s]\n",
"08/16/2020 23:48:55 - INFO - mingpt.trainer - test loss: 0.022606\n",
"epoch 27 iter 17: train loss 0.07931. lr 2.627266e-04: 100%|██████████| 18/18 [00:01<00:00, 12.74it/s]\n",
"08/16/2020 23:48:57 - INFO - mingpt.trainer - test loss: 0.015403\n",
"epoch 28 iter 17: train loss 0.09684. lr 2.440948e-04: 100%|██████████| 18/18 [00:01<00:00, 11.92it/s]\n",
"08/16/2020 23:48:58 - INFO - mingpt.trainer - test loss: 0.015245\n",
"epoch 29 iter 17: train loss 0.09055. lr 2.256841e-04: 100%|██████████| 18/18 [00:01<00:00, 12.77it/s]\n",
"08/16/2020 23:49:00 - INFO - mingpt.trainer - test loss: 0.012647\n",
"epoch 30 iter 17: train loss 0.08837. lr 2.075671e-04: 100%|██████████| 18/18 [00:01<00:00, 12.59it/s]\n",
"08/16/2020 23:49:01 - INFO - mingpt.trainer - test loss: 0.011611\n",
"epoch 31 iter 17: train loss 0.08425. lr 1.898155e-04: 100%|██████████| 18/18 [00:01<00:00, 12.43it/s]\n",
"08/16/2020 23:49:03 - INFO - mingpt.trainer - test loss: 0.009952\n",
"epoch 32 iter 17: train loss 0.10772. lr 1.724993e-04: 100%|██████████| 18/18 [00:01<00:00, 12.40it/s]\n",
"08/16/2020 23:49:05 - INFO - mingpt.trainer - test loss: 0.008648\n",
"epoch 33 iter 17: train loss 0.07272. lr 1.556871e-04: 100%|██████████| 18/18 [00:01<00:00, 12.57it/s]\n",
"08/16/2020 23:49:06 - INFO - mingpt.trainer - test loss: 0.010154\n",
"epoch 34 iter 17: train loss 0.05550. lr 1.394453e-04: 100%|██████████| 18/18 [00:01<00:00, 12.47it/s]\n",
"08/16/2020 23:49:08 - INFO - mingpt.trainer - test loss: 0.007668\n",
"epoch 35 iter 17: train loss 0.05451. lr 1.238381e-04: 100%|██████████| 18/18 [00:01<00:00, 12.59it/s]\n",
"08/16/2020 23:49:09 - INFO - mingpt.trainer - test loss: 0.008095\n",
"epoch 36 iter 17: train loss 0.09133. lr 1.089272e-04: 100%|██████████| 18/18 [00:01<00:00, 12.39it/s]\n",
"08/16/2020 23:49:11 - INFO - mingpt.trainer - test loss: 0.006615\n",
"epoch 37 iter 17: train loss 0.06825. lr 9.477150e-05: 100%|██████████| 18/18 [00:01<00:00, 12.27it/s]\n",
"08/16/2020 23:49:12 - INFO - mingpt.trainer - test loss: 0.005874\n",
"epoch 38 iter 17: train loss 0.05798. lr 8.142699e-05: 100%|██████████| 18/18 [00:01<00:00, 12.49it/s]\n",
"08/16/2020 23:49:14 - INFO - mingpt.trainer - test loss: 0.005701\n",
"epoch 39 iter 17: train loss 0.06975. lr 6.894639e-05: 100%|██████████| 18/18 [00:01<00:00, 12.88it/s]\n",
"08/16/2020 23:49:15 - INFO - mingpt.trainer - test loss: 0.005469\n",
"epoch 40 iter 17: train loss 0.06070. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.80it/s]\n",
"08/16/2020 23:49:17 - INFO - mingpt.trainer - test loss: 0.005307\n",
"epoch 41 iter 17: train loss 0.06378. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.60it/s]\n",
"08/16/2020 23:49:18 - INFO - mingpt.trainer - test loss: 0.005681\n",
"epoch 42 iter 17: train loss 0.04885. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.81it/s]\n",
"08/16/2020 23:49:20 - INFO - mingpt.trainer - test loss: 0.005456\n",
"epoch 43 iter 17: train loss 0.06409. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.81it/s]\n",
"08/16/2020 23:49:21 - INFO - mingpt.trainer - test loss: 0.004907\n",
"epoch 44 iter 17: train loss 0.07563. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.69it/s]\n",
"08/16/2020 23:49:23 - INFO - mingpt.trainer - test loss: 0.004650\n",
"epoch 45 iter 17: train loss 0.03149. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.79it/s]\n",
"08/16/2020 23:49:24 - INFO - mingpt.trainer - test loss: 0.004626\n",
"epoch 46 iter 17: train loss 0.07037. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.86it/s]\n",
"08/16/2020 23:49:26 - INFO - mingpt.trainer - test loss: 0.004147\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 47 iter 17: train loss 0.07650. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.82it/s]\n",
"08/16/2020 23:49:27 - INFO - mingpt.trainer - test loss: 0.004611\n",
"epoch 48 iter 17: train loss 0.06342. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.63it/s]\n",
"08/16/2020 23:49:29 - INFO - mingpt.trainer - test loss: 0.004083\n",
"epoch 49 iter 17: train loss 0.12429. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.69it/s]\n",
"08/16/2020 23:49:30 - INFO - mingpt.trainer - test loss: 0.004081\n",
"epoch 50 iter 17: train loss 0.04616. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.19it/s]\n",
"08/16/2020 23:49:32 - INFO - mingpt.trainer - test loss: 0.003922\n"
]
}
],
"source": [
"from mingpt.trainer import Trainer, TrainerConfig\n",
"\n",
"# initialize a trainer instance and kick off training\n",
"tconf = TrainerConfig(max_epochs=50, batch_size=512, learning_rate=6e-4,\n",
" lr_decay=True, warmup_tokens=1024, final_tokens=50*len(train_dataset)*(ndigit+1),\n",
" num_workers=4)\n",
"trainer = Trainer(model, train_dataset, test_dataset, tconf)\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# now let's give the trained model an addition exam\n",
"from torch.utils.data.dataloader import DataLoader\n",
"from mingpt.utils import sample\n",
"\n",
"def give_exam(dataset, batch_size=32, max_batches=-1):\n",
" \n",
" results = []\n",
" loader = DataLoader(dataset, batch_size=batch_size)\n",
" for b, (x, y) in enumerate(loader):\n",
" x = x.to(trainer.device)\n",
" d1d2 = x[:, :ndigit*2]\n",
" d1d2d3 = sample(model, d1d2, ndigit+1)\n",
" d3 = d1d2d3[:, -(ndigit+1):]\n",
" factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device)\n",
" # decode the integers from individual digits\n",
" d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)\n",
" d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)\n",
" d3i_pred = (d3 * factors).sum(1)\n",
" d3i_gt = d1i + d2i\n",
" correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line, lol\n",
" for i in range(x.size(0)):\n",
" results.append(int(correct[i]))\n",
" judge = 'YEP!!!' if correct[i] else 'NOPE'\n",
" if not correct[i]:\n",
" print(\"GPT claims that %03d + %03d = %03d (gt is %03d; %s)\" \n",
" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i], judge))\n",
" \n",
" if max_batches >= 0 and b+1 >= max_batches:\n",
" break\n",
"\n",
" print(\"final score: %d/%d = %.2f%% correct\" % (np.sum(results), len(results), 100*np.mean(results)))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"final score: 9000/9000 = 100.00% correct\n"
]
}
],
"source": [
"# training set: how well did we memorize?\n",
"give_exam(train_dataset, batch_size=1024, max_batches=10)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT claims that 055 + 045 = 090 (gt is 100; NOPE)\n",
"final score: 999/1000 = 99.90% correct\n"
]
}
],
"source": [
"# test set: how well did we generalize?\n",
"give_exam(test_dataset, batch_size=1024, max_batches=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# well that's amusing... our model learned everything except 55 + 45"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}