Spaces:
Runtime error
Runtime error
# app.py | |
import os | |
import json | |
import keras | |
from datasets import load_dataset | |
import tensorflow as tf | |
from huggingface_hub import login | |
import torch | |
from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer) | |
from sentence_transformers import SentenceTransformer | |
from typing import List, Dict, Union, Tuple | |
import faiss | |
import numpy as np | |
from datasets import Dataset | |
import torch.nn.functional as F | |
from torch.cuda.amp import autocast | |
import gc | |
from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel) | |
from tqdm.auto import tqdm | |
from torch.utils.data import DataLoader | |
import logging | |
import wandb | |
from pathlib import Path | |
from typing import List, Dict, Union, Optional, Any | |
import torch.nn as nn | |
from dataclasses import dataclass, field | |
import time | |
import asyncio | |
import pytest | |
from unittest.mock import Mock, patch | |
from sklearn.metrics import classification_report, confusion_matrix | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from datetime import datetime | |
import requests | |
import pandas as pd | |
import seaborn as sns | |
import traceback | |
from matplotlib.gridspec import GridSpec | |
from datasets import load_dataset, concatenate_datasets | |
from langchain.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.document_loaders import TextLoader | |
from google.colab import output | |
import IPython.display as display | |
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training | |
# Ensure Hugging Face login | |
try: | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
login(token=hf_token) | |
print("Login successful!") | |
except Exception as e: | |
print("Hugging Face Login failed:", e) | |
# CUDA and Memory Configurations | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cudnn.allow_tf32 = False | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True' | |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
def prepare_initial_datasets(batch_size=8): | |
print("Loading datasets with memory-optimized batch processing...") | |
def process_medqa_batch(examples): | |
results = [] | |
inputs = examples['input'] | |
instructions = examples['instruction'] | |
outputs = examples['output'] | |
for inp, inst, out in zip(inputs, instructions, outputs): | |
results.append({ | |
"input": f"{inp} {inst}", | |
"output": out | |
}) | |
return results | |
def process_meddia_batch(examples): | |
results = [] | |
inputs = examples['input'] | |
outputs = examples['output'] | |
for inp, out in zip(inputs, outputs): | |
results.append({ | |
"input": inp, | |
"output": out | |
}) | |
return results | |
def process_persona_batch(examples): | |
results = [] | |
personalities = examples['personality'] | |
utterances = examples['utterances'] | |
for pers, utts in zip(personalities, utterances): | |
try: | |
# Process personality list | |
personality = ' '.join([ | |
p for p in pers | |
if isinstance(p, str) | |
]) | |
# Process utterances | |
if utts and len(utts) > 0: | |
utterance = utts[0] | |
history = [] | |
# Process history | |
if 'history' in utterance and utterance['history']: | |
history = [ | |
h for h in utterance['history'] | |
if isinstance(h, str) | |
] | |
history_text = ' '.join(history) | |
# Get candidate response | |
candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else '' | |
if personality or history_text: | |
results.append({ | |
"input": f"{personality} {history_text}".strip(), | |
"output": candidate | |
}) | |
except Exception as e: | |
print(f"Error processing persona batch item: {e}") | |
continue | |
return results | |
# Load and process each dataset separately | |
print("Processing MedQA dataset...") | |
medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]") | |
medqa_processed = [] | |
for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"): | |
batch = medqa[i:i + batch_size] | |
medqa_processed.extend(process_medqa_batch(batch)) | |
if i % (batch_size * 5) == 0: | |
torch.cuda.empty_cache() | |
print("Processing MedDiagnosis dataset...") | |
meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]") | |
meddia_processed = [] | |
for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"): | |
batch = meddia[i:i + batch_size] | |
meddia_processed.extend(process_meddia_batch(batch)) | |
if i % (batch_size * 5) == 0: | |
torch.cuda.empty_cache() | |
print("Processing Persona-Chat dataset...") | |
persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]") | |
persona_processed = [] | |
for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"): | |
batch = persona[i:i + batch_size] | |
persona_processed.extend(process_persona_batch(batch)) | |
if i % (batch_size * 5) == 0: | |
torch.cuda.empty_cache() | |
torch.cuda.empty_cache() | |
print("Creating final dataset...") | |
all_processed = persona_processed + medqa_processed + meddia_processed | |
valid_data = { | |
"input": [], | |
"output": [] | |
} | |
for item in all_processed: | |
if item["input"].strip() and item["output"].strip(): | |
valid_data["input"].append(item["input"]) | |
valid_data["output"].append(item["output"]) | |
final_dataset = Dataset.from_dict(valid_data) | |
print(f"Final dataset size: {len(final_dataset)}") | |
return final_dataset | |
def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4): | |
def tokenize_batch(examples): | |
formatted_texts = [] | |
for i in range(0, len(examples['input']), batch_size): | |
sub_batch_inputs = examples['input'][i:i + batch_size] | |
sub_batch_outputs = examples['output'][i:i + batch_size] | |
for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs): | |
try: | |
formatted_text = f"""<start_of_turn>user | |
{input_text} | |
<end_of_turn> | |
<start_of_turn>assistant | |
{output_text} | |
<end_of_turn>""" | |
formatted_texts.append(formatted_text) | |
except Exception as e: | |
print(f"Error formatting text: {e}") | |
continue | |
tokenized = tokenizer( | |
formatted_texts, | |
padding="max_length", | |
truncation=True, | |
max_length=max_length, | |
return_tensors=None | |
) | |
tokenized["labels"] = tokenized["input_ids"].copy() | |
return tokenized | |
print(f"Tokenizing dataset in small batches (size={batch_size})...") | |
tokenized_dataset = dataset.map( | |
tokenize_batch, | |
batched=True, | |
batch_size=batch_size, | |
remove_columns=dataset.column_names, | |
desc="Tokenizing dataset", | |
load_from_cache_file=False | |
) | |
return tokenized_dataset | |
def setup_model_and_tokenizer(model_name="google/gemma-2b"): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
from transformers import BitsAndBytesConfig | |
bnb_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_compute_dtype=torch.float16, | |
llm_int8_enable_fp32_cpu_offload=True | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
quantization_config=bnb_config, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
) | |
model = prepare_model_for_kbit_training(model) | |
lora_config = LoraConfig( | |
r=4, | |
lora_alpha=16, | |
target_modules=["q_proj", "v_proj"], | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM" | |
) | |
model = get_peft_model(model, lora_config) | |
model.print_trainable_parameters() | |
return model, tokenizer | |
def setup_training_arguments(output_dir="./pearly_fine_tuned"): | |
return TrainingArguments( | |
output_dir=output_dir, | |
num_train_epochs=1, | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=16, | |
warmup_steps=50, | |
logging_steps=10, | |
save_steps=200, | |
learning_rate=2e-4, | |
fp16=True, | |
gradient_checkpointing=True, | |
gradient_checkpointing_kwargs={"use_reentrant": False}, | |
optim="adamw_8bit", | |
max_grad_norm=0.3, | |
weight_decay=0.001, | |
logging_dir="./logs", | |
save_total_limit=2, | |
remove_unused_columns=False, | |
dataloader_pin_memory=False, | |
max_steps=500, | |
report_to=["none"], | |
) | |
def main(): | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cudnn.allow_tf32 = False | |
torch.cuda.empty_cache() | |
if torch.cuda.is_available(): | |
torch.cuda.reset_peak_memory_stats() | |
print("Preparing initial datasets...") | |
combined_dataset = prepare_initial_datasets(batch_size=4) | |
print(f"\nDataset size: {len(combined_dataset)}") | |
print(f"Column names: {combined_dataset.column_names}") | |
if len(combined_dataset) > 0: | |
print("\nSample input-output pair:") | |
print(f"Input: {combined_dataset[0]['input'][:100]}...") | |
print(f"Output: {combined_dataset[0]['output'][:100]}...") | |
print("\nSetting up model and tokenizer...") | |
model, tokenizer = setup_model_and_tokenizer() | |
print("\nPreparing dataset for training...") | |
processed_dataset = prepare_dataset( | |
combined_dataset, | |
tokenizer, | |
max_length=256, | |
batch_size=2 | |
) | |
torch.cuda.empty_cache() | |
training_args = setup_training_arguments() | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=processed_dataset, | |
tokenizer=tokenizer, | |
) | |
print("\nStarting training...") | |
try: | |
trainer.train() | |
except Exception as e: | |
print(f"Training error: {e}") | |
torch.cuda.empty_cache() | |
raise e | |
finally: | |
torch.cuda.empty_cache() | |
print("\nSaving model...") | |
trainer.save_model() | |
print("Training completed!") | |
DISCLAIMER = """ | |
IMPORTANT MEDICAL DISCLAIMER: | |
Pearly is an AI medical triage assistant designed to help direct you to appropriate medical services. | |
Pearly DOES NOT: | |
- Make medical diagnoses | |
- Prescribe medications | |
- Provide specific treatment recommendations | |
- Replace professional medical advice | |
Always consult qualified healthcare professionals for medical advice and treatment. | |
In case of emergency, call 999 immediately. | |
""" | |
class PearlyBot: | |
def __init__(self, model_path="./pearly_fine_tuned", embedding_model="sentence-transformers/all-MiniLM-L6-v2"): | |
print("Loading saved model...") | |
print(DISCLAIMER) | |
# Clean memory | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Load tokenizer and model directly from saved path | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
device_map="auto" | |
) | |
self.model.eval() # Set to evaluation mode | |
# Initialize RAG components | |
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model) | |
self.vector_store = None | |
self.conversation_history = [] | |
def initialize_rag(self, documents_path="./knowledge_base"): | |
"""Initialize RAG system""" | |
print("Loading knowledge base...") | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=300, | |
chunk_overlap=100, | |
separators=["\n\n", "\n", ".", "!", "?", ":"] | |
) | |
documents = [] | |
for filename in os.listdir(documents_path): | |
if filename.endswith('.txt'): | |
loader = TextLoader(os.path.join(documents_path, filename)) | |
documents.extend(loader.load()) | |
texts = text_splitter.split_documents(documents) | |
self.vector_store = FAISS.from_documents(texts, self.embeddings) | |
self.retriever = self.vector_store.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 5} | |
) | |
print("Knowledge base loaded successfully!") | |
def get_relevant_context(self, user_input): | |
if not self.retriever: | |
return "" | |
docs = self.retriever.get_relevant_documents(user_input) | |
return "\n\n".join([doc.page_content for doc in docs]) | |
def generate_response(self, user_input): | |
context = self.get_relevant_context(user_input) | |
history = "\n".join([ | |
f"User: {turn['user']}\nAssistant: {turn['assistant']}\n" | |
for turn in self.conversation_history[-3:] | |
]) | |
prompt = f"""<start_of_turn>system | |
As Pearly, I use the following medical guidelines to help triage patients: | |
{context} | |
Previous Conversation: | |
{history} | |
Based on these guidelines, I will: | |
1. Assess symptoms and severity | |
2. Ask relevant follow-up questions | |
3. Direct to appropriate care (999, 111, or GP) | |
4. Show empathy and cultural sensitivity | |
5. Never diagnose or recommend treatments | |
<end_of_turn> | |
<start_of_turn>user | |
{user_input} | |
<end_of_turn> | |
<start_of_turn>assistant""" | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512 | |
).to(self.model.device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=256, | |
min_new_tokens=20, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
pad_token_id=self.tokenizer.pad_token_id | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response.split("<start_of_turn>assistant")[-1].strip() | |
if "<end_of_turn>" in response: | |
response = response.split("<end_of_turn>")[0].strip() | |
self.conversation_history.append({ | |
"user": user_input, | |
"assistant": response | |
}) | |
return response | |
def create_demo(): | |
"""Set up Gradio interface for the chatbot with enhanced styling and functionality.""" | |
try: | |
# Health check | |
def health_check(): | |
return {"status": "healthy"} | |
bot = AdaptiveMedicalBot() | |
def chat(message: str, history: List[Dict[str, str]]): | |
try: | |
if not message.strip(): | |
return history | |
bot_response = bot.generate_response(message) | |
# Add user message | |
history.append({ | |
"role": "user", | |
"content": message | |
}) | |
# Add bot response | |
history.append({ | |
"role": "assistant", | |
"content": bot_response['response'] | |
}) | |
return history | |
except Exception as e: | |
logger.error(f"Chat error: {e}") | |
history.append({ | |
"role": "user", | |
"content": message | |
}) | |
history.append({ | |
"role": "assistant", | |
"content": "I apologize, but I'm experiencing technical difficulties. For emergencies, please call 999." | |
}) | |
return history | |
def process_feedback(is_positive: bool, comment: str, history: List[Dict[str, str]]): | |
try: | |
if not history: | |
return | |
last_interaction = history[-2:] # Get last user message and bot response | |
if len(last_interaction) == 2: | |
user_msg = last_interaction[0]["content"] | |
bot_msg = last_interaction[1]["content"] | |
feedback_data = { | |
"user_message": user_msg, | |
"bot_response": bot_msg, | |
"feedback": 1 if is_positive else -1, | |
"comment": comment, | |
"timestamp": datetime.now().isoformat() | |
} | |
bot.handle_feedback( | |
message=user_msg, | |
response=bot_msg, | |
feedback=1 if is_positive else -1 | |
) | |
# Clear feedback inputs | |
return gr.update(value="") | |
except Exception as e: | |
logger.error(f"Error processing feedback: {e}") | |
# Create enhanced Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="indigo", | |
neutral_hue="slate", | |
font=gr.themes.GoogleFont("Inter") | |
)) as demo: | |
# Custom CSS for enhanced styling | |
gr.HTML(""" | |
<style> | |
.container { max-width: 900px; margin: auto; } | |
.header { text-align: center; padding: 20px; } | |
.emergency-banner { | |
background-color: #ff4444; | |
color: white; | |
padding: 10px; | |
text-align: center; | |
font-weight: bold; | |
margin-bottom: 20px; | |
} | |
.feature-card { | |
padding: 15px; | |
border-radius: 10px; | |
text-align: center; | |
transition: transform 0.2s; | |
color: white; | |
font-weight: bold; | |
} | |
/* Specific backgrounds for each emoji card */ | |
.feature-card:nth-child(1) { | |
background: linear-gradient(135deg, #2193b0, #6dd5ed); /* Hospital emoji */ | |
} | |
.feature-card:nth-child(2) { | |
background: linear-gradient(135deg, #834d9b, #d04ed6); /* Magnifying glass emoji */ | |
} | |
.feature-card:nth-child(3) { | |
background: linear-gradient(135deg, #ff4b1f, #ff9068); /* Emergency emoji */ | |
} | |
.feature-card:nth-child(4) { | |
background: linear-gradient(135deg, #38ef7d, #11998e); /* Medical advice emoji */ | |
} | |
/* Hover effect */ | |
.feature-card:hover { | |
transform: translateY(-5px); | |
box-shadow: 0 5px 15px rgba(0,0,0,0.2); | |
} | |
} | |
/* Make emojis larger and more visible */ | |
.feature-card span.emoji { | |
font-size: 2em; | |
display: block; | |
margin-bottom: 10px; | |
} | |
</style> | |
""") | |
# Emergency Banner | |
gr.HTML(""" | |
<div class="emergency-banner"> | |
π¨ For medical emergencies, always call 999 immediately π¨ | |
</div> | |
""") | |
# Header Section | |
with gr.Row(elem_classes="header"): | |
gr.Markdown(""" | |
# GP Medical Triage Assistant - Pearly | |
Welcome to your personal medical triage assistant. I'm here to help assess your symptoms and guide you to appropriate care. | |
""") | |
# Main Features Grid | |
gr.HTML(""" | |
<div class="features-grid"> | |
<div class="feature-card"> | |
<span class="emoji">π₯</span> | |
<div>GP Appointments</div> | |
</div> | |
<div class="feature-card"> | |
<span class="emoji">π</span> | |
<div>Symptom Assessment</div> | |
</div> | |
<div class="feature-card"> | |
<span class="emoji">β‘</span> | |
<div>Urgent Care Guide</div> | |
</div> | |
<div class="feature-card"> | |
<span class="emoji">π</span> | |
<div>Medical Advice</div> | |
</div> | |
</div> | |
<style> | |
/* Enable Enter key submission */ | |
.message-textbox textarea { | |
resize: none; | |
} | |
/* Improve feedback buttons */ | |
#thumb-up, #thumb-down { | |
min-width: 60px; | |
padding: 8px; | |
margin: 5px; | |
} | |
/* Improve chat messages */ | |
.chatbot-message { | |
padding: 12px; | |
margin: 8px 0; | |
border-radius: 8px; | |
} | |
.user-message { | |
background-color: #e3f2fd; | |
} | |
.assistant-message { | |
background-color: #f5f5f5; | |
} | |
/* Feedback section */ | |
.feedback-section { | |
margin-top: 20px; | |
padding: 15px; | |
border-radius: 8px; | |
background-color: #f8f9fa; | |
} | |
</style> | |
""") | |
# Chat Interface | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot( | |
value=[{ | |
"role": "assistant", | |
"content": "Hello! I'm Pearly, your GP medical assistant. How can I help you today?" | |
}], | |
height=500, | |
elem_id="chatbot", | |
type="messages", | |
show_label=False | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
lines=2, | |
scale=4, | |
autofocus=True, | |
submit_on_enter=True | |
) | |
submit = gr.Button("Send", variant="primary", scale=1) | |
with gr.Column(scale=1): | |
# Quick Actions Panel | |
gr.Markdown("### Quick Actions") | |
emergency_btn = gr.Button("π¨ Emergency Info", variant="secondary") | |
nhs_111_btn = gr.Button("π NHS 111 Info", variant="secondary") | |
booking_btn = gr.Button("π GP Booking", variant="secondary") | |
# Conversation Controls | |
gr.Markdown("### Controls") | |
clear = gr.Button("ποΈ Clear Chat") | |
# Feedback Section | |
gr.Markdown("### Feedback") | |
with gr.Row(): | |
feedback_positive = gr.Button("π", elem_id="thumb-up") | |
feedback_negative = gr.Button("π", elem_id="thumb-down") | |
feedback_text = gr.Textbox( | |
label="Additional comments (optional)", | |
placeholder="Tell us more about your experience...", | |
lines=2, | |
visible=True | |
) | |
feedback_submit = gr.Button("Submit Feedback", visible=True) | |
# Examples Section | |
with gr.Accordion("Example Messages", open=False): | |
gr.Examples( | |
examples=[ | |
["I've been having severe headaches for the past week"], | |
["I need to book a routine checkup"], | |
["I'm feeling very anxious lately and need help"], | |
["My child has had a fever for 2 days"], | |
["I need information about COVID-19 testing"] | |
], | |
inputs=msg | |
) | |
# Information Accordions | |
with gr.Accordion("NHS Services Guide", open=False): | |
gr.Markdown(""" | |
### Emergency Services (999) | |
- Life-threatening emergencies | |
- Severe injuries | |
- Suspected heart attack or stroke | |
### NHS 111 | |
- Urgent but non-emergency situations | |
- Medical advice needed | |
- Unsure where to go | |
### GP Services | |
- Routine check-ups | |
- Non-urgent medical issues | |
- Prescription renewals | |
""") | |
# Event Handlers | |
# Message submission handlers | |
msg.submit( | |
chat, | |
inputs=[msg, chatbot], | |
outputs=[chatbot] | |
).then( | |
lambda: gr.update(value=""), | |
None, | |
[msg] | |
) | |
submit.click( | |
chat, | |
inputs=[msg, chatbot], | |
outputs=[chatbot] | |
).then( | |
lambda: gr.update(value=""), | |
None, | |
[msg] | |
) | |
# Feedback handlers | |
feedback_positive.click( | |
lambda history: process_feedback(True, feedback_text.value, history), | |
inputs=[chatbot], | |
outputs=[feedback_text] | |
) | |
feedback_negative.click( | |
lambda history: process_feedback(False, feedback_text.value, history), | |
inputs=[chatbot], | |
outputs=[feedback_text] | |
) | |
feedback_submit.click( | |
lambda: gr.update(value=""), | |
outputs=[feedback_text] | |
) | |
# Quick Action Button Handlers | |
def show_emergency_info(): | |
return """π¨ Emergency Services (999) | |
- For life-threatening emergencies | |
- Severe chest pain | |
- Difficulty breathing | |
- Severe bleeding | |
- Loss of consciousness | |
""" | |
def show_nhs_111_info(): | |
return """π NHS 111 Service | |
- Available 24/7 | |
- Medical advice | |
- Local service information | |
- Urgent care guidance | |
""" | |
def show_booking_info(): | |
return """π GP Booking Options | |
- Online booking | |
- Phone booking | |
- Routine appointments | |
- Urgent appointments | |
""" | |
emergency_btn.click(lambda: show_emergency_info(), outputs=[msg]) | |
nhs_111_btn.click(lambda: show_nhs_111_info(), outputs=[msg]) | |
booking_btn.click(lambda: show_booking_info(), outputs=[msg]) | |
return demo | |
except Exception as e: | |
logger.error(f"Error creating demo: {e}") | |
raise | |
if __name__ == "__main__": | |
load_dotenv() # Load environment variables | |
demo = create_demo() # Launch the Gradio app | |
demo.launch(share=True) | |