mirror of
https://github.com/karpathy/minGPT
synced 2024-04-25 03:15:12 +02:00
110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
"""
|
|
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
|
|
so nothing in this file really has anything to do with GPT specifically.
|
|
"""
|
|
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
from torch.utils.data.dataloader import DataLoader
|
|
from mingpt.utils import CfgNode as CN
|
|
|
|
class Trainer:
|
|
|
|
@staticmethod
|
|
def get_default_config():
|
|
C = CN()
|
|
# device to train on
|
|
C.device = 'auto'
|
|
# dataloder parameters
|
|
C.num_workers = 4
|
|
# optimizer parameters
|
|
C.max_iters = None
|
|
C.batch_size = 64
|
|
C.learning_rate = 3e-4
|
|
C.betas = (0.9, 0.95)
|
|
C.weight_decay = 0.1 # only applied on matmul weights
|
|
C.grad_norm_clip = 1.0
|
|
return C
|
|
|
|
def __init__(self, config, model, train_dataset):
|
|
self.config = config
|
|
self.model = model
|
|
self.optimizer = None
|
|
self.train_dataset = train_dataset
|
|
self.callbacks = defaultdict(list)
|
|
|
|
# determine the device we'll train on
|
|
if config.device == 'auto':
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
else:
|
|
self.device = config.device
|
|
self.model = self.model.to(self.device)
|
|
print("running on device", self.device)
|
|
|
|
# variables that will be assigned to trainer class later for logging and etc
|
|
self.iter_num = 0
|
|
self.iter_time = 0.0
|
|
self.iter_dt = 0.0
|
|
|
|
def add_callback(self, onevent: str, callback):
|
|
self.callbacks[onevent].append(callback)
|
|
|
|
def set_callback(self, onevent: str, callback):
|
|
self.callbacks[onevent] = [callback]
|
|
|
|
def trigger_callbacks(self, onevent: str):
|
|
for callback in self.callbacks.get(onevent, []):
|
|
callback(self)
|
|
|
|
def run(self):
|
|
model, config = self.model, self.config
|
|
|
|
# setup the optimizer
|
|
self.optimizer = model.configure_optimizers(config)
|
|
|
|
# setup the dataloader
|
|
train_loader = DataLoader(
|
|
self.train_dataset,
|
|
sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)),
|
|
shuffle=False,
|
|
pin_memory=True,
|
|
batch_size=config.batch_size,
|
|
num_workers=config.num_workers,
|
|
)
|
|
|
|
model.train()
|
|
self.iter_num = 0
|
|
self.iter_time = time.time()
|
|
data_iter = iter(train_loader)
|
|
while True:
|
|
|
|
# fetch the next batch (x, y) and re-init iterator if needed
|
|
try:
|
|
batch = next(data_iter)
|
|
except StopIteration:
|
|
data_iter = iter(train_loader)
|
|
batch = next(data_iter)
|
|
batch = [t.to(self.device) for t in batch]
|
|
x, y = batch
|
|
|
|
# forward the model
|
|
logits, self.loss = model(x, y)
|
|
|
|
# backprop and update the parameters
|
|
model.zero_grad(set_to_none=True)
|
|
self.loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
|
|
self.optimizer.step()
|
|
|
|
self.trigger_callbacks('on_batch_end')
|
|
self.iter_num += 1
|
|
tnow = time.time()
|
|
self.iter_dt = tnow - self.iter_time
|
|
self.iter_time = tnow
|
|
|
|
# termination conditions
|
|
if config.max_iters is not None and self.iter_num >= config.max_iters:
|
|
break
|