|
|
|
|
|
|
|
from langchain.schema.retriever import BaseRetriever, Document |
|
from langchain.vectorstores import VectorStore |
|
from langchain.vectorstores import Chroma |
|
from typing import List |
|
|
|
|
|
|
|
SUMMARY_TYPES = [] |
|
|
|
|
|
class QARetriever(BaseRetriever): |
|
vectorstore: VectorStore |
|
domains: list = [] |
|
threshold: float = 22 |
|
k_summary: int = 0 |
|
k_total: int = 10 |
|
namespace: str = "vectors" |
|
|
|
def get_relevant_documents(self, query: str) -> List[Document]: |
|
assert isinstance(self.domains, list) |
|
assert self.k_total > self.k_summary, "k_total should be greater than k_summary" |
|
|
|
|
|
filters = {} |
|
if len(self.domains): |
|
filters["domain"] = {"$in": self.domains} |
|
|
|
if self.k_summary > 0: |
|
|
|
filters_summaries = {**filters} |
|
if len(SUMMARY_TYPES): |
|
filters_summaries = { |
|
**filters_summaries, |
|
"report_type": {"$in": SUMMARY_TYPES}, |
|
} |
|
docs_summaries = self.vectorstore.similarity_search_with_score( |
|
query=query, |
|
namespace=self.namespace, |
|
filter=self.format_filter(filters_summaries), |
|
k=self.k_summary, |
|
) |
|
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold] |
|
else: |
|
docs_summaries = [] |
|
|
|
|
|
filters_full = {**filters} |
|
print("filters", filters) |
|
if len(SUMMARY_TYPES): |
|
filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}} |
|
|
|
k_full = self.k_total - len(docs_summaries) |
|
docs_full = self.vectorstore.similarity_search_with_score( |
|
query=query, |
|
namespace=self.namespace, |
|
filter=self.format_filter(filters_full), |
|
k=k_full, |
|
) |
|
|
|
|
|
docs = docs_summaries + docs_full |
|
|
|
|
|
docs = [x for x in docs if x[1] > self.threshold] |
|
|
|
|
|
results = [] |
|
for i, (doc, score) in enumerate(docs): |
|
doc.metadata["similarity_score"] = score |
|
doc.metadata["content"] = doc.page_content |
|
doc.metadata["page_number"] = int(doc.metadata["page_number"]) |
|
doc.page_content = ( |
|
f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}""" |
|
) |
|
results.append(doc) |
|
|
|
return results |
|
|
|
def format_filter(self, filters): |
|
|
|
if isinstance(self.vectorstore, Chroma): |
|
if len(filters) <= 1: |
|
return filters |
|
and_filters = [] |
|
for field, condition in filters.items(): |
|
and_filters.append({field: condition}) |
|
return {"$and": and_filters} |
|
return filters |
|
|