mirror of
https://github.com/karpathy/minGPT
synced 2024-05-04 14:46:11 +02:00
Add optimizer to Trainer's self for callbacks.
This commit is contained in:
parent
e2065c59c6
commit
c4c650e3d5
|
@ -31,6 +31,7 @@ class Trainer:
|
|||
def __init__(self, config, model, train_dataset):
|
||||
self.config = config
|
||||
self.model = model
|
||||
self.optimizer = None
|
||||
self.train_dataset = train_dataset
|
||||
self.callbacks = defaultdict(list)
|
||||
|
||||
|
@ -61,7 +62,7 @@ class Trainer:
|
|||
model, config = self.model, self.config
|
||||
|
||||
# setup the optimizer
|
||||
optimizer = model.configure_optimizers(config)
|
||||
self.optimizer = model.configure_optimizers(config)
|
||||
|
||||
# setup the dataloader
|
||||
train_loader = DataLoader(
|
||||
|
@ -95,7 +96,7 @@ class Trainer:
|
|||
model.zero_grad(set_to_none=True)
|
||||
self.loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
|
||||
optimizer.step()
|
||||
self.optimizer.step()
|
||||
|
||||
self.trigger_callbacks('on_batch_end')
|
||||
self.iter_num += 1
|
||||
|
|
Loading…
Reference in New Issue