PearlIsa commited on
Commit
5bb1aa7
β€’
1 Parent(s): 12ccdbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -103
app.py CHANGED
@@ -19,9 +19,6 @@ import torch.nn as nn
19
  import torch.nn.functional as F
20
  from torch.cuda.amp import autocast
21
  from torch.utils.data import DataLoader
22
- import tensorflow as tf
23
- import keras
24
- import numpy as np
25
 
26
  # Hugging Face and Transformers
27
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
@@ -39,16 +36,6 @@ from langchain_community.embeddings import HuggingFaceEmbeddings # Updated impo
39
  from langchain_community.document_loaders import TextLoader # Updated import
40
  from langchain.text_splitter import RecursiveCharacterTextSplitter
41
 
42
- # Data Science and Visualization Libraries
43
- import pandas as pd
44
- import seaborn as sns
45
- import matplotlib.pyplot as plt
46
- from matplotlib.gridspec import GridSpec
47
- from sklearn.metrics import classification_report, confusion_matrix
48
-
49
- # Development and Testing
50
- import pytest
51
- from unittest.mock import Mock, patch
52
 
53
  # External Tools and APIs
54
  import wandb
@@ -108,54 +95,90 @@ class ModelManager:
108
  gc.collect()
109
 
110
  class PearlyBot:
111
- def __init__(self, model_zip_path: str = "./checkpoint-500.zip", model_dir: str = "./checkpoint-500"):
112
- self.model_dir = ModelManager.verify_and_extract_model(model_zip_path, model_dir)
113
- self.setup_model(self.model_dir)
114
- self.setup_rag()
115
- self.conversation_history = []
116
- self.last_interaction_time = time.time()
117
- self.interaction_cooldown = 1.0 # seconds
118
-
119
- def setup_model(self, model_path: str):
120
- """Initialize the model with proper error handling"""
121
  try:
122
- logger.info("Starting model initialization...")
123
- ModelManager.clear_gpu_memory()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # Load tokenizer
126
- try:
127
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
128
- self.tokenizer.pad_token = self.tokenizer.eos_token
129
- logger.info("Tokenizer loaded successfully")
130
- except Exception as e:
131
- logger.error(f"Failed to load tokenizer: {str(e)}")
132
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # Load model
135
- try:
136
- self.model = AutoModelForCausalLM.from_pretrained(
137
- model_path,
138
- device_map="auto",
139
- load_in_8bit=True,
140
- torch_dtype=torch.float16,
141
- low_cpu_mem_usage=True
142
- )
143
- self.model.eval()
144
- logger.info("Model loaded successfully")
145
- except Exception as e:
146
- logger.error(f"Failed to load model: {str(e)}")
147
- raise
148
-
149
  except Exception as e:
150
  logger.error(f"Error in model setup: {str(e)}")
151
  raise
 
152
 
153
  def setup_rag(self):
154
  try:
