1
0
Fork 0
mirror of https://github.com/karpathy/minGPT synced 2024-05-04 06:36:10 +02:00

reorg the bench code to support multigpu training, have to indent properly under __main__

This commit is contained in:
Andrej Karpathy 2020-08-30 11:40:31 -07:00
parent 492b79fb31
commit a796899f65

110
bench.py
View File

@ -18,23 +18,12 @@ from mingpt.model import GPT
from mingpt.lr_decay import WarmupCosineLearningRateDecay
from mingpt.utils import sample
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
torch.backends.cudnn.benchmark = True # autotune kernels
# -----------------------------------------------------------------------------
import os
if int(os.environ.get('USE_LIGHTNING', 0)):
logging.info("USING LIGHTNING!!")
import pytorch_lightning as pl
else:
import mingpt.fake_lightning as pl
logging.info("using our humble trainer")
# -----------------------------------------------------------------------------
class Text8Dataset(Dataset):
@ -85,53 +74,64 @@ class Text8Dataset(Dataset):
# -----------------------------------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument('-x', '--num-epochs', type=int, default=5, help="number of epochs to train for")
parser.add_argument('-b', '--batch-size', type=int, default=64, help="batch size to train with")
parser.add_argument('-l', '--block-size', type=int, default=128, help="block size for the model (length of window of context)")
parser.add_argument('-n', '--num-workers', type=int, default=0, help="number of workers for dataloading")
parser.add_argument('-g', '--num-gpus', type=int, default=1, help="number of gpus to train on")
parser.add_argument('-p', '--pin-memory', type=int, default=1, help="pin memory on dataloaders?")
parser.add_argument('-r', '--precision', type=int, default=32, help="fp precision to use, e.g. 32/16")
parser.add_argument('-o', '--default_root_dir', type=str, default='.', help="best model checkpoint will be written at this location")
args = parser.parse_args()
print(vars(args))
if __name__ == '__main__':
logging.info("preparing the data loaders")
# NOTE: REDUCED DATA SIZE FOR DEBUGGING, TODO CLEAN BEFORE MERGE IF EVER
train_dataset = Text8Dataset('text8', args.block_size, crop=(0, int(1e6)))
val_dataset = Text8Dataset('text8', args.block_size, crop=(int(90e6), int(1e5)), override_vocab=train_dataset.vocab)
test_dataset = Text8Dataset('text8', args.block_size, crop=(int(95e6), int(1e5)), override_vocab=train_dataset.vocab)
common = {'batch_size': args.batch_size, 'pin_memory': bool(args.pin_memory), 'num_workers': args.num_workers}
train_dataloader = DataLoader(train_dataset, shuffle=True, **common)
val_dataloader = DataLoader(val_dataset, shuffle=False, **common)
parser = argparse.ArgumentParser()
parser.add_argument('-x', '--num-epochs', type=int, default=5, help="number of epochs to train for")
parser.add_argument('-b', '--batch-size', type=int, default=64, help="batch size to train with")
parser.add_argument('-l', '--block-size', type=int, default=128, help="block size for the model (length of window of context)")
parser.add_argument('-g', '--num-gpus', type=int, default=1, help="number of gpus to train on")
parser.add_argument('-n', '--num-workers', type=int, default=0, help="number of workers for dataloading")
parser.add_argument('-p', '--pin-memory', type=int, default=0, help="pin memory on dataloaders?")
parser.add_argument('-r', '--precision', type=int, default=32, help="fp precision to use, e.g. 32/16")
parser.add_argument('-o', '--default_root_dir', type=str, default='.', help="best model checkpoint will be written at this location")
args = parser.parse_args()
print(vars(args))
logging.info("creating the model")
model = GPT(train_dataset.vocab_size, args.block_size, n_layer=6, n_head=8, n_embd=256)
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logging.info("preparing the learning rate schedule")
iter_tokens = args.batch_size * args.block_size # number of tokens backpropped in one iteration
epoch_tokens = math.ceil(len(train_dataset) / args.batch_size) * iter_tokens
lr_decay = WarmupCosineLearningRateDecay(learning_rate=6e-4, warmup_tokens=epoch_tokens//2,
final_tokens=args.num_epochs*epoch_tokens)
torch.backends.cudnn.benchmark = True # autotune kernels
t0 = time.time()
logging.info("training...")
trainer = pl.Trainer(gpus=args.num_gpus, max_epochs=args.num_epochs, gradient_clip_val=1.0, callbacks=[lr_decay],
precision=args.precision, default_root_dir=args.default_root_dir)
trainer.fit(model, train_dataloader, val_dataloader)
t1 = time.time()
logging.info("%d epochs took %fs, or %fs/epoch", args.num_epochs, t1 - t0, (t1-t0)/args.num_epochs)
logging.info("preparing the data loaders")
# NOTE: REDUCED DATA SIZE FOR DEBUGGING, TODO CLEAN BEFORE MERGE IF EVER
train_dataset = Text8Dataset('text8', args.block_size, crop=(0, int(90e6)))
val_dataset = Text8Dataset('text8', args.block_size, crop=(int(90e6), int(5e6)), override_vocab=train_dataset.vocab)
test_dataset = Text8Dataset('text8', args.block_size, crop=(int(95e6), int(5e6)), override_vocab=train_dataset.vocab)
common = {'batch_size': args.batch_size, 'pin_memory': bool(args.pin_memory), 'num_workers': args.num_workers}
train_dataloader = DataLoader(train_dataset, shuffle=True, **common)
val_dataloader = DataLoader(val_dataset, shuffle=False, **common)
logging.info("testing...")
test_dataloader = DataLoader(test_dataset, shuffle=False, **common)
trainer.test(test_dataloaders=test_dataloader)
logging.info("creating the model")
model = GPT(train_dataset.vocab_size, args.block_size, n_layer=8, n_head=8, n_embd=256)
logging.info("sampling:")
context = "anarchism originated as a term of"
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...]
if next(model.parameters()).is_cuda:
x = x.cuda()
y = sample(model, x, 200, temperature=1.0, sample=True, top_k=None)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)
logging.info("preparing the learning rate schedule")
iter_tokens = args.batch_size * args.block_size # number of tokens backpropped in one iteration
epoch_tokens = math.ceil(len(train_dataset) / args.batch_size) * iter_tokens
lr_decay = WarmupCosineLearningRateDecay(learning_rate=6e-4, warmup_tokens=epoch_tokens//2,
final_tokens=args.num_epochs*epoch_tokens)
t0 = time.time()
logging.info("training...")
trainer = pl.Trainer(gpus=args.num_gpus, max_epochs=args.num_epochs, gradient_clip_val=1.0, callbacks=[lr_decay],
precision=args.precision, default_root_dir=args.default_root_dir)
trainer.fit(model, train_dataloader, val_dataloader)
t1 = time.time()
logging.info("%d epochs took %fs, or %fs/epoch", args.num_epochs, t1 - t0, (t1-t0)/args.num_epochs)
logging.info("testing...")
test_dataloader = DataLoader(test_dataset, shuffle=False, **common)
trainer.test(test_dataloaders=test_dataloader)
logging.info("sampling:")
context = "anarchism originated as a term of"
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...]
if next(model.parameters()).is_cuda:
x = x.cuda()
y = sample(model, x, 200, temperature=1.0, sample=True, top_k=None)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)