DHEIVER commited on
Commit
6b7ae1b
·
verified ·
1 Parent(s): 34bde69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -12
app.py CHANGED
@@ -1,6 +1,188 @@
1
- # gradio_interface.py
2
  import gradio as gr
3
- from rag_functions import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def demo():
6
  with gr.Blocks(theme="base") as demo:
@@ -9,22 +191,24 @@ def demo():
9
  collection_name = gr.State()
10
 
11
  gr.Markdown(
12
- """<center><h2>Assistente de Análise de Relatórios de Metrologia</center></h2>
13
- <h3>Faça qualquer pergunta sobre seus relatórios de metrologia</h3>""")
14
  gr.Markdown(
15
- """<b>Nota:</b> Este assistente de IA, utilizando Langchain e LLMs de código aberto, realiza geração aumentada por recuperação (RAG) a partir de seus relatórios de metrologia em formato PDF. \
16
- A interface do usuário está organizada para facilitar o entendimento do fluxo de trabalho do RAG. Este chatbot leva em consideração perguntas anteriores ao gerar respostas, e inclui referências documentais para maior clareza.<br>
 
17
  <br><b>Aviso:</b> Este espaço usa a CPU básica gratuita do Hugging Face. Algumas etapas e modelos LLM utilizados abaixo (pontos finais de inferência gratuitos) podem levar algum tempo para gerar uma resposta.
18
  """)
19
 
20
- with gr.Tab("Etapa 1 - Carregar Relatórios"):
21
  with gr.Row():
22
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carregue seus relatórios de metrologia (único ou múltiplos)")
 
23
 
24
- with gr.Tab("Etapa 2 - Processar Relatórios"):
25
  with gr.Row():
26
  db_btn = gr.Radio(["ChromaDB"], label="Tipo de banco de dados vetorial", value = "ChromaDB", type="index", info="Escolha o banco de dados vetorial")
27
- with gr.Accordion("Opções avançadas - Divisor de texto do relatório", open=False):
28
  with gr.Row():
29
  slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Tamanho do bloco", info="Tamanho do bloco", interactive=True)
30
  with gr.Row():
@@ -52,7 +236,7 @@ def demo():
52
 
53
  with gr.Tab("Etapa 4 - Chatbot"):
54
  chatbot = gr.Chatbot(height=300)
55
- with gr.Accordion("Avançado - Referências do Relatório", open=False):
56
  with gr.Row():
57
  doc_source1 = gr.Textbox(label="Referência 1", lines=2, container=True, scale=20)
58
  source1_page = gr.Number(label="Página", scale=1)
@@ -63,12 +247,13 @@ def demo():
63
  doc_source3 = gr.Textbox(label="Referência 3", lines=2, container=True, scale=20)
64
  source3_page = gr.Number(label="Página", scale=1)
65
  with gr.Row():
66
- msg = gr.Textbox(placeholder="Digite a mensagem (exemplo: 'Qual a precisão dos instrumentos utilizados?')", container=True)
67
  with gr.Row():
68
  submit_btn = gr.Button("Enviar mensagem")
69
  clear_btn = gr.ClearButton([msg, chatbot], value="Limpar conversa")
70
 
71
  # Eventos de pré-processamento
 
72
  db_btn.click(initialize_database, \
73
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
74
  outputs=[vector_db, collection_name, db_progress])
 
 
1
  import gradio as gr
