Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
from transformers import pipeline | |
import nltk | |
import spacy | |
import en_core_web_lg | |
from helpers import prompt_to_nli, display_nli_pr_results_as_list | |
def download_punkt(): | |
nltk.download('punkt') | |
def load_spacy_pipeline(): | |
return en_core_web_lg.load() | |
def choose_text_menu(text): | |
if 'text' not in st.session_state: | |
st.session_state.text = 'Several demonstrators were injured.' | |
text = st.text_area('Event description', st.session_state.text) | |
return text | |
# Load Models in cache | |
def load_model_prompting(): | |
return pipeline("fill-mask", model="distilbert-base-uncased") | |
def load_model_nli(): | |
return pipeline(task="sentiment-analysis", model="roberta-large-mnli") | |
download_punkt() | |
nlp = load_spacy_pipeline() | |
### App START | |
st.markdown( | |
""" | |
# Demonstration dashboard of the PR-ENT Approach. | |
### *Rethinking the Event Coding Pipeline with Prompt Entailment* | |
### ARXIV LINK HERE | |
### Clément Lefebvre (Swiss Data Science Center) | |
### Niklas Stoehr (ETH Zürich, https://niklas-stoehr.com/) | |
##### Version: 1.0 | |
""" | |
) | |
st.markdown(""" | |
### 1. PR-ENT summary | |
""") | |
def load_prent_image(): | |
return Image.open('pipeline_flow.png') | |
st.image(load_prent_image(), caption="""PR-ENT Flow. First, we concatenate the event description e and the template t. Then we feed them through a pretrained | |
prompting model to obtain a list of answer candidates. Then, for each answer candidate we build a | |
hypothesis by filling the template and check for entailment with the premise (the event description). | |
Finally, by filtering on the entailment score, we obtain a list of entailed answer candidates related to the event description. | |
""") | |
model_nli = load_model_nli() | |
model_prompting = load_model_prompting() | |
st.markdown(""" | |
### 2. Write an event description: | |
The first step is to write an event description that will be fed to the pipeline. This can be any text in English. | |
""") | |
text = choose_text_menu('') | |
st.session_state.text = text | |
st.markdown(""" | |
### 3. Template design: | |
The second step is to design a template while keeping in mind the objective of the classification. | |
- A good starting point is to use `This event involves [Z].`. This template will ideally be filled with a 1 word summary of the event. | |
- Another good example is `People were [Z].`. With this one we mostly expect a verb that describes the action. | |
You can also use any template you design. Keep in mind that if the masked slot `[Z]` is at the end of the sentence, to not forget punctuation, | |
otherwise the model may fill the template with punctuation signs. | |
""") | |
if 'prompt' not in st.session_state: | |
st.session_state.prompt = 'This event involves [Z].' | |
prompt = st.text_input('Template:',st.session_state.prompt) | |
st.session_state.prompt = prompt | |
st.markdown(""" | |
### 4. Select the two parameters: | |
- The first parameter `top_k` is the maximum number of tokens that will be given by the prompting model. | |
It's also the number of tokens that will be tried for entailment. Ideally, you want a high enough number of tokens, otherwise you may miss critical information. | |
However, each additional token will increase the computation time as it needs to go through the entailment model. | |
From our experiments, a good choice is between `[10,50]`, lower and you miss information, higher and you start getting unrelated tokens and long computation time. | |
- The second parameter is the minimum entailment score to confirm that the token is entailed with the event description. | |
By default, we set it at `0.5` (more entailed than not) but it can be modified depending on needs. | |
""") | |
def select_top_k(): | |
if 'top_k' not in st.session_state: | |
st.session_state.top_k = 10 | |
return st.number_input('Number of max tokens to output (default: 10, min: 0, max: 50)? ',step = 100, min_value=0, max_value=50, value=int(st.session_state.top_k)) | |
def select_nli_limit(): | |
if 'nli_limit' not in st.session_state: | |
st.session_state.nli_limit = 0.5 | |
return st.number_input('Minimum score of entailment (default: 0.5, min: 0, max: 1)? ',step = 100.0, min_value=0.0, max_value=1.0, value=st.session_state.nli_limit) | |
def update_session_state_callback(value, key): | |
st.session_state[value] = st.session_state[key] | |
top_k = select_top_k() | |
st.session_state.top_k = top_k | |
nli_limit = select_nli_limit() | |
st.session_state.nli_limit = nli_limit | |
st.markdown(""" | |
### 5. Remove similar tokens from output: | |
An additional option is to remove similar tokens (e.g. `protest, protests`) from the output. | |
This computes the lemma of each word (based on the template) and removes duplicate lemmas. | |
""") | |
if 'remove_lemma' not in st.session_state: | |
st.session_state.remove_lemma = False | |
remove_lemma = st.checkbox('Remove similar lemma (e.g. protest, protests) from output?', value= st.session_state.remove_lemma) | |
st.session_state.remove_lemma = remove_lemma | |
# Save settings to display before the results | |
if "old_prompt" not in st.session_state: | |
st.session_state.old_text =st.session_state.text | |
st.session_state.old_prompt =st.session_state.prompt | |
st.session_state.old_top_k = st.session_state.top_k | |
st.session_state.old_nli_limit = st.session_state.nli_limit | |
st.markdown(""" | |
### 6. Run the pipeline | |
""") | |
st.markdown("""The entailed tokens are given as a list of words associated with the probability of entailment.""") | |
if st.button("Run PR-ENT"): | |
computation_state_prent = st.text("PR-ENT Computation Running.") | |
st.session_state.old_text =st.session_state.text | |
st.session_state.old_prompt =st.session_state.prompt | |
st.session_state.old_top_k = st.session_state.top_k | |
st.session_state.old_nli_limit = st.session_state.nli_limit | |
# Replace the mask | |
prompt = prompt.replace('[Z]', '{}') | |
prompt = prompt.replace('[MASK]', '{}') | |
results = prompt_to_nli(text, prompt, model_prompting, model_nli, nlp, top_k, nli_limit, remove_lemma) | |
list_results = [x[0][0] + ' ' + str(int(x[1][1]*100)) + '%' for x in results] | |
st.session_state.list_results = list_results | |
computation_state_prent.text("PR-ENT Computation Done.") | |
if 'list_results' in st.session_state: | |
st.write('**Event Description**: {}'.format(st.session_state.old_text)) | |
st.write('**Template**: "{}"; **Top K**: {}; **Entailment Threshold**: {}.'.format(st.session_state.old_prompt,st.session_state.old_top_k, st.session_state.old_nli_limit)) | |
display_nli_pr_results_as_list('', st.session_state.list_results) | |
st.markdown(""" | |
### 7. Actor-target coding (experimental) | |
Available in actor-target tab (on the left) | |
""") |