DHEIVER commited on
Commit
9278e48
·
verified ·
1 Parent(s): 6b248c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -111
app.py CHANGED
@@ -1,189 +1,283 @@
1
  import gradio as gr
2
  import os
 
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import Chroma
6
  from langchain.chains import ConversationalRetrievalChain
7
- from langchain_community.embeddings import HuggingFaceEmbeddings
8
- from langchain_community.llms import HuggingFaceEndpoint
 
9
  from langchain.memory import ConversationBufferMemory
 
 
10
  from pathlib import Path
11
  import chromadb
12
  from unidecode import unidecode
 
 
 
 
 
 
13
  import re
14
 
15
- # Lista de modelos LLM disponíveis
16
- list_llm = [
17
- "mistralai/Mistral-7B-Instruct-v0.2",
18
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
19
- "mistralai/Mistral-7B-Instruct-v0.1",
20
- "google/gemma-7b-it",
21
- "google/gemma-2b-it",
22
- "HuggingFaceH4/zephyr-7b-beta",
23
- "HuggingFaceH4/zephyr-7b-gemma-v0.1",
24
- "meta-llama/Llama-2-7b-chat-hf",
25
- "microsoft/phi-2",
26
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
27
- "mosaicml/mpt-7b-instruct",
28
- "tiiuae/falcon-7b-instruct",
29
  "google/flan-t5-xxl"
30
  ]
31
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
32
 
33
- # Função para carregar documentos PDF e dividir em chunks
34
  def load_doc(list_file_path, chunk_size, chunk_overlap):
 
 
 
35
  loaders = [PyPDFLoader(x) for x in list_file_path]
36
  pages = []
37
  for loader in loaders:
38
  pages.extend(loader.load())
 
39
  text_splitter = RecursiveCharacterTextSplitter(
40
- chunk_size=chunk_size,
41
- chunk_overlap=chunk_overlap
42
- )
43
  doc_splits = text_splitter.split_documents(pages)
44
  return doc_splits
45
 
46
- # Função para criar o banco de dados vetorial
 
47
  def create_db(splits, collection_name):
48
  embedding = HuggingFaceEmbeddings()
49
- # Usando PersistentClient para persistir o banco de dados
50
- new_client = chromadb.PersistentClient(path="./chroma_db")
51
  vectordb = Chroma.from_documents(
52
  documents=splits,
53
  embedding=embedding,
54
  client=new_client,
55
  collection_name=collection_name,
 
56
  )
57
  return vectordb
58
 
