Spaces:
Runtime error
Runtime error
# app.py | |
import os | |
import logging | |
import torch | |
from typing import Dict, List, Any | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer | |
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training | |
import faiss | |
import numpy as np | |
from tqdm import tqdm | |
from datasets import load_dataset | |
from dataclasses import dataclass, field | |
from datetime import datetime | |
import json | |
from huggingface_hub import login | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Retrieve secrets securely from environment variables | |
kaggle_username = os.getenv("KAGGLE_USERNAME") | |
kaggle_key = os.getenv("KAGGLE_KEY") | |
hf_token = os.getenv("HF_TOKEN") | |
wandb_key = os.getenv("WANDB_API_KEY") | |
# Log in to Hugging Face | |
if hf_token: | |
login(token=hf_token) | |
else: | |
logger.warning("Hugging Face token not found in environment variables.") | |
class AdaptiveBotConfig: | |
"""Configuration for adaptive medical triage bot""" | |
MODEL_NAME: str = "google/gemma-7b" | |
EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2" | |
# LoRA parameters | |
LORA_R: int = 8 | |
LORA_ALPHA: int = 16 | |
LORA_DROPOUT: float = 0.1 | |
LORA_TARGET_MODULES: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) | |
# Training parameters | |
MAX_LENGTH: int = 512 | |
BATCH_SIZE: int = 1 | |
LEARNING_RATE: float = 1e-4 | |
# Adaptive learning parameters | |
MIN_FEEDBACK_FOR_UPDATE: int = 5 | |
FEEDBACK_HISTORY_SIZE: int = 100 | |
LEARNING_RATE_DECAY: float = 0.95 | |
class AdaptiveMedicalBot: | |
def __init__(self): | |
self.config = AdaptiveBotConfig() | |
self.setup_models() | |
self.load_datasets() | |
self.setup_adaptive_learning() | |
self.document_relevance = {} | |
def setup_adaptive_learning(self): | |
"""Initialize adaptive learning components""" | |
self.feedback_history = [] | |
self.conversation_patterns = {} | |
self.learning_buffer = [] | |
# Load existing learning data if available | |
try: | |
if os.path.exists('learning_data.json'): | |
with open('learning_data.json', 'r') as f: | |
data = json.load(f) | |
self.conversation_patterns = data.get('patterns', {}) | |
self.feedback_history = data.get('feedback', []) | |
except Exception as e: | |
logger.warning(f"Could not load learning data: {e}") | |
def setup_models(self): | |
"""Initialize models with LoRA and quantization""" | |
try: | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.config.MODEL_NAME, | |
trust_remote_code=True | |
) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
self.config.MODEL_NAME, | |
quantization_config=bnb_config, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
base_model = prepare_model_for_kbit_training(base_model) | |
lora_config = LoraConfig( | |
r=self.config.LORA_R, | |
lora_alpha=self.config.LORA_ALPHA, | |
target_modules=self.config.LORA_TARGET_MODULES, | |
lora_dropout=self.config.LORA_DROPOUT, | |
bias="none", | |
task_type=TaskType.CAUSAL_LM | |
) | |
self.model = get_peft_model(base_model, lora_config) | |
self.embedding_model = SentenceTransformer(self.config.EMBEDDING_MODEL) | |
except Exception as e: | |
logger.error(f"Error setting up models: {e}") | |
raise | |
def load_datasets(self): | |
"""Load and process datasets for RAG""" | |
try: | |
datasets = { | |
"medqa": load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]"), | |
"diagnosis": load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]"), | |
"persona": load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]") | |
} | |
self.documents = [] | |
for dataset_name, dataset in datasets.items(): | |
for item in dataset: | |
if dataset_name == "persona": | |
if isinstance(item.get('personality'), list): | |
self.documents.append({ | |
'text': " ".join(item['personality']), | |
'type': 'persona' | |
}) | |
else: | |
if 'input' in item and 'output' in item: | |
self.documents.append({ | |
'text': f"{item['input']}\n{item['output']}", | |
'type': dataset_name | |
}) | |
self._create_index() | |
except Exception as e: | |
logger.error(f"Error loading datasets: {e}") | |
raise | |
def _create_index(self): | |
"""Create FAISS index for RAG""" | |
sample_embedding = self.embedding_model.encode("sample text") | |
self.index = faiss.IndexFlatIP(sample_embedding.shape[0]) | |
batch_size = 32 | |
for i in range(0, len(self.documents), batch_size): | |
batch = self.documents[i:i + batch_size] | |
texts = [doc['text'] for doc in batch] | |
embeddings = self.embedding_model.encode(texts) | |
self.index.add(np.array(embeddings)) | |
def analyze_conversation_context(self, message: str, history: List[tuple]) -> Dict[str, Any]: | |
"""Analyze conversation context to determine appropriate follow-up questions""" | |
try: | |
# Extract key information | |
mentioned_symptoms = set() | |
time_indicators = set() | |
severity_indicators = set() | |
# Analyze current message and history | |
for msg in [message] + [h[0] for h in (history or [])]: | |
msg_lower = msg.lower() | |
# Update conversation patterns | |
pattern_key = self._extract_pattern_key(msg_lower) | |
if pattern_key in self.conversation_patterns: | |
self.conversation_patterns[pattern_key]['frequency'] += 1 | |
else: | |
self.conversation_patterns[pattern_key] = { | |
'frequency': 1, | |
'successful_responses': [] | |
} | |
return { | |
'needs_follow_up': True, # Always encourage follow-up questions | |
'conversation_depth': len(history) if history else 0, | |
'pattern_key': pattern_key | |
} | |
except Exception as e: | |
logger.error(f"Error analyzing conversation: {e}") | |
return {'needs_follow_up': True} | |
def _extract_pattern_key(self, message: str) -> str: | |
"""Extract conversation pattern key for learning""" | |
# Simplified pattern extraction - can be enhanced based on learning | |
words = message.lower().split() | |
return " ".join(sorted(set(words))[:5]) | |
def generate_follow_up_questions(self, context: Dict[str, Any], message: str) -> List[str]: | |
"""Generate contextual follow-up questions""" | |
try: | |
# Use the model to generate follow-up questions | |
prompt = f"""Given the patient message: "{message}" | |
Generate relevant follow-up questions to better understand their situation. | |
Focus on: timing, severity, associated symptoms, impact on daily life. | |
Do not make diagnoses or suggest treatments. | |
Questions:""" | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=self.config.MAX_LENGTH, | |
truncation=True | |
).to(self.model.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=100, | |
temperature=0.7, | |
do_sample=True | |
) | |
questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return questions.split('\n') | |
except Exception as e: | |
logger.error(f"Error generating follow-up questions: {e}") | |
return ["Could you tell me more about when this started?"] | |
def update_from_feedback(self, message: str, response: str, feedback: int): | |
"""Process feedback for adaptive learning""" | |
try: | |
self.feedback_history.append({ | |
'message': message, | |
'response': response, | |
'feedback': feedback, | |
'timestamp': datetime.now().isoformat() | |
}) | |
# Update conversation patterns | |
pattern_key = self._extract_pattern_key(message) | |
if pattern_key in self.conversation_patterns: | |
if feedback > 0: | |
self.conversation_patterns[pattern_key]['successful_responses'].append(response) | |
# Save learning data periodically | |
if len(self.feedback_history) % 10 == 0: | |
self._save_learning_data() | |
# Update model if enough feedback | |
if len(self.feedback_history) >= self.config.MIN_FEEDBACK_FOR_UPDATE: | |
self._update_model_from_feedback() | |
except Exception as e: | |
logger.error(f"Error processing feedback: {e}") | |
def store_interaction(self, interaction: Dict[str, Any]): | |
"""Store interaction for adaptive learning""" | |
try: | |
self.learning_buffer.append(interaction) | |
# Update learning patterns | |
if len(self.learning_buffer) >= 10: | |
self._update_learning_model() | |
self._save_learning_data() | |
self.learning_buffer = [] | |
except Exception as e: | |
logger.error(f"Error storing interaction: {e}") | |
def _update_learning_model(self): | |
"""Update model based on accumulated learning""" | |
try: | |
# Process learning buffer | |
successful_interactions = [ | |
interaction for interaction in self.learning_buffer | |
if interaction.get('feedback', 0) > 0 | |
] | |
if successful_interactions: | |
# Update conversation patterns | |
for interaction in successful_interactions: | |
pattern_key = self._extract_pattern_key(interaction['message']) | |
if pattern_key in self.conversation_patterns: | |
self.conversation_patterns[pattern_key]['successful_responses'].append( | |
interaction['response'] | |
) | |
# Update document relevance | |
for interaction in successful_interactions: | |
for doc in interaction.get('relevant_docs', []): | |
doc_key = doc['text'][:100] | |
if doc_key in self.document_relevance: | |
self.document_relevance[doc_key]['success_count'] += 1 | |
logger.info("Updated learning model with new patterns") | |
except Exception as e: | |
logger.error(f"Error updating learning model: {e}") | |
def generate_context_questions(self, message: str, history: List[tuple], context: Dict[str, Any]) -> List[str]: | |
"""Generate context-aware follow-up questions""" | |
try: | |
# Create dynamic question generation prompt | |
prompt = f"""Based on the conversation context, generate appropriate follow-up questions. | |
Consider: | |
- Understanding the main concern | |
- Timeline and progression | |
- Impact on daily life | |
- Related symptoms or factors | |
- Previous treatments or consultations | |
Current message: {message} | |
Context: {context} | |
Generate questions:""" | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=self.config.MAX_LENGTH, | |
truncation=True | |
).to(self.model.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=150, | |
temperature=0.7, | |
do_sample=True | |
) | |
questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return [q.strip() for q in questions.split("\n") if "?" in q] | |
except Exception as e: | |
logger.error(f"Error generating context questions: {e}") | |
return ["Could you tell me more about your concerns?"] | |
def _save_learning_data(self): | |
"""Save learning data to disk""" | |
try: | |
data = { | |
'patterns': self.conversation_patterns, | |
'feedback': self.feedback_history[-100:] # Keep last 100 entries | |
} | |
with open('learning_data.json', 'w') as f: | |
json.dump(data, f) | |
except Exception as e: | |
logger.error(f"Error saving learning data: {e}") | |
def _update_model_from_feedback(self): | |
"""Update model based on feedback""" | |
try: | |
positive_feedback = [f for f in self.feedback_history if f['feedback'] > 0] | |
if len(positive_feedback) >= self.config.MIN_FEEDBACK_FOR_UPDATE: | |
# Prepare training data from successful interactions | |
training_data = [] | |
for feedback in positive_feedback: | |
training_data.append({ | |
'input_ids': self.tokenizer( | |
feedback['message'], | |
return_tensors='pt' | |
).input_ids, | |
'labels': self.tokenizer( | |
feedback['response'], | |
return_tensors='pt' | |
).input_ids | |
}) | |
# Update model (simplified for example) | |
logger.info("Updating model from feedback") | |
self.feedback_history = [] # Clear history after update | |
except Exception as e: | |
logger.error(f"Error updating model from feedback: {e}") | |
def analyze_symptom_context(self, message: str, history: List[tuple]) -> Dict[str, Any]: | |
"""Enhanced symptom analysis with context learning""" | |
try: | |
current_symptoms = set() | |
temporal_info = {} | |
related_conditions = set() | |
conversation_depth = len(history) if history else 0 | |
# Analyze full conversation context | |
all_messages = [message] + [h[0] for h in (history or [])] | |
all_responses = [h[1] for h in (history or [])] | |
# Extract conversation context | |
context = { | |
'symptoms_mentioned': current_symptoms, | |
'temporal_info': temporal_info, | |
'conversation_depth': conversation_depth, | |
'needs_clarification': True, | |
'specialist_referral_needed': False, | |
'previous_questions': set() | |
} | |
if history: | |
# Learn from previous interactions | |
for prev_msg, prev_resp in history: | |
if "?" in prev_resp: | |
context['previous_questions'].add(prev_resp.split("?")[0] + "?") | |
return context | |
except Exception as e: | |
logger.error(f"Error in symptom analysis: {e}") | |
return {'needs_clarification': True} | |
def generate_targeted_questions(self, symptoms: Dict[str, Any], history: List[tuple]) -> List[str]: | |
"""Generate context-aware follow-up questions""" | |
try: | |
# Use the model to generate relevant questions based on context | |
context_prompt = f"""Based on the patient's symptoms and conversation history, generate 3 specific follow-up questions. | |
Focus on: | |
1. Symptom details (duration, severity, patterns) | |
2. Impact on daily life | |
3. Related symptoms or conditions | |
4. Previous treatments or consultations | |
Do not ask about: | |
- Questions already asked | |
- Diagnostic conclusions | |
- Treatment recommendations | |
Current context: {symptoms} | |
Previous questions asked: {symptoms.get('previous_questions', set())} | |
Generate questions:""" | |
inputs = self.tokenizer( | |
context_prompt, | |
return_tensors="pt", | |
max_length=self.config.MAX_LENGTH, | |
truncation=True | |
).to(self.model.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=150, | |
temperature=0.7, | |
do_sample=True | |
) | |
questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return [q.strip() for q in questions.split("\n") if "?" in q] | |
except Exception as e: | |
logger.error(f"Error generating questions: {e}") | |
return ["Could you tell me more about your symptoms?"] | |
def analyze_medical_context(self, message: str, history: List[tuple]) -> Dict[str, Any]: | |
"""Comprehensive medical context analysis""" | |
try: | |
# Initialize context tracking | |
context = { | |
'conversation_depth': len(history) if history else 0, | |
'needs_follow_up': True, | |
'previous_interactions': [], | |
'care_pathway': 'initial_triage', | |
'consultation_type': 'general', | |
} | |
# Analyze current conversation flow | |
all_messages = [message] + [h[0] for h in (history or [])] | |
# Build contextual understanding | |
for msg in all_messages: | |
msg_lower = msg.lower() | |
# Track conversation patterns | |
pattern_key = self._extract_pattern_key(msg_lower) | |
if pattern_key in self.conversation_patterns: | |
self.conversation_patterns[pattern_key]['frequency'] += 1 | |
context['previous_patterns'] = self.conversation_patterns[pattern_key] | |
# Update learning patterns | |
self._update_learning_patterns(msg_lower, context) | |
return context | |
except Exception as e: | |
logger.error(f"Error in medical context analysis: {e}") | |
return {'needs_follow_up': True} | |
def generate_adaptive_response(self, message: str, history: List[tuple] = None) -> Dict[str, Any]: | |
"""Generate comprehensive triage response""" | |
try: | |
# Analyze medical context | |
context = self.analyze_medical_context(message, history) | |
# Retrieve relevant knowledge | |
query_embedding = self.embedding_model.encode([message]) | |
_, indices = self.index.search(query_embedding, k=5) | |
relevant_docs = [self.documents[idx] for idx in indices[0]] | |
# Build conversation history | |
conv_history = "\n".join([f"Patient: {h[0]}\nPearly: {h[1]}" for h in (history or [])]) | |
# Create dynamic prompt based on context | |
prompt = f"""As Pearly, a compassionate GP medical triage assistant, help assess the patient's needs and provide appropriate guidance. | |
Previous Conversation: | |
{conv_history} | |
Current Message: {message} | |
Medical Knowledge Context: | |
{[doc['text'] for doc in relevant_docs]} | |
Guidelines: | |
- Show empathy and understanding | |
- Ask relevant follow-up questions | |
- Guide to appropriate care level (GP, 111, emergency services) | |
- Consider all aspects of patient care | |
- Do not diagnose or recommend treatments | |
- Focus on understanding concerns and proper healthcare guidance | |
Response:""" | |
# Generate base response | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=self.config.MAX_LENGTH, | |
truncation=True | |
).to(self.model.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=300, | |
temperature=0.7, | |
do_sample=True | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Generate contextual follow-up questions | |
if context['needs_follow_up']: | |
follow_ups = self.generate_context_questions(message, history, context) | |
if follow_ups: | |
response = f"{response}\n\n{follow_ups[0]}" | |
# Store interaction for learning | |
self.store_interaction({ | |
'message': message, | |
'response': response, | |
'context': context, | |
'relevant_docs': relevant_docs, | |
'timestamp': datetime.now().isoformat() | |
}) | |
return { | |
'response': response, | |
'context': context | |
} | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
return { | |
'response': "I apologize, but I'm having technical difficulties. If this is an emergency, please call 999 immediately. For urgent concerns, call 111.", | |
'context': {} | |
} | |
def _add_specialist_guidance(self, response: str, context: Dict[str, Any]) -> str: | |
"""Add specialist referral guidance to response""" | |
try: | |
specialist_prompt = f"""Based on the symptoms and context, suggest appropriate specialist care pathways. | |
Context: {context} | |
Current response: {response} | |
Add appropriate specialist referral guidance:""" | |
inputs = self.tokenizer( | |
specialist_prompt, | |
return_tensors="pt", | |
max_length=self.config.MAX_LENGTH, | |
truncation=True | |
).to(self.model.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=150, | |
temperature=0.7, | |
do_sample=True | |
) | |
specialist_guidance = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return f"{response}\n\n{specialist_guidance}" | |
except Exception as e: | |
logger.error(f"Error adding specialist guidance: {e}") | |
return response | |
def update_learning_from_interaction(self, interaction: Dict[str, Any]): | |
"""Update adaptive learning system from interaction""" | |
try: | |
# Extract key information | |
message = interaction['message'] | |
response = interaction['response'] | |
context = interaction['context'] | |
relevant_docs = interaction.get('relevant_docs', []) | |
# Update conversation patterns | |
pattern_key = self._extract_pattern_key(message) | |
if pattern_key in self.conversation_patterns: | |
self.conversation_patterns[pattern_key]['frequency'] += 1 | |
if context.get('successful_response'): | |
self.conversation_patterns[pattern_key]['successful_responses'].append(response) | |
# Update document relevance scores | |
for doc in relevant_docs: | |
doc_key = doc['text'][:100] # Use first 100 chars as key | |
if doc_key in self.document_relevance: | |
self.document_relevance[doc_key]['usage_count'] += 1 | |
if context.get('successful_response'): | |
self.document_relevance[doc_key]['success_count'] += 1 | |
# Save learning data periodically | |
if len(self.learning_buffer) >= 10: | |
self._save_learning_data() | |
self.learning_buffer = [] | |
except Exception as e: | |
logger.error(f"Error updating learning system: {e}") | |
def create_demo(): | |
"""Create Gradio interface with proper message formatting""" | |
try: | |
bot = AdaptiveMedicalBot() | |
def chat(message: str, history: List[Dict[str, str]]): | |
try: | |
# Convert history to the format expected by the bot | |
bot_history = [(h["user"], h["bot"]) for h in history] if history else [] | |
# Generate response | |
response_data = bot.generate_adaptive_response(message, bot_history) | |
response = response_data['response'] | |
# Format response for Gradio chat | |
history.append({"role": "user", "content": message}) | |
history.append({"role": "assistant", "content": response}) | |
return history | |
except Exception as e: | |
logger.error(f"Chat error: {e}") | |
return history + [ | |
{"role": "user", "content": message}, | |
{"role": "assistant", "content": "I apologize, but I'm experiencing technical difficulties. For emergencies, please call 999."} | |
] | |
def process_feedback(feedback: str, history: List[Dict[str, str]]): | |
"""Process feedback when received""" | |
try: | |
if history and len(history) >= 2: | |
last_user_msg = history[-2]["content"] | |
last_bot_msg = history[-1]["content"] | |
bot.update_from_feedback( | |
last_user_msg, | |
last_bot_msg, | |
1 if feedback == "π" else -1 | |
) | |
except Exception as e: | |
logger.error(f"Error processing feedback: {e}") | |
# Create Gradio interface with proper message handling | |
with gr.Blocks(theme="soft") as demo: | |
gr.Markdown(""" | |
# GP Medical Triage Assistant - Pearly | |
π Hello! I'm Pearly, your GP medical assistant. | |
I can help you with: | |
β’ Assessing your symptoms | |
β’ Finding appropriate care | |
β’ Booking appointments | |
β’ General medical guidance | |
For emergencies, always call 999. | |
For urgent concerns, call 111. | |
Please describe your concerns below. | |
""") | |
# Initialize chat components | |
chatbot = gr.Chatbot( | |
value=[{"role": "assistant", "content": "Hello! I'm Pearly, your GP medical assistant. How can I help you today?"}], | |
height=500, | |
type="messages" | |
) | |
msg = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
lines=2 | |
) | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
clear = gr.Button("Clear") | |
feedback = gr.Radio( | |
choices=["π", "π"], | |
label="Was this response helpful?", | |
visible=True | |
) | |
# Set up event handlers | |
submit.click( | |
chat, | |
inputs=[msg, chatbot], | |
outputs=[chatbot], | |
queue=False | |
).then( | |
lambda: "", | |
None, | |
msg # Clear input box after sending | |
) | |
msg.submit( | |
chat, | |
inputs=[msg, chatbot], | |
outputs=[chatbot], | |
queue=False | |
).then( | |
lambda: "", | |
None, | |
msg # Clear input box after sending | |
) | |
clear.click( | |
lambda: [[], ""], | |
None, | |
[chatbot, msg], | |
queue=False | |
) | |
feedback.change( | |
process_feedback, | |
inputs=[feedback, chatbot], | |
outputs=[], | |
queue=False | |
) | |
# Add example messages | |
gr.Examples( | |
examples=[ | |
"I've been having headaches recently", | |
"I need to book a routine checkup", | |
"I'm feeling very anxious lately" | |
], | |
inputs=msg | |
) | |
return demo | |
except Exception as e: | |
logger.error(f"Error creating demo: {e}") | |
raise | |
if __name__ == "__main__": | |
# Initialize environment | |
load_dotenv() | |
# Set up HuggingFace login if token exists | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
login(token=hf_token) | |
# Launch demo | |
demo = create_demo() | |
demo.launch() |