DHEIVER commited on
Commit
f08873e
·
verified ·
1 Parent(s): 9694c2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -213
app.py CHANGED
@@ -1,287 +1,189 @@
1
  import gradio as gr
2
  import os
3
-
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
- from langchain_community.embeddings import HuggingFaceEmbeddings
9
- from langchain_community.llms import HuggingFacePipeline
10
- from langchain.chains import ConversationChain
11
- from langchain.memory import ConversationBufferMemory
12
  from langchain_community.llms import HuggingFaceEndpoint
13
-
14
  from pathlib import Path
15
  import chromadb
16
  from unidecode import unidecode
17
-
18
- from transformers import AutoTokenizer
19
- import transformers
20
- import torch
21
- import tqdm
22
- import accelerate
23
  import re
24
 
25
-
26
-
27
- # default_persist_directory = './chroma_HF/'
28
- list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
29
- "google/gemma-7b-it","google/gemma-2b-it", \
30
- "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1", \
31
- "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
32
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
 
 
 
 
 
 
33
  "google/flan-t5-xxl"
34
  ]
35
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
36
 
37
- # Load PDF document and create doc splits
38
  def load_doc(list_file_path, chunk_size, chunk_overlap):
39
- # Processing for one document only
40
- # loader = PyPDFLoader(file_path)
41
- # pages = loader.load()
42
  loaders = [PyPDFLoader(x) for x in list_file_path]
43
  pages = []
44
  for loader in loaders:
45
  pages.extend(loader.load())
46
- # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
47
  text_splitter = RecursiveCharacterTextSplitter(
48
- chunk_size = chunk_size,
49
- chunk_overlap = chunk_overlap)
 
50
  doc_splits = text_splitter.split_documents(pages)
51
  return doc_splits
52
 
53
-
54
- # Create vector database
55
  def create_db(splits, collection_name):
56
  embedding = HuggingFaceEmbeddings()
57
- new_client = chromadb.EphemeralClient()
 
58
  vectordb = Chroma.from_documents(
59
  documents=splits,
60
  embedding=embedding,
61
  client=new_client,
62
  collection_name=collection_name,
63
- # persist_directory=default_persist_directory
64
  )
65
  return vectordb
66
 
67
-
68
- # Load vector database
69
- def load_db():
70
- embedding = HuggingFaceEmbeddings()
71
- vectordb = Chroma(
72
- # persist_directory=default_persist_directory,
73
- embedding_function=embedding)
74
- return vectordb
75
-
76
-
77
- # Initialize langchain LLM chain
78
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
79
- progress(0.1, desc="Initializing HF tokenizer...")
80
- # HuggingFacePipeline uses local model
81
- # Note: it will download model locally...
82
- # tokenizer=AutoTokenizer.from_pretrained(llm_model)
83
- # progress(0.5, desc="Initializing HF pipeline...")
84
- # pipeline=transformers.pipeline(
85
- # "text-generation",
86
- # model=llm_model,
87
- # tokenizer=tokenizer,
88
- # torch_dtype=torch.bfloat16,
89
- # trust_remote_code=True,
90
- # device_map="auto",
91
- # # max_length=1024,
92
- # max_new_tokens=max_tokens,
93
- # do_sample=True,
94
- # top_k=top_k,
95
- # num_return_sequences=1,
96
- # eos_token_id=tokenizer.eos_token_id
97
- # )
98
- # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
99
-
100
- # HuggingFaceHub uses HF inference endpoints
101
- progress(0.5, desc="Initializing HF Hub...")
102
- # Use of trust_remote_code as model_kwargs
103
- # Warning: langchain issue
104
- # URL: https://github.com/langchain-ai/langchain/issues/6080
105
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
106
  llm = HuggingFaceEndpoint(
107
- repo_id=llm_model,
108
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
109
- temperature = temperature,
110
- max_new_tokens = max_tokens,
111
- top_k = top_k,
112
- load_in_8bit = True,
113
- )
114
- elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1","mosaicml/mpt-7b-instruct"]:
115
- raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint")
116
- llm = HuggingFaceEndpoint(
117
- repo_id=llm_model,
118
- temperature = temperature,
119
- max_new_tokens = max_tokens,
120
- top_k = top_k,
121
  )
 
 
122
  elif llm_model == "microsoft/phi-2":
