wendru18 commited on
Commit
dbffdfc
Β·
1 Parent(s): 7483438

switched to gradio

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. README.md +2 -2
  3. app.py +67 -67
.gitignore CHANGED
@@ -1 +1,2 @@
1
- data/docs.csv
 
 
1
+ data/docs.csv
2
+ gradio_cached_examples
README.md CHANGED
@@ -3,8 +3,8 @@ title: Persplain
3
  emoji: πŸ“Š
4
  colorFrom: gray
5
  colorTo: green
6
- sdk: streamlit
7
- sdk_version: 1.17.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
3
  emoji: πŸ“Š
4
  colorFrom: gray
5
  colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.24.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,109 +1,109 @@
 
1
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
2
- from simpletransformers.classification import MultiLabelClassificationModel
3
  from transformers_interpret import MultiLabelClassificationExplainer
4
- import streamlit as st
5
  import pandas as pd
6
- import numpy as np
 
 
7
 
8
  traits = ["Openness to Experience", "Conscientiousness", "Extraversion", "Agreeableness", "Neuroticism" ]
9
 
10
  short_traits = ["o", "c", "e", "a", "n"]
11
 
12
- @st.experimental_memo
13
- def load_explainer():
14
 
 
15
  print("Loading model...")
16
-
17
  tokenizer = RobertaTokenizer.from_pretrained("andaqu/roBERTa-pers")
18
  model = RobertaForSequenceClassification.from_pretrained("andaqu/roBERTa-pers", problem_type="multi_label_classification")
19
-
20
  explainer = MultiLabelClassificationExplainer(model, tokenizer)
21
-
22
  try:
23
  model.to('cuda')
24
  if next(model.parameters()).is_cuda:
25
  print("Using GPU for inference!")
26
  except:
27
  print("GPU not available, using CPU instead.")
28
-
29
  print("Model loaded!")
30
-
31
  return explainer
32
 
33
- @st.experimental_memo
34
- def explain(text, _explainer):
35
-
36
- attributions = _explainer(text)
37
-
38
- preds = {label: pred_prob.item() for pred_prob, label in zip(_explainer.pred_probs_list, _explainer.labels)}
39
-
40
- attributions_html = {trait : attributions_to_html(attributions[trait]) for trait in attributions}
41
 
42
- return {"preds": preds, "word_attributions_html": attributions_html }
43
-
44
- def attributions_to_html(attributions):
45
- html = ""
46
- for word, attr in attributions:
47
-
48
- if word in ["<s>", "</s>"]:
49
- continue
50
-
51
- attr = round(attr, 2)
52
- abs_attr = abs(attr)
53
-
54
- color = "rgba(255,255,255,0)"
55
- if attr > 0: color = f"rgba(0,255,0,{abs_attr})"
56
- elif attr < 0: color = f"rgba(255,0,0,{abs_attr})"
57
-
58
- html += f'<span style="background-color: {color}" title="{str(attr)}">{word}</span> '
59
-
60
- return html
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if "text" in st.session_state: text = st.session_state.text
64
- else: st.session_state.text = ""
65
 
66
- if "explanation" in st.session_state: preds = st.session_state.explanation
67
- else: st.session_state.explanation = {"preds": {}, "word_attributions_html": ""}
68
 
 
69
 
70
- def main():
 
 
 
 
71
 
72
- st.title("Text to Personality Explainer πŸ“Š")
73
- text = ""
74
 
75
- explainer = load_explainer()
 
76
 
77
- text = st.text_area(label="Input text here...", value="I enjoy meeting people and working hard!")
78
 
79
- show_prediction = st.button("Predict Traits")
80
 
81
- st.session_state.text = text
82
 
83
- if show_prediction and st.session_state.text:
84
-
85
- explanation = explain(text, explainer)
86
- st.session_state.explanation = explanation
87
 
88
- if len(st.session_state.explanation["preds"]) > 0:
89
 
90
- st.write("## Predicted Traits")
 
 
 
91
 
92
- prediction = ["YES" if st.session_state.explanation["preds"][x] > 0.5 else "NO" for x in st.session_state.explanation["preds"]]
93
- probability = [str(round(st.session_state.explanation["preds"][x]*100)) + "%" for x in st.session_state.explanation["preds"]]
94
-
95
- st.table(pd.DataFrame([prediction, probability], columns=traits, index=["Predicted Traits", "Probability"]))
96
 
97
- st.write("## Explanation")
98
- # Show five buttons, horizontally, one for each trait
99
- cols = st.columns(5)
100
 
101
- for i in range(5):
102
- button = cols[i].button(traits[i])
103
-
104
- if button: st.markdown(st.session_state.explanation["word_attributions_html"][short_traits[i]], unsafe_allow_html=True)
105
 
 
 
 
 
 
 
106
 
 
107
 
108
- if __name__ == "__main__":
109
- main()
 
1
+ import gradio as gr
2
  from transformers import RobertaForSequenceClassification, RobertaTokenizer
 
3
  from transformers_interpret import MultiLabelClassificationExplainer
 
4
  import pandas as pd
