PearlIsa's picture
Update app.py
4cabc00 verified
raw
history blame
13.3 kB
# app.py
import os
import logging
import torch
from typing import Dict, List, Any
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
import faiss
import numpy as np
from datasets import load_dataset
from datetime import datetime
import json
from huggingface_hub import login
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Retrieve secrets securely from environment variables
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
class AdaptiveMedicalBot:
def __init__(self):
self.config = self.AdaptiveBotConfig()
self.setup_models()
self.load_datasets()
self.setup_adaptive_learning()
self.conversation_history = [] # Maintain conversation history
class AdaptiveBotConfig:
MODEL_NAME = "google/gemma-7b"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.1
LORA_TARGET_MODULES = ["q_proj", "v_proj"]
MAX_LENGTH = 512
BATCH_SIZE = 1
LEARNING_RATE = 1e-4
def setup_adaptive_learning(self):
"""Initialize adaptive learning components"""
self.feedback_history = []
def setup_models(self):
"""Initialize models with LoRA and quantization"""
try:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
self.tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME)
base_model = AutoModelForCausalLM.from_pretrained(
self.config.MODEL_NAME,
quantization_config=bnb_config,
device_map="auto"
)
base_model = prepare_model_for_kbit_training(base_model)
lora_config = LoraConfig(
r=self.config.LORA_R,
lora_alpha=self.config.LORA_ALPHA,
target_modules=self.config.LORA_TARGET_MODULES,
lora_dropout=self.config.LORA_DROPOUT,
bias="none",
task_type=TaskType.CAUSAL_LM
)
self.model = get_peft_model(base_model, lora_config)
self.embedding_model = SentenceTransformer(self.config.EMBEDDING_MODEL)
except Exception as e:
logger.error(f"Error setting up models: {e}")
raise
def load_datasets(self):
"""Load and prepare datasets for RAG"""
try:
datasets = {
"medqa": load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]"),
"diagnosis": load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]"),
"persona": load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
}
self.documents = []
for dataset_name, dataset in datasets.items():
for item in dataset:
if dataset_name == "persona":
if isinstance(item.get('personality'), list):
self.documents.append({'text': " ".join(item['personality']), 'type': 'persona'})
else:
if 'input' in item and 'output' in item:
self.documents.append({'text': f"{item['input']}\n{item['output']}", 'type': dataset_name})
self._create_index()
except Exception as e:
logger.error(f"Error loading datasets: {e}")
raise
def _create_index(self):
"""Create FAISS index for RAG"""
try:
sample_embedding = self.embedding_model.encode("sample text")
self.index = faiss.IndexFlatIP(sample_embedding.shape[0])
embeddings = [self.embedding_model.encode(doc['text']) for doc in self.documents]
self.index.add(np.array(embeddings))
except Exception as e:
logger.error(f"Error creating FAISS index: {e}")
raise
def generate_follow_up_questions(self, message: str, context: Dict[str, Any]) -> List[str]:
"""Generate follow-up questions based on context"""
try:
prompt = f"""Patient message: "{message}"
Generate relevant follow-up questions focusing on timing, severity, associated symptoms, and impact on daily life.
Questions:"""
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device)
outputs = self.model.generate(inputs['input_ids'], max_new_tokens=50, temperature=0.7, do_sample=True)
questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return questions.split("\n")
except Exception as e:
logger.error(f"Error generating follow-up questions: {e}")
return ["Could you tell me more about when this started?"]
def assess_symptom_severity(self, message: str) -> str:
"""Assess severity based on keywords in the message"""
if "severe" in message.lower() or "emergency" in message.lower():
return "emergency"
elif "persistent" in message.lower() or "moderate" in message.lower():
return "urgent"
return "routine"
def generate_response(self, message: str) -> Dict[str, Any]:
"""Generate a response based on the message"""
try:
severity = self.assess_symptom_severity(message)
response = ""
# Retrieve relevant documents from FAISS
query_embedding = self.embedding_model.encode([message])
_, indices = self.index.search(query_embedding, k=5)
relevant_docs = [self.documents[idx]['text'] for idx in indices[0]]
prompt = f"""As a compassionate medical assistant, analyze the patient message: "{message}".
Consider relevant knowledge and the following documents:\n{relevant_docs}.
Respond with empathy, follow-up questions, and care guidance."""
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device)
outputs = self.model.generate(inputs['input_ids'], max_new_tokens=100, temperature=0.7, do_sample=True)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
follow_ups = self.generate_follow_up_questions(message, {})
response += f"\n{follow_ups[0]}"
# Append response to conversation history
self.conversation_history.append((message, response))
# Add care level guidance
if severity == "emergency":
response += "\nThis seems urgent. Please call 999 immediately."
elif severity == "urgent":
response += "\nConsider calling NHS 111 for urgent assistance."
return {'response': response}
except Exception as e:
logger.error(f"Error generating response: {e}")
return {
'response': "I'm experiencing technical issues. If this is an emergency, please call 999 immediately.",
}
def handle_feedback(self, message: str, response: str, feedback: int):
"""Update model based on feedback"""
try:
self.feedback_history.append({
'message': message,
'response': response,
'feedback': feedback,
'timestamp': datetime.now().isoformat()
})
if len(self.feedback_history) >= 10:
# Implement learning updates from feedback
self.feedback_history = [] # Reset history after learning update
except Exception as e:
logger.error(f"Error processing feedback: {e}")
def create_demo():
"""Set up Gradio interface for the chatbot"""
try:
bot = AdaptiveMedicalBot()
def chat(message: str, history: List[Dict[str, str]]):
try:
bot_history = [(h["user"], h["bot"]) for h in history] if history else []
response_data = bot.generate_response(message)
response = response_data['response']
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
return history
except Exception as e:
logger.error(f"Chat error: {e}")
return history + [
{"role": "user", "content": message},
{"role": "assistant", "content": "I'm experiencing technical difficulties. For emergencies, call 999."}
]
def process_feedback(feedback: str, history: List[Dict[str, str]], comment: str = ""):
try:
if history and len(history) >= 2:
last_user_msg = history[-2]["content"]
last_bot_msg = history[-1]["content"]
bot.handle_feedback(last_user_msg, last_bot_msg, 1 if feedback == "πŸ‘" else -1)
except Exception as e:
logger.error(f"Error processing feedback: {e}")
with gr.Blocks() as demo:
chatbot = gr.Chatbot(value=[{"role": "assistant", "content": "Hello! I'm Pearly, your GP Triage medical assistant. How can I help you today?"}],
height=500,
elem_id="chatbot",
type="messages",
show_label=False
)
msg = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
lines=2
)
submit = gr.Button("Send", variant="primary")
feedback = gr.Radio(
choices=["πŸ‘", "πŸ‘Ž"],
label="Was this response helpful?",
visible=True
)
feedback_text = gr.Textbox(
label="Additional comments (optional)",
placeholder="Tell us more about your experience...",
lines=2
)
# Event Handlers
submit.click(
fn=chat,
inputs=[msg, chatbot],
outputs=[chatbot]
).then(
lambda: "",
None,
msg
)
msg.submit(
fn=chat,
inputs=[msg, chatbot],
outputs=[chatbot]
).then(
lambda: "",
None,
msg
)
feedback.change(
fn=process_feedback,
inputs=[feedback, chatbot, feedback_text],
outputs=[]
)
# Clear Chat Handler
clear = gr.Button("πŸ—‘οΈ Clear Chat")
clear.click(lambda: [[], ""], None, [chatbot, msg])
# Additional Information Sections
gr.HTML("""
<div style="padding: 20px;">
<h2>Quick Actions</h2>
<button onclick="document.getElementById('chatbot').value += 'Emergency Info: For emergencies, call 999.'">Emergency Info</button>
<button onclick="document.getElementById('chatbot').value += 'NHS 111 Info: For urgent but non-emergency situations, call 111.'">NHS 111 Info</button>
<button onclick="document.getElementById('chatbot').value += 'GP Booking Info: For routine appointments with your GP.'">GP Booking</button>
</div>
""")
gr.Markdown("### Example Messages")
gr.Examples(
examples=[
["I've been having severe headaches for the past week"],
["I need to book a routine checkup"],
["I'm feeling very anxious lately and need help"],
["My child has had a fever for 2 days"],
["I need information about COVID-19 testing"]
],
inputs=msg
)
gr.Markdown("""
### NHS Services Guide
**999 - Emergency Services**: For life-threatening emergencies, severe injuries, heart attack, stroke.
**NHS 111**: Available 24/7 for urgent but non-life-threatening situations, medical advice, and guidance.
**GP Services**: Routine check-ups, non-urgent medical issues, and prescription renewals.
""")
return demo
except Exception as e:
logger.error(f"Error creating demo: {e}")
raise
if __name__ == "__main__":
# Launch Gradio Interface
demo = create_demo()
demo.launch()