123
- # raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
124
  llm = HuggingFaceEndpoint(
125
- repo_id=llm_model,
126
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
127
- temperature = temperature,
128
- max_new_tokens = max_tokens,
129
- top_k = top_k,
130
- trust_remote_code = True,
131
- torch_dtype = "auto",
132
  )
133
  elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
134
  llm = HuggingFaceEndpoint(
135
- repo_id=llm_model,
136
- # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
137
- temperature = temperature,
138
- max_new_tokens = 250,
139
- top_k = top_k,
140
  )
141
  elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
142
- raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
143
- llm = HuggingFaceEndpoint(
144
- repo_id=llm_model,
145
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
146
- temperature = temperature,
147
- max_new_tokens = max_tokens,
148
- top_k = top_k,
149
- )
150
  else:
151
  llm = HuggingFaceEndpoint(
152
- repo_id=llm_model,
153
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
154
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
155
- temperature = temperature,
156
- max_new_tokens = max_tokens,
157
- top_k = top_k,
158
  )
159
-
160
- progress(0.75, desc="Defining buffer memory...")
161
  memory = ConversationBufferMemory(
162
  memory_key="chat_history",
163
  output_key='answer',
164
  return_messages=True
165
  )
166
- # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
167
- retriever=vector_db.as_retriever()
168
- progress(0.8, desc="Defining retrieval chain...")
169
  qa_chain = ConversationalRetrievalChain.from_llm(
170
  llm,
171
  retriever=retriever,
172
- chain_type="stuff",
173
  memory=memory,
174
- # combine_docs_chain_kwargs={"prompt": your_prompt})
175
  return_source_documents=True,
176
- #return_generated_question=False,
177
  verbose=False,
178
  )
179
- progress(0.9, desc="Done!")
180
  return qa_chain
181
 
182
-
183
- # Generate collection name for vector database
184
- # - Use filepath as input, ensuring unicode text
185
  def create_collection_name(filepath):
186
- # Extract filename without extension
187
  collection_name = Path(filepath).stem
188
- # Fix potential issues from naming convention
189
- ## Remove space
190
- collection_name = collection_name.replace(" ","-")
191
- ## ASCII transliterations of Unicode text
192
  collection_name = unidecode(collection_name)
193
- ## Remove special characters
194
- #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
195
  collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
196
- ## Limit length to 50 characters
197
  collection_name = collection_name[:50]
198
- ## Minimum length of 3 characters
199
  if len(collection_name) < 3:
200
  collection_name = collection_name + 'xyz'
201
- ## Enforce start and end as alphanumeric character
202
  if not collection_name[0].isalnum():
203
  collection_name = 'A' + collection_name[1:]
204
  if not collection_name[-1].isalnum():
205
  collection_name = collection_name[:-1] + 'Z'
206
- print('Filepath: ', filepath)
207
- print('Collection name: ', collection_name)
208
  return collection_name
209
 
210
-
211
- # Initialize database
212
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
213
- # Create list of documents (when valid)
214
  list_file_path = [x.name for x in list_file_obj if x is not None]
215
- # Create collection_name for vector database
216
- progress(0.1, desc="Creating collection name...")
217
  collection_name = create_collection_name(list_file_path[0])
218
- progress(0.25, desc="Loading document...")
219
- # Load document and create splits
220
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
221
- # Create or load vector database
222
- progress(0.5, desc="Generating vector database...")
223
- # global vector_db
224
  vector_db = create_db(doc_splits, collection_name)
225
- progress(0.9, desc="Done!")
226
- return vector_db, collection_name, "Complete!"
227
-
228
 
 
229
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
230
- # print("llm_option",llm_option)
231
  llm_name = list_llm[llm_option]
232
- print("llm_name: ",llm_name)
233
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
234
- return qa_chain, "Complete!"
235
-
236
 
 
237
  def format_chat_history(message, chat_history):
238
  formatted_chat_history = []
239
  for user_message, bot_message in chat_history:
240
- formatted_chat_history.append(f"User: {user_message}")
241
- formatted_chat_history.append(f"Assistant: {bot_message}")
242
  return formatted_chat_history
243
-
244
 
 
245
  def conversation(qa_chain, message, history):
246
  formatted_chat_history = format_chat_history(message, history)
247
- #print("formatted_chat_history",formatted_chat_history)
248
-
249
- # Generate response using QA chain
250
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
251
  response_answer = response["answer"]
252
- if response_answer.find("Helpful Answer:") != -1:
253
- response_answer = response_answer.split("Helpful Answer:")[-1]
254
  response_sources = response["source_documents"]
