1
0
Fork 0
mirror of https://github.com/karpathy/minGPT synced 2024-05-04 14:46:11 +02:00

Add optimizer to Trainer's self for callbacks.

This commit is contained in:
Luigi Di Sotto 2022-07-26 10:17:44 +02:00
parent e2065c59c6
commit c4c650e3d5

View File

@ -31,6 +31,7 @@ class Trainer:
def __init__(self, config, model, train_dataset):
self.config = config
self.model = model
self.optimizer = None
self.train_dataset = train_dataset
self.callbacks = defaultdict(list)
@ -61,7 +62,7 @@ class Trainer:
model, config = self.model, self.config
# setup the optimizer
optimizer = model.configure_optimizers(config)
self.optimizer = model.configure_optimizers(config)
# setup the dataloader
train_loader = DataLoader(
@ -95,7 +96,7 @@ class Trainer:
model.zero_grad(set_to_none=True)
self.loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
optimizer.step()
self.optimizer.step()
self.trigger_callbacks('on_batch_end')
self.iter_num += 1