ERABB / app.py
909ahmed's picture
Update app.py
64ef2f3 verified
raw
history blame
No virus
5 kB
import gradio as gr
import regex as re
from tqdm import tqdm
import pickle
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
BitsAndBytesConfig
)
def check_check():
model_name = "NousResearch/Llama-2-7b-chat-hf"
use_4bit = True
bnb_4bit_compute_dtype = "bf16"
bnb_4bit_quant_type = "nf4"
bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device_map,
quantization_config=bnb_config
)
model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
file_path = 'data.pkl'
with open(file_path, 'rb') as files:
data_dict = pickle.load(files)
model.state_dict = data_dict
print("Lets go baby")
# pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=1500)
# result = pipe(text)
# print(result[0]['generated_text'])
class Tokenizer:
def __init__(self):
self.vocab = {idx : bytes([idx]) for idx in range(256)}
self.pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
self.merges = {}
def merge(self, tokens, target, new_token):
new_tokens = []
i = 0
while i < len(tokens):
if i + 1 < len(tokens) and tokens[i] == target[0] and tokens[i + 1] == target[1]:
i += 1
new_tokens.append(new_token)
else:
new_tokens.append(tokens[i])
i += 1
return new_tokens
def get_stats(self, idsList):
pairs = {}
if not isinstance(idsList[0], list):
idsList = [idsList]
for tokens in idsList:
for a, b in zip(tokens, tokens[1:]):
if not (a, b) in pairs:
pairs[(a, b)] = 1
else:
pairs[(a, b)] += 1
return pairs
def get_max_pair(self, idsList):
pairs = self.get_stats(idsList)
return sorted(pairs.items(), key=lambda item : item[1])[-1][0]
def get_min(self, idsList):
stats = self.get_stats(idsList)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
return pair
def train(self, epochs, text):
pat = re.compile(self.pattern)
textList = re.findall(pat, text)
idsList = [list(text.encode('utf-8')) for text in textList]
for epoch in tqdm(range(epochs)):
max_pair = self.get_max_pair(idsList)
new_token = 256 + epoch
self.merges[max_pair] = new_token
idsList = [self.merge(tokens, max_pair, new_token) for tokens in idsList]
self.vocab[new_token] = self.vocab[max_pair[0]] + self.vocab[max_pair[1]]
return [x for xs in idsList for x in xs]
def encode(self, text):
tokens = list(text.encode('utf-8'))
while len(tokens) >= 2:
pair = self.get_min(tokens)
if pair not in self.merges:
break
idx = self.merges[pair]
tokens = self.merge(tokens, pair, idx)
return tokens
def decode(self, tokens):
tokens = b"".join(self.vocab[token] for token in tokens)
text = tokens.decode('utf-8', errors='replace')
return text
title = "Ghalib doing tiktok"
description = "A simple Gradio interface to infer urdu tokenizer"
tokenizer = Tokenizer()
temp = Tokenizer()
with open('vocab.pkl', 'rb') as files:
tokenizer.vocab = pickle.load(files)
with open('merges.pkl', 'rb') as files:
tokenizer.merges = pickle.load(files)
def inference(text):
tokens = tokenizer.encode(text)
str_tokens = b" | ".join(tokenizer.vocab[token] for token in tokens)
seps = str_tokens.decode('utf-8', errors='replace')
return tokens, seps, len(temp.encode(text)) / len(tokenizer.encode(text))
iface = gr.Interface(
inference,
inputs=["text"],
outputs=["text", "text", "text"],
examples=["سفید رنگ ہیں آخر سیاہ مو کرتے لٹاتے دولت دنیا کو میکدے میں ہم طلائی ساغر مے نقرئی سبو کرتے ہمیشہ میں نے گریباں کو چاک چاک کیا",
" دل کہ آتے ہیں جس کو دھیان بہت خود بھی آتا ہے اپنے دھیان میں کیاوہ ملے تو یہ پوچھنا ہے مجھےاب بھی ہوں میں تری امان میں کیا"],
title = title,
description = description,
)
check_check()
iface.launch()