Spaces:
Sleeping
Sleeping
#Import necessary libraries. | |
import re, nltk, pandas as pd, numpy as np, ssl, streamlit as st | |
from nltk.corpus import wordnet | |
import spacy | |
nlp = spacy.load("en_core_web_lg") | |
#Import necessary parts for predicting things. | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline | |
import torch | |
import torch.nn.functional as F | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True) | |
#If an error is thrown that the corpus "omw-1.4" isn't discoverable you can use this code. (https://stackoverflow.com/questions/38916452/nltk-download-ssl-certificate-verify-failed) | |
'''try: | |
_create_unverified_https_context = ssl._create_unverified_context | |
except AttributeError: | |
pass | |
else: | |
ssl._create_default_https_context = _create_unverified_https_context | |
nltk.download('omw-1.4')''' | |
# A simple function to pull synonyms and antonyms using spacy's POS | |
def syn_ant(word,POS=False,human=True): | |
pos_options = ['NOUN','VERB','ADJ','ADV'] | |
synonyms = [] | |
antonyms = [] | |
#WordNet hates spaces so you have to remove them | |
if " " in word: | |
word = word.replace(" ", "_") | |
if POS in pos_options: | |
for syn in wordnet.synsets(word, pos=getattr(wordnet, POS)): | |
for l in syn.lemmas(): | |
current = l.name() | |
if human: | |
current = re.sub("_"," ",current) | |
synonyms.append(current) | |
if l.antonyms(): | |
for ant in l.antonyms(): | |
cur_ant = ant.name() | |
if human: | |
cur_ant = re.sub("_"," ",cur_ant) | |
antonyms.append(cur_ant) | |
else: | |
for syn in wordnet.synsets(word): | |
for l in syn.lemmas(): | |
current = l.name() | |
if human: | |
current = re.sub("_"," ",current) | |
synonyms.append(current) | |
if l.antonyms(): | |
for ant in l.antonyms(): | |
cur_ant = ant.name() | |
if human: | |
cur_ant = re.sub("_"," ",cur_ant) | |
antonyms.append(cur_ant) | |
synonyms = list(set(synonyms)) | |
antonyms = list(set(antonyms)) | |
return synonyms, antonyms | |
def process_text(text): | |
doc = nlp(text.lower()) | |
result = [] | |
for token in doc: | |
if (token.is_stop) or (token.is_punct) or (token.lemma_ == '-PRON-'): | |
continue | |
result.append(token.lemma_) | |
return " ".join(result) | |
def clean_definition(syn): | |
#This function removes stop words from sentences to improve on document level similarity for differentiation. | |
if type(syn) is str: | |
synset = wordnet.synset(syn).definition() | |
elif type(syn) is nltk.corpus.reader.wordnet.Synset: | |
synset = syn.definition() | |
definition = nlp(process_text(synset)) | |
return definition | |
def check_sim(a,b): | |
if type(a) is str and type(b) is str: | |
a = nlp(a) | |
b = nlp(b) | |
similarity = a.similarity(b) | |
return similarity | |
# Builds a dataframe dynamically from WordNet using NLTK. | |
def wordnet_df(word,POS=False,seed_definition=None): | |
pos_options = ['NOUN','VERB','ADJ','ADV'] | |
synonyms, antonyms = syn_ant(word,POS,False) | |
#print(synonyms, antonyms) #for QA purposes | |
words = [] | |
cats = [] | |
#WordNet hates spaces so you have to remove them | |
m_word = word.replace(" ", "_") | |
#Allow the user to pick a seed definition if it is not provided directly to the function. Currently not working so it's commented out. | |
'''#commented out the way it was designed to allow for me to do it through Streamlit (keeping it for posterity, and for anyone who wants to use it without streamlit.) | |
for d in range(len(seed_definitions)): | |
print(f"{d}: {seed_definitions[d]}") | |
#choice = int(input("Which of the definitions above most aligns to your selection?")) | |
seed_definition = seed_definitions[choice]''' | |
try: | |
definition = seed_definition | |
except: | |
st.write("You did not supply a definition.") | |
if POS in pos_options: | |
for syn in wordnet.synsets(m_word, pos=getattr(wordnet, POS)): | |
if check_sim(process_text(seed_definition),process_text(syn.definition())) > .7: | |
cur_lemmas = syn.lemmas() | |
hypos = syn.hyponyms() | |
for hypo in hypos: | |
cur_lemmas.extend(hypo.lemmas()) | |
for lemma in cur_lemmas: | |
ll = lemma.name() | |
cats.append(re.sub("_"," ", syn.name().split(".")[0])) | |
words.append(re.sub("_"," ",ll)) | |
if len(synonyms) > 0: | |
for w in synonyms: | |
w = w.replace(" ","_") | |
for syn in wordnet.synsets(w, pos=getattr(wordnet, POS)): | |
if check_sim(process_text(seed_definition),process_text(syn.definition())) > .6: | |
cur_lemmas = syn.lemmas() | |
hypos = syn.hyponyms() | |
for hypo in hypos: | |
cur_lemmas.extend(hypo.lemmas()) | |
for lemma in cur_lemmas: | |
ll = lemma.name() | |
cats.append(re.sub("_"," ", syn.name().split(".")[0])) | |
words.append(re.sub("_"," ",ll)) | |
if len(antonyms) > 0: | |
for a in antonyms: | |
a = a.replace(" ","_") | |
for syn in wordnet.synsets(a, pos=getattr(wordnet, POS)): | |
if check_sim(process_text(seed_definition),process_text(syn.definition())) > .26: | |
cur_lemmas = syn.lemmas() | |
hypos = syn.hyponyms() | |
for hypo in hypos: | |
cur_lemmas.extend(hypo.lemmas()) | |
for lemma in cur_lemmas: | |
ll = lemma.name() | |
cats.append(re.sub("_"," ", syn.name().split(".")[0])) | |
words.append(re.sub("_"," ",ll)) | |
else: | |
for syn in wordnet.synsets(m_word): | |
if check_sim(process_text(seed_definition),process_text(syn.definition())) > .7: | |
cur_lemmas = syn.lemmas() | |
hypos = syn.hyponyms() | |
for hypo in hypos: | |
cur_lemmas.extend(hypo.lemmas()) | |
for lemma in cur_lemmas: | |
ll = lemma.name() | |
cats.append(re.sub("_"," ", syn.name().split(".")[0])) | |
words.append(re.sub("_"," ",ll)) | |
if len(synonyms) > 0: | |
for w in synonyms: | |
w = w.replace(" ","_") | |
for syn in wordnet.synsets(w): | |
if check_sim(process_text(seed_definition),process_text(syn.definition())) > .6: | |
cur_lemmas = syn.lemmas() | |
hypos = syn.hyponyms() | |
for hypo in hypos: | |
cur_lemmas.extend(hypo.lemmas()) | |
for lemma in cur_lemmas: | |
ll = lemma.name() | |
cats.append(re.sub("_"," ", syn.name().split(".")[0])) | |
words.append(re.sub("_"," ",ll)) | |
if len(antonyms) > 0: | |
for a in antonyms: | |
a = a.replace(" ","_") | |
for syn in wordnet.synsets(a): | |
if check_sim(process_text(seed_definition),process_text(syn.definition())) > .26: | |
cur_lemmas = syn.lemmas() | |
hypos = syn.hyponyms() | |
for hypo in hypos: | |
cur_lemmas.extend(hypo.lemmas()) | |
for lemma in cur_lemmas: | |
ll = lemma.name() | |
cats.append(re.sub("_"," ", syn.name().split(".")[0])) | |
words.append(re.sub("_"," ",ll)) | |
df = {"Categories":cats, "Words":words} | |
df = pd.DataFrame(df) | |
df = df.drop_duplicates().reset_index() | |
df = df.drop("index", axis=1) | |
return df | |
def eval_pred_test(text, return_all = False): | |
'''A basic function for evaluating the prediction from the model and turning it into a visualization friendly number.''' | |
preds = pipe(text) | |
neg_score = -1 * preds[0][0]['score'] | |
sent_neg = preds[0][0]['label'] | |
pos_score = preds[0][1]['score'] | |
sent_pos = preds[0][1]['label'] | |
prediction = 0 | |
sentiment = '' | |
if pos_score > abs(neg_score): | |
prediction = pos_score | |
sentiment = sent_pos | |
elif abs(neg_score) > pos_score: | |
prediction = neg_score | |
sentiment = sent_neg | |
if return_all: | |
return prediction, sentiment | |
else: | |
return prediction | |
def get_parallel(word, seed_definition, QA=False): | |
cleaned = nlp(process_text(seed_definition)) | |
root_syns = wordnet.synsets(word) | |
hypers = [] | |
new_hypos = [] | |
for syn in root_syns: | |
hypers.extend(syn.hypernyms()) | |
for syn in hypers: | |
new_hypos.extend(syn.hyponyms()) | |
hypos = list(set([syn for syn in new_hypos if cleaned.similarity(nlp(process_text(syn.definition()))) >=.75]))[:25] | |
# with st.sidebar: | |
# st.write(f"The number of hypos is {len(hypos)} during get Parallel at Similarity >= .75.") #QA | |
if len(hypos) <= 1: | |
hypos = root_syns | |
elif len(hypos) < 3: | |
hypos = list(set([syn for syn in new_hypos if cleaned.similarity(nlp(process_text(syn.definition()))) >=.5]))[:25] # added a cap to each | |
elif len(hypos) < 10: | |
hypos = list(set([syn for syn in new_hypos if cleaned.similarity(nlp(process_text(syn.definition()))) >=.66]))[:25] | |
elif len(hypos) >= 10: | |
hypos = list(set([syn for syn in new_hypos if cleaned.similarity(nlp(process_text(syn.definition()))) >=.8]))[:25] | |
if QA: | |
print(hypers) | |
print(hypos) | |
return hypers, hypos | |
else: | |
return hypos | |
# Builds a dataframe dynamically from WordNet using NLTK. | |
def wordnet_parallel_df(word,seed_definition=None): | |
words = [] | |
cats = [] | |
#WordNet hates spaces so you have to remove them | |
m_word = word.replace(" ", "_") | |
# add synonyms and antonyms for diversity | |
synonyms, antonyms = syn_ant(word) | |
words.extend(synonyms) | |
cats.extend(["synonyms" for n in range(len(synonyms))]) | |
words.extend(antonyms) | |
cats.extend(["antonyms" for n in range(len(antonyms))]) | |
try: | |
hypos = get_parallel(m_word,seed_definition) | |
except: | |
st.write("You did not supply a definition.") | |
#Allow the user to pick a seed definition if it is not provided directly to the function. | |
'''if seed_definition is None: | |
if POS in pos_options: | |
seed_definitions = [syn.definition() for syn in wordnet.synsets(m_word, pos=getattr(wordnet, POS))] | |
else: | |
seed_definitions = [syn.definition() for syn in wordnet.synsets(m_word)] | |
for d in range(len(seed_definitions)): | |
print(f"{d}: {seed_definitions[d]}") | |
choice = int(input("Which of the definitions above most aligns to your selection?")) | |
seed_definition = seed_definitions[choice]''' | |
#This is a QA section | |
# with st.sidebar: | |
# st.write(f"The number of hypos is {len(hypos)} during parallel df creation.") #QA | |
#Transforms hypos into lemmas | |
for syn in hypos: | |
cur_lemmas = syn.lemmas() | |
hypos = syn.hyponyms() | |
for hypo in hypos: | |
cur_lemmas.extend(hypo.lemmas()) | |
for lemma in cur_lemmas: | |
ll = lemma.name() | |
cats.append(re.sub("_"," ", syn.name().split(".")[0])) | |
words.append(re.sub("_"," ",ll)) | |
# with st.sidebar: | |
# st.write(f'There are {len(words)} words in the dataframe at the beginning of df creation.') #QA | |
df = {"Categories":cats, "Words":words} | |
df = pd.DataFrame(df) | |
df = df.drop_duplicates("Words").reset_index() | |
df = df.drop("index", axis=1) | |
return df | |
#@st.experimental_singleton(suppress_st_warning=True) | |
def cf_from_wordnet_df(seed,text,seed_definition=False): | |
seed_token = nlp(seed) | |
seed_POS = seed_token[0].pos_ | |
#print(seed_POS) QA | |
try: | |
df = wordnet_parallel_df(seed,seed_definition) | |
except: | |
st.write("You did not supply a definition.") | |
df["text"] = df.Words.apply(lambda x: re.sub(r'\b'+seed+r'\b',x,text)) | |
df["similarity"] = df.Words.apply(lambda x: seed_token[0].similarity(nlp(x)[0])) | |
df = df[df["similarity"] > 0].reset_index() | |
df.drop("index", axis=1, inplace=True) | |
df["pred"] = df.text.apply(eval_pred_test) | |
# added this because I think it will make the end results better if we ensure the seed is in the data we generate counterfactuals from. | |
df['seed'] = df.Words.apply(lambda x: 'seed' if x.lower() == seed.lower() else 'alternative') | |
return df |