Spaces:
Runtime error
Runtime error
from ctypes import DEFAULT_MODE | |
import streamlit as st | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
from ferret import Benchmark | |
from torch.nn.functional import softmax | |
from copy import deepcopy | |
DEFAULT_MODEL = "Hate-speech-CNERG/bert-base-uncased-hatexplain" | |
DEFAULT_SAMPLES = "3,5,8,13,15,17,18,25,27,28" | |
def get_model(model_name): | |
return AutoModelForSequenceClassification.from_pretrained(model_name) | |
def get_config(model_name): | |
return AutoConfig.from_pretrained(model_name) | |
def get_tokenizer(tokenizer_name): | |
return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) | |
def body(): | |
st.title("Evaluate explanations on dataset samples") | |
st.markdown( | |
""" | |
Let's test how our built-in explainers behave on state-of-the-art datasets for explanability. | |
*ferret* exposes an extensible Dataset API. We currently implement [MovieReviews](https://huggingface.co/datasets/movie_rationales) and [HateXPlain](https://huggingface.co/datasets/hatexplain). | |
In this demo, you let you experiment with HateXPlain. | |
You just need to choose a prediction model and a set of samples to test. | |
We will trigger *ferret* to: | |
1. download the model; | |
2. explain every sample you did choose; | |
3. average all faithfulness and plausibility metrics we support 📊 | |
""" | |
) | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
model_name = st.text_input("HF Model", DEFAULT_MODEL) | |
config = AutoConfig.from_pretrained(model_name) | |
with col2: | |
class_labels = list(config.id2label.values()) | |
target = st.selectbox( | |
"Target", | |
options=class_labels, | |
index=1, | |
help="Class label you want to explain.", | |
) | |
samples_string = st.text_input( | |
"List of samples", | |
DEFAULT_SAMPLES, | |
help="List of indices in the dataset, comma-separated.", | |
) | |
compute = st.button("Run") | |
samples = list(map(int, samples_string.replace(" ", "").split(","))) | |
if compute and model_name: | |
with st.spinner("Preparing the magic. Hang in there..."): | |
model = get_model(model_name) | |
tokenizer = get_tokenizer(model_name) | |
bench = Benchmark(model, tokenizer) | |
with st.spinner("Explaining sample (this might take a while)..."): | |
def compute_table(samples): | |
data = bench.load_dataset("hatexplain") | |
sample_evaluations = bench.evaluate_samples( | |
data, samples, target=class_labels.index(target) | |
) | |
table = bench.show_samples_evaluation_table(sample_evaluations).format( | |
"{:.2f}" | |
) | |
return table | |
table = compute_table(samples) | |
st.markdown("### Averaged metrics") | |
st.dataframe(table) | |
st.caption("Darker colors mean better performance.") | |
# scores = bench.score(text) | |
# scores_str = ", ".join( | |
# [f"{config.id2label[l]}: {s:.2f}" for l, s in enumerate(scores)] | |
# ) | |
# st.text(scores_str) | |
# with st.spinner("Computing Explanations.."): | |
# explanations = bench.explain(text, target=class_labels.index(target)) | |
# st.markdown("### Explanations") | |
# st.dataframe(bench.show_table(explanations)) | |
# st.caption("Darker red (blue) means higher (lower) contribution.") | |
# with st.spinner("Evaluating Explanations..."): | |
# evaluations = bench.evaluate_explanations( | |
# explanations, target=class_labels.index(target), apply_style=False | |
# ) | |
# st.markdown("### Faithfulness Metrics") | |
# st.dataframe(bench.show_evaluation_table(evaluations)) | |
# st.caption("Darker colors mean better performance.") | |
st.markdown( | |
""" | |
**Legend** | |
- **AOPC Comprehensiveness** (aopc_compr) measures *comprehensiveness*, i.e., if the explanation captures all the tokens needed to make the prediction. Higher is better. | |
- **AOPC Sufficiency** (aopc_suff) measures *sufficiency*, i.e., if the relevant tokens in the explanation are sufficient to make the prediction. Lower is better. | |
- **Leave-On-Out TAU Correlation** (taucorr_loo) measures the Kendall rank correlation coefficient τ between the explanation and leave-one-out importances. Closer to 1 is better. | |
See the paper for details. | |
""" | |
) | |
st.markdown( | |
""" | |
**In code, it would be as simple as** | |
""" | |
) | |
st.code( | |
f""" | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from ferret import Benchmark | |
model = AutoModelForSequenceClassification.from_pretrained("{model_name}") | |
tokenizer = AutoTokenizer.from_pretrained("{model_name}") | |
bench = Benchmark(model, tokenizer) | |
data = bench.load_dataset("hatexplain") | |
evaluations = bench.evaluate_samples(data, {samples}) | |
bench.show_samples_evaluation_table(evaluations) | |
""", | |
language="python", | |
) | |