mirror of
https://github.com/karpathy/minGPT
synced 2024-06-10 08:46:10 +02:00
Merge branch 'waynemystir-master'
This commit is contained in:
commit
dffb6a14e2
|
@ -1,2 +1,3 @@
|
|||
.ipynb_checkpoints/
|
||||
__pycache__/
|
||||
*.swp
|
||||
|
|
|
@ -61,13 +61,8 @@ class Trainer:
|
|||
raw_model = model.module if hasattr(self.model, "module") else model
|
||||
optimizer = raw_model.configure_optimizers(config)
|
||||
|
||||
def run_epoch(split):
|
||||
is_train = split == 'train'
|
||||
def run_epoch(loader, is_train):
|
||||
model.train(is_train)
|
||||
data = self.train_dataset if is_train else self.test_dataset
|
||||
loader = DataLoader(data, shuffle=True, pin_memory=True,
|
||||
batch_size=config.batch_size,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
losses = []
|
||||
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
|
||||
|
@ -117,11 +112,27 @@ class Trainer:
|
|||
|
||||
best_loss = float('inf')
|
||||
self.tokens = 0 # counter used for learning rate decay
|
||||
for epoch in range(config.max_epochs):
|
||||
|
||||
run_epoch('train')
|
||||
train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
batch_size=config.batch_size,
|
||||
num_workers=config.num_workers
|
||||
)
|
||||
if self.test_dataset is not None:
|
||||
test_loader = DataLoader(
|
||||
self.test_dataset,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
batch_size=config.batch_size,
|
||||
num_workers=config.num_workers
|
||||
)
|
||||
|
||||
for epoch in range(config.max_epochs):
|
||||
run_epoch(train_loader, is_train=True)
|
||||
if self.test_dataset is not None:
|
||||
test_loss = run_epoch('test')
|
||||
test_loss = run_epoch(test_loader, is_train=False)
|
||||
|
||||
# supports early stopping based on the test loss, or just save always if no test set is provided
|
||||
good_model = self.test_dataset is None or test_loss < best_loss
|
||||
|
|
Loading…
Reference in New Issue