File size: 5,119 Bytes
f0fc5f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# https://github.com/langchain-ai/langchain/issues/8623
import pandas as pd
from langchain.schema.retriever import BaseRetriever, Document
from langchain.vectorstores.base import VectorStoreRetriever
from langchain.vectorstores import VectorStore
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from typing import List
from pydantic import Field
class ClimateQARetriever(BaseRetriever):
vectorstore:VectorStore
sources:list = ["IPCC","IPBES"]
threshold:float = 22
k_summary:int = 3
k_total:int = 10
namespace:str = "vectors"
def get_relevant_documents(self, query: str) -> List[Document]:
# Check if all elements in the list are either IPCC or IPBES
assert isinstance(self.sources,list)
assert all([x in ["IPCC","IPBES"] for x in self.sources])
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
# Prepare base search kwargs
filters = {
"source": { "$in":self.sources},
}
# Search for k_summary documents in the summaries dataset
filters_summaries = {
**filters,
"report_type": { "$in":["SPM","TS"]},
}
docs_summaries = self.vectorstore.similarity_search_with_score(query=query,namespace = self.namespace,filter = filters_summaries,k = self.k_summary)
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
# Search for k_total - k_summary documents in the full reports dataset
filters_full = {
**filters,
"report_type": { "$nin":["SPM","TS"]},
}
k_full = self.k_total - len(docs_summaries)
docs_full = self.vectorstore.similarity_search_with_score(query=query,namespace = self.namespace,filter = filters_full,k = k_full)
# Concatenate documents
docs = docs_summaries + docs_full
# Filter if scores are below threshold
docs = [x for x in docs if x[1] > self.threshold]
# Add score to metadata
results = []
for i,(doc,score) in enumerate(docs):
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
doc.metadata["page_number"] = int(doc.metadata["page_number"])
doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
results.append(doc)
return results
# def filter_summaries(df,k_summary = 3,k_total = 10):
# # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
# # # Filter by source
# # if source == "IPCC":
# # df = df.loc[df["source"]=="IPCC"]
# # elif source == "IPBES":
# # df = df.loc[df["source"]=="IPBES"]
# # else:
# # pass
# # Separate summaries and full reports
# df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
# df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]
# # Find passages from summaries dataset
# passages_summaries = df_summaries.head(k_summary)
# # Find passages from full reports dataset
# passages_fullreports = df_full.head(k_total - len(passages_summaries))
# # Concatenate passages
# passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
# return passages
# def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
# assert max_k > k_total
# validated_sources = ["IPCC","IPBES"]
# sources = [x for x in sources if x in validated_sources]
# filters = {
# "source": { "$in": sources },
# }
# print(filters)
# # Retrieve documents
# docs = retriever.retrieve(query,top_k = max_k,filters = filters)
# # Filter by score
# docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]
# if len(docs) == 0:
# return []
# res = pd.DataFrame(docs)
# passages_df = filter_summaries(res,k_summary,k_total)
# if as_dict:
# contents = passages_df["content"].tolist()
# meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
# passages = []
# for i in range(len(contents)):
# passages.append({"content":contents[i],"meta":meta[i]})
# return passages
# else:
# return passages_df
# def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
# print("hellooooo")
# # Reformulate queries
# reformulated_query,language = reformulate(query)
# print(reformulated_query)
# # Retrieve documents
# passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
# response = {
# "query":query,
# "reformulated_query":reformulated_query,
# "language":language,
# "sources":passages,
# "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
# }
# return response
|