|
import gradio as gr |
|
import regex as re |
|
from tqdm import tqdm |
|
import pickle |
|
|
|
class Tokenizer: |
|
|
|
def __init__(self): |
|
|
|
self.vocab = {idx : bytes([idx]) for idx in range(256)} |
|
self.pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" |
|
self.merges = {} |
|
|
|
def merge(self, tokens, target, new_token): |
|
|
|
new_tokens = [] |
|
i = 0 |
|
while i < len(tokens): |
|
|
|
if i + 1 < len(tokens) and tokens[i] == target[0] and tokens[i + 1] == target[1]: |
|
i += 1 |
|
new_tokens.append(new_token) |
|
else: |
|
new_tokens.append(tokens[i]) |
|
i += 1 |
|
|
|
return new_tokens |
|
|
|
def get_stats(self, idsList): |
|
|
|
pairs = {} |
|
if not isinstance(idsList[0], list): |
|
idsList = [idsList] |
|
for tokens in idsList: |
|
for a, b in zip(tokens, tokens[1:]): |
|
|
|
if not (a, b) in pairs: |
|
pairs[(a, b)] = 1 |
|
else: |
|
pairs[(a, b)] += 1 |
|
return pairs |
|
|
|
def get_max_pair(self, idsList): |
|
|
|
pairs = self.get_stats(idsList) |
|
return sorted(pairs.items(), key=lambda item : item[1])[-1][0] |
|
|
|
def get_min(self, idsList): |
|
|
|
stats = self.get_stats(idsList) |
|
pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) |
|
return pair |
|
|
|
def train(self, epochs, text): |
|
|
|
pat = re.compile(self.pattern) |
|
textList = re.findall(pat, text) |
|
idsList = [list(text.encode('utf-8')) for text in textList] |
|
for epoch in tqdm(range(epochs)): |
|
|
|
max_pair = self.get_max_pair(idsList) |
|
new_token = 256 + epoch |
|
self.merges[max_pair] = new_token |
|
idsList = [self.merge(tokens, max_pair, new_token) for tokens in idsList] |
|
self.vocab[new_token] = self.vocab[max_pair[0]] + self.vocab[max_pair[1]] |
|
|
|
return [x for xs in idsList for x in xs] |
|
|
|
def encode(self, text): |
|
|
|
tokens = list(text.encode('utf-8')) |
|
while len(tokens) >= 2: |
|
|
|
pair = self.get_min(tokens) |
|
if pair not in self.merges: |
|
break |
|
|
|
idx = self.merges[pair] |
|
tokens = self.merge(tokens, pair, idx) |
|
|
|
return tokens |
|
|
|
def decode(self, tokens): |
|
|
|
tokens = b"".join(self.vocab[token] for token in tokens) |
|
text = tokens.decode('utf-8', errors='replace') |
|
return text |
|
|
|
title = "Ghalib doing tiktok" |
|
description = "A simple Gradio interface to infer urdu tokenizer" |
|
|
|
tokenizer = Tokenizer() |
|
temp = Tokenizer() |
|
with open('vocab.pkl', 'rb') as files: |
|
tokenizer.vocab = pickle.load(files) |
|
with open('merges.pkl', 'rb') as files: |
|
tokenizer.merges = pickle.load(files) |
|
|
|
def inference(text): |
|
return tokenizer.encode(text) |
|
|
|
iface = gr.Interface( |
|
inference, |
|
inputs=gr.inputs.Textbox(lines=4, label="Enter Text"), |
|
outputs=[gr.outputs.Textbox(label="Tokens"), gr.outputs.Textbox(label="compression ratio")] |
|
examples=["سفید رنگ ہیں آخر سیاہ مو کرتے لٹاتے دولت دنیا کو میکدے میں ہم طلائی ساغر مے نقرئی سبو کرتے ہمیشہ میں نے گریباں کو چاک چاک کیا", |
|
" دل کہ آتے ہیں جس کو دھیان بہت خود بھی آتا ہے اپنے دھیان میں کیاوہ ملے تو یہ پوچھنا ہے مجھےاب بھی ہوں میں تری امان میں کیا"], |
|
title = title, |
|
description = description, |
|
) |
|
|
|
iface.launch() |