DHEIVER commited on
Commit
099bb87
·
verified ·
1 Parent(s): 17c064a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -72
app.py CHANGED
@@ -12,120 +12,273 @@ import chromadb
12
  from unidecode import unidecode
13
  import re
14
 
15
- # Modelos LLM disponíveis
16
  list_llm = [
17
- "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1",
18
- "google/gemma-7b-it", "google/gemma-2b-it", "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
19
- "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct",
20
- "tiiuae/falcon-7b-instruct", "google/flan-t5-xxl"
 
 
 
 
 
 
 
 
 
21
  ]
22
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
23
 
24
- # Função de carregamento e divisão de documentos
25
- def load_and_split_documents(list_file_path, chunk_size, chunk_overlap):
26
  loaders = [PyPDFLoader(x) for x in list_file_path]
27
  pages = []
28
  for loader in loaders:
29
  pages.extend(loader.load())
30
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
31
- return text_splitter.split_documents(pages)
 
 
 
 
32
 
33
- # Função para criar banco de dados vetorial com ChromaDB
34
- def create_vector_db(splits, collection_name):
35
  embedding = HuggingFaceEmbeddings()
 
36
  new_client = chromadb.PersistentClient(path="./chroma_db")
37
- return Chroma.from_documents(documents=splits, embedding=embedding, client=new_client, collection_name=collection_name)
 
 
 
 
 
 
38
 
39
- # Função para inicializar a cadeia de QA
40
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
41
- progress(0.1, desc="Inicializando tokenizer e Hub...")
42
- llm = HuggingFaceEndpoint(
43
- repo_id=llm_model, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, load_in_8bit=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
- progress(0.5, desc="Definindo memória de buffer e cadeia de recuperação...")
46
- memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
47
  retriever = vector_db.as_retriever()
48
- qa_chain = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True)
 
 
 
 
 
 
 
 
49
  progress(0.9, desc="Concluído!")
50
  return qa_chain
51
 
52
  # Função para gerar um nome de coleção válido
53
  def create_collection_name(filepath):
54
  collection_name = Path(filepath).stem
55
- collection_name = unidecode(collection_name.replace(" ", "-"))
56
- return re.sub('[^A-Za-z0-9]+', '-', collection_name)[:50]
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Função para inicializar o banco de dados e o modelo LLM
59
- def initialize_database_and_llm(list_file_obj, chunk_size, chunk_overlap, llm_option, llm_temperature, max_tokens, top_k, progress=gr.Progress()):
60
  list_file_path = [x.name for x in list_file_obj if x is not None]
61
  progress(0.1, desc="Criando nome da coleção...")
62
  collection_name = create_collection_name(list_file_path[0])
63
- progress(0.25, desc="Carregando e dividindo documentos...")
64
- doc_splits = load_and_split_documents(list_file_path, chunk_size, chunk_overlap)
65
  progress(0.5, desc="Gerando banco de dados vetorial...")
66
- vector_db = create_vector_db(doc_splits, collection_name)
67
- progress(0.75, desc="Inicializando modelo LLM...")
 
 
 
 
68
  llm_name = list_llm[llm_option]
 
69
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
70
- progress(0.9, desc="Concluído!")
71
- return vector_db, collection_name, qa_chain
72
 
73
- # Função de interação com o chatbot
 
 
 
 
 
 
 
 
74
  def conversation(qa_chain, message, history):
75
- formatted_chat_history = [f"Usuário: {user_message}\nAssistente: {bot_message}" for user_message, bot_message in history]
76
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
77
- response_answer = response["answer"].split("Resposta útil:")[-1]
78
- response_sources = [doc.page_content.strip() for doc in response["source_documents"]]
79
- response_pages = [doc.metadata["page"] + 1 for doc in response["source_documents"]]
 
 
 
 
 
 
 
80
  new_history = history + [(message, response_answer)]
81
- return qa_chain, gr.update(value=""), new_history, *response_sources, *response_pages
82
 
