File size: 5,003 Bytes
2eee342 4bd0020 64ef2f3 2eee342 4bd0020 5dcacdd 977ee0e 4bd0020 977ee0e 4bd0020 2eee342 4bd0020 1f29464 4bd0020 7b6dbcc 4bd0020 53426ce 4bd0020 5c372d7 7a6783f 53426ce 6c1f2be 4bd0020 64ef2f3 4bd0020 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import gradio as gr
import regex as re
from tqdm import tqdm
import pickle
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
BitsAndBytesConfig
)
def check_check():
model_name = "NousResearch/Llama-2-7b-chat-hf"
use_4bit = True
bnb_4bit_compute_dtype = "bf16"
bnb_4bit_quant_type = "nf4"
bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device_map,
quantization_config=bnb_config
)
model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
file_path = 'data.pkl'
with open(file_path, 'rb') as files:
data_dict = pickle.load(files)
model.state_dict = data_dict
print("Lets go baby")
# pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=1500)
# result = pipe(text)
# print(result[0]['generated_text'])
class Tokenizer:
def __init__(self):
self.vocab = {idx : bytes([idx]) for idx in range(256)}
self.pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
self.merges = {}
def merge(self, tokens, target, new_token):
new_tokens = []
i = 0
while i < len(tokens):
if i + 1 < len(tokens) and tokens[i] == target[0] and tokens[i + 1] == target[1]:
i += 1
new_tokens.append(new_token)
else:
new_tokens.append(tokens[i])
i += 1
return new_tokens
def get_stats(self, idsList):
pairs = {}
if not isinstance(idsList[0], list):
idsList = [idsList]
for tokens in idsList:
for a, b in zip(tokens, tokens[1:]):
if not (a, b) in pairs:
pairs[(a, b)] = 1
else:
pairs[(a, b)] += 1
return pairs
def get_max_pair(self, idsList):
pairs = self.get_stats(idsList)
return sorted(pairs.items(), key=lambda item : item[1])[-1][0]
def get_min(self, idsList):
stats = self.get_stats(idsList)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
return pair
def train(self, epochs, text):
pat = re.compile(self.pattern)
textList = re.findall(pat, text)
idsList = [list(text.encode('utf-8')) for text in textList]
for epoch in tqdm(range(epochs)):
max_pair = self.get_max_pair(idsList)
new_token = 256 + epoch
self.merges[max_pair] = new_token
idsList = [self.merge(tokens, max_pair, new_token) for tokens in idsList]
self.vocab[new_token] = self.vocab[max_pair[0]] + self.vocab[max_pair[1]]
return [x for xs in idsList for x in xs]
def encode(self, text):
tokens = list(text.encode('utf-8'))
while len(tokens) >= 2:
pair = self.get_min(tokens)
if pair not in self.merges:
break
idx = self.merges[pair]
tokens = self.merge(tokens, pair, idx)
return tokens
def decode(self, tokens):
tokens = b"".join(self.vocab[token] for token in tokens)
text = tokens.decode('utf-8', errors='replace')
return text
title = "Ghalib doing tiktok"
description = "A simple Gradio interface to infer urdu tokenizer"
tokenizer = Tokenizer()
temp = Tokenizer()
with open('vocab.pkl', 'rb') as files:
tokenizer.vocab = pickle.load(files)
with open('merges.pkl', 'rb') as files:
tokenizer.merges = pickle.load(files)
def inference(text):
tokens = tokenizer.encode(text)
str_tokens = b" | ".join(tokenizer.vocab[token] for token in tokens)
seps = str_tokens.decode('utf-8', errors='replace')
return tokens, seps, len(temp.encode(text)) / len(tokenizer.encode(text))
iface = gr.Interface(
inference,
inputs=["text"],
outputs=["text", "text", "text"],
examples=["سفید رنگ ہیں آخر سیاہ مو کرتے لٹاتے دولت دنیا کو میکدے میں ہم طلائی ساغر مے نقرئی سبو کرتے ہمیشہ میں نے گریباں کو چاک چاک کیا",
" دل کہ آتے ہیں جس کو دھیان بہت خود بھی آتا ہے اپنے دھیان میں کیاوہ ملے تو یہ پوچھنا ہے مجھےاب بھی ہوں میں تری امان میں کیا"],
title = title,
description = description,
)
check_check()
iface.launch() |