59
- # Função para inicializar a cadeia de QA com o modelo LLM
 
 
 
 
 
 
 
 
 
 
60
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
61
- progress(0.1, desc="Inicializando tokenizer da HF...")
62
- progress(0.5, desc="Inicializando Hub da HF...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
64
  llm = HuggingFaceEndpoint(
65
- repo_id=llm_model,
66
- temperature=temperature,
67
- max_new_tokens=max_tokens,
68
- top_k=top_k,
69
- load_in_8bit=True,
 
 
 
 
 
 
 
 
 
70
  )
71
- elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1", "mosaicml/mpt-7b-instruct"]:
72
- raise gr.Error("O modelo LLM é muito grande para ser carregado automaticamente no endpoint de inferência gratuito")
73
  elif llm_model == "microsoft/phi-2":
 
74
  llm = HuggingFaceEndpoint(
75
- repo_id=llm_model,
76
- temperature=temperature,
77
- max_new_tokens=max_tokens,
78
- top_k=top_k,
79
- trust_remote_code=True,
80
- torch_dtype="auto",
 
81
  )
82
  elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
83
  llm = HuggingFaceEndpoint(
84
- repo_id=llm_model,
85
- temperature=temperature,
86
- max_new_tokens=250,
87
- top_k=top_k,
 
88
  )
89
  elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
90
- raise gr.Error("O modelo Llama-2-7b-chat-hf requer uma assinatura Pro...")
 
 
 
 
 
 
 
91
  else:
92
  llm = HuggingFaceEndpoint(
93
- repo_id=llm_model,
94
- temperature=temperature,
95
- max_new_tokens=max_tokens,
96
- top_k=top_k,
 
 
97
  )
98
-
99
- progress(0.75, desc="Definindo memória de buffer...")
100
  memory = ConversationBufferMemory(
101
  memory_key="chat_history",
102
  output_key='answer',
103
  return_messages=True
104
  )
105
- retriever = vector_db.as_retriever()
106
- progress(0.8, desc="Definindo cadeia de recuperação...")
 
107
  qa_chain = ConversationalRetrievalChain.from_llm(
108
  llm,
109
  retriever=retriever,
110
- chain_type="stuff",
111
  memory=memory,
 
112
  return_source_documents=True,
 
113
  verbose=False,
114
  )
115
- progress(0.9, desc="Concluído!")
116
  return qa_chain
117
 
118
- # Função para gerar um nome de coleção válido
 
 
119
  def create_collection_name(filepath):
 
120
  collection_name = Path(filepath).stem
121
- collection_name = collection_name.replace(" ", "-")
 
 
 
122
  collection_name = unidecode(collection_name)
 
 
123
  collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
 
124
  collection_name = collection_name[:50]
 
125
  if len(collection_name) < 3:
126
  collection_name = collection_name + 'xyz'
 
127
  if not collection_name[0].isalnum():
128
  collection_name = 'A' + collection_name[1:]
129
  if not collection_name[-1].isalnum():
130
  collection_name = collection_name[:-1] + 'Z'
131
- print('Caminho do arquivo: ', filepath)
132
- print('Nome da coleção: ', collection_name)
133
  return collection_name
134
 
135
- # Função para inicializar o banco de dados
 
136
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
 
137
  list_file_path = [x.name for x in list_file_obj if x is not None]
138
- progress(0.1, desc="Criando nome da coleção...")
 
139
  collection_name = create_collection_name(list_file_path[0])
140
- progress(0.25, desc="Carregando documento...")
 
141
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
142
- progress(0.5, desc="Gerando banco de dados vetorial...")
 
 
143
  vector_db = create_db(doc_splits, collection_name)
144
- progress(0.9, desc="Concluído!")
145
- return vector_db, collection_name, "Completo!"
 
146
 
147
- # Função para inicializar o modelo LLM
148
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
149
  llm_name = list_llm[llm_option]
150
- print("Nome do LLM: ", llm_name)
151
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
152
- return qa_chain, "Completo!"
 
153
 
154
- # Função para formatar o histórico de conversa
155
  def format_chat_history(message, chat_history):
156
  formatted_chat_history = []
157
  for user_message, bot_message in chat_history:
158
- formatted_chat_history.append(f"Usuário: {user_message}")
159
- formatted_chat_history.append(f"Assistente: {bot_message}")
160
  return formatted_chat_history
 
161
 
162
- # Função para realizar a conversa com o chatbot
163
  def conversation(qa_chain, message, history):
164
  formatted_chat_history = format_chat_history(message, history)
 
 
 
165
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
166
  response_answer = response["answer"]
167
- if response_answer.find("Resposta útil:") != -1:
168
- response_answer = response_answer.split("Resposta útil:")[-1]
169
  response_sources = response["source_documents"]
170
  response_source1 = response_sources[0].page_content.strip()
171
  response_source2 = response_sources[1].page_content.strip()
172
  response_source3 = response_sources[2].page_content.strip()
 
173
  response_source1_page = response_sources[0].metadata["page"] + 1
174
  response_source2_page = response_sources[1].metadata["page"] + 1
175
  response_source3_page = response_sources[2].metadata["page"] + 1
 
 
 
 
176
  new_history = history + [(message, response_answer)]
 
177
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
178
 
179
- # Função para carregar arquivos
180
  def upload_file(file_obj):
181
  list_file_path = []
182
  for idx, file in enumerate(file_obj):
183
  file_path = file_obj.name
184
  list_file_path.append(file_path)
 
 
185
  return list_file_path
186
 
 
187
  def demo():
188
  with gr.Blocks(theme="base") as demo:
189
  vector_db = gr.State()
@@ -191,68 +285,68 @@ def demo():
191
  collection_name = gr.State()
192
 
193
  gr.Markdown(
194
- """<center><h2>Chatbot baseado em PDF</center></h2>
195
- <h3>Faça qualquer pergunta sobre seus documentos PDF</h3>""")
196
  gr.Markdown(
197
- """<b>Nota:</b> Este assistente de IA, utilizando Langchain e LLMs de código aberto, realiza geração aumentada por recuperação (RAG) a partir de seus documentos PDF. \
198
- A interface do usuário mostra explicitamente várias etapas para ajudar a entender o fluxo de trabalho do RAG.
199
- Este chatbot leva em consideração perguntas anteriores ao gerar respostas (via memória conversacional), e inclui referências documentais para maior clareza.<br>
200
- <br><b>Aviso:</b> Este espaço usa a CPU básica gratuita do Hugging Face. Algumas etapas e modelos LLM utilizados abaixo (pontos finais de inferência gratuitos) podem levar algum tempo para gerar uma resposta.
201
  """)
202
 
203
- with gr.Tab("Etapa 1 - Carregar PDF"):
204
  with gr.Row():
205
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carregue seus documentos PDF (único ou múltiplos)")
206
- # upload_btn = gr.UploadButton("Carregando documento...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
207
 
208
- with gr.Tab("Etapa 2 - Processar documento"):
209
  with gr.Row():
210
- db_btn = gr.Radio(["ChromaDB"], label="Tipo de banco de dados vetorial", value = "ChromaDB", type="index", info="Escolha o banco de dados vetorial")
211
- with gr.Accordion("Opções avançadas - Divisor de texto do documento", open=False):
212
  with gr.Row():
213
- slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Tamanho do bloco", info="Tamanho do bloco", interactive=True)
214
  with gr.Row():
215
- slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Sobreposição do bloco", info="Sobreposição do bloco", interactive=True)
216
  with gr.Row():
217
- db_progress = gr.Textbox(label="Inicialização do banco de dados vetorial", value="Nenhum")
218
  with gr.Row():
219
- db_btn = gr.Button("Gerar banco de dados vetorial")
220
 
221
- with gr.Tab("Etapa 3 - Inicializar cadeia de QA"):
222
  with gr.Row():
223
  llm_btn = gr.Radio(list_llm_simple, \
224
- label="Modelos LLM", value = list_llm_simple[0], type="index", info="Escolha seu modelo LLM")
225
- with gr.Accordion("Opções avançadas - Modelo LLM", open=False):
226
  with gr.Row():
227
- slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperatura", info="Temperatura do modelo", interactive=True)
228
  with gr.Row():
229
- slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Máximo de Tokens", info="Máximo de tokens do modelo", interactive=True)
230
  with gr.Row():
231
- slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="Amostras top-k", info="Amostras top-k do modelo", interactive=True)
232
  with gr.Row():
