Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Byte Pair Encoding Tokenizer for Indian Languages | |
A simple implementation of BPE tokenizer with Marathi-specific preprocessing. | |
Author: Shilpaj Bhalerao | |
Date: 2025-01-05 | |
""" | |
# Standard Library Imports | |
import re | |
# Third Party Imports | |
from tqdm import tqdm | |
class BPETokenizer: | |
""" | |
Byte Pair Encoding Tokenizer | |
:param vocab_size (int): Size of final vocabulary (including base bytes) | |
:param merges (dict): Dictionary of merge rules | |
:param vocab (dict): Dictionary mapping token IDs to their byte sequences | |
:param inverse_vocab (dict): Dictionary mapping byte sequences to token IDs | |
""" | |
def __init__(self, vocab_size=1000, use_regex=False): | |
""" | |
Initialize the tokenizer with desired vocabulary size. | |
""" | |
self.vocab_size = vocab_size | |
self.merges = {} | |
self.len_of_ids = 0 | |
self.len_raw_bytes = 0 | |
self.vocab = {idx: bytes([idx]) for idx in range(256)} | |
self.inverse_vocab = {bytes([idx]): idx for idx in range(256)} | |
self.use_regex = use_regex | |
# Marathi tokenization regex pattern | |
self.marathi_regex = re.compile( | |
r"([\u0900-\u094F\u0951-\u097F]+|" # Marathi words and ligatures | |
r"[\u0966-\u096F]+|" # Marathi numerals (०-९) | |
r"\d+(?:\s[\u0900-\u097F]+)?|" # Arabic numerals with Marathi context | |
r"#[\w\u0900-\u097F]+|" # Hashtags | |
r"[\w\u0900-\u097F]+[''][\w\u0900-\u097F]+|" # Compound words with apostrophes | |
r"[\w\u0900-\u097F]+(?:-[\w\u0900-\u097F]+)*|" # Hyphenated words | |
r"[\w\u0900-\u097F]+\.[\w\u0900-\u097F]*|" # Abbreviations | |
r'\"[^\"]+\"|\'[^\']+\'|' # Quoted text | |
r"[\u0964\u0965.!?…]|" # Marathi punctuation | |
r"[^\s\u0900-\u097F]+)" # Non-Marathi symbols | |
) | |
def preprocess(self, text: str) -> str: | |
""" | |
Preprocess Marathi text before tokenization. | |
:param text: Input Marathi text | |
:return: Preprocessed text with tokens separated by spaces | |
""" | |
# Find all tokens using the Marathi regex | |
tokens = self.marathi_regex.findall(text) | |
# Join tokens with spaces | |
processed_text = ' '.join(tokens) | |
# Normalize whitespace | |
processed_text = ' '.join(processed_text.split()) | |
return processed_text | |
def _get_stats(self, ids: list[int]) -> dict[tuple[int, int], int]: | |
""" | |
Count frequency of adjacent pairs in sequence. | |
:param ids: list of integers | |
:return: dictionary of pairs and their frequencies | |
""" | |
counts = {} | |
for pair in zip(ids, ids[1:]): | |
counts[pair] = counts.get(pair, 0) + 1 | |
return counts | |
def _merge(self, ids: list[int], pair: tuple[int, int], idx: int) -> list[int]: | |
""" | |
Replace all occurrences of pair with new token idx. | |
:param ids: list of integers | |
:param pair: tuple of integers | |
:param idx: integer | |
:return: list of integers | |
""" | |
newids = [] | |
i = 0 | |
while i < len(ids): | |
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: | |
newids.append(idx) | |
i += 2 | |
else: | |
newids.append(ids[i]) | |
i += 1 | |
return newids | |
def train(self, text: str): | |
""" | |
Train the BPE tokenizer on the given text. | |
:param text: Input text to train on | |
""" | |
print("Training BPE tokenizer...") | |
# Preprocess text first | |
if self.use_regex: | |
text = self.preprocess(text) | |
# Convert text to bytes and get initial tokens | |
raw_bytes = text.encode("utf-8") | |
raw_bytes = list(map(int, raw_bytes)) # convert to integers | |
self.len_raw_bytes = len(raw_bytes) | |
# Calculate number of merges needed | |
num_merges = self.vocab_size - 256 | |
ids = list(raw_bytes) # copy so we don't destroy the original list | |
# Perform merges | |
for i in tqdm(range(num_merges)): | |
stats = self._get_stats(ids) | |
if not stats: | |
break | |
# Find most frequent pair | |
pair = max(stats, key=stats.get) | |
idx = 256 + i | |
# Perform the merge | |
ids = self._merge(ids, pair, idx) | |
self.len_of_ids = len(ids) | |
self.merges[pair] = idx | |
# Update vocabulary | |
new_token = self.vocab[pair[0]] + self.vocab[pair[1]] | |
self.vocab[idx] = new_token | |
self.inverse_vocab[new_token] = idx | |
def encode(self, text: str) -> list[int]: | |
""" | |
Encode text into token IDs. | |
:param text: Text to encode | |
:return: List of token IDs | |
""" | |
# Preprocess if needed | |
if self.use_regex: | |
text = self.preprocess(text) | |
# Convert text to list of integers | |
tokens = list(text.encode("utf-8")) | |
while len(tokens) >= 2: | |
stats = self._get_stats(tokens) | |
pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) | |
if pair not in self.merges: | |
break # nothing else can be merged | |
idx = self.merges[pair] | |
tokens = self._merge(tokens, pair, idx) | |
return tokens | |
def decode(self, ids: list[int]) -> str: | |
""" | |
Decode token IDs back to text. | |
:param ids: List of token IDs | |
:return: Decoded text | |
""" | |
tokens = b"".join(self.vocab[idx] for idx in ids) | |
return tokens.decode("utf-8", errors="replace") | |
def token_to_text(self, token_id: int) -> str: | |
""" | |
Convert a single token ID to its text representation. | |
:param token_id: Token ID | |
:return: Text representation of the token | |
""" | |
return self.vocab[token_id].decode("utf-8", errors="replace") | |
def save(self, path: str): | |
""" | |
Save tokenizer state to file. | |
:param path: Path to save the file | |
""" | |
import json | |
state = { | |
'vocab_size': self.vocab_size, | |
'merges': list(self.merges.items()), # Convert to list of tuples | |
'vocab': {k: list(v) for k, v in self.vocab.items()} # Convert bytes to lists | |
} | |
with open(path, 'w') as f: | |
json.dump(state, f) | |
def load(cls, path: str): | |
""" | |
Load tokenizer state from file. | |
:param path: Path to load the file | |
:return: Loaded tokenizer | |
""" | |
import json | |
with open(path, 'r') as f: | |
state = json.load(f) | |
tokenizer = cls(vocab_size=state['vocab_size']) | |
# Convert lists back to tuples for the merge pairs | |
tokenizer.merges = {tuple(k): v for k, v in state['merges']} | |
tokenizer.vocab = {int(k): bytes(v) for k, v in state['vocab'].items()} | |
tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()} | |
return tokenizer | |
def get_vocab_size(self) -> int: | |
""" | |
Get the size of the vocabulary. | |
:return: Size of the vocabulary | |
""" | |
return len(self.vocab) | |
def get_compression_ratio(self, text: str) -> float: | |
""" | |
Get the compression ratio of the text. | |
:param text: Input text | |
:return: Compression ratio (original_length / encoded_length) | |
""" | |
# Preprocess if needed | |
if self.use_regex: | |
text = self.preprocess(text) | |
return round(self.len_raw_bytes / self.len_of_ids, 4) | |
def get_token_length(self, text: str) -> int: | |
""" | |
Get the length of the tokenized text. | |
:param text: Input text | |
:return: Length of the tokenized text | |
""" | |
return self.len_raw_bytes | |
def get_ids_length(self, text: str) -> int: | |
""" | |
Get the length of the tokenized text. | |
:param text: Input text | |
:return: Length of the tokenized text | |
""" | |
return self.len_of_ids | |
def is_encoded_equals_decoded(self, text: str) -> bool: | |
""" | |
Check if encoding and decoding are consistent. | |
:param text: Input text | |
:return: True if consistent, False otherwise | |
""" | |
encoded = self.encode(text) | |
decoded = self.decode(encoded) | |
return text == decoded | |
if __name__ == "__main__": | |
# Read text from file | |
with open("dataset.txt", "r") as file: | |
text = file.read() | |
# Initialize and train | |
tokenizer = BPETokenizer(vocab_size=3000) | |
tokenizer.train(text) | |
# Save and load | |
tokenizer.save("tokenizer.json") | |
loaded_tokenizer = BPETokenizer.load("tokenizer.json") | |
# Encode and decode | |
encoded = tokenizer.encode("या पुतळ्याच्या डोक्यावर अज्ञातांनी चप्पल ठेवल्याचे आढळून आले आहे.") | |
decoded = loaded_tokenizer.decode(encoded) | |
# Check consistency | |
print("Is encoded equals to loaded decoded? ", decoded == "या पुतळ्याच्या डोक्यावर अज्ञातांनी चप्पल ठेवल्याचे आढळून आले आहे.") | |
# Print vocab size | |
print(f"Vocab size: {tokenizer.get_vocab_size()}") | |
# Print token length | |
print(f"Token length: {tokenizer.get_token_length(text)}") | |
# Print ids length | |
print(f"Ids length: {tokenizer.get_ids_length(text)}") | |
# Print compression ratio | |
print(f"Compression ratio: {tokenizer.get_compression_ratio(text)}X") | |