255
  response_source1 = response_sources[0].page_content.strip()
256
  response_source2 = response_sources[1].page_content.strip()
257
  response_source3 = response_sources[2].page_content.strip()
258
- # Langchain sources are zero-based
259
  response_source1_page = response_sources[0].metadata["page"] + 1
260
  response_source2_page = response_sources[1].metadata["page"] + 1
261
  response_source3_page = response_sources[2].metadata["page"] + 1
262
- # print ('chat response: ', response_answer)
263
- # print('DB source', response_sources)
264
-
265
- # Append user message and response to chat history
266
  new_history = history + [(message, response_answer)]
267
- # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
268
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
269
-
270
 
 
271
  def upload_file(file_obj):
272
  list_file_path = []
273
  for idx, file in enumerate(file_obj):
274
  file_path = file_obj.name
275
  list_file_path.append(file_path)
276
- # print(file_path)
277
- # initialize_database(file_path, progress)
278
  return list_file_path
279
 
280
-
281
- import gradio as gr
282
-
283
- import gradio as gr
284
-
285
  def demo():
286
  with gr.Blocks(theme="base") as demo:
287
  vector_db = gr.State()
@@ -289,82 +191,85 @@ def demo():
289
  collection_name = gr.State()
290
 
291
  gr.Markdown(
292
- """<center><h2>PDF-based chatbot</center></h2>
293
- <h3>Ask any questions about your PDF documents</h3>""")
294
  gr.Markdown(
295
- """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
296
- The user interface explicitly shows multiple steps to help understand the RAG workflow.
297
- This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
298
- <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
299
  """)
300
 
301
- with gr.Tab("Step 1 - Upload PDF"):
302
  with gr.Row():
303
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
 
304
 
305
- with gr.Tab("Step 2 - Process document"):
306
  with gr.Row():
307
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
308
- with gr.Accordion("Advanced options - Document text splitter", open=False):
309
  with gr.Row():
310
- slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
311
  with gr.Row():
312
- slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
313
  with gr.Row():
314
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
315
  with gr.Row():
316
- db_btn = gr.Button("Generate vector database")
317
 
318
- with gr.Tab("Step 3 - Initialize QA chain"):
319
  with gr.Row():
320
- llm_btn = gr.Radio(["LLM1", "LLM2"], label="LLM models", value = "LLM1", type="index", info="Choose your LLM model")
321
- with gr.Accordion("Advanced options - LLM model", open=False):
 
322
  with gr.Row():
323
- slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
324
  with gr.Row():
325
- slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
326
  with gr.Row():
327
- slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
328
  with gr.Row():
329
- llm_progress = gr.Textbox(value="None",label="QA chain initialization")
330
  with gr.Row():
331
- qachain_btn = gr.Button("Initialize Question Answering chain")
332
 
333
- with gr.Tab("Step 4 - Chatbot"):
334
  chatbot = gr.Chatbot(height=300)
335
- with gr.Accordion("Advanced - Document references", open=False):
336
  with gr.Row():
337
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
338
- source1_page = gr.Number(label="Page", scale=1)
339
  with gr.Row():
340
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
341
- source2_page = gr.Number(label="Page", scale=1)
342
  with gr.Row():
343
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
344
- source3_page = gr.Number(label="Page", scale=1)
345
  with gr.Row():
346
- msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
347
  with gr.Row():
348
- submit_btn = gr.Button("Submit message")
349
- clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
350
 
351
- # Preprocessing events
352
- db_btn.click(lambda: ("Vector DB Initialized", "Collection Name", "DB Progress"), \
 
353
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
354
  outputs=[vector_db, collection_name, db_progress])
355
- qachain_btn.click(lambda: ("QA Chain Initialized", "LLM Progress"), \
356
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
357
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
358
  inputs=None, \
359
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
360
  queue=False)
361
 
362
- # Chatbot events
363
- msg.submit(lambda: ("QA Chain", "Message", "Chatbot", "Doc Source 1", 1, "Doc Source 2", 2, "Doc Source 3", 3), \
364
  inputs=[qa_chain, msg, chatbot], \
365
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
366
  queue=False)
367
- submit_btn.click(lambda: ("QA Chain", "Message", "Chatbot", "Doc Source 1", 1, "Doc Source 2", 2, "Doc Source 3", 3), \
368
  inputs=[qa_chain, msg, chatbot], \
369
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
370
  queue=False)
