File size: 6,730 Bytes
8504fa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfb87df
 
 
 
 
 
 
 
 
 
8504fa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import streamlit as st

from PIL import Image

from transformers import pipeline
import nltk
import spacy
import en_core_web_lg

from helpers import prompt_to_nli, display_nli_pr_results_as_list

@st.cache()
def download_punkt():
    nltk.download('punkt')

@st.cache(allow_output_mutation=True)
def load_spacy_pipeline():
    return en_core_web_lg.load()

def choose_text_menu(text):
    if 'text' not in st.session_state:
        st.session_state.text = 'Several demonstrators were injured.'
    text = st.text_area('Event description', st.session_state.text)

    return text

# Load Models in cache

@st.cache(allow_output_mutation=True)
def load_model_prompting():
    return pipeline("fill-mask", model="distilbert-base-uncased")

@st.cache(allow_output_mutation=True)
def load_model_nli():
    return pipeline(task="sentiment-analysis", model="roberta-large-mnli")

download_punkt()
nlp = load_spacy_pipeline()

### App START
st.markdown(
    """
    # Demonstration dashboard of the PR-ENT Approach.
    ### *Rethinking the Event Coding Pipeline with Prompt Entailment*; Clément Lefebvre, Niklas Stoehr
    ##### ARXIV LINK HERE
    ##### Contact: clement.lefebvre@datascience.ch
    ##### Version: 1.0
"""
)

st.markdown(""" 
### 1. PR-ENT summary
""")

@st.cache()
def load_prent_image():
    return Image.open('pipeline_flow.png')
st.image(load_prent_image(), caption="""PR-ENT Flow. First, we concatenate the event description e and the template t. Then we feed them through a pretrained
prompting model to obtain a list of answer candidates. Then, for each answer candidate we build a 
hypothesis by filling the template and check for entailment with the premise (the event description). 
Finally, by filtering on the entailment score, we obtain a list of entailed answer candidates related to the event description.
""")



model_nli = load_model_nli()
model_prompting = load_model_prompting()

st.markdown("""
### 2. Write an event description:
The first step is to write an event description that will be fed to the pipeline. This can be any text in English.
""")
text = choose_text_menu('')
st.session_state.text = text

st.markdown(""" 
### 3. Template design:
The second step is to design a template while keeping in mind the objective of the classification.
- A good starting point is to use `This event involves [Z].`. This template will ideally be filled with a 1 word summary of the event.
- Another good example is `People were [Z].`. With this one we mostly expect a verb that describes the action.
  You can also use any template you design. Keep in mind that if the masked slot `[Z]` is at the end of the sentence, to not forget punctuation,
  otherwise the model may fill the template with punctuation signs.
""")

if 'prompt' not in st.session_state:
    st.session_state.prompt = 'This event involves [Z].'
prompt = st.text_input('Template:',st.session_state.prompt)
st.session_state.prompt = prompt


st.markdown("""
### 4. Select the two parameters:
- The first parameter `top_k` is the maximum number of tokens that will be given by the prompting model. 
  It's also the number of tokens that will be tried for entailment. Ideally, you want a high enough number of tokens, otherwise you may miss critical information.
  However, each additional token will increase the computation time as it needs to go through the entailment model. 
  From our experiments, a good choice is between `[10,50]`, lower and you miss information, higher and you start getting unrelated tokens and long computation time.
- The second parameter is the minimum entailment score to confirm that the token is entailed with the event description. 
  By default, we set it at `0.5` (more entailed than not) but it can be modified depending on needs.
""")

def select_top_k():
    if 'top_k' not in st.session_state:
        st.session_state.top_k = 10

    return st.number_input('Number of max tokens to output (default: 10, min: 0, max: 50)? ',step = 100, min_value=0, max_value=50, value=int(st.session_state.top_k))


def select_nli_limit():
    if 'nli_limit' not in st.session_state:
        st.session_state.nli_limit = 0.5

    return st.number_input('Minimum score of entailment (default: 0.5, min: 0, max: 1)? ',step = 100.0, min_value=0.0, max_value=1.0, value=st.session_state.nli_limit)

def update_session_state_callback(value, key):
    st.session_state[value] = st.session_state[key]

top_k = select_top_k()
st.session_state.top_k = top_k

nli_limit = select_nli_limit()
st.session_state.nli_limit = nli_limit

st.markdown("""
### 5. Remove similar tokens from output:
An additional option is to remove similar tokens (e.g. `protest, protests`) from the output. 
This computes the lemma of each word (based on the template) and removes duplicate lemmas.
""")
if 'remove_lemma' not in st.session_state:
    st.session_state.remove_lemma = False
remove_lemma = st.checkbox('Remove similar lemma (e.g. protest, protests) from output?', value= st.session_state.remove_lemma)
st.session_state.remove_lemma = remove_lemma


# Save settings to display before the results
if "old_prompt" not in st.session_state:
    st.session_state.old_text =st.session_state.text
    st.session_state.old_prompt =st.session_state.prompt
    st.session_state.old_top_k = st.session_state.top_k
    st.session_state.old_nli_limit = st.session_state.nli_limit  

st.markdown("""
### 6. Run the pipeline
""")

st.markdown("""The entailed tokens are given as a list of words associated with the probability of entailment.""")

if st.button("Run PR-ENT"):
    computation_state_prent = st.text("PR-ENT Computation Running.")
    st.session_state.old_text =st.session_state.text
    st.session_state.old_prompt =st.session_state.prompt
    st.session_state.old_top_k = st.session_state.top_k
    st.session_state.old_nli_limit = st.session_state.nli_limit

    # Replace the mask
    prompt = prompt.replace('[Z]', '{}')
    prompt = prompt.replace('[MASK]', '{}')
    results = prompt_to_nli(text, prompt, model_prompting, model_nli, nlp, top_k, nli_limit, remove_lemma)
    list_results = [x[0][0] + ' ' + str(int(x[1][1]*100)) + '%' for x in results]
    st.session_state.list_results = list_results

    computation_state_prent.text("PR-ENT Computation Done.")

if 'list_results' in st.session_state:
    st.write('**Event Description**: {}'.format(st.session_state.old_text))
    st.write('**Template**: "{}"; **Top K**: {}; **Entailment Threshold**: {}.'.format(st.session_state.old_prompt,st.session_state.old_top_k, st.session_state.old_nli_limit))
    display_nli_pr_results_as_list('', st.session_state.list_results)

st.markdown("""
### 7. Actor-target coding (experimental)

Available in actor-target tab (on the left)

""")