1
0
mirror of https://github.com/karpathy/minGPT synced 2024-11-15 19:10:39 +01:00
minGPT/mingpt/utils.py

140 lines
4.9 KiB
Python

import os
import sys
import json
import random
from ast import literal_eval
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def setup_logging(config):
""" monotonous bookkeeping """
work_dir = config.system.work_dir
# create the work directory if it doesn't already exist
os.makedirs(work_dir, exist_ok=True)
# log the args (if any)
with open(os.path.join(work_dir, 'args.txt'), 'w') as f:
f.write(' '.join(sys.argv))
# log the config itself
with open(os.path.join(work_dir, 'config.json'), 'w') as f:
f.write(json.dumps(config.to_dict(), indent=4))
def top_k_logits(logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[:, [-1]]] = -float('Inf')
return out
@torch.no_grad()
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
"""
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
the sequence, feeding the predictions back into the model each time. Clearly the sampling
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
of block_size, unlike an RNN that has an infinite context window.
"""
block_size = model.get_block_size()
model.eval()
for k in range(steps):
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
logits, _ = model(x_cond)
# pluck the logits at the final step and scale by temperature
logits = logits[:, -1, :] / temperature
# optionally crop probabilities to only the top k options
if top_k is not None:
logits = top_k_logits(logits, top_k)
# apply softmax to convert to probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution or take the most likely
if sample:
ix = torch.multinomial(probs, num_samples=1)
else:
_, ix = torch.topk(probs, k=1, dim=-1)
# append to the sequence and continue
x = torch.cat((x, ix), dim=1)
return x
class CfgNode:
""" a lightweight configuration class inspired by yacs """
# TODO: convert to subclass from a dict like in yacs?
# TODO: implement freezing to prevent shooting of own foot
# TODO: additional existence/override checks when reading/writing params?
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __str__(self):
return self._str_helper(0)
def _str_helper(self, indent):
""" need to have a helper to support nested indentation for pretty printing """
parts = []
for k, v in self.__dict__.items():
if isinstance(v, CfgNode):
parts.append("%s:\n" % k)
parts.append(v._str_helper(indent + 1))
else:
parts.append("%s: %s\n" % (k, v))
parts = [' ' * (indent * 4) + p for p in parts]
return "".join(parts)
def to_dict(self):
""" return a dict representation of the config """
return { k: v.to_dict() if isinstance(v, CfgNode) else v for k, v in self.__dict__.items() }
def merge_from_dict(self, d):
self.__dict__.update(d)
def merge_from_args(self, args):
"""
update the configuration from a list of strings that is expected
to come from the command line, i.e. sys.argv[1:].
The arguments are expected to be in the form of `--arg=value`, and
the arg can use . to denote nested sub-attributes. Example:
--model.n_layer=10 --trainer.batch_size=32
"""
for arg in args:
keyval = arg.split('=')
assert len(keyval) == 2, "expecting each override arg to be of form --arg=value, got %s" % arg
key, val = keyval # unpack
# first translate val into a python object
try:
val = literal_eval(val)
"""
need some explanation here.
- if val is simply a string, literal_eval will throw a ValueError
- if val represents a thing (like an 3, 3.14, [1,2,3], False, None, etc.) it will get created
"""
except ValueError:
pass
# find the appropriate object to insert the attribute into
assert key[:2] == '--'
key = key[2:] # strip the '--'
keys = key.split('.')
obj = self
for k in keys[:-1]:
obj = getattr(obj, k)
leaf_key = keys[-1]
# ensure that this attribute exists
assert hasattr(obj, leaf_key), f"{key} is not an attribute that exists in the config"
# overwrite the attribute
print("command line overwriting config attribute %s with %s" % (key, val))
setattr(obj, leaf_key, val)