233
- llm_progress = gr.Textbox(value="Nenhum",label="Inicialização da cadeia QA")
234
  with gr.Row():
235
- qachain_btn = gr.Button("Inicializar cadeia de Pergunta e Resposta")
236
 
237
- with gr.Tab("Etapa 4 - Chatbot"):
238
  chatbot = gr.Chatbot(height=300)
239
- with gr.Accordion("Avançado - Referências do documento", open=False):
240
  with gr.Row():
241
- doc_source1 = gr.Textbox(label="Referência 1", lines=2, container=True, scale=20)
242
- source1_page = gr.Number(label="Página", scale=1)
243
  with gr.Row():
244
- doc_source2 = gr.Textbox(label="Referência 2", lines=2, container=True, scale=20)
245
- source2_page = gr.Number(label="Página", scale=1)
246
  with gr.Row():
247
- doc_source3 = gr.Textbox(label="Referência 3", lines=2, container=True, scale=20)
248
- source3_page = gr.Number(label="Página", scale=1)
249
  with gr.Row():
250
- msg = gr.Textbox(placeholder="Digite a mensagem (exemplo: 'Sobre o que é este documento?')", container=True)
251
  with gr.Row():
252
- submit_btn = gr.Button("Enviar mensagem")
253
- clear_btn = gr.ClearButton([msg, chatbot], value="Limpar conversa")
254
 
255
- # Eventos de pré-processamento
256
  #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
