vtiyyal1's picture
Upload rerank.py
f8e7b59 verified
# reranks the top articles from a given csv file
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.vectorstores import DocArrayInMemorySearch
from sentence_transformers import CrossEncoder
import pandas as pd
import time
"""
This function rerank top articles (15 -> 4) from a given csv, then sends to LLM
Input:
csv_path: str
question: str
top_n: int
Output:
response: str
links: list of str
titles: list of str
Other functions in this file does not send articles to LLM. This is an exception.
Created using langchain RAG functions. Deprecated.
Update: Use langchain_RAG instead.
"""
def langchain_rerank_answer(csv_path, question, source='url', top_n=4):
llm = ChatOpenAI(temperature=0.0)
loader = CSVLoader(csv_path, source_column="url")
index = VectorstoreIndexCreator(
vectorstore_cls=DocArrayInMemorySearch,
).from_loaders([loader])
# prompt_template = """You are an a chatbot that answers tobacco related questions with source. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
# {context}
# Question: {question}"""
# PROMPT = PromptTemplate(
# template=prompt_template, input_variables=["context", "question"]
# )
# chain_type_kwargs = {"prompt": PROMPT}
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=index.vectorstore.as_retriever(),
verbose=False,
return_source_documents=True,
# chain_type_kwargs=chain_type_kwargs,
# chain_type_kwargs = {
# "document_separator": "<<<<>>>>>"
# },
)
answer = qa({"query": question})
sources = answer['source_documents']
sources_out = [source.metadata['source'] for source in sources]
return answer['result'], sources_out
"""
Langchain with sources.
This function is deprecated. Use langchain_RAG instead.
"""
def langchain_with_sources(csv_path, question, top_n=4):
llm = ChatOpenAI(temperature=0.0)
loader = CSVLoader(csv_path, source_column="uuid")
index = VectorstoreIndexCreator(
vectorstore_cls=DocArrayInMemorySearch,
).from_loaders([loader])
qa = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=index.vectorstore.as_retriever(),
)
output = qa({"question": question}, return_only_outputs=True)
return output['answer'], output['sources']
"""
Reranks the top articles using crossencoder.
Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for embedding / reranking.
Input:
csv_path: str
question: str
top_n: int
Output:
out_values: list of [content, uuid, title]
"""
# returns list of top n similar articles using crossencoder
def crossencoder_rerank_answer(csv_path: str, question: str, top_n=4) -> list:
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
articles = pd.read_csv(csv_path)
contents = articles['content'].tolist()
uuids = articles['uuid'].tolist()
titles = articles['title'].tolist()
published_dates = articles['published_date'].tolist()
# biencoder retrieval does not have domain
if 'domain' not in articles:
domain = [""] * len(contents)
else:
domain = articles['domain'].tolist()
cross_inp = [[question, content] for content in contents]
cross_scores = cross_encoder.predict(cross_inp)
scores_sentences = list(zip(cross_scores, contents, uuids, titles, domain, published_dates))
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
out_values = scores_sentences[:top_n]
# if score is less than 0, truncate
for idx in range(len(out_values)):
if out_values[idx][0] < 0:
out_values = out_values[:idx]
if len(out_values) == 0:
out_values = scores_sentences[:1]
break
# print(out_values)
return out_values
def crossencoder_rerank_sentencewise(csv_path: str, question: str, top_n=10) -> list:
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
articles = pd.read_csv(csv_path)
contents = articles['content'].tolist()
uuids = articles['uuid'].tolist()
titles = articles['title'].tolist()
published_dates = articles['published_date'].tolist()
if 'domain' not in articles:
domain = [""] * len(contents)
else:
domain = articles['domain'].tolist()
sentences = []
new_uuids = []
new_titles = []
new_domains = []
new_published_dates = []
for idx in range(len(contents)):
sents = sent_tokenize(contents[idx])
sentences.extend(sents)
new_uuids.extend([uuids[idx]] * len(sents))
new_titles.extend([titles[idx]] * len(sents))
new_domains.extend([domain[idx]] * len(sents))
new_published_dates.extend([published_dates[idx]] * len(sents))
cross_inp = [[question, sent] for sent in sentences]
cross_scores = cross_encoder.predict(cross_inp)
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains, new_published_dates))
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
out_values = scores_sentences[:top_n]
# if score is less than 0, truncate
for idx in range(len(out_values)):
if out_values[idx][0] < 0:
out_values = out_values[:idx]
if len(out_values) == 0:
out_values = scores_sentences[:1]
break
return out_values
def crossencoder_rerank_sentencewise_sentence_chunks(csv_path, question, top_n=10, chunk_size=2):
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
articles = pd.read_csv(csv_path)
contents = articles['content'].tolist()
uuids = articles['uuid'].tolist()
titles = articles['title'].tolist()
# embeddings do not have domain as column
if 'domain' not in articles:
domain = [""] * len(contents)
else:
domain = articles['domain'].tolist()
sentences = []
new_uuids = []
new_titles = []
new_domains = []
for idx in range(len(contents)):
sents = sent_tokenize(contents[idx])
sents_merged = []
# if the number of sentences is less than chunk size, merge and join
if len(sents) < chunk_size:
sents_merged.append(' '.join(sents))
else:
for i in range(0, len(sents) - chunk_size + 1):
sents_merged.append(' '.join(sents[i:i + chunk_size]))
sentences.extend(sents_merged)
new_uuids.extend([uuids[idx]] * len(sents_merged))
new_titles.extend([titles[idx]] * len(sents_merged))
new_domains.extend([domain[idx]] * len(sents_merged))
cross_inp = [[question, sent] for sent in sentences]
cross_scores = cross_encoder.predict(cross_inp)
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
out_values = scores_sentences[:top_n]
for idx in range(len(out_values)):
if out_values[idx][0] < 0:
out_values = out_values[:idx]
if len(out_values) == 0:
out_values = scores_sentences[:1]
break
return out_values
def crossencoder_rerank_sentencewise_articles(csv_path, question, top_n=4):
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
contents, uuids, titles, domain = load_articles(csv_path)
sentences = []
contents_elongated = []
new_uuids = []
new_titles = []
new_domains = []
for idx in range(len(contents)):
sents = sent_tokenize(contents[idx])
sentences.extend(sents)
new_uuids.extend([uuids[idx]] * len(sents))
contents_elongated.extend([contents[idx]] * len(sents))
new_titles.extend([titles[idx]] * len(sents))
new_domains.extend([domain[idx]] * len(sents))
cross_inp = [[question, sent] for sent in sentences]
cross_scores = cross_encoder.predict(cross_inp)
scores_sentences = list(zip(cross_scores, contents_elongated, new_uuids, new_titles, new_domains))
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
score_sentences_compressed = []
for item in scores_sentences:
if not score_sentences_compressed:
score_sentences_compressed.append(item)
else:
if item[2] not in [x[2] for x in score_sentences_compressed]:
score_sentences_compressed.append(item)
scores_sentences = score_sentences_compressed
return scores_sentences[:top_n]
def no_rerank(csv_path, question, top_n=4):
contents, uuids, titles, domains = load_articles(csv_path)
return list(zip(contents, uuids, titles, domains))[:top_n]
def load_articles(csv_path:str):
articles = pd.read_csv(csv_path)
contents = articles['content'].tolist()
uuids = articles['uuid'].tolist()
titles = articles['title'].tolist()
if 'domain' not in articles:
domain = [""] * len(contents)
else:
domain = articles['domain'].tolist()
return contents, uuids, titles, domain