PearlIsa's picture
Update app.py
d5166a8 verified
raw
history blame
30.4 kB
# app.py
import os
import logging
import torch
from typing import Dict, List, Any
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
import faiss
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from dataclasses import dataclass, field
from datetime import datetime
import json
from huggingface_hub import login
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Retrieve secrets securely from environment variables
kaggle_username = os.getenv("KAGGLE_USERNAME")
kaggle_key = os.getenv("KAGGLE_KEY")
hf_token = os.getenv("HF_TOKEN")
wandb_key = os.getenv("WANDB_API_KEY")
# Log in to Hugging Face
if hf_token:
login(token=hf_token)
else:
logger.warning("Hugging Face token not found in environment variables.")
@dataclass
class AdaptiveBotConfig:
"""Configuration for adaptive medical triage bot"""
MODEL_NAME: str = "google/gemma-7b"
EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
# LoRA parameters
LORA_R: int = 8
LORA_ALPHA: int = 16
LORA_DROPOUT: float = 0.1
LORA_TARGET_MODULES: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
# Training parameters
MAX_LENGTH: int = 512
BATCH_SIZE: int = 1
LEARNING_RATE: float = 1e-4
# Adaptive learning parameters
MIN_FEEDBACK_FOR_UPDATE: int = 5
FEEDBACK_HISTORY_SIZE: int = 100
LEARNING_RATE_DECAY: float = 0.95
class AdaptiveMedicalBot:
def __init__(self):
self.config = AdaptiveBotConfig()
self.setup_models()
self.load_datasets()
self.setup_adaptive_learning()
self.document_relevance = {}
def setup_adaptive_learning(self):
"""Initialize adaptive learning components"""
self.feedback_history = []
self.conversation_patterns = {}
self.learning_buffer = []
# Load existing learning data if available
try:
if os.path.exists('learning_data.json'):
with open('learning_data.json', 'r') as f:
data = json.load(f)
self.conversation_patterns = data.get('patterns', {})
self.feedback_history = data.get('feedback', [])
except Exception as e:
logger.warning(f"Could not load learning data: {e}")
def setup_models(self):
"""Initialize models with LoRA and quantization"""
try:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.MODEL_NAME,
trust_remote_code=True
)
base_model = AutoModelForCausalLM.from_pretrained(
self.config.MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
base_model = prepare_model_for_kbit_training(base_model)
lora_config = LoraConfig(
r=self.config.LORA_R,
lora_alpha=self.config.LORA_ALPHA,
target_modules=self.config.LORA_TARGET_MODULES,
lora_dropout=self.config.LORA_DROPOUT,
bias="none",
task_type=TaskType.CAUSAL_LM
)
self.model = get_peft_model(base_model, lora_config)
self.embedding_model = SentenceTransformer(self.config.EMBEDDING_MODEL)
except Exception as e:
logger.error(f"Error setting up models: {e}")
raise
def load_datasets(self):
"""Load and process datasets for RAG"""
try:
datasets = {
"medqa": load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]"),
"diagnosis": load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]"),
"persona": load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
}
self.documents = []
for dataset_name, dataset in datasets.items():
for item in dataset:
if dataset_name == "persona":
if isinstance(item.get('personality'), list):
self.documents.append({
'text': " ".join(item['personality']),
'type': 'persona'
})
else:
if 'input' in item and 'output' in item:
self.documents.append({
'text': f"{item['input']}\n{item['output']}",
'type': dataset_name
})
self._create_index()
except Exception as e:
logger.error(f"Error loading datasets: {e}")
raise
def _create_index(self):
"""Create FAISS index for RAG"""
sample_embedding = self.embedding_model.encode("sample text")
self.index = faiss.IndexFlatIP(sample_embedding.shape[0])
batch_size = 32
for i in range(0, len(self.documents), batch_size):
batch = self.documents[i:i + batch_size]
texts = [doc['text'] for doc in batch]
embeddings = self.embedding_model.encode(texts)
self.index.add(np.array(embeddings))
def analyze_conversation_context(self, message: str, history: List[tuple]) -> Dict[str, Any]:
"""Analyze conversation context to determine appropriate follow-up questions"""
try:
# Extract key information
mentioned_symptoms = set()
time_indicators = set()
severity_indicators = set()
# Analyze current message and history
for msg in [message] + [h[0] for h in (history or [])]:
msg_lower = msg.lower()
# Update conversation patterns
pattern_key = self._extract_pattern_key(msg_lower)
if pattern_key in self.conversation_patterns:
self.conversation_patterns[pattern_key]['frequency'] += 1
else:
self.conversation_patterns[pattern_key] = {
'frequency': 1,
'successful_responses': []
}
return {
'needs_follow_up': True, # Always encourage follow-up questions
'conversation_depth': len(history) if history else 0,
'pattern_key': pattern_key
}
except Exception as e:
logger.error(f"Error analyzing conversation: {e}")
return {'needs_follow_up': True}
def _extract_pattern_key(self, message: str) -> str:
"""Extract conversation pattern key for learning"""
# Simplified pattern extraction - can be enhanced based on learning
words = message.lower().split()
return " ".join(sorted(set(words))[:5])
def generate_follow_up_questions(self, context: Dict[str, Any], message: str) -> List[str]:
"""Generate contextual follow-up questions"""
try:
# Use the model to generate follow-up questions
prompt = f"""Given the patient message: "{message}"
Generate relevant follow-up questions to better understand their situation.
Focus on: timing, severity, associated symptoms, impact on daily life.
Do not make diagnoses or suggest treatments.
Questions:"""
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=self.config.MAX_LENGTH,
truncation=True
).to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
do_sample=True
)
questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return questions.split('\n')
except Exception as e:
logger.error(f"Error generating follow-up questions: {e}")
return ["Could you tell me more about when this started?"]
def update_from_feedback(self, message: str, response: str, feedback: int):
"""Process feedback for adaptive learning"""
try:
self.feedback_history.append({
'message': message,
'response': response,
'feedback': feedback,
'timestamp': datetime.now().isoformat()
})
# Update conversation patterns
pattern_key = self._extract_pattern_key(message)
if pattern_key in self.conversation_patterns:
if feedback > 0:
self.conversation_patterns[pattern_key]['successful_responses'].append(response)
# Save learning data periodically
if len(self.feedback_history) % 10 == 0:
self._save_learning_data()
# Update model if enough feedback
if len(self.feedback_history) >= self.config.MIN_FEEDBACK_FOR_UPDATE:
self._update_model_from_feedback()
except Exception as e:
logger.error(f"Error processing feedback: {e}")
def store_interaction(self, interaction: Dict[str, Any]):
"""Store interaction for adaptive learning"""
try:
self.learning_buffer.append(interaction)
# Update learning patterns
if len(self.learning_buffer) >= 10:
self._update_learning_model()
self._save_learning_data()
self.learning_buffer = []
except Exception as e:
logger.error(f"Error storing interaction: {e}")
def _update_learning_model(self):
"""Update model based on accumulated learning"""
try:
# Process learning buffer
successful_interactions = [
interaction for interaction in self.learning_buffer
if interaction.get('feedback', 0) > 0
]
if successful_interactions:
# Update conversation patterns
for interaction in successful_interactions:
pattern_key = self._extract_pattern_key(interaction['message'])
if pattern_key in self.conversation_patterns:
self.conversation_patterns[pattern_key]['successful_responses'].append(
interaction['response']
)
# Update document relevance
for interaction in successful_interactions:
for doc in interaction.get('relevant_docs', []):
doc_key = doc['text'][:100]
if doc_key in self.document_relevance:
self.document_relevance[doc_key]['success_count'] += 1
logger.info("Updated learning model with new patterns")
except Exception as e:
logger.error(f"Error updating learning model: {e}")
def generate_context_questions(self, message: str, history: List[tuple], context: Dict[str, Any]) -> List[str]:
"""Generate context-aware follow-up questions"""
try:
# Create dynamic question generation prompt
prompt = f"""Based on the conversation context, generate appropriate follow-up questions.
Consider:
- Understanding the main concern
- Timeline and progression
- Impact on daily life
- Related symptoms or factors
- Previous treatments or consultations
Current message: {message}
Context: {context}
Generate questions:"""
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=self.config.MAX_LENGTH,
truncation=True
).to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
do_sample=True
)
questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [q.strip() for q in questions.split("\n") if "?" in q]
except Exception as e:
logger.error(f"Error generating context questions: {e}")
return ["Could you tell me more about your concerns?"]
def _save_learning_data(self):
"""Save learning data to disk"""
try:
data = {
'patterns': self.conversation_patterns,
'feedback': self.feedback_history[-100:] # Keep last 100 entries
}
with open('learning_data.json', 'w') as f:
json.dump(data, f)
except Exception as e:
logger.error(f"Error saving learning data: {e}")
def _update_model_from_feedback(self):
"""Update model based on feedback"""
try:
positive_feedback = [f for f in self.feedback_history if f['feedback'] > 0]
if len(positive_feedback) >= self.config.MIN_FEEDBACK_FOR_UPDATE:
# Prepare training data from successful interactions
training_data = []
for feedback in positive_feedback:
training_data.append({
'input_ids': self.tokenizer(
feedback['message'],
return_tensors='pt'
).input_ids,
'labels': self.tokenizer(
feedback['response'],
return_tensors='pt'
).input_ids
})
# Update model (simplified for example)
logger.info("Updating model from feedback")
self.feedback_history = [] # Clear history after update
except Exception as e:
logger.error(f"Error updating model from feedback: {e}")
def analyze_symptom_context(self, message: str, history: List[tuple]) -> Dict[str, Any]:
"""Enhanced symptom analysis with context learning"""
try:
current_symptoms = set()
temporal_info = {}
related_conditions = set()
conversation_depth = len(history) if history else 0
# Analyze full conversation context
all_messages = [message] + [h[0] for h in (history or [])]
all_responses = [h[1] for h in (history or [])]
# Extract conversation context
context = {
'symptoms_mentioned': current_symptoms,
'temporal_info': temporal_info,
'conversation_depth': conversation_depth,
'needs_clarification': True,
'specialist_referral_needed': False,
'previous_questions': set()
}
if history:
# Learn from previous interactions
for prev_msg, prev_resp in history:
if "?" in prev_resp:
context['previous_questions'].add(prev_resp.split("?")[0] + "?")
return context
except Exception as e:
logger.error(f"Error in symptom analysis: {e}")
return {'needs_clarification': True}
def generate_targeted_questions(self, symptoms: Dict[str, Any], history: List[tuple]) -> List[str]:
"""Generate context-aware follow-up questions"""
try:
# Use the model to generate relevant questions based on context
context_prompt = f"""Based on the patient's symptoms and conversation history, generate 3 specific follow-up questions.
Focus on:
1. Symptom details (duration, severity, patterns)
2. Impact on daily life
3. Related symptoms or conditions
4. Previous treatments or consultations
Do not ask about:
- Questions already asked
- Diagnostic conclusions
- Treatment recommendations
Current context: {symptoms}
Previous questions asked: {symptoms.get('previous_questions', set())}
Generate questions:"""
inputs = self.tokenizer(
context_prompt,
return_tensors="pt",
max_length=self.config.MAX_LENGTH,
truncation=True
).to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
do_sample=True
)
questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [q.strip() for q in questions.split("\n") if "?" in q]
except Exception as e:
logger.error(f"Error generating questions: {e}")
return ["Could you tell me more about your symptoms?"]
def analyze_medical_context(self, message: str, history: List[tuple]) -> Dict[str, Any]:
"""Comprehensive medical context analysis"""
try:
# Initialize context tracking
context = {
'conversation_depth': len(history) if history else 0,
'needs_follow_up': True,
'previous_interactions': [],
'care_pathway': 'initial_triage',
'consultation_type': 'general',
}
# Analyze current conversation flow
all_messages = [message] + [h[0] for h in (history or [])]
# Build contextual understanding
for msg in all_messages:
msg_lower = msg.lower()
# Track conversation patterns
pattern_key = self._extract_pattern_key(msg_lower)
if pattern_key in self.conversation_patterns:
self.conversation_patterns[pattern_key]['frequency'] += 1
context['previous_patterns'] = self.conversation_patterns[pattern_key]
# Update learning patterns
self._update_learning_patterns(msg_lower, context)
return context
except Exception as e:
logger.error(f"Error in medical context analysis: {e}")
return {'needs_follow_up': True}
def generate_adaptive_response(self, message: str, history: List[tuple] = None) -> Dict[str, Any]:
"""Generate comprehensive triage response"""
try:
# Analyze medical context
context = self.analyze_medical_context(message, history)
# Retrieve relevant knowledge
query_embedding = self.embedding_model.encode([message])
_, indices = self.index.search(query_embedding, k=5)
relevant_docs = [self.documents[idx] for idx in indices[0]]
# Build conversation history
conv_history = "\n".join([f"Patient: {h[0]}\nPearly: {h[1]}" for h in (history or [])])
# Create dynamic prompt based on context
prompt = f"""As Pearly, a compassionate GP medical triage assistant, help assess the patient's needs and provide appropriate guidance.
Previous Conversation:
{conv_history}
Current Message: {message}
Medical Knowledge Context:
{[doc['text'] for doc in relevant_docs]}
Guidelines:
- Show empathy and understanding
- Ask relevant follow-up questions
- Guide to appropriate care level (GP, 111, emergency services)
- Consider all aspects of patient care
- Do not diagnose or recommend treatments
- Focus on understanding concerns and proper healthcare guidance
Response:"""
# Generate base response
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=self.config.MAX_LENGTH,
truncation=True
).to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
do_sample=True
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Generate contextual follow-up questions
if context['needs_follow_up']:
follow_ups = self.generate_context_questions(message, history, context)
if follow_ups:
response = f"{response}\n\n{follow_ups[0]}"
# Store interaction for learning
self.store_interaction({
'message': message,
'response': response,
'context': context,
'relevant_docs': relevant_docs,
'timestamp': datetime.now().isoformat()
})
return {
'response': response,
'context': context
}
except Exception as e:
logger.error(f"Error generating response: {e}")
return {
'response': "I apologize, but I'm having technical difficulties. If this is an emergency, please call 999 immediately. For urgent concerns, call 111.",
'context': {}
}
def _add_specialist_guidance(self, response: str, context: Dict[str, Any]) -> str:
"""Add specialist referral guidance to response"""
try:
specialist_prompt = f"""Based on the symptoms and context, suggest appropriate specialist care pathways.
Context: {context}
Current response: {response}
Add appropriate specialist referral guidance:"""
inputs = self.tokenizer(
specialist_prompt,
return_tensors="pt",
max_length=self.config.MAX_LENGTH,
truncation=True
).to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
do_sample=True
)
specialist_guidance = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return f"{response}\n\n{specialist_guidance}"
except Exception as e:
logger.error(f"Error adding specialist guidance: {e}")
return response
def update_learning_from_interaction(self, interaction: Dict[str, Any]):
"""Update adaptive learning system from interaction"""
try:
# Extract key information
message = interaction['message']
response = interaction['response']
context = interaction['context']
relevant_docs = interaction.get('relevant_docs', [])
# Update conversation patterns
pattern_key = self._extract_pattern_key(message)
if pattern_key in self.conversation_patterns:
self.conversation_patterns[pattern_key]['frequency'] += 1
if context.get('successful_response'):
self.conversation_patterns[pattern_key]['successful_responses'].append(response)
# Update document relevance scores
for doc in relevant_docs:
doc_key = doc['text'][:100] # Use first 100 chars as key
if doc_key in self.document_relevance:
self.document_relevance[doc_key]['usage_count'] += 1
if context.get('successful_response'):
self.document_relevance[doc_key]['success_count'] += 1
# Save learning data periodically
if len(self.learning_buffer) >= 10:
self._save_learning_data()
self.learning_buffer = []
except Exception as e:
logger.error(f"Error updating learning system: {e}")
def create_demo():
"""Create Gradio interface with proper message formatting"""
try:
bot = AdaptiveMedicalBot()
def chat(message: str, history: List[Dict[str, str]]):
try:
# Convert history to the format expected by the bot
bot_history = [(h["user"], h["bot"]) for h in history] if history else []
# Generate response
response_data = bot.generate_adaptive_response(message, bot_history)
response = response_data['response']
# Format response for Gradio chat
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
return history
except Exception as e:
logger.error(f"Chat error: {e}")
return history + [
{"role": "user", "content": message},
{"role": "assistant", "content": "I apologize, but I'm experiencing technical difficulties. For emergencies, please call 999."}
]
def process_feedback(feedback: str, history: List[Dict[str, str]]):
"""Process feedback when received"""
try:
if history and len(history) >= 2:
last_user_msg = history[-2]["content"]
last_bot_msg = history[-1]["content"]
bot.update_from_feedback(
last_user_msg,
last_bot_msg,
1 if feedback == "πŸ‘" else -1
)
except Exception as e:
logger.error(f"Error processing feedback: {e}")
# Create Gradio interface with proper message handling
with gr.Blocks(theme="soft") as demo:
gr.Markdown("""
# GP Medical Triage Assistant - Pearly
πŸ‘‹ Hello! I'm Pearly, your GP medical assistant.
I can help you with:
β€’ Assessing your symptoms
β€’ Finding appropriate care
β€’ Booking appointments
β€’ General medical guidance
For emergencies, always call 999.
For urgent concerns, call 111.
Please describe your concerns below.
""")
# Initialize chat components
chatbot = gr.Chatbot(
value=[{"role": "assistant", "content": "Hello! I'm Pearly, your GP medical assistant. How can I help you today?"}],
height=500,
type="messages"
)
msg = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
lines=2
)
with gr.Row():
submit = gr.Button("Submit")
clear = gr.Button("Clear")
feedback = gr.Radio(
choices=["πŸ‘", "πŸ‘Ž"],
label="Was this response helpful?",
visible=True
)
# Set up event handlers
submit.click(
chat,
inputs=[msg, chatbot],
outputs=[chatbot],
queue=False
).then(
lambda: "",
None,
msg # Clear input box after sending
)
msg.submit(
chat,
inputs=[msg, chatbot],
outputs=[chatbot],
queue=False
).then(
lambda: "",
None,
msg # Clear input box after sending
)
clear.click(
lambda: [[], ""],
None,
[chatbot, msg],
queue=False
)
feedback.change(
process_feedback,
inputs=[feedback, chatbot],
outputs=[],
queue=False
)
# Add example messages
gr.Examples(
examples=[
"I've been having headaches recently",
"I need to book a routine checkup",
"I'm feeling very anxious lately"
],
inputs=msg
)
return demo
except Exception as e:
logger.error(f"Error creating demo: {e}")
raise
if __name__ == "__main__":
# Initialize environment
load_dotenv()
# Set up HuggingFace login if token exists
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
# Launch demo
demo = create_demo()
demo.launch()