AuRoRA / retrieval_utils.py
Anni123's picture
Update retrieval_utils.py
12a00ac
raw
history blame
9.46 kB
'''
Modified from https://github.com/RuochenZhao/Verify-and-Edit
'''
import wikipedia
import wikipediaapi
import spacy
import numpy as np
import ngram
#import nltk
import torch
import sklearn
#from textblob import TextBlob
from nltk import tokenize
from sentence_transformers import SentenceTransformer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoder, DPRContextEncoderTokenizer
from llm_utils import decoder_for_gpt3
from utils import entity_cleansing, knowledge_cleansing
import nltk
nltk.download('punkt')
wiki_wiki = wikipediaapi.Wikipedia('en')
nlp = spacy.load("en_core_web_sm")
ENT_TYPE = ['EVENT', 'FAC', 'GPE', 'LANGUAGE', 'LAW', 'LOC', 'NORP', 'ORG', 'PERSON', 'PRODUCT', 'WORK_OF_ART']
CTX_ENCODER = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
CTX_TOKENIZER = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", model_max_length = 512)
Q_ENCODER = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
Q_TOKENIZER = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base", model_max_length = 512)
## todo: extract entities from ConceptNet
def find_ents(text, engine):
doc = nlp(text)
valid_ents = []
for ent in doc.ents:
if ent.label_ in ENT_TYPE:
valid_ents.append(ent.text)
#in case entity list is empty: resort to LLM to extract entity
if valid_ents == []:
input = "Question: " + "[ " + text + "]\n"
input += "Output the entities in Question separated by comma: "
response = decoder_for_gpt3(input, 32, engine=engine)
valid_ents = entity_cleansing(response)
return valid_ents
def relevant_pages_for_ents(valid_ents, topk = 5):
'''
Input: a list of valid entities
Output: a list of list containing topk pages for each entity
'''
if valid_ents == []:
return []
titles = []
for ve in valid_ents:
title = wikipedia.search(ve)[:topk]
titles.append(title)
#titles = list(dict.fromkeys(titles))
return titles
def relevant_pages_for_text(text, topk = 5):
return wikipedia.search(text)[:topk]
def get_wiki_objs(pages):
'''
Input: a list of list
Output: a list of list
'''
if pages == []:
return []
obj_pages = []
for titles_for_ve in pages:
pages_for_ve = [wiki_wiki.page(title) for title in titles_for_ve]
obj_pages.append(pages_for_ve)
return obj_pages
def get_linked_pages(wiki_pages, topk = 5):
linked_ents = []
for wp in wiki_pages:
linked_ents += list(wp.links.values())
if topk != -1:
linked_ents = linked_ents[:topk]
return linked_ents
def get_texts_to_pages(pages, topk = 2):
'''
Input: list of list of pages
Output: list of list of texts
'''
total_texts = []
for ve_pages in pages:
ve_texts = []
for p in ve_pages:
text = p.text
text = tokenize.sent_tokenize(text)[:topk]
text = ' '.join(text)
ve_texts.append(text)
total_texts.append(ve_texts)
return total_texts
def DPR_embeddings(q_encoder, q_tokenizer, question):
question_embedding = q_tokenizer(question, return_tensors="pt",max_length=5, truncation=True)
with torch.no_grad():
try:
question_embedding = q_encoder(**question_embedding)[0][0]
except:
print(question)
print(question_embedding['input_ids'].size())
raise Exception('end')
question_embedding = question_embedding.numpy()
return question_embedding
def model_embeddings(sentence, model):
embedding = model.encode([sentence])
return embedding[0] #should return an array of shape 384
##todo: plus overlap filtering
def filtering_retrieved_texts(question, ent_texts, retr_method="wikipedia_dpr", topk=1):
filtered_texts = []
for texts in ent_texts:
if texts != []: #not empty list
if retr_method == "ngram":
pars = np.array([ngram.NGram.compare(question, sent, N=1) for sent in texts])
#argsort: smallest to biggest
pars = pars.argsort()[::-1][:topk]
else:
if retr_method == "wikipedia_dpr":
sen_embeds = [DPR_embeddings(Q_ENCODER, Q_TOKENIZER, question)]
par_embeds = [DPR_embeddings(CTX_ENCODER, CTX_TOKENIZER, s) for s in texts]
else:
embedding_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
sen_embeds = [model_embeddings(question, embedding_model)]
par_embeds = [model_embeddings(s, embedding_model) for s in texts]
pars = sklearn.metrics.pairwise.pairwise_distances(sen_embeds, par_embeds)
pars = pars.argsort(axis=1)[0][:topk]
filtered_texts += [texts[i] for i in pars]
filtered_texts = list(dict.fromkeys(filtered_texts))
return filtered_texts
def join_knowledge(filtered_texts):
if filtered_texts == []:
return ""
return " ".join(filtered_texts)
def retrieve_for_question_kb(question, engine, know_type="entity_know", no_links=False):
valid_ents = find_ents(question, engine)
print(valid_ents)
# find pages
page_titles = []
if "entity" in know_type:
pages_for_ents = relevant_pages_for_ents(valid_ents, topk = 5) #list of list
if pages_for_ents != []:
page_titles += pages_for_ents
if "question" in know_type:
pages_for_question = relevant_pages_for_text(question, topk = 5)
if pages_for_question != []:
page_titles += pages_for_question
pages = get_wiki_objs(page_titles) #list of list
if pages == []:
return ""
new_pages = []
assert page_titles != []
assert pages != []
print(page_titles)
#print(pages)
for i, ve_pt in enumerate(page_titles):
new_ve_pages = []
for j, pt in enumerate(ve_pt):
if 'disambiguation' in pt:
new_ve_pages += get_linked_pages([pages[i][j]], topk=-1)
else:
new_ve_pages += [pages[i][j]]
new_pages.append(new_ve_pages)
pages = new_pages
if not no_links:
# add linked pages
for ve_pages in pages:
ve_pages += get_linked_pages(ve_pages, topk=5)
ve_pages = list(dict.fromkeys(ve_pages))
#get texts
texts = get_texts_to_pages(pages, topk=1)
filtered_texts = filtering_retrieved_texts(question, texts)
joint_knowledge = join_knowledge(filtered_texts)
return valid_ents, joint_knowledge
def retrieve_for_question(question, engine, retrieve_source="llm_kb"):
# Retrieve knowledge from LLM
if "llm" in retrieve_source:
self_retrieve_prompt = "Question: " + "[ " + question + "]\n"
self_retrieve_prompt += "Necessary knowledge about the question by not answering the question: "
self_retrieve_knowledge = decoder_for_gpt3(self_retrieve_prompt, 256, engine=engine)
self_retrieve_knowledge = knowledge_cleansing(self_retrieve_knowledge)
print("------Self_Know------")
print(self_retrieve_knowledge)
# Retrieve knowledge from KB
if "kb" in retrieve_source:
entities, kb_retrieve_knowledge = retrieve_for_question_kb(question, engine, no_links=True)
if kb_retrieve_knowledge != "":
print("------KB_Know------")
print(kb_retrieve_knowledge)
return entities, self_retrieve_knowledge, kb_retrieve_knowledge
def refine_for_question(question, engine, self_retrieve_knowledge, kb_retrieve_knowledge, retrieve_source="llm_kb"):
# Refine knowledge
if retrieve_source == "llm_only":
refine_knowledge = self_retrieve_knowledge
elif retrieve_source == "kb_only":
if kb_retrieve_knowledge != "":
refine_prompt = "Question: " + "[ " + question + "]\n"
refine_prompt += "Knowledge: " + "[ " + kb_retrieve_knowledge + "]\n"
refine_prompt += "Based on Knowledge, output the brief and refined knowledge necessary for Question by not giving the answer: "
refine_knowledge = decoder_for_gpt3(refine_prompt, 256, engine=engine)
print("------Refined_Know------")
print(refine_knowledge)
else:
refine_knowledge = ""
elif retrieve_source == "llm_kb":
if kb_retrieve_knowledge != "":
#refine_prompt = "Question: " + "[ " + question + "]\n"
refine_prompt = "Knowledge_1: " + "[ " + self_retrieve_knowledge + "]\n"
refine_prompt += "Knowledge_2: " + "[ " + kb_retrieve_knowledge + "]\n"
#refine_prompt += "By using Knowledge_2 to check Knowledge_1, output the brief and correct knowledge necessary for Question: "
refine_prompt += "By using Knowledge_2 to check Knowledge_1, output the brief and correct knowledge: "
refine_knowledge = decoder_for_gpt3(refine_prompt, 256, engine=engine)
refine_knowledge = knowledge_cleansing(refine_knowledge)
#refine_knowledge = kb_retrieve_knowledge + refine_knowledge
print("------Refined_Know------")
print(refine_knowledge)
else:
refine_knowledge = self_retrieve_knowledge
return refine_knowledge