# app.py import os import logging import torch from typing import Dict, List, Any import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from sentence_transformers import SentenceTransformer from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training import faiss import numpy as np from datasets import load_dataset from datetime import datetime import json from huggingface_hub import login from dotenv import load_dotenv # Load environment variables load_dotenv() # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Retrieve secrets securely from environment variables hf_token = os.getenv("HF_TOKEN") if hf_token: login(token=hf_token) class AdaptiveMedicalBot: def __init__(self): self.config = self.AdaptiveBotConfig() self.setup_models() self.load_datasets() self.setup_adaptive_learning() self.conversation_history = [] # Maintain conversation history class AdaptiveBotConfig: MODEL_NAME = "google/gemma-7b" EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" LORA_R = 8 LORA_ALPHA = 16 LORA_DROPOUT = 0.1 LORA_TARGET_MODULES = ["q_proj", "v_proj"] MAX_LENGTH = 512 BATCH_SIZE = 1 LEARNING_RATE = 1e-4 def setup_adaptive_learning(self): """Initialize adaptive learning components""" self.feedback_history = [] def setup_models(self): """Initialize models with LoRA and quantization""" try: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16 ) self.tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME) base_model = AutoModelForCausalLM.from_pretrained( self.config.MODEL_NAME, quantization_config=bnb_config, device_map="auto" ) base_model = prepare_model_for_kbit_training(base_model) lora_config = LoraConfig( r=self.config.LORA_R, lora_alpha=self.config.LORA_ALPHA, target_modules=self.config.LORA_TARGET_MODULES, lora_dropout=self.config.LORA_DROPOUT, bias="none", task_type=TaskType.CAUSAL_LM ) self.model = get_peft_model(base_model, lora_config) self.embedding_model = SentenceTransformer(self.config.EMBEDDING_MODEL) except Exception as e: logger.error(f"Error setting up models: {e}") raise def load_datasets(self): """Load and prepare datasets for RAG""" try: datasets = { "medqa": load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]"), "diagnosis": load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]"), "persona": load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]") } self.documents = [] for dataset_name, dataset in datasets.items(): for item in dataset: if dataset_name == "persona": if isinstance(item.get('personality'), list): self.documents.append({'text': " ".join(item['personality']), 'type': 'persona'}) else: if 'input' in item and 'output' in item: self.documents.append({'text': f"{item['input']}\n{item['output']}", 'type': dataset_name}) self._create_index() except Exception as e: logger.error(f"Error loading datasets: {e}") raise def _create_index(self): """Create FAISS index for RAG""" try: sample_embedding = self.embedding_model.encode("sample text") self.index = faiss.IndexFlatIP(sample_embedding.shape[0]) embeddings = [self.embedding_model.encode(doc['text']) for doc in self.documents] self.index.add(np.array(embeddings)) except Exception as e: logger.error(f"Error creating FAISS index: {e}") raise def generate_follow_up_questions(self, message: str, context: Dict[str, Any]) -> List[str]: """Generate follow-up questions based on context""" try: prompt = f"""Patient message: "{message}" Generate relevant follow-up questions focusing on timing, severity, associated symptoms, and impact on daily life. Questions:""" inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device) outputs = self.model.generate(inputs['input_ids'], max_new_tokens=50, temperature=0.7, do_sample=True) questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return questions.split("\n") except Exception as e: logger.error(f"Error generating follow-up questions: {e}") return ["Could you tell me more about when this started?"] def assess_symptom_severity(self, message: str) -> str: """Assess severity based on keywords in the message""" if "severe" in message.lower() or "emergency" in message.lower(): return "emergency" elif "persistent" in message.lower() or "moderate" in message.lower(): return "urgent" return "routine" def generate_response(self, message: str) -> Dict[str, Any]: """Generate a response based on the message""" try: severity = self.assess_symptom_severity(message) response = "" # Retrieve relevant documents from FAISS query_embedding = self.embedding_model.encode([message]) _, indices = self.index.search(query_embedding, k=5) relevant_docs = [self.documents[idx]['text'] for idx in indices[0]] prompt = f"""As a compassionate medical assistant, analyze the patient message: "{message}". Consider relevant knowledge and the following documents:\n{relevant_docs}. Respond with empathy, follow-up questions, and care guidance.""" inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device) outputs = self.model.generate(inputs['input_ids'], max_new_tokens=100, temperature=0.7, do_sample=True) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) follow_ups = self.generate_follow_up_questions(message, {}) response += f"\n{follow_ups[0]}" # Append response to conversation history self.conversation_history.append((message, response)) # Add care level guidance if severity == "emergency": response += "\nThis seems urgent. Please call 999 immediately." elif severity == "urgent": response += "\nConsider calling NHS 111 for urgent assistance." return {'response': response} except Exception as e: logger.error(f"Error generating response: {e}") return { 'response': "I'm experiencing technical issues. If this is an emergency, please call 999 immediately.", } def handle_feedback(self, message: str, response: str, feedback: int): """Update model based on feedback""" try: self.feedback_history.append({ 'message': message, 'response': response, 'feedback': feedback, 'timestamp': datetime.now().isoformat() }) if len(self.feedback_history) >= 10: # Implement learning updates from feedback self.feedback_history = [] # Reset history after learning update except Exception as e: logger.error(f"Error processing feedback: {e}") def create_demo(): """Set up Gradio interface for the chatbot""" try: bot = AdaptiveMedicalBot() def chat(message: str, history: List[Dict[str, str]]): try: bot_history = [(h["user"], h["bot"]) for h in history] if history else [] response_data = bot.generate_response(message) response = response_data['response'] history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": response}) return history except Exception as e: logger.error(f"Chat error: {e}") return history + [ {"role": "user", "content": message}, {"role": "assistant", "content": "I'm experiencing technical difficulties. For emergencies, call 999."} ] def process_feedback(feedback: str, history: List[Dict[str, str]], comment: str = ""): try: if history and len(history) >= 2: last_user_msg = history[-2]["content"] last_bot_msg = history[-1]["content"] bot.handle_feedback(last_user_msg, last_bot_msg, 1 if feedback == "👍" else -1) except Exception as e: logger.error(f"Error processing feedback: {e}") with gr.Blocks() as demo: chatbot = gr.Chatbot(value=[{"role": "assistant", "content": "Hello! I'm Pearly, your GP Triage medical assistant. How can I help you today?"}], height=500, elem_id="chatbot", type="messages", show_label=False ) msg = gr.Textbox( label="Your message", placeholder="Type your message here...", lines=2 ) submit = gr.Button("Send", variant="primary") feedback = gr.Radio( choices=["👍", "👎"], label="Was this response helpful?", visible=True ) feedback_text = gr.Textbox( label="Additional comments (optional)", placeholder="Tell us more about your experience...", lines=2 ) # Event Handlers submit.click( fn=chat, inputs=[msg, chatbot], outputs=[chatbot] ).then( lambda: "", None, msg ) msg.submit( fn=chat, inputs=[msg, chatbot], outputs=[chatbot] ).then( lambda: "", None, msg ) feedback.change( fn=process_feedback, inputs=[feedback, chatbot, feedback_text], outputs=[] ) # Clear Chat Handler clear = gr.Button("🗑️ Clear Chat") clear.click(lambda: [[], ""], None, [chatbot, msg]) # Additional Information Sections gr.HTML("""

Quick Actions

""") gr.Markdown("### Example Messages") gr.Examples( examples=[ ["I've been having severe headaches for the past week"], ["I need to book a routine checkup"], ["I'm feeling very anxious lately and need help"], ["My child has had a fever for 2 days"], ["I need information about COVID-19 testing"] ], inputs=msg ) gr.Markdown(""" ### NHS Services Guide **999 - Emergency Services**: For life-threatening emergencies, severe injuries, heart attack, stroke. **NHS 111**: Available 24/7 for urgent but non-life-threatening situations, medical advice, and guidance. **GP Services**: Routine check-ups, non-urgent medical issues, and prescription renewals. """) return demo except Exception as e: logger.error(f"Error creating demo: {e}") raise if __name__ == "__main__": # Launch Gradio Interface demo = create_demo() demo.launch()