Parth211 commited on
Commit
6ccd6dd
1 Parent(s): cb48794
Files changed (1) hide show
  1. app.py +65 -280
app.py CHANGED
@@ -22,15 +22,6 @@ import tqdm
22
  import accelerate
23
  import re
24
 
25
-
26
- import torch
27
- from sacrebleu import corpus_bleu
28
- from rouge_score import rouge_scorer
29
- from bert_score import score
30
- from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
31
- import nltk
32
- from nltk.util import ngrams
33
-
34
  api_key = os.getenv('API_KEY')
35
 
36
 
@@ -87,6 +78,25 @@ def load_db():
87
  # Initialize langchain LLM chain
88
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
89
  progress(0.1, desc="Initializing HF tokenizer...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # HuggingFaceHub uses HF inference endpoints
92
  progress(0.5, desc="Initializing HF Hub...")
@@ -237,138 +247,32 @@ def format_chat_history(message, chat_history):
237
  formatted_chat_history.append(f"User: {user_message}")
238
  formatted_chat_history.append(f"Assistant: {bot_message}")
239
  return formatted_chat_history
 
240
 
241
- #----------------------------------------------------------------------------------
242
- def load_gpt2_model():
243
- model = GPT2LMHeadModel.from_pretrained('gpt2')
244
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
245
- return model, tokenizer
246
-
247
- gpt2_model, gpt2_tokenizer = load_gpt2_model()
248
- bias_pipeline = pipeline("zero-shot-classification", model="Hate-speech-CNERG/dehatebert-mono-english")
249
-
250
- def evaluate_bleu_rouge(candidates, references):
251
- bleu_score = corpus_bleu(candidates, [references]).score
252
- scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
253
- rouge_scores = [scorer.score(ref, cand) for ref, cand in zip(references, candidates)]
254
- rouge1 = sum([score['rouge1'].fmeasure for score in rouge_scores]) / len(rouge_scores)
255
- return bleu_score, rouge1
256
-
257
- def evaluate_bert_score(candidates, references):
258
- P, R, F1 = score(candidates, references, lang="en", model_type='bert-base-multilingual-cased')
259
- return P.mean().item(), R.mean().item(), F1.mean().item()
260
-
261
- def evaluate_perplexity(text, model, tokenizer):
262
- encodings = tokenizer(text, return_tensors='pt')
263
- max_length = model.config.n_positions
264
- stride = 512
265
- lls = []
266
- for i in range(0, encodings.input_ids.size(1), stride):
267
- begin_loc = max(i + stride - max_length, 0)
268
- end_loc = min(i + stride, encodings.input_ids.size(1))
269
- trg_len = end_loc - i
270
- input_ids = encodings.input_ids[:, begin_loc:end_loc]
271
- target_ids = input_ids.clone()
272
- target_ids[:, :-trg_len] = -100
273
- with torch.no_grad():
274
- outputs = model(input_ids, labels=target_ids)
275
- log_likelihood = outputs[0] * trg_len
276
- lls.append(log_likelihood)
277
- ppl = torch.exp(torch.stack(lls).sum() / end_loc)
278
- return ppl.item()
279
-
280
- def evaluate_diversity(texts):
281
- all_tokens = [tok for text in texts for tok in text.split()]
282
- unique_bigrams = set(ngrams(all_tokens, 2))
283
- diversity_score = len(unique_bigrams) / len(all_tokens) if all_tokens else 0
284
- return diversity_score
285
-
286
- def evaluate_racial_bias(text, pipeline):
287
- results = pipeline([text], candidate_labels=["hate speech", "not hate speech"])
288
- bias_score = results[0]['scores'][results[0]['labels'].index('hate speech')]
289
- return bias_score
290
-
291
- def evaluate_all(question, response, reference, gpt2_model, gpt2_tokenizer, bias_pipeline):
292
- candidates = [response]
293
- references = [reference]
294
- bleu, rouge1 = evaluate_bleu_rouge(candidates, references)
295
- bert_p, bert_r, bert_f1 = evaluate_bert_score(candidates, references)
296
- perplexity = evaluate_perplexity(response, gpt2_model, gpt2_tokenizer)
297
- diversity = evaluate_diversity(candidates)
298
- racial_bias = evaluate_racial_bias(response, bias_pipeline)
299
- return {
300
- "BLEU": bleu,
301
- "ROUGE-1": rouge1,
302
- "BERT P": bert_p,
303
- "BERT R": bert_r,
304
- "BERT F1": bert_f1,
305
- "Perplexity": perplexity,
306
- "Diversity": diversity,
307
- "Racial Bias": racial_bias
308
- }
309
-
310
- #---------------------------------------------------------------------------------
311
-
312
- def display_metrics(metrics):
313
- result = ""
314
- for k, v in metrics.items():
315
- if k == 'BLEU':
316
- result += f"BLEU measures the overlap between the generated output and reference text based on n-grams. Higher scores indicate better match. Score obtained: {v}\n\n"
317
- elif k == "ROUGE-1":
318
- result += f"ROUGE-1 measures the overlap of unigrams between the generated output and reference text. Higher scores indicate better match. Score obtained: {v}\n\n"
319
- elif k == 'BERT P':
320
- result += "BERTScore evaluates the semantic similarity between the generated output and reference text using BERT embeddings.\n\n"
321
- result += f"**BERT Precision**: {metrics['BERT P']}\n"
322
- result += f"**BERT Recall**: {metrics['BERT R']}\n"
323
- result += f"**BERT F1 Score**: {metrics['BERT F1']}\n\n"
324
- elif k == 'Perplexity':
325
- result += f"Perplexity measures how well a language model predicts the text. Lower values indicate better fluency and coherence. Score obtained: {v}\n\n"
326
- elif k == 'Diversity':
327
- result += f"Diversity measures the uniqueness of bigrams in the generated output. Higher values indicate more diverse and varied output. Score obtained: {v}\n\n"
328
- elif k == 'Racial Bias':
329
- result += f"Racial Bias score indicates the presence of biased language in the generated output. Higher scores indicate more bias. Score obtained: {v}\n\n"
330
- return result
331
- #---------------------------------------------------------------------------------------------------------------------------------------------------
332
-
333
-
334
-
335
-
336
-
337
-
338
- def conversation(qa_chain, message, history, gpt2_model, gpt2_tokenizer, bias_pipeline):
339
  formatted_chat_history = format_chat_history(message, history)
340
- question_by_user = message
341
-
 
342
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
343
  response_answer = response["answer"]
344
- answer_of_question = response["answer"]
345
  if response_answer.find("Helpful Answer:") != -1:
346
  response_answer = response_answer.split("Helpful Answer:")[-1]
347
  response_sources = response["source_documents"]
348
- context = " ".join([d.page_content for d in response_sources])
349
-
350
  response_source1 = response_sources[0].page_content.strip()
351
  response_source2 = response_sources[1].page_content.strip()
352
  response_source3 = response_sources[2].page_content.strip()
353
-
354
  response_source1_page = response_sources[0].metadata["page"] + 1
355
  response_source2_page = response_sources[1].metadata["page"] + 1
356
  response_source3_page = response_sources[2].metadata["page"] + 1
357
-
358
- new_history = history + [(message, response_answer)]
359
 
360
- # Evaluate the metrics
361
- metrics = evaluate_all(question_by_user, answer_of_question, context,gpt2_model, gpt2_tokenizer, bias_pipeline)
362
- evaluation_metrics = display_metrics(metrics)
363
-
364
- return (qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page,
365
- response_source2, response_source2_page, response_source3, response_source3_page,
366
- evaluation_metrics)
367
-
368
-
369
- # def interact(qa_chain, message, history):
370
- # return conversation(qa_chain, message, history, evaluator)
371
-
372
 
373
 
374
  def upload_file(file_obj):
@@ -380,21 +284,19 @@ def upload_file(file_obj):
380
  # initialize_database(file_path, progress)
381
  return list_file_path
382
 
383
- ####################################
384
 
385
  def demo():
386
  with gr.Blocks(theme="base") as demo:
387
  vector_db = gr.State()
388
  qa_chain = gr.State()
389
  collection_name = gr.State()
390
- history = gr.State([]) # Initialize history as an empty list
391
 
392
  gr.Markdown(
393
  """<center><h2>PDF-based chatbot</center></h2>
394
  <h3>Ask any questions about your PDF documents</h3>""")
395
  gr.Markdown(
396
  """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
397
- The user interface explicitly shows multiple steps to help understand the RAG workflow.
398
  This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
399
  <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
400
  """)
@@ -402,15 +304,16 @@ def demo():
402
  with gr.Tab("Step 1 - Upload PDF"):
403
  with gr.Row():
404
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
 
405
 
406
  with gr.Tab("Step 2 - Process document"):
407
  with gr.Row():
408
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value="ChromaDB", type="index", info="Choose your vector database")
409
  with gr.Accordion("Advanced options - Document text splitter", open=False):
