Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -19,9 +19,6 @@ import torch.nn as nn
|
|
19 |
import torch.nn.functional as F
|
20 |
from torch.cuda.amp import autocast
|
21 |
from torch.utils.data import DataLoader
|
22 |
-
import tensorflow as tf
|
23 |
-
import keras
|
24 |
-
import numpy as np
|
25 |
|
26 |
# Hugging Face and Transformers
|
27 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
@@ -39,16 +36,6 @@ from langchain_community.embeddings import HuggingFaceEmbeddings # Updated impo
|
|
39 |
from langchain_community.document_loaders import TextLoader # Updated import
|
40 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
41 |
|
42 |
-
# Data Science and Visualization Libraries
|
43 |
-
import pandas as pd
|
44 |
-
import seaborn as sns
|
45 |
-
import matplotlib.pyplot as plt
|
46 |
-
from matplotlib.gridspec import GridSpec
|
47 |
-
from sklearn.metrics import classification_report, confusion_matrix
|
48 |
-
|
49 |
-
# Development and Testing
|
50 |
-
import pytest
|
51 |
-
from unittest.mock import Mock, patch
|
52 |
|
53 |
# External Tools and APIs
|
54 |
import wandb
|
@@ -108,54 +95,90 @@ class ModelManager:
|
|
108 |
gc.collect()
|
109 |
|
110 |
class PearlyBot:
|
111 |
-
def __init__(self
|
112 |
-
self.model_dir = ModelManager.verify_and_extract_model(model_zip_path, model_dir)
|
113 |
-
self.setup_model(self.model_dir)
|
114 |
-
self.setup_rag()
|
115 |
-
self.conversation_history = []
|
116 |
-
self.last_interaction_time = time.time()
|
117 |
-
self.interaction_cooldown = 1.0 # seconds
|
118 |
-
|
119 |
-
def setup_model(self, model_path: str):
|
120 |
-
"""Initialize the model with proper error handling"""
|
121 |
try:
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
-
#
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
# Load model
|
135 |
-
try:
|
136 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
137 |
-
model_path,
|
138 |
-
device_map="auto",
|
139 |
-
load_in_8bit=True,
|
140 |
-
torch_dtype=torch.float16,
|
141 |
-
low_cpu_mem_usage=True
|
142 |
-
)
|
143 |
-
self.model.eval()
|
144 |
-
logger.info("Model loaded successfully")
|
145 |
-
except Exception as e:
|
146 |
-
logger.error(f"Failed to load model: {str(e)}")
|
147 |
-
raise
|
148 |
-
|
149 |
except Exception as e:
|
150 |
logger.error(f"Error in model setup: {str(e)}")
|
151 |
raise
|
|
|
152 |
|
153 |
def setup_rag(self):
|
154 |
try:
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
# Load your knowledge base content
|
157 |
-
|
158 |
-
|
|
|
|
|
159 |
|
160 |
EMERGENCY (999) SCENARIOS:
|
161 |
1. Cardiovascular:
|
@@ -456,6 +479,48 @@ PROFESSIONAL CONDUCT:
|
|
456 |
logger.error(f"Error setting up RAG: {str(e)}")
|
457 |
raise
|
458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
def get_relevant_context(self, query):
|
460 |
try:
|
461 |
docs = self.vector_store.similarity_search(query, k=3)
|
@@ -576,28 +641,8 @@ Guidelines:
|
|
576 |
except Exception as e:
|
577 |
logger.error(f"Error in cleanup: {e}")
|
578 |
|
579 |
-
def process_feedback(positive: bool, comment: str, history: List[Dict[str, str]]):
|
580 |
-
try:
|
581 |
-
if not history or len(history) < 2:
|
582 |
-
return gr.update(value="")
|
583 |
-
|
584 |
-
last_user_msg = history[-2]["content"] if isinstance(history[-2], dict) else history[-2][0]
|
585 |
-
last_bot_msg = history[-1]["content"] if isinstance(history[-1], dict) else history[-1][1]
|
586 |
-
|
587 |
-
bot.handle_feedback(
|
588 |
-
message=last_user_msg,
|
589 |
-
response=last_bot_msg,
|
590 |
-
feedback=1 if positive else -1
|
591 |
-
)
|
592 |
-
|
593 |
-
return gr.update(value="")
|
594 |
-
|
595 |
-
except Exception as e:
|
596 |
-
logger.error(f"Error processing feedback: {e}")
|
597 |
-
return gr.update(value="")
|
598 |
|
599 |
def create_demo():
|
600 |
-
"""Set up Gradio interface for the chatbot with enhanced styling and functionality."""
|
601 |
try:
|
602 |
# Initialize bot
|
603 |
bot = PearlyBot()
|
@@ -608,10 +653,7 @@ def create_demo():
|
|
608 |
if not message.strip():
|
609 |
return history
|
610 |
|
611 |
-
# Generate response
|
612 |
response = bot.generate_response(message, history)
|
613 |
-
|
614 |
-
# Update history with proper formatting
|
615 |
history.append({
|
616 |
"role": "user",
|
617 |
"content": message
|
@@ -621,7 +663,6 @@ def create_demo():
|
|
621 |
"content": response
|
622 |
})
|
623 |
return history
|
624 |
-
|
625 |
except Exception as e:
|
626 |
logger.error(f"Chat error: {e}")
|
627 |
return history + [{
|
@@ -642,11 +683,57 @@ def create_demo():
|
|
642 |
response=last_bot_msg,
|
643 |
feedback=1 if positive else -1
|
644 |
)
|
645 |
-
|
646 |
return gr.update(value="")
|
647 |
except Exception as e:
|
648 |
logger.error(f"Error processing feedback: {e}")
|
649 |
return gr.update(value="")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
650 |
|
651 |
|
652 |
# Create enhanced Gradio interface
|
@@ -711,33 +798,6 @@ def create_demo():
|
|
711 |
}
|
712 |
</style>
|
713 |
""")
|
714 |
-
# Event Handlers - Moved inside the gr.Blocks context
|
715 |
-
msg.submit(chat, [msg, chatbot], [chatbot]).then(
|
716 |
-
lambda: gr.update(value=""), None, [msg]
|
717 |
-
)
|
718 |
-
|
719 |
-
submit.click(chat, [msg, chatbot], [chatbot]).then(
|
720 |
-
lambda: gr.update(value=""), None, [msg]
|
721 |
-
)
|
722 |
-
|
723 |
-
# Feedback handlers
|
724 |
-
feedback_positive.click(
|
725 |
-
lambda h: process_feedback(True, feedback_text.value, h),
|
726 |
-
inputs=[chatbot],
|
727 |
-
outputs=[feedback_text]
|
728 |
-
)
|
729 |
-
|
730 |
-
feedback_negative.click(
|
731 |
-
lambda h: process_feedback(False, feedback_text.value, h),
|
732 |
-
inputs=[chatbot],
|
733 |
-
outputs=[feedback_text]
|
734 |
-
)
|
735 |
-
|
736 |
-
# Clear chat
|
737 |
-
clear.click(lambda: None, None, chatbot)
|
738 |
-
|
739 |
-
# Add queue for handling multiple users
|
740 |
-
demo.queue(concurrency_count=1, max_size=10)
|
741 |
|
742 |
# Emergency Banner
|
743 |
gr.HTML("""
|
@@ -854,6 +914,64 @@ def create_demo():
|
|
854 |
|
855 |
|
856 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
857 |
return demo
|
858 |
|
859 |
except Exception as e:
|
|
|
19 |
import torch.nn.functional as F
|
20 |
from torch.cuda.amp import autocast
|
21 |
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
22 |
|
23 |
# Hugging Face and Transformers
|
24 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
|
|
36 |
from langchain_community.document_loaders import TextLoader # Updated import
|
37 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
# External Tools and APIs
|
41 |
import wandb
|
|
|
95 |
gc.collect()
|
96 |
|
97 |
class PearlyBot:
|
98 |
+
def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
try:
|
100 |
+
# Use the correct model path from your space
|
101 |
+
self.repo_id = "Pearilsa/pearly_med_triage_chatbot_kagglex"
|
102 |
+
self.model_filename = "pearly_model.zip"
|
103 |
+
self.setup_model()
|
104 |
+
self.setup_rag()
|
105 |
+
self.conversation_history = []
|
106 |
+
self.last_interaction_time = time.time()
|
107 |
+
self.interaction_cooldown = 1.0
|
108 |
+
except Exception as e:
|
109 |
+
logger.error(f"Error initializing bot: {e}")
|
110 |
+
raise
|
111 |
+
|
112 |
+
def setup_model(self):
|
113 |
+
"""Initialize model from Hugging Face space"""
|
114 |
+
try:
|
115 |
+
logger.info(f"Loading model from {self.repo_id}")
|
116 |
|
117 |
+
# Download and prepare model path
|
118 |
+
local_model_path = os.path.join(os.getcwd(), "models")
|
119 |
+
os.makedirs(local_model_path, exist_ok=True)
|
120 |
+
|
121 |
+
# Load tokenizer and model from the space
|
122 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
123 |
+
self.repo_id,
|
124 |
+
token=os.getenv("HF_TOKEN"), # Use your Hugging Face token
|
125 |
+
cache_dir=local_model_path
|
126 |
+
)
|
127 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
128 |
+
logger.info("Tokenizer loaded successfully")
|
129 |
+
|
130 |
+
# Load model with 8-bit quantization
|
131 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
132 |
+
self.repo_id,
|
133 |
+
token=os.getenv("HF_TOKEN"),
|
134 |
+
device_map="auto",
|
135 |
+
load_in_8bit=True,
|
136 |
+
torch_dtype=torch.float16,
|
137 |
+
low_cpu_mem_usage=True,
|
138 |
+
cache_dir=local_model_path
|
139 |
+
)
|
140 |
+
self.model.eval()
|
141 |
+
logger.info("Model loaded successfully")
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
except Exception as e:
|
144 |
logger.error(f"Error in model setup: {str(e)}")
|
145 |
raise
|
146 |
+
|
147 |
|
148 |
def setup_rag(self):
|
149 |
try:
|
150 |
+
# Add configuration options
|
151 |
+
self.chunk_size = 300
|
152 |
+
self.chunk_overlap = 100
|
153 |
+
self.num_relevant_chunks = 3
|
154 |
+
|
155 |
+
# Load knowledge base
|
156 |
+
knowledge_base = self._load_knowledge_base()
|
157 |
+
|
158 |
+
# Setup embeddings with error handling
|
159 |
+
self.embeddings = self._initialize_embeddings()
|
160 |
+
|
161 |
+
# Enhanced text splitting
|
162 |
+
texts = self._split_texts(knowledge_base)
|
163 |
+
|
164 |
+
# Create vector store with metadata
|
165 |
+
self.vector_store = FAISS.from_texts(
|
166 |
+
texts,
|
167 |
+
self.embeddings,
|
168 |
+
metadatas=[{"source": f"chunk_{i}"} for i in range(len(texts))]
|
169 |
+
)
|
170 |
+
|
171 |
+
# Add validation
|
172 |
+
self._validate_rag_setup()
|
173 |
+
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"RAG setup failed: {str(e)}")
|
176 |
+
raise
|
177 |
# Load your knowledge base content
|
178 |
+
def _load_knowledge_base(self):
|
179 |
+
# Add validation and error handling for knowledge base loading
|
180 |
+
return {
|
181 |
+
"triage_scenarios.txt": """Medical Triage Scenarios and Responses:
|
182 |
|
183 |
EMERGENCY (999) SCENARIOS:
|
184 |
1. Cardiovascular:
|
|
|
479 |
logger.error(f"Error setting up RAG: {str(e)}")
|
480 |
raise
|
481 |
|
482 |
+
def _validate_rag_setup(self):
|
483 |
+
"""Validate RAG system setup"""
|
484 |
+
try:
|
485 |
+
# Verify embeddings are working
|
486 |
+
test_text = "This is a test embedding"
|
487 |
+
test_embedding = self.embeddings.encode(test_text)
|
488 |
+
assert len(test_embedding) > 0
|
489 |
+
|
490 |
+
# Verify vector store is operational
|
491 |
+
test_results = self.vector_store.similarity_search(test_text, k=1)
|
492 |
+
assert len(test_results) > 0
|
493 |
+
|
494 |
+
logger.info("RAG system validation successful")
|
495 |
+
return True
|
496 |
+
except Exception as e:
|
497 |
+
logger.error(f"RAG system validation failed: {str(e)}")
|
498 |
+
raise
|
499 |
+
|
500 |
+
def _initialize_embeddings(self):
|
501 |
+
try:
|
502 |
+
return HuggingFaceEmbeddings(
|
503 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
504 |
+
cache_folder="./embeddings_cache" # Added caching
|
505 |
+
)
|
506 |
+
except Exception as e:
|
507 |
+
logger.error(f"Failed to initialize embeddings: {str(e)}")
|
508 |
+
raise
|
509 |
+
|
510 |
+
def _split_texts(self, knowledge_base):
|
511 |
+
splitter = RecursiveCharacterTextSplitter(
|
512 |
+
chunk_size=self.chunk_size,
|
513 |
+
chunk_overlap=self.chunk_overlap,
|
514 |
+
length_function=len,
|
515 |
+
add_start_index=True
|
516 |
+
)
|
517 |
+
|
518 |
+
all_texts = []
|
519 |
+
for content in knowledge_base.values():
|
520 |
+
texts = splitter.split_text(content)
|
521 |
+
all_texts.extend(texts)
|
522 |
+
return all_texts
|
523 |
+
|
524 |
def get_relevant_context(self, query):
|
525 |
try:
|
526 |
docs = self.vector_store.similarity_search(query, k=3)
|
|
|
641 |
except Exception as e:
|
642 |
logger.error(f"Error in cleanup: {e}")
|
643 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
644 |
|
645 |
def create_demo():
|
|
|
646 |
try:
|
647 |
# Initialize bot
|
648 |
bot = PearlyBot()
|
|
|
653 |
if not message.strip():
|
654 |
return history
|
655 |
|
|
|
656 |
response = bot.generate_response(message, history)
|
|
|
|
|
657 |
history.append({
|
658 |
"role": "user",
|
659 |
"content": message
|
|
|
663 |
"content": response
|
664 |
})
|
665 |
return history
|
|
|
666 |
except Exception as e:
|
667 |
logger.error(f"Chat error: {e}")
|
668 |
return history + [{
|
|
|
683 |
response=last_bot_msg,
|
684 |
feedback=1 if positive else -1
|
685 |
)
|
|
|
686 |
return gr.update(value="")
|
687 |
except Exception as e:
|
688 |
logger.error(f"Error processing feedback: {e}")
|
689 |
return gr.update(value="")
|
690 |
+
|
691 |
+
# Create Gradio interface
|
692 |
+
with gr.Blocks(theme=gr.themes.Soft(...)) as demo:
|
693 |
+
# 1. First, create all UI elements
|
694 |
+
# CSS styles
|
695 |
+
gr.HTML("""<style>...""")
|
696 |
+
|
697 |
+
# Emergency Banner
|
698 |
+
gr.HTML("""<div class="emergency-banner">...""")
|
699 |
+
|
700 |
+
# Header
|
701 |
+
with gr.Row(elem_classes="header"):
|
702 |
+
gr.Markdown("""# GP Medical Triage Assistant...""")
|
703 |
+
|
704 |
+
# Features Grid
|
705 |
+
gr.HTML("""<div class="features-grid">...""")
|
706 |
+
|
707 |
+
# Chat Interface
|
708 |
+
with gr.Row():
|
709 |
+
with gr.Column(scale=4):
|
710 |
+
chatbot = gr.Chatbot(...)
|
711 |
+
with gr.Row():
|
712 |
+
msg = gr.Textbox(...)
|
713 |
+
submit = gr.Button(...)
|
714 |
+
|
715 |
+
with gr.Column(scale=1):
|
716 |
+
# Quick Actions
|
717 |
+
emergency_btn = gr.Button("π¨ Emergency Info", variant="secondary")
|
718 |
+
nhs_111_btn = gr.Button("π NHS 111 Info", variant="secondary")
|
719 |
+
booking_btn = gr.Button("π
GP Booking", variant="secondary")
|
720 |
+
|
721 |
+
# Controls
|
722 |
+
clear = gr.Button("ποΈ Clear Chat")
|
723 |
+
|
724 |
+
# Feedback
|
725 |
+
with gr.Row():
|
726 |
+
feedback_positive = gr.Button("π", elem_id="thumb-up")
|
727 |
+
feedback_negative = gr.Button("π", elem_id="thumb-down")
|
728 |
+
feedback_text = gr.Textbox(...)
|
729 |
+
feedback_submit = gr.Button(...)
|
730 |
+
|
731 |
+
# Examples and Guide
|
732 |
+
with gr.Accordion("Example Messages", open=False):
|
733 |
+
gr.Examples([...])
|
734 |
+
|
735 |
+
with gr.Accordion("NHS Services Guide", open=False):
|
736 |
+
gr.Markdown("""...""")
|
737 |
|
738 |
|
739 |
# Create enhanced Gradio interface
|
|
|
798 |
}
|
799 |
</style>
|
800 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
801 |
|
802 |
# Emergency Banner
|
803 |
gr.HTML("""
|
|
|
914 |
|
915 |
|
916 |
|
917 |
+
def show_emergency_info():
|
918 |
+
return """π¨ Emergency Services (999)
|
919 |
+
- For life-threatening emergencies
|
920 |
+
- Severe chest pain
|
921 |
+
- Difficulty breathing
|
922 |
+
- Severe bleeding
|
923 |
+
- Loss of consciousness
|
924 |
+
"""
|
925 |
+
|
926 |
+
def show_nhs_111_info():
|
927 |
+
return """π NHS 111 Service
|
928 |
+
- Available 24/7
|
929 |
+
- Medical advice
|
930 |
+
- Local service information
|
931 |
+
- Urgent care guidance
|
932 |
+
"""
|
933 |
+
|
934 |
+
def show_booking_info():
|
935 |
+
return """π
GP Booking Options
|
936 |
+
- Online booking
|
937 |
+
- Phone booking
|
938 |
+
- Routine appointments
|
939 |
+
- Urgent appointments
|
940 |
+
"""
|
941 |
+
|
942 |
+
# Chat handlers
|
943 |
+
msg.submit(chat, [msg, chatbot], [chatbot]).then(
|
944 |
+
lambda: gr.update(value=""), None, [msg]
|
945 |
+
)
|
946 |
+
|
947 |
+
submit.click(chat, [msg, chatbot], [chatbot]).then(
|
948 |
+
lambda: gr.update(value=""), None, [msg]
|
949 |
+
)
|
950 |
+
|
951 |
+
# Quick action handlers
|
952 |
+
emergency_btn.click(lambda: show_emergency_info(), outputs=[msg])
|
953 |
+
nhs_111_btn.click(lambda: show_nhs_111_info(), outputs=[msg])
|
954 |
+
booking_btn.click(lambda: show_booking_info(), outputs=[msg])
|
955 |
+
|
956 |
+
# Feedback handlers
|
957 |
+
feedback_positive.click(
|
958 |
+
lambda h: process_feedback(True, feedback_text.value, h),
|
959 |
+
inputs=[chatbot],
|
960 |
+
outputs=[feedback_text]
|
961 |
+
)
|
962 |
+
|
963 |
+
feedback_negative.click(
|
964 |
+
lambda h: process_feedback(False, feedback_text.value, h),
|
965 |
+
inputs=[chatbot],
|
966 |
+
outputs=[feedback_text]
|
967 |
+
)
|
968 |
+
|
969 |
+
# Clear chat
|
970 |
+
clear.click(lambda: None, None, chatbot)
|
971 |
+
|
972 |
+
# 3. Finally, add the queue
|
973 |
+
demo.queue(concurrency_count=1, max_size=10)
|
974 |
+
|
975 |
return demo
|
976 |
|
977 |
except Exception as e:
|