Spaces:
Runtime error
Runtime error
working rag
Browse files- app.py +33 -2
- rag/rag.py +113 -0
app.py
CHANGED
@@ -1,4 +1,35 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
st.
|
|
|
|
|
|
|
|
1 |
+
import weave
|
2 |
+
|
3 |
import streamlit as st
|
4 |
+
from rag.rag import SimpleRAGPipeline
|
5 |
+
|
6 |
+
WANDB_PROJECT = "paper_reader"
|
7 |
+
|
8 |
+
weave.init(f"{WANDB_PROJECT}")
|
9 |
+
|
10 |
+
st.set_page_config(page_title="Chat with the Llama 3 paper!", page_icon="π¦", layout="centered", initial_sidebar_state="auto", menu_items=None)
|
11 |
+
st.title("Chat with the Llama 3 paper π¬π¦")
|
12 |
+
|
13 |
+
@st.cache_resource(show_spinner=False)
|
14 |
+
def load_rag_pipeline():
|
15 |
+
rag_pipeline = SimpleRAGPipeline()
|
16 |
+
rag_pipeline.build_query_engine()
|
17 |
+
|
18 |
+
return rag_pipeline
|
19 |
+
|
20 |
+
if "rag_pipeline" not in st.session_state.keys():
|
21 |
+
st.session_state.rag_pipeline = load_rag_pipeline()
|
22 |
+
|
23 |
+
rag_pipeline = st.session_state["rag_pipeline"]
|
24 |
+
|
25 |
+
# openai_api_key = st.sidebar.text_input('OpenAI API Key', type='password')
|
26 |
+
|
27 |
+
def generate_response(query):
|
28 |
+
response = rag_pipeline.predict(query)
|
29 |
+
st.write_stream(response.response_gen)
|
30 |
|
31 |
+
with st.form('my_form'):
|
32 |
+
query = st.text_area('Ask your question about the Llama 3 paper here:')
|
33 |
+
submitted = st.form_submit_button('Submit')
|
34 |
+
if submitted:
|
35 |
+
generate_response(query)
|
rag/rag.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
|
3 |
+
load_dotenv()
|
4 |
+
|
5 |
+
import weave
|
6 |
+
import pathlib
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
from llama_index.core import PromptTemplate
|
10 |
+
from llama_index.core.node_parser import MarkdownNodeParser
|
11 |
+
from llama_index.core import VectorStoreIndex
|
12 |
+
from llama_index.core.retrievers import VectorIndexRetriever
|
13 |
+
from llama_index.core.query_engine import RetrieverQueryEngine
|
14 |
+
from llama_index.core import get_response_synthesizer
|
15 |
+
from llama_index.llms.openai import OpenAI
|
16 |
+
from llama_index.embeddings.openai import OpenAIEmbedding
|
17 |
+
from llama_index.core import VectorStoreIndex
|
18 |
+
|
19 |
+
data_dir = "data/raw_docs/documents.pkl"
|
20 |
+
with open(data_dir, "rb") as file:
|
21 |
+
docs_files = pickle.load(file)
|
22 |
+
|
23 |
+
print(f"Number of files: {len(docs_files)}\n")
|
24 |
+
|
25 |
+
SYSTEM_PROMPT_TEMPLATE = """
|
26 |
+
Answer to the user question about the newly released Llama 3 405 billion parameter model based on the context. Provide an helful and complete answer. The paper will have information about the training, inference, evaluation and many developments in Machine Learning.
|
27 |
+
|
28 |
+
Answer based only on the context provided in the documents. The answer should be tehcnical and informative. Do not make up things.
|
29 |
+
|
30 |
+
User Query: {query_str}
|
31 |
+
Context: {context_str}
|
32 |
+
Answer:
|
33 |
+
"""
|
34 |
+
|
35 |
+
|
36 |
+
class SimpleRAGPipeline(weave.Model):
|
37 |
+
chat_llm: str = "gpt-4"
|
38 |
+
embedding_model: str = "text-embedding-3-small"
|
39 |
+
temperature: float = 0.0
|
40 |
+
similarity_top_k: int = 2
|
41 |
+
chunk_size: int = 512
|
42 |
+
chunk_overlap: int = 128
|
43 |
+
prompt_template: str = SYSTEM_PROMPT_TEMPLATE
|
44 |
+
query_engine: RetrieverQueryEngine = None
|
45 |
+
|
46 |
+
def _get_llm(self):
|
47 |
+
return OpenAI(
|
48 |
+
model=self.chat_llm,
|
49 |
+
temperature=0.0,
|
50 |
+
max_tokens=4096,
|
51 |
+
)
|
52 |
+
|
53 |
+
def _get_embedding_model(self):
|
54 |
+
return OpenAIEmbedding(model=self.embedding_model)
|
55 |
+
|
56 |
+
def _get_text_qa_template(self):
|
57 |
+
return PromptTemplate(self.prompt_template)
|
58 |
+
|
59 |
+
def _load_documents_and_chunk(self, files: pathlib.PosixPath):
|
60 |
+
parser = MarkdownNodeParser()
|
61 |
+
nodes = parser.get_nodes_from_documents(docs_files)
|
62 |
+
return nodes
|
63 |
+
|
64 |
+
def _create_vector_index(self, nodes):
|
65 |
+
index = VectorStoreIndex(
|
66 |
+
nodes,
|
67 |
+
embed_model=self._get_embedding_model(),
|
68 |
+
show_progress=True,
|
69 |
+
insert_batch_size=128,
|
70 |
+
)
|
71 |
+
|
72 |
+
return index
|
73 |
+
|
74 |
+
def _get_retriever(self, index):
|
75 |
+
retriever = VectorIndexRetriever(
|
76 |
+
index=index,
|
77 |
+
similarity_top_k=self.similarity_top_k,
|
78 |
+
)
|
79 |
+
return retriever
|
80 |
+
|
81 |
+
def _get_response_synthesizer(self):
|
82 |
+
llm = self._get_llm()
|
83 |
+
response_synthesizer = get_response_synthesizer(
|
84 |
+
llm=llm,
|
85 |
+
response_mode="compact",
|
86 |
+
text_qa_template=self._get_text_qa_template(),
|
87 |
+
streaming=True,
|
88 |
+
)
|
89 |
+
return response_synthesizer
|
90 |
+
|
91 |
+
def build_query_engine(self):
|
92 |
+
nodes = self._load_documents_and_chunk(docs_files)
|
93 |
+
index = self._create_vector_index(nodes)
|
94 |
+
retriever = self._get_retriever(index)
|
95 |
+
response_synthesizer = self._get_response_synthesizer()
|
96 |
+
|
97 |
+
self.query_engine = RetrieverQueryEngine(
|
98 |
+
retriever=retriever,
|
99 |
+
response_synthesizer=response_synthesizer,
|
100 |
+
)
|
101 |
+
|
102 |
+
@weave.op()
|
103 |
+
def predict(self, question: str):
|
104 |
+
response = self.query_engine.query(question)
|
105 |
+
return response
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
rag_pipeline = SimpleRAGPipeline()
|
110 |
+
rag_pipeline.build_query_engine()
|
111 |
+
|
112 |
+
response = rag_pipeline.predict("What is Llama 3 model?")
|
113 |
+
print(response["response"])
|