mirror of
https://github.com/karpathy/minGPT
synced 2024-11-15 19:10:39 +01:00
140 lines
4.9 KiB
Python
140 lines
4.9 KiB
Python
|
|
import os
|
|
import sys
|
|
import json
|
|
import random
|
|
from ast import literal_eval
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
def setup_logging(config):
|
|
""" monotonous bookkeeping """
|
|
work_dir = config.system.work_dir
|
|
# create the work directory if it doesn't already exist
|
|
os.makedirs(work_dir, exist_ok=True)
|
|
# log the args (if any)
|
|
with open(os.path.join(work_dir, 'args.txt'), 'w') as f:
|
|
f.write(' '.join(sys.argv))
|
|
# log the config itself
|
|
with open(os.path.join(work_dir, 'config.json'), 'w') as f:
|
|
f.write(json.dumps(config.to_dict(), indent=4))
|
|
|
|
def top_k_logits(logits, k):
|
|
v, ix = torch.topk(logits, k)
|
|
out = logits.clone()
|
|
out[out < v[:, [-1]]] = -float('Inf')
|
|
return out
|
|
|
|
@torch.no_grad()
|
|
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
|
|
"""
|
|
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
|
|
the sequence, feeding the predictions back into the model each time. Clearly the sampling
|
|
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
|
|
of block_size, unlike an RNN that has an infinite context window.
|
|
"""
|
|
block_size = model.get_block_size()
|
|
model.eval()
|
|
for k in range(steps):
|
|
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
|
logits, _ = model(x_cond)
|
|
# pluck the logits at the final step and scale by temperature
|
|
logits = logits[:, -1, :] / temperature
|
|
# optionally crop probabilities to only the top k options
|
|
if top_k is not None:
|
|
logits = top_k_logits(logits, top_k)
|
|
# apply softmax to convert to probabilities
|
|
probs = F.softmax(logits, dim=-1)
|
|
# sample from the distribution or take the most likely
|
|
if sample:
|
|
ix = torch.multinomial(probs, num_samples=1)
|
|
else:
|
|
_, ix = torch.topk(probs, k=1, dim=-1)
|
|
# append to the sequence and continue
|
|
x = torch.cat((x, ix), dim=1)
|
|
|
|
return x
|
|
|
|
class CfgNode:
|
|
""" a lightweight configuration class inspired by yacs """
|
|
# TODO: convert to subclass from a dict like in yacs?
|
|
# TODO: implement freezing to prevent shooting of own foot
|
|
# TODO: additional existence/override checks when reading/writing params?
|
|
|
|
def __init__(self, **kwargs):
|
|
self.__dict__.update(kwargs)
|
|
|
|
def __str__(self):
|
|
return self._str_helper(0)
|
|
|
|
def _str_helper(self, indent):
|
|
""" need to have a helper to support nested indentation for pretty printing """
|
|
parts = []
|
|
for k, v in self.__dict__.items():
|
|
if isinstance(v, CfgNode):
|
|
parts.append("%s:\n" % k)
|
|
parts.append(v._str_helper(indent + 1))
|
|
else:
|
|
parts.append("%s: %s\n" % (k, v))
|
|
parts = [' ' * (indent * 4) + p for p in parts]
|
|
return "".join(parts)
|
|
|
|
def to_dict(self):
|
|
""" return a dict representation of the config """
|
|
return { k: v.to_dict() if isinstance(v, CfgNode) else v for k, v in self.__dict__.items() }
|
|
|
|
def merge_from_dict(self, d):
|
|
self.__dict__.update(d)
|
|
|
|
def merge_from_args(self, args):
|
|
"""
|
|
update the configuration from a list of strings that is expected
|
|
to come from the command line, i.e. sys.argv[1:].
|
|
|
|
The arguments are expected to be in the form of `--arg=value`, and
|
|
the arg can use . to denote nested sub-attributes. Example:
|
|
|
|
--model.n_layer=10 --trainer.batch_size=32
|
|
"""
|
|
for arg in args:
|
|
|
|
keyval = arg.split('=')
|
|
assert len(keyval) == 2, "expecting each override arg to be of form --arg=value, got %s" % arg
|
|
key, val = keyval # unpack
|
|
|
|
# first translate val into a python object
|
|
try:
|
|
val = literal_eval(val)
|
|
"""
|
|
need some explanation here.
|
|
- if val is simply a string, literal_eval will throw a ValueError
|
|
- if val represents a thing (like an 3, 3.14, [1,2,3], False, None, etc.) it will get created
|
|
"""
|
|
except ValueError:
|
|
pass
|
|
|
|
# find the appropriate object to insert the attribute into
|
|
assert key[:2] == '--'
|
|
key = key[2:] # strip the '--'
|
|
keys = key.split('.')
|
|
obj = self
|
|
for k in keys[:-1]:
|
|
obj = getattr(obj, k)
|
|
leaf_key = keys[-1]
|
|
|
|
# ensure that this attribute exists
|
|
assert hasattr(obj, leaf_key), f"{key} is not an attribute that exists in the config"
|
|
|
|
# overwrite the attribute
|
|
print("command line overwriting config attribute %s with %s" % (key, val))
|
|
setattr(obj, leaf_key, val)
|