1
0
Fork 0
mirror of https://github.com/karpathy/minGPT synced 2024-03-29 10:19:59 +01:00

add a refactored BPE encoder from openai. Basically I dont super trust the huggingface tokenizer, the implementation sprawls multiple files and inheritance and has special magic handling around AddedTokens that I don't fully follow. Prefer to roll our own explicit implementation here that exactly mirrors the code of OpenAI and nothing else

This commit is contained in:
Andrej 2022-07-12 02:01:41 +00:00 committed by GitHub
parent 40635a91f4
commit 9642f40b83
3 changed files with 317 additions and 27 deletions

View File

@ -13,9 +13,11 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
"from mingpt.model import GPT\n",
"from mingpt.utils import set_seed\n",
"from mingpt.bpe import BPETokenizer\n",
"set_seed(3407)"
]
},
@ -65,11 +67,25 @@
"def generate(prompt='', num_samples=10, steps=20, do_sample=True):\n",
" \n",
" # tokenize the input prompt into integer input sequence\n",
" tokenizer = GPT2Tokenizer.from_pretrained(model_type)\n",
" if prompt == '': # to create unconditional samples we feed in the special start token\n",
" prompt = '<|endoftext|>'\n",
" encoded_input = tokenizer(prompt, return_tensors='pt').to(device)\n",
" x = encoded_input['input_ids']\n",
" if use_mingpt:\n",
" tokenizer = BPETokenizer()\n",
" if prompt == '':\n",
" # to create unconditional samples...\n",
" # manually create a tensor with only the special <|endoftext|> token\n",
" # similar to what openai's code does here https://github.com/openai/gpt-2/blob/master/src/generate_unconditional_samples.py\n",
" x = torch.tensor([[tokenizer.encoder.encoder['<|endoftext|>']]], dtype=torch.long)\n",
" else:\n",
" x = tokenizer(prompt).to(device)\n",
" else:\n",
" tokenizer = GPT2Tokenizer.from_pretrained(model_type)\n",
" if prompt == '': \n",
" # to create unconditional samples...\n",
" # huggingface/transformers tokenizer special cases these strings\n",
" prompt = '<|endoftext|>'\n",
" encoded_input = tokenizer(prompt, return_tensors='pt').to(device)\n",
" x = encoded_input['input_ids']\n",
" \n",
" # we'll process all desired num_samples in a batch, so expand out the batch dim\n",
" x = x.expand(num_samples, -1)\n",
"\n",
" # forward the model `steps` times to get samples, in a batch\n",
@ -87,14 +103,6 @@
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-07-11 18:42:21.744061: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
"2022-07-11 18:42:21.744099: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
]
},
{
"name": "stdout",
"output_type": "stream",

279
mingpt/bpe.py Normal file
View File

@ -0,0 +1,279 @@
"""
bpe is short for Byte Pair Encoder. It translates arbitrary utf-8 strings into
sequences of integers, where each integer represents small chunks of commonly
occuring characters. This implementation is based on openai's gpt2 encoder.py:
https://github.com/openai/gpt-2/blob/master/src/encoder.py
but was midly modified because the original implementation is a bit confusing.
I also tried to add as many comments as possible, my own understanding of what's
going on.
"""
import os
import json
import regex as re
import requests
import torch
# -----------------------------------------------------------------------------
def bytes_to_unicode():
"""
Everu possible byte (really an integer 0..255) gets mapped by OpenAI to a unicode
character that represents it visually. Some bytes have their appearance preserved
because they don't cause any trouble. These are defined in list bs. For example:
chr(33) returns "!", so in the returned dictionary we simply have d[33] -> "!".
However, chr(0), for example, is '\x00', which looks ugly. So OpenAI maps these
bytes, into new characters in a range where chr() returns a single nice character.
So in the final dictionary we have d[0] -> 'Ā' instead, which is just chr(0 + 2**8).
In particular, the space character is 32, which we can see by ord(' '). Instead,
this function will shift space (32) by 256 to 288, so d[32] -> 'Ġ'.
So this is just a simple one-to-one mapping of bytes 0..255 into unicode characters
that "look nice", either in their original form, or a funny shifted character
like 'Ā', or 'Ġ', etc.
"""
# the 188 integers that render fine in their original form and need no shifting
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict
# now get the representations of the other 68 integers that do need shifting
# each will get mapped chr(256 + n), where n will grow from 0...67 in the loop
n = 0
for b in range(2**8):
if b not in bs:
# if this byte is "ugly" then map it to the next available "nice" character
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
d = dict(zip(bs, cs))
return d
def get_pairs(word):
"""
Return all bigrams as a set of tuples, of consecutive elements in the iterable word.
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges):
# byte encoder/decoder
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
# bpe token encoder/decoder
self.encoder = encoder
self.decoder = {v:k for k,v in self.encoder.items()}
# bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
# the splitting pattern used for pre-tokenization
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions <-- original openai comment
"""
ok so what is this regex looking for, exactly?
python re reference: https://docs.python.org/3/library/re.html
- the vertical bars | is OR, so re.findall will chunkate text as the pieces match, from left to right
- '\'s' would split up things like Andrej's -> (Andrej, 's)
- ' ?\p{L}': optional space followed by 1+ unicode code points in the category "letter"
- ' ?\p{N}': optional space followed by 1+ unicode code points in the category "number"
- ' ?[^\s\p{L}\p{N}]+': optional space, then 1+ things that are NOT a whitespace, letter or number
- '\s+(?!\S)': 1+ whitespace characters (e.g. space or tab or etc) UNLESS they are followed by non-whitespace
so this will consume whitespace characters in a sequence but exclude the last whitespace in
that sequence. that last whitespace has the opportunity to then match the optional ' ?' in
earlier patterns.
- '\s+': 1+ whitespace characters, intended probably to catch a full trailing sequence of whitespaces at end of string
So TLDR:
- we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens
- we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces
"""
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.cache = {}
def bpe(self, token):
"""
this function uses self.bpe_ranks to iterative merge all the possible bpe tokens
up the tree. token is a string of one individual 'word' (after regex tokenization)
and after byte encoding, e.g. 'Ġthere'.
"""
# token is a string of one individual 'word', after byte encoding, e.g. 'Ġthere'
# memoization, for efficiency
if token in self.cache:
return self.cache[token]
word = tuple(token) # individual characters that make up the token, in a tuple
pairs = get_pairs(word) # get all bigrams
if not pairs:
return token
while True:
# find the next lowest rank bigram that can be merged
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break # no more bigrams are eligible to be merged
first, second = bigram
# we will now replace all occurences of (first, second) in the list of current
# words into one merged token first_second, in the output list new_words
new_word = []
i = 0
while i < len(word):
# find the next occurence of first in the sequence of current words
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
# if this occurence is also followed by second, then merge them into one
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
# all occurences of (first, second) have been merged to first_second
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
# concat all words into a string, and use ' ' as the separator. Note that
# by now all characters have been byte encoded, guaranteeing that ' ' is
# not used in the actual data and is a 'special' delimiter character
word = ' '.join(word)
# cache the result and return
self.cache[token] = word
return word
def encode(self, text):
""" string goes in, list of integers comes out """
bpe_idx = []
# pre-tokenize the input text into string tokens (words, roughly speaking)
tokens = re.findall(self.pat, text)
# process each token into BPE integers
for token in tokens:
# encode the token as a bytes (b'') object
token_bytes = token.encode('utf-8')
# translate all bytes to their unicode string representation and flatten
token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
# perform all the applicable bpe merges according to self.bpe_ranks
token_merged = self.bpe(token_translated).split(' ')
# translate all bpe tokens to integers
token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
# extend our running list of all output integers
bpe_idx.extend(token_ix)
return bpe_idx
def encode_and_show_work(self, text):
""" debugging function, same as encode but returns all intermediate work """
bpe_idx = []
parts = []
tokens = re.findall(self.pat, text)
for token in tokens:
token_bytes = token.encode('utf-8')
token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
token_merged = self.bpe(token_translated).split(' ')
token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
bpe_idx.extend(token_ix)
parts.append({
'token': token,
'token_bytes': token_bytes,
'token_translated': token_translated,
'token_merged': token_merged,
'token_ix': token_ix,
})
out = {
'bpe_idx': bpe_idx, # the actual output sequence
'tokens': tokens, # result of pre-tokenization
'parts': parts, # intermediates for each token part
}
return out
def decode(self, bpe_idx):
""" list of integers comes in, string comes out """
# inverse map the integers to get the tokens
tokens_merged = [self.decoder[token] for token in bpe_idx]
# inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytes
tokens_flat = ''.join(tokens_merged)
tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat])
# recover the full utf-8 string
text = tokens_bytes.decode('utf-8', errors='replace')
return text
def get_file(local_file, remote_file):
""" downloads remote_file to local_file if necessary """
if not os.path.isfile(local_file):
print(f"downloading {remote_file} to {local_file}")
response = requests.get(remote_file)
open(local_file, "wb").write(response.content)
def get_encoder():
"""
Returns an instance of the GPT BPE Encoder/Decoder
and handles caching of "database" files.
"""
home_dir = os.path.expanduser('~')
cache_dir = os.path.join(home_dir, '.cache', 'mingpt')
os.makedirs(cache_dir, exist_ok=True)
# load encoder.json that has the raw mappings from token -> bpe index
encoder_local_file = os.path.join(cache_dir, 'encoder.json')
encoder_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json'
get_file(encoder_local_file, encoder_remote_file)
with open(encoder_local_file, 'r') as f:
encoder = json.load(f)
assert len(encoder) == 50257 # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token
# load vocab.bpe that contains the bpe merges, i.e. the bpe tree structure
# in the form tuples (a, b), that indicate that (a, b) is to be merged to one token ab
vocab_local_file = os.path.join(cache_dir, 'vocab.bpe')
vocab_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe'
get_file(vocab_local_file, vocab_remote_file)
with open(vocab_local_file, 'r', encoding="utf-8") as f:
bpe_data = f.read()
# light postprocessing: strip the version on first line and the last line is a blank
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
assert len(bpe_merges) == 50000 # 50,000 merged tokens
# construct the Encoder object and return
enc = Encoder(encoder, bpe_merges)
return enc
# -----------------------------------------------------------------------------
class BPETokenizer:
""" PyTorch-aware class that wraps the Encoder above """
def __init__(self):
self.encoder = get_encoder()
def __call__(self, text, return_tensors='pt'):
# PyTorch only; here because we want to match huggingface/transformers interface
assert return_tensors == 'pt'
# single string input for now, in the future potentially a list of strings
assert isinstance(text, str)
# encode and create a "batch dimension" of 1
idx = [self.encoder.encode(text)]
# wrap into PyTorch tensor
out = torch.tensor(idx, dtype=torch.long)
return out
def decode(self, idx):
# ensure a simple 1D tensor for now
assert idx.ndim == 1
# decode indices to text
text = self.encoder.decode(idx.tolist())
return text

View File

@ -6,7 +6,7 @@ import unittest
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from mingpt.model import GPT
from mingpt.bpe import BPETokenizer
# -----------------------------------------------------------------------------
class TestHuggingFaceImport(unittest.TestCase):
@ -14,7 +14,7 @@ class TestHuggingFaceImport(unittest.TestCase):
def test_gpt2(self):
model_type = 'gpt2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
prompt = "Hello, my dog is a little"
prompt = "Hello!!!!!!!!!? 🤗, my dog is a little"
# create a minGPT and a huggingface/transformers model
model = GPT.from_pretrained(model_type)
@ -23,32 +23,35 @@ class TestHuggingFaceImport(unittest.TestCase):
# ship both to device
model.to(device)
model_hf.to(device)
# set both to eval mode
model.eval()
model_hf.eval()
# tokenize an input prompt
tokenizer = GPT2Tokenizer.from_pretrained(model_type)
# tokenize input prompt
# ... with mingpt
tokenizer = BPETokenizer()
x1 = tokenizer(prompt).to(device)
# ... with huggingface/transformers
tokenizer_hf = GPT2Tokenizer.from_pretrained(model_type)
model_hf.config.pad_token_id = model_hf.config.eos_token_id # suppress a warning
if prompt == '': # to create unconditional samples we feed in the special start token
prompt = '<|endoftext|>'
encoded_input = tokenizer(prompt, return_tensors='pt').to(device)
x = encoded_input['input_ids']
encoded_input = tokenizer_hf(prompt, return_tensors='pt').to(device)
x2 = encoded_input['input_ids']
# ensure the logits match exactly
logits1, loss = model(x)
logits2 = model_hf(x).logits
logits1, loss = model(x1)
logits2 = model_hf(x2).logits
self.assertTrue(torch.allclose(logits1, logits2))
# now draw the argmax samples from each
y1 = model.generate(x, max_new_tokens=20, do_sample=False)[0]
y2 = model_hf.generate(x, max_new_tokens=20, do_sample=False)[0]
y1 = model.generate(x1, max_new_tokens=20, do_sample=False)[0]
y2 = model_hf.generate(x2, max_new_tokens=20, do_sample=False)[0]
self.assertTrue(torch.equal(y1, y2)) # compare the raw sampled indices
# convert indices to strings
out1 = tokenizer.decode(y1.cpu().squeeze())
out2 = tokenizer.decode(y2.cpu().squeeze())
out2 = tokenizer_hf.decode(y2.cpu().squeeze())
self.assertTrue(out1 == out2) # compare the exact output strings too
if __name__ == '__main__':
unittest.main()
unittest.main()