257
  db_btn.click(initialize_database, \
258
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
@@ -264,7 +358,7 @@ def demo():
264
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
265
  queue=False)
266
 
267
- # Eventos do Chatbot
268
  msg.submit(conversation, \
269
  inputs=[qa_chain, msg, chatbot], \
270
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
 
1
  import gradio as gr
2
  import os
3
+
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.llms import HuggingFacePipeline
10
+ from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
+ from langchain_community.llms import HuggingFaceEndpoint
13
+
14
  from pathlib import Path
15
  import chromadb
16
  from unidecode import unidecode
17
+
18
+ from transformers import AutoTokenizer
19
+ import transformers
20
+ import torch
21
+ import tqdm
22
+ import accelerate
23
  import re
24
 
25
+
26
+
27
+ # default_persist_directory = './chroma_HF/'
28
+ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
29
+ "google/gemma-7b-it","google/gemma-2b-it", \
30
+ "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1", \
31
+ "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
32
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
 
 
 
 
 
 
33
  "google/flan-t5-xxl"
34
  ]
35
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
36
 
37
+ # Load PDF document and create doc splits
38
  def load_doc(list_file_path, chunk_size, chunk_overlap):
39
+ # Processing for one document only
40
+ # loader = PyPDFLoader(file_path)
41
+ # pages = loader.load()
42
  loaders = [PyPDFLoader(x) for x in list_file_path]
43
  pages = []
44
  for loader in loaders:
45
  pages.extend(loader.load())
46
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
47
  text_splitter = RecursiveCharacterTextSplitter(
48
+ chunk_size = chunk_size,
49
+ chunk_overlap = chunk_overlap)
 
50
  doc_splits = text_splitter.split_documents(pages)
51
  return doc_splits
52
 
53
+
54
+ # Create vector database
55
  def create_db(splits, collection_name):
56
  embedding = HuggingFaceEmbeddings()
57
+ new_client = chromadb.EphemeralClient()
 
58
  vectordb = Chroma.from_documents(
59
  documents=splits,
60
  embedding=embedding,
61
  client=new_client,
62
  collection_name=collection_name,
63
+ # persist_directory=default_persist_directory
64
  )
65
  return vectordb
66
 
67
+
68
+ # Load vector database
69
+ def load_db():
70
+ embedding = HuggingFaceEmbeddings()
71
+ vectordb = Chroma(
72
+ # persist_directory=default_persist_directory,
73
+ embedding_function=embedding)
74
+ return vectordb
75
+
76
+
77
+ # Initialize langchain LLM chain
78
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
79
+ progress(0.1, desc="Initializing HF tokenizer...")
80
+ # HuggingFacePipeline uses local model
81
+ # Note: it will download model locally...
82
+ # tokenizer=AutoTokenizer.from_pretrained(llm_model)
83
+ # progress(0.5, desc="Initializing HF pipeline...")
84
+ # pipeline=transformers.pipeline(
85
+ # "text-generation",
86
+ # model=llm_model,
87
+ # tokenizer=tokenizer,
88
+ # torch_dtype=torch.bfloat16,
89
+ # trust_remote_code=True,
90
+ # device_map="auto",
91
+ # # max_length=1024,
92
+ # max_new_tokens=max_tokens,
93
+ # do_sample=True,
94
+ # top_k=top_k,
95
+ # num_return_sequences=1,
96
+ # eos_token_id=tokenizer.eos_token_id
97
+ # )
98
+ # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
99
+
100
+ # HuggingFaceHub uses HF inference endpoints
101
+ progress(0.5, desc="Initializing HF Hub...")
102
+ # Use of trust_remote_code as model_kwargs
103
+ # Warning: langchain issue
104
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
105
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
106
  llm = HuggingFaceEndpoint(
107
+ repo_id=llm_model,
108
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
109
+ temperature = temperature,
110
+ max_new_tokens = max_tokens,
111
+ top_k = top_k,
112
+ load_in_8bit = True,
113
+ )
114
+ elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1","mosaicml/mpt-7b-instruct"]:
115
+ raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint")
116
+ llm = HuggingFaceEndpoint(
117
+ repo_id=llm_model,
118
+ temperature = temperature,
119
+ max_new_tokens = max_tokens,
120
+ top_k = top_k,
121
  )
 
 
