Spaces:
Sleeping
Sleeping
"""XAI for Transformers Intent Classifier App.""" | |
from collections import Counter | |
from itertools import count | |
from operator import itemgetter | |
from re import DOTALL, sub | |
import streamlit as st | |
from plotly.express import bar | |
from transformers import (AutoModelForSequenceClassification, AutoTokenizer, | |
pipeline) | |
from transformers_interpret import SequenceClassificationExplainer | |
hide_streamlit_style = """ | |
<style> | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
</style> | |
""" | |
hide_plotly_bar = {"displayModeBar": False} | |
st.markdown(hide_streamlit_style, unsafe_allow_html=True) | |
repo_id = "remzicam/privacy_intent" | |
task = "text-classification" | |
title = "XAI for Intent Classification and Model Interpretation" | |
st.markdown( | |
f"<h1 style='text-align: center; color: #0068C9;'>{title}</h1>", unsafe_allow_html=True | |
) | |
def load_models(): | |
""" | |
It loads the model and tokenizer from the HuggingFace model hub, and then creates a pipeline object | |
that can be used to make predictions. Also, it creates model interpretation object. | |
Returns: | |
the privacy_intent_pipe and cls_explainer. | |
""" | |
model = AutoModelForSequenceClassification.from_pretrained( | |
repo_id, low_cpu_mem_usage=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
privacy_intent_pipe = pipeline( | |
task, model=model, tokenizer=tokenizer, return_all_scores=True | |
) | |
cls_explainer = SequenceClassificationExplainer(model, tokenizer) | |
return privacy_intent_pipe, cls_explainer | |
privacy_intent_pipe, cls_explainer = load_models() | |
def label_probs_figure_creater(input_text:str): | |
""" | |
It takes in a string, runs it through the pipeline, and returns a figure and the label with the | |
highest probability | |
Args: | |
input_text (str): The text you want to analyze | |
Returns: | |
A tuple of a figure and a string. | |
""" | |
outputs = privacy_intent_pipe(input_text)[0] | |
sorted_outputs = sorted(outputs, key=lambda k: k["score"]) | |
prediction_label = sorted_outputs[-1]["label"] | |
fig = bar( | |
sorted_outputs, | |
x="score", | |
y="label", | |
color="score", | |
color_continuous_scale="rainbow", | |
width=600, | |
height=400, | |
) | |
fig.update_layout( | |
title="Model Prediction Probabilities for Each Label", | |
xaxis_title="", | |
yaxis_title="", | |
xaxis=dict( # attribures for x axis | |
showline=True, | |
showgrid=True, | |
linecolor="black", | |
tickfont=dict(family="Calibri"), | |
), | |
yaxis=dict( # attribures for y axis | |
showline=True, | |
showgrid=True, | |
linecolor="black", | |
tickfont=dict( | |
family="Times New Roman", | |
), | |
), | |
plot_bgcolor="white", | |
title_x=0.5, | |
) | |
return fig, prediction_label | |
def xai_attributions_html(input_text: str): | |
""" | |
1. The function takes in a string of text as input. | |
2. It then uses the explainer to generate attributions for each word in the input text. | |
3. It then uses the explainer to generate an HTML visualization of the attributions. | |
4. It then cleans up the HTML visualization by removing some unnecessary HTML tags. | |
5. It then returns the attributions and the HTML visualization | |
Args: | |
input_text (str): The text you want to explain. | |
Returns: | |
the word attributions and the html. | |
""" | |
word_attributions = cls_explainer(input_text) | |
#remove special tokens | |
word_attributions = word_attributions[1:-1] | |
# remove strings shorter than 1 chrachter | |
word_attributions = [i for i in word_attributions if len(i[0]) > 1] | |
html = cls_explainer.visualize().data | |
html = html.replace("#s", "") | |
html = html.replace("#/s", "") | |
html = sub("<th.*?/th>", "", html, 4, DOTALL) | |
html = sub("<td.*?/td>", "", html, 4, DOTALL) | |
return word_attributions, html+"<br>" | |
def explanation_intro(prediction_label: str): | |
""" | |
generates model explanaiton html markdown from prediction label of the model. | |
Args: | |
prediction_label (str): The label that the model predicted. | |
Returns: | |
A string | |
""" | |
return f"""<div style="background-color: lightblue; | |
color: rgb(0, 66, 128);">The model predicted the given sentence as <span style="color: black"><strong>'{prediction_label}'</strong></span>. | |
The figure below shows the contribution of each token to this decision. | |
<span style="color: darkgreen"><strong> Green </strong></span> tokens indicate a <strong>positive </strong> contribution, while <span style="color: red"><strong> red </strong></span> tokens indicate a <strong>negative</strong> contribution. | |
The <strong>bolder</strong> the color, the greater the value.</div><br>""" | |
def explanation_viz(prediction_label: str, word_attributions): | |
""" | |
It takes in a prediction label and a list of word attributions, and returns a markdown string that contains | |
the word that had the highest attribution and the prediction label | |
Args: | |
prediction_label (str): The label that the model predicted. | |
word_attributions: a list of tuples of the form (word, attribution score) | |
Returns: | |
A string | |
""" | |
top_attention_word = max(word_attributions, key=itemgetter(1))[0] | |
return f"""The token **_'{top_attention_word}'_** is the biggest driver for the decision of the model as **'{prediction_label}'**""" | |
def word_attributions_dict_creater(word_attributions): | |
""" | |
It takes a list of tuples, reverses it, splits it into two lists, colors the scores, numerates | |
duplicated strings, and returns a dictionary | |
Args: | |
word_attributions: This is the output of the model explainer. | |
Returns: | |
A dictionary with the keys "word", "score", and "colors". | |
""" | |
word_attributions.reverse() | |
words, scores = zip(*word_attributions) | |
# colorize positive and negative scores | |
colors = ["red" if x < 0 else "lightgreen" for x in scores] | |
# darker tone for max score | |
max_index = scores.index(max(scores)) | |
colors[max_index] = "darkgreen" | |
# numerate duplicated strings | |
c = Counter(words) | |
iters = {k: count(1) for k, v in c.items() if v > 1} | |
words_ = [x + "_" + str(next(iters[x])) if x in iters else x for x in words] | |
# plotly accepts dictionaries | |
return { | |
"word": words_, | |
"score": scores, | |
"colors": colors, | |
} | |
def attention_score_figure_creater(word_attributions_dict): | |
""" | |
It takes a dictionary of words and their attention scores and returns a bar graph of the words and | |
their attention scores with specified colors. | |
Args: | |
word_attributions_dict: a dictionary with keys "word", "score", and "colors" | |
Returns: | |
A figure object | |
""" | |
fig = bar(word_attributions_dict, x="score", y="word", width=400, height=500) | |
fig.update_traces(marker_color=word_attributions_dict["colors"]) | |
fig.update_layout( | |
title="Word-Attention Score", | |
xaxis_title="", | |
yaxis_title="", | |
xaxis=dict( # attribures for x axis | |
showline=True, | |
showgrid=True, | |
linecolor="black", | |
tickfont=dict(family="Calibri"), | |
), | |
yaxis=dict( # attribures for y axis | |
showline=True, | |
showgrid=True, | |
linecolor="black", | |
tickfont=dict( | |
family="Times New Roman", | |
), | |
), | |
plot_bgcolor="white", | |
title_x=0.5, | |
) | |
return fig | |
form = st.form(key="intent-form") | |
input_text = form.text_area( | |
label="Text", | |
value="At any time during your use of the Services, you may decide to share some information or content publicly or privately.", | |
) | |
submit = form.form_submit_button("Submit") | |
if submit: | |
label_probs_figure, prediction_label = label_probs_figure_creater(input_text) | |
st.plotly_chart(label_probs_figure, config=hide_plotly_bar) | |
explanation_general = explanation_intro(prediction_label) | |
st.markdown(explanation_general, unsafe_allow_html=True) | |
with st.spinner(): | |
word_attributions, html = xai_attributions_html(input_text) | |
st.markdown(html, unsafe_allow_html=True) | |
explanation_specific = explanation_viz(prediction_label, word_attributions) | |
st.info(explanation_specific) | |
word_attributions_dict = word_attributions_dict_creater(word_attributions) | |
attention_score_figure = attention_score_figure_creater(word_attributions_dict) | |
st.plotly_chart(attention_score_figure, config=hide_plotly_bar) | |