PRENT-Codebook / helpers.py
PRENT's picture
Upload dashboard
cdc2127
raw history blame
No virus
21.1 kB
import json
import string
from time import time
import en_core_web_lg
import inflect
import nltk
import numpy as np
import pandas as pd
import streamlit as st
from nltk.tokenize import sent_tokenize
from transformers import pipeline
# Set constant values
INFLECT_ENGINE = inflect.engine()
TOP_K = 30
NLI_LIMIT = 0.9
st.set_page_config(layout="wide")
def get_top_k():
return TOP_K
def get_nli_limit():
return NLI_LIMIT
### Streamlit specific
@st.cache(allow_output_mutation=True)
def load_model_prompting():
return pipeline("fill-mask", model="distilbert-base-uncased")
@st.cache(allow_output_mutation=True)
def load_model_nli():
try:
return pipeline(
task="sentiment-analysis", model="roberta-large-mnli", device="mps"
)
except:
return pipeline(task="sentiment-analysis", model="roberta-large-mnli")
@st.cache(allow_output_mutation=True)
def load_spacy_pipeline():
return en_core_web_lg.load()
@st.cache()
def download_punkt():
nltk.download("punkt")
download_punkt()
@st.experimental_memo(max_entries=1)
def read_json_from_web(uploaded_json):
return json.load(uploaded_json)
@st.experimental_memo(max_entries=1)
def read_csv_from_web(uploaded_file):
"""Read CSV from the streamlit interface
:param uploaded_file: File to read
:type uploaded_file: UploadedFile (BytesIO)
:return: Dataframe
:rtype: pandas DataFrame
"""
try:
# Try first to read comma separated and semicolon separated files
data = pd.read_csv(uploaded_file, sep=None, engine="python")
# If both are not correct, then it will error and go to the except
except pd.errors.ParserError:
# This should be the case when there is no separator (1 column csv)
# Reset the IO object due to the previous crash
uploaded_file.seek(0)
# Use standard reading of CSV (no separator)
data = pd.read_csv(uploaded_file)
return data
def apply_style():
# Avoid having ellipsis in the multi select options
styl = """
<style>
.stMultiSelect span{
max-width: none;
}
</style>
"""
st.markdown(styl, unsafe_allow_html=True)
# Set color of multiselect to red
st.markdown(
"""
<style>
span[data-baseweb="tag"] {
background-color: red !important;
}
</style>
""",
unsafe_allow_html=True,
)
hide_st_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
</style>
"""
st.markdown(hide_st_style, unsafe_allow_html=True)
def choose_text_menu(text):
if "text" not in st.session_state:
st.session_state.text = "Several demonstrators were injured."
text = st.text_area("Event description", st.session_state.text)
return text
def initiate_widget_st_state(widget_key, perm_key, default_value):
if perm_key not in st.session_state:
st.session_state[perm_key] = default_value
if widget_key not in st.session_state:
st.session_state[widget_key] = st.session_state[perm_key]
def get_idx_column(col_name, col_list):
if col_name in col_list:
return col_list.index(col_name)
else:
return 0
def callback_add_to_multiselect(str_to_add, multiselect_key, text_input_key, *keys):
if len(str_to_add) == 0:
st.warning("Word is empty, did you press Enter on the field text?")
return
current_dict = st.session_state
*dict_keys, item_keys = keys
try:
for key in dict_keys:
current_dict = current_dict[key]
current_dict[item_keys].append(str_to_add)
except KeyError as e:
raise KeyError(keys) from e
if multiselect_key in st.session_state:
st.session_state[multiselect_key].append(str_to_add)
else:
st.session_state[multiselect_key] = [str_to_add]
st.session_state[text_input_key] = ""
# Split the text into sentences. Necessary for NLI models
def split_sentences(text):
return sent_tokenize(text)
def get_num_sentences_in_list_text(list_texts):
num_sentences = 0
for text in list_texts:
num_sentences += len(split_sentences(text))
return num_sentences
###### Prompting
def query_model_prompting(model, text, prompt_with_mask, top_k, targets):
"""Query the prompting model
:param model: Prompting model object
:type model: Huggingface pipeline object
:param text: Event description (context)
:type text: str
:param prompt_with_mask: Prompt with a mask
:type prompt_with_mask: str
:param top_k: Number of tokens to output
:type top_k: integer
:param targets: Restrict the answer to these possible tokens
:type targets: list
:return: Results of the prompting model
:rtype: list of dict
"""
sequence = text + prompt_with_mask
output_tokens = model(sequence, top_k=top_k, targets=targets)
return output_tokens
def do_sentence_entailment(sentence, hypothesis, model):
"""Concatenate context and hypothesis then perform entailment
:param sentence: Event description (context), 1 sentence
:type sentence: str
:param hypothesis: Mask filled with a token
:type hypothesis: str
:param model: NLI Model
:type model: Huggingface pipeline
:return: DataFrame containing the result of the entailment
:rtype: pandas DataFrame
"""
text = sentence + "</s></s>" + hypothesis
res = model(text, return_all_scores=True)
df_res = pd.DataFrame(res[0])
df_res["label"] = df_res["label"].apply(lambda x: x.lower())
df_res.columns = ["Label", "Score"]
return df_res
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
return np.exp(x) / np.sum(np.exp(x), axis=0)
def get_singular_form(word):
"""Get the singular form of a word
:param word: word
:type word: string
:return: singular form of the word
:rtype: string
"""
if INFLECT_ENGINE.singular_noun(word):
return INFLECT_ENGINE.singular_noun(word)
else:
return word
######### NLI + PROMPTING
def do_text_entailment(text, hypothesis, model):
"""
Do entailment for each sentence of the event description as
model was trained on sentence pair
:param text: Event Description (context)
:type text: str
:param hypothesis: Mask filled with a token
:type hypothesis: str
:param model: Model NLI
:type model: Huggingface pipeline
:return: List of entailment results for each sentence of the text
:rtype: list
"""
text_entailment_results = []
for i, sentence in enumerate(split_sentences(text)):
df_score = do_sentence_entailment(sentence, hypothesis, model)
text_entailment_results.append((sentence, hypothesis, df_score))
return text_entailment_results
def get_true_entailment(text_entailment_results, nli_limit):
"""
From the result of each sentence entailment, extract the maximum entailment score and
check if it's higher than the entailment threshold.
"""
true_hypothesis_list = []
max_score = 0
for sentence_entailment in text_entailment_results:
df_score = sentence_entailment[2]
score = df_score[df_score["Label"] == "entailment"]["Score"].values.max()
if score > max_score:
max_score = score
if max_score > nli_limit:
true_hypothesis_list.append((sentence_entailment[1], np.round(max_score, 2)))
return list(set(true_hypothesis_list))
def run_model_nli(data, batch_size, model_nli, use_tf=False):
if not use_tf:
return model_nli(data, top_k=3, batch_size=batch_size)
else:
raise NotImplementedError
# return run_pipeline_on_gpu(data, batch_size, model_nli["tokenizer"], model_nli["model"])
def prompt_to_nli_batching(
text,
prompt,
model_prompting,
nli_model,
nlp,
top_k=10,
nli_limit=0.5,
targets=None,
additional_words=None,
remove_lemma=False,
use_tf=False,
):
# Check if text has end ponctuation
if text[-1] not in string.punctuation:
text += "."
prompt_masked = prompt.format(model_prompting.tokenizer.mask_token)
output_prompting = query_model_prompting(
model_prompting, text, prompt_masked, top_k, targets=targets
)
if remove_lemma:
output_prompting = filter_prompt_output_by_lemma(prompt, output_prompting, nlp)
full_batch_concat = []
prompt_tokens = []
for token in output_prompting:
hypothesis = prompt.format(token["token_str"])
for i, sentence in enumerate(split_sentences(text)):
full_batch_concat.append(sentence + "</s></s>" + hypothesis)
prompt_tokens.append((token["token_str"], token["score"]))
# Add words that must be tried for entailment
# Also increase batch_size
if additional_words:
for i, sentence in enumerate(split_sentences(text)):
for token in additional_words:
hypothesis = prompt.format(token)
full_batch_concat.append(sentence + "</s></s>" + hypothesis)
prompt_tokens.append((token, 1))
top_k = top_k + 1
results_nli = run_model_nli(full_batch_concat, top_k, nli_model, use_tf)
# Get entailed tokens
entailed_tokens = []
for i, res in enumerate(results_nli):
entailed_tokens.extend(
[
(get_singular_form(prompt_tokens[i][0]), x["score"])
for x in res
if ((x["label"] == "ENTAILMENT") & (x["score"] > nli_limit))
]
)
if entailed_tokens:
entailed_tokens = list(
pd.DataFrame(entailed_tokens).groupby(0).max()[1].items()
)
return entailed_tokens, list(set(prompt_tokens))
def remove_similar_lemma_from_list(prompt, list_words, nlp):
## Compute a dictionnary with the lemma for all tokens
## If there is a duplicate lemma then the dictionnary value will be a list of the corresponding tokens
lemma_dict = {}
for each in list_words:
mask_filled = nlp(prompt.strip(".").format(each))
lemma_dict.setdefault([x.lemma_ for x in mask_filled][-1], []).append(each)
## Get back the list of tokens
## If multiple tokens available then take the shortest one
new_token_list = []
for key in lemma_dict.keys():
if len(lemma_dict[key]) >= 1:
new_token_list.append(min(lemma_dict[key], key=len))
else:
raise ValueError("Lemma dict has 0 corresponding words")
return new_token_list
def filter_prompt_output_by_lemma(prompt, output_prompting, nlp):
"""
Remove all similar lemmas from the prompt output (e.g. "protest", "protests")
"""
list_words = [x["token_str"] for x in output_prompting]
new_token_list = remove_similar_lemma_from_list(prompt, list_words, nlp)
return [x for x in output_prompting if x["token_str"] in new_token_list]
# Streamlit specific run functions
@st.experimental_memo(max_entries=1024)
def do_prent(text, template, top_k, nli_limit, additional_words=None):
"""Function used to execute PRENT model
:param text: Event text
:type text: string
:param template: Template with mask
:type template: string
:param top_k: Maximum tokens to output from prompting model
:type top_k: int
:param nli_limit: Threshold of entailment for NLI [0,1]
:type nli_limit: float
:param additional_words: List of words that bypass prompting and goes directly to NLI, defaults to None
:type additional_words: list, optional
:return: (Results Entailment, Results Prompting)
:rtype: tuple
"""
results_nli, results_pr = prompt_to_nli_batching(
text,
template,
load_model_prompting(),
load_model_nli(),
load_spacy_pipeline(),
top_k=top_k,
nli_limit=nli_limit,
targets=None,
additional_words=additional_words,
remove_lemma=True,
)
return results_nli, results_pr
def get_additional_words():
"""Extract the additional words from the codebook
:return: list of additional words
:rtype: list
"""
if "add_words" in st.session_state.codebook:
additional_words = st.session_state.codebook["add_words"]
else:
additional_words = None
return additional_words
def run_prent(
text="", templates=[], additional_words=None, progress=True, display_text=True
):
"""Execute PRENT over a list of templates and display streamlit widgets
:param text: Event description, defaults to ""
:type text: str, optional
:param templates: Templates with a mask, defaults to []
:type templates: list, optional
:param additional_words: List of words to bypass prompting, defaults to None
:type additional_words: list, optional
:param progress: Display or not the progress bar, defaults to True
:type progress: bool, optional
:return: (results of prent, computation time)
:rtype: tuple
"""
# Check if there is any template and event description available
if not templates:
st.warning("Template list is empty. Please add one.")
return None, None
if not text:
st.warning("Event description is empty.")
return None, None
# Display text only when computing
if display_text:
temp_text = st.empty()
temp_text.markdown("**Event Descriptions:** {}".format(text))
# Start progress bar
if progress:
progress_bar = st.progress(0)
num_prent_call = len(templates)
num_sentences = get_num_sentences_in_list_text([text])
iter = 0
t0 = time()
# We set the radio choice of streamlit to Ignore at first
if "accept_reject_text_perm" in st.session_state:
st.session_state["accept_reject_text_perm"] = "Ignore"
res = {}
for template in templates:
template = template.replace("[Z]", "{}")
results_nli, results_pr = do_prent(
text,
template,
top_k=TOP_K,
nli_limit=NLI_LIMIT,
additional_words=additional_words,
)
# Results_nli contains % of entailment, we only care about the tokens string
res[template] = [x[0] for x in results_nli]
# Update progress bar
iter += 1
if progress:
progress_bar.progress((1 / num_prent_call) * (iter))
if display_text:
temp_text.markdown("")
time_comput = (time() - t0) / num_sentences
# This check is done otherwise the time of computation is replaced by the
# time of computation when using cached value
if not time_comput < st.session_state.time_comput / 5:
st.session_state.time_comput = int(time_comput)
# Store some results
res["templates_used"] = templates
res["additional_words_used"] = additional_words
return res, time_comput
####### Find event types based on codebook and PRENT results
def check_any_conds(cond_any, list_res):
"""Function that evaluates the "OR" conditions of the codebook versus the list of filled templates
:param cond_any: List of groundtruth filled templates
:type cond_any: list
:param list_res: A list of the filled templates given by PRENT
:type list_res: list
:return: True if any groundtruth template is inside the list given by PRENT
:rtype: bool
"""
cond_any = list(cond_any)
condition = False
# Return False if there is no any condition
if not cond_any:
return False
for cond in cond_any:
# With the current codebook design, this should never be true.
# Before it was possible to have recursion to check AND conditions inside an OR condition
if isinstance(cond, dict):
condition = check_all_conds(cond["all"], list_res)
else:
# Check lowercase version of templates
if cond.lower() in [x.lower() for x in list_res]:
condition = True
# Exit function as the other templates won't change the outcome
return condition
return condition
def check_all_conds(cond_all, list_res):
"""Function that evaluates the "AND" conditions of the codebook versus the list of filled templates
:param cond_all: List of groundtruth filled templates
:type cond_all: list
:param list_res: A list of the filled templates given by PRENT
:type list_res: list
:return: True if all groundtruth template are inside the list given by PRENT
:rtype: bool
"""
cond_all = list(cond_all)
# Return False if there is no all condition
if not cond_all:
return False
# Start bool on True, and put it to false if any template is missing
condition = True
for cond in cond_all:
# With the current codebook design, this should never be true.
# Before it was possible to have recursion to check OR conditions inside an AND condition
if isinstance(cond, dict):
condition = check_any_conds(cond["any"])
else:
# Check lowercase version of templates
if not (cond.lower() in [x.lower() for x in list_res]):
condition = False
# Exit function as the other templates won't change the outcome
return condition
return condition
def find_event_types(codebook, list_res):
"""This function evaluates the codebook and then outputs a list of events types corresponding to the given results of PRENT (list of filled templates).
:param codebook: A codebook in the format given by the dashboard
:type codebook: dict
:param list_res: A list of the filled templates given by PRENT
:type list_res: list
:return: List of event type
:rtype: list
"""
list_event_type = []
# Iterate over all defined event types
for event_type in codebook["events"]:
code_event = codebook["events"][event_type]
is_not_all_event, is_not_any_event, is_not_event = False, False, False
is_all_event, is_any_event, is_event = False, False, False
# First check if NOT conditions are met
# e.g. a filled template that is contrary to the event is present
if "not_all" in code_event:
cond_all = code_event["not_all"]
if check_all_conds(cond_all, list_res):
is_not_all_event = True
if "not_any" in code_event:
cond_any = code_event["not_any"]
if check_any_conds(cond_any, list_res):
is_not_any_event = True
# Next we need to check if the "not_all" and "not_any" are related
# by an "OR" or "AND".
# This latest case needs special care because one of two list can
# be empty so False
if code_event["not_all_any_rel"] == "AND":
if is_not_all_event and (not code_event["not_any"]):
# If all TRUE and ANY is empty (so false)
is_not_event = True
elif is_not_any_event and (not code_event["not_all"]):
# If any TRUE and ALL is empty (so false)
is_not_event = True
if is_not_all_event and is_not_any_event:
is_not_event = True
elif code_event["not_all_any_rel"] == "OR":
if is_not_all_event or is_not_any_event:
is_not_event = True
# The other checks are not necessary if this is true, so we go
# to the next iteration
if is_not_event:
continue
# Similar to the previous checks but this time we look for templates that should be present
if "all" in code_event:
cond_all = code_event["all"]
## Then check if All conditions are met, if not exit
if check_all_conds(cond_all, list_res):
is_all_event = True
if "any" in code_event:
## Finally check if Any conditions is met, if not exit
cond_any = code_event["any"]
if check_any_conds(cond_any, list_res):
is_any_event = True
# This case needs special care because one of two list can
# be empty so False
if code_event["all_any_rel"] == "AND":
if is_all_event and (not code_event["any"]):
# If all TRUE and ANY is empty (so false)
is_event = True
elif is_any_event and (not code_event["all"]):
# If any TRUE and ALL is empty (so false)
is_event = True
elif is_all_event and is_any_event:
is_event = True
elif code_event["all_any_rel"] == "OR":
if is_all_event or is_any_event:
is_event = True
# If all checks are correct, then we can add the event type to the output list
if is_event:
list_event_type.append(event_type)
return list_event_type