1
0
mirror of https://github.com/karpathy/minGPT synced 2024-11-15 19:10:39 +01:00
minGPT/mingpt/fake_lightning.py
2020-08-30 11:39:55 -07:00

151 lines
5.6 KiB
Python

"""
A manual, minimal and non-full-featured implementation of boilerplate training loop.
Intentionally made to have the same API as PyTorch Lightning, giving two benefits:
1) Everyone can inspect/hack this simple implementation for educational purposes
2) Everyone can run the full Lightning implementation when they just want to go FAST
"""
import os
import math
import logging
from tqdm import tqdm
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class LightningModule(nn.Module):
pass
class Callback:
pass
# -----------------------------------------------------------------------------
"""
Simple Trainer object; Boilerplate that could apply to any arbitrary neural network,
so nothing here really has anything to do with GPT specifically. This is a
very basic Trainer class that will only train the model on up to one GPU.
"""
class Trainer:
def __init__(self, max_epochs, gpus=0, gradient_clip_val=None, default_root_dir='.', callbacks=None,
precision=32, **kwargs):
self.gpus = gpus
self.max_epochs = max_epochs
self.gradient_clip_val = gradient_clip_val
self.callbacks = [] if callbacks is None else callbacks
self.model = None
if default_root_dir is not None:
os.makedirs(default_root_dir, exist_ok = True)
self.default_root_dir = default_root_dir
if self.gpus > 1:
logger.error("This simple Trainer does not support > 1 GPUs, will just use one.")
if precision != 32:
logger.error("This simple Trainer does not support non-fp32 precision, will use fp32")
def save_checkpoint(self):
ckpt_path = os.path.join(self.default_root_dir, 'model.pt')
logger.info("saving model checkpoint to %s", ckpt_path)
torch.save(self.model.state_dict(), ckpt_path)
def load_checkpoint(self):
ckpt_path = os.path.join(self.default_root_dir, 'model.pt')
logger.info("loading model from %s", ckpt_path)
state_dict = torch.load(ckpt_path)
self.model.load_state_dict(state_dict)
def eval_split_(self, dataloader, split):
self.model.eval()
use_gpu = self.gpus > 0 and torch.cuda.is_available()
losses = []
for it, (x, y) in enumerate(dataloader):
# place data on the correct device
if use_gpu:
x, y = x.cuda(), y.cuda(non_blocking=True)
# forward the model
with torch.no_grad():
if split == 'val':
result = self.model.validation_step((x, y))
elif split == 'test':
result = self.model.test_step((x, y))
losses.append(result['loss'].item())
mean_loss = torch.mean(torch.tensor(losses)).item()
logger.info("%s loss: %f", split, mean_loss)
return mean_loss
def test(self, test_dataloaders): # note we expect a list of dataloaders here
self.load_checkpoint() # load the best checkpoint we found during optimization
return self.eval_split_(test_dataloaders, 'test')
def val(self, val_dataloader):
return self.eval_split_(val_dataloader, 'val')
def fit(self, model, train_dataloader, val_dataloader=None):
self.model = model # bind model to the class here
self.model.train()
# ship model to gpu if possible
use_gpu = self.gpus > 0 and torch.cuda.is_available()
if use_gpu:
logger.info("found CUDA device, shipping model to GPU")
self.model.cuda()
# prepare the optimizer
optimizer = self.model.configure_optimizers()
self.optimizers = [optimizer]
# start the training loop
best_val_loss = float('inf')
for epoch in range(self.max_epochs):
# do an epoch of training
pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
for it, (x, y) in pbar:
# place data on the correct device
if use_gpu:
x, y = x.cuda(), y.cuda(non_blocking=True)
# forward the model
result = self.model.training_step((x, y))
loss = result['loss']
# reset gradient
for param in self.model.parameters():
param.grad = None # a faster alternative to model.zero_grad()
# backward pass
loss.backward()
# clip the gradient to mitigate loss explosions
if self.gradient_clip_val is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_val)
# update all parameters
optimizer.step() # todo: use fused optimizer
# notify all relevant callbacks that a batch update ended. e.g. a callback may decay learning rate
for cb in self.callbacks:
if hasattr(cb, 'on_train_batch_end'):
cb.on_train_batch_end(self, None, (x, y))
# report progress
lr = optimizer.param_groups[0]['lr']
pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")
# calculate the current validation loss and checkpoint the model for early stopping
if val_dataloader is not None:
val_loss = self.val(val_dataloader)
if (self.default_root_dir is not None) and (val_loss < best_val_loss):
best_val_loss = val_loss
self.save_checkpoint()