Spaces:
Sleeping
Sleeping
File size: 3,645 Bytes
0097326 d5ce4c4 d4bbd8e d5ce4c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import unicodedata
from collections import defaultdict
from itertools import pairwise
def get_adjacent_pair_counts(ids) -> defaultdict:
counts = defaultdict(int)
for pair in pairwise(ids):
counts[pair] += 1
return counts
def merge_pairs(ids, pair, idx):
newids = []
i = 0
n = len(ids)
while i < n:
if i < n - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
def replace_control_characters(s: str) -> str:
chars = []
for ch in s:
if unicodedata.category(ch)[0] != 'C':
chars.append(ch)
else:
chars.append(f'\\u{ord(ch):04x}') # escape
return ''.join(chars)
def render_token(t: bytes) -> str:
s = t.decode('utf-8', errors='replace')
s = replace_control_characters(s)
return s
class Tokenizer:
"""Base class for Tokenizers"""
def __init__(self):
self.merges = {}
self.pattern = ''
self.special_tokens = {}
self.vocab = self._build_vocab()
def train(self, text, vocab_size, verbose=False):
raise NotImplementedError
def encode(self, text):
raise NotImplementedError
def decode(self, ids):
raise NotImplementedError
def _build_vocab(self):
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode('utf-8')
return vocab
def save(self, file_prefix):
# Similar to sentencepiece model saving
model_file = file_prefix + '.model'
with open(model_file, 'w') as f:
f.write('xsbpe v1\n')
f.write(f'{self.pattern}\n')
f.write(f'{len(self.special_tokens)}\n')
for special, idx in self.special_tokens.items():
f.write(f'{special} {idx}\n')
for idx1, idx2 in self.merges:
f.write(f'{idx1} {idx2}\n')
# vocab file meant for human inspection only
vocab_file = file_prefix + '.vocab'
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
with open(vocab_file, 'w', encoding='utf-8') as f:
for idx, token in self.vocab.items():
s = render_token(token)
if idx in inverted_merges:
idx0, idx1 = inverted_merges[idx]
s0 = render_token(self.vocab[idx0])
s1 = render_token(self.vocab[idx1])
f.write(f'[{s0}][{s1}] -> [{s}] {idx}\n')
else:
f.write(f'[{s}] {idx} \n')
def load(self, model_file):
assert model_file.endswith('.model')
merges = {}
special_tokens = {}
idx = 256
with open(model_file, 'r', encoding='utf-8') as f:
version = f.readline().strip()
assert version == 'xsbpe v1'
self.pattern = f.readline().strip()
num_special = int(f.readline().strip())
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens[special] = int(special_idx)
for line in f:
idx1, idx2 = map(int, line.split())
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
self.vocab = self._build_vocab() |