@@ -374,5 +279,6 @@ def demo():
374
  queue=False)
375
  demo.queue().launch(debug=True)
376
 
 
377
  if __name__ == "__main__":
378
- demo()
 
1
  import gradio as gr
2
  import os
 
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import Chroma
6
  from langchain.chains import ConversationalRetrievalChain
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
 
8
  from langchain_community.llms import HuggingFaceEndpoint
9
+ from langchain.memory import ConversationBufferMemory
10
  from pathlib import Path
11
  import chromadb
12
  from unidecode import unidecode
 
 
 
 
 
 
13
  import re
14
 
15
+ # Lista de modelos LLM disponíveis
16
+ list_llm = [
17
+ "mistralai/Mistral-7B-Instruct-v0.2",
18
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
19
+ "mistralai/Mistral-7B-Instruct-v0.1",
20
+ "google/gemma-7b-it",
21
+ "google/gemma-2b-it",
22
+ "HuggingFaceH4/zephyr-7b-beta",
23
+ "HuggingFaceH4/zephyr-7b-gemma-v0.1",
24
+ "meta-llama/Llama-2-7b-chat-hf",
25
+ "microsoft/phi-2",
26
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
27
+ "mosaicml/mpt-7b-instruct",
28
+ "tiiuae/falcon-7b-instruct",
29
  "google/flan-t5-xxl"
30
  ]
31
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
32
 
33
+ # Função para carregar documentos PDF e dividir em chunks
34
  def load_doc(list_file_path, chunk_size, chunk_overlap):
 
 
 
35
  loaders = [PyPDFLoader(x) for x in list_file_path]
36
  pages = []
37
  for loader in loaders:
38
  pages.extend(loader.load())
 
39
  text_splitter = RecursiveCharacterTextSplitter(
40
+ chunk_size=chunk_size,
41
+ chunk_overlap=chunk_overlap
42
+ )
43
  doc_splits = text_splitter.split_documents(pages)
44
  return doc_splits
45
 
46
+ # Função para criar o banco de dados vetorial
 
47
  def create_db(splits, collection_name):
48
  embedding = HuggingFaceEmbeddings()
49
+ # Usando PersistentClient para persistir o banco de dados
50
+ new_client = chromadb.PersistentClient(path="./chroma_db")
51
  vectordb = Chroma.from_documents(
52
  documents=splits,
53
  embedding=embedding,
54
  client=new_client,
55
  collection_name=collection_name,
 
56
  )
57
  return vectordb
58
 
