mirror of
https://github.com/karpathy/minGPT
synced 2024-05-18 05:26:03 +02:00
properly separate params that should be weight decayed, and make a small incremental step towards Lightning compatibility by creating the optimizer object inside the model's configure_optimizers
This commit is contained in:
parent
23982656df
commit
bbbdac74fa
|
@ -119,6 +119,9 @@ class GPT(nn.Module):
|
|||
|
||||
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def get_block_size(self):
|
||||
return self.block_size
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
|
@ -128,8 +131,51 @@ class GPT(nn.Module):
|
|||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def get_block_size(self):
|
||||
return self.block_size
|
||||
def configure_optimizers(self, train_config):
|
||||
"""
|
||||
This long function is unfortunately doing something very simple and is being very defensive:
|
||||
We are separating out all parameters of the model into two buckets: those that will experience
|
||||
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
||||
We are then returning the PyTorch optimizer object.
|
||||
"""
|
||||
|
||||
# separate out all parameters to those that will and won't experience regularizing weight decay
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, )
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
for mn, m in self.named_modules():
|
||||
for pn, p in m.named_parameters():
|
||||
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
||||
|
||||
if pn.endswith('bias'):
|
||||
# all biases will not be decayed
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||
# weights of whitelist modules will be weight decayed
|
||||
decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||
# weights of blacklist modules will NOT be weight decayed
|
||||
no_decay.add(fpn)
|
||||
|
||||
# special case the position embedding parameter in the root GPT module as not decayed
|
||||
no_decay.add('pos_emb')
|
||||
|
||||
# validate that we considered every parameter
|
||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
||||
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
||||
% (str(param_dict.keys() - union_params), )
|
||||
|
||||
# create the pytorch optimizer object
|
||||
optim_groups = [
|
||||
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
|
||||
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
|
||||
return optimizer
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
b, t = idx.size()
|
||||
|
|
|
@ -51,22 +51,15 @@ class Trainer:
|
|||
self.model = torch.nn.DataParallel(self.model).to(self.device)
|
||||
|
||||
def save_checkpoint(self):
|
||||
ckpt_model = self.model.module if hasattr(self.model, "module") else self.model
|
||||
# DataParallel wrappers keep raw model object in .module attribute
|
||||
raw_model = self.model.module if hasattr(self.model, "module") else self.model
|
||||
logger.info("saving %s", self.config.ckpt_path)
|
||||
torch.save(ckpt_model.state_dict(), self.config.ckpt_path)
|
||||
torch.save(raw_model.state_dict(), self.config.ckpt_path)
|
||||
|
||||
def train(self):
|
||||
model, config = self.model, self.config
|
||||
|
||||
# create the optimizer
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
|
||||
params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
|
||||
optim_groups = [
|
||||
{"params": params_decay, "weight_decay": config.weight_decay},
|
||||
{"params": params_nodecay, "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = optim.AdamW(optim_groups, lr=config.learning_rate, betas=config.betas)
|
||||
raw_model = model.module if hasattr(self.model, "module") else model
|
||||
optimizer = raw_model.configure_optimizers(config)
|
||||
|
||||
def run_epoch(split):
|
||||
is_train = split == 'train'
|
||||
|
|
Loading…
Reference in New Issue