DHEIVER commited on
Commit
80f4c28
·
verified ·
1 Parent(s): 76e6194

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -249
app.py CHANGED
@@ -1,21 +1,30 @@
1
  import gradio as gr
2
  import os
 
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import Chroma
6
  from langchain.chains import ConversationalRetrievalChain
7
- from langchain_community.embeddings import HuggingFaceEmbeddings
8
- from langchain_community.llms import HuggingFacePipeline, HuggingFaceEndpoint
 
9
  from langchain.memory import ConversationBufferMemory
 
 
10
  from pathlib import Path
11
  import chromadb
12
  from unidecode import unidecode
 
13
  from transformers import AutoTokenizer
14
  import transformers
15
  import torch
 
 
16
  import re
17
 
18
- # Lista de modelos LLM disponíveis
 
 
19
  list_llm = [
20
  "mistralai/Mistral-7B-Instruct-v0.2",
21
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
@@ -29,350 +38,348 @@ list_llm = [
29
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
30
  "mosaicml/mpt-7b-instruct",
31
  "tiiuae/falcon-7b-instruct",
32
- "google/flan-t5-xxl"
33
- ]
34
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
35
 
36
- # Funções principais (mantidas as mesmas)
37
  def load_doc(list_file_path, chunk_size, chunk_overlap):
 
 
 
38
  loaders = [PyPDFLoader(x) for x in list_file_path]
39
  pages = []
40
  for loader in loaders:
41
  pages.extend(loader.load())
 
42
  text_splitter = RecursiveCharacterTextSplitter(
43
- chunk_size=chunk_size,
44
- chunk_overlap=chunk_overlap)
45
- return text_splitter.split_documents(pages)
 
 
46
 
 
47
  def create_db(splits, collection_name):
48
  embedding = HuggingFaceEmbeddings()
49
  new_client = chromadb.EphemeralClient()
50
- return Chroma.from_documents(
51
  documents=splits,
52
  embedding=embedding,
53
  client=new_client,
54
- collection_name=collection_name
 
55
  )
 
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
58
- progress(0.1, desc="Inicializando tokenizer...")
59
- progress(0.5, desc="Configurando modelo...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
61
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
62
  llm = HuggingFaceEndpoint(
63
  repo_id=llm_model,
64
- temperature=temperature,
65
- max_new_tokens=max_tokens,
66
- top_k=top_k,
67
- load_in_8bit=True,
 
68
  )
69
- # ... (restante das condições para outros modelos)
70
-
71
- progress(0.75, desc="Configurando memória...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  memory = ConversationBufferMemory(
73
  memory_key="chat_history",
74
  output_key='answer',
75
  return_messages=True
76
  )
77
-
78
- progress(0.8, desc="Configurando cadeia de recuperação...")
79
- retriever = vector_db.as_retriever()
80
  qa_chain = ConversationalRetrievalChain.from_llm(
81
  llm,
82
  retriever=retriever,
83
  chain_type="stuff",
84
  memory=memory,
 
85
  return_source_documents=True,
 
86
  verbose=False,
87
  )
88
- progress(0.9, desc="Concluído!")
89
  return qa_chain
90
 
 
 
 
91
  def create_collection_name(filepath):
 
92
  collection_name = Path(filepath).stem
 
 
93
  collection_name = collection_name.replace(" ","-")
 
94
  collection_name = unidecode(collection_name)
 
 
95
  collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
 
96
  collection_name = collection_name[:50]
 
97
  if len(collection_name) < 3:
98
  collection_name = collection_name + 'xyz'
 
99
  if not collection_name[0].isalnum():
100
  collection_name = 'A' + collection_name[1:]
101
  if not collection_name[-1].isalnum():
102
  collection_name = collection_name[:-1] + 'Z'
 
 
103
  return collection_name
104
 
 
 
105
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
 
106
  list_file_path = [x.name for x in list_file_obj if x is not None]
107
- progress(0.1, desc="Criando nome da coleção...")
 
108
  collection_name = create_collection_name(list_file_path[0])
109
- progress(0.25, desc="Carregando documento...")
 
110
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
111
- progress(0.5, desc="Gerando banco de dados vetorial...")
 
 
112
  vector_db = create_db(doc_splits, collection_name)
113
- progress(0.9, desc="Concluído!")
114
- return vector_db, collection_name, "Completo!"
 
115
 
116
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
117
  llm_name = list_llm[llm_option]
 
118
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
119
- return qa_chain, "Completo!"
 
120
 
121
  def format_chat_history(message, chat_history):
122
  formatted_chat_history = []
123
  for user_message, bot_message in chat_history:
124
- formatted_chat_history.append(f"Usuário: {user_message}")
125
- formatted_chat_history.append(f"Assistente: {bot_message}")
126
  return formatted_chat_history
 
127
 
128
  def conversation(qa_chain, message, history):
129
  formatted_chat_history = format_chat_history(message, history)
 
 
 
130
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
131
  response_answer = response["answer"]
132
- if response_answer.find("Resposta útil:") != -1:
133
- response_answer = response_answer.split("Resposta útil:")[-1]
134
  response_sources = response["source_documents"]
135
  response_source1 = response_sources[0].page_content.strip()
136
  response_source2 = response_sources[1].page_content.strip()
137
  response_source3 = response_sources[2].page_content.strip()
 
138
  response_source1_page = response_sources[0].metadata["page"] + 1
139
  response_source2_page = response_sources[1].metadata["page"] + 1
140
  response_source3_page = response_sources[2].metadata["page"] + 1
 
 
 
 
141
  new_history = history + [(message, response_answer)]
 
142
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- # Interface Gradio em português
145
  def demo():
146
- css = """
147
- .gradio-container {max-width: 1200px !important}
148
- .message.user {background: #e3f2fd; padding: 10px; border-radius: 5px;}
149
- .message.bot {background: #f5f5f5; padding: 10px; border-radius: 5px;}
150
- """
151
-
152
- with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
153
  vector_db = gr.State()
154
  qa_chain = gr.State()
155
  collection_name = gr.State()
156
 
157
- gr.Markdown("""
158
- <center><h1>🤖 Assistente de Documentos PDF</h1></center>
159
- <h3>Faça perguntas sobre seus documentos PDF e obtenha respostas inteligentes</h3>
 
 
 
 
 
160
  """)
161
 
162
- with gr.Tabs():
163
- with gr.Tab("📄 1. Carregar PDF", id=1):
164
- gr.Markdown("### Carregue seus documentos PDF para análise")
 
 
 
 
 
 
 
 
165
  with gr.Row():
166
- document = gr.Files(
167
- height=100,
168
- file_count="multiple",
169
- file_types=["pdf"],
170
- interactive=True,
171
- label="Arraste e solte seus PDFs aqui"
172
- )
173
 
174
- with gr.Tab("⚙️ 2. Processar Documento", id=2):
175
- gr.Markdown("### Configure o processamento do documento")
 
 
 
176
  with gr.Row():
177
- db_btn = gr.Radio(
178
- ["ChromaDB"],
179
- label="Tipo de banco de dados vetorial",
180
- value="ChromaDB",
181
- type="index",
182
- info="Escolha o banco de dados para armazenar os vetores"
183
- )
184
-
185
- with gr.Accordion("⚙️ Opções avançadas - Divisão de texto", open=False):
186
- gr.Markdown("Ajuste como o texto será dividido para análise:")
187
- with gr.Row():
188
- slider_chunk_size = gr.Slider(
189
- minimum=100,
190
- maximum=1000,
191
- value=600,
192
- step=20,
193
- label="Tamanho do bloco",
194
- info="Quantidade de caracteres por bloco"
195
- )
196
- with gr.Row():
197
- slider_chunk_overlap = gr.Slider(
198
- minimum=10,
199
- maximum=200,
200
- value=40,
201
- step=10,
202
- label="Sobreposição de blocos",
203
- info="Quantidade de caracteres sobrepostos entre blocos"
204
- )
205
-
206
  with gr.Row():
207
- db_progress = gr.Textbox(
208
- label="Progresso da inicialização",
209
- value="Aguardando processamento...",
210
- interactive=False
211
- )
212
-
213
  with gr.Row():
214
- db_btn = gr.Button(
215
- "Processar Documento",
216
- variant="primary"
217
- )
218
-
219
- with gr.Tab("🧠 3. Configurar IA", id=3):
220
- gr.Markdown("### Escolha e configure o modelo de linguagem")
 
 
221
  with gr.Row():
222
- llm_btn = gr.Radio(
223
- list_llm_simple,
224
- label="Modelos disponíveis",
225
- value=list_llm_simple[0],
226
- type="index",
227
- info="Selecione o modelo de linguagem"
228
- )
229
-
230
- with gr.Accordion("⚙️ Opções avançadas - Configurações do modelo", open=False):
231
- gr.Markdown("Ajuste os parâmetros do modelo de linguagem:")
232
- with gr.Row():
233
- slider_temperature = gr.Slider(
234
- minimum=0.01,
235
- maximum=1.0,
236
- value=0.7,
237
- step=0.1,
238
- label="Temperatura",
239
- info="Controla a criatividade das respostas"
240
- )
241
- with gr.Row():
242
- slider_maxtokens = gr.Slider(
243
- minimum=224,
244
- maximum=4096,
245
- value=1024,
246
- step=32,
247
- label="Máximo de tokens",
248
- info="Limite de tamanho das respostas"
249
- )
250
- with gr.Row():
251
- slider_topk = gr.Slider(
252
- minimum=1,
253
- maximum=10,
254
- value=3,
255
- step=1,
256
- label="Amostras top-k",
257
- info="Número de opções consideradas"
258
- )
259
-
260
  with gr.Row():
261
- llm_progress = gr.Textbox(
262
- value="Aguardando configuração...",
263
- label="Status da IA",
264
- interactive=False
265
- )
266
-
267
  with gr.Row():
268
- qachain_btn = gr.Button(
269
- "Inicializar Assistente",
270
- variant="primary"
271
- )
 
 
 
272
 
273
- with gr.Tab("💬 4. Conversar", id=4):
274
- gr.Markdown("### Converse com o assistente sobre o documento")
275
- chatbot = gr.Chatbot(
276
- height=400,
277
- label="Histórico da Conversa",
278
- bubble_full_width=False
279
- )
280
-
281
- with gr.Accordion("🔍 Referências do documento", open=False):
282
- gr.Markdown("Trechos do documento usados para gerar as respostas:")
283
- with gr.Row():
284
- doc_source1 = gr.Textbox(
285
- label="Referência 1",
286
- lines=2,
287
- container=True,
288
- scale=20
289
- )
290
- source1_page = gr.Number(
291
- label="Página",
292
- scale=1
293
- )
294
- with gr.Row():
295
- doc_source2 = gr.Textbox(
296
- label="Referência 2",
297
- lines=2,
298
- container=True,
299
- scale=20
300
- )
301
- source2_page = gr.Number(
302
- label="Página",
303
- scale=1
304
- )
305
- with gr.Row():
306
- doc_source3 = gr.Textbox(
307
- label="Referência 3",
308
- lines=2,
309
- container=True,
310
- scale=20
311
- )
312
- source3_page = gr.Number(
313
- label="Página",
314
- scale=1
315
- )
316
-
317
- with gr.Row():
318
- msg = gr.Textbox(
319
- placeholder="Digite sua mensagem...",
320
- container=True,
321
- scale=4
322
- )
323
- submit_btn = gr.Button(
324
- "Enviar",
325
- variant="primary",
326
- scale=1
327
- )
328
-
329
- with gr.Row():
330
- clear_btn = gr.ClearButton(
331
- [msg, chatbot],
332
- value="Limpar Conversa",
333
- variant="secondary"
334
- )
335
-
336
- # Conexões de eventos
337
- db_btn.click(
338
- initialize_database,
339
- inputs=[document, slider_chunk_size, slider_chunk_overlap],
340
- outputs=[vector_db, collection_name, db_progress]
341
- )
342
-
343
- qachain_btn.click(
344
- initialize_LLM,
345
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
346
- outputs=[qa_chain, llm_progress]
347
- ).then(
348
- lambda: [None, "", 0, "", 0, "", 0],
349
- inputs=None,
350
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
351
- queue=False
352
- )
353
-
354
- msg.submit(
355
- conversation,
356
- inputs=[qa_chain, msg, chatbot],
357
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
358
- queue=False
359
- )
360
-
361
- submit_btn.click(
362
- conversation,
363
- inputs=[qa_chain, msg, chatbot],
364
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
365
- queue=False
366
- )
367
-
368
- clear_btn.click(
369
- lambda: [None, "", 0, "", 0, "", 0],
370
- inputs=None,
371
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
372
- queue=False
373
- )
374
-
375
  demo.queue().launch(debug=True)
376
 
 
377
  if __name__ == "__main__":
378
  demo()
 
1
  import gradio as gr
2
  import os
3
+
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.llms import HuggingFacePipeline
10
+ from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
+ from langchain_community.llms import HuggingFaceEndpoint
13
+
14
  from pathlib import Path
15
  import chromadb
16
  from unidecode import unidecode
17
+
18
  from transformers import AutoTokenizer
19
  import transformers
20
  import torch
21
+ import tqdm
22
+ import accelerate
23
  import re
24
 
25
+
26
+
27
+ # default_persist_directory = './chroma_HF/'
28
  list_llm = [
29
  "mistralai/Mistral-7B-Instruct-v0.2",
30
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
 
38
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
39
  "mosaicml/mpt-7b-instruct",
40
  "tiiuae/falcon-7b-instruct",
41
+ "google/flan-t5-xxl"]
 
42
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
43
 
44
+ # Load PDF document and create doc splits
45
  def load_doc(list_file_path, chunk_size, chunk_overlap):
46
+ # Processing for one document only
47
+ # loader = PyPDFLoader(file_path)
48
+ # pages = loader.load()
49
  loaders = [PyPDFLoader(x) for x in list_file_path]
50
  pages = []
51
  for loader in loaders:
52
  pages.extend(loader.load())
53
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
54
  text_splitter = RecursiveCharacterTextSplitter(
55
+ chunk_size = chunk_size,
56
+ chunk_overlap = chunk_overlap)
57
+ doc_splits = text_splitter.split_documents(pages)
58
+ return doc_splits
59
+
60
 
61
+ # Create vector database
62
  def create_db(splits, collection_name):
63
  embedding = HuggingFaceEmbeddings()
64
  new_client = chromadb.EphemeralClient()
65
+ vectordb = Chroma.from_documents(
66
  documents=splits,
67
  embedding=embedding,
68
  client=new_client,
69
+ collection_name=collection_name,
70
+ # persist_directory=default_persist_directory
71
  )
72
+ return vectordb
73
 
74
+
75
+ # Load vector database
76
+ def load_db():
77
+ embedding = HuggingFaceEmbeddings()
78
+ vectordb = Chroma(
79
+ # persist_directory=default_persist_directory,
80
+ embedding_function=embedding)
81
+ return vectordb
82
+
83
+
84
+ # Initialize langchain LLM chain
85
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
86
+ progress(0.1, desc="Initializing HF tokenizer...")
87
+ # HuggingFacePipeline uses local model
88
+ # Note: it will download model locally...
89
+ # tokenizer=AutoTokenizer.from_pretrained(llm_model)
90
+ # progress(0.5, desc="Initializing HF pipeline...")
91
+ # pipeline=transformers.pipeline(
92
+ # "text-generation",
93
+ # model=llm_model,
94
+ # tokenizer=tokenizer,
95
+ # torch_dtype=torch.bfloat16,
96
+ # trust_remote_code=True,
97
+ # device_map="auto",
98
+ # # max_length=1024,
99
+ # max_new_tokens=max_tokens,
100
+ # do_sample=True,
101
+ # top_k=top_k,
102
+ # num_return_sequences=1,
103
+ # eos_token_id=tokenizer.eos_token_id
104
+ # )
105
+ # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
106
 
107
+ # HuggingFaceHub uses HF inference endpoints
108
+ progress(0.5, desc="Initializing HF Hub...")
109
+ # Use of trust_remote_code as model_kwargs
110
+ # Warning: langchain issue
111
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
112
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
113
  llm = HuggingFaceEndpoint(
114
  repo_id=llm_model,
115
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
116
+ temperature = temperature,
117
+ max_new_tokens = max_tokens,
118
+ top_k = top_k,
119
+ load_in_8bit = True,
120
  )
121
+ elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1","mosaicml/mpt-7b-instruct"]:
122
+ raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint")
123
+ llm = HuggingFaceEndpoint(
124
+ repo_id=llm_model,
125
+ temperature = temperature,
126
+ max_new_tokens = max_tokens,
127
+ top_k = top_k,
128
+ )
129
+ elif llm_model == "microsoft/phi-2":
130
+ # raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
131
+ llm = HuggingFaceEndpoint(
132
+ repo_id=llm_model,
133
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
134
+ temperature = temperature,
135
+ max_new_tokens = max_tokens,
136
+ top_k = top_k,
137
+ trust_remote_code = True,
138
+ torch_dtype = "auto",
139
+ )
140
+ elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
141
+ llm = HuggingFaceEndpoint(
142
+ repo_id=llm_model,
143
+ # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
144
+ temperature = temperature,
145
+ max_new_tokens = 250,
146
+ top_k = top_k,
147
+ )
148
+ elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
149
+ raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
150
+ llm = HuggingFaceEndpoint(
151
+ repo_id=llm_model,
152
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
153
+ temperature = temperature,
154
+ max_new_tokens = max_tokens,
155
+ top_k = top_k,
156
+ )
157
+ else:
158
+ llm = HuggingFaceEndpoint(
159
+ repo_id=llm_model,
160
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
161
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
162
+ temperature = temperature,
163
+ max_new_tokens = max_tokens,
164
+ top_k = top_k,
165
+ )
166
+
167
+ progress(0.75, desc="Defining buffer memory...")
168
  memory = ConversationBufferMemory(
169
  memory_key="chat_history",
170
  output_key='answer',
171
  return_messages=True
172
  )
173
+ # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
174
+ retriever=vector_db.as_retriever()
175
+ progress(0.8, desc="Defining retrieval chain...")
176
  qa_chain = ConversationalRetrievalChain.from_llm(
177
  llm,
178
  retriever=retriever,
179
  chain_type="stuff",
180
  memory=memory,
181
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
182
  return_source_documents=True,
183
+ #return_generated_question=False,
184
  verbose=False,
185
  )
186
+ progress(0.9, desc="Done!")
187
  return qa_chain
188
 
189
+
190
+ # Generate collection name for vector database
191
+ # - Use filepath as input, ensuring unicode text
192
  def create_collection_name(filepath):
193
+ # Extract filename without extension
194
  collection_name = Path(filepath).stem
195
+ # Fix potential issues from naming convention
196
+ ## Remove space
197
  collection_name = collection_name.replace(" ","-")
198
+ ## ASCII transliterations of Unicode text
199
  collection_name = unidecode(collection_name)
200
+ ## Remove special characters
201
+ #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
202
  collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
203
+ ## Limit length to 50 characters
204
  collection_name = collection_name[:50]
205
+ ## Minimum length of 3 characters
206
  if len(collection_name) < 3:
207
  collection_name = collection_name + 'xyz'
208
+ ## Enforce start and end as alphanumeric character
209
  if not collection_name[0].isalnum():
210
  collection_name = 'A' + collection_name[1:]
211
  if not collection_name[-1].isalnum():
212
  collection_name = collection_name[:-1] + 'Z'
213
+ print('Filepath: ', filepath)
214
+ print('Collection name: ', collection_name)
215
  return collection_name
216
 
217
+
218
+ # Initialize database
219
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
220
+ # Create list of documents (when valid)
221
  list_file_path = [x.name for x in list_file_obj if x is not None]
222
+ # Create collection_name for vector database
223
+ progress(0.1, desc="Creating collection name...")
224
  collection_name = create_collection_name(list_file_path[0])
225
+ progress(0.25, desc="Loading document...")
226
+ # Load document and create splits
227
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
228
+ # Create or load vector database
229
+ progress(0.5, desc="Generating vector database...")
230
+ # global vector_db
231
  vector_db = create_db(doc_splits, collection_name)
232
+ progress(0.9, desc="Done!")
233
+ return vector_db, collection_name, "Complete!"
234
+
235
 
236
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
237
+ # print("llm_option",llm_option)
238
  llm_name = list_llm[llm_option]
239
+ print("llm_name: ",llm_name)
240
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
241
+ return qa_chain, "Complete!"
242
+
243
 
244
  def format_chat_history(message, chat_history):
245
  formatted_chat_history = []
246
  for user_message, bot_message in chat_history:
247
+ formatted_chat_history.append(f"User: {user_message}")
248
+ formatted_chat_history.append(f"Assistant: {bot_message}")
249
  return formatted_chat_history
250
+
251
 
252
  def conversation(qa_chain, message, history):
253
  formatted_chat_history = format_chat_history(message, history)
254
+ #print("formatted_chat_history",formatted_chat_history)
255
+
256
+ # Generate response using QA chain
257
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
258
  response_answer = response["answer"]
259
+ if response_answer.find("Helpful Answer:") != -1:
260
+ response_answer = response_answer.split("Helpful Answer:")[-1]
261
  response_sources = response["source_documents"]
262
  response_source1 = response_sources[0].page_content.strip()
263
  response_source2 = response_sources[1].page_content.strip()
264
  response_source3 = response_sources[2].page_content.strip()
265
+ # Langchain sources are zero-based
266
  response_source1_page = response_sources[0].metadata["page"] + 1
267
  response_source2_page = response_sources[1].metadata["page"] + 1
268
  response_source3_page = response_sources[2].metadata["page"] + 1
269
+ # print ('chat response: ', response_answer)
270
+ # print('DB source', response_sources)
271
+
272
+ # Append user message and response to chat history
273
  new_history = history + [(message, response_answer)]
274
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
275
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
276
+
277
+
278
+ def upload_file(file_obj):
279
+ list_file_path = []
280
+ for idx, file in enumerate(file_obj):
281
+ file_path = file_obj.name
282
+ list_file_path.append(file_path)
283
+ # print(file_path)
284
+ # initialize_database(file_path, progress)
285
+ return list_file_path
286
+
287
 
 
288
  def demo():
289
+ with gr.Blocks(theme="base") as demo:
 
 
 
 
 
 
290
  vector_db = gr.State()
291
  qa_chain = gr.State()
292
  collection_name = gr.State()
293
 
294
+ gr.Markdown(
295
+ """<center><h2>PDF-based chatbot</center></h2>
296
+ <h3>Ask any questions about your PDF documents</h3>""")
297
+ gr.Markdown(
298
+ """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
299
+ The user interface explicitely shows multiple steps to help understand the RAG workflow.
300
+ This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
301
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
302
  """)
303
 
304
+ with gr.Tab("Step 1 - Upload PDF"):
305
+ with gr.Row():
306
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
307
+ # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
308
+
309
+ with gr.Tab("Step 2 - Process document"):
310
+ with gr.Row():
311
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
312
+ with gr.Accordion("Advanced options - Document text splitter", open=False):
313
+ with gr.Row():
314
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
315
  with gr.Row():
316
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
317
+ with gr.Row():
318
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
319
+ with gr.Row():
320
+ db_btn = gr.Button("Generate vector database")
 
 
321
 
322
+ with gr.Tab("Step 3 - Initialize QA chain"):
323
+ with gr.Row():
324
+ llm_btn = gr.Radio(list_llm_simple, \
325
+ label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
326
+ with gr.Accordion("Advanced options - LLM model", open=False):
327
  with gr.Row():
328
+ slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  with gr.Row():
330
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
 
 
 
 
 
331
  with gr.Row():
332
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
333
+ with gr.Row():
334
+ llm_progress = gr.Textbox(value="None",label="QA chain initialization")
335
+ with gr.Row():
336
+ qachain_btn = gr.Button("Initialize Question Answering chain")
337
+
338
+ with gr.Tab("Step 4 - Chatbot"):
339
+ chatbot = gr.Chatbot(height=300)
340
+ with gr.Accordion("Advanced - Document references", open=False):
341
  with gr.Row():
342
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
343
+ source1_page = gr.Number(label="Page", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  with gr.Row():
345
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
346
+ source2_page = gr.Number(label="Page", scale=1)
 
 
 
 
347
  with gr.Row():
348
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
349
+ source3_page = gr.Number(label="Page", scale=1)
350
+ with gr.Row():
351
+ msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
352
+ with gr.Row():
353
+ submit_btn = gr.Button("Submit message")
354
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
355
 
356
+ # Preprocessing events
357
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
358
+ db_btn.click(initialize_database, \
359
+ inputs=[document, slider_chunk_size, slider_chunk_overlap], \
360
+ outputs=[vector_db, collection_name, db_progress])
361
+ qachain_btn.click(initialize_LLM, \
362
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
363
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
364
+ inputs=None, \
365
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
366
+ queue=False)
367
+
368
+ # Chatbot events
369
+ msg.submit(conversation, \
370
+ inputs=[qa_chain, msg, chatbot], \
371
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
372
+ queue=False)
373
+ submit_btn.click(conversation, \
374
+ inputs=[qa_chain, msg, chatbot], \
375
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
376
+ queue=False)
377
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], \
378
+ inputs=None, \
379
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
380
+ queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  demo.queue().launch(debug=True)
382
 
383
+
384
  if __name__ == "__main__":
385
  demo()