1
0
Fork 0
mirror of https://github.com/karpathy/minGPT synced 2024-03-29 10:19:59 +01:00

tiny tweaks to printing and some function apis

This commit is contained in:
Andrej 2022-05-31 23:07:00 +00:00 committed by GitHub
parent 9ec160cd8c
commit 52cb434db2
2 changed files with 7 additions and 5 deletions

2
.gitignore vendored
View File

@ -1,3 +1,5 @@
.ipynb_checkpoints/
__pycache__/
*.swp
.env
.pylintrc

View File

@ -148,7 +148,7 @@ if __name__ == '__main__':
trainer = Trainer(config.trainer, model, train_dataset)
# helper function for the evaluation of a model
def eval_split(trainer, split, max_batches=-1):
def eval_split(trainer, split, max_batches=None):
dataset = {'train':train_dataset, 'test':test_dataset}[split]
ndigit = config.data.ndigit
results = []
@ -176,7 +176,7 @@ if __name__ == '__main__':
if not correct[i] and mistakes_printed_already < 5: # only print up to 5 mistakes to get a sense
mistakes_printed_already += 1
print("GPT claims that %d + %d = %d but gt is %d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i]))
if max_batches >= 0 and b+1 >= max_batches:
if max_batches is not None and b+1 >= max_batches:
break
rt = torch.tensor(results, dtype=torch.float)
print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean()))
@ -195,16 +195,16 @@ if __name__ == '__main__':
if trainer.iter_num % 500 == 0:
# evaluate both the train and test score
train_max_batches = {1: -1, 2: -1, 3: 5}[config.data.ndigit] # if ndigit=2 we can afford the whole train set, ow no
train_max_batches = {1: None, 2: None, 3: 5}[config.data.ndigit] # if ndigit=2 we can afford the whole train set, ow no
model.eval()
with torch.no_grad():
train_score = eval_split(trainer, 'train', max_batches=train_max_batches)
test_score = eval_split(trainer, 'test', max_batches=-1)
test_score = eval_split(trainer, 'test', max_batches=None)
score = train_score + test_score
# save the model if this is the best score we've seen so far
if score > top_score:
top_score = score
print("saving model with new top score of %d" % (score, ))
print(f"saving model with new top score of {score}")
ckpt_path = os.path.join(config.system.work_dir, "model.pt")
torch.save(model.state_dict(), ckpt_path)
# revert model to training mode