from ctypes import DEFAULT_MODE import streamlit as st from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig from ferret import Benchmark from torch.nn.functional import softmax DEFAULT_MODEL = "cardiffnlp/twitter-xlm-roberta-base-sentiment" @st.cache() def get_model(model_name): return AutoModelForSequenceClassification.from_pretrained(model_name) @st.cache() def get_config(model_name): return AutoConfig.from_pretrained(model_name) def get_tokenizer(tokenizer_name): return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) def body(): st.title("Evaluate explanations on dataset samples") st.markdown( """ Let's test how our built-in explainers behave on state-of-the-art datasets for explanability. *ferret* exposes an extensible Dataset API. We currently implement [MovieReviews](https://huggingface.co/datasets/movie_rationales) and [HateXPlain](https://huggingface.co/datasets/hatexplain). In this demo, you let you experiment with HateXPlain. You just need to choose a prediction model and a set of samples to test. We will trigger *ferret* to: 1. download the model; 2. explain every sample you did choose; 3. average all faithfulness and plausibility metrics we support 📊 """ ) col1, col2 = st.columns([3, 1]) with col1: model_name = st.text_input("HF Model", DEFAULT_MODEL) config = AutoConfig.from_pretrained(model_name) with col2: class_labels = list(config.id2label.values()) target = st.selectbox( "Target", options=class_labels, index=1, help="Class label you want to explain.", ) samples_string = st.text_input( "List of samples", "11,6,42", help="List of indices in the dataset, comma-separated.", ) samples = map(int, samples_string.split(",")) compute = st.button("Run") if compute and model_name: with st.spinner("Preparing the magic. Hang in there..."): model = get_model(model_name) tokenizer = get_tokenizer(model_name) bench = Benchmark(model, tokenizer) with st.spinner("Explaining sample (this might take a while)..."): @st.cache() def compute_table(samples): data = bench.load_dataset("hatexplain") sample_evaluations = bench.evaluate_samples(data, samples) table = bench.show_samples_evaluation_table(sample_evaluations) return table table = compute_table(samples) st.markdown("### Averaged metrics") st.dataframe(table) st.caption("Darker colors mean better performance.") # scores = bench.score(text) # scores_str = ", ".join( # [f"{config.id2label[l]}: {s:.2f}" for l, s in enumerate(scores)] # ) # st.text(scores_str) # with st.spinner("Computing Explanations.."): # explanations = bench.explain(text, target=class_labels.index(target)) # st.markdown("### Explanations") # st.dataframe(bench.show_table(explanations)) # st.caption("Darker red (blue) means higher (lower) contribution.") # with st.spinner("Evaluating Explanations..."): # evaluations = bench.evaluate_explanations( # explanations, target=class_labels.index(target), apply_style=False # ) # st.markdown("### Faithfulness Metrics") # st.dataframe(bench.show_evaluation_table(evaluations)) # st.caption("Darker colors mean better performance.") st.markdown( """ **Legend** - **AOPC Comprehensiveness** (aopc_compr) measures *comprehensiveness*, i.e., if the explanation captures all the tokens needed to make the prediction. Higher is better. - **AOPC Sufficiency** (aopc_suff) measures *sufficiency*, i.e., if the relevant tokens in the explanation are sufficient to make the prediction. Lower is better. - **Leave-On-Out TAU Correlation** (taucorr_loo) measures the Kendall rank correlation coefficient τ between the explanation and leave-one-out importances. Closer to 1 is better. See the paper for details. """ ) st.markdown( """ **In code, it would be as simple as** """ ) st.code( f""" from transformers import AutoModelForSequenceClassification, AutoTokenizer from ferret import Benchmark model = AutoModelForSequenceClassification.from_pretrained("{model_name}") tokenizer = AutoTokenizer.from_pretrained("{model_name}") bench = Benchmark(model, tokenizer) data = bench.load_dataset("hatexplain") evaluations = bench.evaluate_samples(data, {samples}) bench.show_samples_evaluation_table(evaluations) """, language="python", )