wendru18
commited on
Commit
Β·
dbffdfc
1
Parent(s):
7483438
switched to gradio
Browse files- .gitignore +2 -1
- README.md +2 -2
- app.py +67 -67
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
data/docs.csv
|
|
|
|
1 |
+
data/docs.csv
|
2 |
+
gradio_cached_examples
|
README.md
CHANGED
@@ -3,8 +3,8 @@ title: Persplain
|
|
3 |
emoji: π
|
4 |
colorFrom: gray
|
5 |
colorTo: green
|
6 |
-
sdk:
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
3 |
emoji: π
|
4 |
colorFrom: gray
|
5 |
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.24.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
CHANGED
@@ -1,109 +1,109 @@
|
|
|
|
1 |
from transformers import RobertaForSequenceClassification, RobertaTokenizer
|
2 |
-
from simpletransformers.classification import MultiLabelClassificationModel
|
3 |
from transformers_interpret import MultiLabelClassificationExplainer
|
4 |
-
import streamlit as st
|
5 |
import pandas as pd
|
6 |
-
|
|
|
|
|
7 |
|
8 |
traits = ["Openness to Experience", "Conscientiousness", "Extraversion", "Agreeableness", "Neuroticism" ]
|
9 |
|
10 |
short_traits = ["o", "c", "e", "a", "n"]
|
11 |
|
12 |
-
|
13 |
-
def load_explainer():
|
14 |
|
|
|
15 |
print("Loading model...")
|
16 |
-
|
17 |
tokenizer = RobertaTokenizer.from_pretrained("andaqu/roBERTa-pers")
|
18 |
model = RobertaForSequenceClassification.from_pretrained("andaqu/roBERTa-pers", problem_type="multi_label_classification")
|
19 |
-
|
20 |
explainer = MultiLabelClassificationExplainer(model, tokenizer)
|
21 |
-
|
22 |
try:
|
23 |
model.to('cuda')
|
24 |
if next(model.parameters()).is_cuda:
|
25 |
print("Using GPU for inference!")
|
26 |
except:
|
27 |
print("GPU not available, using CPU instead.")
|
28 |
-
|
29 |
print("Model loaded!")
|
30 |
-
|
31 |
return explainer
|
32 |
|
33 |
-
|
34 |
-
def explain(text, _explainer):
|
35 |
-
|
36 |
-
attributions = _explainer(text)
|
37 |
-
|
38 |
-
preds = {label: pred_prob.item() for pred_prob, label in zip(_explainer.pred_probs_list, _explainer.labels)}
|
39 |
-
|
40 |
-
attributions_html = {trait : attributions_to_html(attributions[trait]) for trait in attributions}
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
attr = round(attr, 2)
|
52 |
-
abs_attr = abs(attr)
|
53 |
-
|
54 |
-
color = "rgba(255,255,255,0)"
|
55 |
-
if attr > 0: color = f"rgba(0,255,0,{abs_attr})"
|
56 |
-
elif attr < 0: color = f"rgba(255,0,0,{abs_attr})"
|
57 |
-
|
58 |
-
html += f'<span style="background-color: {color}" title="{str(attr)}">{word}</span> '
|
59 |
-
|
60 |
-
return html
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
|
66 |
-
if
|
67 |
-
|
68 |
|
|
|
69 |
|
70 |
-
def
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
|
73 |
-
|
74 |
|
75 |
-
|
|
|
76 |
|
77 |
-
|
78 |
|
79 |
-
|
80 |
|
81 |
-
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
explanation = explain(text, explainer)
|
86 |
-
st.session_state.explanation = explanation
|
87 |
|
88 |
-
|
89 |
|
90 |
-
|
|
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
|
97 |
-
|
98 |
-
# Show five buttons, horizontally, one for each trait
|
99 |
-
cols = st.columns(5)
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
if button: st.markdown(st.session_state.explanation["word_attributions_html"][short_traits[i]], unsafe_allow_html=True)
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
|
|
107 |
|
108 |
-
|
109 |
-
main()
|
|
|
1 |
+
import gradio as gr
|
2 |
from transformers import RobertaForSequenceClassification, RobertaTokenizer
|
|
|
3 |
from transformers_interpret import MultiLabelClassificationExplainer
|
|
|
4 |
import pandas as pd
|
5 |
+
from transformers import logging
|
6 |
+
|
7 |
+
logging.set_verbosity_warning()
|
8 |
|
9 |
traits = ["Openness to Experience", "Conscientiousness", "Extraversion", "Agreeableness", "Neuroticism" ]
|
10 |
|
11 |
short_traits = ["o", "c", "e", "a", "n"]
|
12 |
|
13 |
+
short_to_long = {"o": "Openness to Experience", "c": "Conscientiousness", "e": "Extraversion", "a": "Agreeableness", "n": "Neuroticism" }
|
|
|
14 |
|
15 |
+
def load_explainer():
|
16 |
print("Loading model...")
|
|
|
17 |
tokenizer = RobertaTokenizer.from_pretrained("andaqu/roBERTa-pers")
|
18 |
model = RobertaForSequenceClassification.from_pretrained("andaqu/roBERTa-pers", problem_type="multi_label_classification")
|
|
|
19 |
explainer = MultiLabelClassificationExplainer(model, tokenizer)
|
|
|
20 |
try:
|
21 |
model.to('cuda')
|
22 |
if next(model.parameters()).is_cuda:
|
23 |
print("Using GPU for inference!")
|
24 |
except:
|
25 |
print("GPU not available, using CPU instead.")
|
|
|
26 |
print("Model loaded!")
|
|
|
27 |
return explainer
|
28 |
|
29 |
+
explainer = load_explainer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
def explain(text, _explainer):
|
32 |
+
if text is not None:
|
33 |
+
attributions = _explainer(text)
|
34 |
+
preds = {label: pred_prob.item() for pred_prob, label in zip(_explainer.pred_probs_list, _explainer.labels)}
|
35 |
+
attributions_html = {trait : attributions_to_html(attributions[trait], trait) for trait in attributions}
|
36 |
+
return {"preds": preds, "word_attributions_html": attributions_html }
|
37 |
+
else:
|
38 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
def attributions_to_html(attributions, short_trait):
|
41 |
+
html = f""
|
42 |
+
for word, attr in attributions:
|
43 |
+
if word in ["<s>", "</s>"]:
|
44 |
+
continue
|
45 |
+
attr = round(attr, 2)
|
46 |
+
abs_attr = abs(attr)
|
47 |
+
color = "rgba(255,255,255,0)"
|
48 |
+
if attr > 0: color = f"rgba(0,255,0,{abs_attr})"
|
49 |
+
elif attr < 0: color = f"rgba(255,0,0,{abs_attr})"
|
50 |
+
html += f'<span style="background-color: {color}" title="{str(attr)}">{word}</span> '
|
51 |
+
html += f"<br>"
|
52 |
+
return html
|
53 |
|
54 |
+
def get_predictions(text):
|
55 |
+
explanation = explain(text, explainer)
|
56 |
|
57 |
+
prediction = ["YES" if explanation["preds"][x] > 0.5 else "NO" for x in explanation["preds"]]
|
58 |
+
probability = [str(round(explanation["preds"][x]*100)) + "%" for x in explanation["preds"]]
|
59 |
|
60 |
+
result_df = pd.DataFrame(data={"Predicted Traits": prediction, "Probability": probability}, index=traits)
|
61 |
|
62 |
+
def color_row(row):
|
63 |
+
if row['Predicted Traits'] == 'YES':
|
64 |
+
return ['background-color: green']*len(row)
|
65 |
+
else:
|
66 |
+
return ['background-color: red']*len(row)
|
67 |
|
68 |
+
# apply conditional formatting to dataframe
|
69 |
+
result_df = result_df.style.apply(color_row, axis=1)
|
70 |
|
71 |
+
def render_html(val):
|
72 |
+
return val
|
73 |
|
74 |
+
explanation_df = pd.DataFrame(data={"Explanation": [explanation["word_attributions_html"][x] for x in short_traits]}, index=traits)
|
75 |
|
76 |
+
explanation_df = explanation_df.style.format({'Explanation': render_html})
|
77 |
|
78 |
+
return result_df, explanation_df
|
79 |
|
80 |
+
def text_to_personality_explainer(text):
|
81 |
+
result_df, explanation_df = get_predictions(text)
|
|
|
|
|
82 |
|
83 |
+
return "<center>" + result_df.to_html() + "</center>", "<center>" + explanation_df.to_html() + "</center>"
|
84 |
|
85 |
+
main = gr.Blocks()
|
86 |
+
text_input = gr.Textbox(placeholder="Enter text here...")
|
87 |
+
result = gr.outputs.HTML()
|
88 |
+
explanation = gr.outputs.HTML()
|
89 |
|
90 |
+
with main:
|
91 |
+
gr.Markdown("# Text to Personality Explainer π")
|
92 |
+
gr.Markdown("Predict personality traits from text using a RoBERTa model fine-tuned on a Big Five Personality Traits dataset.")
|
93 |
+
gr.Markdown("Explanations are given in the form of word attributions, where the color of the word indicates the importance of the word for the prediction. Green words increase the probability of the trait, red words decrease the probability of the trait.")
|
94 |
|
95 |
+
gr.Examples(["I love working and talking to people!", "I am a bad person. :(", "I find it challenging to agree with my brother."], fn=text_to_personality_explainer, inputs=text_input, outputs=[result, explanation], cache_examples=False)
|
|
|
|
|
96 |
|
97 |
+
text_input.render()
|
98 |
+
text_button = gr.Button("Predict")
|
|
|
|
|
99 |
|
100 |
+
with gr.Tabs():
|
101 |
+
with gr.TabItem("Prediction"):
|
102 |
+
result.render()
|
103 |
+
|
104 |
+
with gr.TabItem("Explanation"):
|
105 |
+
explanation.render()
|
106 |
|
107 |
+
text_button.click(text_to_personality_explainer, inputs=text_input, outputs=[result, explanation])
|
108 |
|
109 |
+
main.launch(show_api=False)
|
|