mirror of
https://github.com/karpathy/minGPT
synced 2024-05-18 05:26:03 +02:00
ok i hated the previous global/local config idea. reverting it and simplying and i think this is the best api so far
This commit is contained in:
parent
ea20661f78
commit
00aa9cb2ed
|
@ -3,9 +3,6 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
|
|||
so nothing in this file really has anything to do with GPT specifically.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
|
@ -17,7 +14,6 @@ class Trainer:
|
|||
|
||||
@classmethod
|
||||
def get_default_config(cls):
|
||||
""" returns a default 'local' config for this class alone """
|
||||
C = CN()
|
||||
# device to train on
|
||||
C.device = 'auto'
|
||||
|
@ -31,29 +27,19 @@ class Trainer:
|
|||
C.grad_norm_clip = 1.0
|
||||
return C
|
||||
|
||||
def __init__(self, gconfig, model, train_dataset):
|
||||
# gconfig is a 'global' config for everything, not just Trainer class alone
|
||||
self.gconfig = gconfig
|
||||
self.config = gconfig.trainer
|
||||
def __init__(self, config, model, train_dataset):
|
||||
self.config = config
|
||||
self.model = model
|
||||
self.train_dataset = train_dataset
|
||||
self.callbacks = defaultdict(list)
|
||||
|
||||
# set up logging
|
||||
work_dir = gconfig.system.work_dir
|
||||
os.makedirs(work_dir, exist_ok=True)
|
||||
with open(os.path.join(work_dir, 'args.txt'), 'w') as f:
|
||||
f.write(' '.join(sys.argv))
|
||||
with open(os.path.join(work_dir, 'config.json'), 'w') as f:
|
||||
f.write(json.dumps(self.gconfig.to_dict(), indent=4))
|
||||
|
||||
# take over whatever gpus are on the system
|
||||
if self.config.device == 'auto':
|
||||
# determine the device we'll train on
|
||||
if config.device == 'auto':
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
else:
|
||||
self.device = self.config.device
|
||||
print("running on device", self.device)
|
||||
self.device = config.device
|
||||
self.model = self.model.to(self.device)
|
||||
print("running on device", self.device)
|
||||
|
||||
# variables that will be assigned to trainer class later for logging and etc
|
||||
self.iter_num = 0
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import random
|
||||
from ast import literal_eval
|
||||
|
||||
|
@ -13,6 +16,18 @@ def set_seed(seed):
|
|||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def setup_logging(config):
|
||||
""" monotonous bookkeeping """
|
||||
work_dir = config.system.work_dir
|
||||
# create the work directory if it doesn't already exist
|
||||
os.makedirs(work_dir, exist_ok=True)
|
||||
# log the args (if any)
|
||||
with open(os.path.join(work_dir, 'args.txt'), 'w') as f:
|
||||
f.write(' '.join(sys.argv))
|
||||
# log the config itself
|
||||
with open(os.path.join(work_dir, 'config.json'), 'w') as f:
|
||||
f.write(json.dumps(config.to_dict(), indent=4))
|
||||
|
||||
def top_k_logits(logits, k):
|
||||
v, ix = torch.topk(logits, k)
|
||||
out = logits.clone()
|
||||
|
|
|
@ -4,6 +4,7 @@ Trains a GPT to add n-digit numbers.
|
|||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -11,7 +12,7 @@ from torch.utils.data.dataloader import DataLoader
|
|||
|
||||
from mingpt.model import GPT
|
||||
from mingpt.trainer import Trainer
|
||||
from mingpt.utils import set_seed, sample, CfgNode as CN
|
||||
from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
@ -125,7 +126,7 @@ if __name__ == '__main__':
|
|||
config = get_config()
|
||||
config.merge_from_args(sys.argv[1:])
|
||||
print(config)
|
||||
# inits
|
||||
setup_logging(config)
|
||||
set_seed(config.system.seed)
|
||||
|
||||
# construct train and test datasets
|
||||
|
@ -138,7 +139,7 @@ if __name__ == '__main__':
|
|||
model = GPT(config.model)
|
||||
|
||||
# construct the trainer object
|
||||
trainer = Trainer(config, model, train_dataset)
|
||||
trainer = Trainer(config.trainer, model, train_dataset)
|
||||
|
||||
# helper function for the evaluation of a model
|
||||
def eval_split(trainer, split, max_batches=None):
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch.utils.data.dataloader import DataLoader
|
|||
|
||||
from mingpt.model import GPT
|
||||
from mingpt.trainer import Trainer
|
||||
from mingpt.utils import set_seed, sample, CfgNode as CN
|
||||
from mingpt.utils import set_seed, sample, setup_logging, CfgNode as CN
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
@ -90,6 +90,7 @@ if __name__ == '__main__':
|
|||
config = get_config()
|
||||
config.merge_from_args(sys.argv[1:])
|
||||
print(config)
|
||||
setup_logging(config)
|
||||
set_seed(config.system.seed)
|
||||
|
||||
# construct the training dataset
|
||||
|
@ -102,7 +103,7 @@ if __name__ == '__main__':
|
|||
model = GPT(config.model)
|
||||
|
||||
# construct the trainer object
|
||||
trainer = Trainer(config, model, train_dataset)
|
||||
trainer = Trainer(config.trainer, model, train_dataset)
|
||||
|
||||
# iteration callback
|
||||
def batch_end_callback(trainer):
|
||||
|
|
Loading…
Reference in New Issue