1
0
mirror of https://github.com/karpathy/minGPT synced 2024-11-15 19:10:39 +01:00
minGPT/mingpt/model.py
2020-08-30 10:46:32 -07:00

223 lines
10 KiB
Python

"""
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
- all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
"""
import math
import logging
from typing import List, Set, Dict, Tuple, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
# -----------------------------------------------------------------------------
import os
if int(os.environ.get('USE_LIGHTNING', 0)):
import pytorch_lightning as pl
else:
import mingpt.fake_lightning as pl
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
It is possible to use torch.nn.MultiheadAttention here but I am including an
explicit implementation here to show that there is nothing too scary here.
"""
def __init__(self, n_embd, block_size, n_head, attn_pdrop, resid_pdrop):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
# key, query, value projections for all heads
self.key = nn.Linear(n_embd, n_embd)
self.query = nn.Linear(n_embd, n_embd)
self.value = nn.Linear(n_embd, n_embd)
# regularization
self.attn_drop = nn.Dropout(attn_pdrop)
self.resid_drop = nn.Dropout(resid_pdrop)
# output projection
self.proj = nn.Linear(n_embd, n_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size))
.view(1, 1, block_size, block_size))
def forward(self, x, layer_past=None):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
return y
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, n_embd, block_size, n_head, attn_pdrop, resid_pdrop):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, block_size, n_head, attn_pdrop, resid_pdrop)
self.mlp = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(resid_pdrop),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(pl.LightningModule):
""" the full GPT language model, with a context size of block_size """
def __init__(self,
# model definition args
vocab_size: int, # size of the vocabulary (number of possible tokens)
block_size: int, # length of the model's context window in time
n_layer: int, # depth of the model; number of Transformer blocks in sequence
n_embd: int, # the "width" of the model, number of channels in each Transformer
n_head: int, # number of heads in each multi-head attention inside each Transformer block
# model optimization args
learning_rate: float = 3e-4, # the base learning rate of the model
weight_decay: float = 0.1, # amount of regularizing L2 weight decay on MatMul ops
betas: Tuple[float, float] = (0.9, 0.95), # momentum terms (betas) for the Adam optimizer
embd_pdrop: float = 0.1, # \in [0,1]: amount of dropout on input embeddings
resid_pdrop: float = 0.1, # \in [0,1]: amount of dropout in each residual connection
attn_pdrop: float = 0.1, # \in [0,1]: amount of dropout on the attention matrix
):
super().__init__()
# save these for optimizer init later
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.betas = betas
# input embedding stem: drop(content + position)
self.tok_emb = nn.Embedding(vocab_size, n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
self.drop = nn.Dropout(embd_pdrop)
# deep transformer: just a sequence of transformer blocks
self.blocks = nn.Sequential(*[Block(n_embd, block_size, n_head, attn_pdrop, resid_pdrop) for _ in range(n_layer)])
# decoder: at the end one more layernorm and decode the answers
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False) # no need for extra bias due to one in ln_f
self.block_size = block_size
self.apply(self._init_weights)
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
"""
Vanilla model initialization:
- all MatMul weights \in N(0, 0.02) and biases to zero
- all LayerNorm post-normalization scaling set to identity, so weight=1, bias=0
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def configure_optimizers(self):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
no_decay.add('pos_emb')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=self.betas)
return optimizer
def forward(self, idx):
b, t = idx.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)
return logits
def step_(self, split, batch, batch_idx=None):
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return {'loss': loss}
def training_step(self, *args, **kwargs):
return self.step_('train', *args, **kwargs)
def validation_step(self, *args, **kwargs):
return self.step_('val', *args, **kwargs)
def test_step(self, *args, **kwargs):
return self.step_('test', *args, **kwargs)