Fecalisboa commited on
Commit
6120503
1 Parent(s): 42a8df2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -23
app.py CHANGED
@@ -14,13 +14,51 @@ from langchain_community.llms import HuggingFacePipeline
14
  from langchain.chains import ConversationChain
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
 
17
  import torch
 
18
  api_token = os.getenv("HF_TOKEN")
19
 
 
 
 
 
20
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
21
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
22
 
23
- # Load PDF document and create doc splits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def load_doc(list_file_path, chunk_size, chunk_overlap):
25
  loaders = [PyPDFLoader(x) for x in list_file_path]
26
  pages = []
@@ -30,7 +68,6 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
30
  doc_splits = text_splitter.split_documents(pages)
31
  return doc_splits
32
 
33
- # Create vector database
34
  def create_db(splits, collection_name, db_type):
35
  embedding = HuggingFaceEmbeddings()
36
 
@@ -63,10 +100,8 @@ def create_db(splits, collection_name, db_type):
63
 
64
  return vectordb
65
 
66
- # Initialize langchain LLM chain
67
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
68
  progress(0.1, desc="Initializing HF tokenizer...")
69
-
70
  progress(0.5, desc="Initializing HF Hub...")
71
 
72
  llm = HuggingFaceEndpoint(
@@ -229,27 +264,58 @@ def demo():
229
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
230
 
231
  with gr.Tab("Step 6 - Chatbot without document"):
232
- with gr.Row():
233
- llm_no_doc_btn = gr.Radio(list_llm_simple,
234
- label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model for chatbot without document")
235
- with gr.Accordion("Advanced options - LLM model", open=False):
236
- with gr.Row():
237
- slider_temperature_no_doc = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
238
- with gr.Row():
239
- slider_maxtokens_no_doc = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
240
- with gr.Row():
241
- slider_topk_no_doc = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
242
- with gr.Row():
243
- llm_no_doc_progress = gr.Textbox(value="None", label="LLM initialization for chatbot without document")
244
- with gr.Row():
245
- llm_no_doc_init_btn = gr.Button("Initialize LLM for Chatbot without document")
246
  chatbot_no_doc = gr.Chatbot(height=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  with gr.Row():
248
  msg_no_doc = gr.Textbox(placeholder="Type message to chat with lucIAna", container=True)
249
  with gr.Row():
250
  submit_btn_no_doc = gr.Button("Submit message")
251
  clear_btn_no_doc = gr.ClearButton([msg_no_doc, chatbot_no_doc], value="Clear conversation")
252
 
 
 
 
 
 
 
 
253
  # Preprocessing events
254
  db_btn.click(initialize_database,
255
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
@@ -257,7 +323,7 @@ def demo():
257
  set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
258
  inputs=prompt_input,
259
  outputs=initial_prompt)
260
- qachain_btn.click(initialize_LLM,
261
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
262
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
263
  inputs=None,
@@ -279,10 +345,6 @@ def demo():
279
  queue=False)
280
 
281
  # Initialize LLM without document for conversation
282
- llm_no_doc_init_btn.click(initialize_llm_no_doc,
283
- inputs=[llm_no_doc_btn, slider_temperature_no_doc, slider_maxtokens_no_doc, slider_topk_no_doc, initial_prompt],
284
- outputs=[llm_no_doc, llm_no_doc_progress])
285
-
286
  submit_btn_no_doc.click(conversation_no_doc,
287
  inputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
288
  outputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
 
14
  from langchain.chains import ConversationChain
15
  from langchain.memory import ConversationBufferMemory
16
  from langchain_community.llms import HuggingFaceEndpoint
17
+ from huggingface_hub import InferenceClient
18
  import torch
19
+
20
  api_token = os.getenv("HF_TOKEN")
21
 
22
+ client = InferenceClient(
23
+ "mistralai/Mistral-7B-Instruct-v0.3"
24
+ )
25
+
26
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
27
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
28
 
29
+ def format_prompt(message, history):
30
+ prompt = "<s>"
31
+ for user_prompt, bot_response in history:
32
+ prompt += f"[INST] {user_prompt} [/INST]"
33
+ prompt += f" {bot_response}</s> "
34
+ prompt += f"[INST] {message} [/INST]"
35
+ return prompt
36
+
37
+ def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
38
+ temperature = float(temperature)
39
+ if temperature < 1e-2:
40
+ temperature = 1e-2
41
+ top_p = float(top_p)
42
+
43
+ generate_kwargs = dict(
44
+ temperature=temperature,
45
+ max_new_tokens=max_new_tokens,
46
+ top_p=top_p,
47
+ repetition_penalty=repetition_penalty,
48
+ do_sample=True,
49
+ seed=42,
50
+ )
51
+
52
+ formatted_prompt = format_prompt(prompt, history)
53
+
54
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
55
+ output = ""
56
+
57
+ for response in stream:
58
+ output += response.token.text
59
+ yield output
60
+ return output
61
+
62
  def load_doc(list_file_path, chunk_size, chunk_overlap):
63
  loaders = [PyPDFLoader(x) for x in list_file_path]
64
  pages = []
 
68
  doc_splits = text_splitter.split_documents(pages)
69
  return doc_splits
70
 
 
71
  def create_db(splits, collection_name, db_type):
72
  embedding = HuggingFaceEmbeddings()
73
 
 
100
 
101
  return vectordb
102
 
 
103
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
104
  progress(0.1, desc="Initializing HF tokenizer...")
 
105
  progress(0.5, desc="Initializing HF Hub...")
106
 
107
  llm = HuggingFaceEndpoint(
 
264
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
265
 
266
  with gr.Tab("Step 6 - Chatbot without document"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  chatbot_no_doc = gr.Chatbot(height=300)
268
+ additional_inputs=[
269
+ gr.Slider(
270
+ label="Temperature",
271
+ value=0.9,
272
+ minimum=0.0,
273
+ maximum=1.0,
274
+ step=0.05,
275
+ interactive=True,
276
+ info="Higher values produce more diverse outputs",
277
+ ),
278
+ gr.Slider(
279
+ label="Max new tokens",
280
+ value=256,
281
+ minimum=0,
282
+ maximum=1048,
283
+ step=64,
284
+ interactive=True,
285
+ info="The maximum numbers of new tokens",
286
+ ),
287
+ gr.Slider(
288
+ label="Top-p (nucleus sampling)",
289
+ value=0.90,
290
+ minimum=0.0,
291
+ maximum=1,
292
+ step=0.05,
293
+ interactive=True,
294
+ info="Higher values sample more low-probability tokens",
295
+ ),
296
+ gr.Slider(
297
+ label="Repetition penalty",
298
+ value=1.2,
299
+ minimum=1.0,
300
+ maximum=2.0,
301
+ step=0.05,
302
+ interactive=True,
303
+ info="Penalize repeated tokens",
304
+ )
305
+ ]
306
  with gr.Row():
307
  msg_no_doc = gr.Textbox(placeholder="Type message to chat with lucIAna", container=True)
308
  with gr.Row():
309
  submit_btn_no_doc = gr.Button("Submit message")
310
  clear_btn_no_doc = gr.ClearButton([msg_no_doc, chatbot_no_doc], value="Clear conversation")
311
 
312
+ gr.ChatInterface(
313
+ fn=generate,
314
+ chatbot=chatbot_no_doc,
315
+ additional_inputs=additional_inputs,
316
+ title="Mistral 7B v0.3"
317
+ )
318
+
319
  # Preprocessing events
320
  db_btn.click(initialize_database,
321
  inputs=[document, slider_chunk_size, slider_chunk_overlap, db_type_radio],
 
323
  set_prompt_btn.click(lambda prompt: gr.update(value=prompt),
324
  inputs=prompt_input,
325
  outputs=initial_prompt)
326
+ qachain_btn.click(initialize_llmchain,
327
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db, initial_prompt],
328
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0],
329
  inputs=None,
 
345
  queue=False)
346
 
347
  # Initialize LLM without document for conversation
 
 
 
 
348
  submit_btn_no_doc.click(conversation_no_doc,
349
  inputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],
350
  outputs=[llm_no_doc, msg_no_doc, chatbot_no_doc],