karalif's picture
Update app.py
216e8fc verified
raw
history blame
5.86 kB
import gradio as gr
import re
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
from keybert import KeyBERT
from datasets import load_dataset
import shap
from transformers_interpret import SequenceClassificationExplainer
from ferret import Benchmark
#model_identifier = "karalif/myTestModel"
#model = AutoModelForSequenceClassification.from_pretrained(model_identifier)
#tokenizer = AutoTokenizer.from_pretrained(model_identifier)
name = "karalif/myTestModel"
model = AutoModelForSequenceClassification.from_pretrained(name)
tokenizer = AutoTokenizer.from_pretrained(name, normalization=True)
bench = Benchmark(model, tokenizer)
#text = "hvað er maðurinn eiginlega að pæla ég fatta ekki??????????"
def get_prediction(text):
explanations_formality = bench.explain(text, target=0)
explanations_sentiment = bench.explain(text, target=1)
explanations_politeness = bench.explain(text, target=2)
explanations_toxicity = bench.explain(text, target=3)
encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=200)
encoding = {k: v.to(model.device) for k, v in encoding.items()}
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu()).numpy()
kw_model = KeyBERT()
keywords = kw_model.extract_keywords(text, keyphrase_ngram_range=(1, 1), stop_words='english', use_maxsum=True, nr_candidates=20, top_n=5)
response = ""
labels = ['Politeness', 'Toxicity', 'Sentiment', 'Formality']
colors = ['#b8e994', '#f8d7da', '#fff3cd', '#bee5eb'] # Corresponding colors for labels
for i, label in enumerate(labels):
response += f"<span style='background-color:{colors[i]}; color:black;'>{label}</span>: {probs[i]*100:.1f}%<br>"
influential_keywords = "INFLUENTIAL KEYWORDS:<br>"
for keyword, score in keywords:
influential_keywords += f"{keyword} (Score: {score:.2f})<br>"
return response, keywords, influential_keywords
def replace_encoding(tokens):
return [token.replace('Ġ', ' ')
.replace('ð', 'ð')
.replace('é', 'é')
.replace('æ', 'æ')
.replace('ý', 'ý')
.replace('á', 'á')
.replace('ú', 'ú')
.replace('ÃŃ', 'í')
.replace('Ãö', 'ö')
.replace('þ', 'þ')
.replace('Ãģ', 'Á')
.replace('Ãį', 'Ú')
.replace('Ãĵ', 'Ó')
.replace('ÃĨ', 'Æ')
.replace('ÃIJ', 'Ð')
.replace('Ãĸ', 'Ö')
.replace('Ãī', 'É')
.replace('Ãļ', 'ý')
for token in tokens[1:-1]]
def predict(text):
greeting_pattern = r"^(Halló|Hæ|Sæl|Góðan dag|Kær kveðja|Daginn|Kvöldið|Ágætis|Elsku)"
prediction_output, keywords, influential_keywords = get_prediction(text)
greeting_feedback = ""
modified_input = text
for keyword, _ in keywords:
modified_input = modified_input.replace(keyword, f"<span style='color:green;'>{keyword}</span>")
if not re.match(greeting_pattern, text, re.IGNORECASE):
greeting_feedback = "OTHER FEEDBACK:<br>Heilsaðu dóninn þinn<br>"
response = f"INPUT:<br>{modified_input}<br><br>MY PREDICTION:<br>{prediction_output}<br>{influential_keywords}<br>{greeting_feedback}"
# Include influential words in the response
explanation_lists = [explanations_toxicity, explanations_formality, explanations_sentiment, explanations_politeness]
labels = ['Toxicity', 'Formality', 'Sentiment', 'Politeness']
response += "<br>MOST INFLUENTIAL WORDS FOR EACH LABEL:<br>"
for i, explanations in enumerate(explanation_lists):
label = labels[i]
for explanation in explanations:
if explanation.explainer == 'Partition SHAP':
tokens = replace_encoding(explanation.tokens)
token_score_pairs = zip(tokens, explanation.scores)
formatted_output = ' '.join([f"{token} ({score})" for token, score in token_score_pairs])
response += f"{label}: {formatted_output}<br>"
response += "<br>TOP 2 MOST INFLUENTIAL WORDS FOR EACH LABEL:<br>"
for i, explanations in enumerate(explanation_lists):
label = labels[i]
response += f"{label}:<br>"
for explanation in explanations:
if explanation.explainer == 'Partition SHAP':
sorted_scores = sorted(enumerate(explanation.scores), key=lambda x: abs(x[1]), reverse=True)[:2]
tokens = replace_encoding(explanation.tokens)
tokens = [tokens[idx] for idx, _ in sorted_scores]
formatted_output = ' '.join(tokens)
response += f"{formatted_output}<br>"
return response
description_html = """
<center>
<img src='http://www.ru.is/media/HR_logo_vinstri_transparent.png' width='250' height='auto'>
</center>
"""
demo = gr.Interface(
fn=predict,
inputs=gr.TextArea(label="Enter text here:"),
outputs=gr.HTML(label="Leiðrétt"),
description=description_html,
examples=[
["Það voru vitni að árásinni sem tilkynntu málið til lögreglu sem kom skjótt á vettvang."],
["Ég held þetta sé ekki góður tími fara heimsókn."],
["Sæl og blessuð Kristín, hvað er að frella af þér gamla??"],
["Hver á þenan bússtað? já eða nei."],
["Hafi þau svo látið gólfið þorna vel og síðan flotað það til lagfæringar eftir motturnar."],
],
theme=gr.themes.Default(primary_hue="red", secondary_hue="pink")
)
demo.launch()