Spaces:
Runtime error
Runtime error
# app.py | |
import os | |
import logging | |
import torch | |
from typing import Dict, List, Any | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer | |
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training | |
import faiss | |
import numpy as np | |
from datasets import load_dataset | |
from datetime import datetime | |
import json | |
from huggingface_hub import login | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Retrieve secrets securely from environment variables | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
login(token=hf_token) | |
class AdaptiveMedicalBot: | |
def __init__(self): | |
self.config = self.AdaptiveBotConfig() | |
self.setup_models() | |
self.load_datasets() | |
self.setup_adaptive_learning() | |
self.conversation_history = [] # Maintain conversation history | |
class AdaptiveBotConfig: | |
MODEL_NAME = "google/gemma-7b" | |
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
LORA_R = 8 | |
LORA_ALPHA = 16 | |
LORA_DROPOUT = 0.1 | |
LORA_TARGET_MODULES = ["q_proj", "v_proj"] | |
MAX_LENGTH = 512 | |
BATCH_SIZE = 1 | |
LEARNING_RATE = 1e-4 | |
def setup_adaptive_learning(self): | |
"""Initialize adaptive learning components""" | |
self.feedback_history = [] | |
def setup_models(self): | |
"""Initialize models with LoRA and quantization""" | |
try: | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
self.config.MODEL_NAME, | |
quantization_config=bnb_config, | |
device_map="auto" | |
) | |
base_model = prepare_model_for_kbit_training(base_model) | |
lora_config = LoraConfig( | |
r=self.config.LORA_R, | |
lora_alpha=self.config.LORA_ALPHA, | |
target_modules=self.config.LORA_TARGET_MODULES, | |
lora_dropout=self.config.LORA_DROPOUT, | |
bias="none", | |
task_type=TaskType.CAUSAL_LM | |
) | |
self.model = get_peft_model(base_model, lora_config) | |
self.embedding_model = SentenceTransformer(self.config.EMBEDDING_MODEL) | |
except Exception as e: | |
logger.error(f"Error setting up models: {e}") | |
raise | |
def load_datasets(self): | |
"""Load and prepare datasets for RAG""" | |
try: | |
datasets = { | |
"medqa": load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]"), | |
"diagnosis": load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]"), | |
"persona": load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]") | |
} | |
self.documents = [] | |
for dataset_name, dataset in datasets.items(): | |
for item in dataset: | |
if dataset_name == "persona": | |
if isinstance(item.get('personality'), list): | |
self.documents.append({'text': " ".join(item['personality']), 'type': 'persona'}) | |
else: | |
if 'input' in item and 'output' in item: | |
self.documents.append({'text': f"{item['input']}\n{item['output']}", 'type': dataset_name}) | |
self._create_index() | |
except Exception as e: | |
logger.error(f"Error loading datasets: {e}") | |
raise | |
def _create_index(self): | |
"""Create FAISS index for RAG""" | |
try: | |
sample_embedding = self.embedding_model.encode("sample text") | |
self.index = faiss.IndexFlatIP(sample_embedding.shape[0]) | |
embeddings = [self.embedding_model.encode(doc['text']) for doc in self.documents] | |
self.index.add(np.array(embeddings)) | |
except Exception as e: | |
logger.error(f"Error creating FAISS index: {e}") | |
raise | |
def generate_follow_up_questions(self, message: str, context: Dict[str, Any]) -> List[str]: | |
"""Generate follow-up questions based on context""" | |
try: | |
prompt = f"""Patient message: "{message}" | |
Generate relevant follow-up questions focusing on timing, severity, associated symptoms, and impact on daily life. | |
Questions:""" | |
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device) | |
outputs = self.model.generate(inputs['input_ids'], max_new_tokens=50, temperature=0.7, do_sample=True) | |
questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return questions.split("\n") | |
except Exception as e: | |
logger.error(f"Error generating follow-up questions: {e}") | |
return ["Could you tell me more about when this started?"] | |
def assess_symptom_severity(self, message: str) -> str: | |
"""Assess severity based on keywords in the message""" | |
if "severe" in message.lower() or "emergency" in message.lower(): | |
return "emergency" | |
elif "persistent" in message.lower() or "moderate" in message.lower(): | |
return "urgent" | |
return "routine" | |
def generate_response(self, message: str) -> Dict[str, Any]: | |
"""Generate a response based on the message""" | |
try: | |
severity = self.assess_symptom_severity(message) | |
response = "" | |
# Retrieve relevant documents from FAISS | |
query_embedding = self.embedding_model.encode([message]) | |
_, indices = self.index.search(query_embedding, k=5) | |
relevant_docs = [self.documents[idx]['text'] for idx in indices[0]] | |
prompt = f"""As a compassionate medical assistant, analyze the patient message: "{message}". | |
Consider relevant knowledge and the following documents:\n{relevant_docs}. | |
Respond with empathy, follow-up questions, and care guidance.""" | |
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device) | |
outputs = self.model.generate(inputs['input_ids'], max_new_tokens=100, temperature=0.7, do_sample=True) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
follow_ups = self.generate_follow_up_questions(message, {}) | |
response += f"\n{follow_ups[0]}" | |
# Append response to conversation history | |
self.conversation_history.append((message, response)) | |
# Add care level guidance | |
if severity == "emergency": | |
response += "\nThis seems urgent. Please call 999 immediately." | |
elif severity == "urgent": | |
response += "\nConsider calling NHS 111 for urgent assistance." | |
return {'response': response} | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
return { | |
'response': "I'm experiencing technical issues. If this is an emergency, please call 999 immediately.", | |
} | |
def handle_feedback(self, message: str, response: str, feedback: int): | |
"""Update model based on feedback""" | |
try: | |
self.feedback_history.append({ | |
'message': message, | |
'response': response, | |
'feedback': feedback, | |
'timestamp': datetime.now().isoformat() | |
}) | |
if len(self.feedback_history) >= 10: | |
# Implement learning updates from feedback | |
self.feedback_history = [] # Reset history after learning update | |
except Exception as e: | |
logger.error(f"Error processing feedback: {e}") | |
def create_demo(): | |
"""Set up Gradio interface for the chatbot""" | |
try: | |
bot = AdaptiveMedicalBot() | |
def chat(message: str, history: List[Dict[str, str]]): | |
try: | |
bot_history = [(h["user"], h["bot"]) for h in history] if history else [] | |
response_data = bot.generate_response(message) | |
response = response_data['response'] | |
history.append({"role": "user", "content": message}) | |
history.append({"role": "assistant", "content": response}) | |
return history | |
except Exception as e: | |
logger.error(f"Chat error: {e}") | |
return history + [ | |
{"role": "user", "content": message}, | |
{"role": "assistant", "content": "I'm experiencing technical difficulties. For emergencies, call 999."} | |
] | |
def process_feedback(feedback: str, history: List[Dict[str, str]], comment: str = ""): | |
try: | |
if history and len(history) >= 2: | |
last_user_msg = history[-2]["content"] | |
last_bot_msg = history[-1]["content"] | |
bot.handle_feedback(last_user_msg, last_bot_msg, 1 if feedback == "π" else -1) | |
except Exception as e: | |
logger.error(f"Error processing feedback: {e}") | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot(value=[{"role": "assistant", "content": "Hello! I'm Pearly, your GP Triage medical assistant. How can I help you today?"}], | |
height=500, | |
elem_id="chatbot", | |
type="messages", | |
show_label=False | |
) | |
msg = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
lines=2 | |
) | |
submit = gr.Button("Send", variant="primary") | |
feedback = gr.Radio( | |
choices=["π", "π"], | |
label="Was this response helpful?", | |
visible=True | |
) | |
feedback_text = gr.Textbox( | |
label="Additional comments (optional)", | |
placeholder="Tell us more about your experience...", | |
lines=2 | |
) | |
# Event Handlers | |
submit.click( | |
fn=chat, | |
inputs=[msg, chatbot], | |
outputs=[chatbot] | |
).then( | |
lambda: "", | |
None, | |
msg | |
) | |
msg.submit( | |
fn=chat, | |
inputs=[msg, chatbot], | |
outputs=[chatbot] | |
).then( | |
lambda: "", | |
None, | |
msg | |
) | |
feedback.change( | |
fn=process_feedback, | |
inputs=[feedback, chatbot, feedback_text], | |
outputs=[] | |
) | |
# Clear Chat Handler | |
clear = gr.Button("ποΈ Clear Chat") | |
clear.click(lambda: [[], ""], None, [chatbot, msg]) | |
# Additional Information Sections | |
gr.HTML(""" | |
<div style="padding: 20px;"> | |
<h2>Quick Actions</h2> | |
<button onclick="document.getElementById('chatbot').value += 'Emergency Info: For emergencies, call 999.'">Emergency Info</button> | |
<button onclick="document.getElementById('chatbot').value += 'NHS 111 Info: For urgent but non-emergency situations, call 111.'">NHS 111 Info</button> | |
<button onclick="document.getElementById('chatbot').value += 'GP Booking Info: For routine appointments with your GP.'">GP Booking</button> | |
</div> | |
""") | |
gr.Markdown("### Example Messages") | |
gr.Examples( | |
examples=[ | |
["I've been having severe headaches for the past week"], | |
["I need to book a routine checkup"], | |
["I'm feeling very anxious lately and need help"], | |
["My child has had a fever for 2 days"], | |
["I need information about COVID-19 testing"] | |
], | |
inputs=msg | |
) | |
gr.Markdown(""" | |
### NHS Services Guide | |
**999 - Emergency Services**: For life-threatening emergencies, severe injuries, heart attack, stroke. | |
**NHS 111**: Available 24/7 for urgent but non-life-threatening situations, medical advice, and guidance. | |
**GP Services**: Routine check-ups, non-urgent medical issues, and prescription renewals. | |
""") | |
return demo | |
except Exception as e: | |
logger.error(f"Error creating demo: {e}") | |
raise | |
if __name__ == "__main__": | |
# Launch Gradio Interface | |
demo = create_demo() | |
demo.launch() | |