59
+ # Função para inicializar a cadeia de QA com o modelo LLM
 
 
 
 
 
 
 
 
 
 
60
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
61
+ progress(0.1, desc="Inicializando tokenizer da HF...")
62
+ progress(0.5, desc="Inicializando Hub da HF...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
64
  llm = HuggingFaceEndpoint(
65
+ repo_id=llm_model,
66
+ temperature=temperature,
67
+ max_new_tokens=max_tokens,
68
+ top_k=top_k,
69
+ load_in_8bit=True,
 
 
 
 
 
 
 
 
 
70
  )
71
+ elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1", "mosaicml/mpt-7b-instruct"]:
72
+ raise gr.Error("O modelo LLM é muito grande para ser carregado automaticamente no endpoint de inferência gratuito")
73
  elif llm_model == "microsoft/phi-2":
 
74
  llm = HuggingFaceEndpoint(
75
+ repo_id=llm_model,
76
+ temperature=temperature,
77
+ max_new_tokens=max_tokens,
78
+ top_k=top_k,
79
+ trust_remote_code=True,
80
+ torch_dtype="auto",
 
81
  )
82
  elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
83
  llm = HuggingFaceEndpoint(
84
+ repo_id=llm_model,
85
+ temperature=temperature,
86
+ max_new_tokens=250,
87
+ top_k=top_k,
 
88
  )
89
  elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
90
+ raise gr.Error("O modelo Llama-2-7b-chat-hf requer uma assinatura Pro...")
 
 
 
 
 
 
 
91
  else:
92
  llm = HuggingFaceEndpoint(
93
+ repo_id=llm_model,
94
+ temperature=temperature,
95
+ max_new_tokens=max_tokens,
96
+ top_k=top_k,
 
 
97
  )
98
+
99
+ progress(0.75, desc="Definindo memória de buffer...")
100
  memory = ConversationBufferMemory(
101
  memory_key="chat_history",
102
  output_key='answer',
103
  return_messages=True
104
  )
105
+ retriever = vector_db.as_retriever()
106
+ progress(0.8, desc="Definindo cadeia de recuperação...")
 
107
  qa_chain = ConversationalRetrievalChain.from_llm(
108
  llm,
109
  retriever=retriever,
110
+ chain_type="stuff",
111
  memory=memory,
 
112
  return_source_documents=True,
 
113
  verbose=False,
114
  )
115
+ progress(0.9, desc="Concluído!")
116
  return qa_chain
117
 
118
+ # Função para gerar um nome de coleção válido
 
 
119
  def create_collection_name(filepath):
 
120
  collection_name = Path(filepath).stem
121
+ collection_name = collection_name.replace(" ", "-")
 
 
 
122
  collection_name = unidecode(collection_name)
 
 
123
  collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
 
124
  collection_name = collection_name[:50]
 
125
  if len(collection_name) < 3:
126
  collection_name = collection_name + 'xyz'
 
127
  if not collection_name[0].isalnum():
128
  collection_name = 'A' + collection_name[1:]
129
  if not collection_name[-1].isalnum():
130
  collection_name = collection_name[:-1] + 'Z'
131
+ print('Caminho do arquivo: ', filepath)
132
+ print('Nome da coleção: ', collection_name)
133
  return collection_name
134
 
135
+ # Função para inicializar o banco de dados
 
136
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
 
137
  list_file_path = [x.name for x in list_file_obj if x is not None]
138
+ progress(0.1, desc="Criando nome da coleção...")
 
139
  collection_name = create_collection_name(list_file_path[0])
140
+ progress(0.25, desc="Carregando documento...")
 
141
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
142
+ progress(0.5, desc="Gerando banco de dados vetorial...")
 
 
143
  vector_db = create_db(doc_splits, collection_name)
144
+ progress(0.9, desc="Concluído!")
145
+ return vector_db, collection_name, "Completo!"
 
146
 
147
+ # Função para inicializar o modelo LLM
148
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
149
  llm_name = list_llm[llm_option]
150
+ print("Nome do LLM: ", llm_name)
151
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
152
+ return qa_chain, "Completo!"
 
153
 
154
+ # Função para formatar o histórico de conversa
155
  def format_chat_history(message, chat_history):
156
  formatted_chat_history = []
157
  for user_message, bot_message in chat_history:
158
+ formatted_chat_history.append(f"Usuário: {user_message}")
159
+ formatted_chat_history.append(f"Assistente: {bot_message}")
160
  return formatted_chat_history
 
161
 
162
+ # Função para realizar a conversa com o chatbot
163
  def conversation(qa_chain, message, history):
164
  formatted_chat_history = format_chat_history(message, history)
 
 
 
165
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
166
  response_answer = response["answer"]
167
+ if response_answer.find("Resposta útil:") != -1:
168
+ response_answer = response_answer.split("Resposta útil:")[-1]
169
  response_sources = response["source_documents"]
170
  response_source1 = response_sources[0].page_content.strip()
171
  response_source2 = response_sources[1].page_content.strip()
172
  response_source3 = response_sources[2].page_content.strip()
 
173
  response_source1_page = response_sources[0].metadata["page"] + 1
174
  response_source2_page = response_sources[1].metadata["page"] + 1
175
  response_source3_page = response_sources[2].metadata["page"] + 1
 
 
 
 
176
  new_history = history + [(message, response_answer)]
 
177
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
178
 
179
+ # Função para carregar arquivos
180
  def upload_file(file_obj):
181
  list_file_path = []
182
  for idx, file in enumerate(file_obj):
183
  file_path = file_obj.name
184
  list_file_path.append(file_path)
 
 
185
  return list_file_path
186
 
 
 
 
 
 
187
  def demo():
188
  with gr.Blocks(theme="base") as demo:
189
  vector_db = gr.State()
 
191
  collection_name = gr.State()
192
 
193
  gr.Markdown(
194
+ """<center><h2>Chatbot baseado em PDF</center></h2>
195
+ <h3>Faça qualquer pergunta sobre seus documentos PDF</h3>""")
196
  gr.Markdown(
197
+ """<b>Nota:</b> Este assistente de IA, utilizando Langchain e LLMs de código aberto, realiza geração aumentada por recuperação (RAG) a partir de seus documentos PDF. \
198
+ A interface do usuário mostra explicitamente várias etapas para ajudar a entender o fluxo de trabalho do RAG.
199
+ Este chatbot leva em consideração perguntas anteriores ao gerar respostas (via memória conversacional), e inclui referências documentais para maior clareza.<br>
200
+ <br><b>Aviso:</b> Este espaço usa a CPU básica gratuita do Hugging Face. Algumas etapas e modelos LLM utilizados abaixo (pontos finais de inferência gratuitos) podem levar algum tempo para gerar uma resposta.
201
  """)
202
 
203
+ with gr.Tab("Etapa 1 - Carregar PDF"):
204
  with gr.Row():
205
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carregue seus documentos PDF (único ou múltiplos)")
206
+ # upload_btn = gr.UploadButton("Carregando documento...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
207
 
208
+ with gr.Tab("Etapa 2 - Processar documento"):
209
  with gr.Row():
210
+ db_btn = gr.Radio(["ChromaDB"], label="Tipo de banco de dados vetorial", value = "ChromaDB", type="index", info="Escolha o banco de dados vetorial")
211
+ with gr.Accordion("Opções avançadas - Divisor de texto do documento", open=False):
212
  with gr.Row():
213
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Tamanho do bloco", info="Tamanho do bloco", interactive=True)
214
  with gr.Row():
215
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Sobreposição do bloco", info="Sobreposição do bloco", interactive=True)
216
  with gr.Row():
217
+ db_progress = gr.Textbox(label="Inicialização do banco de dados vetorial", value="Nenhum")
218
  with gr.Row():
219
+ db_btn = gr.Button("Gerar banco de dados vetorial")
220
 
221
+ with gr.Tab("Etapa 3 - Inicializar cadeia de QA"):
222
  with gr.Row():
223
+ llm_btn = gr.Radio(list_llm_simple, \
224
+ label="Modelos LLM", value = list_llm_simple[0], type="index", info="Escolha seu modelo LLM")
225
+ with gr.Accordion("Opções avançadas - Modelo LLM", open=False):
226
  with gr.Row():
227
+ slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperatura", info="Temperatura do modelo", interactive=True)
228
  with gr.Row():
229
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Máximo de Tokens", info="Máximo de tokens do modelo", interactive=True)
230
  with gr.Row():
231
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="Amostras top-k", info="Amostras top-k do modelo", interactive=True)
232
  with gr.Row():
