Detsutut's picture
Upload 2 files
f388ec1 verified
raw
history blame
5.6 kB
from transformers_interpret import SequenceClassificationExplainer
from captum.attr import visualization as viz
import html
class CustomExplainer(SequenceClassificationExplainer):
def __init__(self, model, tokenizer):
super().__init__(model, tokenizer)
def visualize(self, html_filepath: str = None, true_class: str = None):
"""
Visualizes word attributions. If in a notebook table will be displayed inline.
Otherwise pass a valid path to `html_filepath` and the visualization will be saved
as a html file.
If the true class is known for the text that can be passed to `true_class`
"""
tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
attr_class = self.id2label[self.selected_index]
if self._single_node_output:
if true_class is None:
true_class = round(float(self.pred_probs))
predicted_class = round(float(self.pred_probs))
attr_class = round(float(self.pred_probs))
else:
if true_class is None:
true_class = self.selected_index
predicted_class = self.predicted_class_name
score_viz = self.attributions.visualize_attributions( # type: ignore
self.pred_probs,
predicted_class,
true_class,
attr_class,
tokens,
)
print(score_viz)
html = viz.visualize_text([score_viz])
if html_filepath:
if not html_filepath.endswith(".html"):
html_filepath = html_filepath + ".html"
with open(html_filepath, "w") as html_file:
html_file.write("<meta charset='UTF-8'>" + html.data)
return html
def merge_attributions(self, token_level_attributions):
final = []
scores = []
for i, elem in enumerate(token_level_attributions):
token = elem[0]
score = elem[1]
if token.startswith("##"):
final[-1] = final[-1] + token.replace("##", "")
scores[-1] = scores[-1] + score
else:
final.append(token)
scores.append(score)
attr = [(final[i], scores[i]) for i in range(len(final))]
return attr
def visualize_wordwise(self, sentence: str, path: str, true_class: str):
pred_class = self.predicted_class_name
if pred_class == true_class:
legend_sent = f"against {pred_class}"
else:
legend_sent = f"against {pred_class} and towards {true_class}"
attribution_weights = self.merge_attributions(self(sentence))
min_weight = min([float(abs(w)) for _, w in attribution_weights])
max_weight = max([float(abs(w)) for _, w in attribution_weights])
attention_html = []
for word, weight in attribution_weights:
hue = 5 if weight < 0 else 147
sat = "100%" if weight < 0 else "50%"
# Logarithmic mapping to scale weight values
scaled_weight = (min_weight + abs(weight)) / (max_weight - min_weight)
# Adjust brightness and saturation for better contrast
lightness = f"{100 - 50 * scaled_weight}%"
color = f"hsl({hue},{sat},{lightness})"
attention_html.append(
f"<span class='word-box' style='background-color: {color};''>{word}</span><span>&nbsp;</span>")
attention_html = html.unescape("".join(attention_html))
final_html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Attention Visualization</title>
<style>
span {{
font-family: sans-serif;
font-size: 16px;
}}
</style>
<style>
/* Color legend */
.color-legend {{
display: inline-block;
margin: 10px 0;
padding: 10px 15px;
border: 1px solid #ccc;
border-radius: 5px;
}}
.word-box {{
display: inline-block;
border-radius: 5px;
padding: 0.2em;
}}
.color-legend span {{
display: inline-block;
margin: 0 5px;
}}
.positive-weight {{
color: green;
}}
.negative-weight {{
color: red;
}}
.color-legend span:first-child {{
margin-left: 0;
}}
</style>
<meta charset="utf-8" />
</head>
<body>
<div class="color-legend">
<p>PREDICTED LABEL: <b>{pred_class}</b><br>TRUE LABEL: <b>{true_class}</b></p>
<p><span class='word-box' style='background-color: hsl(5,100%,50%)';>Disagreement</span> ({legend_sent})</p>
<p><span class='word-box' style='background-color: hsl(147,50%,50%)';>Agreement</span> (towards {pred_class})</p>
</div>
<div>{attention_html}</div>
</body>
</html>
"""
with open(path, "w", encoding="utf-8") as f:
f.write(final_html)