Spaces:
Sleeping
Sleeping
File size: 4,137 Bytes
d064c89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.generators import OpenAIGenerator
from haystack.utils import Secret
from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever
from src.settings import settings
class RAGPipeline:
def __init__(
self,
document_store,
template: str,
top_k: int,
) -> None:
self.text_embedder: SentenceTransformersTextEmbedder # type: ignore
self.retriever: QdrantEmbeddingRetriever # type: ignore
self.prompt_builder: PromptBuilder # type: ignore
self.llm_provider: OpenAIGenerator # type: ignore
self.pipeline: Pipeline | None = None
self.document_store = document_store
self.template = template
self.top_k = top_k
self.get_text_embedder()
self.get_retriever()
self.get_prompt_builder()
self.get_llm_provider()
def run(self, query: str, filter_selections: dict[str, list] | None = None) -> dict:
if not self.pipeline:
self.build_pipeline()
if self.pipeline:
filters = RAGPipeline.build_filter(filter_selections=filter_selections)
result = self.pipeline.run(
data={
"text_embedder": {"text": query},
"retriever": {"filters": filters},
"prompt_builder": {"query": query},
},
include_outputs_from=["retriever", "llm"],
)
return result
def get_text_embedder(self) -> None:
self.text_embedder = SentenceTransformersTextEmbedder(
model=settings.qdrant_database.model
)
self.text_embedder.warm_up()
def get_retriever(self) -> None:
self.retriever = QdrantEmbeddingRetriever(
document_store=self.document_store, top_k=self.top_k
)
def get_prompt_builder(self) -> None:
self.prompt_builder = PromptBuilder(template=self.template)
def get_llm_provider(self) -> None:
self.llm_provider = OpenAIGenerator(
model=settings.llm_provider.model,
api_key=Secret.from_env_var("LLM_PROVIDER__API_KEY"),
max_retries=3,
generation_kwargs={"max_tokens": 5000, "temperature": 0.2},
)
@staticmethod
def build_filter(filter_selections: dict[str, list] | None = None) -> dict:
filters: dict[str, str | list[dict]] = {"operator": "AND", "conditions": []}
if filter_selections:
for meta_data_name, selections in filter_selections.items():
filters["conditions"].append( # type: ignore
{
"field": "meta." + meta_data_name,
"operator": "in",
"value": selections,
}
)
else:
filters = {}
return filters
def build_pipeline(self):
self.pipeline = Pipeline()
self.pipeline.add_component("text_embedder", self.text_embedder)
self.pipeline.add_component("retriever", self.retriever)
self.pipeline.add_component("prompt_builder", self.prompt_builder)
self.pipeline.add_component("llm", self.llm_provider)
self.pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
self.pipeline.connect("retriever", "prompt_builder.documents")
self.pipeline.connect("prompt_builder", "llm")
if __name__ == "__main__":
document_store = DocumentStore(index="inc_data")
with open("src/rag/prompt_templates/inc_template.txt", "r") as file:
template = file.read()
pipeline = RAGPipeline(
document_store=document_store.document_store, template=template, top_k=5
)
filter_selections = {
"author": ["Malaysia", "Australia"],
}
result = pipeline.run(
"What is Malaysia's position on plastic waste?",
filter_selections=filter_selections,
)
pass
|