233
+ llm_progress = gr.Textbox(value="Nenhum",label="Inicialização da cadeia QA")
234
  with gr.Row():
235
+ qachain_btn = gr.Button("Inicializar cadeia de Pergunta e Resposta")
236
 
237
+ with gr.Tab("Etapa 4 - Chatbot"):
238
  chatbot = gr.Chatbot(height=300)
239
+ with gr.Accordion("Avançado - Referências do documento", open=False):
240
  with gr.Row():
241
+ doc_source1 = gr.Textbox(label="Referência 1", lines=2, container=True, scale=20)
242
+ source1_page = gr.Number(label="Página", scale=1)
243
  with gr.Row():
244
+ doc_source2 = gr.Textbox(label="Referência 2", lines=2, container=True, scale=20)
245
+ source2_page = gr.Number(label="Página", scale=1)
246
  with gr.Row():
247
+ doc_source3 = gr.Textbox(label="Referência 3", lines=2, container=True, scale=20)
248
+ source3_page = gr.Number(label="Página", scale=1)
249
  with gr.Row():
250
+ msg = gr.Textbox(placeholder="Digite a mensagem (exemplo: 'Sobre o que é este documento?')", container=True)
251
  with gr.Row():
252
+ submit_btn = gr.Button("Enviar mensagem")
253
+ clear_btn = gr.ClearButton([msg, chatbot], value="Limpar conversa")
254
 
255
+ # Eventos de pré-processamento
256
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
257
+ db_btn.click(initialize_database, \
258
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
259
  outputs=[vector_db, collection_name, db_progress])
260
+ qachain_btn.click(initialize_LLM, \
261
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
262
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
263
  inputs=None, \
264
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
265
  queue=False)
266
 
267
+ # Eventos do Chatbot
268
+ msg.submit(conversation, \
269
  inputs=[qa_chain, msg, chatbot], \
270
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
271
  queue=False)
272
+ submit_btn.click(conversation, \
273
  inputs=[qa_chain, msg, chatbot], \
274
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
275
  queue=False)
 
279
  queue=False)
280
  demo.queue().launch(debug=True)
281
 
282
+
283
  if __name__ == "__main__":
284
+ demo()