1
0
Fork 0
mirror of https://github.com/karpathy/minGPT synced 2024-06-10 08:46:10 +02:00

Merge branch 'waynemystir-master'

This commit is contained in:
Andrej Karpathy 2022-03-26 13:48:03 +00:00
commit dffb6a14e2
2 changed files with 21 additions and 9 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
.ipynb_checkpoints/
__pycache__/
*.swp

View File

@ -61,13 +61,8 @@ class Trainer:
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)
def run_epoch(split):
is_train = split == 'train'
def run_epoch(loader, is_train):
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
loader = DataLoader(data, shuffle=True, pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers)
losses = []
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
@ -117,11 +112,27 @@ class Trainer:
best_loss = float('inf')
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):
run_epoch('train')
train_loader = DataLoader(
self.train_dataset,
shuffle=True,
pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers
)
if self.test_dataset is not None:
test_loader = DataLoader(
self.test_dataset,
shuffle=True,
pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers
)
for epoch in range(config.max_epochs):
run_epoch(train_loader, is_train=True)
if self.test_dataset is not None:
test_loss = run_epoch('test')
test_loss = run_epoch(test_loader, is_train=False)
# supports early stopping based on the test loss, or just save always if no test set is provided
good_model = self.test_dataset is None or test_loss < best_loss