mirror of
https://github.com/karpathy/minGPT
synced 2024-05-21 15:06:05 +02:00
163 lines
6.0 KiB
Plaintext
163 lines
6.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Shows how one can generate text given a prompt and some hyperparameters, using either minGPT or huggingface/transformers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
|
|
"from mingpt.model import GPT\n",
|
|
"from mingpt.utils import sample\n",
|
|
"from mingpt.utils import set_seed\n",
|
|
"set_seed(3407)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"use_mingpt = True # use minGPT or huggingface/transformers model?\n",
|
|
"model_type = 'gpt2-xl'\n",
|
|
"device = 'cuda'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"number of parameters: 1557.61M\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"if use_mingpt:\n",
|
|
" model = GPT.from_pretrained(model_type)\n",
|
|
"else:\n",
|
|
" model = GPT2LMHeadModel.from_pretrained(model_type)\n",
|
|
" model.config.pad_token_id = model.config.eos_token_id # suppress a warning\n",
|
|
"\n",
|
|
"# ship model to device and set to eval mode\n",
|
|
"model.to(device)\n",
|
|
"model.eval();"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def generate(prompt='', num_samples=10, steps=20, do_sample=True):\n",
|
|
" \n",
|
|
" # tokenize the input prompt into integer input sequence\n",
|
|
" tokenizer = GPT2Tokenizer.from_pretrained(model_type)\n",
|
|
" if prompt == '': # to create unconditional samples we feed in the special start token\n",
|
|
" prompt = '<|endoftext|>'\n",
|
|
" encoded_input = tokenizer(prompt, return_tensors='pt').to(device)\n",
|
|
" x = encoded_input['input_ids']\n",
|
|
" x = x.expand(num_samples, -1)\n",
|
|
"\n",
|
|
" # forward the model `steps` times to get samples, in a batch\n",
|
|
" if use_mingpt:\n",
|
|
" y = sample(model=model, x=x, steps=steps, sample=do_sample, top_k=40)\n",
|
|
" else:\n",
|
|
" y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n",
|
|
" \n",
|
|
" for i in range(num_samples):\n",
|
|
" out = tokenizer.decode(y[i].cpu().squeeze())\n",
|
|
" print('-'*80)\n",
|
|
" print(out)\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2022-07-08 23:51:10.949993: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
|
|
"2022-07-08 23:51:10.950042: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"--------------------------------------------------------------------------------\n",
|
|
"Andrej Karpathy, the chief of the criminal investigation department, said during a news conference, \"We still have a lot of\n",
|
|
"--------------------------------------------------------------------------------\n",
|
|
"Andrej Karpathy, the man whom most of America believes is the architect of the current financial crisis. He runs the National Council\n",
|
|
"--------------------------------------------------------------------------------\n",
|
|
"Andrej Karpathy, the head of the Department for Regional Reform of Bulgaria and an MP in the centre-right GERB party\n",
|
|
"--------------------------------------------------------------------------------\n",
|
|
"Andrej Karpathy, the former head of the World Bank's IMF department, who worked closely with the IMF. The IMF had\n",
|
|
"--------------------------------------------------------------------------------\n",
|
|
"Andrej Karpathy, the vice president for innovation and research at Citi who oversaw the team's work to make sense of the\n",
|
|
"--------------------------------------------------------------------------------\n",
|
|
"Andrej Karpathy, the CEO of OOAK Research, said that the latest poll indicates that it won't take much to\n",
|
|
"--------------------------------------------------------------------------------\n",
|
|
"Andrej Karpathy, the former prime minister of Estonia was at the helm of a three-party coalition when parliament met earlier this\n",
|
|
"--------------------------------------------------------------------------------\n",
|
|
"Andrej Karpathy, the director of the Institute of Economic and Social Research, said if the rate of return is only 5 per\n",
|
|
"--------------------------------------------------------------------------------\n",
|
|
"Andrej Karpathy, the minister of commerce for Latvia's western neighbour: \"The deal means that our two countries have reached more\n",
|
|
"--------------------------------------------------------------------------------\n",
|
|
"Andrej Karpathy, the state's environmental protection commissioner. \"That's why we have to keep these systems in place.\"\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"generate(prompt='Andrej Karpathy, the', num_samples=10, steps=20)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3.10.4 64-bit",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.4"
|
|
},
|
|
"orig_nbformat": 4,
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|