Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
import pandas as pd | |
from tqdm.auto import tqdm | |
import matplotlib.pyplot as plt | |
import matplotlib | |
from IPython.display import display, HTML | |
from transformers import AutoTokenizer | |
from DecompX.src.decompx_utils import DecompXConfig | |
from DecompX.src.modeling_bert import BertForSequenceClassification | |
from DecompX.src.modeling_roberta import RobertaForSequenceClassification | |
plt.style.use("ggplot") | |
MODELS = ['TehranNLP-org/bert-base-uncased-cls-sst2', 'TehranNLP-org/bert-large-sst2', "WillHeld/roberta-base-sst2"] | |
def plot_clf(tokens, logits, label_names, title="", file_name=None): | |
print(tokens) | |
plt.figure(figsize=(4.5, 5)) | |
colors = ["#019875" if l else "#B8293D" for l in (logits >= 0)] | |
plt.barh(range(len(tokens)), logits, color=colors) | |
plt.axvline(0, color='black', ls='-', lw=2, alpha=0.2) | |
plt.gca().invert_yaxis() | |
max_limit = np.max(np.abs(logits)) + 0.2 | |
min_limit = -0.01 if np.min(logits) > 0 else -max_limit | |
plt.xlim(min_limit, max_limit) | |
plt.gca().set_xticks([min_limit, max_limit]) | |
plt.gca().set_xticklabels(label_names, fontsize=14, fontweight="bold") | |
plt.gca().set_yticks(range(len(tokens))) | |
plt.gca().set_yticklabels(tokens) | |
plt.gca().yaxis.tick_right() | |
for xtick, color in zip(plt.gca().get_yticklabels(), colors): | |
xtick.set_color(color) | |
xtick.set_fontweight("bold") | |
xtick.set_verticalalignment("center") | |
for xtick, color in zip(plt.gca().get_xticklabels(), ["#B8293D", "#019875"]): | |
xtick.set_color(color) | |
# plt.title(title, fontsize=14, fontweight="bold") | |
plt.title(title) | |
plt.tight_layout() | |
def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False): | |
""" | |
importance: (sent_len) | |
""" | |
if no_cls_sep: | |
importance = importance[1:-1] | |
tokenized_text = tokenized_text[1:-1] | |
importance = importance / np.abs(importance).max() / 1.5 # Normalize | |
if discrete: | |
importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6 | |
html = "<pre style='color:black; padding: 3px;'>"+prefix | |
for i in range(len(tokenized_text)): | |
if importance[i] >= 0: | |
rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i]) # Wistia | |
else: | |
rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i])) # Wistia | |
text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else "" | |
color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color | |
html += (f"<span style='" | |
f"{color}" | |
f"color:black; border-radius: 5px; padding: 3px;" | |
f"font-weight: {int(800)};" | |
"'>") | |
html += tokenized_text[i].replace('<', "[").replace(">", "]") | |
html += "</span> " | |
html += "</pre>" | |
# display(HTML(html)) | |
return html | |
def print_preview(decompx_outputs_df, idx=0, discrete=False): | |
html = "" | |
NO_CLS_SEP = False | |
df = decompx_outputs_df | |
for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]: | |
if col in df and df[col][idx] is not None: | |
if "aggregated" in col: | |
sentence_importance = df[col].iloc[idx][0, :] | |
if "classifier" in col: | |
for label in range(df[col].iloc[idx].shape[-1]): | |
sentence_importance = df[col].iloc[idx][:, label] | |
html += print_importance( | |
sentence_importance, | |
df["tokens"].iloc[idx], | |
prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20), | |
no_cls_sep=NO_CLS_SEP, | |
discrete=False | |
) | |
break | |
sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]] | |
html += print_importance( | |
sentence_importance, | |
df["tokens"].iloc[idx], | |
prefix=f"{col.split('_')[-1]}:".ljust(20), | |
no_cls_sep=NO_CLS_SEP, | |
discrete=discrete | |
) | |
return "<div style='overflow:auto; background-color:white; padding: 10px;'>" + html | |
def run_decompx(text, model): | |
""" | |
Provide DecompX Token Explanation of Model on Text | |
""" | |
SENTENCES = [text, "nothing"] | |
CONFIGS = { | |
"DecompX": | |
DecompXConfig( | |
include_biases=True, | |
bias_decomp_type="absdot", | |
include_LN1=True, | |
include_FFN=True, | |
FFN_approx_type="GeLU_ZO", | |
include_LN2=True, | |
aggregation="vector", | |
include_classifier_w_pooler=True, | |
tanh_approx_type="ZO", | |
output_all_layers=True, | |
output_attention=None, | |
output_res1=None, | |
output_LN1=None, | |
output_FFN=None, | |
output_res2=None, | |
output_encoder=None, | |
output_aggregated="norm", | |
output_pooler="norm", | |
output_classifier=True, | |
), | |
} | |
MODEL = model | |
# LOAD MODEL AND TOKENIZER | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
tokenized_sentence = tokenizer(SENTENCES, return_tensors="pt", padding=True) | |
batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1) | |
if "roberta" in MODEL: | |
model = RobertaForSequenceClassification.from_pretrained(MODEL) | |
elif "bert" in MODEL: | |
model = BertForSequenceClassification.from_pretrained(MODEL) | |
else: | |
raise Exception(f"Not implented model: {MODEL}") | |
# RUN DECOMPX | |
with torch.no_grad(): | |
model.eval() | |
logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model( | |
**tokenized_sentence, | |
output_attentions=False, | |
return_dict=False, | |
output_hidden_states=True, | |
decompx_config=CONFIGS["DecompX"] | |
) | |
decompx_outputs = { | |
"tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))], | |
"logits": logits.cpu().detach().numpy().tolist(), # (batch, classes) | |
"cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim) | |
} | |
### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ### | |
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes) | |
importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))] | |
decompx_outputs["importance_last_layer_classifier"] = importance | |
### decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ### | |
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len) | |
importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len) | |
importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))] | |
decompx_outputs["importance_all_layers_aggregated"] = importance | |
decompx_outputs_df = pd.DataFrame(decompx_outputs) | |
idx = 0 | |
pred_label = np.argmax(decompx_outputs_df.iloc[idx]["logits"], axis=-1) | |
label = decompx_outputs_df.iloc[idx]["importance_last_layer_classifier"][:, pred_label] | |
tokens = decompx_outputs_df.iloc[idx]["tokens"][1:-1] | |
label = label[1:-1] | |
label = label / np.max(np.abs(label)) | |
plot_clf(tokens, label, ['-','+'], title=f"DecompX for Predicted Label: {pred_label}", file_name="example_sst2_our_method") | |
return plt, print_preview(decompx_outputs_df) | |
demo = gr.Interface( | |
fn=run_decompx, | |
inputs=[ | |
gr.components.Textbox(label="Text"), | |
gr.components.Dropdown(label="Model", choices=MODELS), | |
], | |
outputs=["plot", "html"], | |
examples=[ | |
["a good piece of work more often than not.", "TehranNLP-org/bert-base-uncased-cls-sst2"], | |
["a good piece of work more often than not.", "TehranNLP-org/bert-large-sst2"], | |
["a good piece of work more often than not.", "WillHeld/roberta-base-sst2"], | |
["A deep and meaningful film.", "TehranNLP-org/bert-base-uncased-cls-sst2"], | |
], | |
cache_examples=True, | |
title="DecompX Demo", | |
description="This is a demo for the ACL 2023 paper [DecompX](https://github.com/mohsenfayyaz/DecompX/)" | |
) | |
demo.launch() |