|
from transformers_interpret import SequenceClassificationExplainer |
|
from captum.attr import visualization as viz |
|
import html |
|
|
|
|
|
class CustomExplainer(SequenceClassificationExplainer): |
|
def __init__(self, model, tokenizer): |
|
super().__init__(model, tokenizer) |
|
|
|
def visualize(self, html_filepath: str = None, true_class: str = None): |
|
""" |
|
Visualizes word attributions. If in a notebook table will be displayed inline. |
|
|
|
Otherwise pass a valid path to `html_filepath` and the visualization will be saved |
|
as a html file. |
|
|
|
If the true class is known for the text that can be passed to `true_class` |
|
|
|
""" |
|
tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] |
|
attr_class = self.id2label[self.selected_index] |
|
|
|
if self._single_node_output: |
|
if true_class is None: |
|
true_class = round(float(self.pred_probs)) |
|
predicted_class = round(float(self.pred_probs)) |
|
attr_class = round(float(self.pred_probs)) |
|
|
|
else: |
|
if true_class is None: |
|
true_class = self.selected_index |
|
predicted_class = self.predicted_class_name |
|
|
|
score_viz = self.attributions.visualize_attributions( |
|
self.pred_probs, |
|
predicted_class, |
|
true_class, |
|
attr_class, |
|
tokens, |
|
) |
|
|
|
print(score_viz) |
|
|
|
html = viz.visualize_text([score_viz]) |
|
|
|
if html_filepath: |
|
if not html_filepath.endswith(".html"): |
|
html_filepath = html_filepath + ".html" |
|
with open(html_filepath, "w") as html_file: |
|
html_file.write("<meta charset='UTF-8'>" + html.data) |
|
return html |
|
|
|
def merge_attributions(self, token_level_attributions): |
|
final = [] |
|
scores = [] |
|
for i, elem in enumerate(token_level_attributions): |
|
token = elem[0] |
|
score = elem[1] |
|
if token.startswith("##"): |
|
final[-1] = final[-1] + token.replace("##", "") |
|
scores[-1] = scores[-1] + score |
|
else: |
|
final.append(token) |
|
scores.append(score) |
|
attr = [(final[i], scores[i]) for i in range(len(final))] |
|
return attr |
|
|
|
def visualize_wordwise(self, sentence: str, path: str, true_class: str): |
|
pred_class = self.predicted_class_name |
|
if pred_class == true_class: |
|
legend_sent = f"against {pred_class}" |
|
else: |
|
legend_sent = f"against {pred_class} and towards {true_class}" |
|
attribution_weights = self.merge_attributions(self(sentence)) |
|
min_weight = min([float(abs(w)) for _, w in attribution_weights]) |
|
max_weight = max([float(abs(w)) for _, w in attribution_weights]) |
|
attention_html = [] |
|
for word, weight in attribution_weights: |
|
hue = 5 if weight < 0 else 147 |
|
sat = "100%" if weight < 0 else "50%" |
|
|
|
|
|
scaled_weight = (min_weight + abs(weight)) / (max_weight - min_weight) |
|
|
|
|
|
lightness = f"{100 - 50 * scaled_weight}%" |
|
|
|
color = f"hsl({hue},{sat},{lightness})" |
|
|
|
attention_html.append( |
|
f"<span class='word-box' style='background-color: {color};''>{word}</span><span> </span>") |
|
|
|
attention_html = html.unescape("".join(attention_html)) |
|
|
|
final_html = f""" |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Attention Visualization</title> |
|
<style> |
|
span {{ |
|
font-family: sans-serif; |
|
font-size: 16px; |
|
}} |
|
</style> |
|
<style> |
|
/* Color legend */ |
|
.color-legend {{ |
|
display: inline-block; |
|
margin: 10px 0; |
|
padding: 10px 15px; |
|
border: 1px solid #ccc; |
|
border-radius: 5px; |
|
}} |
|
|
|
.word-box {{ |
|
display: inline-block; |
|
border-radius: 5px; |
|
padding: 0.2em; |
|
}} |
|
|
|
.color-legend span {{ |
|
display: inline-block; |
|
margin: 0 5px; |
|
}} |
|
|
|
.positive-weight {{ |
|
color: green; |
|
}} |
|
|
|
.negative-weight {{ |
|
color: red; |
|
}} |
|
|
|
.color-legend span:first-child {{ |
|
margin-left: 0; |
|
}} |
|
</style> |
|
<meta charset="utf-8" /> |
|
</head> |
|
<body> |
|
<div class="color-legend"> |
|
<p>PREDICTED LABEL: <b>{pred_class}</b><br>TRUE LABEL: <b>{true_class}</b></p> |
|
<p><span class='word-box' style='background-color: hsl(5,100%,50%)';>Disagreement</span> ({legend_sent})</p> |
|
<p><span class='word-box' style='background-color: hsl(147,50%,50%)';>Agreement</span> (towards {pred_class})</p> |
|
</div> |
|
<div>{attention_html}</div> |
|
</body> |
|
</html> |
|
""" |
|
|
|
with open(path, "w", encoding="utf-8") as f: |
|
f.write(final_html) |
|
|