83
- # Função de carregamento de arquivos
84
  def upload_file(file_obj):
85
- return [file_obj.name for file_obj in file_obj if file_obj is not None]
 
 
 
 
86
 
87
- # Interface Gradio
88
  def demo():
89
  with gr.Blocks(theme="base") as demo:
90
- vector_db, qa_chain, collection_name = gr.State(), gr.State(), gr.State()
91
- gr.Markdown("<center><h2>Chatbot baseado em PDF</center></h2><h3>Faça qualquer pergunta sobre seus documentos PDF</h3>")
92
-
 
 
 
 
 
 
 
 
 
 
 
93
  with gr.Tab("Etapa 1 - Carregar PDF"):
94
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"])
95
-
 
 
96
  with gr.Tab("Etapa 2 - Processar documento"):
97
- db_btn = gr.Button("Gerar banco de dados vetorial")
98
- slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Tamanho do bloco")
99
- slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Sobreposição do bloco")
100
- db_progress = gr.Textbox(label="Inicialização do banco de dados vetorial")
101
-
 
 
 
 
 
 
 
102
  with gr.Tab("Etapa 3 - Inicializar cadeia de QA"):
103
- llm_btn = gr.Radio(list_llm_simple, label="Modelos LLM")
104
- slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperatura")
105
- slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Máximo de Tokens")
106
- slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Amostras top-k")
107
- llm_progress = gr.Textbox(value="Nenhum", label="Inicialização da cadeia QA")
108
- qachain_btn = gr.Button("Inicializar cadeia de Pergunta e Resposta")
 
 
 
 
 
 
 
 
109
 
110
  with gr.Tab("Etapa 4 - Chatbot"):
111
  chatbot = gr.Chatbot(height=300)
