File size: 2,984 Bytes
e921012
e35585c
 
e921012
 
 
 
 
 
 
 
 
e35585c
 
e921012
 
 
 
 
 
 
e35585c
e921012
 
e35585c
 
e921012
e35585c
 
e921012
e35585c
 
 
e921012
 
 
 
 
 
e35585c
 
 
 
 
 
 
 
 
 
 
e921012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e35585c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
from functools import lru_cache
from typing import Literal

from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai import OpenAIEmbeddings
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode

os.environ["GRPC_VERBOSITY"] = "NONE"


class RetrieversConfig:
    REQUIRED_ENV_VARS = ["QDRANT_API_KEY", "QDRANT_URL", "OPENAI_API_KEY"]

    def __init__(

        self,

        dense_model_name: Literal["text-embedding-3-small"] = "text-embedding-3-small",

        sparse_model_name: Literal[

            "prithivida/Splade_PP_en_v1"

        ] = "prithivida/Splade_PP_en_v1",

    ):
        self._validate_environment()
        self.qdrant_url = os.getenv("QDRANT_URL")
        self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
        self.dense_model_name = dense_model_name
        self.sparse_model_name = sparse_model_name

    @staticmethod
    def _validate_environment():
        missing_vars = [
            var
            for var in RetrieversConfig.REQUIRED_ENV_VARS
            if not os.getenv(var, "").strip()
        ]
        if missing_vars:
            raise EnvironmentError(
                f"Missing or empty environment variable(s): {', '.join(missing_vars)}"
            )

    @property
    @lru_cache(maxsize=2)
    def dense_embeddings(self):
        return OpenAIEmbeddings(model=self.dense_model_name)

    @property
    @lru_cache(maxsize=2)
    def sparse_embeddings(self):
        return FastEmbedSparse(model_name=self.sparse_model_name)

    @lru_cache(maxsize=8)
    def get_qdrant_retriever(

        self,

        collection_name: str,

        dense_vector_name: str,

        sparse_vector_name: str,

        k: int = 5,

    ) -> VectorStoreRetriever:
        qdrantdb = QdrantVectorStore.from_existing_collection(
            embedding=self.dense_embeddings,
            sparse_embedding=self.sparse_embeddings,
            url=self.qdrant_url,
            api_key=self.qdrant_api_key,
            prefer_grpc=True,
            collection_name=collection_name,
            retrieval_mode=RetrievalMode.HYBRID,
            vector_name=dense_vector_name,
            sparse_vector_name=sparse_vector_name,
        )

        return qdrantdb.as_retriever(search_kwargs={"k": k})

    def get_practitioners_retriever(self, k: int = 5) -> VectorStoreRetriever:
        return self.get_qdrant_retriever(
            collection_name="practitioners_hybrid_db",
            dense_vector_name="practitioners_dense_vectors",
            sparse_vector_name="practitioners_sparse_vectors",
            k=k,
        )

    def get_documents_retriever(self, k: int = 5) -> VectorStoreRetriever:
        return self.get_qdrant_retriever(
            collection_name="docs_hybrid_db",
            dense_vector_name="docs_dense_vectors",
            sparse_vector_name="docs_sparse_vectors",
            k=k,
        )