Spaces:
Sleeping
Sleeping
from haystack import Pipeline | |
from haystack.components.builders import PromptBuilder | |
from haystack.components.embedders import SentenceTransformersTextEmbedder | |
from haystack.components.generators import OpenAIGenerator | |
from haystack.utils import Secret | |
from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever | |
from src.settings import settings | |
class RAGPipeline: | |
def __init__( | |
self, | |
document_store, | |
template: str, | |
top_k: int, | |
) -> None: | |
self.text_embedder: SentenceTransformersTextEmbedder # type: ignore | |
self.retriever: QdrantEmbeddingRetriever # type: ignore | |
self.prompt_builder: PromptBuilder # type: ignore | |
self.llm_provider: OpenAIGenerator # type: ignore | |
self.pipeline: Pipeline | None = None | |
self.document_store = document_store | |
self.template = template | |
self.top_k = top_k | |
self.get_text_embedder() | |
self.get_retriever() | |
self.get_prompt_builder() | |
self.get_llm_provider() | |
def run(self, query: str, filter_selections: dict[str, list] | None = None) -> dict: | |
if not self.pipeline: | |
self.build_pipeline() | |
if self.pipeline: | |
filters = RAGPipeline.build_filter(filter_selections=filter_selections) | |
result = self.pipeline.run( | |
data={ | |
"text_embedder": {"text": query}, | |
"retriever": {"filters": filters}, | |
"prompt_builder": {"query": query}, | |
}, | |
include_outputs_from=["retriever", "llm"], | |
) | |
return result | |
def get_text_embedder(self) -> None: | |
self.text_embedder = SentenceTransformersTextEmbedder( | |
model=settings.qdrant_database.model | |
) | |
self.text_embedder.warm_up() | |
def get_retriever(self) -> None: | |
self.retriever = QdrantEmbeddingRetriever( | |
document_store=self.document_store, top_k=self.top_k | |
) | |
def get_prompt_builder(self) -> None: | |
self.prompt_builder = PromptBuilder(template=self.template) | |
def get_llm_provider(self) -> None: | |
self.llm_provider = OpenAIGenerator( | |
model=settings.llm_provider.model, | |
api_key=Secret.from_env_var("LLM_PROVIDER__API_KEY"), | |
max_retries=3, | |
generation_kwargs={"max_tokens": 5000, "temperature": 0.2}, | |
) | |
def build_filter(filter_selections: dict[str, list] | None = None) -> dict: | |
filters: dict[str, str | list[dict]] = {"operator": "AND", "conditions": []} | |
if filter_selections: | |
for meta_data_name, selections in filter_selections.items(): | |
filters["conditions"].append( # type: ignore | |
{ | |
"field": "meta." + meta_data_name, | |
"operator": "in", | |
"value": selections, | |
} | |
) | |
else: | |
filters = {} | |
return filters | |
def build_pipeline(self): | |
self.pipeline = Pipeline() | |
self.pipeline.add_component("text_embedder", self.text_embedder) | |
self.pipeline.add_component("retriever", self.retriever) | |
self.pipeline.add_component("prompt_builder", self.prompt_builder) | |
self.pipeline.add_component("llm", self.llm_provider) | |
self.pipeline.connect("text_embedder.embedding", "retriever.query_embedding") | |
self.pipeline.connect("retriever", "prompt_builder.documents") | |
self.pipeline.connect("prompt_builder", "llm") | |
if __name__ == "__main__": | |
document_store = DocumentStore(index="inc_data") | |
with open("src/rag/prompt_templates/inc_template.txt", "r") as file: | |
template = file.read() | |
pipeline = RAGPipeline( | |
document_store=document_store.document_store, template=template, top_k=5 | |
) | |
filter_selections = { | |
"author": ["Malaysia", "Australia"], | |
} | |
result = pipeline.run( | |
"What is Malaysia's position on plastic waste?", | |
filter_selections=filter_selections, | |
) | |
pass | |