PearlIsa's picture
Update app.py
263ec9b verified
# Standard imports first
import os
import torch
import logging
from datetime import datetime
from huggingface_hub import login
from dotenv import load_dotenv
from datasets import load_dataset, Dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
BitsAndBytesConfig
)
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training
)
from tqdm.auto import tqdm
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SecretsManager:
"""Handles authentication and secrets management"""
@staticmethod
def setup_credentials():
"""Setup all required credentials"""
try:
# Load environment variables
load_dotenv()
# Get credentials
credentials = {
'KAGGLE_USERNAME': os.getenv('KAGGLE_USERNAME'),
'KAGGLE_KEY': os.getenv('KAGGLE_KEY'),
'HF_TOKEN': os.getenv('HF_TOKEN'),
'WANDB_KEY': os.getenv('WANDB_KEY')
}
# Validate credentials
missing_creds = [k for k, v in credentials.items() if not v]
if missing_creds:
logger.warning(f"Missing credentials: {', '.join(missing_creds)}")
# Setup Hugging Face authentication
if credentials['HF_TOKEN']:
login(token=credentials['HF_TOKEN'])
logger.info("Successfully logged in to Hugging Face")
# Setup Kaggle credentials if available
if credentials['KAGGLE_USERNAME'] and credentials['KAGGLE_KEY']:
os.environ['KAGGLE_USERNAME'] = credentials['KAGGLE_USERNAME']
os.environ['KAGGLE_KEY'] = credentials['KAGGLE_KEY']
# Setup wandb if available
if credentials['WANDB_KEY']:
os.environ['WANDB_API_KEY'] = credentials['WANDB_KEY']
return credentials
except Exception as e:
logger.error(f"Error setting up credentials: {e}")
raise
class ModelTrainer:
"""Handles model training pipeline"""
def __init__(self):
# Set memory optimization environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# Initialize attributes
self.model = None
self.tokenizer = None
self.dataset = None
self.processed_dataset = None
self.chunk_size = 300
self.chunk_overlap = 100
self.num_relevant_chunks = 3
self.vector_store = None
self.embeddings = None
self.last_interaction_time = time.time() # Add this
self.interaction_cooldown = 1.0 # Add this
# Setup GPU preferences
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def prepare_initial_datasets(batch_size=8):
print("Loading datasets with memory-optimized batch processing...")
def process_medqa_batch(examples):
results = []
inputs = examples['input']
instructions = examples['instruction']
outputs = examples['output']
for inp, inst, out in zip(inputs, instructions, outputs):
results.append({
"input": f"{inp} {inst}",
"output": out
})
return results
def process_meddia_batch(examples):
results = []
inputs = examples['input']
outputs = examples['output']
for inp, out in zip(inputs, outputs):
results.append({
"input": inp,
"output": out
})
return results
def process_persona_batch(examples):
results = []
personalities = examples['personality']
utterances = examples['utterances']
for pers, utts in zip(personalities, utterances):
try:
# Process personality list
personality = ' '.join([
p for p in pers
if isinstance(p, str)
])
# Process utterances
if utts and len(utts) > 0:
utterance = utts[0]
history = []
# Process history
if 'history' in utterance and utterance['history']:
history = [
h for h in utterance['history']
if isinstance(h, str)
]
history_text = ' '.join(history)
# Get candidate response
candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else ''
if personality or history_text:
results.append({
"input": f"{personality} {history_text}".strip(),
"output": candidate
})
except Exception as e:
print(f"Error processing persona batch item: {e}")
continue
return results
try:
Load and process each dataset separately
print("Processing MedQA dataset...")
medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]")
medqa_processed = []
for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"):
batch = medqa[i:i + batch_size]
medqa_processed.extend(process_medqa_batch(batch))
if i % (batch_size * 5) == 0:
torch.cuda.empty_cache()
print("Processing MedDiagnosis dataset...")
meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]")
meddia_processed = []
for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"):
batch = meddia[i:i + batch_size]
meddia_processed.extend(process_meddia_batch(batch))
if i % (batch_size * 5) == 0:
torch.cuda.empty_cache()
print("Processing Persona-Chat dataset...")
persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
persona_processed = []
for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"):
batch = persona[i:i + batch_size]
persona_processed.extend(process_persona_batch(batch))
if i % (batch_size * 5) == 0:
torch.cuda.empty_cache()
torch.cuda.empty_cache()
print("Creating final dataset...")
all_processed = persona_processed + medqa_processed + meddia_processed
valid_data = {
"input": [],
"output": []
}
for item in all_processed:
if item["input"].strip() and item["output"].strip():
valid_data["input"].append(item["input"])
valid_data["output"].append(item["output"])
final_dataset = Dataset.from_dict(valid_data)
print(f"Final dataset size: {len(final_dataset)}")
return final_dataset
def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4):
def tokenize_batch(examples):
formatted_texts = []
for i in range(0, len(examples['input']), batch_size):
sub_batch_inputs = examples['input'][i:i + batch_size]
sub_batch_outputs = examples['output'][i:i + batch_size]
for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs):
try:
formatted_text = f"""<start_of_turn>user
{input_text}
<end_of_turn>
<start_of_turn>assistant
{output_text}
<end_of_turn>"""
formatted_texts.append(formatted_text)
except Exception as e:
print(f"Error formatting text: {e}")
continue
tokenized = tokenizer(
formatted_texts,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors=None
)
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
print(f"Tokenizing dataset in small batches (size={batch_size})...")
tokenized_dataset = dataset.map(
tokenize_batch,
batched=True,
batch_size=batch_size,
remove_columns=dataset.column_names,
desc="Tokenizing dataset",
load_from_cache_file=False
)
return tokenized_dataset
def setup_rag(self):
"""Initialize RAG components"""
try:
logger.info("Setting up RAG system...")
# Load knowledge base
knowledge_base = self._load_knowledge_base()
# Setup embeddings
self.embeddings = self._initialize_embeddings()
# Process texts for vector store
texts = self._split_texts(knowledge_base)
# Create vector store with metadata
self.vector_store = FAISS.from_texts(
texts,
self.embeddings,
metadatas=[{"source": f"chunk_{i}"} for i in range(len(texts))]
)
# Validate RAG setup
self._validate_rag_setup()
logger.info("RAG system setup complete")
except Exception as e:
logger.error(f"Failed to setup RAG: {e}")
raise
# Load your knowledge base content
def _load_knowledge_base(self):
"""Load and validate knowledge base content"""
try:
knowledge_base = {
"triage_scenarios.txt": """Medical Triage Scenarios and Responses:
EMERGENCY (999) SCENARIOS:
1. Cardiovascular:
- Chest pain/pressure
- Heart attack symptoms
- Irregular heartbeat with dizziness
Response: Immediate 999 call, sit/lie down, chew aspirin if available
2. Respiratory:
- Severe breathing difficulty
- Choking
- Unable to speak full sentences
Response: 999, sitting position, clear airway
3. Neurological:
- Stroke symptoms (FAST)
- Seizures
- Unconsciousness
Response: 999, recovery position if unconscious
4. Trauma:
- Severe bleeding
- Head injuries with confusion
- Major burns
Response: 999, apply direct pressure to bleeding
URGENT CARE (111) SCENARIOS:
1. Moderate Symptoms:
- Persistent fever
- Non-severe infections
- Minor injuries
Response: 111 contact, monitor symptoms
2. Minor Emergencies:
- Small cuts needing stitches
- Sprains and strains
- Mild allergic reactions
Response: 111 or urgent care visit
GP APPOINTMENT SCENARIOS:
1. Routine Care:
- Chronic condition review
- Medication reviews
- Non-urgent symptoms
Response: Book routine GP appointment
2. Preventive Care:
- Vaccinations
- Health screenings
- Regular check-ups
Response: Schedule with GP reception""",
"emergency_detection.txt": """Enhanced Emergency Detection Criteria:
IMMEDIATE LIFE THREATS:
1. Cardiac Symptoms:
- Chest pain/pressure/tightness
- Pain spreading to arms/jaw/neck
- Sweating with nausea
- Shortness of breath
2. Breathing Problems:
- Severe shortness of breath
- Blue lips or face
- Unable to complete sentences
- Choking/airway blockage
3. Neurological:
- FAST (Face, Arms, Speech, Time)
- Sudden confusion
- Severe headache
- Seizures
- Loss of consciousness
4. Severe Trauma:
- Heavy bleeding
- Deep wounds
- Head injury with confusion
- Severe burns
- Broken bones with deformity
5. Anaphylaxis:
- Sudden swelling
- Difficulty breathing
- Rapid onset rash
- Light-headedness
URGENT BUT NOT IMMEDIATE:
1. Moderate Symptoms:
- Persistent fever
- Dehydration
- Non-severe infections
- Minor injuries
2. Worsening Conditions:
- Increasing pain
- Progressive symptoms
- Medication reactions
RESPONSE PROTOCOLS:
1. For Life Threats:
- Immediate 999 call
- Clear first aid instructions
- Stay on line until help arrives
2. For Urgent Care:
- 111 contact
- Monitor for worsening
- Document symptoms""",
"gp_booking.txt": """GP Appointment Booking Templates:
APPOINTMENT TYPES:
1. Routine Appointments:
Template: "I need to book a routine appointment for [condition]. My availability is [times/dates]. My GP is Dr. [name] if available."
2. Follow-up Appointments:
Template: "I need a follow-up appointment regarding [condition] discussed on [date]. My previous appointment was with Dr. [name]."
3. Medication Reviews:
Template: "I need a medication review for [medication]. My last review was [date]."
BOOKING INFORMATION NEEDED:
1. Patient Details:
- Full name
- Date of birth
- NHS number (if known)
- Registered GP practice
2. Appointment Details:
- Nature of appointment
- Preferred times/dates
- Urgency level
- Special requirements
3. Contact Information:
- Phone number
- Alternative contact
- Preferred contact method
BOOKING PROCESS:
1. Online Booking:
- NHS app instructions
- Practice website guidance
- System navigation help
2. Phone Booking:
- Best times to call
- Required information
- Queue management tips
3. Special Circumstances:
- Interpreter needs
- Accessibility requirements
- Transport arrangements""",
"cultural_sensitivity.txt": """Cultural Sensitivity Guidelines:
CULTURAL AWARENESS:
1. Religious Considerations:
- Prayer times
- Religious observations
- Dietary restrictions
- Gender preferences for care
- Religious festivals/fasting periods
2. Language Support:
- Interpreter services
- Multi-language resources
- Clear communication methods
- Family involvement preferences
3. Cultural Beliefs:
- Traditional medicine practices
- Cultural health beliefs
- Family decision-making
- Privacy customs
COMMUNICATION APPROACHES:
1. Respectful Interaction:
- Use preferred names/titles
- Appropriate greetings
- Non-judgmental responses
- Active listening
2. Language Usage:
- Clear, simple terms
- Avoid medical jargon
- Confirm understanding
- Respect silence/pauses
3. Non-verbal Communication:
- Eye contact customs
- Personal space
- Body language awareness
- Gesture sensitivity
SPECIFIC CONSIDERATIONS:
1. South Asian Communities:
- Family involvement
- Gender sensitivity
- Traditional medicine
- Language diversity
2. Middle Eastern Communities:
- Gender-specific care
- Religious observations
- Family hierarchies
- Privacy concerns
3. African/Caribbean Communities:
- Traditional healers
- Community involvement
- Historical medical mistrust
- Cultural specific conditions
4. Eastern European Communities:
- Direct communication
- Family involvement
- Medical documentation
- Language support
INCLUSIVE PRACTICES:
1. Appointment Scheduling:
- Religious holidays
- Prayer times
- Family availability
- Interpreter needs
2. Treatment Planning:
- Cultural preferences
- Traditional practices
- Family involvement
- Dietary requirements
3. Support Services:
- Community resources
- Cultural organizations
- Language services
- Social support""",
"service_boundaries.txt": """Service Limitations and Professional Boundaries:
CLEAR BOUNDARIES:
1. Medical Advice:
- No diagnoses
- No prescriptions
- No treatment recommendations
- No medical procedures
- No second opinions
2. Emergency Services:
- Clear referral criteria
- Documented responses
- Follow-up protocols
- Handover procedures
3. Information Sharing:
- Confidentiality limits
- Data protection
- Record keeping
- Information governance
PROFESSIONAL CONDUCT:
1. Communication:
- Professional language
- Emotional boundaries
- Personal distance
- Service scope
2. Service Delivery:
- No financial transactions
- No personal relationships
- Clear role definition
- Professional limits"""
}
# Create knowledge base directory
os.makedirs("knowledge_base", exist_ok=True)
# Write files and process documents
documents = []
for filename, content in knowledge_base.items():
filepath = os.path.join("knowledge_base", filename)
with open(filepath, "w", encoding="utf-8") as f:
f.write(content)
documents.append(content)
logger.info(f"Written knowledge base file: {filename}")
return knowledge_base
except Exception as e:
logger.error(f"Error loading knowledge base: {str(e)}")
raise
def _validate_rag_setup(self):
"""Validate RAG system setup"""
try:
# Verify embeddings are working
test_text = "This is a test embedding"
test_embedding = self.embeddings.encode(test_text)
assert len(test_embedding) > 0
# Verify vector store is operational
test_results = self.vector_store.similarity_search(test_text, k=1)
assert len(test_results) > 0
logger.info("RAG system validation successful")
return True
except Exception as e:
logger.error(f"RAG system validation failed: {str(e)}")
raise
def setup_model_and_tokenizer(model_name="google/gemma-2b"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
llm_int8_enable_fp32_cpu_offload=True
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=4,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model, tokenizer
def setup_training_arguments(output_dir="./pearly_fine_tuned"):
return TrainingArguments(
output_dir=output_dir,
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
warmup_steps=50,
logging_steps=10,
save_steps=200,
learning_rate=2e-4,
fp16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="adamw_8bit",
max_grad_norm=0.3,
weight_decay=0.001,
logging_dir="./logs",
save_total_limit=2,
remove_unused_columns=False,
dataloader_pin_memory=False,
max_steps=500,
report_to=["none"],
)
def train(self):
"""Main training pipeline with RAG integration"""
try:
logger.info("Starting training pipeline")
# Clear GPU memory
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# Setup model, tokenizer, and RAG
logger.info("Setting up model components...")
self.model, self.tokenizer = self.setup_model_and_tokenizer()
self.setup_rag()
# Prepare and process datasets
logger.info("Preparing datasets...")
self.dataset = self.prepare_initial_datasets(batch_size=4)
self.processed_dataset = self.prepare_dataset(
self.dataset,
self.tokenizer,
max_length=256,
batch_size=2
)
# Train model
logger.info("Starting training...")
training_args = self.setup_training_arguments()
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=self.processed_dataset,
tokenizer=self.tokenizer
)
trainer.train()
# Save and push to hub
logger.info("Saving model...")
trainer.save_model()
if os.getenv('HF_TOKEN'):
trainer.push_to_hub(
"Pearilsa/pearly_med_triage_chatbot_kagglex",
private=True
)
logger.info("Training completed successfully!")
except Exception as e:
logger.error(f"Training failed: {e}")
raise
finally:
torch.cuda.empty_cache()
if __name__ == "__main__":
# Initialize trainer
trainer = ModelTrainer()
# Train model
trainer.train()
def _get_enhanced_context(self, query: str) -> str:
"""Get relevant context with scores"""
try:
# Get documents with similarity scores
docs_and_scores = self.vector_store.similarity_search_with_score(
query,
k=self.num_relevant_chunks
)
# Filter and format relevant contexts
relevant_contexts = []
for doc, score in docs_and_scores:
if score < 0.8: # Lower score means more relevant
source = doc.metadata.get('source', 'Unknown')
relevant_contexts.append(
f"[Source: {source}]\n{doc.page_content}"
)
return "\n\n".join(relevant_contexts) if relevant_contexts else ""
except Exception as e:
logger.error(f"Error retrieving enhanced context: {e}")
return ""
def _initialize_embeddings(self):
try:
return HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
cache_folder="./embeddings_cache" # Added caching
)
except Exception as e:
logger.error(f"Failed to initialize embeddings: {str(e)}")
raise
def _split_texts(self, knowledge_base):
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
length_function=len,
add_start_index=True
)
all_texts = []
for content in knowledge_base.values():
texts = splitter.split_text(content)
all_texts.extend(texts)
return all_texts
def get_relevant_context(self, query):
try:
docs = self.vector_store.similarity_search(query, k=3)
return "\n".join(doc.page_content for doc in docs)
except Exception as e:
logger.error(f"Error retrieving context: {str(e)}")
return ""
@torch.inference_mode()
def generate_response(self, message: str, history: list) -> str:
"""Generate response using both fine-tuned model and RAG"""
try:
# Rate limiting and memory management
current_time = time.time()
if current_time - self.last_interaction_time < self.interaction_cooldown:
time.sleep(self.interaction_cooldown)
torch.cuda.empty_cache()
# Get enhanced context from RAG
context = self._get_enhanced_context(message)
# Format conversation history
conv_history = "\n".join([
f"User: {turn['input']}\nAssistant: {turn['output']}"
for turn in history[-3:] # Keep last 3 turns
])
# Create enhanced prompt with RAG context
prompt = f"""<start_of_turn>system
Using these medical guidelines:
{context}
Previous conversation:
{conv_history}
Guidelines:
1. Assess symptoms and severity based on both your training and the provided guidelines
2. Ask relevant follow-up questions if needed
3. Direct to appropriate care (999, 111, or GP) according to symptom severity
4. Show empathy and cultural sensitivity
5. Never diagnose or recommend treatments
<end_of_turn>
<start_of_turn>user
{message}
<end_of_turn>
<start_of_turn>assistant"""
# Generate response with model
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_new_tokens=20,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
no_repeat_ngram_size=3
)
# Process response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("<start_of_turn>assistant")[-1].strip()
if "<end_of_turn>" in response:
response = response.split("<end_of_turn>")[0].strip()
self.last_interaction_time = time.time()
return response
except Exception as e:
logger.error(f"Error generating response: {e}")
return "I apologize, but I encountered an error. Please try again."
def handle_feedback(self, message: str, response: str, feedback: int):
"""Handle user feedback for responses"""
try:
timestamp = datetime.now().isoformat()
feedback_data = {
"message": message,
"response": response,
"feedback": feedback,
"timestamp": timestamp
}
# Log feedback
logger.info(f"Feedback received: {feedback_data}")
# Here you could:
# 1. Store feedback in a database
# 2. Send to monitoring system
# 3. Use for model improvements
return True
except Exception as e:
logger.error(f"Error handling feedback: {e}")
return False
def __del__(self):
"""Cleanup resources"""
try:
if hasattr(self, 'model'):
del self.model
ModelManager.clear_gpu_memory()
except Exception as e:
logger.error(f"Error in cleanup: {e}")
def create_demo():
try:
# Initialize bot
bot = PearlyBot()
def chat(message: str, history: list):
"""Handle chat interactions"""
try:
if not message.strip():
return history
response = bot.generate_response(message, history)
history.append({
"role": "user",
"content": message
})
history.append({
"role": "assistant",
"content": response
})
return history
except Exception as e:
logger.error(f"Chat error: {e}")
return history + [{
"role": "assistant",
"content": "I apologize, but I'm experiencing technical difficulties. For emergencies, please call 999."
}]
def process_feedback(positive: bool, comment: str, history: list):
try:
if not history or len(history) < 2:
return gr.update(value="")
last_user_msg = history[-2]["content"] if isinstance(history[-2], dict) else history[-2][0]
last_bot_msg = history[-1]["content"] if isinstance(history[-1], dict) else history[-1][1]
bot.handle_feedback(
message=last_user_msg,
response=last_bot_msg,
feedback=1 if positive else -1
)
return gr.update(value="")
except Exception as e:
logger.error(f"Error processing feedback: {e}")
return gr.update(value="")
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(...)) as demo:
# 1. First, create all UI elements
# CSS styles
gr.HTML("""<style>...""")
# Emergency Banner
gr.HTML("""<div class="emergency-banner">...""")
# Header
with gr.Row(elem_classes="header"):
gr.Markdown("""# GP Medical Triage Assistant...""")
# Features Grid
gr.HTML("""<div class="features-grid">...""")
# Chat Interface
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(...)
with gr.Row():
msg = gr.Textbox(...)
submit = gr.Button(...)
with gr.Column(scale=1):
# Quick Actions
emergency_btn = gr.Button("🚨 Emergency Info", variant="secondary")
nhs_111_btn = gr.Button("πŸ“ž NHS 111 Info", variant="secondary")
booking_btn = gr.Button("πŸ“… GP Booking", variant="secondary")
# Controls
clear = gr.Button("πŸ—‘οΈ Clear Chat")
# Feedback
with gr.Row():
feedback_positive = gr.Button("πŸ‘", elem_id="thumb-up")
feedback_negative = gr.Button("πŸ‘Ž", elem_id="thumb-down")
feedback_text = gr.Textbox(...)
feedback_submit = gr.Button(...)
# Examples and Guide
with gr.Accordion("Example Messages", open=False):
gr.Examples([...])
with gr.Accordion("NHS Services Guide", open=False):
gr.Markdown("""...""")
# Create enhanced Gradio interface
with gr.Blocks(theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="indigo",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter")
)) as demo:
# Custom CSS for enhanced styling
gr.HTML("""
<style>
.container { max-width: 900px; margin: auto; }
.header { text-align: center; padding: 20px; }
.emergency-banner {
background-color: #ff4444;
color: white;
padding: 10px;
text-align: center;
font-weight: bold;
margin-bottom: 20px;
}
.feature-card {
padding: 15px;
border-radius: 10px;
text-align: center;
transition: transform 0.2s;
color: white;
font-weight: bold;
}
.feature-card:nth-child(1) { background: linear-gradient(135deg, #2193b0, #6dd5ed); }
.feature-card:nth-child(2) { background: linear-gradient(135deg, #834d9b, #d04ed6); }
.feature-card:nth-child(3) { background: linear-gradient(135deg, #ff4b1f, #ff9068); }
.feature-card:nth-child(4) { background: linear-gradient(135deg, #38ef7d, #11998e); }
.feature-card:hover {
transform: translateY(-5px);
box-shadow: 0 5px 15px rgba(0,0,0,0.2);
}
.feature-card span.emoji {
font-size: 2em;
display: block;
margin-bottom: 10px;
}
.message-textbox textarea { resize: none; }
#thumb-up, #thumb-down {
min-width: 60px;
padding: 8px;
margin: 5px;
}
.chatbot-message {
padding: 12px;
margin: 8px 0;
border-radius: 8px;
}
.user-message { background-color: #e3f2fd; }
.assistant-message { background-color: #f5f5f5; }
.feedback-section {
margin-top: 20px;
padding: 15px;
border-radius: 8px;
background-color: #f8f9fa;
}
</style>
""")
# Emergency Banner
gr.HTML("""
<div class="emergency-banner">
🚨 For medical emergencies, always call 999 immediately 🚨
</div>
""")
# Header Section
with gr.Row(elem_classes="header"):
gr.Markdown("""
# GP Medical Triage Assistant - Pearly
Welcome to your personal medical triage assistant. I'm here to help assess your symptoms and guide you to appropriate care.
""")
# Main Features Grid
gr.HTML("""
<div class="features-grid">
<div class="feature-card">
<span class="emoji">πŸ₯</span>
<div>GP Appointments</div>
</div>
<div class="feature-card">
<span class="emoji">πŸ”</span>
<div>Symptom Assessment</div>
</div>
<div class="feature-card">
<span class="emoji">⚑</span>
<div>Urgent Care Guide</div>
</div>
<div class="feature-card">
<span class="emoji">πŸ’Š</span>
<div>Medical Advice</div>
</div>
</div>
""")
# Chat Interface
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
value=[{
"role": "assistant",
"content": "Hello! I'm Pearly, your GP medical assistant. How can I help you today?"
}],
height=500,
elem_id="chatbot",
type="messages",
show_label=False
)
with gr.Row():
msg = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
lines=2,
scale=4,
autofocus=True,
submit_on_enter=True
)
submit = gr.Button("Send", variant="primary", scale=1)
with gr.Column(scale=1):
# Quick Actions Panel
gr.Markdown("### Quick Actions")
emergency_btn = gr.Button("🚨 Emergency Info", variant="secondary")
nhs_111_btn = gr.Button("πŸ“ž NHS 111 Info", variant="secondary")
booking_btn = gr.Button("πŸ“… GP Booking", variant="secondary")
# Controls and Feedback
gr.Markdown("### Controls")
clear = gr.Button("πŸ—‘οΈ Clear Chat")
gr.Markdown("### Feedback")
with gr.Row():
feedback_positive = gr.Button("πŸ‘", elem_id="thumb-up")
feedback_negative = gr.Button("πŸ‘Ž", elem_id="thumb-down")
feedback_text = gr.Textbox(
label="Additional comments",
placeholder="Tell us more...",
lines=2,
visible=True
)
feedback_submit = gr.Button("Submit Feedback", visible=True)
# Examples and Information
with gr.Accordion("Example Messages", open=False):
gr.Examples([
["I've been having severe headaches for the past week"],
["I need to book a routine checkup"],
["I'm feeling very anxious lately and need help"],
["My child has had a fever for 2 days"],
["I need information about COVID-19 testing"]
], inputs=msg)
with gr.Accordion("NHS Services Guide", open=False):
gr.Markdown("""
### Emergency Services (999)
- Life-threatening emergencies
- Severe injuries
- Suspected heart attack or stroke
### NHS 111
- Urgent but non-emergency situations
- Medical advice needed
- Unsure where to go
### GP Services
- Routine check-ups
- Non-urgent medical issues
- Prescription renewals
""")
def show_emergency_info():
return """🚨 Emergency Services (999)
- For life-threatening emergencies
- Severe chest pain
- Difficulty breathing
- Severe bleeding
- Loss of consciousness
"""
def show_nhs_111_info():
return """πŸ“ž NHS 111 Service
- Available 24/7
- Medical advice
- Local service information
- Urgent care guidance
"""
def show_booking_info():
return """πŸ“… GP Booking Options
- Online booking
- Phone booking
- Routine appointments
- Urgent appointments
"""
# Chat handlers
msg.submit(chat, [msg, chatbot], [chatbot]).then(
lambda: gr.update(value=""), None, [msg]
)
submit.click(chat, [msg, chatbot], [chatbot]).then(
lambda: gr.update(value=""), None, [msg]
)
# Quick action handlers
emergency_btn.click(lambda: show_emergency_info(), outputs=[msg])
nhs_111_btn.click(lambda: show_nhs_111_info(), outputs=[msg])
booking_btn.click(lambda: show_booking_info(), outputs=[msg])
# Feedback handlers
feedback_positive.click(
lambda h: process_feedback(True, feedback_text.value, h),
inputs=[chatbot],
outputs=[feedback_text]
)
feedback_negative.click(
lambda h: process_feedback(False, feedback_text.value, h),
inputs=[chatbot],
outputs=[feedback_text]
)
# Clear chat
clear.click(lambda: None, None, chatbot)
# 3. Finally, add the queue
demo.queue(concurrency_count=1, max_size=10)
return demo
except Exception as e:
logger.error(f"Error creating demo: {e}")
raise
if __name__ == "__main__":
# Initialize logging and load env vars
logging.basicConfig(level=logging.INFO)
load_dotenv()
# Create and launch demo
demo = create_demo()
demo.launch(server_name="0.0.0.0", server_port=7860)