File size: 4,137 Bytes
d064c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.generators import OpenAIGenerator
from haystack.utils import Secret
from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever
from src.settings import settings


class RAGPipeline:
    def __init__(
        self,
        document_store,
        template: str,
        top_k: int,
    ) -> None:
        self.text_embedder: SentenceTransformersTextEmbedder  # type: ignore
        self.retriever: QdrantEmbeddingRetriever  # type: ignore
        self.prompt_builder: PromptBuilder  # type: ignore
        self.llm_provider: OpenAIGenerator  # type: ignore
        self.pipeline: Pipeline | None = None
        self.document_store = document_store
        self.template = template
        self.top_k = top_k

        self.get_text_embedder()
        self.get_retriever()
        self.get_prompt_builder()
        self.get_llm_provider()

    def run(self, query: str, filter_selections: dict[str, list] | None = None) -> dict:
        if not self.pipeline:
            self.build_pipeline()
        if self.pipeline:
            filters = RAGPipeline.build_filter(filter_selections=filter_selections)
            result = self.pipeline.run(
                data={
                    "text_embedder": {"text": query},
                    "retriever": {"filters": filters},
                    "prompt_builder": {"query": query},
                },
                include_outputs_from=["retriever", "llm"],
            )
        return result

    def get_text_embedder(self) -> None:
        self.text_embedder = SentenceTransformersTextEmbedder(
            model=settings.qdrant_database.model
        )
        self.text_embedder.warm_up()

    def get_retriever(self) -> None:
        self.retriever = QdrantEmbeddingRetriever(
            document_store=self.document_store, top_k=self.top_k
        )

    def get_prompt_builder(self) -> None:
        self.prompt_builder = PromptBuilder(template=self.template)

    def get_llm_provider(self) -> None:
        self.llm_provider = OpenAIGenerator(
            model=settings.llm_provider.model,
            api_key=Secret.from_env_var("LLM_PROVIDER__API_KEY"),
            max_retries=3,
            generation_kwargs={"max_tokens": 5000, "temperature": 0.2},
        )

    @staticmethod
    def build_filter(filter_selections: dict[str, list] | None = None) -> dict:
        filters: dict[str, str | list[dict]] = {"operator": "AND", "conditions": []}
        if filter_selections:
            for meta_data_name, selections in filter_selections.items():
                filters["conditions"].append(  # type: ignore
                    {
                        "field": "meta." + meta_data_name,
                        "operator": "in",
                        "value": selections,
                    }
                )
        else:
            filters = {}
        return filters

    def build_pipeline(self):
        self.pipeline = Pipeline()
        self.pipeline.add_component("text_embedder", self.text_embedder)
        self.pipeline.add_component("retriever", self.retriever)
        self.pipeline.add_component("prompt_builder", self.prompt_builder)
        self.pipeline.add_component("llm", self.llm_provider)

        self.pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
        self.pipeline.connect("retriever", "prompt_builder.documents")
        self.pipeline.connect("prompt_builder", "llm")


if __name__ == "__main__":
    document_store = DocumentStore(index="inc_data")

    with open("src/rag/prompt_templates/inc_template.txt", "r") as file:
        template = file.read()

    pipeline = RAGPipeline(
        document_store=document_store.document_store, template=template, top_k=5
    )
    filter_selections = {
        "author": ["Malaysia", "Australia"],
    }
    result = pipeline.run(
        "What is Malaysia's position on plastic waste?",
        filter_selections=filter_selections,
    )
    pass