import time from typing import * import re import json import numba def sample_vocab(tokens: Iterable[str], vocab_size: Optional[int] = None, vocab_coverage: Optional[float] = None) -> List[str]: assert (vocab_size is not None and vocab_coverage is None) or \ (vocab_size is None and vocab_coverage is not None), "vocab_size [or] vocab_coverage need specified" token_count = {} for c in tokens: token_count[c] = token_count.get(c, 0) + 1 if vocab_size is not None: token_count = list(token_count.items()) token_count.sort(key=lambda i: i[1], reverse=True) vocab = [c[0] for c in token_count[:vocab_size]] else: total_count = sum(token_count.values()) token_freq = [(c, i / total_count) for c, i in token_count.items()] token_freq.sort(key=lambda i: i[1], reverse=True) freq_sum = 0.0 split = 0 for split in range(len(token_freq)): freq_sum += token_freq[split][1] if freq_sum >= vocab_coverage: break vocab = [c[0] for c in token_freq[:split + 1]] return vocab class CharTokenizer: def __init__(self, corpus: str, vocab_size: Optional[int] = None, vocab_coverage: Optional[float] = None, reserved_vocab: Optional[List[str]] = None, unk_literal: str = ''): if reserved_vocab is not None: assert len(reserved_vocab) == len(set(reserved_vocab)), 'no duplicate is allowed in reserved vocab' assert unk_literal not in reserved_vocab, f'unk literal "{unk_literal}" cannot be in reserved vocab' else: reserved_vocab = [] vocab = reserved_vocab.copy() if reserved_vocab is not None else [] vocab += sample_vocab(corpus, vocab_size - len(vocab) - 1, vocab_coverage) self.s2i = {s: i + 1 for i, s in enumerate(vocab)} self.s2i[unk_literal] = 0 self.i2s = {i: s for s, i in self.s2i.items()} self.special_vocab = set(reserved_vocab + [unk_literal]) self.unk_literal = unk_literal def encode(self, text: str) -> List[int]: cursor, ids = 0, [] while cursor < len(text): for s in self.special_vocab: if text[cursor:].startswith(s): ids.append(self.s2i[s]) cursor += len(s) break else: ids.append(self.s2i.get(text[cursor], self.s2i.get(self.unk_literal))) cursor += 1 return ids def decode(self, ids: List[int]) -> str: return ''.join(self.i2s[i] for i in ids) def get_vocab_mapping(self): return self.s2i class WordTokenizer: def __init__(self, corpus: str, vocab_size: Optional[int] = None, vocab_coverage: Optional[float] = None, reserved_vocab: Optional[List[str]] = None, unk_literal: str = ''): if reserved_vocab is not None: assert len(reserved_vocab) == len(set(reserved_vocab)), 'no duplicate is allowed in reserved vocab' assert unk_literal not in reserved_vocab, f'unk literal "{unk_literal}" cannot be in reserved vocab' else: reserved_vocab = [] vocab = reserved_vocab.copy() if reserved_vocab is not None else [] tokens = (c[0] if c[0] != '' else c[1] for c in re.finditer(r'(\w+)|(\W)', corpus)) vocab += sample_vocab(tokens, vocab_size - len(vocab) - 1, vocab_coverage) self.s2i = {s: i + 1 for i, s in enumerate(vocab)} self.s2i[unk_literal] = 0 self.i2s = {i: s for s, i in self.s2i.items()} self.special_vocab = set(reserved_vocab + [unk_literal]) self.unk_literal = unk_literal def encode(self, text: str) -> List[int]: specials = '|'.join(f'{i}' for i in self.special_vocab) tokens = (c[0] if c[0] != '' else c[1] for c in re.finditer(rf'({specials}|\w+)|(\W)', text)) return [self.s2i.get(t, self.s2i[self.unk_literal]) for t in tokens] def decode(self, ids: List[int]) -> str: return ''.join(self.i2s[i] for i in ids) def get_vocab_mapping(self): return self.s2i def get_vocab_size(self): return len(self.s2i) def eval_vocab_coverage(self, corpus: str): encoded = self.encode(corpus) return 1 - (len([i for i in encoded if i == 0]) / len(encoded)) class TRIETokenizer: @staticmethod def split_bytes(data: bytes): return [b'%c' % i for i in data] def __init__(self, vocab_file: str): self.nodes = [(b'', -1, -1, [-1 for _ in range(256)])] # node value, parent index, token id, children with open(vocab_file, 'r') as file: vocabs = json.load(file) vocabs.sort(key=lambda i: len(i['bytes'])) for entry in vocabs: self.add_vocab(bytes(entry['bytes']), entry['id']) self.id_to_bytes = {i['id']: i['bytes'] for i in vocabs} def add_vocab(self, vocab_bytes: bytes, vocab_id: int): cur_node_idx = 0 for i, b in enumerate(vocab_bytes): cur_node = self.nodes[cur_node_idx] if cur_node[3][b] != -1: cur_node_idx = cur_node[3][b] else: new_node_idx = len(self.nodes) self.nodes.append((vocab_bytes, cur_node_idx, vocab_id if i == len(vocab_bytes) - 1 else -1, [-1 for _ in range(256)])) cur_node[3][b] = new_node_idx cur_node_idx = new_node_idx def attempt_match(self, match_bytes: bytes): match_length, match_token_id = -1, -1 cur_node_idx, depth = 0, 0 for i, b in enumerate(match_bytes): match_node_idx = self.nodes[cur_node_idx][3][b] if match_node_idx == -1: break cur_node = self.nodes[match_node_idx] if cur_node[2] != -1: match_length = depth match_token_id = cur_node[2] cur_node_idx = match_node_idx depth += 1 return match_length, match_token_id def encode(self, text: str): text_bytes = text.encode('utf-8') tokens, length = [], 0 while length < len(text_bytes): offset, token_id = self.attempt_match(text_bytes[length:]) assert offset >= 0 tokens.append(token_id) length += offset + 1 return tokens def decode(self, token_ids: List[int]): return bytes([t for i in token_ids for t in self.id_to_bytes[i]]).decode('utf-8', errors='replace') def get_vocab_size(self): return len(self.id_to_bytes) @numba.njit def trie_attempt_match_jit(trie_nodes, match_bytes: bytes): match_length, match_token_id = -1, -1 cur_node_idx, depth = 0, 0 for i, b in enumerate(match_bytes): match_node_idx = trie_nodes[cur_node_idx][3][int(b)] if match_node_idx == -1: break cur_node = trie_nodes[match_node_idx] if cur_node[2] != -1: match_length = depth match_token_id = cur_node[2] cur_node_idx = match_node_idx depth += 1 return match_length, match_token_id @numba.njit def trie_encode_jit(trie_nodes, text_bytes: bytes): tokens, length = [], 0 while length < len(text_bytes): offset, token_id = trie_attempt_match_jit(trie_nodes, text_bytes[length:]) assert offset >= 0 tokens.append(token_id) length += offset + 1 return tokens class TRIETokenizerFast: def __init__(self, vocab_file: str): self.nodes = [(b'', -1, -1, [-1 for _ in range(256)])] # node value, parent index, token id, children with open(vocab_file, 'r') as file: vocabs = json.load(file) vocabs.sort(key=lambda i: len(i['bytes'])) for entry in vocabs: self.add_vocab(bytes(entry['bytes']), entry['id']) self.id_to_bytes = {i['id']: i['bytes'] for i in vocabs} self.nodesJit = numba.typed.List(self.nodes) def add_vocab(self, vocab_bytes: bytes, vocab_id: int): cur_node_idx = 0 for i, b in enumerate(vocab_bytes): cur_node = self.nodes[cur_node_idx] if cur_node[3][b] != -1: cur_node_idx = cur_node[3][b] else: new_node_idx = len(self.nodes) self.nodes.append((vocab_bytes, cur_node_idx, vocab_id if i == len(vocab_bytes) - 1 else -1, [-1 for _ in range(256)])) cur_node[3][b] = new_node_idx cur_node_idx = new_node_idx def encode(self, text: str): return trie_encode_jit(self.nodesJit, text.encode('utf-8')) def decode(self, token_ids: List[int]): return bytes([t for i in token_ids for t in self.id_to_bytes[i]]).decode('utf-8', errors='replace') def get_vocab_size(self): return len(self.id_to_bytes) # if __name__ == '__main__': # tokenizer = TRIETokenizerFast('llama_vocab_pruned_20k.json') # with open('corpus/TinyStoriesV2-GPT4-valid.txt', 'r') as file: # text = file.read()[:10240] # # total_tokens = 0 # s = time.time() # for i in range(1000): # encoded = tokenizer.encode(text) # total_tokens += len(encoded) # print(len(encoded)) # e = time.time() # print(f'{e - s:.3f} secs, {total_tokens / (e - s):.3f} tps')