""" Minimal (byte-level) Byte Pair Encoding tokenizer. Unlike RegexTokenizer: - Operates on integer codes from an encodec codebook. """ import regex as re from .base import Tokenizer, get_stats, merge class CodebookTokenizer(Tokenizer): def __init__(self, pattern=None, codebook_size=1024): """ - pattern: optional string to override the default (GPT-4 split pattern) - special_tokens: str -> int dictionary of special tokens example: {'<|endoftext|>': 100257} """ self.merges = {} # (int, int) -> int self.pattern = pattern self.compiled_pattern = re.compile(self.pattern) self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257} self.inverse_special_tokens = {} self.codebook_size = codebook_size self.vocab = self._build_vocab() # int -> bytes def train(self, text, vocab_size, verbose=False): assert vocab_size >= self.codebook_size num_merges = vocab_size - self.codebook_size # split the text up into text chunks # text is a continuous signal, there is no splitting it up. text_chunks = [text,] # re.findall(self.compiled_pattern, text) # input text preprocessing ids = [[int(idx) for idx in ch.split(' ')] for ch in text_chunks] # iteratively merge the most common pairs to create new tokens merges = {} # (int, int) -> int # vocab = {idx: bytes([idx]) for idx in range(self.codebook_size)} # idx -> bytes vocab = {idx: f" {idx:04d}".encode('utf-8') for idx in range(self.codebook_size)} # idx -> bytes for i in range(num_merges): # count the number of times every consecutive pair appears stats = {} for chunk_ids in ids: # passing in stats will update it in place, adding up counts get_stats(chunk_ids, stats) # find the pair with the highest count pair = max(stats, key=stats.get) # mint a new token: assign it the next available id idx = self.codebook_size + i # replace all occurrences of pair in ids with idx ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids] # save the merge merges[pair] = idx vocab[idx] = vocab[pair[0]] + vocab[pair[1]] # prints if verbose: print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") # save class variables self.merges = merges # used in encode() self.vocab = vocab # used in decode() def register_special_tokens(self, special_tokens): # special_tokens is a dictionary of str -> int # example: {"<|endoftext|>": 100257} self.special_tokens = special_tokens self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} def decode(self, ids): # given ids (list of integers), return Python string part_bytes = [] for idx in ids: if idx in self.vocab: part_bytes.append(self.vocab[idx]) elif idx in self.inverse_special_tokens: part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8")) else: raise ValueError(f"invalid token id: {idx}") text_bytes = b"".join(part_bytes) text = text_bytes.decode("utf-8", errors="replace") return text def decode_int(self, ids) -> list[int]: ret: str = self.decode(ids) for s in self.special_tokens: ret = ret.replace(s, ' ' + s + ' ') ret = ret.strip() ret = [int(t) if t[0].isnumeric() else t for t in ret.split(' ') if len(t) > 0] return ret def _encode_chunk(self, text_bytes): # return the token ids # let's begin. first, convert all bytes to integers in range 0..255 ids = list(text_bytes) while len(ids) >= 2: # find the pair with the lowest merge index stats = get_stats(ids) pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) # subtle: if there are no more merges available, the key will # result in an inf for every single pair, and the min will be # just the first pair in the list, arbitrarily # we can detect this terminating case by a membership check if pair not in self.merges: break # nothing else can be merged anymore # otherwise let's merge the best pair (lowest merge index) idx = self.merges[pair] ids = merge(ids, pair, idx) return ids def encode_ordinary(self, text): """Encoding that ignores any special tokens.""" # split text into chunks of text by categories defined in regex pattern text_chunks = [text,] #re.findall(self.compiled_pattern, text) # all chunks of text are encoded separately, then results are joined ids = [] for chunk in text_chunks: # chunk_bytes = chunk.encode("utf-8") # raw bytes chunk_ids = [int(idx) for idx in chunk.split(' ')] chunk_ids = self._encode_chunk(chunk_ids) ids.extend(chunk_ids) return ids def encode(self, text, allowed_special="none_raise"): """ Unlike encode_ordinary, this function handles special tokens. allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens if none_raise, then an error is raised if any special token is encountered in text this is the default tiktoken behavior right now as well any other behavior is either annoying, or a major footgun """ # decode the user desire w.r.t. handling of special tokens special = None if allowed_special == "all": special = self.special_tokens elif allowed_special == "none": special = {} elif allowed_special == "none_raise": special = {} assert all(token not in text for token in self.special_tokens) elif isinstance(allowed_special, set): special = {k: v for k, v in self.special_tokens.items() if k in allowed_special} else: raise ValueError(f"allowed_special={allowed_special} not understood") if not special: # shortcut: if no special tokens, just use the ordinary encoding return self.encode_ordinary(text) # otherwise, we have to be careful with potential special tokens in text # we handle special tokens by splitting the text # based on the occurrence of any exact match with any of the special tokens # we can use re.split for this. note that surrounding the pattern with () # makes it into a capturing group, so the special tokens will be included special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")" special_chunks = re.split(special_pattern, text) # now all the special characters are separated from the rest of the text # all chunks of text are encoded separately, then results are joined ids = [] for part in special_chunks: part = part.strip() if len(part) == 0: continue if part in special: # this is a special token, encode it separately as a special case ids.append(special[part]) else: # this is an ordinary sequence, encode it normally ids.extend(self.encode_ordinary(part)) return ids def load(self, model_file): """Inverse of save() but only for the model file""" model_file = str(model_file) assert model_file.endswith(".model") # read the model file merges = {} special_tokens = {} idx = self.codebook_size with open(model_file, 'r', encoding="utf-8") as f: # read the version version = f.readline().strip() assert version == "minbpe v1" # read the pattern self.pattern = f.readline().strip() # read the special tokens num_special = int(f.readline().strip()) for _ in range(num_special): special, special_idx = f.readline().strip().split() special_tokens[special] = int(special_idx) # read the merges for line in f: # print(line) idx1, idx2 = map(int, line.split()) merges[(idx1, idx2)] = idx idx += 1 self.merges = merges self.special_tokens = special_tokens self.vocab = self._build_vocab() def _build_vocab(self): # vocab is simply and deterministically derived from merges vocab = {idx: f" {idx:04d}".encode('utf-8') for idx in range(self.codebook_size)} for (p0, p1), idx in self.merges.items(): vocab[idx] = vocab[p0] + vocab[p1] for special, idx in self.special_tokens.items(): vocab[idx] = special.encode("utf-8") return vocab