Spaces:
Running
Running
File size: 8,629 Bytes
094135a f34a8cd 094135a 98b120b 094135a f34a8cd 094135a 98b120b 094135a f34a8cd 094135a f34a8cd 094135a f34a8cd 094135a f34a8cd 094135a 98b120b 1c7eb00 98b120b 094135a 98b120b 094135a f34a8cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import gradio as gr
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import display, HTML
from transformers import AutoTokenizer
from DecompX.src.decompx_utils import DecompXConfig
from DecompX.src.modeling_bert import BertForSequenceClassification
from DecompX.src.modeling_roberta import RobertaForSequenceClassification
plt.style.use("ggplot")
MODELS = ['TehranNLP-org/bert-base-uncased-cls-sst2', 'TehranNLP-org/bert-large-sst2', "WillHeld/roberta-base-sst2"]
def plot_clf(tokens, logits, label_names, title="", file_name=None):
print(tokens)
plt.figure(figsize=(4.5, 5))
colors = ["#019875" if l else "#B8293D" for l in (logits >= 0)]
plt.barh(range(len(tokens)), logits, color=colors)
plt.axvline(0, color='black', ls='-', lw=2, alpha=0.2)
plt.gca().invert_yaxis()
max_limit = np.max(np.abs(logits)) + 0.2
min_limit = -0.01 if np.min(logits) > 0 else -max_limit
plt.xlim(min_limit, max_limit)
plt.gca().set_xticks([min_limit, max_limit])
plt.gca().set_xticklabels(label_names, fontsize=14, fontweight="bold")
plt.gca().set_yticks(range(len(tokens)))
plt.gca().set_yticklabels(tokens)
plt.gca().yaxis.tick_right()
for xtick, color in zip(plt.gca().get_yticklabels(), colors):
xtick.set_color(color)
xtick.set_fontweight("bold")
xtick.set_verticalalignment("center")
for xtick, color in zip(plt.gca().get_xticklabels(), ["#B8293D", "#019875"]):
xtick.set_color(color)
# plt.title(title, fontsize=14, fontweight="bold")
plt.title(title)
plt.tight_layout()
def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False):
"""
importance: (sent_len)
"""
if no_cls_sep:
importance = importance[1:-1]
tokenized_text = tokenized_text[1:-1]
importance = importance / np.abs(importance).max() / 1.5 # Normalize
if discrete:
importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6
html = "<pre style='color:black; padding: 3px;'>"+prefix
for i in range(len(tokenized_text)):
if importance[i] >= 0:
rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i]) # Wistia
else:
rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i])) # Wistia
text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
html += (f"<span style='"
f"{color}"
f"color:black; border-radius: 5px; padding: 3px;"
f"font-weight: {int(800)};"
"'>")
html += tokenized_text[i].replace('<', "[").replace(">", "]")
html += "</span> "
html += "</pre>"
# display(HTML(html))
return html
def print_preview(decompx_outputs_df, idx=0, discrete=False):
html = ""
NO_CLS_SEP = False
df = decompx_outputs_df
for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]:
if col in df and df[col][idx] is not None:
if "aggregated" in col:
sentence_importance = df[col].iloc[idx][0, :]
if "classifier" in col:
for label in range(df[col].iloc[idx].shape[-1]):
sentence_importance = df[col].iloc[idx][:, label]
html += print_importance(
sentence_importance,
df["tokens"].iloc[idx],
prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20),
no_cls_sep=NO_CLS_SEP,
discrete=False
)
break
sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]]
html += print_importance(
sentence_importance,
df["tokens"].iloc[idx],
prefix=f"{col.split('_')[-1]}:".ljust(20),
no_cls_sep=NO_CLS_SEP,
discrete=discrete
)
return "<div style='overflow:auto; background-color:white; padding: 10px;'>" + html
def run_decompx(text, model):
"""
Provide DecompX Token Explanation of Model on Text
"""
SENTENCES = [text, "nothing"]
CONFIGS = {
"DecompX":
DecompXConfig(
include_biases=True,
bias_decomp_type="absdot",
include_LN1=True,
include_FFN=True,
FFN_approx_type="GeLU_ZO",
include_LN2=True,
aggregation="vector",
include_classifier_w_pooler=True,
tanh_approx_type="ZO",
output_all_layers=True,
output_attention=None,
output_res1=None,
output_LN1=None,
output_FFN=None,
output_res2=None,
output_encoder=None,
output_aggregated="norm",
output_pooler="norm",
output_classifier=True,
),
}
MODEL = model
# LOAD MODEL AND TOKENIZER
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenized_sentence = tokenizer(SENTENCES, return_tensors="pt", padding=True)
batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1)
if "roberta" in MODEL:
model = RobertaForSequenceClassification.from_pretrained(MODEL)
elif "bert" in MODEL:
model = BertForSequenceClassification.from_pretrained(MODEL)
else:
raise Exception(f"Not implented model: {MODEL}")
# RUN DECOMPX
with torch.no_grad():
model.eval()
logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model(
**tokenized_sentence,
output_attentions=False,
return_dict=False,
output_hidden_states=True,
decompx_config=CONFIGS["DecompX"]
)
decompx_outputs = {
"tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))],
"logits": logits.cpu().detach().numpy().tolist(), # (batch, classes)
"cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim)
}
### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ###
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes)
importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
decompx_outputs["importance_last_layer_classifier"] = importance
### decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len)
importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len)
importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
decompx_outputs["importance_all_layers_aggregated"] = importance
decompx_outputs_df = pd.DataFrame(decompx_outputs)
idx = 0
pred_label = np.argmax(decompx_outputs_df.iloc[idx]["logits"], axis=-1)
label = decompx_outputs_df.iloc[idx]["importance_last_layer_classifier"][:, pred_label]
tokens = decompx_outputs_df.iloc[idx]["tokens"][1:-1]
label = label[1:-1]
label = label / np.max(np.abs(label))
plot_clf(tokens, label, ['-','+'], title=f"DecompX for Predicted Label: {pred_label}", file_name="example_sst2_our_method")
return plt, print_preview(decompx_outputs_df)
demo = gr.Interface(
fn=run_decompx,
inputs=[
gr.components.Textbox(label="Text"),
gr.components.Dropdown(label="Model", choices=MODELS),
],
outputs=["plot", "html"],
examples=[
["a good piece of work more often than not.", "TehranNLP-org/bert-base-uncased-cls-sst2"],
["a good piece of work more often than not.", "TehranNLP-org/bert-large-sst2"],
["a good piece of work more often than not.", "WillHeld/roberta-base-sst2"],
["A deep and meaningful film.", "TehranNLP-org/bert-base-uncased-cls-sst2"],
],
cache_examples=True,
title="DecompX Demo",
description="This is a demo for the ACL 2023 paper [DecompX](https://github.com/mohsenfayyaz/DecompX/)"
)
demo.launch() |