410
  with gr.Row():
411
- slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
412
  with gr.Row():
413
- slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
414
  with gr.Row():
415
  db_progress = gr.Textbox(label="Vector database initialization", value="None")
416
  with gr.Row():
@@ -418,16 +321,17 @@ def demo():
418
 
419
  with gr.Tab("Step 3 - Initialize QA chain"):
420
  with gr.Row():
421
- llm_btn = gr.Radio(list_llm_simple, label="LLM models", value=list_llm_simple[0], type="index", info="Choose your LLM model")
 
422
  with gr.Accordion("Advanced options - LLM model", open=False):
423
  with gr.Row():
424
- slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
425
  with gr.Row():
426
- slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
427
  with gr.Row():
428
- slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
429
  with gr.Row():
430
- llm_progress = gr.Textbox(value="None", label="QA chain initialization")
431
  with gr.Row():
432
  qachain_btn = gr.Button("Initialize Question Answering chain")
433
 
@@ -448,153 +352,34 @@ def demo():
448
  with gr.Row():
449
  submit_btn = gr.Button("Submit message")
450
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
451
- with gr.Row("Metrics"):
452
- metrics_output = gr.Textbox(lines=10, label="Evaluation Metrics")
453
-
454
  # Preprocessing events
455
- db_btn.click(initialize_database,
456
- inputs=[document, slider_chunk_size, slider_chunk_overlap],
457
- outputs=[vector_db, collection_name, db_progress])
458
-
459
- qachain_btn.click(initialize_LLM,
460
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
461
- outputs=[qa_chain, llm_progress]).then(lambda: [None, "", 0, "", 0, "", 0],
462
- inputs=None,
463
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
464
- queue=False)
465
 
