parambharat commited on
Commit
049ff35
β€’
1 Parent(s): dbb0a0b

chore: improve rag pipeline

Browse files
Files changed (2) hide show
  1. app.py +17 -8
  2. rag/rag.py +72 -22
app.py CHANGED
@@ -1,15 +1,21 @@
1
  import os
2
- os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
3
 
4
- import weave
5
 
6
  import streamlit as st
 
7
  from rag.rag import SimpleRAGPipeline
8
 
9
- st.set_page_config(page_title="Chat with the Llama 3 paper!", page_icon="πŸ¦™", layout="centered", initial_sidebar_state="auto", menu_items=None)
 
 
 
 
 
 
10
 
11
- wandb_api_key = st.sidebar.text_input('WANDB_API_KEY', type='password')
12
- if len(wandb_api_key)>=10:
13
  os.environ["WANDB_API_KEY"] = wandb_api_key
14
  else:
15
  st.stop()
@@ -20,6 +26,7 @@ weave.init(f"{WANDB_PROJECT}")
20
 
21
  st.title("Chat with the Llama 3 paper πŸ’¬πŸ¦™")
22
 
 
23
  @st.cache_resource(show_spinner=False)
24
  def load_rag_pipeline():
25
  rag_pipeline = SimpleRAGPipeline()
@@ -27,6 +34,7 @@ def load_rag_pipeline():
27
 
28
  return rag_pipeline
29
 
 
30
  if "rag_pipeline" not in st.session_state.keys():
31
  st.session_state.rag_pipeline = load_rag_pipeline()
32
 
@@ -37,8 +45,9 @@ def generate_response(query):
37
  response = rag_pipeline.predict(query)
38
  st.write_stream(response.response_gen)
39
 
40
- with st.form('my_form'):
41
- query = st.text_area('Ask your question about the Llama 3 paper here:')
42
- submitted = st.form_submit_button('Submit')
 
43
  if submitted:
44
  generate_response(query)
 
1
  import os
 
2
 
3
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
4
 
5
  import streamlit as st
6
+ import weave
7
  from rag.rag import SimpleRAGPipeline
8
 
9
+ st.set_page_config(
10
+ page_title="Chat with the Llama 3 paper!",
11
+ page_icon="πŸ¦™",
12
+ layout="centered",
13
+ initial_sidebar_state="auto",
14
+ menu_items=None,
15
+ )
16
 
17
+ wandb_api_key = st.sidebar.text_input("WANDB_API_KEY", type="password")
18
+ if len(wandb_api_key) >= 10:
19
  os.environ["WANDB_API_KEY"] = wandb_api_key
20
  else:
21
  st.stop()
 
26
 
27
  st.title("Chat with the Llama 3 paper πŸ’¬πŸ¦™")
28
 
29
+
30
  @st.cache_resource(show_spinner=False)
31
  def load_rag_pipeline():
32
  rag_pipeline = SimpleRAGPipeline()
 
34
 
35
  return rag_pipeline
36
 
37
+
38
  if "rag_pipeline" not in st.session_state.keys():
39
  st.session_state.rag_pipeline = load_rag_pipeline()
40
 
 
45
  response = rag_pipeline.predict(query)
46
  st.write_stream(response.response_gen)
47
 
48
+
49
+ with st.form("my_form"):
50
+ query = st.text_area("Ask your question about the Llama 3 paper here:")
51
+ submitted = st.form_submit_button("Submit")
52
  if submitted:
53
  generate_response(query)
rag/rag.py CHANGED
@@ -2,42 +2,89 @@ from dotenv import load_dotenv
2
 
3
  load_dotenv()
4
 
5
- import weave
6
- import pathlib
7
  import pickle
8
 
9
- from llama_index.core import PromptTemplate
 
10
  from llama_index.core.node_parser import MarkdownNodeParser
11
- from llama_index.core import VectorStoreIndex
12
- from llama_index.core.retrievers import VectorIndexRetriever
13
  from llama_index.core.query_engine import RetrieverQueryEngine
14
- from llama_index.core import get_response_synthesizer
15
- from llama_index.llms.openai import OpenAI
16
  from llama_index.embeddings.openai import OpenAIEmbedding
17
- from llama_index.core import VectorStoreIndex
18
 
19
  data_dir = "data/raw_docs/documents.pkl"
20
  with open(data_dir, "rb") as file:
21
  docs_files = pickle.load(file)
22
 
23
- print(f"Number of files: {len(docs_files)}\n")
 
24
 
25
  SYSTEM_PROMPT_TEMPLATE = """
26
- Answer to the user question about the newly released Llama 3 405 billion parameter model based on the context. Provide an helful and complete answer. The paper will have information about the training, inference, evaluation and many developments in Machine Learning.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- Answer based only on the context provided in the documents. The answer should be tehcnical and informative. Do not make up things.
29
 
30
- User Query: {query_str}
31
- Context: {context_str}
32
- Answer:
33
  """
34
 
35
 
36
  class SimpleRAGPipeline(weave.Model):
37
- chat_llm: str = "gpt-4"
38
  embedding_model: str = "text-embedding-3-small"
39
- temperature: float = 0.0
40
- similarity_top_k: int = 2
41
  chunk_size: int = 512
42
  chunk_overlap: int = 128
43
  prompt_template: str = SYSTEM_PROMPT_TEMPLATE
@@ -46,7 +93,7 @@ class SimpleRAGPipeline(weave.Model):
46
  def _get_llm(self):
47
  return OpenAI(
48
  model=self.chat_llm,
49
- temperature=0.0,
50
  max_tokens=4096,
51
  )
52
 
@@ -56,9 +103,9 @@ class SimpleRAGPipeline(weave.Model):
56
  def _get_text_qa_template(self):
