ayut commited on
Commit
b50f20a
β€’
1 Parent(s): efa4300

working rag

Browse files
Files changed (2) hide show
  1. app.py +33 -2
  2. rag/rag.py +113 -0
app.py CHANGED
@@ -1,4 +1,35 @@
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
1
+ import weave
2
+
3
  import streamlit as st
4
+ from rag.rag import SimpleRAGPipeline
5
+
6
+ WANDB_PROJECT = "paper_reader"
7
+
8
+ weave.init(f"{WANDB_PROJECT}")
9
+
10
+ st.set_page_config(page_title="Chat with the Llama 3 paper!", page_icon="πŸ¦™", layout="centered", initial_sidebar_state="auto", menu_items=None)
11
+ st.title("Chat with the Llama 3 paper πŸ’¬πŸ¦™")
12
+
13
+ @st.cache_resource(show_spinner=False)
14
+ def load_rag_pipeline():
15
+ rag_pipeline = SimpleRAGPipeline()
16
+ rag_pipeline.build_query_engine()
17
+
18
+ return rag_pipeline
19
+
20
+ if "rag_pipeline" not in st.session_state.keys():
21
+ st.session_state.rag_pipeline = load_rag_pipeline()
22
+
23
+ rag_pipeline = st.session_state["rag_pipeline"]
24
+
25
+ # openai_api_key = st.sidebar.text_input('OpenAI API Key', type='password')
26
+
27
+ def generate_response(query):
28
+ response = rag_pipeline.predict(query)
29
+ st.write_stream(response.response_gen)
30
 
31
+ with st.form('my_form'):
32
+ query = st.text_area('Ask your question about the Llama 3 paper here:')
33
+ submitted = st.form_submit_button('Submit')
34
+ if submitted:
35
+ generate_response(query)
rag/rag.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+
3
+ load_dotenv()
4
+
5
+ import weave
6
+ import pathlib
7
+ import pickle
8
+
9
+ from llama_index.core import PromptTemplate
10
+ from llama_index.core.node_parser import MarkdownNodeParser
11
+ from llama_index.core import VectorStoreIndex
12
+ from llama_index.core.retrievers import VectorIndexRetriever
13
+ from llama_index.core.query_engine import RetrieverQueryEngine
14
+ from llama_index.core import get_response_synthesizer
15
+ from llama_index.llms.openai import OpenAI
16
+ from llama_index.embeddings.openai import OpenAIEmbedding
17
+ from llama_index.core import VectorStoreIndex
18
+
19
+ data_dir = "data/raw_docs/documents.pkl"
20
+ with open(data_dir, "rb") as file:
21
+ docs_files = pickle.load(file)
22
+
23
+ print(f"Number of files: {len(docs_files)}\n")
24
+
25
+ SYSTEM_PROMPT_TEMPLATE = """
26
+ Answer to the user question about the newly released Llama 3 405 billion parameter model based on the context. Provide an helful and complete answer. The paper will have information about the training, inference, evaluation and many developments in Machine Learning.
27
+
28
+ Answer based only on the context provided in the documents. The answer should be tehcnical and informative. Do not make up things.
29
+
30
+ User Query: {query_str}
31
+ Context: {context_str}
32
+ Answer:
33
+ """
34
+
35
+
36
+ class SimpleRAGPipeline(weave.Model):
37
+ chat_llm: str = "gpt-4"
38
+ embedding_model: str = "text-embedding-3-small"
39
+ temperature: float = 0.0
40
+ similarity_top_k: int = 2
41
+ chunk_size: int = 512
42
+ chunk_overlap: int = 128
43
+ prompt_template: str = SYSTEM_PROMPT_TEMPLATE
44
+ query_engine: RetrieverQueryEngine = None
45
+
46
+ def _get_llm(self):
47
+ return OpenAI(
48
+ model=self.chat_llm,
49
+ temperature=0.0,
50
+ max_tokens=4096,
51
+ )
52
+
53
+ def _get_embedding_model(self):
54
+ return OpenAIEmbedding(model=self.embedding_model)
55
+
56
+ def _get_text_qa_template(self):
57
+ return PromptTemplate(self.prompt_template)
58
+
59
+ def _load_documents_and_chunk(self, files: pathlib.PosixPath):
60
+ parser = MarkdownNodeParser()
61
+ nodes = parser.get_nodes_from_documents(docs_files)
62
+ return nodes
63
+
64
+ def _create_vector_index(self, nodes):
65
+ index = VectorStoreIndex(
66
+ nodes,
67
+ embed_model=self._get_embedding_model(),
68
+ show_progress=True,
69
+ insert_batch_size=128,
70
+ )
71
+
72
+ return index
73
+
74
+ def _get_retriever(self, index):
75
+ retriever = VectorIndexRetriever(
76
+ index=index,
77
+ similarity_top_k=self.similarity_top_k,
78
+ )
79
+ return retriever
80
+
81
+ def _get_response_synthesizer(self):
82
+ llm = self._get_llm()
83
+ response_synthesizer = get_response_synthesizer(
84
+ llm=llm,
85
+ response_mode="compact",
86
+ text_qa_template=self._get_text_qa_template(),
87
+ streaming=True,
88
+ )
89
+ return response_synthesizer
90
+
91
+ def build_query_engine(self):
92
+ nodes = self._load_documents_and_chunk(docs_files)
93
+ index = self._create_vector_index(nodes)
94
+ retriever = self._get_retriever(index)
95
+ response_synthesizer = self._get_response_synthesizer()
96
+
97
+ self.query_engine = RetrieverQueryEngine(
98
+ retriever=retriever,
99
+ response_synthesizer=response_synthesizer,
100
+ )
101
+
102
+ @weave.op()
103
+ def predict(self, question: str):
104
+ response = self.query_engine.query(question)
105
+ return response
106
+
107
+
108
+ if __name__ == "__main__":
109
+ rag_pipeline = SimpleRAGPipeline()
110
+ rag_pipeline.build_query_engine()
111
+
112
+ response = rag_pipeline.predict("What is Llama 3 model?")
113
+ print(response["response"])