mirror of
https://github.com/karpathy/minGPT
synced 2024-05-04 06:36:10 +02:00
reorg the bench code to support multigpu training, have to indent properly under __main__
This commit is contained in:
parent
492b79fb31
commit
a796899f65
110
bench.py
110
bench.py
|
@ -18,23 +18,12 @@ from mingpt.model import GPT
|
|||
from mingpt.lr_decay import WarmupCosineLearningRateDecay
|
||||
from mingpt.utils import sample
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
torch.backends.cudnn.benchmark = True # autotune kernels
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
import os
|
||||
if int(os.environ.get('USE_LIGHTNING', 0)):
|
||||
logging.info("USING LIGHTNING!!")
|
||||
import pytorch_lightning as pl
|
||||
else:
|
||||
import mingpt.fake_lightning as pl
|
||||
logging.info("using our humble trainer")
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class Text8Dataset(Dataset):
|
||||
|
@ -85,53 +74,64 @@ class Text8Dataset(Dataset):
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-x', '--num-epochs', type=int, default=5, help="number of epochs to train for")
|
||||
parser.add_argument('-b', '--batch-size', type=int, default=64, help="batch size to train with")
|
||||
parser.add_argument('-l', '--block-size', type=int, default=128, help="block size for the model (length of window of context)")
|
||||
parser.add_argument('-n', '--num-workers', type=int, default=0, help="number of workers for dataloading")
|
||||
parser.add_argument('-g', '--num-gpus', type=int, default=1, help="number of gpus to train on")
|
||||
parser.add_argument('-p', '--pin-memory', type=int, default=1, help="pin memory on dataloaders?")
|
||||
parser.add_argument('-r', '--precision', type=int, default=32, help="fp precision to use, e.g. 32/16")
|
||||
parser.add_argument('-o', '--default_root_dir', type=str, default='.', help="best model checkpoint will be written at this location")
|
||||
args = parser.parse_args()
|
||||
print(vars(args))
|
||||
if __name__ == '__main__':
|
||||
|
||||
logging.info("preparing the data loaders")
|
||||
# NOTE: REDUCED DATA SIZE FOR DEBUGGING, TODO CLEAN BEFORE MERGE IF EVER
|
||||
train_dataset = Text8Dataset('text8', args.block_size, crop=(0, int(1e6)))
|
||||
val_dataset = Text8Dataset('text8', args.block_size, crop=(int(90e6), int(1e5)), override_vocab=train_dataset.vocab)
|
||||
test_dataset = Text8Dataset('text8', args.block_size, crop=(int(95e6), int(1e5)), override_vocab=train_dataset.vocab)
|
||||
common = {'batch_size': args.batch_size, 'pin_memory': bool(args.pin_memory), 'num_workers': args.num_workers}
|
||||
train_dataloader = DataLoader(train_dataset, shuffle=True, **common)
|
||||
val_dataloader = DataLoader(val_dataset, shuffle=False, **common)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-x', '--num-epochs', type=int, default=5, help="number of epochs to train for")
|
||||
parser.add_argument('-b', '--batch-size', type=int, default=64, help="batch size to train with")
|
||||
parser.add_argument('-l', '--block-size', type=int, default=128, help="block size for the model (length of window of context)")
|
||||
parser.add_argument('-g', '--num-gpus', type=int, default=1, help="number of gpus to train on")
|
||||
parser.add_argument('-n', '--num-workers', type=int, default=0, help="number of workers for dataloading")
|
||||
parser.add_argument('-p', '--pin-memory', type=int, default=0, help="pin memory on dataloaders?")
|
||||
parser.add_argument('-r', '--precision', type=int, default=32, help="fp precision to use, e.g. 32/16")
|
||||
parser.add_argument('-o', '--default_root_dir', type=str, default='.', help="best model checkpoint will be written at this location")
|
||||
args = parser.parse_args()
|
||||
print(vars(args))
|
||||
|
||||
logging.info("creating the model")
|
||||
model = GPT(train_dataset.vocab_size, args.block_size, n_layer=6, n_head=8, n_embd=256)
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
logging.info("preparing the learning rate schedule")
|
||||
iter_tokens = args.batch_size * args.block_size # number of tokens backpropped in one iteration
|
||||
epoch_tokens = math.ceil(len(train_dataset) / args.batch_size) * iter_tokens
|
||||
lr_decay = WarmupCosineLearningRateDecay(learning_rate=6e-4, warmup_tokens=epoch_tokens//2,
|
||||
final_tokens=args.num_epochs*epoch_tokens)
|
||||
torch.backends.cudnn.benchmark = True # autotune kernels
|
||||
|
||||
t0 = time.time()
|
||||
logging.info("training...")
|
||||
trainer = pl.Trainer(gpus=args.num_gpus, max_epochs=args.num_epochs, gradient_clip_val=1.0, callbacks=[lr_decay],
|
||||
precision=args.precision, default_root_dir=args.default_root_dir)
|
||||
trainer.fit(model, train_dataloader, val_dataloader)
|
||||
t1 = time.time()
|
||||
logging.info("%d epochs took %fs, or %fs/epoch", args.num_epochs, t1 - t0, (t1-t0)/args.num_epochs)
|
||||
logging.info("preparing the data loaders")
|
||||
# NOTE: REDUCED DATA SIZE FOR DEBUGGING, TODO CLEAN BEFORE MERGE IF EVER
|
||||
train_dataset = Text8Dataset('text8', args.block_size, crop=(0, int(90e6)))
|
||||
val_dataset = Text8Dataset('text8', args.block_size, crop=(int(90e6), int(5e6)), override_vocab=train_dataset.vocab)
|
||||
test_dataset = Text8Dataset('text8', args.block_size, crop=(int(95e6), int(5e6)), override_vocab=train_dataset.vocab)
|
||||
common = {'batch_size': args.batch_size, 'pin_memory': bool(args.pin_memory), 'num_workers': args.num_workers}
|
||||
train_dataloader = DataLoader(train_dataset, shuffle=True, **common)
|
||||
val_dataloader = DataLoader(val_dataset, shuffle=False, **common)
|
||||
|
||||
logging.info("testing...")
|
||||
test_dataloader = DataLoader(test_dataset, shuffle=False, **common)
|
||||
trainer.test(test_dataloaders=test_dataloader)
|
||||
logging.info("creating the model")
|
||||
model = GPT(train_dataset.vocab_size, args.block_size, n_layer=8, n_head=8, n_embd=256)
|
||||
|
||||
logging.info("sampling:")
|
||||
context = "anarchism originated as a term of"
|
||||
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...]
|
||||
if next(model.parameters()).is_cuda:
|
||||
x = x.cuda()
|
||||
y = sample(model, x, 200, temperature=1.0, sample=True, top_k=None)[0]
|
||||
completion = ''.join([train_dataset.itos[int(i)] for i in y])
|
||||
print(completion)
|
||||
logging.info("preparing the learning rate schedule")
|
||||
iter_tokens = args.batch_size * args.block_size # number of tokens backpropped in one iteration
|
||||
epoch_tokens = math.ceil(len(train_dataset) / args.batch_size) * iter_tokens
|
||||
lr_decay = WarmupCosineLearningRateDecay(learning_rate=6e-4, warmup_tokens=epoch_tokens//2,
|
||||
final_tokens=args.num_epochs*epoch_tokens)
|
||||
|
||||
t0 = time.time()
|
||||
logging.info("training...")
|
||||
trainer = pl.Trainer(gpus=args.num_gpus, max_epochs=args.num_epochs, gradient_clip_val=1.0, callbacks=[lr_decay],
|
||||
precision=args.precision, default_root_dir=args.default_root_dir)
|
||||
trainer.fit(model, train_dataloader, val_dataloader)
|
||||
t1 = time.time()
|
||||
logging.info("%d epochs took %fs, or %fs/epoch", args.num_epochs, t1 - t0, (t1-t0)/args.num_epochs)
|
||||
|
||||
logging.info("testing...")
|
||||
test_dataloader = DataLoader(test_dataset, shuffle=False, **common)
|
||||
trainer.test(test_dataloaders=test_dataloader)
|
||||
|
||||
logging.info("sampling:")
|
||||
context = "anarchism originated as a term of"
|
||||
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...]
|
||||
if next(model.parameters()).is_cuda:
|
||||
x = x.cuda()
|
||||
y = sample(model, x, 200, temperature=1.0, sample=True, top_k=None)[0]
|
||||
completion = ''.join([train_dataset.itos[int(i)] for i in y])
|
||||
print(completion)
|
||||
|
|
Loading…
Reference in New Issue