1
0
Fork 0
mirror of https://github.com/karpathy/minGPT synced 2024-05-18 21:46:03 +02:00
minGPT/generate.ipynb

167 lines
6.2 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Shows how one can generate text given a prompt and some hyperparameters, using either minGPT or huggingface/transformers"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
"from mingpt.model import GPT\n",
"from mingpt.utils import set_seed\n",
"from mingpt.bpe import BPETokenizer\n",
"set_seed(3407)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"use_mingpt = True # use minGPT or huggingface/transformers model?\n",
"model_type = 'gpt2-xl'\n",
"device = 'cuda'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of parameters: 1557.61M\n"
]
}
],
"source": [
"if use_mingpt:\n",
" model = GPT.from_pretrained(model_type)\n",
"else:\n",
" model = GPT2LMHeadModel.from_pretrained(model_type)\n",
" model.config.pad_token_id = model.config.eos_token_id # suppress a warning\n",
"\n",
"# ship model to device and set to eval mode\n",
"model.to(device)\n",
"model.eval();"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def generate(prompt='', num_samples=10, steps=20, do_sample=True):\n",
" \n",
" # tokenize the input prompt into integer input sequence\n",
" if use_mingpt:\n",
" tokenizer = BPETokenizer()\n",
" if prompt == '':\n",
" # to create unconditional samples...\n",
" # manually create a tensor with only the special <|endoftext|> token\n",
" # similar to what openai's code does here https://github.com/openai/gpt-2/blob/master/src/generate_unconditional_samples.py\n",
" x = torch.tensor([[tokenizer.encoder.encoder['<|endoftext|>']]], dtype=torch.long)\n",
" else:\n",
" x = tokenizer(prompt).to(device)\n",
" else:\n",
" tokenizer = GPT2Tokenizer.from_pretrained(model_type)\n",
" if prompt == '': \n",
" # to create unconditional samples...\n",
" # huggingface/transformers tokenizer special cases these strings\n",
" prompt = '<|endoftext|>'\n",
" encoded_input = tokenizer(prompt, return_tensors='pt').to(device)\n",
" x = encoded_input['input_ids']\n",
" \n",
" # we'll process all desired num_samples in a batch, so expand out the batch dim\n",
" x = x.expand(num_samples, -1)\n",
"\n",
" # forward the model `steps` times to get samples, in a batch\n",
" y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n",
" \n",
" for i in range(num_samples):\n",
" out = tokenizer.decode(y[i].cpu().squeeze())\n",
" print('-'*80)\n",
" print(out)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the chief of the criminal investigation department, said during a news conference, \"We still have a lot of\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the man whom most of America believes is the architect of the current financial crisis. He runs the National Council\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the head of the Department for Regional Reform of Bulgaria and an MP in the centre-right GERB party\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the former head of the World Bank's IMF department, who worked closely with the IMF. The IMF had\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the vice president for innovation and research at Citi who oversaw the team's work to make sense of the\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the CEO of OOAK Research, said that the latest poll indicates that it won't take much to\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the former prime minister of Estonia was at the helm of a three-party coalition when parliament met earlier this\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the director of the Institute of Economic and Social Research, said if the rate of return is only 5 per\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the minister of commerce for Latvia's western neighbour: \"The deal means that our two countries have reached more\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the state's environmental protection commissioner. \"That's why we have to keep these systems in place.\"\n",
"\n"
]
}
],
"source": [
"generate(prompt='Andrej Karpathy, the', num_samples=10, steps=20)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.4 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}