466
  # Chatbot events
467
  msg.submit(conversation, \
468
- inputs=[qa_chain, msg, chatbot, gpt2_model, gpt2_tokenizer, bias_pipeline], \
469
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page,metrics_output], \
 
 
 
 
 
 
 
 
470
  queue=False)
471
-
472
- submit_btn.click(conversation,
473
- inputs=[qa_chain, msg, history, gpt2_model, gpt2_tokenizer, bias_pipeline],
474
- outputs=[qa_chain, chatbot, history, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page, metrics_output])
475
-
476
- clear_btn.click(lambda: [None, "", 0, "", 0, "", 0],
477
- inputs=None,
478
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
479
- queue=False)
480
-
481
  demo.queue().launch(debug=True)
482
 
483
- if __name__ == "__main__":
484
- demo()
485
-
486
-
487
-
488
-
489
 
490
- ###################################
491
- # def demo():
492
- # with gr.Blocks(theme="base") as demo:
493
- # vector_db = gr.State()
494
- # qa_chain = gr.State()
495
- # collection_name = gr.State()
496
- # history = gr.State()
497
-
498
- # gr.Markdown(
499
- # """<center><h2>PDF-based chatbot</center></h2>
500
- # <h3>Ask any questions about your PDF documents</h3>""")
501
- # gr.Markdown(
502
- # """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
503
- # The user interface explicitely shows multiple steps to help understand the RAG workflow.
504
- # This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
505
- # <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
506
- # """)
507
-
508
- # with gr.Tab("Step 1 - Upload PDF"):
509
- # with gr.Row():
510
- # document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
511
- # # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
512
-
513
- # with gr.Tab("Step 2 - Process document"):
514
- # with gr.Row():
515
- # db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
516
- # with gr.Accordion("Advanced options - Document text splitter", open=False):
517
- # with gr.Row():
518
- # slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
519
- # with gr.Row():
520
- # slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
521
- # with gr.Row():
522
- # db_progress = gr.Textbox(label="Vector database initialization", value="None")
523
- # with gr.Row():
524
- # db_btn = gr.Button("Generate vector database")
525
-
526
- # with gr.Tab("Step 3 - Initialize QA chain"):
527
- # with gr.Row():
528
- # llm_btn = gr.Radio(list_llm_simple, \
529
- # label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
530
- # with gr.Accordion("Advanced options - LLM model", open=False):
531
- # with gr.Row():
532
- # slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
533
- # with gr.Row():
534
- # slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
535
- # with gr.Row():
536
- # slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
537
- # with gr.Row():
538
- # llm_progress = gr.Textbox(value="None",label="QA chain initialization")
539
- # with gr.Row():
540
- # qachain_btn = gr.Button("Initialize Question Answering chain")
541
-
542
- # with gr.Tab("Step 4 - Chatbot"):
543
- # chatbot = gr.Chatbot(height=300)
544
- # with gr.Accordion("Advanced - Document references", open=False):
545
- # with gr.Row():
546
- # doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
547
- # source1_page = gr.Number(label="Page", scale=1)
548
- # with gr.Row():
549
- # doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
550
- # source2_page = gr.Number(label="Page", scale=1)
551
- # with gr.Row():
552
- # doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
553
- # source3_page = gr.Number(label="Page", scale=1)
554
- # with gr.Row():
555
- # msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
556
- # with gr.Row():
557
- # submit_btn = gr.Button("Submit message")
558
- # clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
559
- # with gr.Row("Metrics"):
560
- # metrics_output = gr.Textbox(lines=10, label="Evaluation Metrics")
561
-
562
-
563
- # # Preprocessing events
564
- # #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
565
- # db_btn.click(initialize_database, \
566
- # inputs=[document, slider_chunk_size, slider_chunk_overlap], \
567
- # outputs=[vector_db, collection_name, db_progress])
568
- # qachain_btn.click(initialize_LLM, \
569
- # inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
570
- # outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
571
- # inputs=None, \
572
- # outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
573
- # queue=False)
574
-
575
- # Chatbot events
576
- # msg.submit(interact, inputs=[gr.State(),qa_chain, msg, history], outputs=[
577
- # gr.State(), chatbot, history, response_source1, response_source1_page,
578
- # response_source2, response_source2_page, response_source3, response_source3_page,
579
- # None, None, None, metrics_output
580
- # ],queue=False)
581
-
582
-
583
- # submit_btn.click(conversation, \
584
- # inputs=[qa_chain, msg, chatbot], \
585
- # outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
586
- # queue=False)
587
- # clear_btn.click(lambda:[None,"",0,"",0,"",0], \
588
- # inputs=None, \
589
- # outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
590
- # queue=False)
591
-
592
-
593
-
594
-
595
-
596
- # demo.queue().launch(debug=True)
597
-
598
-
599
- # if __name__ == "__main__":
600
- # demo()
 
