paper_reader / rag /rag.py
ayut's picture
display trace url
34ccb7c
raw
history blame
6.15 kB
from dotenv import load_dotenv
load_dotenv()
import pickle
import weave
from llama_index.core import PromptTemplate, VectorStoreIndex, get_response_synthesizer
from llama_index.core.node_parser import MarkdownNodeParser
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
data_dir = "data/raw_docs/documents.pkl"
with open(data_dir, "rb") as file:
docs_files = pickle.load(file)
for i, doc in enumerate(docs_files[:], 1):
doc.metadata["page"] = i
SYSTEM_PROMPT_TEMPLATE = """
Answer the following question about the newly released Llama 3 405 billion parameter model based on provided snippets from the research paper.
Provide helpful, complete, and accurate answers to the question using only the information contained in these snippets.
Here are the relevant snippets from the Llama 3 405B model research paper:
<snippets>
{context_str}
</snippets>
To answer the question:
1. Carefully read and analyze the provided snippets.
2. Identify information that is directly relevant to the user's question.
3. Formulate a comprehensive answer based solely on the information in the snippets.
4. Do not include any information or claims that are not supported by the provided snippets.
Guidelines for your answer:
1. Be technical and informative, providing as much detail as the snippets allow.
2. If the snippets do not contain enough information to fully answer the question, state this clearly and provide what information you can based on the available snippets.
3. Do not make up or infer information beyond what is explicitly stated in the snippets.
4. If the question cannot be answered at all based on the provided snippets, state this clearly and explain why.
5. Use appropriate technical language and terminology as used in the snippets.
6. Cite the relevant sentences from the snippets and their page numbers to support your answer.
7. Answer in MFAQ format (Minimal Facts Answerable Question), providing the most concise and accurate response possible.
8. Use Markdown to format your response and include citation footnotes to indicate the snippets and the page number used to derive your answer.
9. Your answer must always contain footnotes citing the snippets used to derive the answer.
Here's an example of a question and an answer. You must use this as a template to format your response:
<example>
<question>
What was the main mix of the training data ? How much data was used to train the model ?
</question>
## Answer
The main mix of the training data for the Llama 3 405 billion parameter model is as follows:
- **General knowledge**: 50%
- **Mathematical and reasoning tokens**: 25%
- **Code tokens**: 17%
- **Multilingual tokens**: 8%[^1^].
Regarding the amount of data used to train the model, the snippets do not provide a specific total volume of data in terms of tokens or bytes. However, they do mention that the model was pre-trained on a large dataset containing knowledge until the end of 2023[^2^]. Additionally, the training process involved pre-training on 2.87 trillion tokens before further adjustments[^3^].
[^1^]: "Scaling Laws for Data Mix," page 6.
[^2^]: "Pre-Training Data," page 4.
[^3^]: "Initial Pre-Training," page 14.
</example>
Remember, your role is to accurately convey the information from the research paper snippets, not to speculate or provide information from other sources.
<question>
{query_str}
</question>
Answer:
"""
class SimpleRAGPipeline(weave.Model):
chat_llm: str = "gpt-4o"
embedding_model: str = "text-embedding-3-small"
temperature: float = 0.1
similarity_top_k: int = 15
chunk_size: int = 512
chunk_overlap: int = 128
prompt_template: str = SYSTEM_PROMPT_TEMPLATE
query_engine: RetrieverQueryEngine = None
def _get_llm(self):
return OpenAI(
model=self.chat_llm,
temperature=self.temperature,
max_tokens=4096,
)
def _get_embedding_model(self):
return OpenAIEmbedding(model=self.embedding_model)
def _get_text_qa_template(self):
return PromptTemplate(self.prompt_template)
def _load_documents_and_chunk(self, documents: list):
parser = MarkdownNodeParser()
nodes = parser.get_nodes_from_documents(documents)
return nodes
def _create_vector_index(self, nodes):
index = VectorStoreIndex(
nodes,
embed_model=self._get_embedding_model(),
show_progress=True,
insert_batch_size=512,
)
return index
def _get_retriever(self, index):
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=self.similarity_top_k,
)
return retriever
def _get_response_synthesizer(self):
llm = self._get_llm()
response_synthesizer = get_response_synthesizer(
llm=llm,
response_mode="compact",
text_qa_template=self._get_text_qa_template(),
streaming=True,
)
return response_synthesizer
def build_query_engine(self):
nodes = self._load_documents_and_chunk(docs_files)
index = self._create_vector_index(nodes)
retriever = self._get_retriever(index)
response_synthesizer = self._get_response_synthesizer()
self.query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=response_synthesizer,
)
@weave.op()
def predict(self, question: str):
response = self.query_engine.query(question)
return {
"response": response,
'call_id': weave.get_current_call().id,
"url": weave.get_current_call().ui_url,
}
if __name__ == "__main__":
rag_pipeline = SimpleRAGPipeline()
rag_pipeline.build_query_engine()
response = rag_pipeline.predict(
"How does the model perform in comparision to gpt4 model?"
)
for resp in response.response_gen:
print(resp, end="")