Spaces:
Runtime error
Runtime error
from dotenv import load_dotenv | |
load_dotenv() | |
import pickle | |
import weave | |
from llama_index.core import PromptTemplate, VectorStoreIndex, get_response_synthesizer | |
from llama_index.core.node_parser import MarkdownNodeParser | |
from llama_index.core.query_engine import RetrieverQueryEngine | |
from llama_index.core.retrievers import VectorIndexRetriever | |
from llama_index.embeddings.openai import OpenAIEmbedding | |
from llama_index.llms.openai import OpenAI | |
data_dir = "data/raw_docs/documents.pkl" | |
with open(data_dir, "rb") as file: | |
docs_files = pickle.load(file) | |
for i, doc in enumerate(docs_files[:], 1): | |
doc.metadata["page"] = i | |
SYSTEM_PROMPT_TEMPLATE = """ | |
Answer the following question about the newly released Llama 3 405 billion parameter model based on provided snippets from the research paper. | |
Provide helpful, complete, and accurate answers to the question using only the information contained in these snippets. | |
Here are the relevant snippets from the Llama 3 405B model research paper: | |
<snippets> | |
{context_str} | |
</snippets> | |
To answer the question: | |
1. Carefully read and analyze the provided snippets. | |
2. Identify information that is directly relevant to the user's question. | |
3. Formulate a comprehensive answer based solely on the information in the snippets. | |
4. Do not include any information or claims that are not supported by the provided snippets. | |
Guidelines for your answer: | |
1. Be technical and informative, providing as much detail as the snippets allow. | |
2. If the snippets do not contain enough information to fully answer the question, state this clearly and provide what information you can based on the available snippets. | |
3. Do not make up or infer information beyond what is explicitly stated in the snippets. | |
4. If the question cannot be answered at all based on the provided snippets, state this clearly and explain why. | |
5. Use appropriate technical language and terminology as used in the snippets. | |
6. Cite the relevant sentences from the snippets and their page numbers to support your answer. | |
7. Answer in MFAQ format (Minimal Facts Answerable Question), providing the most concise and accurate response possible. | |
8. Use Markdown to format your response and include citation footnotes to indicate the snippets and the page number used to derive your answer. | |
9. Your answer must always contain footnotes citing the snippets used to derive the answer. | |
Here's an example of a question and an answer. You must use this as a template to format your response: | |
<example> | |
<question> | |
What was the main mix of the training data ? How much data was used to train the model ? | |
</question> | |
## Answer | |
The main mix of the training data for the Llama 3 405 billion parameter model is as follows: | |
- **General knowledge**: 50% | |
- **Mathematical and reasoning tokens**: 25% | |
- **Code tokens**: 17% | |
- **Multilingual tokens**: 8%[^1^]. | |
Regarding the amount of data used to train the model, the snippets do not provide a specific total volume of data in terms of tokens or bytes. However, they do mention that the model was pre-trained on a large dataset containing knowledge until the end of 2023[^2^]. Additionally, the training process involved pre-training on 2.87 trillion tokens before further adjustments[^3^]. | |
[^1^]: "Scaling Laws for Data Mix," page 6. | |
[^2^]: "Pre-Training Data," page 4. | |
[^3^]: "Initial Pre-Training," page 14. | |
</example> | |
Remember, your role is to accurately convey the information from the research paper snippets, not to speculate or provide information from other sources. | |
<question> | |
{query_str} | |
</question> | |
Answer: | |
""" | |
class SimpleRAGPipeline(weave.Model): | |
chat_llm: str = "gpt-4o" | |
embedding_model: str = "text-embedding-3-small" | |
temperature: float = 0.1 | |
similarity_top_k: int = 15 | |
chunk_size: int = 512 | |
chunk_overlap: int = 128 | |
prompt_template: str = SYSTEM_PROMPT_TEMPLATE | |
query_engine: RetrieverQueryEngine = None | |
def _get_llm(self): | |
return OpenAI( | |
model=self.chat_llm, | |
temperature=self.temperature, | |
max_tokens=4096, | |
) | |
def _get_embedding_model(self): | |
return OpenAIEmbedding(model=self.embedding_model) | |
def _get_text_qa_template(self): | |
return PromptTemplate(self.prompt_template) | |
def _load_documents_and_chunk(self, documents: list): | |
parser = MarkdownNodeParser() | |
nodes = parser.get_nodes_from_documents(documents) | |
return nodes | |
def _create_vector_index(self, nodes): | |
index = VectorStoreIndex( | |
nodes, | |
embed_model=self._get_embedding_model(), | |
show_progress=True, | |
insert_batch_size=512, | |
) | |
return index | |
def _get_retriever(self, index): | |
retriever = VectorIndexRetriever( | |
index=index, | |
similarity_top_k=self.similarity_top_k, | |
) | |
return retriever | |
def _get_response_synthesizer(self): | |
llm = self._get_llm() | |
response_synthesizer = get_response_synthesizer( | |
llm=llm, | |
response_mode="compact", | |
text_qa_template=self._get_text_qa_template(), | |
streaming=True, | |
) | |
return response_synthesizer | |
def build_query_engine(self): | |
nodes = self._load_documents_and_chunk(docs_files) | |
index = self._create_vector_index(nodes) | |
retriever = self._get_retriever(index) | |
response_synthesizer = self._get_response_synthesizer() | |
self.query_engine = RetrieverQueryEngine( | |
retriever=retriever, | |
response_synthesizer=response_synthesizer, | |
) | |
def predict(self, question: str): | |
response = self.query_engine.query(question) | |
return { | |
"response": response, | |
'call_id': weave.get_current_call().id, | |
"url": weave.get_current_call().ui_url, | |
} | |
if __name__ == "__main__": | |
rag_pipeline = SimpleRAGPipeline() | |
rag_pipeline.build_query_engine() | |
response = rag_pipeline.predict( | |
"How does the model perform in comparision to gpt4 model?" | |
) | |
for resp in response.response_gen: | |
print(resp, end="") | |