22
  import accelerate
23
  import re
24
 
 
 
 
 
 
 
 
 
 
25
  api_key = os.getenv('API_KEY')
26
 
27
 
 
78
  # Initialize langchain LLM chain
79
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
80
  progress(0.1, desc="Initializing HF tokenizer...")
81
+ # HuggingFacePipeline uses local model
82
+ # Note: it will download model locally...
83
+ # tokenizer=AutoTokenizer.from_pretrained(llm_model)
84
+ # progress(0.5, desc="Initializing HF pipeline...")
85
+ # pipeline=transformers.pipeline(
86
+ # "text-generation",
87
+ # model=llm_model,
88
+ # tokenizer=tokenizer,
89
+ # torch_dtype=torch.bfloat16,
90
+ # trust_remote_code=True,
91
+ # device_map="auto",
92
+ # # max_length=1024,
93
+ # max_new_tokens=max_tokens,
94
+ # do_sample=True,
95
+ # top_k=top_k,
96
+ # num_return_sequences=1,
97
+ # eos_token_id=tokenizer.eos_token_id
98
+ # )
99
+ # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
100
 
101
  # HuggingFaceHub uses HF inference endpoints
102
  progress(0.5, desc="Initializing HF Hub...")
 
247
  formatted_chat_history.append(f"User: {user_message}")