5
+ from transformers import logging
6
+
7
+ logging.set_verbosity_warning()
8
 
9
  traits = ["Openness to Experience", "Conscientiousness", "Extraversion", "Agreeableness", "Neuroticism" ]
10
 
11
  short_traits = ["o", "c", "e", "a", "n"]
12
 
13
+ short_to_long = {"o": "Openness to Experience", "c": "Conscientiousness", "e": "Extraversion", "a": "Agreeableness", "n": "Neuroticism" }
 
14
 
15
+ def load_explainer():
16
  print("Loading model...")
 
17
  tokenizer = RobertaTokenizer.from_pretrained("andaqu/roBERTa-pers")
18
  model = RobertaForSequenceClassification.from_pretrained("andaqu/roBERTa-pers", problem_type="multi_label_classification")
 
19
  explainer = MultiLabelClassificationExplainer(model, tokenizer)
 
20
  try:
21
  model.to('cuda')
22
  if next(model.parameters()).is_cuda:
23
  print("Using GPU for inference!")
24
  except:
25
  print("GPU not available, using CPU instead.")
 
26
  print("Model loaded!")
 
27
  return explainer
28
 
29
+ explainer = load_explainer()
 
 
 
 
 
 
 
30
 
31
+ def explain(text, _explainer):
32
+ if text is not None:
33
+ attributions = _explainer(text)
34
+ preds = {label: pred_prob.item() for pred_prob, label in zip(_explainer.pred_probs_list, _explainer.labels)}
35
+ attributions_html = {trait : attributions_to_html(attributions[trait], trait) for trait in attributions}
36
+ return {"preds": preds, "word_attributions_html": attributions_html }
37
+ else:
38
+ return None
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ def attributions_to_html(attributions, short_trait):
41
+ html = f""
42
+ for word, attr in attributions:
43
+ if word in ["<s>", "</s>"]:
44
+ continue
45
+ attr = round(attr, 2)
46
+ abs_attr = abs(attr)
47
+ color = "rgba(255,255,255,0)"
48
+ if attr > 0: color = f"rgba(0,255,0,{abs_attr})"
49
+ elif attr < 0: color = f"rgba(255,0,0,{abs_attr})"
50
+ html += f'<span style="background-color: {color}" title="{str(attr)}">{word}</span> '
51
+ html += f"<br>"
52
+ return html
53
 
54
+ def get_predictions(text):
55
+ explanation = explain(text, explainer)
56
 
57
+ prediction = ["YES" if explanation["preds"][x] > 0.5 else "NO" for x in explanation["preds"]]
58
+ probability = [str(round(explanation["preds"][x]*100)) + "%" for x in explanation["preds"]]
59
 
60
+ result_df = pd.DataFrame(data={"Predicted Traits": prediction, "Probability": probability}, index=traits)
61
 
62
+ def color_row(row):
63
+ if row['Predicted Traits'] == 'YES':
64
+ return ['background-color: green']*len(row)
65
+ else:
66
+ return ['background-color: red']*len(row)
67
 
68
+ # apply conditional formatting to dataframe
69
+ result_df = result_df.style.apply(color_row, axis=1)
70
 
71
+ def render_html(val):
72
+ return val
73
 
74
+ explanation_df = pd.DataFrame(data={"Explanation": [explanation["word_attributions_html"][x] for x in short_traits]}, index=traits)
75
 
76
+ explanation_df = explanation_df.style.format({'Explanation': render_html})
77
 
78
+ return result_df, explanation_df
79
 
80
+ def text_to_personality_explainer(text):
81
+ result_df, explanation_df = get_predictions(text)
 
 
82
 
83
+ return "<center>" + result_df.to_html() + "</center>", "<center>" + explanation_df.to_html() + "</center>"
84
 
85
+ main = gr.Blocks()
86
+ text_input = gr.Textbox(placeholder="Enter text here...")
87
+ result = gr.outputs.HTML()
88
+ explanation = gr.outputs.HTML()
89
 
90
+ with main:
91
+ gr.Markdown("# Text to Personality Explainer πŸ“Š")
92
+ gr.Markdown("Predict personality traits from text using a RoBERTa model fine-tuned on a Big Five Personality Traits dataset.")
93
+ gr.Markdown("Explanations are given in the form of word attributions, where the color of the word indicates the importance of the word for the prediction. Green words increase the probability of the trait, red words decrease the probability of the trait.")
94
 
95
+ gr.Examples(["I love working and talking to people!", "I am a bad person. :(", "I find it challenging to agree with my brother."], fn=text_to_personality_explainer, inputs=text_input, outputs=[result, explanation], cache_examples=False)
 
 
96
 
97
+ text_input.render()
98
+ text_button = gr.Button("Predict")
 
 
99
 
100
+ with gr.Tabs():
101
+ with gr.TabItem("Prediction"):
102
+ result.render()
103
+
104
+ with gr.TabItem("Explanation"):
105
+ explanation.render()
106
 
107
+ text_button.click(text_to_personality_explainer, inputs=text_input, outputs=[result, explanation])
108
 
109
+ main.launch(show_api=False)