Spaces:
Running
Running
File size: 5,292 Bytes
0188e45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from pathlib import Path
from typing import Any, Optional, List, Dict, Tuple, ClassVar, Collection
from langchain.schema import Document
from langchain_community.vectorstores.chroma import Chroma, DEFAULT_K
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.utils import xor_args
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
class AdvancedVectorStoreRetriever(VectorStoreRetriever):
allowed_search_types: ClassVar[Collection[str]] = (
"similarity",
"similarity_score_threshold",
"mmr",
"similarity_with_embeddings"
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
if self.search_type == "similarity_with_embeddings":
docs_scores_and_embeddings = (
self.vectorstore.advanced_similarity_search(
query, **self.search_kwargs
)
)
for doc, score, embeddings in docs_scores_and_embeddings:
if '__embeddings' not in doc.metadata.keys():
doc.metadata['__embeddings'] = embeddings
if '__similarity' not in doc.metadata.keys():
doc.metadata['__similarity'] = score
docs = [doc for doc, _, _ in docs_scores_and_embeddings]
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
)
for doc, similarity in docs_and_similarities:
if '__similarity' not in doc.metadata.keys():
doc.metadata['__similarity'] = similarity
docs = [doc for doc, _ in docs_and_similarities]
else:
docs = super()._get_relevant_documents(query, run_manager=run_manager)
return docs
class AdvancedVectorStore(VectorStore):
def as_retriever(self, **kwargs: Any) -> AdvancedVectorStoreRetriever:
tags = kwargs.pop("tags", None) or []
tags.extend(self._get_retriever_tags())
return AdvancedVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
class ChromaAdvancedRetrieval(Chroma, AdvancedVectorStore):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@xor_args(("query_texts", "query_embeddings"))
def __query_collection(
self,
query_texts: Optional[List[str]] = None,
query_embeddings: Optional[List[List[float]]] = None,
n_results: int = 4,
where: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Query the chroma collection."""
try:
import chromadb # noqa: F401
except ImportError:
raise ValueError(
"Could not import chromadb python package. "
"Please install it with `pip install chromadb`."
)
return self._collection.query(
query_texts=query_texts,
query_embeddings=query_embeddings,
n_results=n_results,
where=where,
where_document=where_document,
**kwargs,
)
def advanced_similarity_search(
self,
query: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> [List[Document], float, List[float]]:
docs_scores_and_embeddings = self.similarity_search_with_scores_and_embeddings(query, k, filter=filter)
return docs_scores_and_embeddings
def similarity_search_with_scores_and_embeddings(
self,
query: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float, List[float]]]:
if self._embedding_function is None:
results = self.__query_collection(
query_texts=[query],
n_results=k,
where=filter,
where_document=where_document,
include=['metadatas', 'documents', 'embeddings', 'distances']
)
else:
query_embedding = self._embedding_function.embed_query(query)
results = self.__query_collection(
query_embeddings=[query_embedding],
n_results=k,
where=filter,
where_document=where_document,
include=['metadatas', 'documents', 'embeddings', 'distances']
)
return _results_to_docs_scores_and_embeddings(results)
def _results_to_docs_scores_and_embeddings(results: Any) -> List[Tuple[Document, float, List[float]]]:
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2], result[3])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
results["embeddings"][0],
)
]
|