ERABB / app.py
909ahmed's picture
Update app.py
6c1f2be verified
raw
history blame
No virus
3.59 kB
import gradio as gr
import regex as re
from tqdm import tqdm
import pickle
class Tokenizer:
def __init__(self):
self.vocab = {idx : bytes([idx]) for idx in range(256)}
self.pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
self.merges = {}
def merge(self, tokens, target, new_token):
new_tokens = []
i = 0
while i < len(tokens):
if i + 1 < len(tokens) and tokens[i] == target[0] and tokens[i + 1] == target[1]:
i += 1
new_tokens.append(new_token)
else:
new_tokens.append(tokens[i])
i += 1
return new_tokens
def get_stats(self, idsList):
pairs = {}
if not isinstance(idsList[0], list):
idsList = [idsList]
for tokens in idsList:
for a, b in zip(tokens, tokens[1:]):
if not (a, b) in pairs:
pairs[(a, b)] = 1
else:
pairs[(a, b)] += 1
return pairs
def get_max_pair(self, idsList):
pairs = self.get_stats(idsList)
return sorted(pairs.items(), key=lambda item : item[1])[-1][0]
def get_min(self, idsList):
stats = self.get_stats(idsList)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
return pair
def train(self, epochs, text):
pat = re.compile(self.pattern)
textList = re.findall(pat, text)
idsList = [list(text.encode('utf-8')) for text in textList]
for epoch in tqdm(range(epochs)):
max_pair = self.get_max_pair(idsList)
new_token = 256 + epoch
self.merges[max_pair] = new_token
idsList = [self.merge(tokens, max_pair, new_token) for tokens in idsList]
self.vocab[new_token] = self.vocab[max_pair[0]] + self.vocab[max_pair[1]]
return [x for xs in idsList for x in xs]
def encode(self, text):
tokens = list(text.encode('utf-8'))
while len(tokens) >= 2:
pair = self.get_min(tokens)
if pair not in self.merges:
break
idx = self.merges[pair]
tokens = self.merge(tokens, pair, idx)
return tokens
def decode(self, tokens):
tokens = b"".join(self.vocab[token] for token in tokens)
text = tokens.decode('utf-8', errors='replace')
return text
title = "Ghalib doing tiktok"
description = "A simple Gradio interface to infer urdu tokenizer"
tokenizer = Tokenizer()
with open('merges.pkl', 'rb') as files:
tokenizer.vocab = pickle.load(files)
with open('vocab.pkl', 'rb') as files:
tokenizer.merges = pickle.load(files)
def inference(text):
return tokenizer.encode(text)
iface = gr.Interface(
inference,
inputs = ["text"],
outputs = ["text"],
examples=["سفید رنگ ہیں آخر سیاہ مو کرتے لٹاتے دولت دنیا کو میکدے میں ہم طلائی ساغر مے نقرئی سبو کرتے ہمیشہ میں نے گریباں کو چاک چاک کیا",
" دل کہ آتے ہیں جس کو دھیان بہت خود بھی آتا ہے اپنے دھیان میں کیاوہ ملے تو یہ پوچھنا ہے مجھےاب بھی ہوں میں تری امان میں کیا"],
title = title,
description = description,
)
iface.launch()