2
+ import os
3
+ from langchain_community.document_loaders import PyPDFLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_community.vectorstores import Chroma
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain_community.llms import HuggingFaceEndpoint
9
+ from langchain.memory import ConversationBufferMemory
10
+ from pathlib import Path
11
+ import chromadb
12
+ from unidecode import unidecode
13
+ import re
14
+
15
+ # Lista de modelos LLM disponíveis
16
+ list_llm = [
17
+ "mistralai/Mistral-7B-Instruct-v0.2",
18
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
19
+ "mistralai/Mistral-7B-Instruct-v0.1",
20
+ "google/gemma-7b-it",
21
+ "google/gemma-2b-it",
22
+ "HuggingFaceH4/zephyr-7b-beta",
23
+ "HuggingFaceH4/zephyr-7b-gemma-v0.1",
24
+ "meta-llama/Llama-2-7b-chat-hf",
25
+ "microsoft/phi-2",
26
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
27
+ "mosaicml/mpt-7b-instruct",
28
+ "tiiuae/falcon-7b-instruct",
29
+ "google/flan-t5-xxl"
30
+ ]
31
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
32
+
33
+ # Função para carregar documentos PDF e dividir em chunks
34
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
35
+ loaders = [PyPDFLoader(x) for x in list_file_path]
36
+ pages = []
37
+ for loader in loaders:
38
+ pages.extend(loader.load())
39
+ text_splitter = RecursiveCharacterTextSplitter(
40
+ chunk_size=chunk_size,
41
+ chunk_overlap=chunk_overlap
42
+ )
43
+ doc_splits = text_splitter.split_documents(pages)
44
+ return doc_splits
45
+
46
+ # Função para criar o banco de dados vetorial
47
+ def create_db(splits, collection_name):
48
+ embedding = HuggingFaceEmbeddings()
49
+ # Usando PersistentClient para persistir o banco de dados
50
+ new_client = chromadb.PersistentClient(path="./chroma_db")
51
+ vectordb = Chroma.from_documents(
52
+ documents=splits,
53
+ embedding=embedding,
54
+ client=new_client,
55
+ collection_name=collection_name,
56
+ )
57
+ return vectordb
58
+
59
+ # Função para inicializar a cadeia de QA com o modelo LLM
60
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
61
+ progress(0.1, desc="Inicializando tokenizer da HF...")
62
+ progress(0.5, desc="Inicializando Hub da HF...")
63
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
64
+ llm = HuggingFaceEndpoint(
65
+ repo_id=llm_model,
66
+ temperature=temperature,
67
+ max_new_tokens=max_tokens,
68
+ top_k=top_k,
69
+ load_in_8bit=True,
70
+ )
71
+ elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1", "mosaicml/mpt-7b-instruct"]:
72
+ raise gr.Error("O modelo LLM é muito grande para ser carregado automaticamente no endpoint de inferência gratuito")
73
+ elif llm_model == "microsoft/phi-2":
74
+ llm = HuggingFaceEndpoint(
75
+ repo_id=llm_model,
76
+ temperature=temperature,
77
+ max_new_tokens=max_tokens,
78
+ top_k=top_k,
79
+ trust_remote_code=True,
80
+ torch_dtype="auto",
81
+ )
82
+ elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
83
+ llm = HuggingFaceEndpoint(
84
+ repo_id=llm_model,
85
+ temperature=temperature,
86
+ max_new_tokens=250,
87
+ top_k=top_k,
88
+ )
89
+ elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
90
+ raise gr.Error("O modelo Llama-2-7b-chat-hf requer uma assinatura Pro...")
91
+ else:
92
+ llm = HuggingFaceEndpoint(
93
+ repo_id=llm_model,
94
+ temperature=temperature,
95
+ max_new_tokens=max_tokens,
96
+ top_k=top_k,
97
+ )
98
+
99
+ progress(0.75, desc="Definindo memória de buffer...")
100
+ memory = ConversationBufferMemory(
101
+ memory_key="chat_history",
102
+ output_key='answer',
103
+ return_messages=True
104
+ )
105
+ retriever = vector_db.as_retriever()
106
+ progress(0.8, desc="Definindo cadeia de recuperação...")
107
+ qa_chain = ConversationalRetrievalChain.from_llm(
108
+ llm,
109
+ retriever=retriever,
110
+ chain_type="stuff",
111
+ memory=memory,
112
+ return_source_documents=True,
113
+ verbose=False,
114
+ )
115
+ progress(0.9, desc="Concluído!")
116
+ return qa_chain
117
+
118
+ # Função para gerar um nome de coleção válido
119
+ def create_collection_name(filepath):
120
+ collection_name = Path(filepath).stem
121
+ collection_name = collection_name.replace(" ", "-")
122
+ collection_name = unidecode(collection_name)
123
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
124
+ collection_name = collection_name[:50]
125
+ if len(collection_name) < 3:
126
+ collection_name = collection_name + 'xyz'
127
+ if not collection_name[0].isalnum():
128
+ collection_name = 'A' + collection_name[1:]
129
+ if not collection_name[-1].isalnum():
130
+ collection_name = collection_name[:-1] + 'Z'
131
+ print('Caminho do arquivo: ', filepath)
132
+ print('Nome da coleção: ', collection_name)
133
+ return collection_name
134
+
135
+ # Função para inicializar o banco de dados
136
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
137
+ list_file_path = [x.name for x in list_file_obj if x is not None]
138
+ progress(0.1, desc="Criando nome da coleção...")
139
+ collection_name = create_collection_name(list_file_path[0])
140
+ progress(0.25, desc="Carregando documento...")
141
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
142
+ progress(0.5, desc="Gerando banco de dados vetorial...")
143
+ vector_db = create_db(doc_splits, collection_name)
144
+ progress(0.9, desc="Concluído!")
145
+ return vector_db, collection_name, "Completo!"
146
+
147
+ # Função para inicializar o modelo LLM
148
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
149
+ llm_name = list_llm[llm_option]
150
+ print("Nome do LLM: ", llm_name)
151
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
152
+ return qa_chain, "Completo!"
153
+
154
+ # Função para formatar o histórico de conversa
155
+ def format_chat_history(message, chat_history):
156
+ formatted_chat_history = []
157
+ for user_message, bot_message in chat_history:
158
+ formatted_chat_history.append(f"Usuário: {user_message}")
159
+ formatted_chat_history.append(f"Assistente: {bot_message}")
160
+ return formatted_chat_history
161
+
162
+ # Função para realizar a conversa com o chatbot
163
+ def conversation(qa_chain, message, history):
164
+ formatted_chat_history = format_chat_history(message, history)
165
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
166
+ response_answer = response["answer"]
167
+ if response_answer.find("Resposta útil:") != -1:
168
+ response_answer = response_answer.split("Resposta útil:")[-1]
169
+ response_sources = response["source_documents"]
170
+ response_source1 = response_sources[0].page_content.strip()
171
+ response_source2 = response_sources[1].page_content.strip()
172
+ response_source3 = response_sources[2].page_content.strip()
173
+ response_source1_page = response_sources[0].metadata["page"] + 1
174
+ response_source2_page = response_sources[1].metadata["page"] + 1
175
+ response_source3_page = response_sources[2].metadata["page"] + 1
176
+ new_history = history + [(message, response_answer)]
177
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
178
+
179
+ # Função para carregar arquivos
180
+ def upload_file(file_obj):
181
+ list_file_path = []
182
+ for idx, file in enumerate(file_obj):
183
+ file_path = file_obj.name
184
+ list_file_path.append(file_path)
185
+ return list_file_path
186
 