112
- doc_source1, doc_source2, doc_source3 = gr.Textbox(label="Referência 1"), gr.Textbox(label="Referência 2"), gr.Textbox(label="Referência 3")
113
- source1_page, source2_page, source3_page = gr.Number(label="Página 1"), gr.Number(label="Página 2"), gr.Number(label="Página 3")
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- # Campo de texto para enviar mensagens
116
- user_input = gr.Textbox(label="Sua mensagem")
117
-
118
- # Implementação de lógica de interação de conversa
119
- def send_message(message, history, qa_chain):
120
- formatted_chat_history = [f"Usuário: {user_message}\nAssistente: {bot_message}" for user_message, bot_message in history]
121
- response = qa_chain({"question": message, "chat_history": formatted_chat_history})
122
- response_answer = response["answer"].split("Resposta útil:")[-1]
123
- response_sources = [doc.page_content.strip() for doc in response["source_documents"]]
124
- response_pages = [doc.metadata["page"] + 1 for doc in response["source_documents"]]
125
- new_history = history + [(message, response_answer)]
126
- return qa_chain, gr.update(value=""), new_history, *response_sources, *response_pages
127
-
128
- user_input.submit(send_message, inputs=[user_input, chatbot.history, qa_chain], outputs=[qa_chain, gr.update(value=""), chatbot.history, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
129
-
130
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
131
 
 
 
 
12
  from unidecode import unidecode
13
  import re
14
 
15
+ # Lista de modelos LLM disponíveis
16
  list_llm = [
17
+ "mistralai/Mistral-7B-Instruct-v0.2",
18
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
19
+ "mistralai/Mistral-7B-Instruct-v0.1",
20
+ "google/gemma-7b-it",
21
+ "google/gemma-2b-it",
22
+ "HuggingFaceH4/zephyr-7b-beta",
23
+ "HuggingFaceH4/zephyr-7b-gemma-v0.1",
24
+ "meta-llama/Llama-2-7b-chat-hf",
25
+ "microsoft/phi-2",
26
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
27
+ "mosaicml/mpt-7b-instruct",
28
+ "tiiuae/falcon-7b-instruct",
29
+ "google/flan-t5-xxl"
30
  ]
31
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
32
 
33
+ # Função para carregar documentos PDF e dividir em chunks
34
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
35
  loaders = [PyPDFLoader(x) for x in list_file_path]
36
  pages = []
37
  for loader in loaders:
38
  pages.extend(loader.load())
39
+ text_splitter = RecursiveCharacterTextSplitter(
40
+ chunk_size=chunk_size,
41
+ chunk_overlap=chunk_overlap
42
+ )
43
+ doc_splits = text_splitter.split_documents(pages)
44
+ return doc_splits
45
 
46
+ # Função para criar o banco de dados vetorial
47
+ def create_db(splits, collection_name):
48
  embedding = HuggingFaceEmbeddings()
49
+ # Usando PersistentClient para persistir o banco de dados
50
  new_client = chromadb.PersistentClient(path="./chroma_db")
51
+ vectordb = Chroma.from_documents(
52
+ documents=splits,
53
+ embedding=embedding,
54
+ client=new_client,
55
+ collection_name=collection_name,
56
+ )
57
+ return vectordb
58
 
59
+ # Função para inicializar a cadeia de QA com o modelo LLM
60
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
61
+ progress(0.1, desc="Inicializando tokenizer da HF...")
62
+ progress(0.5, desc="Inicializando Hub da HF...")
63
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
64
+ llm = HuggingFaceEndpoint(
65
+ repo_id=llm_model,
66
+ temperature=temperature,
67
+ max_new_tokens=max_tokens,
68
+ top_k=top_k,
69
+ load_in_8bit=True,
70
+ )
71
+ elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1", "mosaicml/mpt-7b-instruct"]:
72
+ raise gr.Error("O modelo LLM é muito grande para ser carregado automaticamente no endpoint de inferência gratuito")
73
+ elif llm_model == "microsoft/phi-2":
74
+ llm = HuggingFaceEndpoint(
75
+ repo_id=llm_model,
76
+ temperature=temperature,
77
+ max_new_tokens=max_tokens,
78
+ top_k=top_k,
79
+ trust_remote_code=True,
80
+ torch_dtype="auto",
81
+ )
82
+ elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
83
+ llm = HuggingFaceEndpoint(
84
+ repo_id=llm_model,
85
+ temperature=temperature,
86
+ max_new_tokens=250,
87
+ top_k=top_k,
88
+ )
89
+ elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
90
+ raise gr.Error("O modelo Llama-2-7b-chat-hf requer uma assinatura Pro...")
91
+ else:
92
+ llm = HuggingFaceEndpoint(
93
+ repo_id=llm_model,
94
+ temperature=temperature,
95
+ max_new_tokens=max_tokens,
96
+ top_k=top_k,
97
+ )
98
+
99
+ progress(0.75, desc="Definindo memória de buffer...")
100
+ memory = ConversationBufferMemory(
101
+ memory_key="chat_history",
102
+ output_key='answer',
103
+ return_messages=True
104
  )
 
 
105
  retriever = vector_db.as_retriever()
106
+ progress(0.8, desc="Definindo cadeia de recuperação...")
107
+ qa_chain = ConversationalRetrievalChain.from_llm(
108
+ llm,
109
+ retriever=retriever,
110
+ chain_type="stuff",
111
+ memory=memory,
112
+ return_source_documents=True,
113
+ verbose=False,
114
+ )
115
  progress(0.9, desc="Concluído!")
116
  return qa_chain
117
 
118
  # Função para gerar um nome de coleção válido
119
  def create_collection_name(filepath):
120
  collection_name = Path(filepath).stem
121
+ collection_name = collection_name.replace(" ", "-")
122
+ collection_name = unidecode(collection_name)
123
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
124
+ collection_name = collection_name[:50]
125
+ if len(collection_name) < 3:
126
+ collection_name = collection_name + 'xyz'
127
+ if not collection_name[0].isalnum():
128
+ collection_name = 'A' + collection_name[1:]
129
+ if not collection_name[-1].isalnum():
130
+ collection_name = collection_name[:-1] + 'Z'
131
+ print('Caminho do arquivo: ', filepath)
132
+ print('Nome da coleção: ', collection_name)
133
+ return collection_name
134
 
135
+ # Função para inicializar o banco de dados
136
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
137
  list_file_path = [x.name for x in list_file_obj if x is not None]
138
  progress(0.1, desc="Criando nome da coleção...")
139
  collection_name = create_collection_name(list_file_path[0])
140
+ progress(0.25, desc="Carregando documento...")
141
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
142
  progress(0.5, desc="Gerando banco de dados vetorial...")
143
+ vector_db = create_db(doc_splits, collection_name)
144
+ progress(0.9, desc="Concluído!")
145
+ return vector_db, collection_name, "Completo!"
146
+
147
+ # Função para inicializar o modelo LLM
148
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
149
  llm_name = list_llm[llm_option]
150
+ print("Nome do LLM: ", llm_name)
151
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
152
+ return qa_chain, "Completo!"
 
153
 
154
+ # Função para formatar o histórico de conversa
155
+ def format_chat_history(message, chat_history):
156
+ formatted_chat_history = []
157
+ for user_message, bot_message in chat_history:
158
+ formatted_chat_history.append(f"Usuário: {user_message}")
159
+ formatted_chat_history.append(f"Assistente: {bot_message}")
160
+ return formatted_chat_history
161
+
162
+ # Função para realizar a conversa com o chatbot
163
  def conversation(qa_chain, message, history):
164
+ formatted_chat_history = format_chat_history(message, history)
165
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
166
+ response_answer = response["answer"]
167
+ if response_answer.find("Resposta útil:") != -1:
168
+ response_answer = response_answer.split("Resposta útil:")[-1]
169
+ response_sources = response["source_documents"]
170
+ response_source1 = response_sources[0].page_content.strip()
171
+ response_source2 = response_sources[1].page_content.strip()
172
+ response_source3 = response_sources[2].page_content.strip()
173
+ response_source1_page = response_sources[0].metadata["page"] + 1
174
+ response_source2_page = response_sources[1].metadata["page"] + 1
175
+ response_source3_page = response_sources[2].metadata["page"] + 1
176
  new_history = history + [(message, response_answer)]
177
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
178
 
179
+ # Função para carregar arquivos
180
  def upload_file(file_obj):
181
+ list_file_path = []
182
+ for idx, file in enumerate(file_obj):
183
+ file_path = file_obj.name
184
+ list_file_path.append(file_path)
185
+ return list_file_path
186
 
 
187
  def demo():
188
  with gr.Blocks(theme="base") as demo:
189
+ vector_db = gr.State()
190
+ qa_chain = gr.State()
191
+ collection_name = gr.State()
192
+
193
+ gr.Markdown(
194
+ """<center><h2>Chatbot baseado em PDF</center></h2>
195
+ <h3>Faça qualquer pergunta sobre seus documentos PDF</h3>""")
196
+ gr.Markdown(
197
+ """<b>Nota:</b> Este assistente de IA, utilizando Langchain e LLMs de código aberto, realiza geração aumentada por recuperação (RAG) a partir de seus documentos PDF. \
198
+ A interface do usuário mostra explicitamente várias etapas para ajudar a entender o fluxo de trabalho do RAG.
199
+ Este chatbot leva em consideração perguntas anteriores ao gerar respostas (via memória conversacional), e inclui referências documentais para maior clareza.<br>
200
+ <br><b>Aviso:</b> Este espaço usa a CPU básica gratuita do Hugging Face. Algumas etapas e modelos LLM utilizados abaixo (pontos finais de inferência gratuitos) podem levar algum tempo para gerar uma resposta.
201
+ """)
202
+
203
  with gr.Tab("Etapa 1 - Carregar PDF"):
204
+ with gr.Row():
205
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carregue seus documentos PDF (único ou múltiplos)")
206
+ # upload_btn = gr.UploadButton("Carregando documento...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
207
+
208
  with gr.Tab("Etapa 2 - Processar documento"):
209
+ with gr.Row():
210
+ db_btn = gr.Radio(["ChromaDB"], label="Tipo de banco de dados vetorial", value = "ChromaDB", type="index", info="Escolha o banco de dados vetorial")
211
+ with gr.Accordion("Opções avançadas - Divisor de texto do documento", open=False):
212
+ with gr.Row():
213
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Tamanho do bloco", info="Tamanho do bloco", interactive=True)
214
+ with gr.Row():
215
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Sobreposição do bloco", info="Sobreposição do bloco", interactive=True)
216
+ with gr.Row():
217
+ db_progress = gr.Textbox(label="Inicialização do banco de dados vetorial", value="Nenhum")
218
+ with gr.Row():
219
+ db_btn = gr.Button("Gerar banco de dados vetorial")
220
+
221
  with gr.Tab("Etapa 3 - Inicializar cadeia de QA"):
222
+ with gr.Row():
223
+ llm_btn = gr.Radio(list_llm_simple, \
224
+ label="Modelos LLM", value = list_llm_simple[0], type="index", info="Escolha seu modelo LLM")
225
+ with gr.Accordion("Opções avançadas - Modelo LLM", open=False):
226
+ with gr.Row():
227
+ slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperatura", info="Temperatura do modelo", interactive=True)
228
+ with gr.Row():
229
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Máximo de Tokens", info="Máximo de tokens do modelo", interactive=True)
230
+ with gr.Row():
231
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="Amostras top-k", info="Amostras top-k do modelo", interactive=True)
232
+ with gr.Row():
233
+ llm_progress = gr.Textbox(value="Nenhum",label="Inicialização da cadeia QA")
234
+ with gr.Row():
235
+ qachain_btn = gr.Button("Inicializar cadeia de Pergunta e Resposta")
236
 
237
  with gr.Tab("Etapa 4 - Chatbot"):
238
  chatbot = gr.Chatbot(height=300)
239
+ with gr.Accordion("Avançado - Referências do documento", open=False):
240
+ with gr.Row():
241
+ doc_source1 = gr.Textbox(label="Referência 1", lines=2, container=True, scale=20)
242
+ source1_page = gr.Number(label="Página", scale=1)
243
+ with gr.Row():
244
+ doc_source2 = gr.Textbox(label="Referência 2", lines=2, container=True, scale=20)
245
+ source2_page = gr.Number(label="Página", scale=1)
246
+ with gr.Row():
247
+ doc_source3 = gr.Textbox(label="Referência 3", lines=2, container=True, scale=20)
248
+ source3_page = gr.Number(label="Página", scale=1)
249
+ with gr.Row():
250
+ msg = gr.Textbox(placeholder="Digite a mensagem (exemplo: 'Sobre o que é este documento?')", container=True)
251
+ with gr.Row():
252
+ submit_btn = gr.Button("Enviar mensagem")
253
+ clear_btn = gr.ClearButton([msg, chatbot], value="Limpar conversa")
254
 
255
+ # Eventos de pré-processamento
256
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
257
+ db_btn.click(initialize_database, \
258
+ inputs=[document, slider_chunk_size, slider_chunk_overlap], \
259
+ outputs=[vector_db, collection_name, db_progress])
260
+ qachain_btn.click(initialize_LLM, \
261
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
262
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
263
+ inputs=None, \
264
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
265
+ queue=False)
266
+
267
+ # Eventos do Chatbot
268
+ msg.submit(conversation, \
269
+ inputs=[qa_chain, msg, chatbot], \
270
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
271
+ queue=False)
272
+ submit_btn.click(conversation, \
273
+ inputs=[qa_chain, msg, chatbot], \
274
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
275
+ queue=False)
276
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], \
277
+ inputs=None, \
278
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
279
+ queue=False)
280
+ demo.queue().launch(debug=True)
281
+
282
 
283
+ if __name__ == "__main__":
284
+ demo()