Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,984 Bytes
e921012 e35585c e921012 e35585c e921012 e35585c e921012 e35585c e921012 e35585c e921012 e35585c e921012 e35585c e921012 e35585c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import os
from functools import lru_cache
from typing import Literal
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai import OpenAIEmbeddings
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
os.environ["GRPC_VERBOSITY"] = "NONE"
class RetrieversConfig:
REQUIRED_ENV_VARS = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"]
def __init__(
self,
dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small",
sparse_model_name: Literal[
"prithivida/Splade_PP_en_v1"
] = "prithivida/Splade_PP_en_v1",
):
self._validate_environment()
self.qdrant_url = os.getenv("QDRANT_URL")
self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
self.dense_model_name = dense_model_name
self.sparse_model_name = sparse_model_name
@staticmethod
def _validate_environment():
missing_vars = [
var
for var in RetrieversConfig.REQUIRED_ENV_VARS
if not os.getenv(var, "").strip()
]
if missing_vars:
raise EnvironmentError(
f"Missing or empty environment variable(s): {', '.join(missing_vars)}"
)
@property
@lru_cache(maxsize=2)
def dense_embeddings(self):
return OpenAIEmbeddings(model=self.dense_model_name)
@property
@lru_cache(maxsize=2)
def sparse_embeddings(self):
return FastEmbedSparse(model_name=self.sparse_model_name)
@lru_cache(maxsize=8)
def get_qdrant_retriever(
self,
collection_name: str,
dense_vector_name: str,
sparse_vector_name: str,
k: int = 5,
) -> VectorStoreRetriever:
qdrantdb = QdrantVectorStore.from_existing_collection(
embedding=self.dense_embeddings,
sparse_embedding=self.sparse_embeddings,
url=self.qdrant_url,
api_key=self.qdrant_api_key,
prefer_grpc=True,
collection_name=collection_name,
retrieval_mode=RetrievalMode.HYBRID,
vector_name=dense_vector_name,
sparse_vector_name=sparse_vector_name,
)
return qdrantdb.as_retriever(search_kwargs={"k": k})
def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever:
return self.get_qdrant_retriever(
collection_name="practitioners_hybrid_db",
dense_vector_name="practitioners_dense_vectors",
sparse_vector_name="practitioners_sparse_vectors",
k=k,
)
def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever:
return self.get_qdrant_retriever(
collection_name="docs_hybrid_db",
dense_vector_name="docs_dense_vectors",
sparse_vector_name="docs_sparse_vectors",
k=k,
)
|