155
- logger.info("Setting up RAG system...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # Load your knowledge base content
157
- knowledge_base = {
158
- "triage_scenarios.txt": """Medical Triage Scenarios and Responses:
 
 
159
 
160
  EMERGENCY (999) SCENARIOS:
161
  1. Cardiovascular:
@@ -456,6 +479,48 @@ PROFESSIONAL CONDUCT:
456
  logger.error(f"Error setting up RAG: {str(e)}")
457
  raise
458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  def get_relevant_context(self, query):
460
  try:
461
  docs = self.vector_store.similarity_search(query, k=3)
@@ -576,28 +641,8 @@ Guidelines:
576
  except Exception as e:
577
  logger.error(f"Error in cleanup: {e}")
578
 
579
- def process_feedback(positive: bool, comment: str, history: List[Dict[str, str]]):
580
- try:
581
- if not history or len(history) < 2:
582
- return gr.update(value="")
583
-
584
- last_user_msg = history[-2]["content"] if isinstance(history[-2], dict) else history[-2][0]
585
- last_bot_msg = history[-1]["content"] if isinstance(history[-1], dict) else history[-1][1]
586
-
587
- bot.handle_feedback(
588
- message=last_user_msg,
589
- response=last_bot_msg,
590
- feedback=1 if positive else -1
591
- )
592
-
593
- return gr.update(value="")
594
-
595
- except Exception as e:
596
- logger.error(f"Error processing feedback: {e}")
597
- return gr.update(value="")
598
 
599
  def create_demo():
600
- """Set up Gradio interface for the chatbot with enhanced styling and functionality."""
601
  try:
602
  # Initialize bot
603
  bot = PearlyBot()
@@ -608,10 +653,7 @@ def create_demo():
608
  if not message.strip():
609
  return history
610
 
611
- # Generate response
612
  response = bot.generate_response(message, history)
613
-
614
- # Update history with proper formatting
615
  history.append({
616
  "role": "user",
617
  "content": message
@@ -621,7 +663,6 @@ def create_demo():
621
  "content": response
622
  })
623
  return history
624
-
625
  except Exception as e:
626
  logger.error(f"Chat error: {e}")
627
  return history + [{
@@ -642,11 +683,57 @@ def create_demo():
642
  response=last_bot_msg,
643
  feedback=1 if positive else -1
644
  )
645
-
646
  return gr.update(value="")
647
  except Exception as e:
648
  logger.error(f"Error processing feedback: {e}")
649
  return gr.update(value="")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
 
651
 
652
  # Create enhanced Gradio interface
@@ -711,33 +798,6 @@ def create_demo():
711
  }
712
  </style>
713
  """)
714
- # Event Handlers - Moved inside the gr.Blocks context
715
- msg.submit(chat, [msg, chatbot], [chatbot]).then(
716
- lambda: gr.update(value=""), None, [msg]
717
- )
718
-
719
- submit.click(chat, [msg, chatbot], [chatbot]).then(
720
- lambda: gr.update(value=""), None, [msg]
721
- )
722
-
723
- # Feedback handlers
724
- feedback_positive.click(
725
- lambda h: process_feedback(True, feedback_text.value, h),
726
- inputs=[chatbot],
727
- outputs=[feedback_text]
728
- )
729
-
730
- feedback_negative.click(
731
- lambda h: process_feedback(False, feedback_text.value, h),
732
- inputs=[chatbot],
733
- outputs=[feedback_text]
734
- )
735
-
736
- # Clear chat
737
- clear.click(lambda: None, None, chatbot)
738
-
739
- # Add queue for handling multiple users
740
- demo.queue(concurrency_count=1, max_size=10)
741
 
742
  # Emergency Banner
743
  gr.HTML("""
@@ -854,6 +914,64 @@ def create_demo():
854
 
855
 
856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
  return demo
858
 
859
  except Exception as e:
 
19
  import torch.nn.functional as F
20
  from torch.cuda.amp import autocast
21
  from torch.utils.data import DataLoader
 
 
 
22
 
23
  # Hugging Face and Transformers
24
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
 
36
  from langchain_community.document_loaders import TextLoader # Updated import
37
  from langchain.text_splitter import RecursiveCharacterTextSplitter
38
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # External Tools and APIs
41
  import wandb
 
95
  gc.collect()
96
 
97
  class PearlyBot:
98
+ def __init__(self):
 
 
 
 
 
 
 
 
 
99
  try:
100
+ # Use the correct model path from your space
101
+ self.repo_id = "Pearilsa/pearly_med_triage_chatbot_kagglex"
102
+ self.model_filename = "pearly_model.zip"
103
+ self.setup_model()
104
+ self.setup_rag()
105
+ self.conversation_history = []
106
+ self.last_interaction_time = time.time()
107
+ self.interaction_cooldown = 1.0
108
+ except Exception as e:
109
+ logger.error(f"Error initializing bot: {e}")
110
+ raise
111
+
112
+ def setup_model(self):
113
+ """Initialize model from Hugging Face space"""
114
+ try:
115
+ logger.info(f"Loading model from {self.repo_id}")
116
 
117
+ # Download and prepare model path
118
+ local_model_path = os.path.join(os.getcwd(), "models")
119
+ os.makedirs(local_model_path, exist_ok=True)
120
+
121
+ # Load tokenizer and model from the space
122
+ self.tokenizer = AutoTokenizer.from_pretrained(
123
+ self.repo_id,
124
+ token=os.getenv("HF_TOKEN"), # Use your Hugging Face token
125
+ cache_dir=local_model_path
126
+ )
127
+ self.tokenizer.pad_token = self.tokenizer.eos_token
128
+ logger.info("Tokenizer loaded successfully")
129
+
130
+ # Load model with 8-bit quantization
131
+ self.model = AutoModelForCausalLM.from_pretrained(
132
+ self.repo_id,
133
+ token=os.getenv("HF_TOKEN"),
134
+ device_map="auto",
135
+ load_in_8bit=True,
136
+ torch_dtype=torch.float16,
137
+ low_cpu_mem_usage=True,
138
+ cache_dir=local_model_path
139
+ )
140
+ self.model.eval()
141
+ logger.info("Model loaded successfully")
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  except Exception as e:
144
  logger.error(f"Error in model setup: {str(e)}")
145
  raise
146
+
147
 
148
  def setup_rag(self):
149
  try:
150
+ # Add configuration options
151
+ self.chunk_size = 300
152
+ self.chunk_overlap = 100
153
+ self.num_relevant_chunks = 3
154
+
155
+ # Load knowledge base
156
+ knowledge_base = self._load_knowledge_base()
157
+
158
+ # Setup embeddings with error handling
159
+ self.embeddings = self._initialize_embeddings()
160
+
161
+ # Enhanced text splitting
162
+ texts = self._split_texts(knowledge_base)
163
+
164
+ # Create vector store with metadata
165
+ self.vector_store = FAISS.from_texts(
166
+ texts,
167
+ self.embeddings,
168
+ metadatas=[{"source": f"chunk_{i}"} for i in range(len(texts))]
169
+ )
170
+
171
+ # Add validation
172
+ self._validate_rag_setup()
173
+
174
+ except Exception as e:
175
+ logger.error(f"RAG setup failed: {str(e)}")
176
+ raise
177
  # Load your knowledge base content
178
+ def _load_knowledge_base(self):
179
+ # Add validation and error handling for knowledge base loading
180
+ return {
181
+ "triage_scenarios.txt": """Medical Triage Scenarios and Responses:
182
 
183
  EMERGENCY (999) SCENARIOS:
184
  1. Cardiovascular:
 
479
  logger.error(f"Error setting up RAG: {str(e)}")
480
  raise
481
 
482
+ def _validate_rag_setup(self):
483
+ """Validate RAG system setup"""
484
+ try:
485
+ # Verify embeddings are working
486
+ test_text = "This is a test embedding"
487
+ test_embedding = self.embeddings.encode(test_text)
488
+ assert len(test_embedding) > 0
489
+
490
+ # Verify vector store is operational
491
+ test_results = self.vector_store.similarity_search(test_text, k=1)
492
+ assert len(test_results) > 0
493
+
494
+ logger.info("RAG system validation successful")
495
+ return True
496
+ except Exception as e:
497
+ logger.error(f"RAG system validation failed: {str(e)}")
498
+ raise
499
+
500
+ def _initialize_embeddings(self):
501
+ try:
502
+ return HuggingFaceEmbeddings(
503
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
504
+ cache_folder="./embeddings_cache" # Added caching
505
+ )
506
+ except Exception as e:
507
+ logger.error(f"Failed to initialize embeddings: {str(e)}")
508
+ raise
509
+
510
+ def _split_texts(self, knowledge_base):
511
+ splitter = RecursiveCharacterTextSplitter(
512
+ chunk_size=self.chunk_size,
513
+ chunk_overlap=self.chunk_overlap,
514
+ length_function=len,
515
+ add_start_index=True
516
+ )
517
+
518
+ all_texts = []
519
+ for content in knowledge_base.values():
520
+ texts = splitter.split_text(content)
521
+ all_texts.extend(texts)
522
+ return all_texts
523
+
524
  def get_relevant_context(self, query):
525
  try:
526
  docs = self.vector_store.similarity_search(query, k=3)
 
641
  except Exception as e:
642
  logger.error(f"Error in cleanup: {e}")
643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
 
645
  def create_demo():
 
646
  try:
647
  # Initialize bot
648
  bot = PearlyBot()
 
653
  if not message.strip():
654
  return history
655
 
 
656
  response = bot.generate_response(message, history)
 
 
657
  history.append({
658
  "role": "user",
659
  "content": message
 
663
  "content": response
664
  })
665
  return history
 
666
  except Exception as e:
667
  logger.error(f"Chat error: {e}")
668
  return history + [{
 
683
  response=last_bot_msg,
684
  feedback=1 if positive else -1
685
  )
 
686
  return gr.update(value="")
687
  except Exception as e:
688
  logger.error(f"Error processing feedback: {e}")
689
  return gr.update(value="")
690
+
691
+ # Create Gradio interface
692
+ with gr.Blocks(theme=gr.themes.Soft(...)) as demo:
693
+ # 1. First, create all UI elements
694
+ # CSS styles
695
+ gr.HTML("""<style>...""")
696
+
697
+ # Emergency Banner
698
+ gr.HTML("""<div class="emergency-banner">...""")
699
+
700
+ # Header
701
+ with gr.Row(elem_classes="header"):
702
+ gr.Markdown("""# GP Medical Triage Assistant...""")
703
+
704
+ # Features Grid
705
+ gr.HTML("""<div class="features-grid">...""")
706
+
707
+ # Chat Interface
708
+ with gr.Row():
709
+ with gr.Column(scale=4):
710
+ chatbot = gr.Chatbot(...)
711
+ with gr.Row():
712
+ msg = gr.Textbox(...)
713
+ submit = gr.Button(...)
714
+
715
+ with gr.Column(scale=1):
716
+ # Quick Actions
717
+ emergency_btn = gr.Button("🚨 Emergency Info", variant="secondary")
718
+ nhs_111_btn = gr.Button("πŸ“ž NHS 111 Info", variant="secondary")
719
+ booking_btn = gr.Button("πŸ“… GP Booking", variant="secondary")
720
+
721
+ # Controls
722
+ clear = gr.Button("πŸ—‘οΈ Clear Chat")
723
+
724
+ # Feedback
725
+ with gr.Row():
726
+ feedback_positive = gr.Button("πŸ‘", elem_id="thumb-up")
727
+ feedback_negative = gr.Button("πŸ‘Ž", elem_id="thumb-down")
728
+ feedback_text = gr.Textbox(...)
729
+ feedback_submit = gr.Button(...)
730
+
731
+ # Examples and Guide
732
+ with gr.Accordion("Example Messages", open=False):
733
+ gr.Examples([...])
734
+
735
+ with gr.Accordion("NHS Services Guide", open=False):
736
+ gr.Markdown("""...""")
737
 
738
 
739
  # Create enhanced Gradio interface
 
798
  }
799
  </style>
800
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
 
802
  # Emergency Banner
803
  gr.HTML("""
 
914
 
915
 
916
 
917
+ def show_emergency_info():
918
+ return """🚨 Emergency Services (999)
919
+ - For life-threatening emergencies
920
+ - Severe chest pain
921
+ - Difficulty breathing
922
+ - Severe bleeding
923
+ - Loss of consciousness
924
+ """
925
+
926
+ def show_nhs_111_info():
927
+ return """πŸ“ž NHS 111 Service
928
+ - Available 24/7
929
+ - Medical advice
930
+ - Local service information
931
+ - Urgent care guidance
932
+ """
933
+
934
+ def show_booking_info():
935
+ return """πŸ“… GP Booking Options
936
+ - Online booking
937
+ - Phone booking
938
+ - Routine appointments
939
+ - Urgent appointments
940
+ """
941
+
942
+ # Chat handlers
943
+ msg.submit(chat, [msg, chatbot], [chatbot]).then(
944
+ lambda: gr.update(value=""), None, [msg]
945
+ )
946
+
947
+ submit.click(chat, [msg, chatbot], [chatbot]).then(
948
+ lambda: gr.update(value=""), None, [msg]
949
+ )
950
+
951
+ # Quick action handlers
952
+ emergency_btn.click(lambda: show_emergency_info(), outputs=[msg])
953
+ nhs_111_btn.click(lambda: show_nhs_111_info(), outputs=[msg])
954
+ booking_btn.click(lambda: show_booking_info(), outputs=[msg])
955
+
956
+ # Feedback handlers
957
+ feedback_positive.click(
958
+ lambda h: process_feedback(True, feedback_text.value, h),
959
+ inputs=[chatbot],
960
+ outputs=[feedback_text]
961
+ )
962
+
963
+ feedback_negative.click(
964
+ lambda h: process_feedback(False, feedback_text.value, h),
965
+ inputs=[chatbot],
966
+ outputs=[feedback_text]
967
+ )
968
+
969
+ # Clear chat
970
+ clear.click(lambda: None, None, chatbot)
971
+
972
+ # 3. Finally, add the queue
973
+ demo.queue(concurrency_count=1, max_size=10)
974
+
975
  return demo
976
 
977
  except Exception as e: