import json import string from time import time import en_core_web_lg import inflect import nltk import numpy as np import pandas as pd import streamlit as st from nltk.tokenize import sent_tokenize from transformers import pipeline # Set constant values INFLECT_ENGINE = inflect.engine() TOP_K = 30 NLI_LIMIT = 0.9 st.set_page_config(layout="wide") def get_top_k(): return TOP_K def get_nli_limit(): return NLI_LIMIT ### Streamlit specific @st.cache(allow_output_mutation=True) def load_model_prompting(): return pipeline("fill-mask", model="distilbert-base-uncased") @st.cache(allow_output_mutation=True) def load_model_nli(): try: return pipeline( task="sentiment-analysis", model="roberta-large-mnli", device="mps" ) except: return pipeline(task="sentiment-analysis", model="roberta-large-mnli") @st.cache(allow_output_mutation=True) def load_spacy_pipeline(): return en_core_web_lg.load() @st.cache() def download_punkt(): nltk.download("punkt") download_punkt() @st.experimental_memo(max_entries=1) def read_json_from_web(uploaded_json): return json.load(uploaded_json) @st.experimental_memo(max_entries=1) def read_csv_from_web(uploaded_file): """Read CSV from the streamlit interface :param uploaded_file: File to read :type uploaded_file: UploadedFile (BytesIO) :return: Dataframe :rtype: pandas DataFrame """ try: # Try first to read comma separated and semicolon separated files data = pd.read_csv(uploaded_file, sep=None, engine="python") # If both are not correct, then it will error and go to the except except pd.errors.ParserError: # This should be the case when there is no separator (1 column csv) # Reset the IO object due to the previous crash uploaded_file.seek(0) # Use standard reading of CSV (no separator) data = pd.read_csv(uploaded_file) return data def apply_style(): # Avoid having ellipsis in the multi select options styl = """ """ st.markdown(styl, unsafe_allow_html=True) # Set color of multiselect to red st.markdown( """ """, unsafe_allow_html=True, ) hide_st_style = """ """ st.markdown(hide_st_style, unsafe_allow_html=True) def choose_text_menu(text): if "text" not in st.session_state: st.session_state.text = "Several demonstrators were injured." text = st.text_area("Event description", st.session_state.text) return text def initiate_widget_st_state(widget_key, perm_key, default_value): if perm_key not in st.session_state: st.session_state[perm_key] = default_value if widget_key not in st.session_state: st.session_state[widget_key] = st.session_state[perm_key] def get_idx_column(col_name, col_list): if col_name in col_list: return col_list.index(col_name) else: return 0 def callback_add_to_multiselect(str_to_add, multiselect_key, text_input_key, *keys): if len(str_to_add) == 0: st.warning("Word is empty, did you press Enter on the field text?") return current_dict = st.session_state *dict_keys, item_keys = keys try: for key in dict_keys: current_dict = current_dict[key] current_dict[item_keys].append(str_to_add) except KeyError as e: raise KeyError(keys) from e if multiselect_key in st.session_state: st.session_state[multiselect_key].append(str_to_add) else: st.session_state[multiselect_key] = [str_to_add] st.session_state[text_input_key] = "" # Split the text into sentences. Necessary for NLI models def split_sentences(text): return sent_tokenize(text) def get_num_sentences_in_list_text(list_texts): num_sentences = 0 for text in list_texts: num_sentences += len(split_sentences(text)) return num_sentences ###### Prompting def query_model_prompting(model, text, prompt_with_mask, top_k, targets): """Query the prompting model :param model: Prompting model object :type model: Huggingface pipeline object :param text: Event description (context) :type text: str :param prompt_with_mask: Prompt with a mask :type prompt_with_mask: str :param top_k: Number of tokens to output :type top_k: integer :param targets: Restrict the answer to these possible tokens :type targets: list :return: Results of the prompting model :rtype: list of dict """ sequence = text + prompt_with_mask output_tokens = model(sequence, top_k=top_k, targets=targets) return output_tokens def do_sentence_entailment(sentence, hypothesis, model): """Concatenate context and hypothesis then perform entailment :param sentence: Event description (context), 1 sentence :type sentence: str :param hypothesis: Mask filled with a token :type hypothesis: str :param model: NLI Model :type model: Huggingface pipeline :return: DataFrame containing the result of the entailment :rtype: pandas DataFrame """ text = sentence + "" + hypothesis res = model(text, return_all_scores=True) df_res = pd.DataFrame(res[0]) df_res["label"] = df_res["label"].apply(lambda x: x.lower()) df_res.columns = ["Label", "Score"] return df_res def softmax(x): """Compute softmax values for each sets of scores in x.""" return np.exp(x) / np.sum(np.exp(x), axis=0) def get_singular_form(word): """Get the singular form of a word :param word: word :type word: string :return: singular form of the word :rtype: string """ if INFLECT_ENGINE.singular_noun(word): return INFLECT_ENGINE.singular_noun(word) else: return word ######### NLI + PROMPTING def do_text_entailment(text, hypothesis, model): """ Do entailment for each sentence of the event description as model was trained on sentence pair :param text: Event Description (context) :type text: str :param hypothesis: Mask filled with a token :type hypothesis: str :param model: Model NLI :type model: Huggingface pipeline :return: List of entailment results for each sentence of the text :rtype: list """ text_entailment_results = [] for i, sentence in enumerate(split_sentences(text)): df_score = do_sentence_entailment(sentence, hypothesis, model) text_entailment_results.append((sentence, hypothesis, df_score)) return text_entailment_results def get_true_entailment(text_entailment_results, nli_limit): """ From the result of each sentence entailment, extract the maximum entailment score and check if it's higher than the entailment threshold. """ true_hypothesis_list = [] max_score = 0 for sentence_entailment in text_entailment_results: df_score = sentence_entailment[2] score = df_score[df_score["Label"] == "entailment"]["Score"].values.max() if score > max_score: max_score = score if max_score > nli_limit: true_hypothesis_list.append((sentence_entailment[1], np.round(max_score, 2))) return list(set(true_hypothesis_list)) def run_model_nli(data, batch_size, model_nli, use_tf=False): if not use_tf: return model_nli(data, top_k=3, batch_size=batch_size) else: raise NotImplementedError # return run_pipeline_on_gpu(data, batch_size, model_nli["tokenizer"], model_nli["model"]) def prompt_to_nli_batching( text, prompt, model_prompting, nli_model, nlp, top_k=10, nli_limit=0.5, targets=None, additional_words=None, remove_lemma=False, use_tf=False, ): # Check if text has end ponctuation if text[-1] not in string.punctuation: text += "." prompt_masked = prompt.format(model_prompting.tokenizer.mask_token) output_prompting = query_model_prompting( model_prompting, text, prompt_masked, top_k, targets=targets ) if remove_lemma: output_prompting = filter_prompt_output_by_lemma(prompt, output_prompting, nlp) full_batch_concat = [] prompt_tokens = [] for token in output_prompting: hypothesis = prompt.format(token["token_str"]) for i, sentence in enumerate(split_sentences(text)): full_batch_concat.append(sentence + "" + hypothesis) prompt_tokens.append((token["token_str"], token["score"])) # Add words that must be tried for entailment # Also increase batch_size if additional_words: for i, sentence in enumerate(split_sentences(text)): for token in additional_words: hypothesis = prompt.format(token) full_batch_concat.append(sentence + "" + hypothesis) prompt_tokens.append((token, 1)) top_k = top_k + 1 results_nli = run_model_nli(full_batch_concat, top_k, nli_model, use_tf) # Get entailed tokens entailed_tokens = [] for i, res in enumerate(results_nli): entailed_tokens.extend( [ (get_singular_form(prompt_tokens[i][0]), x["score"]) for x in res if ((x["label"] == "ENTAILMENT") & (x["score"] > nli_limit)) ] ) if entailed_tokens: entailed_tokens = list( pd.DataFrame(entailed_tokens).groupby(0).max()[1].items() ) return entailed_tokens, list(set(prompt_tokens)) def remove_similar_lemma_from_list(prompt, list_words, nlp): ## Compute a dictionnary with the lemma for all tokens ## If there is a duplicate lemma then the dictionnary value will be a list of the corresponding tokens lemma_dict = {} for each in list_words: mask_filled = nlp(prompt.strip(".").format(each)) lemma_dict.setdefault([x.lemma_ for x in mask_filled][-1], []).append(each) ## Get back the list of tokens ## If multiple tokens available then take the shortest one new_token_list = [] for key in lemma_dict.keys(): if len(lemma_dict[key]) >= 1: new_token_list.append(min(lemma_dict[key], key=len)) else: raise ValueError("Lemma dict has 0 corresponding words") return new_token_list def filter_prompt_output_by_lemma(prompt, output_prompting, nlp): """ Remove all similar lemmas from the prompt output (e.g. "protest", "protests") """ list_words = [x["token_str"] for x in output_prompting] new_token_list = remove_similar_lemma_from_list(prompt, list_words, nlp) return [x for x in output_prompting if x["token_str"] in new_token_list] # Streamlit specific run functions @st.experimental_memo(max_entries=1024) def do_prent(text, template, top_k, nli_limit, additional_words=None): """Function used to execute PRENT model :param text: Event text :type text: string :param template: Template with mask :type template: string :param top_k: Maximum tokens to output from prompting model :type top_k: int :param nli_limit: Threshold of entailment for NLI [0,1] :type nli_limit: float :param additional_words: List of words that bypass prompting and goes directly to NLI, defaults to None :type additional_words: list, optional :return: (Results Entailment, Results Prompting) :rtype: tuple """ results_nli, results_pr = prompt_to_nli_batching( text, template, load_model_prompting(), load_model_nli(), load_spacy_pipeline(), top_k=top_k, nli_limit=nli_limit, targets=None, additional_words=additional_words, remove_lemma=True, ) return results_nli, results_pr def get_additional_words(): """Extract the additional words from the codebook :return: list of additional words :rtype: list """ if "add_words" in st.session_state.codebook: additional_words = st.session_state.codebook["add_words"] else: additional_words = None return additional_words def run_prent( text="", templates=[], additional_words=None, progress=True, display_text=True ): """Execute PRENT over a list of templates and display streamlit widgets :param text: Event description, defaults to "" :type text: str, optional :param templates: Templates with a mask, defaults to [] :type templates: list, optional :param additional_words: List of words to bypass prompting, defaults to None :type additional_words: list, optional :param progress: Display or not the progress bar, defaults to True :type progress: bool, optional :return: (results of prent, computation time) :rtype: tuple """ # Check if there is any template and event description available if not templates: st.warning("Template list is empty. Please add one.") return None, None if not text: st.warning("Event description is empty.") return None, None # Display text only when computing if display_text: temp_text = st.empty() temp_text.markdown("**Event Descriptions:** {}".format(text)) # Start progress bar if progress: progress_bar = st.progress(0) num_prent_call = len(templates) num_sentences = get_num_sentences_in_list_text([text]) iter = 0 t0 = time() # We set the radio choice of streamlit to Ignore at first if "accept_reject_text_perm" in st.session_state: st.session_state["accept_reject_text_perm"] = "Ignore" res = {} for template in templates: template = template.replace("[Z]", "{}") results_nli, results_pr = do_prent( text, template, top_k=TOP_K, nli_limit=NLI_LIMIT, additional_words=additional_words, ) # Results_nli contains % of entailment, we only care about the tokens string res[template] = [x[0] for x in results_nli] # Update progress bar iter += 1 if progress: progress_bar.progress((1 / num_prent_call) * (iter)) if display_text: temp_text.markdown("") time_comput = (time() - t0) / num_sentences # This check is done otherwise the time of computation is replaced by the # time of computation when using cached value if not time_comput < st.session_state.time_comput / 5: st.session_state.time_comput = int(time_comput) # Store some results res["templates_used"] = templates res["additional_words_used"] = additional_words return res, time_comput ####### Find event types based on codebook and PRENT results def check_any_conds(cond_any, list_res): """Function that evaluates the "OR" conditions of the codebook versus the list of filled templates :param cond_any: List of groundtruth filled templates :type cond_any: list :param list_res: A list of the filled templates given by PRENT :type list_res: list :return: True if any groundtruth template is inside the list given by PRENT :rtype: bool """ cond_any = list(cond_any) condition = False # Return False if there is no any condition if not cond_any: return False for cond in cond_any: # With the current codebook design, this should never be true. # Before it was possible to have recursion to check AND conditions inside an OR condition if isinstance(cond, dict): condition = check_all_conds(cond["all"], list_res) else: # Check lowercase version of templates if cond.lower() in [x.lower() for x in list_res]: condition = True # Exit function as the other templates won't change the outcome return condition return condition def check_all_conds(cond_all, list_res): """Function that evaluates the "AND" conditions of the codebook versus the list of filled templates :param cond_all: List of groundtruth filled templates :type cond_all: list :param list_res: A list of the filled templates given by PRENT :type list_res: list :return: True if all groundtruth template are inside the list given by PRENT :rtype: bool """ cond_all = list(cond_all) # Return False if there is no all condition if not cond_all: return False # Start bool on True, and put it to false if any template is missing condition = True for cond in cond_all: # With the current codebook design, this should never be true. # Before it was possible to have recursion to check OR conditions inside an AND condition if isinstance(cond, dict): condition = check_any_conds(cond["any"]) else: # Check lowercase version of templates if not (cond.lower() in [x.lower() for x in list_res]): condition = False # Exit function as the other templates won't change the outcome return condition return condition def find_event_types(codebook, list_res): """This function evaluates the codebook and then outputs a list of events types corresponding to the given results of PRENT (list of filled templates). :param codebook: A codebook in the format given by the dashboard :type codebook: dict :param list_res: A list of the filled templates given by PRENT :type list_res: list :return: List of event type :rtype: list """ list_event_type = [] # Iterate over all defined event types for event_type in codebook["events"]: code_event = codebook["events"][event_type] is_not_all_event, is_not_any_event, is_not_event = False, False, False is_all_event, is_any_event, is_event = False, False, False # First check if NOT conditions are met # e.g. a filled template that is contrary to the event is present if "not_all" in code_event: cond_all = code_event["not_all"] if check_all_conds(cond_all, list_res): is_not_all_event = True if "not_any" in code_event: cond_any = code_event["not_any"] if check_any_conds(cond_any, list_res): is_not_any_event = True # Next we need to check if the "not_all" and "not_any" are related # by an "OR" or "AND". # This latest case needs special care because one of two list can # be empty so False if code_event["not_all_any_rel"] == "AND": if is_not_all_event and (not code_event["not_any"]): # If all TRUE and ANY is empty (so false) is_not_event = True elif is_not_any_event and (not code_event["not_all"]): # If any TRUE and ALL is empty (so false) is_not_event = True if is_not_all_event and is_not_any_event: is_not_event = True elif code_event["not_all_any_rel"] == "OR": if is_not_all_event or is_not_any_event: is_not_event = True # The other checks are not necessary if this is true, so we go # to the next iteration if is_not_event: continue # Similar to the previous checks but this time we look for templates that should be present if "all" in code_event: cond_all = code_event["all"] ## Then check if All conditions are met, if not exit if check_all_conds(cond_all, list_res): is_all_event = True if "any" in code_event: ## Finally check if Any conditions is met, if not exit cond_any = code_event["any"] if check_any_conds(cond_any, list_res): is_any_event = True # This case needs special care because one of two list can # be empty so False if code_event["all_any_rel"] == "AND": if is_all_event and (not code_event["any"]): # If all TRUE and ANY is empty (so false) is_event = True elif is_any_event and (not code_event["all"]): # If any TRUE and ALL is empty (so false) is_event = True elif is_all_event and is_any_event: is_event = True elif code_event["all_any_rel"] == "OR": if is_all_event or is_any_event: is_event = True # If all checks are correct, then we can add the event type to the output list if is_event: list_event_type.append(event_type) return list_event_type