jamino30's picture
Upload folder using huggingface_hub
d4bbd8e verified
raw
history blame
3.65 kB
import unicodedata
from collections import defaultdict
from itertools import pairwise
def get_adjacent_pair_counts(ids) -> defaultdict:
counts = defaultdict(int)
for pair in pairwise(ids):
counts[pair] += 1
return counts
def merge_pairs(ids, pair, idx):
newids = []
i = 0
n = len(ids)
while i < n:
if i < n - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
def replace_control_characters(s: str) -> str:
chars = []
for ch in s:
if unicodedata.category(ch)[0] != 'C':
chars.append(ch)
else:
chars.append(f'\\u{ord(ch):04x}') # escape
return ''.join(chars)
def render_token(t: bytes) -> str:
s = t.decode('utf-8', errors='replace')
s = replace_control_characters(s)
return s
class Tokenizer:
"""Base class for Tokenizers"""
def __init__(self):
self.merges = {}
self.pattern = ''
self.special_tokens = {}
self.vocab = self._build_vocab()
def train(self, text, vocab_size, verbose=False):
raise NotImplementedError
def encode(self, text):
raise NotImplementedError
def decode(self, ids):
raise NotImplementedError
def _build_vocab(self):
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode('utf-8')
return vocab
def save(self, file_prefix):
# Similar to sentencepiece model saving
model_file = file_prefix + '.model'
with open(model_file, 'w') as f:
f.write('xsbpe v1\n')
f.write(f'{self.pattern}\n')
f.write(f'{len(self.special_tokens)}\n')
for special, idx in self.special_tokens.items():
f.write(f'{special} {idx}\n')
for idx1, idx2 in self.merges:
f.write(f'{idx1} {idx2}\n')
# vocab file meant for human inspection only
vocab_file = file_prefix + '.vocab'
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
with open(vocab_file, 'w', encoding='utf-8') as f:
for idx, token in self.vocab.items():
s = render_token(token)
if idx in inverted_merges:
idx0, idx1 = inverted_merges[idx]
s0 = render_token(self.vocab[idx0])
s1 = render_token(self.vocab[idx1])
f.write(f'[{s0}][{s1}] -> [{s}] {idx}\n')
else:
f.write(f'[{s}] {idx} \n')
def load(self, model_file):
assert model_file.endswith('.model')
merges = {}
special_tokens = {}
idx = 256
with open(model_file, 'r', encoding='utf-8') as f:
version = f.readline().strip()
assert version == 'xsbpe v1'
self.pattern = f.readline().strip()
num_special = int(f.readline().strip())
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens[special] = int(special_idx)
for line in f:
idx1, idx2 = map(int, line.split())
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
self.vocab = self._build_vocab()