PearlIsa's picture
Update app.py
ce580ca verified
raw
history blame
29.4 kB
# app.py
import os
import json
import keras
from datasets import load_dataset
import tensorflow as tf
from huggingface_hub import login
import torch
from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer)
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Union, Tuple
import faiss
import numpy as np
from datasets import Dataset
import torch.nn.functional as F
from torch.cuda.amp import autocast
import gc
from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel)
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import logging
import wandb
from pathlib import Path
from typing import List, Dict, Union, Optional, Any
import torch.nn as nn
from dataclasses import dataclass, field
import time
import asyncio
import pytest
from unittest.mock import Mock, patch
from sklearn.metrics import classification_report, confusion_matrix
import gradio as gr
import matplotlib.pyplot as plt
from datetime import datetime
import requests
import pandas as pd
import seaborn as sns
import traceback
from matplotlib.gridspec import GridSpec
from datasets import load_dataset, concatenate_datasets
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import TextLoader
from google.colab import output
import IPython.display as display
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
# Ensure Hugging Face login
try:
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("Login successful!")
except Exception as e:
print("Hugging Face Login failed:", e)
# CUDA and Memory Configurations
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
def prepare_initial_datasets(batch_size=8):
print("Loading datasets with memory-optimized batch processing...")
def process_medqa_batch(examples):
results = []
inputs = examples['input']
instructions = examples['instruction']
outputs = examples['output']
for inp, inst, out in zip(inputs, instructions, outputs):
results.append({
"input": f"{inp} {inst}",
"output": out
})
return results
def process_meddia_batch(examples):
results = []
inputs = examples['input']
outputs = examples['output']
for inp, out in zip(inputs, outputs):
results.append({
"input": inp,
"output": out
})
return results
def process_persona_batch(examples):
results = []
personalities = examples['personality']
utterances = examples['utterances']
for pers, utts in zip(personalities, utterances):
try:
# Process personality list
personality = ' '.join([
p for p in pers
if isinstance(p, str)
])
# Process utterances
if utts and len(utts) > 0:
utterance = utts[0]
history = []
# Process history
if 'history' in utterance and utterance['history']:
history = [
h for h in utterance['history']
if isinstance(h, str)
]
history_text = ' '.join(history)
# Get candidate response
candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else ''
if personality or history_text:
results.append({
"input": f"{personality} {history_text}".strip(),
"output": candidate
})
except Exception as e:
print(f"Error processing persona batch item: {e}")
continue
return results
# Load and process each dataset separately
print("Processing MedQA dataset...")
medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]")
medqa_processed = []
for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"):
batch = medqa[i:i + batch_size]
medqa_processed.extend(process_medqa_batch(batch))
if i % (batch_size * 5) == 0:
torch.cuda.empty_cache()
print("Processing MedDiagnosis dataset...")
meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]")
meddia_processed = []
for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"):
batch = meddia[i:i + batch_size]
meddia_processed.extend(process_meddia_batch(batch))
if i % (batch_size * 5) == 0:
torch.cuda.empty_cache()
print("Processing Persona-Chat dataset...")
persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
persona_processed = []
for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"):
batch = persona[i:i + batch_size]
persona_processed.extend(process_persona_batch(batch))
if i % (batch_size * 5) == 0:
torch.cuda.empty_cache()
torch.cuda.empty_cache()
print("Creating final dataset...")
all_processed = persona_processed + medqa_processed + meddia_processed
valid_data = {
"input": [],
"output": []
}
for item in all_processed:
if item["input"].strip() and item["output"].strip():
valid_data["input"].append(item["input"])
valid_data["output"].append(item["output"])
final_dataset = Dataset.from_dict(valid_data)
print(f"Final dataset size: {len(final_dataset)}")
return final_dataset
def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4):
def tokenize_batch(examples):
formatted_texts = []
for i in range(0, len(examples['input']), batch_size):
sub_batch_inputs = examples['input'][i:i + batch_size]
sub_batch_outputs = examples['output'][i:i + batch_size]
for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs):
try:
formatted_text = f"""<start_of_turn>user
{input_text}
<end_of_turn>
<start_of_turn>assistant
{output_text}
<end_of_turn>"""
formatted_texts.append(formatted_text)
except Exception as e:
print(f"Error formatting text: {e}")
continue
tokenized = tokenizer(
formatted_texts,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors=None
)
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
print(f"Tokenizing dataset in small batches (size={batch_size})...")
tokenized_dataset = dataset.map(
tokenize_batch,
batched=True,
batch_size=batch_size,
remove_columns=dataset.column_names,
desc="Tokenizing dataset",
load_from_cache_file=False
)
return tokenized_dataset
def setup_model_and_tokenizer(model_name="google/gemma-2b"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
llm_int8_enable_fp32_cpu_offload=True
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=4,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model, tokenizer
def setup_training_arguments(output_dir="./pearly_fine_tuned"):
return TrainingArguments(
output_dir=output_dir,
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
warmup_steps=50,
logging_steps=10,
save_steps=200,
learning_rate=2e-4,
fp16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="adamw_8bit",
max_grad_norm=0.3,
weight_decay=0.001,
logging_dir="./logs",
save_total_limit=2,
remove_unused_columns=False,
dataloader_pin_memory=False,
max_steps=500,
report_to=["none"],
)
def main():
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
print("Preparing initial datasets...")
combined_dataset = prepare_initial_datasets(batch_size=4)
print(f"\nDataset size: {len(combined_dataset)}")
print(f"Column names: {combined_dataset.column_names}")
if len(combined_dataset) > 0:
print("\nSample input-output pair:")
print(f"Input: {combined_dataset[0]['input'][:100]}...")
print(f"Output: {combined_dataset[0]['output'][:100]}...")
print("\nSetting up model and tokenizer...")
model, tokenizer = setup_model_and_tokenizer()
print("\nPreparing dataset for training...")
processed_dataset = prepare_dataset(
combined_dataset,
tokenizer,
max_length=256,
batch_size=2
)
torch.cuda.empty_cache()
training_args = setup_training_arguments()
trainer = Trainer(
model=model,
args=training_args,
train_dataset=processed_dataset,
tokenizer=tokenizer,
)
print("\nStarting training...")
try:
trainer.train()
except Exception as e:
print(f"Training error: {e}")
torch.cuda.empty_cache()
raise e
finally:
torch.cuda.empty_cache()
print("\nSaving model...")
trainer.save_model()
print("Training completed!")
DISCLAIMER = """
IMPORTANT MEDICAL DISCLAIMER:
Pearly is an AI medical triage assistant designed to help direct you to appropriate medical services.
Pearly DOES NOT:
- Make medical diagnoses
- Prescribe medications
- Provide specific treatment recommendations
- Replace professional medical advice
Always consult qualified healthcare professionals for medical advice and treatment.
In case of emergency, call 999 immediately.
"""
class PearlyBot:
def __init__(self, model_path="./pearly_fine_tuned", embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
print("Loading saved model...")
print(DISCLAIMER)
# Clean memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Load tokenizer and model directly from saved path
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto"
)
self.model.eval() # Set to evaluation mode
# Initialize RAG components
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
self.vector_store = None
self.conversation_history = []
def initialize_rag(self, documents_path="./knowledge_base"):
"""Initialize RAG system"""
print("Loading knowledge base...")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=300,
chunk_overlap=100,
separators=["\n\n", "\n", ".", "!", "?", ":"]
)
documents = []
for filename in os.listdir(documents_path):
if filename.endswith('.txt'):
loader = TextLoader(os.path.join(documents_path, filename))
documents.extend(loader.load())
texts = text_splitter.split_documents(documents)
self.vector_store = FAISS.from_documents(texts, self.embeddings)
self.retriever = self.vector_store.as_retriever(
search_type="similarity",
search_kwargs={"k": 5}
)
print("Knowledge base loaded successfully!")
def get_relevant_context(self, user_input):
if not self.retriever:
return ""
docs = self.retriever.get_relevant_documents(user_input)
return "\n\n".join([doc.page_content for doc in docs])
def generate_response(self, user_input):
context = self.get_relevant_context(user_input)
history = "\n".join([
f"User: {turn['user']}\nAssistant: {turn['assistant']}\n"
for turn in self.conversation_history[-3:]
])
prompt = f"""<start_of_turn>system
As Pearly, I use the following medical guidelines to help triage patients:
{context}
Previous Conversation:
{history}
Based on these guidelines, I will:
1. Assess symptoms and severity
2. Ask relevant follow-up questions
3. Direct to appropriate care (999, 111, or GP)
4. Show empathy and cultural sensitivity
5. Never diagnose or recommend treatments
<end_of_turn>
<start_of_turn>user
{user_input}
<end_of_turn>
<start_of_turn>assistant"""
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_new_tokens=20,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=self.tokenizer.pad_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("<start_of_turn>assistant")[-1].strip()
if "<end_of_turn>" in response:
response = response.split("<end_of_turn>")[0].strip()
self.conversation_history.append({
"user": user_input,
"assistant": response
})
return response
def create_demo():
"""Set up Gradio interface for the chatbot with enhanced styling and functionality."""
try:
# Health check
@gr.routes.get("/health")
def health_check():
return {"status": "healthy"}
bot = AdaptiveMedicalBot()
def chat(message: str, history: List[Dict[str, str]]):
try:
if not message.strip():
return history
bot_response = bot.generate_response(message)
# Add user message
history.append({
"role": "user",
"content": message
})
# Add bot response
history.append({
"role": "assistant",
"content": bot_response['response']
})
return history
except Exception as e:
logger.error(f"Chat error: {e}")
history.append({
"role": "user",
"content": message
})
history.append({
"role": "assistant",
"content": "I apologize, but I'm experiencing technical difficulties. For emergencies, please call 999."
})
return history
def process_feedback(is_positive: bool, comment: str, history: List[Dict[str, str]]):
try:
if not history:
return
last_interaction = history[-2:] # Get last user message and bot response
if len(last_interaction) == 2:
user_msg = last_interaction[0]["content"]
bot_msg = last_interaction[1]["content"]
feedback_data = {
"user_message": user_msg,
"bot_response": bot_msg,
"feedback": 1 if is_positive else -1,
"comment": comment,
"timestamp": datetime.now().isoformat()
}
bot.handle_feedback(
message=user_msg,
response=bot_msg,
feedback=1 if is_positive else -1
)
# Clear feedback inputs
return gr.update(value="")
except Exception as e:
logger.error(f"Error processing feedback: {e}")
# Create enhanced Gradio interface
with gr.Blocks(theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="indigo",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter")
)) as demo:
# Custom CSS for enhanced styling
gr.HTML("""
<style>
.container { max-width: 900px; margin: auto; }
.header { text-align: center; padding: 20px; }
.emergency-banner {
background-color: #ff4444;
color: white;
padding: 10px;
text-align: center;
font-weight: bold;
margin-bottom: 20px;
}
.feature-card {
padding: 15px;
border-radius: 10px;
text-align: center;
transition: transform 0.2s;
color: white;
font-weight: bold;
}
/* Specific backgrounds for each emoji card */
.feature-card:nth-child(1) {
background: linear-gradient(135deg, #2193b0, #6dd5ed); /* Hospital emoji */
}
.feature-card:nth-child(2) {
background: linear-gradient(135deg, #834d9b, #d04ed6); /* Magnifying glass emoji */
}
.feature-card:nth-child(3) {
background: linear-gradient(135deg, #ff4b1f, #ff9068); /* Emergency emoji */
}
.feature-card:nth-child(4) {
background: linear-gradient(135deg, #38ef7d, #11998e); /* Medical advice emoji */
}
/* Hover effect */
.feature-card:hover {
transform: translateY(-5px);
box-shadow: 0 5px 15px rgba(0,0,0,0.2);
}
}
/* Make emojis larger and more visible */
.feature-card span.emoji {
font-size: 2em;
display: block;
margin-bottom: 10px;
}
</style>
""")
# Emergency Banner
gr.HTML("""
<div class="emergency-banner">
🚨 For medical emergencies, always call 999 immediately 🚨
</div>
""")
# Header Section
with gr.Row(elem_classes="header"):
gr.Markdown("""
# GP Medical Triage Assistant - Pearly
Welcome to your personal medical triage assistant. I'm here to help assess your symptoms and guide you to appropriate care.
""")
# Main Features Grid
gr.HTML("""
<div class="features-grid">
<div class="feature-card">
<span class="emoji">πŸ₯</span>
<div>GP Appointments</div>
</div>
<div class="feature-card">
<span class="emoji">πŸ”</span>
<div>Symptom Assessment</div>
</div>
<div class="feature-card">
<span class="emoji">⚑</span>
<div>Urgent Care Guide</div>
</div>
<div class="feature-card">
<span class="emoji">πŸ’Š</span>
<div>Medical Advice</div>
</div>
</div>
<style>
/* Enable Enter key submission */
.message-textbox textarea {
resize: none;
}
/* Improve feedback buttons */
#thumb-up, #thumb-down {
min-width: 60px;
padding: 8px;
margin: 5px;
}
/* Improve chat messages */
.chatbot-message {
padding: 12px;
margin: 8px 0;
border-radius: 8px;
}
.user-message {
background-color: #e3f2fd;
}
.assistant-message {
background-color: #f5f5f5;
}
/* Feedback section */
.feedback-section {
margin-top: 20px;
padding: 15px;
border-radius: 8px;
background-color: #f8f9fa;
}
</style>
""")
# Chat Interface
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
value=[{
"role": "assistant",
"content": "Hello! I'm Pearly, your GP medical assistant. How can I help you today?"
}],
height=500,
elem_id="chatbot",
type="messages",
show_label=False
)
with gr.Row():
msg = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
lines=2,
scale=4,
autofocus=True,
submit_on_enter=True
)
submit = gr.Button("Send", variant="primary", scale=1)
with gr.Column(scale=1):
# Quick Actions Panel
gr.Markdown("### Quick Actions")
emergency_btn = gr.Button("🚨 Emergency Info", variant="secondary")
nhs_111_btn = gr.Button("πŸ“ž NHS 111 Info", variant="secondary")
booking_btn = gr.Button("πŸ“… GP Booking", variant="secondary")
# Conversation Controls
gr.Markdown("### Controls")
clear = gr.Button("πŸ—‘οΈ Clear Chat")
# Feedback Section
gr.Markdown("### Feedback")
with gr.Row():
feedback_positive = gr.Button("πŸ‘", elem_id="thumb-up")
feedback_negative = gr.Button("πŸ‘Ž", elem_id="thumb-down")
feedback_text = gr.Textbox(
label="Additional comments (optional)",
placeholder="Tell us more about your experience...",
lines=2,
visible=True
)
feedback_submit = gr.Button("Submit Feedback", visible=True)
# Examples Section
with gr.Accordion("Example Messages", open=False):
gr.Examples(
examples=[
["I've been having severe headaches for the past week"],
["I need to book a routine checkup"],
["I'm feeling very anxious lately and need help"],
["My child has had a fever for 2 days"],
["I need information about COVID-19 testing"]
],
inputs=msg
)
# Information Accordions
with gr.Accordion("NHS Services Guide", open=False):
gr.Markdown("""
### Emergency Services (999)
- Life-threatening emergencies
- Severe injuries
- Suspected heart attack or stroke
### NHS 111
- Urgent but non-emergency situations
- Medical advice needed
- Unsure where to go
### GP Services
- Routine check-ups
- Non-urgent medical issues
- Prescription renewals
""")
# Event Handlers
# Message submission handlers
msg.submit(
chat,
inputs=[msg, chatbot],
outputs=[chatbot]
).then(
lambda: gr.update(value=""),
None,
[msg]
)
submit.click(
chat,
inputs=[msg, chatbot],
outputs=[chatbot]
).then(
lambda: gr.update(value=""),
None,
[msg]
)
# Feedback handlers
feedback_positive.click(
lambda history: process_feedback(True, feedback_text.value, history),
inputs=[chatbot],
outputs=[feedback_text]
)
feedback_negative.click(
lambda history: process_feedback(False, feedback_text.value, history),
inputs=[chatbot],
outputs=[feedback_text]
)
feedback_submit.click(
lambda: gr.update(value=""),
outputs=[feedback_text]
)
# Quick Action Button Handlers
def show_emergency_info():
return """🚨 Emergency Services (999)
- For life-threatening emergencies
- Severe chest pain
- Difficulty breathing
- Severe bleeding
- Loss of consciousness
"""
def show_nhs_111_info():
return """πŸ“ž NHS 111 Service
- Available 24/7
- Medical advice
- Local service information
- Urgent care guidance
"""
def show_booking_info():
return """πŸ“… GP Booking Options
- Online booking
- Phone booking
- Routine appointments
- Urgent appointments
"""
emergency_btn.click(lambda: show_emergency_info(), outputs=[msg])
nhs_111_btn.click(lambda: show_nhs_111_info(), outputs=[msg])
booking_btn.click(lambda: show_booking_info(), outputs=[msg])
return demo
except Exception as e:
logger.error(f"Error creating demo: {e}")
raise
if __name__ == "__main__":
load_dotenv() # Load environment variables
demo = create_demo() # Launch the Gradio app
demo.launch(share=True)