Spaces:
Runtime error
Runtime error
parambharat
commited on
Commit
β’
049ff35
1
Parent(s):
dbb0a0b
chore: improve rag pipeline
Browse files- app.py +17 -8
- rag/rag.py +72 -22
app.py
CHANGED
@@ -1,15 +1,21 @@
|
|
1 |
import os
|
2 |
-
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
3 |
|
4 |
-
|
5 |
|
6 |
import streamlit as st
|
|
|
7 |
from rag.rag import SimpleRAGPipeline
|
8 |
|
9 |
-
st.set_page_config(
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
wandb_api_key = st.sidebar.text_input(
|
12 |
-
if len(wandb_api_key)>=10:
|
13 |
os.environ["WANDB_API_KEY"] = wandb_api_key
|
14 |
else:
|
15 |
st.stop()
|
@@ -20,6 +26,7 @@ weave.init(f"{WANDB_PROJECT}")
|
|
20 |
|
21 |
st.title("Chat with the Llama 3 paper π¬π¦")
|
22 |
|
|
|
23 |
@st.cache_resource(show_spinner=False)
|
24 |
def load_rag_pipeline():
|
25 |
rag_pipeline = SimpleRAGPipeline()
|
@@ -27,6 +34,7 @@ def load_rag_pipeline():
|
|
27 |
|
28 |
return rag_pipeline
|
29 |
|
|
|
30 |
if "rag_pipeline" not in st.session_state.keys():
|
31 |
st.session_state.rag_pipeline = load_rag_pipeline()
|
32 |
|
@@ -37,8 +45,9 @@ def generate_response(query):
|
|
37 |
response = rag_pipeline.predict(query)
|
38 |
st.write_stream(response.response_gen)
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
43 |
if submitted:
|
44 |
generate_response(query)
|
|
|
1 |
import os
|
|
|
2 |
|
3 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
4 |
|
5 |
import streamlit as st
|
6 |
+
import weave
|
7 |
from rag.rag import SimpleRAGPipeline
|
8 |
|
9 |
+
st.set_page_config(
|
10 |
+
page_title="Chat with the Llama 3 paper!",
|
11 |
+
page_icon="π¦",
|
12 |
+
layout="centered",
|
13 |
+
initial_sidebar_state="auto",
|
14 |
+
menu_items=None,
|
15 |
+
)
|
16 |
|
17 |
+
wandb_api_key = st.sidebar.text_input("WANDB_API_KEY", type="password")
|
18 |
+
if len(wandb_api_key) >= 10:
|
19 |
os.environ["WANDB_API_KEY"] = wandb_api_key
|
20 |
else:
|
21 |
st.stop()
|
|
|
26 |
|
27 |
st.title("Chat with the Llama 3 paper π¬π¦")
|
28 |
|
29 |
+
|
30 |
@st.cache_resource(show_spinner=False)
|
31 |
def load_rag_pipeline():
|
32 |
rag_pipeline = SimpleRAGPipeline()
|
|
|
34 |
|
35 |
return rag_pipeline
|
36 |
|
37 |
+
|
38 |
if "rag_pipeline" not in st.session_state.keys():
|
39 |
st.session_state.rag_pipeline = load_rag_pipeline()
|
40 |
|
|
|
45 |
response = rag_pipeline.predict(query)
|
46 |
st.write_stream(response.response_gen)
|
47 |
|
48 |
+
|
49 |
+
with st.form("my_form"):
|
50 |
+
query = st.text_area("Ask your question about the Llama 3 paper here:")
|
51 |
+
submitted = st.form_submit_button("Submit")
|
52 |
if submitted:
|
53 |
generate_response(query)
|
rag/rag.py
CHANGED
@@ -2,42 +2,89 @@ from dotenv import load_dotenv
|
|
2 |
|
3 |
load_dotenv()
|
4 |
|
5 |
-
import weave
|
6 |
-
import pathlib
|
7 |
import pickle
|
8 |
|
9 |
-
|
|
|
10 |
from llama_index.core.node_parser import MarkdownNodeParser
|
11 |
-
from llama_index.core import VectorStoreIndex
|
12 |
-
from llama_index.core.retrievers import VectorIndexRetriever
|
13 |
from llama_index.core.query_engine import RetrieverQueryEngine
|
14 |
-
from llama_index.core import
|
15 |
-
from llama_index.llms.openai import OpenAI
|
16 |
from llama_index.embeddings.openai import OpenAIEmbedding
|
17 |
-
from llama_index.
|
18 |
|
19 |
data_dir = "data/raw_docs/documents.pkl"
|
20 |
with open(data_dir, "rb") as file:
|
21 |
docs_files = pickle.load(file)
|
22 |
|
23 |
-
|
|
|
24 |
|
25 |
SYSTEM_PROMPT_TEMPLATE = """
|
26 |
-
Answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
|
30 |
-
|
31 |
-
Context: {context_str}
|
32 |
-
Answer:
|
33 |
"""
|
34 |
|
35 |
|
36 |
class SimpleRAGPipeline(weave.Model):
|
37 |
-
chat_llm: str = "gpt-
|
38 |
embedding_model: str = "text-embedding-3-small"
|
39 |
-
temperature: float = 0.
|
40 |
-
similarity_top_k: int =
|
41 |
chunk_size: int = 512
|
42 |
chunk_overlap: int = 128
|
43 |
prompt_template: str = SYSTEM_PROMPT_TEMPLATE
|
@@ -46,7 +93,7 @@ class SimpleRAGPipeline(weave.Model):
|
|
46 |
def _get_llm(self):
|
47 |
return OpenAI(
|
48 |
model=self.chat_llm,
|
49 |
-
temperature=
|
50 |
max_tokens=4096,
|
51 |
)
|
52 |
|
@@ -56,9 +103,9 @@ class SimpleRAGPipeline(weave.Model):
|
|
56 |
def _get_text_qa_template(self):
|
57 |
return PromptTemplate(self.prompt_template)
|
58 |
|
59 |
-
def _load_documents_and_chunk(self,
|
60 |
parser = MarkdownNodeParser()
|
61 |
-
nodes = parser.get_nodes_from_documents(
|
62 |
return nodes
|
63 |
|
64 |
def _create_vector_index(self, nodes):
|
@@ -109,5 +156,8 @@ if __name__ == "__main__":
|
|
109 |
rag_pipeline = SimpleRAGPipeline()
|
110 |
rag_pipeline.build_query_engine()
|
111 |
|
112 |
-
response = rag_pipeline.predict(
|
113 |
-
|
|
|
|
|
|
|
|
2 |
|
3 |
load_dotenv()
|
4 |
|
|
|
|
|
5 |
import pickle
|
6 |
|
7 |
+
import weave
|
8 |
+
from llama_index.core import PromptTemplate, VectorStoreIndex, get_response_synthesizer
|
9 |
from llama_index.core.node_parser import MarkdownNodeParser
|
|
|
|
|
10 |
from llama_index.core.query_engine import RetrieverQueryEngine
|
11 |
+
from llama_index.core.retrievers import VectorIndexRetriever
|
|
|
12 |
from llama_index.embeddings.openai import OpenAIEmbedding
|
13 |
+
from llama_index.llms.openai import OpenAI
|
14 |
|
15 |
data_dir = "data/raw_docs/documents.pkl"
|
16 |
with open(data_dir, "rb") as file:
|
17 |
docs_files = pickle.load(file)
|
18 |
|
19 |
+
for i, doc in enumerate(docs_files[:], 1):
|
20 |
+
doc.metadata["page"] = i
|
21 |
|
22 |
SYSTEM_PROMPT_TEMPLATE = """
|
23 |
+
Answer the following question about the newly released Llama 3 405 billion parameter model based on provided snippets from the research paper.
|
24 |
+
Provide helpful, complete, and accurate answers to the question using only the information contained in these snippets.
|
25 |
+
|
26 |
+
Here are the relevant snippets from the Llama 3 405B model research paper:
|
27 |
+
|
28 |
+
<snippets>
|
29 |
+
{context_str}
|
30 |
+
</snippets>
|
31 |
+
|
32 |
+
<question>
|
33 |
+
{query_str}
|
34 |
+
</question>
|
35 |
+
|
36 |
+
To answer this question:
|
37 |
+
|
38 |
+
1. Carefully read and analyze the provided snippets.
|
39 |
+
2. Identify information that is directly relevant to the user's question.
|
40 |
+
3. Formulate a comprehensive answer based solely on the information in the snippets.
|
41 |
+
4. Do not include any information or claims that are not supported by the provided snippets.
|
42 |
+
|
43 |
+
Guidelines for your answer:
|
44 |
+
|
45 |
+
1. Be technical and informative, providing as much detail as the snippets allow.
|
46 |
+
2. If the snippets do not contain enough information to fully answer the question, state this clearly and provide what information you can based on the available snippets.
|
47 |
+
3. Do not make up or infer information beyond what is explicitly stated in the snippets.
|
48 |
+
4. If the question cannot be answered at all based on the provided snippets, state this clearly and explain why.
|
49 |
+
5. Use appropriate technical language and terminology as used in the snippets.
|
50 |
+
6. Cite the relevant sentences from the snippets and their page numbers to support your answer.
|
51 |
+
7. Answer in MFAQ format (Minimal Facts Answerable Question), providing the most concise and accurate response possible.
|
52 |
+
8. Use Markdown to format your response and include citations to indicate the snippets and the page number used to derive your answer.
|
53 |
+
|
54 |
+
Here's an example of a question and an answer. You must use this as a template to format your response:
|
55 |
+
|
56 |
+
<example>
|
57 |
+
Question: What was the main mix of the training data ? How much data was used to train the model ?
|
58 |
+
|
59 |
+
### Answer
|
60 |
+
The main mix of the training data for the Llama 3 405 billion parameter model is as follows:
|
61 |
+
|
62 |
+
- **General knowledge**: 50%
|
63 |
+
- **Mathematical and reasoning tokens**: 25%
|
64 |
+
- **Code tokens**: 17%
|
65 |
+
- **Multilingual tokens**: 8%[^1^].
|
66 |
+
|
67 |
+
Regarding the amount of data used to train the model, the snippets do not provide a specific total volume of data in terms of tokens or bytes. However, they do mention that the model was pre-trained on a large dataset containing knowledge until the end of 2023[^2^]. Additionally, the training process involved pre-training on 2.87 trillion tokens before further adjustments[^3^].
|
68 |
+
|
69 |
+
### References
|
70 |
+
|
71 |
+
[^1^]: "Scaling Laws for Data Mix," page 6.
|
72 |
+
[^2^]: "Pre-Training Data," page 4.
|
73 |
+
[^3^]: "Initial Pre-Training," page 14.
|
74 |
+
|
75 |
+
</example>
|
76 |
|
77 |
+
Remember, your role is to accurately convey the information from the research paper snippets, not to speculate or provide information from other sources.
|
78 |
|
79 |
+
Answer:
|
|
|
|
|
80 |
"""
|
81 |
|
82 |
|
83 |
class SimpleRAGPipeline(weave.Model):
|
84 |
+
chat_llm: str = "gpt-4o"
|
85 |
embedding_model: str = "text-embedding-3-small"
|
86 |
+
temperature: float = 0.1
|
87 |
+
similarity_top_k: int = 15
|
88 |
chunk_size: int = 512
|
89 |
chunk_overlap: int = 128
|
90 |
prompt_template: str = SYSTEM_PROMPT_TEMPLATE
|
|
|
93 |
def _get_llm(self):
|
94 |
return OpenAI(
|
95 |
model=self.chat_llm,
|
96 |
+
temperature=self.temperature,
|
97 |
max_tokens=4096,
|
98 |
)
|
99 |
|
|
|
103 |
def _get_text_qa_template(self):
|
104 |
return PromptTemplate(self.prompt_template)
|
105 |
|
106 |
+
def _load_documents_and_chunk(self, documents: list):
|
107 |
parser = MarkdownNodeParser()
|
108 |
+
nodes = parser.get_nodes_from_documents(documents)
|
109 |
return nodes
|
110 |
|
111 |
def _create_vector_index(self, nodes):
|
|
|
156 |
rag_pipeline = SimpleRAGPipeline()
|
157 |
rag_pipeline.build_query_engine()
|
158 |
|
159 |
+
response = rag_pipeline.predict(
|
160 |
+
"How does the model perform in comparision to gpt4 model?"
|
161 |
+
)
|
162 |
+
for resp in response.response_gen:
|
163 |
+
print(resp, end="")
|