|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from unidecode import unidecode |
|
from collections import Counter |
|
import torch |
|
import os |
|
import gradio as gr |
|
import numpy as np |
|
import re |
|
import string |
|
from peft import PeftModel, PeftConfig |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("osiria/primo") |
|
model = AutoModelForCausalLM.from_pretrained("osiria/primo") |
|
model = PeftModel.from_pretrained(model, "osiria/primo") |
|
|
|
class Prime: |
|
|
|
def __init__(self, tokenizer, model): |
|
self.tokenizer = tokenizer |
|
self.model = model |
|
|
|
def _check_sublist(self, lst, sub_lst, sep = " "): |
|
|
|
l_type = type(lst[0]) |
|
lst = sep.join(list(map(str, lst))) |
|
sub_lst = sep.join(list(map(str, sub_lst))) |
|
|
|
return sub_lst in lst |
|
|
|
def _exclude_sublist(self, lst, sub_lst, sep = " "): |
|
|
|
l_type = type(lst[0]) |
|
lst = sep.join(list(map(str, lst))) |
|
sub_lst = sep.join(list(map(str, sub_lst))) |
|
lst = re.sub("\s+", " ", lst.replace(sub_lst, "")).strip().split(sep) |
|
lst = list(map(l_type, lst)) |
|
|
|
return lst |
|
|
|
def generate(self, prompt, message = "", sep = " [AI]", max_tokens = 100, excluded = [[40, 19]], |
|
lookback = 5, resample_tokens = [27793], replace_tokens = {11302: 23318}, |
|
stop_tokens = [239], |
|
sample = False, |
|
top_k = 5): |
|
|
|
if message: |
|
prompt = message + ". " + prompt |
|
prompt = prompt.replace("β", '"').replace("β", '"').replace("β", "'") |
|
if not sample: |
|
top_k = 2 |
|
tokens = tokenizer.encode("[HUMAN] " + prompt + sep) |
|
tokens_generated = [] |
|
checkpoint = 0 |
|
while tokens[-1] not in stop_tokens and len(tokens_generated) < max_tokens: |
|
output = model.forward(input_ids=torch.tensor([tokens]).to(device)).logits[0,-1] |
|
output = torch.softmax(output, dim = 0) |
|
candidates = torch.topk(output, k = top_k) |
|
if sample: |
|
indices = candidates.indices |
|
scores = candidates.values |
|
next_token = indices[torch.multinomial(scores, 1)[0].item()] |
|
else: |
|
next_token = candidates.indices[0] |
|
next_token = next_token.item() |
|
sub_tokens = tokens_generated[-lookback:] + [next_token] |
|
if next_token in resample_tokens: |
|
next_token = candidates.indices[1] |
|
next_token = next_token.item() |
|
if len(tokens_generated) >= (lookback + 1) and next_token in tokens_generated[-2:]: |
|
next_token = candidates.indices[1] |
|
next_token = next_token.item() |
|
elif len(tokens_generated) >= lookback and self._check_sublist(tokens_generated, sub_tokens): |
|
if checkpoint: |
|
tokens = tokens[:checkpoint] |
|
break |
|
else: |
|
next_token = candidates.indices[1] |
|
next_token = next_token.item() |
|
sample = True |
|
if next_token in replace_tokens: |
|
next_token = replace_tokens[next_token] |
|
tokens = tokens + [next_token] |
|
tokens_generated = tokens_generated + [next_token] |
|
if next_token == 5: |
|
checkpoint = len(tokens) |
|
for ex_lst in excluded: |
|
tokens = self._exclude_sublist(tokens, ex_lst) |
|
output = tokenizer.decode(tokens, skip_special_tokens=True) |
|
output = output.split(sep)[-1].strip() |
|
output = output[0].upper() + output[1:] |
|
if output[-1] == tokenizer.decode(stop_tokens[0]): |
|
output = output[:-1] |
|
if len(re.findall("\d\.", output)) > 1: |
|
output = re.sub("\d\.", "<br>β’", output) |
|
output = re.sub("^\<br\>", "", output) |
|
return output |
|
|
|
model.eval() |
|
device = torch.device("cuda") |
|
prime = Prime(tokenizer = tokenizer, model = model) |
|
|
|
def process_input(user_input, max_tokens, sample, top_k, message): |
|
return prime.generate(prompt = user_input, message = message, |
|
max_tokens = max_tokens, sample = sample, |
|
top_k = top_k) |
|
|
|
|
|
header = '''-------------------------------------------------------------------------------------------------- |
|
<style> |
|
.vertical-text { |
|
writing-mode: vertical-lr; |
|
text-orientation: upright; |
|
background-color:red; |
|
} |
|
</style> |
|
<center> |
|
<body> |
|
<span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;">β</span> |
|
<span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;">ββ</span> |
|
<span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;">βββββ</span> |
|
<span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;">βββββ</span> |
|
<span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;">ββ</span> |
|
<span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;">β</span> |
|
</body> |
|
</center> |
|
<br> |
|
<center><img src="file/primo.png" width="100"></center> |
|
''' |
|
|
|
import gradio as gr |
|
import random |
|
import time |
|
|
|
with gr.Blocks(title="primo", css="footer {visibility: hidden}", theme=gr.themes.Default(text_size="md", spacing_size="md")) as interface: |
|
gr.Markdown(header) |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("<b>opzioni</b>") |
|
max_tokens = gr.Slider(1, 250, value=150, label="massimo numero di token", info="scegli un limite tra 1 e 250") |
|
sample = gr.Checkbox(label="campionamento") |
|
top_k = gr.Slider(1, 5, step=1, value=1, label="creativitΓ ", info="scegli un livello tra 1 e 5") |
|
message = gr.Textbox(label="messaggio di sistema", value = "") |
|
clear = gr.Button("pulisci conversazione") |
|
with gr.Column(scale=8): |
|
chatbot = gr.Chatbot(label = "prime").style(height=600) |
|
msg = gr.Textbox(label = "richiesta") |
|
|
|
def user(user_message, history): |
|
return gr.update(value="", interactive=False), history + [[user_message, None]] |
|
|
|
def bot(history, message, max_tokens, sample, top_k): |
|
bot_message = process_input(history[-1][0], message = message, max_tokens = max_tokens, |
|
sample = sample, top_k = top_k) |
|
history[-1][1] = "" |
|
for character in bot_message: |
|
history[-1][1] += character |
|
time.sleep(0.05) |
|
yield history |
|
|
|
response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, [chatbot, message, max_tokens, sample, top_k], chatbot |
|
) |
|
response.then(lambda: gr.update(interactive=True), None, [msg], queue=False) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
with gr.Column(scale=1): |
|
gr.Markdown("<b>attenzione</b>") |
|
gr.Markdown("il modello potrebbe comportarsi in maniera imprevista nel caso in cui riceva prompt troppo lontani dal suo pre-training o fine-tuning e, per via della natura probabilistica del meccanismo di generazione, potrebbe occasionalmente produrre contenuti distorti o offensivi in relazione a tematiche come il genere, le etnie, le ideologie, e le convinzioni politiche o religiose<br><br>per via di queste limitazioni, il modello e i suoi output dovrebbero essere usati con cautela, e non dovrebbero essere coinvolti in contesti che richiedono che il testo generato sia corretto o veritiero") |
|
|
|
interface.queue() |
|
interface.launch() |