julien-c HF staff commited on
Commit
05db436
1 Parent(s): c174bab

encoder.py from https://github.com/openai/gpt-2/blob/master/src/encoder.py

Browse files
Files changed (1) hide show
  1. encoder.py +117 -0
encoder.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Byte pair encoding utilities"""
2
+
3
+ import os
4
+ import json
5
+ import regex as re
6
+ from functools import lru_cache
7
+
8
+ @lru_cache()
9
+ def bytes_to_unicode():
10
+ """
11
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
12
+ The reversible bpe codes work on unicode strings.
13
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
14
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
15
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
16
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
17
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
18
+ """
19
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
20
+ cs = bs[:]
21
+ n = 0
22
+ for b in range(2**8):
23
+ if b not in bs:
24
+ bs.append(b)
25
+ cs.append(2**8+n)
26
+ n += 1
27
+ cs = [chr(n) for n in cs]
28
+ return dict(zip(bs, cs))
29
+
30
+ def get_pairs(word):
31
+ """Return set of symbol pairs in a word.
32
+
33
+ Word is represented as tuple of symbols (symbols being variable-length strings).
34
+ """
35
+ pairs = set()
36
+ prev_char = word[0]
37
+ for char in word[1:]:
38
+ pairs.add((prev_char, char))
39
+ prev_char = char
40
+ return pairs
41
+
42
+ class Encoder:
43
+ def __init__(self, encoder, bpe_merges, errors='replace'):
44
+ self.encoder = encoder
45
+ self.decoder = {v:k for k,v in self.encoder.items()}
46
+ self.errors = errors # how to handle errors in decoding
47
+ self.byte_encoder = bytes_to_unicode()
48
+ self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
49
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
50
+ self.cache = {}
51
+
52
+ # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
53
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
54
+
55
+ def bpe(self, token):
56
+ if token in self.cache:
57
+ return self.cache[token]
58
+ word = tuple(token)
59
+ pairs = get_pairs(word)
60
+
61
+ if not pairs:
62
+ return token
63
+
64
+ while True:
65
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
66
+ if bigram not in self.bpe_ranks:
67
+ break
68
+ first, second = bigram
69
+ new_word = []
70
+ i = 0
71
+ while i < len(word):
72
+ try:
73
+ j = word.index(first, i)
74
+ new_word.extend(word[i:j])
75
+ i = j
76
+ except:
77
+ new_word.extend(word[i:])
78
+ break
79
+
80
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
81
+ new_word.append(first+second)
82
+ i += 2
83
+ else:
84
+ new_word.append(word[i])
85
+ i += 1
86
+ new_word = tuple(new_word)
87
+ word = new_word
88
+ if len(word) == 1:
89
+ break
90
+ else:
91
+ pairs = get_pairs(word)
92
+ word = ' '.join(word)
93
+ self.cache[token] = word
94
+ return word
95
+
96
+ def encode(self, text):
97
+ bpe_tokens = []
98
+ for token in re.findall(self.pat, text):
99
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
100
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
101
+ return bpe_tokens
102
+
103
+ def decode(self, tokens):
104
+ text = ''.join([self.decoder[token] for token in tokens])
105
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
106
+ return text
107
+
108
+ def get_encoder(model_name, models_dir):
109
+ with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
110
+ encoder = json.load(f)
111
+ with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
112
+ bpe_data = f.read()
113
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
114
+ return Encoder(
115
+ encoder=encoder,
116
+ bpe_merges=bpe_merges,
117
+ )