AamirAli123 commited on
Commit
0d7efba
·
verified ·
1 Parent(s): 1f7e67b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +245 -0
  2. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.llms import HuggingFacePipeline
10
+ from langchain.chains import ConversationChain
11
+ from langchain.memory import ConversationBufferMemory
12
+ from langchain.llms import HuggingFaceHub
13
+ from pathlib import Path
14
+ import chromadb
15
+ load_dotenv()
16
+ huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
17
+
18
+ # default_persist_directory = './chroma_HF/'
19
+ list_llm = ["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
20
+ "google/gemma-7b-it","google/gemma-2b-it", \
21
+ "HuggingFaceH4/zephyr-7b-beta", \
22
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "tiiuae/falcon-7b-instruct", \
23
+ "google/flan-t5-xxl"
24
+ ]
25
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
26
+
27
+ # Load PDF document and create doc splits
28
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
29
+ # Processing for one document only
30
+ loaders = [PyPDFLoader(x) for x in list_file_path]
31
+ pages = []
32
+ for loader in loaders:
33
+ pages.extend(loader.load())
34
+ text_splitter = RecursiveCharacterTextSplitter(
35
+ chunk_size = chunk_size,
36
+ chunk_overlap = chunk_overlap)
37
+ doc_splits = text_splitter.split_documents(pages)
38
+ return doc_splits
39
+
40
+ def load_doc_for_openai(list_file_path):
41
+ # Processing for one document only
42
+ loaders = [PyPDFLoader(x) for x in list_file_path]
43
+ pages = []
44
+ for loader in loaders:
45
+ pages.extend(loader.load())
46
+ text_splitter = RecursiveCharacterTextSplitter(
47
+ chunk_size = 600,
48
+ chunk_overlap = 40)
49
+ doc_splits = text_splitter.split_documents(pages)
50
+ return doc_splits
51
+
52
+ # Create vector database
53
+ def create_db(splits, collection_name):
54
+ embedding = HuggingFaceEmbeddings()
55
+ new_client = chromadb.EphemeralClient()
56
+ vectordb = Chroma.from_documents(
57
+ documents = splits,
58
+ embedding = embedding,
59
+ client = new_client,
60
+ collection_name = collection_name,
61
+ # persist_directory=default_persist_directory
62
+ )
63
+ return vectordb
64
+
65
+
66
+ # Load vector database
67
+ def load_db():
68
+ embedding = HuggingFaceEmbeddings()
69
+ vectordb = Chroma( embedding_function = embedding)
70
+ return vectordb
71
+
72
+
73
+ # Initialize langchain LLM chain
74
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
75
+ progress(0.1, desc="Initializing HF tokenizer...")
76
+ # HuggingFaceHub uses HF inference endpoints
77
+ progress(0.5, desc="Initializing HF Hub...")
78
+ # Use of trust_remote_code as model_kwargs
79
+ # Warning: langchain issue
80
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
81
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
82
+ llm = HuggingFaceHub(
83
+ repo_id=llm_model,
84
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
85
+ )
86
+ elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
87
+ llm = HuggingFaceHub(
88
+ repo_id=llm_model,
89
+ model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
90
+ )
91
+ else:
92
+ llm = HuggingFaceHub(
93
+ repo_id=llm_model,
94
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
95
+ )
96
+
97
+ progress(0.75, desc="Defining buffer memory...")
98
+ memory = ConversationBufferMemory(
99
+ memory_key="chat_history",
100
+ output_key='answer',
101
+ return_messages=True
102
+ )
103
+ progress(0.8, desc="Defining retrieval chain...")
104
+ retriever = vector_db.as_retriever()
105
+ qa_chain = ConversationalRetrievalChain.from_llm(
106
+ llm,
107
+ retriever = retriever,
108
+ chain_type = "stuff",
109
+ memory = memory,
110
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
111
+ return_source_documents=True,
112
+ #return_generated_question=False,
113
+ verbose = False,
114
+ )
115
+ progress(0.9, desc="Done!")
116
+ return qa_chain
117
+
118
+
119
+ # Initialize database
120
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, vector_db, progress = gr.Progress()):
121
+ # Create list of documents (when valid)
122
+ list_file_path = [x.name for x in list_file_obj if x is not None]
123
+ # Create collection_name for vector database
124
+ progress(0.1, desc="Creating collection name...")
125
+ collection_name = Path(list_file_path[0]).stem
126
+ # Fix potential issues from naming convention
127
+ ## Remove space
128
+ collection_name = collection_name.replace(" ","-")
129
+ ## Limit lenght to 50 characters
130
+ collection_name = collection_name[:50]
131
+ ## Enforce start and end as alphanumeric character
132
+ if not collection_name[0].isalnum():
133
+ collection_name[0] = 'A'
134
+ if not collection_name[-1].isalnum():
135
+ collection_name[-1] = 'Z'
136
+ # print('list_file_path: ', list_file_path)
137
+ print('Collection name: ', collection_name)
138
+ progress(0.25, desc="Loading document...")
139
+ # Load document and create splits
140
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
141
+ # Create or load vector database
142
+ progress(0.7, desc="Generating vector database...")
143
+ # global vector_db
144
+ vector_db = create_db(doc_splits, collection_name)
145
+ return vector_db, collection_name, "Complete!"
146
+
147
+
148
+ def re_initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
149
+ llm_name = list_llm[llm_option]
150
+ print("llm_name: ",llm_name)
151
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
152
+ return qa_chain
153
+
154
+
155
+ def format_chat_history(message, chat_history):
156
+ formatted_chat_history = []
157
+ for user_message, bot_message in chat_history:
158
+ formatted_chat_history.append(f"User: {user_message}")
159
+ formatted_chat_history.append(f"Assistant: {bot_message}")
160
+ return formatted_chat_history
161
+
162
+
163
+ def conversation(qa_chain, message, history, llm_option):
164
+ formatted_chat_history = format_chat_history(message, history)
165
+ # Generate response using QA chain
166
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
167
+ response_answer = response["answer"]
168
+ if response_answer.find("Helpful Answer:") != -1:
169
+ response_answer = response_answer.split("Helpful Answer:")[-1]
170
+ new_history = history + [(message, response_answer)]
171
+ return qa_chain, gr.update(value = ""), new_history
172
+
173
+
174
+ def upload_file(file_obj):
175
+ list_file_path = []
176
+ for idx, file in enumerate(file_obj):
177
+ file_path = file_obj.name
178
+ list_file_path.append(file_path)
179
+ # print(file_path)
180
+ return list_file_path
181
+
182
+
183
+ def demo():
184
+ with gr.Blocks(theme = "base") as demo:
185
+ vector_db = gr.State()
186
+ qa_chain = gr.State()
187
+ collection_name = gr.State()
188
+ gr.Markdown(
189
+ '''
190
+ <div style="text-align:center;">
191
+ <span style="font-size:3em; font-weight:bold;">PDF Document Chatbot</span>
192
+ </div>
193
+ ''')
194
+ with gr.Row():
195
+ with gr.Row():
196
+ with gr.Column():
197
+ document = gr.Files(file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
198
+ with gr.Row():
199
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database", visible = False)
200
+ with gr.Accordion("Advanced options - Document text splitter", open=False, visible = False):
201
+ with gr.Row():
202
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True, visible = False)
203
+ with gr.Row():
204
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True, visible = False)
205
+ llm_btn = gr.Radio(list_llm_simple, label = "LLM models", type = "index", info = "Choose your LLM model")
206
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
207
+ with gr.Row():
208
+ submit_file = gr.Button("Submit File")
209
+ with gr.Row():
210
+ with gr.Column():
211
+ chatbot = gr.Chatbot()
212
+ msg = gr.Textbox(placeholder = "Type Your Message")
213
+ with gr.Accordion("Advanced options - LLM model", open = False):
214
+ with gr.Row():
215
+ slider_temperature = gr.Slider(minimum = 0.0, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
216
+ with gr.Row():
217
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
218
+ with gr.Row():
219
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
220
+ with gr.Row():
221
+ submit_btn = gr.Button("Submit")
222
+ # clear_btn = gr.ClearButton([msg2, chatbot])
223
+ # Preprocessing events
224
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
225
+ submit_file.click(initialize_database, \
226
+ inputs=[document, slider_chunk_size, slider_chunk_overlap, vector_db], \
227
+ outputs = [vector_db, collection_name, db_progress])
228
+ llm_btn.change(
229
+ re_initialize_LLM, \
230
+ inputs = [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
231
+ outputs = [qa_chain]
232
+ )
233
+ msg.submit(conversation, \
234
+ inputs=[qa_chain, msg, chatbot, llm_btn], \
235
+ outputs=[qa_chain, msg, chatbot], \
236
+ queue=False)
237
+ submit_btn.click(conversation, \
238
+ inputs=[qa_chain, msg, chatbot, llm_btn], \
239
+ outputs=[qa_chain, msg, chatbot], \
240
+ queue=False)
241
+ demo.queue().launch(share = True, debug = True)
242
+
243
+
244
+ if __name__ == "__main__":
245
+ demo()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentence-transformers
4
+ langchain
5
+ tqdm
6
+ accelerate
7
+ pypdf
8
+ chromadb
9
+ langchain-community
10
+ weasyprint
11
+ openai
12
+ tiktoken
13
+ pypdf
14
+ pdf2image
15
+ gradio