File size: 5,293 Bytes
b0390ea
2893724
b0390ea
 
 
c2db5d1
b0390ea
7e45f0d
 
b0390ea
 
 
 
 
 
 
 
 
 
 
 
 
 
2893724
 
 
b0390ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a87154e
b0390ea
 
 
97ba57c
 
7e45f0d
97ba57c
 
 
b0390ea
7e45f0d
b0390ea
 
 
 
 
 
 
 
 
 
39c3042
b0390ea
39c3042
7e45f0d
 
 
 
 
 
b0390ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from ctypes import DEFAULT_MODE
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from ferret import Benchmark
from torch.nn.functional import softmax
from copy import deepcopy

DEFAULT_MODEL = "Hate-speech-CNERG/bert-base-uncased-hatexplain"
DEFAULT_SAMPLES = "3,5,8,13,15,17,18,25,27,28"


@st.cache()
def get_model(model_name):
    return AutoModelForSequenceClassification.from_pretrained(model_name)


@st.cache()
def get_config(model_name):
    return AutoConfig.from_pretrained(model_name)


def get_tokenizer(tokenizer_name):
    return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)


def body():

    st.title("Evaluate explanations on dataset samples")

    st.markdown(
        """
        Let's test how our built-in explainers behave on state-of-the-art datasets for explanability. 
        
        *ferret* exposes an extensible Dataset API. We currently implement [MovieReviews](https://huggingface.co/datasets/movie_rationales) and [HateXPlain](https://huggingface.co/datasets/hatexplain).

        In this demo, you let you experiment with HateXPlain.
        You just need to choose a prediction model and a set of samples to test.
        We will trigger *ferret* to:

        1. download the model;
        2. explain every sample you did choose;
        3. average all faithfulness and plausibility metrics we support 📊
        """
    )

    col1, col2 = st.columns([3, 1])
    with col1:
        model_name = st.text_input("HF Model", DEFAULT_MODEL)
        config = AutoConfig.from_pretrained(model_name)

    with col2:
        class_labels = list(config.id2label.values())
        target = st.selectbox(
            "Target",
            options=class_labels,
            index=0,
            help="Class label you want to explain.",
        )

    samples_string = st.text_input(
        "List of samples",
        DEFAULT_SAMPLES,
        help="List of indices in the dataset, comma-separated.",
    )
    compute = st.button("Run")

    samples = list(map(int, samples_string.replace(" ", "").split(",")))

    if compute and model_name:

        with st.spinner("Preparing the magic. Hang in there..."):
            model = get_model(model_name)
            tokenizer = get_tokenizer(model_name)
            bench = Benchmark(model, tokenizer)

        with st.spinner("Explaining sample (this might take a while)..."):

            @st.cache(allow_output_mutation=True)
            def compute_table(samples):
                data = bench.load_dataset("hatexplain")
                sample_evaluations = bench.evaluate_samples(
                    data, samples, target=class_labels.index(target)
                )
                table = bench.show_samples_evaluation_table(sample_evaluations).format(
                    "{:.2f}"
                )
                return table

            table = compute_table(samples)

        st.markdown("### Averaged metrics")
        st.dataframe(table)
        st.caption("Darker colors mean better performance.")

        # scores = bench.score(text)
        # scores_str = ", ".join(
        #     [f"{config.id2label[l]}: {s:.2f}" for l, s in enumerate(scores)]
        # )
        # st.text(scores_str)

        # with st.spinner("Computing Explanations.."):
        #     explanations = bench.explain(text, target=class_labels.index(target))

        # st.markdown("### Explanations")
        # st.dataframe(bench.show_table(explanations))
        # st.caption("Darker red (blue) means higher (lower) contribution.")

        # with st.spinner("Evaluating Explanations..."):
        #     evaluations = bench.evaluate_explanations(
        #         explanations, target=class_labels.index(target), apply_style=False
        #     )

        # st.markdown("### Faithfulness Metrics")
        # st.dataframe(bench.show_evaluation_table(evaluations))
        # st.caption("Darker colors mean better performance.")

        st.markdown(
            """
            **Legend**

            - **AOPC Comprehensiveness** (aopc_compr) measures *comprehensiveness*, i.e., if the explanation captures all the tokens needed to make the prediction. Higher is better.
            
            - **AOPC Sufficiency** (aopc_suff) measures *sufficiency*, i.e., if the relevant tokens in the explanation are sufficient to make the prediction. Lower is better.

            - **Leave-On-Out TAU Correlation** (taucorr_loo) measures the Kendall rank correlation coefficient τ between the explanation and leave-one-out importances. Closer to 1 is better. 
            
            See the paper for details.
            """
        )

        st.markdown(
            """
            **In code, it would be as simple as**
        """
        )
        st.code(
            f"""
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from ferret import Benchmark

model = AutoModelForSequenceClassification.from_pretrained("{model_name}")
tokenizer = AutoTokenizer.from_pretrained("{model_name}")

bench = Benchmark(model, tokenizer)
data = bench.load_dataset("hatexplain")
evaluations = bench.evaluate_samples(data, {samples})
bench.show_samples_evaluation_table(evaluations)
            """,
            language="python",
        )