Spaces:
Sleeping
Sleeping
update files
Browse files- app.py +15 -13
- requirements.txt +2 -1
app.py
CHANGED
@@ -113,6 +113,10 @@ def xai_attributions_html(input_text: str):
|
|
113 |
"""
|
114 |
|
115 |
word_attributions = cls_explainer(input_text)
|
|
|
|
|
|
|
|
|
116 |
html = cls_explainer.visualize().data
|
117 |
html = html.replace("#s", "")
|
118 |
html = html.replace("#/s", "")
|
@@ -121,7 +125,7 @@ def xai_attributions_html(input_text: str):
|
|
121 |
return word_attributions, html
|
122 |
|
123 |
|
124 |
-
def explanation_intro(prediction_label: str):
|
125 |
"""
|
126 |
generates model explanaiton markdown from prediction label of the model.
|
127 |
|
@@ -131,7 +135,7 @@ def explanation_intro(prediction_label: str): #TODO: write docstring
|
|
131 |
Returns:
|
132 |
A string
|
133 |
"""
|
134 |
-
return f"""The model predicted the given sentence as **:blue[{prediction_label}]**.
|
135 |
The figure below shows the contribution of each token to this decision.
|
136 |
**:green[Green]** tokens indicate a **positive contribution**, while **:red[red]** tokens indicate a **negative** contribution.
|
137 |
The **bolder** the color, the greater the value."""
|
@@ -150,7 +154,7 @@ def explanation_viz(prediction_label: str, word_attributions):
|
|
150 |
A string
|
151 |
"""
|
152 |
top_attention_word = max(word_attributions, key=itemgetter(1))[0]
|
153 |
-
return f"""The
|
154 |
|
155 |
|
156 |
def word_attributions_dict_creater(word_attributions):
|
@@ -164,9 +168,6 @@ def word_attributions_dict_creater(word_attributions):
|
|
164 |
Returns:
|
165 |
A dictionary with the keys "word", "score", and "colors".
|
166 |
"""
|
167 |
-
word_attributions = word_attributions[1:-1]
|
168 |
-
# remove strings shorter than 1 chrachter
|
169 |
-
word_attributions = [i for i in word_attributions if len(i[0]) > 1]
|
170 |
word_attributions.reverse()
|
171 |
words, scores = zip(*word_attributions)
|
172 |
# colorize positive and negative scores
|
@@ -237,10 +238,11 @@ if submit:
|
|
237 |
st.plotly_chart(label_probs_figure, config=hide_plotly_bar)
|
238 |
explanation_general = explanation_intro(prediction_label)
|
239 |
st.info(explanation_general)
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
|
113 |
"""
|
114 |
|
115 |
word_attributions = cls_explainer(input_text)
|
116 |
+
#remove special tokens
|
117 |
+
word_attributions = word_attributions[1:-1]
|
118 |
+
# remove strings shorter than 1 chrachter
|
119 |
+
word_attributions = [i for i in word_attributions if len(i[0]) > 1]
|
120 |
html = cls_explainer.visualize().data
|
121 |
html = html.replace("#s", "")
|
122 |
html = html.replace("#/s", "")
|
|
|
125 |
return word_attributions, html
|
126 |
|
127 |
|
128 |
+
def explanation_intro(prediction_label: str):
|
129 |
"""
|
130 |
generates model explanaiton markdown from prediction label of the model.
|
131 |
|
|
|
135 |
Returns:
|
136 |
A string
|
137 |
"""
|
138 |
+
return f"""The model predicted the given sentence as **:blue['{prediction_label}']**.
|
139 |
The figure below shows the contribution of each token to this decision.
|
140 |
**:green[Green]** tokens indicate a **positive contribution**, while **:red[red]** tokens indicate a **negative** contribution.
|
141 |
The **bolder** the color, the greater the value."""
|
|
|
154 |
A string
|
155 |
"""
|
156 |
top_attention_word = max(word_attributions, key=itemgetter(1))[0]
|
157 |
+
return f"""The token **_'{top_attention_word}'_** is the biggest driver for the decision of the model as **:blue['{prediction_label}']**."""
|
158 |
|
159 |
|
160 |
def word_attributions_dict_creater(word_attributions):
|
|
|
168 |
Returns:
|
169 |
A dictionary with the keys "word", "score", and "colors".
|
170 |
"""
|
|
|
|
|
|
|
171 |
word_attributions.reverse()
|
172 |
words, scores = zip(*word_attributions)
|
173 |
# colorize positive and negative scores
|
|
|
238 |
st.plotly_chart(label_probs_figure, config=hide_plotly_bar)
|
239 |
explanation_general = explanation_intro(prediction_label)
|
240 |
st.info(explanation_general)
|
241 |
+
with st.spinner():
|
242 |
+
word_attributions, html = xai_attributions_html(input_text)
|
243 |
+
st.markdown(html, unsafe_allow_html=True)
|
244 |
+
explanation_specific = explanation_viz(prediction_label, word_attributions)
|
245 |
+
st.info(explanation_specific)
|
246 |
+
word_attributions_dict = word_attributions_dict_creater(word_attributions)
|
247 |
+
attention_score_figure = attention_score_figure_creater(word_attributions_dict)
|
248 |
+
st.plotly_chart(attention_score_figure, config=hide_plotly_bar)
|
requirements.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
--find-links https://download.pytorch.org/whl/torch_stable.html
|
|
|
|
|
2 |
accelerate
|
3 |
plotly
|
4 |
-
torch==1.13.1+cpu
|
5 |
transformers
|
6 |
transformers-interpret
|
|
|
1 |
--find-links https://download.pytorch.org/whl/torch_stable.html
|
2 |
+
torch==1.13.1+cpu
|
3 |
+
streamlit==1.16.0
|
4 |
accelerate
|
5 |
plotly
|
|
|
6 |
transformers
|
7 |
transformers-interpret
|