122
  elif llm_model == "microsoft/phi-2":
123
+ # raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
124
  llm = HuggingFaceEndpoint(
125
+ repo_id=llm_model,
126
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
127
+ temperature = temperature,
128
+ max_new_tokens = max_tokens,
129
+ top_k = top_k,
130
+ trust_remote_code = True,
131
+ torch_dtype = "auto",
132
  )
133
  elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
134
  llm = HuggingFaceEndpoint(
135
+ repo_id=llm_model,
136
+ # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
137
+ temperature = temperature,
138
+ max_new_tokens = 250,
139
+ top_k = top_k,
140
  )
141
  elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
142
+ raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
143
+ llm = HuggingFaceEndpoint(
144
+ repo_id=llm_model,
145
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
146
+ temperature = temperature,
147
+ max_new_tokens = max_tokens,
148
+ top_k = top_k,
149
+ )
150
  else:
151
  llm = HuggingFaceEndpoint(
152
+ repo_id=llm_model,
153
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
154
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
155
+ temperature = temperature,
156
+ max_new_tokens = max_tokens,
157
+ top_k = top_k,
158
  )
159
+
160
+ progress(0.75, desc="Defining buffer memory...")
161
  memory = ConversationBufferMemory(
162
  memory_key="chat_history",
163
  output_key='answer',
164
  return_messages=True
165
  )
166
+ # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
167
+ retriever=vector_db.as_retriever()
168
+ progress(0.8, desc="Defining retrieval chain...")
169
  qa_chain = ConversationalRetrievalChain.from_llm(
170
  llm,
171
  retriever=retriever,
172
+ chain_type="stuff",
173
  memory=memory,
174
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
175
  return_source_documents=True,
176
+ #return_generated_question=False,
177
  verbose=False,
178
  )
179
+ progress(0.9, desc="Done!")
180
  return qa_chain
181
 
182
+
183
+ # Generate collection name for vector database
184
+ # - Use filepath as input, ensuring unicode text
185
  def create_collection_name(filepath):
186
+ # Extract filename without extension
187
  collection_name = Path(filepath).stem
188
+ # Fix potential issues from naming convention
189
+ ## Remove space
190
+ collection_name = collection_name.replace(" ","-")
191
+ ## ASCII transliterations of Unicode text
192
  collection_name = unidecode(collection_name)
193
+ ## Remove special characters
194
+ #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
195
  collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
196
+ ## Limit length to 50 characters
197
  collection_name = collection_name[:50]
198
+ ## Minimum length of 3 characters
199
  if len(collection_name) < 3:
200
  collection_name = collection_name + 'xyz'
201
+ ## Enforce start and end as alphanumeric character
202
  if not collection_name[0].isalnum():
203
  collection_name = 'A' + collection_name[1:]
204
  if not collection_name[-1].isalnum():
205
  collection_name = collection_name[:-1] + 'Z'
206
+ print('Filepath: ', filepath)
207
+ print('Collection name: ', collection_name)
208
  return collection_name
209
 
210
+
211
+ # Initialize database
212
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
213
+ # Create list of documents (when valid)
214
  list_file_path = [x.name for x in list_file_obj if x is not None]
215
+ # Create collection_name for vector database
216
+ progress(0.1, desc="Creating collection name...")
217
  collection_name = create_collection_name(list_file_path[0])
218
+ progress(0.25, desc="Loading document...")
219
+ # Load document and create splits
220
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
221
+ # Create or load vector database
222
+ progress(0.5, desc="Generating vector database...")
223
+ # global vector_db
224
  vector_db = create_db(doc_splits, collection_name)
225
+ progress(0.9, desc="Done!")
226
+ return vector_db, collection_name, "Complete!"
227
+
228
 
 
229
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
230
+ # print("llm_option",llm_option)
231
  llm_name = list_llm[llm_option]
232
+ print("llm_name: ",llm_name)
233
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
234
+ return qa_chain, "Complete!"
235
+
236
 
 
237
  def format_chat_history(message, chat_history):
238
  formatted_chat_history = []
239
  for user_message, bot_message in chat_history:
