Spaces:
Sleeping
Sleeping
import time | |
from typing import * | |
import re | |
import json | |
import numba | |
def sample_vocab(tokens: Iterable[str], vocab_size: Optional[int] = None, | |
vocab_coverage: Optional[float] = None) -> List[str]: | |
assert (vocab_size is not None and vocab_coverage is None) or \ | |
(vocab_size is None and vocab_coverage is not None), "vocab_size [or] vocab_coverage need specified" | |
token_count = {} | |
for c in tokens: | |
token_count[c] = token_count.get(c, 0) + 1 | |
if vocab_size is not None: | |
token_count = list(token_count.items()) | |
token_count.sort(key=lambda i: i[1], reverse=True) | |
vocab = [c[0] for c in token_count[:vocab_size]] | |
else: | |
total_count = sum(token_count.values()) | |
token_freq = [(c, i / total_count) for c, i in token_count.items()] | |
token_freq.sort(key=lambda i: i[1], reverse=True) | |
freq_sum = 0.0 | |
split = 0 | |
for split in range(len(token_freq)): | |
freq_sum += token_freq[split][1] | |
if freq_sum >= vocab_coverage: | |
break | |
vocab = [c[0] for c in token_freq[:split + 1]] | |
return vocab | |
class CharTokenizer: | |
def __init__(self, corpus: str, vocab_size: Optional[int] = None, vocab_coverage: Optional[float] = None, | |
reserved_vocab: Optional[List[str]] = None, unk_literal: str = '<unk>'): | |
if reserved_vocab is not None: | |
assert len(reserved_vocab) == len(set(reserved_vocab)), 'no duplicate is allowed in reserved vocab' | |
assert unk_literal not in reserved_vocab, f'unk literal "{unk_literal}" cannot be in reserved vocab' | |
else: | |
reserved_vocab = [] | |
vocab = reserved_vocab.copy() if reserved_vocab is not None else [] | |
vocab += sample_vocab(corpus, vocab_size - len(vocab) - 1, vocab_coverage) | |
self.s2i = {s: i + 1 for i, s in enumerate(vocab)} | |
self.s2i[unk_literal] = 0 | |
self.i2s = {i: s for s, i in self.s2i.items()} | |
self.special_vocab = set(reserved_vocab + [unk_literal]) | |
self.unk_literal = unk_literal | |
def encode(self, text: str) -> List[int]: | |
cursor, ids = 0, [] | |
while cursor < len(text): | |
for s in self.special_vocab: | |
if text[cursor:].startswith(s): | |
ids.append(self.s2i[s]) | |
cursor += len(s) | |
break | |
else: | |
ids.append(self.s2i.get(text[cursor], self.s2i.get(self.unk_literal))) | |
cursor += 1 | |
return ids | |
def decode(self, ids: List[int]) -> str: | |
return ''.join(self.i2s[i] for i in ids) | |
def get_vocab_mapping(self): | |
return self.s2i | |
class WordTokenizer: | |
def __init__(self, corpus: str, vocab_size: Optional[int] = None, vocab_coverage: Optional[float] = None, | |
reserved_vocab: Optional[List[str]] = None, unk_literal: str = '<unk>'): | |
if reserved_vocab is not None: | |
assert len(reserved_vocab) == len(set(reserved_vocab)), 'no duplicate is allowed in reserved vocab' | |
assert unk_literal not in reserved_vocab, f'unk literal "{unk_literal}" cannot be in reserved vocab' | |
else: | |
reserved_vocab = [] | |
vocab = reserved_vocab.copy() if reserved_vocab is not None else [] | |
tokens = (c[0] if c[0] != '' else c[1] for c in re.finditer(r'(\w+)|(\W)', corpus)) | |
vocab += sample_vocab(tokens, vocab_size - len(vocab) - 1, vocab_coverage) | |
self.s2i = {s: i + 1 for i, s in enumerate(vocab)} | |
self.s2i[unk_literal] = 0 | |
self.i2s = {i: s for s, i in self.s2i.items()} | |
self.special_vocab = set(reserved_vocab + [unk_literal]) | |
self.unk_literal = unk_literal | |
def encode(self, text: str) -> List[int]: | |
specials = '|'.join(f'{i}' for i in self.special_vocab) | |
tokens = (c[0] if c[0] != '' else c[1] for c in re.finditer(rf'({specials}|\w+)|(\W)', text)) | |
return [self.s2i.get(t, self.s2i[self.unk_literal]) for t in tokens] | |
def decode(self, ids: List[int]) -> str: | |
return ''.join(self.i2s[i] for i in ids) | |
def get_vocab_mapping(self): | |
return self.s2i | |
def get_vocab_size(self): | |
return len(self.s2i) | |
def eval_vocab_coverage(self, corpus: str): | |
encoded = self.encode(corpus) | |
return 1 - (len([i for i in encoded if i == 0]) / len(encoded)) | |
class TRIETokenizer: | |
def split_bytes(data: bytes): | |
return [b'%c' % i for i in data] | |
def __init__(self, vocab_file: str): | |
self.nodes = [(b'', -1, -1, [-1 for _ in range(256)])] # node value, parent index, token id, children | |
with open(vocab_file, 'r') as file: | |
vocabs = json.load(file) | |
vocabs.sort(key=lambda i: len(i['bytes'])) | |
for entry in vocabs: | |
self.add_vocab(bytes(entry['bytes']), entry['id']) | |
self.id_to_bytes = {i['id']: i['bytes'] for i in vocabs} | |
def add_vocab(self, vocab_bytes: bytes, vocab_id: int): | |
cur_node_idx = 0 | |
for i, b in enumerate(vocab_bytes): | |
cur_node = self.nodes[cur_node_idx] | |
if cur_node[3][b] != -1: | |
cur_node_idx = cur_node[3][b] | |
else: | |
new_node_idx = len(self.nodes) | |
self.nodes.append((vocab_bytes, cur_node_idx, vocab_id if i == len(vocab_bytes) - 1 else -1, | |
[-1 for _ in range(256)])) | |
cur_node[3][b] = new_node_idx | |
cur_node_idx = new_node_idx | |
def attempt_match(self, match_bytes: bytes): | |
match_length, match_token_id = -1, -1 | |
cur_node_idx, depth = 0, 0 | |
for i, b in enumerate(match_bytes): | |
match_node_idx = self.nodes[cur_node_idx][3][b] | |
if match_node_idx == -1: | |
break | |
cur_node = self.nodes[match_node_idx] | |
if cur_node[2] != -1: | |
match_length = depth | |
match_token_id = cur_node[2] | |
cur_node_idx = match_node_idx | |
depth += 1 | |
return match_length, match_token_id | |
def encode(self, text: str): | |
text_bytes = text.encode('utf-8') | |
tokens, length = [], 0 | |
while length < len(text_bytes): | |
offset, token_id = self.attempt_match(text_bytes[length:]) | |
assert offset >= 0 | |
tokens.append(token_id) | |
length += offset + 1 | |
return tokens | |
def decode(self, token_ids: List[int]): | |
return bytes([t for i in token_ids for t in self.id_to_bytes[i]]).decode('utf-8', errors='replace') | |
def get_vocab_size(self): | |
return len(self.id_to_bytes) | |
def trie_attempt_match_jit(trie_nodes, match_bytes: bytes): | |
match_length, match_token_id = -1, -1 | |
cur_node_idx, depth = 0, 0 | |
for i, b in enumerate(match_bytes): | |
match_node_idx = trie_nodes[cur_node_idx][3][int(b)] | |
if match_node_idx == -1: | |
break | |
cur_node = trie_nodes[match_node_idx] | |
if cur_node[2] != -1: | |
match_length = depth | |
match_token_id = cur_node[2] | |
cur_node_idx = match_node_idx | |
depth += 1 | |
return match_length, match_token_id | |
def trie_encode_jit(trie_nodes, text_bytes: bytes): | |
tokens, length = [], 0 | |
while length < len(text_bytes): | |
offset, token_id = trie_attempt_match_jit(trie_nodes, text_bytes[length:]) | |
assert offset >= 0 | |
tokens.append(token_id) | |
length += offset + 1 | |
return tokens | |
class TRIETokenizerFast: | |
def __init__(self, vocab_file: str): | |
self.nodes = [(b'', -1, -1, [-1 for _ in range(256)])] # node value, parent index, token id, children | |
with open(vocab_file, 'r') as file: | |
vocabs = json.load(file) | |
vocabs.sort(key=lambda i: len(i['bytes'])) | |
for entry in vocabs: | |
self.add_vocab(bytes(entry['bytes']), entry['id']) | |
self.id_to_bytes = {i['id']: i['bytes'] for i in vocabs} | |
self.nodesJit = numba.typed.List(self.nodes) | |
def add_vocab(self, vocab_bytes: bytes, vocab_id: int): | |
cur_node_idx = 0 | |
for i, b in enumerate(vocab_bytes): | |
cur_node = self.nodes[cur_node_idx] | |
if cur_node[3][b] != -1: | |
cur_node_idx = cur_node[3][b] | |
else: | |
new_node_idx = len(self.nodes) | |
self.nodes.append((vocab_bytes, cur_node_idx, vocab_id if i == len(vocab_bytes) - 1 else -1, | |
[-1 for _ in range(256)])) | |
cur_node[3][b] = new_node_idx | |
cur_node_idx = new_node_idx | |
def encode(self, text: str): | |
return trie_encode_jit(self.nodesJit, text.encode('utf-8')) | |
def decode(self, token_ids: List[int]): | |
return bytes([t for i in token_ids for t in self.id_to_bytes[i]]).decode('utf-8', errors='replace') | |
def get_vocab_size(self): | |
return len(self.id_to_bytes) | |
# if __name__ == '__main__': | |
# tokenizer = TRIETokenizerFast('llama_vocab_pruned_20k.json') | |
# with open('corpus/TinyStoriesV2-GPT4-valid.txt', 'r') as file: | |
# text = file.read()[:10240] | |
# | |
# total_tokens = 0 | |
# s = time.time() | |
# for i in range(1000): | |
# encoded = tokenizer.encode(text) | |
# total_tokens += len(encoded) | |
# print(len(encoded)) | |
# e = time.time() | |
# print(f'{e - s:.3f} secs, {total_tokens / (e - s):.3f} tps') | |