from dotenv import load_dotenv load_dotenv() import weave import pathlib import pickle from llama_index.core import PromptTemplate from llama_index.core.node_parser import MarkdownNodeParser from llama_index.core import VectorStoreIndex from llama_index.core.retrievers import VectorIndexRetriever from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core import get_response_synthesizer from llama_index.llms.openai import OpenAI from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.core import VectorStoreIndex data_dir = "data/raw_docs/documents.pkl" with open(data_dir, "rb") as file: docs_files = pickle.load(file) print(f"Number of files: {len(docs_files)}\n") SYSTEM_PROMPT_TEMPLATE = """ Answer to the user question about the newly released Llama 3 405 billion parameter model based on the context. Provide an helful and complete answer. The paper will have information about the training, inference, evaluation and many developments in Machine Learning. Answer based only on the context provided in the documents. The answer should be tehcnical and informative. Do not make up things. User Query: {query_str} Context: {context_str} Answer: """ class SimpleRAGPipeline(weave.Model): chat_llm: str = "gpt-4" embedding_model: str = "text-embedding-3-small" temperature: float = 0.0 similarity_top_k: int = 2 chunk_size: int = 512 chunk_overlap: int = 128 prompt_template: str = SYSTEM_PROMPT_TEMPLATE query_engine: RetrieverQueryEngine = None def _get_llm(self): return OpenAI( model=self.chat_llm, temperature=0.0, max_tokens=4096, ) def _get_embedding_model(self): return OpenAIEmbedding(model=self.embedding_model) def _get_text_qa_template(self): return PromptTemplate(self.prompt_template) def _load_documents_and_chunk(self, files: pathlib.PosixPath): parser = MarkdownNodeParser() nodes = parser.get_nodes_from_documents(docs_files) return nodes def _create_vector_index(self, nodes): index = VectorStoreIndex( nodes, embed_model=self._get_embedding_model(), show_progress=True, insert_batch_size=128, ) return index def _get_retriever(self, index): retriever = VectorIndexRetriever( index=index, similarity_top_k=self.similarity_top_k, ) return retriever def _get_response_synthesizer(self): llm = self._get_llm() response_synthesizer = get_response_synthesizer( llm=llm, response_mode="compact", text_qa_template=self._get_text_qa_template(), streaming=True, ) return response_synthesizer def build_query_engine(self): nodes = self._load_documents_and_chunk(docs_files) index = self._create_vector_index(nodes) retriever = self._get_retriever(index) response_synthesizer = self._get_response_synthesizer() self.query_engine = RetrieverQueryEngine( retriever=retriever, response_synthesizer=response_synthesizer, ) @weave.op() def predict(self, question: str): response = self.query_engine.query(question) return response if __name__ == "__main__": rag_pipeline = SimpleRAGPipeline() rag_pipeline.build_query_engine() response = rag_pipeline.predict("What is Llama 3 model?") print(response["response"])