""" Basic Byte Pair Encoding (BPE) tokenizer. This is a simple tokenizer that can be trained on a text and then used to encode and decode strings. The training process is based on the BPE algorithm, which merges the most common pairs of tokens iteratively to create new tokens. The tokenizer can be saved to disk and loaded back later. The vocabulary and merges can be inspected in human-readable form. No handling of special tokens or regex patterns. """ from tokenizer.utils import get_stats, merge from tokenizer.base import Tokenizer class BasicTokenizer(Tokenizer): def __init__(self): super().__init__() def train(self, text, vocab_size, verbose=False): assert vocab_size >= 256 num_merges = vocab_size - 256 # input text preprocessing text_bytes = text.encode("utf-8") # raw bytes ids = list(text_bytes) # list of integers in range 0..255 # iteratively merge the most common pairs to create new tokens merges = {} # (int, int) -> int vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes for i in range(num_merges): # count up the number of times every consecutive pair appears stats = get_stats(ids) # find the pair with the highest count pair = max(stats, key=stats.get) # mint a new token: assign it the next available id idx = 256 + i # replace all occurrences of pair in ids with idx ids = merge(ids, pair, idx) # save the merge merges[pair] = idx vocab[idx] = vocab[pair[0]] + vocab[pair[1]] # prints if verbose: print( f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences" ) # save class variables self.merges = merges # used in encode() self.vocab = vocab # used in decode() def decode(self, ids): # given ids (list of integers), return Python string text_bytes = b"".join(self.vocab[idx] for idx in ids) text = text_bytes.decode("utf-8", errors="replace") return text def encode(self, text): # given a string text, return the token ids text_bytes = text.encode("utf-8") # raw bytes ids = list(text_bytes) # list of integers in range 0..255 while len(ids) >= 2: # find the pair with the lowest merge index stats = get_stats(ids) pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) # subtle: if there are no more merges available, the key will # result in an inf for every single pair, and the min will be # just the first pair in the list, arbitrarily # we can detect this terminating case by a membership check if pair not in self.merges: break # nothing else can be merged anymore # otherwise let's merge the best pair (lowest merge index) idx = self.merges[pair] ids = merge(ids, pair, idx) return ids