248
  formatted_chat_history.append(f"Assistant: {bot_message}")
249
  return formatted_chat_history
250
+
251
 
252
+ def conversation(qa_chain, message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  formatted_chat_history = format_chat_history(message, history)
254
+ #print("formatted_chat_history",formatted_chat_history)
255
+
256
+ # Generate response using QA chain
257
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
258
  response_answer = response["answer"]
 
259
  if response_answer.find("Helpful Answer:") != -1:
260
  response_answer = response_answer.split("Helpful Answer:")[-1]
261
  response_sources = response["source_documents"]
 
 
262
  response_source1 = response_sources[0].page_content.strip()
263
  response_source2 = response_sources[1].page_content.strip()
264
  response_source3 = response_sources[2].page_content.strip()
265
+ # Langchain sources are zero-based
266
  response_source1_page = response_sources[0].metadata["page"] + 1
267
  response_source2_page = response_sources[1].metadata["page"] + 1
268
  response_source3_page = response_sources[2].metadata["page"] + 1
269
+ # print ('chat response: ', response_answer)
270
+ # print('DB source', response_sources)
271
 
272
+ # Append user message and response to chat history
273
+ new_history = history + [(message, response_answer)]
274
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
275
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
 
 
 
 
 
 
 
276
 
277
 
278
  def upload_file(file_obj):
 
284
  # initialize_database(file_path, progress)
285
  return list_file_path
286
 
 
287
 
288
  def demo():
289
  with gr.Blocks(theme="base") as demo:
290
  vector_db = gr.State()
291
  qa_chain = gr.State()
292
  collection_name = gr.State()
 
293
 
294
  gr.Markdown(
295
  """<center><h2>PDF-based chatbot</center></h2>
296
  <h3>Ask any questions about your PDF documents</h3>""")
297
  gr.Markdown(
298
  """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
299
+ The user interface explicitely shows multiple steps to help understand the RAG workflow.
300
  This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
301
  <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
302
  """)
 
304
  with gr.Tab("Step 1 - Upload PDF"):
305
  with gr.Row():
306
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
307
+ # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
308
 
309
  with gr.Tab("Step 2 - Process document"):
310
  with gr.Row():
311
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
312
  with gr.Accordion("Advanced options - Document text splitter", open=False):
313
  with gr.Row():
314
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
315
  with gr.Row():
316
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
317
  with gr.Row():
318
  db_progress = gr.Textbox(label="Vector database initialization", value="None")
319
  with gr.Row():
 
321
 
322
  with gr.Tab("Step 3 - Initialize QA chain"):
323
  with gr.Row():
324
+ llm_btn = gr.Radio(list_llm_simple, \
325
+ label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
326
  with gr.Accordion("Advanced options - LLM model", open=False):
327
  with gr.Row():
328
+ slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
329
  with gr.Row():
330
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
331
  with gr.Row():
332
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
333
  with gr.Row():
334
+ llm_progress = gr.Textbox(value="None",label="QA chain initialization")
335
  with gr.Row():
336
  qachain_btn = gr.Button("Initialize Question Answering chain")
337
 
 
352
  with gr.Row():
353
  submit_btn = gr.Button("Submit message")
354
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
355
+
 
 
356
  # Preprocessing events
357
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
358
+ db_btn.click(initialize_database, \
359
+ inputs=[document, slider_chunk_size, slider_chunk_overlap], \
360
+ outputs=[vector_db, collection_name, db_progress])
361
+ qachain_btn.click(initialize_LLM, \
362
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
363
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
364
+ inputs=None, \
365
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
366
+ queue=False)
367
 
368
  # Chatbot events
369
  msg.submit(conversation, \
370
+ inputs=[qa_chain, msg, chatbot], \
371
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
372
+ queue=False)
373
+ submit_btn.click(conversation, \
374
+ inputs=[qa_chain, msg, chatbot], \
375
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
376
+ queue=False)
377
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], \
378
+ inputs=None, \
379
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
380
  queue=False)
 
 
 
 
 
 
 
 
 
 
381
  demo.queue().launch(debug=True)
382
 
 
 
 
 
 
 
383
 
384
+ if __name__ == "__main__":
385
+ demo()