remzicam commited on
Commit
a45605a
·
1 Parent(s): 62b8fda

initial commit

Browse files
Files changed (2) hide show
  1. app.py +246 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """XAI for Transformers Intent Classifier App."""
2
+
3
+ from collections import Counter
4
+ from itertools import count
5
+ from operator import itemgetter
6
+ from re import DOTALL, sub
7
+
8
+ import streamlit as st
9
+ from plotly.express import bar
10
+ from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
11
+ pipeline)
12
+ from transformers_interpret import SequenceClassificationExplainer
13
+
14
+ hide_streamlit_style = """
15
+ <style>
16
+ #MainMenu {visibility: hidden;}
17
+ footer {visibility: hidden;}
18
+ </style>
19
+ """
20
+ hide_plotly_bar = {"displayModeBar": False}
21
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
22
+ repo_id = "remzicam/privacy_intent"
23
+ task = "text-classification"
24
+ title = "XAI for Intent Classification and Model Interpretation"
25
+ st.markdown(
26
+ f"<h1 style='text-align: center; color: #0068C9;'>{title}</h1>", unsafe_allow_html=True
27
+ )
28
+
29
+
30
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
31
+ def load_models():
32
+ """
33
+ It loads the model and tokenizer from the HuggingFace model hub, and then creates a pipeline object
34
+ that can be used to make predictions. Also, it creates model interpretation object.
35
+
36
+ Returns:
37
+ the privacy_intent_pipe and cls_explainer.
38
+ """
39
+ model = AutoModelForSequenceClassification.from_pretrained(
40
+ repo_id, low_cpu_mem_usage=True
41
+ )
42
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
43
+ privacy_intent_pipe = pipeline(
44
+ task, model=model, tokenizer=tokenizer, return_all_scores=True
45
+ )
46
+ cls_explainer = SequenceClassificationExplainer(model, tokenizer)
47
+ return privacy_intent_pipe, cls_explainer
48
+
49
+
50
+ privacy_intent_pipe, cls_explainer = load_models()
51
+
52
+
53
+ def label_probs_figure_creater(input_text:str):
54
+ """
55
+ It takes in a string, runs it through the pipeline, and returns a figure and the label with the
56
+ highest probability
57
+
58
+ Args:
59
+ input_text (str): The text you want to analyze
60
+
61
+ Returns:
62
+ A tuple of a figure and a string.
63
+ """
64
+ outputs = privacy_intent_pipe(input_text)[0]
65
+ sorted_outputs = sorted(outputs, key=lambda k: k["score"])
66
+ prediction_label = sorted_outputs[-1]["label"]
67
+ fig = bar(
68
+ sorted_outputs,
69
+ x="score",
70
+ y="label",
71
+ color="score",
72
+ color_continuous_scale="rainbow",
73
+ width=600,
74
+ height=400,
75
+ )
76
+ fig.update_layout(
77
+ title="Model Prediction Probabilities for Each Label",
78
+ xaxis_title="",
79
+ yaxis_title="",
80
+ xaxis=dict( # attribures for x axis
81
+ showline=True,
82
+ showgrid=True,
83
+ linecolor="black",
84
+ tickfont=dict(family="Calibri"),
85
+ ),
86
+ yaxis=dict( # attribures for y axis
87
+ showline=True,
88
+ showgrid=True,
89
+ linecolor="black",
90
+ tickfont=dict(
91
+ family="Times New Roman",
92
+ ),
93
+ ),
94
+ plot_bgcolor="white",
95
+ title_x=0.5,
96
+ )
97
+ return fig, prediction_label
98
+
99
+
100
+ def xai_attributions_html(input_text: str):
101
+ """
102
+ 1. The function takes in a string of text as input.
103
+ 2. It then uses the explainer to generate attributions for each word in the input text.
104
+ 3. It then uses the explainer to generate an HTML visualization of the attributions.
105
+ 4. It then cleans up the HTML visualization by removing some unnecessary HTML tags.
106
+ 5. It then returns the attributions and the HTML visualization
107
+
108
+ Args:
109
+ input_text (str): The text you want to explain.
110
+
111
+ Returns:
112
+ the word attributions and the html.
113
+ """
114
+
115
+ word_attributions = cls_explainer(input_text)
116
+ html = cls_explainer.visualize().data
117
+ html = html.replace("#s", "")
118
+ html = html.replace("#/s", "")
119
+ html = sub("<th.*?/th>", "", html, 4, DOTALL)
120
+ html = sub("<td.*?/td>", "", html, 4, DOTALL)
121
+ return word_attributions, html
122
+
123
+
124
+ def explanation_intro(prediction_label: str): #TODO: write docstring
125
+ """
126
+ generates model explanaiton markdown from prediction label of the model.
127
+
128
+ Args:
129
+ prediction_label (str): The label that the model predicted.
130
+
131
+ Returns:
132
+ A string
133
+ """
134
+ return f"""The model predicted the given sentence as **:blue[{prediction_label}]**.
135
+ The figure below shows the contribution of each token to this decision.
136
+ **:green[Green]** tokens indicate a **positive contribution**, while **:red[red]** tokens indicate a **negative** contribution.
137
+ The **bolder** the color, the greater the value."""
138
+
139
+
140
+ def explanation_viz(prediction_label: str, word_attributions):
141
+ """
142
+ It takes in a prediction label and a list of word attributions, and returns a markdown string that contains
143
+ the word that had the highest attribution and the prediction label
144
+
145
+ Args:
146
+ prediction_label (str): The label that the model predicted.
147
+ word_attributions: a list of tuples of the form (word, attribution score)
148
+
149
+ Returns:
150
+ A string
151
+ """
152
+ top_attention_word = max(word_attributions, key=itemgetter(1))[0]
153
+ return f"""The word **_{top_attention_word}_** is the biggest driver for the decision of the model as **:blue[{prediction_label}]**."""
154
+
155
+
156
+ def word_attributions_dict_creater(word_attributions):
157
+ """
158
+ It takes a list of tuples, reverses it, splits it into two lists, colors the scores, numerates
159
+ duplicated strings, and returns a dictionary
160
+
161
+ Args:
162
+ word_attributions: This is the output of the model explainer.
163
+
164
+ Returns:
165
+ A dictionary with the keys "word", "score", and "colors".
166
+ """
167
+ word_attributions = word_attributions[1:-1]
168
+ # remove strings shorter than 1 chrachter
169
+ word_attributions = [i for i in word_attributions if len(i[0]) > 1]
170
+ word_attributions.reverse()
171
+ words, scores = zip(*word_attributions)
172
+ # colorize positive and negative scores
173
+ colors = ["red" if x < 0 else "lightgreen" for x in scores]
174
+ # darker tone for max score
175
+ max_index = scores.index(max(scores))
176
+ colors[max_index] = "darkgreen"
177
+ # numerate duplicated strings
178
+ c = Counter(words)
179
+ iters = {k: count(1) for k, v in c.items() if v > 1}
180
+ words_ = [x + "_" + str(next(iters[x])) if x in iters else x for x in words]
181
+ # plotly accepts dictionaries
182
+
183
+ return {
184
+ "word": words_,
185
+ "score": scores,
186
+ "colors": colors,
187
+ }
188
+
189
+
190
+ def attention_score_figure_creater(word_attributions_dict):
191
+ """
192
+ It takes a dictionary of words and their attention scores and returns a bar graph of the words and
193
+ their attention scores with specified colors.
194
+
195
+ Args:
196
+ word_attributions_dict: a dictionary with keys "word", "score", and "colors"
197
+
198
+ Returns:
199
+ A figure object
200
+ """
201
+ fig = bar(word_attributions_dict, x="score", y="word", width=400, height=500)
202
+ fig.update_traces(marker_color=word_attributions_dict["colors"])
203
+ fig.update_layout(
204
+ title="Word-Attention Score",
205
+ xaxis_title="",
206
+ yaxis_title="",
207
+ xaxis=dict( # attribures for x axis
208
+ showline=True,
209
+ showgrid=True,
210
+ linecolor="black",
211
+ tickfont=dict(family="Calibri"),
212
+ ),
213
+ yaxis=dict( # attribures for y axis
214
+ showline=True,
215
+ showgrid=True,
216
+ linecolor="black",
217
+ tickfont=dict(
218
+ family="Times New Roman",
219
+ ),
220
+ ),
221
+ plot_bgcolor="white",
222
+ title_x=0.5,
223
+ )
224
+
225
+ return fig
226
+
227
+
228
+ form = st.form(key="intent-form")
229
+ input_text = form.text_area(
230
+ label="Text",
231
+ value="At any time during your use of the Services, you may decide to share some information or content publicly or privately.",
232
+ )
233
+ submit = form.form_submit_button("Submit")
234
+
235
+ if submit:
236
+ label_probs_figure, prediction_label = label_probs_figure_creater(input_text)
237
+ st.plotly_chart(label_probs_figure, config=hide_plotly_bar)
238
+ explanation_general = explanation_intro(prediction_label)
239
+ st.info(explanation_general)
240
+ word_attributions, html = xai_attributions_html(input_text)
241
+ st.markdown(html, unsafe_allow_html=True)
242
+ explanation_specific = explanation_viz(prediction_label, word_attributions)
243
+ st.info(explanation_specific)
244
+ word_attributions_dict = word_attributions_dict_creater(word_attributions)
245
+ attention_score_figure = attention_score_figure_creater(word_attributions_dict)
246
+ st.plotly_chart(attention_score_figure, config=hide_plotly_bar)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ accelerate
3
+ plotly
4
+ torch==1.13.1+cpu
5
+ transformers
6
+ transformers-interpret