187
  def demo():
188
  with gr.Blocks(theme="base") as demo:
 
191
  collection_name = gr.State()
192
 
193
  gr.Markdown(
194
+ """<center><h2>Chatbot baseado em PDF</center></h2>
195
+ <h3>Faça qualquer pergunta sobre seus documentos PDF</h3>""")
196
  gr.Markdown(
197
+ """<b>Nota:</b> Este assistente de IA, utilizando Langchain e LLMs de código aberto, realiza geração aumentada por recuperação (RAG) a partir de seus documentos PDF. \
198
+ A interface do usuário mostra explicitamente várias etapas para ajudar a entender o fluxo de trabalho do RAG.
199
+ Este chatbot leva em consideração perguntas anteriores ao gerar respostas (via memória conversacional), e inclui referências documentais para maior clareza.<br>
200
  <br><b>Aviso:</b> Este espaço usa a CPU básica gratuita do Hugging Face. Algumas etapas e modelos LLM utilizados abaixo (pontos finais de inferência gratuitos) podem levar algum tempo para gerar uma resposta.
201
  """)
202
 
203
+ with gr.Tab("Etapa 1 - Carregar PDF"):
204
  with gr.Row():
205
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carregue seus documentos PDF (único ou múltiplos)")
206
+ # upload_btn = gr.UploadButton("Carregando documento...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
207
 
208
+ with gr.Tab("Etapa 2 - Processar documento"):
209
  with gr.Row():
210
  db_btn = gr.Radio(["ChromaDB"], label="Tipo de banco de dados vetorial", value = "ChromaDB", type="index", info="Escolha o banco de dados vetorial")
211
+ with gr.Accordion("Opções avançadas - Divisor de texto do documento", open=False):
212
  with gr.Row():
213
  slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Tamanho do bloco", info="Tamanho do bloco", interactive=True)
214
  with gr.Row():
 
236
 
237
  with gr.Tab("Etapa 4 - Chatbot"):
238
  chatbot = gr.Chatbot(height=300)
239
+ with gr.Accordion("Avançado - Referências do documento", open=False):
240
  with gr.Row():
241
  doc_source1 = gr.Textbox(label="Referência 1", lines=2, container=True, scale=20)
242
  source1_page = gr.Number(label="Página", scale=1)
 
247
  doc_source3 = gr.Textbox(label="Referência 3", lines=2, container=True, scale=20)
248
  source3_page = gr.Number(label="Página", scale=1)
249
  with gr.Row():
250
+ msg = gr.Textbox(placeholder="Digite a mensagem (exemplo: 'Sobre o que é este documento?')", container=True)
251
  with gr.Row():
252
  submit_btn = gr.Button("Enviar mensagem")
253
  clear_btn = gr.ClearButton([msg, chatbot], value="Limpar conversa")
254
 
255
  # Eventos de pré-processamento
256
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
257
  db_btn.click(initialize_database, \
258
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
259
  outputs=[vector_db, collection_name, db_progress])