240
+ formatted_chat_history.append(f"User: {user_message}")
241
+ formatted_chat_history.append(f"Assistant: {bot_message}")
242
  return formatted_chat_history
243
+
244
 
 
245
  def conversation(qa_chain, message, history):
246
  formatted_chat_history = format_chat_history(message, history)
247
+ #print("formatted_chat_history",formatted_chat_history)
248
+
249
+ # Generate response using QA chain
250
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
251
  response_answer = response["answer"]
252
+ if response_answer.find("Helpful Answer:") != -1:
253
+ response_answer = response_answer.split("Helpful Answer:")[-1]
254
  response_sources = response["source_documents"]
255
  response_source1 = response_sources[0].page_content.strip()
256
  response_source2 = response_sources[1].page_content.strip()
257
  response_source3 = response_sources[2].page_content.strip()
258
+ # Langchain sources are zero-based
259
  response_source1_page = response_sources[0].metadata["page"] + 1
260
  response_source2_page = response_sources[1].metadata["page"] + 1
261
  response_source3_page = response_sources[2].metadata["page"] + 1
262
+ # print ('chat response: ', response_answer)
263
+ # print('DB source', response_sources)
264
+
265
+ # Append user message and response to chat history
266
  new_history = history + [(message, response_answer)]
267
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
268
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
269
+
270
 
 
271
  def upload_file(file_obj):
272
  list_file_path = []
273
  for idx, file in enumerate(file_obj):
274
  file_path = file_obj.name
275
  list_file_path.append(file_path)
276
+ # print(file_path)
277
+ # initialize_database(file_path, progress)
278
  return list_file_path
279
 
280
+
281
  def demo():
282
  with gr.Blocks(theme="base") as demo:
283
  vector_db = gr.State()
 
285
  collection_name = gr.State()
286
 
287
  gr.Markdown(
288
+ """<center><h2>PDF-based chatbot</center></h2>
289
+ <h3>Ask any questions about your PDF documents</h3>""")
290
  gr.Markdown(
291
+ """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
292
+ The user interface explicitely shows multiple steps to help understand the RAG workflow.
293
+ This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
294
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
295
  """)
296
 
297
+ with gr.Tab("Step 1 - Upload PDF"):
298
  with gr.Row():
299
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
300
+ # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
301
 
302
+ with gr.Tab("Step 2 - Process document"):
303
  with gr.Row():
304
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
305
+ with gr.Accordion("Advanced options - Document text splitter", open=False):
306
  with gr.Row():
307
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
308
  with gr.Row():
309
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
310
  with gr.Row():
311
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
312
  with gr.Row():
313
+ db_btn = gr.Button("Generate vector database")
314
 
315
+ with gr.Tab("Step 3 - Initialize QA chain"):
316
  with gr.Row():
317
  llm_btn = gr.Radio(list_llm_simple, \
318
+ label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
319
+ with gr.Accordion("Advanced options - LLM model", open=False):
320
  with gr.Row():
321
+ slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
322
  with gr.Row():
323
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
324
  with gr.Row():
325
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
326
  with gr.Row():
327
+ llm_progress = gr.Textbox(value="None",label="QA chain initialization")
328
  with gr.Row():
329
+ qachain_btn = gr.Button("Initialize Question Answering chain")
330
 
331
+ with gr.Tab("Step 4 - Chatbot"):
332
  chatbot = gr.Chatbot(height=300)
333
+ with gr.Accordion("Advanced - Document references", open=False):
334
  with gr.Row():
335
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
336
+ source1_page = gr.Number(label="Page", scale=1)
337
  with gr.Row():
338
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
339
+ source2_page = gr.Number(label="Page", scale=1)
340
  with gr.Row():
341
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
342
+ source3_page = gr.Number(label="Page", scale=1)
343
  with gr.Row():
344
+ msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
345
  with gr.Row():
346
+ submit_btn = gr.Button("Submit message")
347
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
348
 
349
+ # Preprocessing events
350
  #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
351
  db_btn.click(initialize_database, \
352
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
 
358
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
359
  queue=False)
360
 
361
+ # Chatbot events
362
  msg.submit(conversation, \
363
  inputs=[qa_chain, msg, chatbot], \
364
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \