LOUIS SANNA
feat(domains)
780c913
raw
history blame
3.19 kB
# https://github.com/langchain-ai/langchain/issues/8623
from langchain.schema.retriever import BaseRetriever, Document
from langchain.vectorstores import VectorStore
from langchain.vectorstores import Chroma
from typing import List
## The idea that some documents are summaries so easier to exploit
SUMMARY_TYPES = []
class QARetriever(BaseRetriever):
vectorstore: VectorStore
domains: list = []
threshold: float = 22
k_summary: int = 0
k_total: int = 10
namespace: str = "vectors"
def get_relevant_documents(self, query: str) -> List[Document]:
assert isinstance(self.domains, list)
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
# Prepare base search kwargs
filters = {}
if len(self.domains):
filters["domain"] = {"$in": self.domains}
if self.k_summary > 0:
# Search for k_summary documents in the summaries dataset
filters_summaries = {**filters}
if len(SUMMARY_TYPES):
filters_summaries = {
**filters_summaries,
"report_type": {"$in": SUMMARY_TYPES},
}
docs_summaries = self.vectorstore.similarity_search_with_score(
query=query,
namespace=self.namespace,
filter=self.format_filter(filters_summaries),
k=self.k_summary,
)
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
else:
docs_summaries = []
# Search for k_total - k_summary documents in the full reports dataset
filters_full = {**filters}
print("filters", filters)
if len(SUMMARY_TYPES):
filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}
k_full = self.k_total - len(docs_summaries)
docs_full = self.vectorstore.similarity_search_with_score(
query=query,
namespace=self.namespace,
filter=self.format_filter(filters_full),
k=k_full,
)
# Concatenate documents
docs = docs_summaries + docs_full
# Filter if scores are below threshold
docs = [x for x in docs if x[1] > self.threshold]
# Add score to metadata
results = []
for i, (doc, score) in enumerate(docs):
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
doc.metadata["page_number"] = int(doc.metadata["page_number"])
doc.page_content = (
f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
)
results.append(doc)
return results
def format_filter(self, filters):
# https://docs.trychroma.com/usage-guide#using-logical-operators
if isinstance(self.vectorstore, Chroma):
if len(filters) <= 1:
return filters
and_filters = []
for field, condition in filters.items():
and_filters.append({field: condition})
return {"$and": and_filters}
return filters