mirror of
https://github.com/karpathy/minGPT
synced 2024-03-29 10:19:59 +01:00
tiny tweaks to printing and some function apis
This commit is contained in:
parent
9ec160cd8c
commit
52cb434db2
|
@ -1,3 +1,5 @@
|
|||
.ipynb_checkpoints/
|
||||
__pycache__/
|
||||
*.swp
|
||||
.env
|
||||
.pylintrc
|
||||
|
|
|
@ -148,7 +148,7 @@ if __name__ == '__main__':
|
|||
trainer = Trainer(config.trainer, model, train_dataset)
|
||||
|
||||
# helper function for the evaluation of a model
|
||||
def eval_split(trainer, split, max_batches=-1):
|
||||
def eval_split(trainer, split, max_batches=None):
|
||||
dataset = {'train':train_dataset, 'test':test_dataset}[split]
|
||||
ndigit = config.data.ndigit
|
||||
results = []
|
||||
|
@ -176,7 +176,7 @@ if __name__ == '__main__':
|
|||
if not correct[i] and mistakes_printed_already < 5: # only print up to 5 mistakes to get a sense
|
||||
mistakes_printed_already += 1
|
||||
print("GPT claims that %d + %d = %d but gt is %d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i]))
|
||||
if max_batches >= 0 and b+1 >= max_batches:
|
||||
if max_batches is not None and b+1 >= max_batches:
|
||||
break
|
||||
rt = torch.tensor(results, dtype=torch.float)
|
||||
print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean()))
|
||||
|
@ -195,16 +195,16 @@ if __name__ == '__main__':
|
|||
|
||||
if trainer.iter_num % 500 == 0:
|
||||
# evaluate both the train and test score
|
||||
train_max_batches = {1: -1, 2: -1, 3: 5}[config.data.ndigit] # if ndigit=2 we can afford the whole train set, ow no
|
||||
train_max_batches = {1: None, 2: None, 3: 5}[config.data.ndigit] # if ndigit=2 we can afford the whole train set, ow no
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
train_score = eval_split(trainer, 'train', max_batches=train_max_batches)
|
||||
test_score = eval_split(trainer, 'test', max_batches=-1)
|
||||
test_score = eval_split(trainer, 'test', max_batches=None)
|
||||
score = train_score + test_score
|
||||
# save the model if this is the best score we've seen so far
|
||||
if score > top_score:
|
||||
top_score = score
|
||||
print("saving model with new top score of %d" % (score, ))
|
||||
print(f"saving model with new top score of {score}")
|
||||
ckpt_path = os.path.join(config.system.work_dir, "model.pt")
|
||||
torch.save(model.state_dict(), ckpt_path)
|
||||
# revert model to training mode
|
||||
|
|
Loading…
Reference in New Issue