import gradio as gr
from transformers import RobertaForSequenceClassification, RobertaTokenizer
from transformers_interpret import MultiLabelClassificationExplainer
import pandas as pd
from transformers import logging
logging.set_verbosity_warning()
traits = ["Openness to Experience", "Conscientiousness", "Extraversion", "Agreeableness", "Neuroticism" ]
short_traits = ["o", "c", "e", "a", "n"]
short_to_long = {"o": "Openness to Experience", "c": "Conscientiousness", "e": "Extraversion", "a": "Agreeableness", "n": "Neuroticism" }
def load_explainer():
print("Loading model...")
tokenizer = RobertaTokenizer.from_pretrained("andaqu/roBERTa-pers")
model = RobertaForSequenceClassification.from_pretrained("andaqu/roBERTa-pers", problem_type="multi_label_classification")
explainer = MultiLabelClassificationExplainer(model, tokenizer)
try:
model.to('cuda')
if next(model.parameters()).is_cuda:
print("Using GPU for inference!")
except:
print("GPU not available, using CPU instead.")
print("Model loaded!")
return explainer
explainer = load_explainer()
def explain(text, _explainer):
if text is not None:
attributions = _explainer(text)
preds = {label: pred_prob.item() for pred_prob, label in zip(_explainer.pred_probs_list, _explainer.labels)}
attributions_html = {trait : attributions_to_html(attributions[trait], trait) for trait in attributions}
return {"preds": preds, "word_attributions_html": attributions_html }
else:
return None
def attributions_to_html(attributions, short_trait):
html = f""
for word, attr in attributions:
if word in ["", ""]:
continue
attr = round(attr, 2)
abs_attr = abs(attr)
color = "rgba(255,255,255,0)"
if attr > 0: color = f"rgba(0,255,0,{abs_attr})"
elif attr < 0: color = f"rgba(255,0,0,{abs_attr})"
html += f'{word} '
html += f"
"
return html
def get_predictions(text):
explanation = explain(text, explainer)
prediction = ["YES" if explanation["preds"][x] > 0.5 else "NO" for x in explanation["preds"]]
probability = [str(round(explanation["preds"][x]*100)) + "%" for x in explanation["preds"]]
result_df = pd.DataFrame(data={"Predicted Traits": prediction, "Probability": probability}, index=traits)
def color_row(row):
if row['Predicted Traits'] == 'YES':
return ['background-color: green']*len(row)
else:
return ['background-color: red']*len(row)
# apply conditional formatting to dataframe
result_df = result_df.style.apply(color_row, axis=1)
def render_html(val):
return val
explanation_df = pd.DataFrame(data={"Explanation": [explanation["word_attributions_html"][x] for x in short_traits]}, index=traits)
explanation_df = explanation_df.style.format({'Explanation': render_html})
return result_df, explanation_df
def text_to_personality_explainer(text):
result_df, explanation_df = get_predictions(text)
return "