Spaces:
Runtime error
Runtime error
import gradio as gr | |
import re | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
import torch | |
from keybert import KeyBERT | |
from datasets import load_dataset | |
import shap | |
from transformers_interpret import SequenceClassificationExplainer | |
from ferret import Benchmark | |
#model_identifier = "karalif/myTestModel" | |
#model = AutoModelForSequenceClassification.from_pretrained(model_identifier) | |
#tokenizer = AutoTokenizer.from_pretrained(model_identifier) | |
name = "karalif/myTestModel" | |
model = AutoModelForSequenceClassification.from_pretrained(name) | |
tokenizer = AutoTokenizer.from_pretrained(name, normalization=True) | |
bench = Benchmark(model, tokenizer) | |
#text = "hvað er maðurinn eiginlega að pæla ég fatta ekki??????????" | |
def get_prediction(text): | |
encoding = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=200) | |
encoding = {k: v.to(model.device) for k, v in encoding.items()} | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
logits = outputs.logits | |
sigmoid = torch.nn.Sigmoid() | |
probs = sigmoid(logits.squeeze().cpu()).numpy() | |
kw_model = KeyBERT() | |
keywords = kw_model.extract_keywords(text, keyphrase_ngram_range=(1, 1), stop_words='english', use_maxsum=True, nr_candidates=20, top_n=5) | |
response = "" | |
labels = ['Politeness', 'Toxicity', 'Sentiment', 'Formality'] | |
colors = ['#b8e994', '#f8d7da', '#fff3cd', '#bee5eb'] # Corresponding colors for labels | |
for i, label in enumerate(labels): | |
response += f"<span style='background-color:{colors[i]}; color:black;'>{label}</span>: {probs[i]*100:.1f}%<br>" | |
influential_keywords = "INFLUENTIAL KEYWORDS:<br>" | |
for keyword, score in keywords: | |
influential_keywords += f"{keyword} (Score: {score:.2f})<br>" | |
return response, keywords, influential_keywords | |
def replace_encoding(tokens): | |
return [token.replace('Ġ', ' ') | |
.replace('ð', 'ð') | |
.replace('é', 'é') | |
.replace('æ', 'æ') | |
.replace('ý', 'ý') | |
.replace('á', 'á') | |
.replace('ú', 'ú') | |
.replace('ÃŃ', 'í') | |
.replace('Ãö', 'ö') | |
.replace('þ', 'þ') | |
.replace('Ãģ', 'Á') | |
.replace('Ãį', 'Ú') | |
.replace('Ãĵ', 'Ó') | |
.replace('ÃĨ', 'Æ') | |
.replace('ÃIJ', 'Ð') | |
.replace('Ãĸ', 'Ö') | |
.replace('Ãī', 'É') | |
.replace('Ãļ', 'ý') | |
for token in tokens[1:-1]] | |
def predict(text): | |
explanations_formality = bench.explain(text, target=0) | |
explanations_sentiment = bench.explain(text, target=1) | |
explanations_politeness = bench.explain(text, target=2) | |
explanations_toxicity = bench.explain(text, target=3) | |
greeting_pattern = r"^(Halló|Hæ|Sæl|Góðan dag|Kær kveðja|Daginn|Kvöldið|Ágætis|Elsku)" | |
prediction_output, keywords, influential_keywords = get_prediction(text) | |
greeting_feedback = "" | |
modified_input = text | |
for keyword, _ in keywords: | |
modified_input = modified_input.replace(keyword, f"<span style='color:green;'>{keyword}</span>") | |
#if not re.match(greeting_pattern, text, re.IGNORECASE): | |
# greeting_feedback = "OTHER FEEDBACK:<br>Heilsaðu dóninn þinn<br>" | |
response = f"INPUT:<br>{modified_input}<br><br>MY PREDICTION:<br>{prediction_output}<br>{influential_keywords}<br>{greeting_feedback}" | |
# Influential words | |
explanation_lists = [explanations_toxicity, explanations_formality, explanations_sentiment, explanations_politeness] | |
labels = ['Toxicity', 'Formality', 'Sentiment', 'Politeness'] | |
response += "<br>MOST INFLUENTIAL WORDS FOR EACH LABEL:<br>" | |
for i, explanations in enumerate(explanation_lists): | |
label = labels[i] | |
for explanation in explanations: | |
if explanation.explainer == 'Partition SHAP': | |
tokens = replace_encoding(explanation.tokens) | |
token_score_pairs = zip(tokens, explanation.scores) | |
formatted_output = ' '.join([f"{token} ({score})" for token, score in token_score_pairs]) | |
response += f"{label}: {formatted_output}<br>" | |
#response += "<br>TOP 2 MOST INFLUENTIAL WORDS FOR EACH LABEL:<br>" | |
#for i, explanations in enumerate(explanation_lists): | |
# label = labels[i] | |
# response += f"{label}:<br>" | |
# for explanation in explanations: | |
# if explanation.explainer == 'Partition SHAP': | |
# sorted_scores = sorted(enumerate(explanation.scores), key=lambda x: abs(x[1]), reverse=True)[:2] | |
# tokens = replace_encoding(explanation.tokens) | |
# tokens = [tokens[idx] for idx, _ in sorted_scores] | |
# formatted_output = ' '.join(tokens) | |
# response += f"{formatted_output}<br>" | |
return response | |
description_html = """ | |
<center> | |
<img src='http://www.ru.is/media/HR_logo_vinstri_transparent.png' width='250' height='auto'> | |
</center> | |
""" | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.TextArea(label="Enter text here:"), | |
outputs=gr.HTML(label="Leiðrétt"), | |
description=description_html, | |
examples=[ | |
["Sæl og blessuð Kristín, hvað er að frella af þér gamla??"], | |
], | |
theme=gr.themes.Default(primary_hue="red", secondary_hue="pink") | |
) | |
demo.launch() |