samsonleegh commited on
Commit
6d513c3
1 Parent(s): 0eee40c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from llama_index.core.base.embeddings.base import similarity
3
+ #from llama_index.llms.ollama import Ollama
4
+ from llama_index.llms.groq import Groq
5
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, DocumentSummaryIndex
6
+ from llama_index.core import StorageContext, get_response_synthesizer
7
+ from llama_index.core.retrievers import VectorIndexRetriever
8
+ from llama_index.core.query_engine import RetrieverQueryEngine
9
+ from llama_index.vector_stores.chroma import ChromaVectorStore
10
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
11
+ from llama_index.core import load_index_from_storage
12
+ import os
13
+ from dotenv import load_dotenv
14
+ from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler, CBEventType
15
+ from llama_index.core.node_parser import SentenceSplitter
16
+ from llama_index.core.postprocessor import SimilarityPostprocessor
17
+ import time
18
+ import gradio as gr
19
+ from llama_index.core.memory import ChatMemoryBuffer
20
+ from llama_parse import LlamaParse
21
+ from llama_index.core import PromptTemplate
22
+ from llama_index.core.llms import ChatMessage, MessageRole
23
+ from llama_index.core.chat_engine import CondenseQuestionChatEngine
24
+
25
+
26
+ load_dotenv()
27
+ GROQ_API_KEY = os.getenv('GROQ_API_KEY')
28
+ LLAMAINDEX_API_KEY = os.getenv('LLAMAINDEX_API_KEY')
29
+
30
+ # set up callback manager
31
+ llama_debug = LlamaDebugHandler(print_trace_on_end=True)
32
+ callback_manager = CallbackManager([llama_debug])
33
+ Settings.callback_manager = callback_manager
34
+
35
+ # set up LLM
36
+ llm = Groq(model="llama3-70b-8192")#"llama3-8b-8192")
37
+ Settings.llm = llm
38
+
39
+ # set up embedding model
40
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
41
+ Settings.embed_model = embed_model
42
+
43
+ # create splitter
44
+ splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=50)
45
+ Settings.transformations = [splitter]
46
+
47
+ # create parser
48
+ parser = LlamaParse(
49
+ api_key=LLAMAINDEX_API_KEY,
50
+ result_type="markdown", # "markdown" and "text" are available
51
+ verbose=True,
52
+ )
53
+
54
+ #create index
55
+ if os.path.exists("./vectordb"):
56
+ print("Index Exists!")
57
+ storage_context = StorageContext.from_defaults(persist_dir="./vectordb")
58
+ index = load_index_from_storage(storage_context)
59
+ else:
60
+ filename_fn = lambda filename: {"file_name": filename}
61
+ required_exts = [".pdf",".docx"]
62
+ file_extractor = {".pdf": parser}
63
+ reader = SimpleDirectoryReader(
64
+ input_dir="./data",
65
+ file_extractor=file_extractor,
66
+ required_exts=required_exts,
67
+ recursive=True,
68
+ file_metadata=filename_fn
69
+ )
70
+ documents = reader.load_data()
71
+ print("index creating with `%d` documents", len(documents))
72
+ index = VectorStoreIndex.from_documents(documents, embed_model=embed_model, transformations=[splitter])
73
+ index.storage_context.persist(persist_dir="./vectordb")
74
+
75
+ """
76
+ #create document summary index
77
+ if os.path.exists("./docsummarydb"):
78
+ print("Index Exists!")
79
+ storage_context = StorageContext.from_defaults(persist_dir="./docsummarydb")
80
+ doc_index = load_index_from_storage(storage_context)
81
+ else:
82
+ filename_fn = lambda filename: {"file_name": filename}
83
+ required_exts = [".pdf",".docx"]
84
+ reader = SimpleDirectoryReader(
85
+ input_dir="./data",
86
+ required_exts=required_exts,
87
+ recursive=True,
88
+ file_metadata=filename_fn
89
+ )
90
+ documents = reader.load_data()
91
+ print("index creating with `%d` documents", len(documents))
92
+
93
+ response_synthesizer = get_response_synthesizer(
94
+ response_mode="tree_summarize", use_async=True
95
+ )
96
+ doc_index = DocumentSummaryIndex.from_documents(
97
+ documents,
98
+ llm = llm,
99
+ transformations = [splitter],
100
+ response_synthesizer = response_synthesizer,
101
+ show_progress = True
102
+ )
103
+ doc_index.storage_context.persist(persist_dir="./docsummarydb")
104
+ """
105
+ """
106
+ retriever = DocumentSummaryIndexEmbeddingRetriever(
107
+ doc_index,
108
+ similarity_top_k=5,
109
+ )
110
+ """
111
+
112
+ # set up retriever
113
+ retriever = VectorIndexRetriever(
114
+ index = index,
115
+ similarity_top_k = 10,
116
+ #vector_store_query_mode="mmr",
117
+ #vector_store_kwargs={"mmr_threshold": 0.4}
118
+ )
119
+
120
+ # set up response synthesizer
121
+ response_synthesizer = get_response_synthesizer()
122
+
123
+ ### customising prompts worsened the result###
124
+ """
125
+ # set up prompt template
126
+ qa_prompt_tmpl = (
127
+ "Context information from multiple sources is below.\n"
128
+ "---------------------\n"
129
+ "{context_str}\n"
130
+ "---------------------\n"
131
+ "Given the information from multiple sources and not prior knowledge, "
132
+ "answer the query.\n"
133
+ "Query: {query_str}\n"
134
+ "Answer: "
135
+ )
136
+ qa_prompt = PromptTemplate(qa_prompt_tmpl)
137
+ """
138
+ # setting up query engine
139
+ query_engine = RetrieverQueryEngine(
140
+ retriever = retriever,
141
+ node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.53)],
142
+ response_synthesizer=get_response_synthesizer(response_mode="tree_summarize",verbose=True)
143
+ )
144
+ print(query_engine.get_prompts())
145
+
146
+ #response = query_engine.query("What happens if the distributor wants its own warehouse for pizzahood?")
147
+ #print(response)
148
+
149
+
150
+ memory = ChatMemoryBuffer.from_defaults(token_limit=10000)
151
+
152
+ custom_prompt = PromptTemplate(
153
+ """\
154
+ Given a conversation (between Human and Assistant) and a follow up message from Human, \
155
+ rewrite the message to be a standalone question that captures all relevant context \
156
+ from the conversation. If you are unsure, ask for more information.
157
+
158
+ <Chat History>
159
+ {chat_history}
160
+
161
+ <Follow Up Message>
162
+ {question}
163
+
164
+ <Standalone question>
165
+ """
166
+ )
167
+
168
+ # list of `ChatMessage` objects
169
+ custom_chat_history = [
170
+ ChatMessage(
171
+ role=MessageRole.USER,
172
+ content="Hello assistant.",
173
+ ),
174
+ ChatMessage(role=MessageRole.ASSISTANT, content="Hello user."),
175
+ ]
176
+
177
+ chat_engine = CondenseQuestionChatEngine.from_defaults(
178
+ query_engine=query_engine,
179
+ condense_question_prompt=custom_prompt,
180
+ chat_history=custom_chat_history,
181
+ verbose=True,
182
+ memory=memory
183
+ )
184
+
185
+ # gradio with streaming support
186
+ with gr.Blocks() as demo:
187
+ chat_engine = chat_engine
188
+ chatbot = gr.Chatbot()
189
+ msg = gr.Textbox(label="⏎ for sending",
190
+ placeholder="Ask me something",)
191
+ clear = gr.Button("Delete")
192
+
193
+ def user(user_message, history):
194
+ return "", history + [[user_message, None]]
195
+
196
+ def bot(history):
197
+ user_message = history[-1][0]
198
+ #bot_message = chat_engine.chat(user_message)
199
+ bot_message = query_engine.query(user_message + "Let's think step by step to get the correct answer. If you cannot provide an answer, say you don't know.")
200
+ history[-1][1] = ""
201
+ for character in bot_message.response:
202
+ history[-1][1] += character
203
+ time.sleep(0.01)
204
+ yield history
205
+
206
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
207
+ bot, chatbot, chatbot
208
+ )
209
+ clear.click(lambda: None, None, chatbot, queue=False)
210
+ demo.queue()
211
+ demo.launch(share=False)