1
0
Fork 0
mirror of https://github.com/karpathy/minGPT synced 2024-05-18 05:26:03 +02:00

properly separate params that should be weight decayed, and make a small incremental step towards Lightning compatibility by creating the optimizer object inside the model's configure_optimizers

This commit is contained in:
“Andrej 2020-08-23 15:48:20 -07:00
parent 23982656df
commit bbbdac74fa
2 changed files with 53 additions and 14 deletions

View File

@ -119,6 +119,9 @@ class GPT(nn.Module):
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
@ -128,8 +131,51 @@ class GPT(nn.Module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def get_block_size(self):
return self.block_size
def configure_optimizers(self, train_config):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
no_decay.add('pos_emb')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
return optimizer
def forward(self, idx, targets=None):
b, t = idx.size()

View File

@ -51,22 +51,15 @@ class Trainer:
self.model = torch.nn.DataParallel(self.model).to(self.device)
def save_checkpoint(self):
ckpt_model = self.model.module if hasattr(self.model, "module") else self.model
# DataParallel wrappers keep raw model object in .module attribute
raw_model = self.model.module if hasattr(self.model, "module") else self.model
logger.info("saving %s", self.config.ckpt_path)
torch.save(ckpt_model.state_dict(), self.config.ckpt_path)
torch.save(raw_model.state_dict(), self.config.ckpt_path)
def train(self):
model, config = self.model, self.config
# create the optimizer
no_decay = ["bias", "LayerNorm.weight"]
params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
optim_groups = [
{"params": params_decay, "weight_decay": config.weight_decay},
{"params": params_nodecay, "weight_decay": 0.0},
]
optimizer = optim.AdamW(optim_groups, lr=config.learning_rate, betas=config.betas)
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)
def run_epoch(split):
is_train = split == 'train'