Fecalisboa commited on
Commit
42a8df2
·
verified ·
1 Parent(s): 23d8c97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -62
app.py CHANGED
@@ -97,40 +97,26 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, in
97
  progress(0.9, desc="Done!")
98
  return qa_chain
99
 
100
- # Generate collection name for vector database
101
- def create_collection_name(filepath):
102
- collection_name = Path(filepath).stem
103
- collection_name = collection_name.replace(" ", "-")
104
- collection_name = unidecode(collection_name)
105
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
106
- collection_name = collection_name[:50]
107
- if len(collection_name) < 3:
108
- collection_name = collection_name + 'xyz'
109
- if not collection_name[0].isalnum():
110
- collection_name = 'A' + collection_name[1:]
111
- if not collection_name[-1].isalnum():
112
- collection_name = collection_name[:-1] + 'Z'
113
- print('Filepath: ', filepath)
114
- print('Collection name: ', collection_name)
115
- return collection_name
116
-
117
- # Initialize database
118
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, db_type, progress=gr.Progress()):
119
- list_file_path = [x.name for x in list_file_obj if x is not None]
120
- progress(0.1, desc="Creating collection name...")
121
- collection_name = create_collection_name(list_file_path[0])
122
- progress(0.25, desc="Loading document...")
123
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
124
- progress(0.5, desc="Generating vector database...")
125
- vector_db = create_db(doc_splits, collection_name, db_type)
126
  progress(0.9, desc="Done!")
127
- return vector_db, collection_name, "Complete!"
128
-
129
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
130
- llm_name = list_llm[llm_option]
131
- print("llm_name: ", llm_name)
132
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, initial_prompt, progress)
133
- return qa_chain, "Complete!"
134
 
135
  def format_chat_history(message, chat_history):
136
  formatted_chat_history = []
@@ -156,27 +142,6 @@ def conversation(qa_chain, message, history):
156
  new_history = history + [(message, response_answer)]
157
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
158
 
159
- def initialize_llm_no_doc(llm_model, temperature, max_tokens, top_k, initial_prompt, progress=gr.Progress()):
160
- progress(0.1, desc="Initializing HF tokenizer...")
161
- progress(0.5, desc="Initializing HF Hub...")
162
- llm = HuggingFaceEndpoint(
163
- repo_id=llm_model,
164
- huggingfacehub_api_token=api_token,
165
- temperature=temperature,
166
- max_new_tokens=max_tokens,
167
- top_k=top_k,
168
- )
169
- progress(0.75, desc="Defining buffer memory...")
170
- memory = ConversationBufferMemory(
171
- memory_key="chat_history",
172
- output_key='answer',
173
- return_messages=True
174
- )
175
- conversation_chain = ConversationChain(llm=llm, memory=memory, verbose=False)
176
- conversation_chain({"question": initial_prompt})
177
- progress(0.9, desc="Done!")
178
- return conversation_chain
179
-
180
  def conversation_no_doc(llm, message, history):
181
  formatted_chat_history = format_chat_history(message, history)
182
  response = llm({"question": message, "chat_history": formatted_chat_history})
@@ -222,19 +187,13 @@ def demo():
222
  db_progress = gr.Textbox(label="Vector database initialization", value="None")
223
  with gr.Row():
224
  db_btn = gr.Button("Generate vector database")
225
- # Define o estado para o prompt inicial
226
- initial_prompt = gr.State("")
227
-
228
- # Define a aba "Set Initial Prompt"
229
  with gr.Tab("Step 3 - Set Initial Prompt"):
230
  with gr.Row():
231
  prompt_input = gr.Textbox(label="Initial Prompt", lines=5, value="Você é um advogado sênior, onde seu papel é analisar e trazer as informações sem inventar, dando a sua melhor opinião sempre trazendo contexto e referência. Aprenda o que é jurisprudência.")
232
  with gr.Row():
233
  set_prompt_btn = gr.Button("Set Prompt")
234
 
235
- # Atualiza o estado do prompt inicial ao clicar no botão "Set Prompt"
236
- set_prompt_btn.click(fn=lambda prompt: prompt, inputs=prompt_input, outputs=initial_prompt)
237
-
238
  with gr.Tab("Step 4 - Initialize QA chain"):
239
  with gr.Row():
240
  llm_btn = gr.Radio(list_llm_simple,
@@ -295,7 +254,7 @@ def demo():
295
  db_btn.click(initialize_database,
296
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
297
  outputs=[vector_db, collection_name, db_progress])
298
- set_prompt_btn.click(lambda prompt: prompt,
299
  inputs=prompt_input,
300
  outputs=initial_prompt)
301
  qachain_btn.click(initialize_LLM,
 
97
  progress(0.9, desc="Done!")
98
  return qa_chain
99
 
100
+ def initialize_llm_no_doc(llm_model, temperature, max_tokens, top_k, initial_prompt, progress=gr.Progress()):
101
+ progress(0.1, desc="Initializing HF tokenizer...")
102
+ progress(0.5, desc="Initializing HF Hub...")
103
+ llm = HuggingFaceEndpoint(
104
+ repo_id=llm_model,
105
+ huggingfacehub_api_token=api_token,
106
+ temperature=temperature,
107
+ max_new_tokens=max_tokens,
108
+ top_k=top_k,
109
+ )
110
+ progress(0.75, desc="Defining buffer memory...")
111
+ memory = ConversationBufferMemory(
112
+ memory_key="chat_history",
113
+ output_key='answer',
114
+ return_messages=True
115
+ )
116
+ conversation_chain = ConversationChain(llm=llm, memory=memory, verbose=False)
117
+ conversation_chain({"question": initial_prompt})
 
 
 
 
 
 
 
 
118
  progress(0.9, desc="Done!")
119
+ return conversation_chain
 
 
 
 
 
 
120
 
121
  def format_chat_history(message, chat_history):
122
  formatted_chat_history = []
 
142
  new_history = history + [(message, response_answer)]
143
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def conversation_no_doc(llm, message, history):
146
  formatted_chat_history = format_chat_history(message, history)
147
  response = llm({"question": message, "chat_history": formatted_chat_history})
 
187
  db_progress = gr.Textbox(label="Vector database initialization", value="None")
188
  with gr.Row():
189
  db_btn = gr.Button("Generate vector database")
190
+
 
 
 
191
  with gr.Tab("Step 3 - Set Initial Prompt"):
192
  with gr.Row():
193
  prompt_input = gr.Textbox(label="Initial Prompt", lines=5, value="Você é um advogado sênior, onde seu papel é analisar e trazer as informações sem inventar, dando a sua melhor opinião sempre trazendo contexto e referência. Aprenda o que é jurisprudência.")
194
  with gr.Row():
195
  set_prompt_btn = gr.Button("Set Prompt")
196
 
 
 
 
197
  with gr.Tab("Step 4 - Initialize QA chain"):
198
  with gr.Row():
199
  llm_btn = gr.Radio(list_llm_simple,
 
254
  db_btn.click(initialize_database,
255
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
256
  outputs=[vector_db, collection_name, db_progress])
257
+ set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
258
  inputs=prompt_input,
259
  outputs=initial_prompt)
260
  qachain_btn.click(initialize_LLM,