import gradio as gr import torch import numpy as np import pandas as pd from tqdm.auto import tqdm import matplotlib.pyplot as plt import matplotlib from IPython.display import display, HTML from transformers import AutoTokenizer from DecompX.src.decompx_utils import DecompXConfig from DecompX.src.modeling_bert import BertForSequenceClassification from DecompX.src.modeling_roberta import RobertaForSequenceClassification plt.style.use("ggplot") MODELS = ['TehranNLP-org/bert-base-uncased-cls-sst2', 'TehranNLP-org/bert-large-sst2', "WillHeld/roberta-base-sst2"] def plot_clf(tokens, logits, label_names, title="", file_name=None): print(tokens) plt.figure(figsize=(4.5, 5)) colors = ["#019875" if l else "#B8293D" for l in (logits >= 0)] plt.barh(range(len(tokens)), logits, color=colors) plt.axvline(0, color='black', ls='-', lw=2, alpha=0.2) plt.gca().invert_yaxis() max_limit = np.max(np.abs(logits)) + 0.2 min_limit = -0.01 if np.min(logits) > 0 else -max_limit plt.xlim(min_limit, max_limit) plt.gca().set_xticks([min_limit, max_limit]) plt.gca().set_xticklabels(label_names, fontsize=14, fontweight="bold") plt.gca().set_yticks(range(len(tokens))) plt.gca().set_yticklabels(tokens) plt.gca().yaxis.tick_right() for xtick, color in zip(plt.gca().get_yticklabels(), colors): xtick.set_color(color) xtick.set_fontweight("bold") xtick.set_verticalalignment("center") for xtick, color in zip(plt.gca().get_xticklabels(), ["#B8293D", "#019875"]): xtick.set_color(color) # plt.title(title, fontsize=14, fontweight="bold") plt.title(title) plt.tight_layout() def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False): """ importance: (sent_len) """ if no_cls_sep: importance = importance[1:-1] tokenized_text = tokenized_text[1:-1] importance = importance / np.abs(importance).max() / 1.5 # Normalize if discrete: importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6 html = "
"+prefix
    for i in range(len(tokenized_text)):
        if importance[i] >= 0:
            rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i])   # Wistia
        else:
            rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i]))   # Wistia
        text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
        color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
        html += (f"")
        html += tokenized_text[i].replace('<', "[").replace(">", "]")
        html += " "
    html += "
" # display(HTML(html)) return html def print_preview(decompx_outputs_df, idx=0, discrete=False): html = "" NO_CLS_SEP = False df = decompx_outputs_df for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]: if col in df and df[col][idx] is not None: if "aggregated" in col: sentence_importance = df[col].iloc[idx][0, :] if "classifier" in col: for label in range(df[col].iloc[idx].shape[-1]): sentence_importance = df[col].iloc[idx][:, label] html += print_importance( sentence_importance, df["tokens"].iloc[idx], prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20), no_cls_sep=NO_CLS_SEP, discrete=False ) break sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]] html += print_importance( sentence_importance, df["tokens"].iloc[idx], prefix=f"{col.split('_')[-1]}:".ljust(20), no_cls_sep=NO_CLS_SEP, discrete=discrete ) return "
" + html def run_decompx(text, model): """ Provide DecompX Token Explanation of Model on Text """ SENTENCES = [text, "nothing"] CONFIGS = { "DecompX": DecompXConfig( include_biases=True, bias_decomp_type="absdot", include_LN1=True, include_FFN=True, FFN_approx_type="GeLU_ZO", include_LN2=True, aggregation="vector", include_classifier_w_pooler=True, tanh_approx_type="ZO", output_all_layers=True, output_attention=None, output_res1=None, output_LN1=None, output_FFN=None, output_res2=None, output_encoder=None, output_aggregated="norm", output_pooler="norm", output_classifier=True, ), } MODEL = model # LOAD MODEL AND TOKENIZER tokenizer = AutoTokenizer.from_pretrained(MODEL) tokenized_sentence = tokenizer(SENTENCES, return_tensors="pt", padding=True) batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1) if "roberta" in MODEL: model = RobertaForSequenceClassification.from_pretrained(MODEL) elif "bert" in MODEL: model = BertForSequenceClassification.from_pretrained(MODEL) else: raise Exception(f"Not implented model: {MODEL}") # RUN DECOMPX with torch.no_grad(): model.eval() logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model( **tokenized_sentence, output_attentions=False, return_dict=False, output_hidden_states=True, decompx_config=CONFIGS["DecompX"] ) decompx_outputs = { "tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))], "logits": logits.cpu().detach().numpy().tolist(), # (batch, classes) "cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim) } ### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ### importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes) importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))] decompx_outputs["importance_last_layer_classifier"] = importance ### decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ### importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len) importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len) importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))] decompx_outputs["importance_all_layers_aggregated"] = importance decompx_outputs_df = pd.DataFrame(decompx_outputs) idx = 0 pred_label = np.argmax(decompx_outputs_df.iloc[idx]["logits"], axis=-1) label = decompx_outputs_df.iloc[idx]["importance_last_layer_classifier"][:, pred_label] tokens = decompx_outputs_df.iloc[idx]["tokens"][1:-1] label = label[1:-1] label = label / np.max(np.abs(label)) plot_clf(tokens, label, ['-','+'], title=f"DecompX for Predicted Label: {pred_label}", file_name="example_sst2_our_method") return plt, print_preview(decompx_outputs_df) demo = gr.Interface( fn=run_decompx, inputs=[ gr.components.Textbox(label="Text"), gr.components.Dropdown(label="Model", choices=MODELS), ], outputs=["plot", "html"], examples=[ ["a good piece of work more often than not.", "TehranNLP-org/bert-base-uncased-cls-sst2"], ["a good piece of work more often than not.", "TehranNLP-org/bert-large-sst2"], ["a good piece of work more often than not.", "WillHeld/roberta-base-sst2"], ["A deep and meaningful film.", "TehranNLP-org/bert-base-uncased-cls-sst2"], ], cache_examples=True, title="DecompX Demo", description="This is a demo for the ACL 2023 paper [DecompX](https://github.com/mohsenfayyaz/DecompX/)" ) demo.launch()