1
0
Fork 0
mirror of https://github.com/karpathy/minGPT synced 2024-05-18 05:26:03 +02:00

ok i hated the previous global/local config idea. reverting it and simplying and i think this is the best api so far

This commit is contained in:
Andrej 2022-06-27 20:41:01 +00:00 committed by GitHub
parent ea20661f78
commit 00aa9cb2ed
4 changed files with 28 additions and 25 deletions

View File

@ -3,9 +3,6 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
so nothing in this file really has anything to do with GPT specifically.
"""
import os
import sys
import json
import time
from collections import defaultdict
@ -17,7 +14,6 @@ class Trainer:
@classmethod
def get_default_config(cls):
""" returns a default 'local' config for this class alone """
C = CN()
# device to train on
C.device = 'auto'
@ -31,29 +27,19 @@ class Trainer:
C.grad_norm_clip = 1.0
return C
def __init__(self, gconfig, model, train_dataset):
# gconfig is a 'global' config for everything, not just Trainer class alone
self.gconfig = gconfig
self.config = gconfig.trainer
def __init__(self, config, model, train_dataset):
self.config = config
self.model = model
self.train_dataset = train_dataset
self.callbacks = defaultdict(list)
# set up logging
work_dir = gconfig.system.work_dir
os.makedirs(work_dir, exist_ok=True)
with open(os.path.join(work_dir, 'args.txt'), 'w') as f:
f.write(' '.join(sys.argv))
with open(os.path.join(work_dir, 'config.json'), 'w') as f:
f.write(json.dumps(self.gconfig.to_dict(), indent=4))
# take over whatever gpus are on the system
if self.config.device == 'auto':
# determine the device we'll train on
if config.device == 'auto':
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = self.config.device
print("running on device", self.device)
self.device = config.device
self.model = self.model.to(self.device)
print("running on device", self.device)
# variables that will be assigned to trainer class later for logging and etc
self.iter_num = 0

View File

@ -1,4 +1,7 @@
import os
import sys
import json
import random
from ast import literal_eval
@ -13,6 +16,18 @@ def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def setup_logging(config):
""" monotonous bookkeeping """
work_dir = config.system.work_dir
# create the work directory if it doesn't already exist
os.makedirs(work_dir, exist_ok=True)
# log the args (if any)
with open(os.path.join(work_dir, 'args.txt'), 'w') as f:
f.write(' '.join(sys.argv))
# log the config itself
with open(os.path.join(work_dir, 'config.json'), 'w') as f:
f.write(json.dumps(config.to_dict(), indent=4))
def top_k_logits(logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()

View File

@ -4,6 +4,7 @@ Trains a GPT to add n-digit numbers.
import os
import sys
import json
import torch
from torch.utils.data import Dataset
@ -11,7 +12,7 @@ from torch.utils.data.dataloader import DataLoader
from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, sample, CfgNode as CN
from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN
# -----------------------------------------------------------------------------
@ -125,7 +126,7 @@ if __name__ == '__main__':
config = get_config()
config.merge_from_args(sys.argv[1:])
print(config)
# inits
setup_logging(config)
set_seed(config.system.seed)
# construct train and test datasets
@ -138,7 +139,7 @@ if __name__ == '__main__':
model = GPT(config.model)
# construct the trainer object
trainer = Trainer(config, model, train_dataset)
trainer = Trainer(config.trainer, model, train_dataset)
# helper function for the evaluation of a model
def eval_split(trainer, split, max_batches=None):

View File

@ -11,7 +11,7 @@ from torch.utils.data.dataloader import DataLoader
from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, sample, CfgNode as CN
from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN
# -----------------------------------------------------------------------------
@ -90,6 +90,7 @@ if __name__ == '__main__':
config = get_config()
config.merge_from_args(sys.argv[1:])
print(config)
setup_logging(config)
set_seed(config.system.seed)
# construct the training dataset
@ -102,7 +103,7 @@ if __name__ == '__main__':
model = GPT(config.model)
# construct the trainer object
trainer = Trainer(config, model, train_dataset)
trainer = Trainer(config.trainer, model, train_dataset)
# iteration callback
def batch_end_callback(trainer):