57
  return PromptTemplate(self.prompt_template)
58
 
59
- def _load_documents_and_chunk(self, files: pathlib.PosixPath):
60
  parser = MarkdownNodeParser()
61
- nodes = parser.get_nodes_from_documents(docs_files)
62
  return nodes
63
 
64
  def _create_vector_index(self, nodes):
@@ -109,5 +156,8 @@ if __name__ == "__main__":
109
  rag_pipeline = SimpleRAGPipeline()
110
  rag_pipeline.build_query_engine()
111
 
112
- response = rag_pipeline.predict("What is Llama 3 model?")
113
- print(response["response"])
 
 
 
 
2
 
3
  load_dotenv()
4
 
 
 
5
  import pickle
6
 
7
+ import weave
8
+ from llama_index.core import PromptTemplate, VectorStoreIndex, get_response_synthesizer
9
  from llama_index.core.node_parser import MarkdownNodeParser
 
 
10
  from llama_index.core.query_engine import RetrieverQueryEngine
11
+ from llama_index.core.retrievers import VectorIndexRetriever
 
12
  from llama_index.embeddings.openai import OpenAIEmbedding
13
+ from llama_index.llms.openai import OpenAI
14
 
15
  data_dir = "data/raw_docs/documents.pkl"
16
  with open(data_dir, "rb") as file:
17
  docs_files = pickle.load(file)
18
 
19
+ for i, doc in enumerate(docs_files[:], 1):
20
+ doc.metadata["page"] = i
21
 
22
  SYSTEM_PROMPT_TEMPLATE = """
23
+ Answer the following question about the newly released Llama 3 405 billion parameter model based on provided snippets from the research paper.
24
+ Provide helpful, complete, and accurate answers to the question using only the information contained in these snippets.
25
+
26
+ Here are the relevant snippets from the Llama 3 405B model research paper:
27
+
28
+ <snippets>
29
+ {context_str}
30
+ </snippets>
31
+
32
+ <question>
33
+ {query_str}
34
+ </question>
35
+
36
+ To answer this question:
37
+
38
+ 1. Carefully read and analyze the provided snippets.
39
+ 2. Identify information that is directly relevant to the user's question.
40
+ 3. Formulate a comprehensive answer based solely on the information in the snippets.
41
+ 4. Do not include any information or claims that are not supported by the provided snippets.
42
+
43
+ Guidelines for your answer:
44
+
45
+ 1. Be technical and informative, providing as much detail as the snippets allow.
46
+ 2. If the snippets do not contain enough information to fully answer the question, state this clearly and provide what information you can based on the available snippets.
47
+ 3. Do not make up or infer information beyond what is explicitly stated in the snippets.
48
+ 4. If the question cannot be answered at all based on the provided snippets, state this clearly and explain why.
49
+ 5. Use appropriate technical language and terminology as used in the snippets.
50
+ 6. Cite the relevant sentences from the snippets and their page numbers to support your answer.
51
+ 7. Answer in MFAQ format (Minimal Facts Answerable Question), providing the most concise and accurate response possible.
52
+ 8. Use Markdown to format your response and include citations to indicate the snippets and the page number used to derive your answer.
53
+
54
+ Here's an example of a question and an answer. You must use this as a template to format your response:
55
+
56
+ <example>
57
+ Question: What was the main mix of the training data ? How much data was used to train the model ?
58
+
59
+ ### Answer
60
+ The main mix of the training data for the Llama 3 405 billion parameter model is as follows:
61
+
62
+ - **General knowledge**: 50%
63
+ - **Mathematical and reasoning tokens**: 25%
64
+ - **Code tokens**: 17%
65
+ - **Multilingual tokens**: 8%[^1^].
66
+
67
+ Regarding the amount of data used to train the model, the snippets do not provide a specific total volume of data in terms of tokens or bytes. However, they do mention that the model was pre-trained on a large dataset containing knowledge until the end of 2023[^2^]. Additionally, the training process involved pre-training on 2.87 trillion tokens before further adjustments[^3^].
68
+
69
+ ### References
70
+
71
+ [^1^]: "Scaling Laws for Data Mix," page 6.
72
+ [^2^]: "Pre-Training Data," page 4.
73
+ [^3^]: "Initial Pre-Training," page 14.
74
+
75
+ </example>
76
 
77
+ Remember, your role is to accurately convey the information from the research paper snippets, not to speculate or provide information from other sources.
78
 
79
+ Answer:
 
 
80
  """
81
 
82
 
83
  class SimpleRAGPipeline(weave.Model):
84
+ chat_llm: str = "gpt-4o"
85
  embedding_model: str = "text-embedding-3-small"
86
+ temperature: float = 0.1
87
+ similarity_top_k: int = 15
88
  chunk_size: int = 512
89
  chunk_overlap: int = 128
90
  prompt_template: str = SYSTEM_PROMPT_TEMPLATE
 
93
  def _get_llm(self):
94
  return OpenAI(
95
  model=self.chat_llm,
96
+ temperature=self.temperature,
97
  max_tokens=4096,
98
  )
99
 
 
103
  def _get_text_qa_template(self):
104
  return PromptTemplate(self.prompt_template)
105
 
106
+ def _load_documents_and_chunk(self, documents: list):
107
  parser = MarkdownNodeParser()
108
+ nodes = parser.get_nodes_from_documents(documents)
109
  return nodes
110
 
111
  def _create_vector_index(self, nodes):
 
156
  rag_pipeline = SimpleRAGPipeline()
157
  rag_pipeline.build_query_engine()
158
 
159
+ response = rag_pipeline.predict(
160
+ "How does the model perform in comparision to gpt4 model?"
161
+ )
162
+ for resp in response.response_gen:
163
+ print(resp, end="")