PearlIsa commited on
Commit
4cabc00
1 Parent(s): f6e71bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -740
app.py CHANGED
@@ -9,16 +9,12 @@ from sentence_transformers import SentenceTransformer
9
  from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
10
  import faiss
11
  import numpy as np
12
- from tqdm import tqdm
13
  from datasets import load_dataset
14
- from dataclasses import dataclass, field
15
  from datetime import datetime
16
  import json
17
  from huggingface_hub import login
18
  from dotenv import load_dotenv
19
 
20
-
21
-
22
  # Load environment variables
23
  load_dotenv()
24
 
@@ -30,62 +26,32 @@ logging.basicConfig(
30
  logger = logging.getLogger(__name__)
31
 
32
  # Retrieve secrets securely from environment variables
33
- kaggle_username = os.getenv("KAGGLE_USERNAME")
34
- kaggle_key = os.getenv("KAGGLE_KEY")
35
  hf_token = os.getenv("HF_TOKEN")
36
- wandb_key = os.getenv("WANDB_API_KEY")
37
-
38
- # Log in to Hugging Face
39
  if hf_token:
40
  login(token=hf_token)
41
- else:
42
- logger.warning("Hugging Face token not found in environment variables.")
43
-
44
- @dataclass
45
- class AdaptiveBotConfig:
46
- """Configuration for adaptive medical triage bot"""
47
- MODEL_NAME: str = "google/gemma-7b"
48
- EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
49
-
50
- # LoRA parameters
51
- LORA_R: int = 8
52
- LORA_ALPHA: int = 16
53
- LORA_DROPOUT: float = 0.1
54
- LORA_TARGET_MODULES: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
55
-
56
- # Training parameters
57
- MAX_LENGTH: int = 512
58
- BATCH_SIZE: int = 1
59
- LEARNING_RATE: float = 1e-4
60
-
61
- # Adaptive learning parameters
62
- MIN_FEEDBACK_FOR_UPDATE: int = 5
63
- FEEDBACK_HISTORY_SIZE: int = 100
64
- LEARNING_RATE_DECAY: float = 0.95
65
 
66
  class AdaptiveMedicalBot:
67
  def __init__(self):
68
- self.config = AdaptiveBotConfig()
69
  self.setup_models()
70
  self.load_datasets()
71
  self.setup_adaptive_learning()
72
- self.document_relevance = {}
73
-
 
 
 
 
 
 
 
 
 
 
 
74
  def setup_adaptive_learning(self):
75
  """Initialize adaptive learning components"""
76
  self.feedback_history = []
77
- self.conversation_patterns = {}
78
- self.learning_buffer = []
79
-
80
- # Load existing learning data if available
81
- try:
82
- if os.path.exists('learning_data.json'):
83
- with open('learning_data.json', 'r') as f:
84
- data = json.load(f)
85
- self.conversation_patterns = data.get('patterns', {})
86
- self.feedback_history = data.get('feedback', [])
87
- except Exception as e:
88
- logger.warning(f"Could not load learning data: {e}")
89
 
90
  def setup_models(self):
91
  """Initialize models with LoRA and quantization"""
@@ -95,21 +61,15 @@ class AdaptiveMedicalBot:
95
  bnb_4bit_quant_type="nf4",
96
  bnb_4bit_compute_dtype=torch.float16
97
  )
98
-
99
- self.tokenizer = AutoTokenizer.from_pretrained(
100
- self.config.MODEL_NAME,
101
- trust_remote_code=True
102
- )
103
-
104
  base_model = AutoModelForCausalLM.from_pretrained(
105
  self.config.MODEL_NAME,
106
  quantization_config=bnb_config,
107
- device_map="auto",
108
- trust_remote_code=True
109
  )
110
-
111
  base_model = prepare_model_for_kbit_training(base_model)
112
-
113
  lora_config = LoraConfig(
114
  r=self.config.LORA_R,
115
  lora_alpha=self.config.LORA_ALPHA,
@@ -118,720 +78,185 @@ class AdaptiveMedicalBot:
118
  bias="none",
119
  task_type=TaskType.CAUSAL_LM
120
  )
121
-
122
  self.model = get_peft_model(base_model, lora_config)
123
  self.embedding_model = SentenceTransformer(self.config.EMBEDDING_MODEL)
124
-
125
  except Exception as e:
126
  logger.error(f"Error setting up models: {e}")
127
  raise
128
 
129
  def load_datasets(self):
130
- """Load and process datasets for RAG"""
131
  try:
132
  datasets = {
133
  "medqa": load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]"),
134
  "diagnosis": load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]"),
135
  "persona": load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
136
  }
137
-
138
  self.documents = []
139
-
140
  for dataset_name, dataset in datasets.items():
141
  for item in dataset:
142
  if dataset_name == "persona":
143
  if isinstance(item.get('personality'), list):
144
- self.documents.append({
145
- 'text': " ".join(item['personality']),
146
- 'type': 'persona'
147
- })
148
  else:
149
  if 'input' in item and 'output' in item:
150
- self.documents.append({
151
- 'text': f"{item['input']}\n{item['output']}",
152
- 'type': dataset_name
153
- })
154
-
155
  self._create_index()
156
-
157
  except Exception as e:
158
  logger.error(f"Error loading datasets: {e}")
159
  raise
160
 
161
  def _create_index(self):
162
  """Create FAISS index for RAG"""
163
- sample_embedding = self.embedding_model.encode("sample text")
164
- self.index = faiss.IndexFlatIP(sample_embedding.shape[0])
165
-
166
- batch_size = 32
167
- for i in range(0, len(self.documents), batch_size):
168
- batch = self.documents[i:i + batch_size]
169
- texts = [doc['text'] for doc in batch]
170
- embeddings = self.embedding_model.encode(texts)
171
- self.index.add(np.array(embeddings))
172
-
173
- def analyze_conversation_context(self, message: str, history: List[tuple]) -> Dict[str, Any]:
174
- """Analyze conversation context to determine appropriate follow-up questions"""
175
  try:
176
- # Extract key information
177
- mentioned_symptoms = set()
178
- time_indicators = set()
179
- severity_indicators = set()
180
-
181
- # Analyze current message and history
182
- for msg in [message] + [h[0] for h in (history or [])]:
183
- msg_lower = msg.lower()
184
-
185
- # Update conversation patterns
186
- pattern_key = self._extract_pattern_key(msg_lower)
187
- if pattern_key in self.conversation_patterns:
188
- self.conversation_patterns[pattern_key]['frequency'] += 1
189
- else:
190
- self.conversation_patterns[pattern_key] = {
191
- 'frequency': 1,
192
- 'successful_responses': []
193
- }
194
-
195
- return {
196
- 'needs_follow_up': True, # Always encourage follow-up questions
197
- 'conversation_depth': len(history) if history else 0,
198
- 'pattern_key': pattern_key
199
- }
200
-
201
- except Exception as e:
202
- logger.error(f"Error analyzing conversation: {e}")
203
- return {'needs_follow_up': True}
204
 
205
- def _extract_pattern_key(self, message: str) -> str:
206
- """Extract conversation pattern key for learning"""
207
- # Simplified pattern extraction - can be enhanced based on learning
208
- words = message.lower().split()
209
- return " ".join(sorted(set(words))[:5])
210
 
211
- def generate_follow_up_questions(self, context: Dict[str, Any], message: str) -> List[str]:
212
- """Generate contextual follow-up questions"""
213
  try:
214
- # Use the model to generate follow-up questions
215
- prompt = f"""Given the patient message: "{message}"
216
- Generate relevant follow-up questions to better understand their situation.
217
- Focus on: timing, severity, associated symptoms, impact on daily life.
218
- Do not make diagnoses or suggest treatments.
219
-
220
  Questions:"""
221
-
222
- inputs = self.tokenizer(
223
- prompt,
224
- return_tensors="pt",
225
- max_length=self.config.MAX_LENGTH,
226
- truncation=True
227
- ).to(self.model.device)
228
-
229
- outputs = self.model.generate(
230
- **inputs,
231
- max_new_tokens=100,
232
- temperature=0.7,
233
- do_sample=True
234
- )
235
-
236
  questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
237
- return questions.split('\n')
238
-
239
  except Exception as e:
240
  logger.error(f"Error generating follow-up questions: {e}")
241
  return ["Could you tell me more about when this started?"]
242
 
243
- def update_from_feedback(self, message: str, response: str, feedback: int):
244
- """Process feedback for adaptive learning"""
245
- try:
246
- self.feedback_history.append({
247
- 'message': message,
248
- 'response': response,
249
- 'feedback': feedback,
250
- 'timestamp': datetime.now().isoformat()
251
- })
252
-
253
- # Update conversation patterns
254
- pattern_key = self._extract_pattern_key(message)
255
- if pattern_key in self.conversation_patterns:
256
- if feedback > 0:
257
- self.conversation_patterns[pattern_key]['successful_responses'].append(response)
258
-
259
- # Save learning data periodically
260
- if len(self.feedback_history) % 10 == 0:
261
- self._save_learning_data()
262
-
263
- # Update model if enough feedback
264
- if len(self.feedback_history) >= self.config.MIN_FEEDBACK_FOR_UPDATE:
265
- self._update_model_from_feedback()
266
-
267
- except Exception as e:
268
- logger.error(f"Error processing feedback: {e}")
269
 
270
- def store_interaction(self, interaction: Dict[str, Any]):
271
- """Store interaction for adaptive learning"""
272
  try:
273
- self.learning_buffer.append(interaction)
274
-
275
- # Update learning patterns
276
- if len(self.learning_buffer) >= 10:
277
- self._update_learning_model()
278
- self._save_learning_data()
279
- self.learning_buffer = []
280
-
281
- except Exception as e:
282
- logger.error(f"Error storing interaction: {e}")
283
 
284
- def _update_learning_model(self):
285
- """Update model based on accumulated learning"""
286
- try:
287
- # Process learning buffer
288
- successful_interactions = [
289
- interaction for interaction in self.learning_buffer
290
- if interaction.get('feedback', 0) > 0
291
- ]
292
-
293
- if successful_interactions:
294
- # Update conversation patterns
295
- for interaction in successful_interactions:
296
- pattern_key = self._extract_pattern_key(interaction['message'])
297
- if pattern_key in self.conversation_patterns:
298
- self.conversation_patterns[pattern_key]['successful_responses'].append(
299
- interaction['response']
300
- )
301
-
302
- # Update document relevance
303
- for interaction in successful_interactions:
304
- for doc in interaction.get('relevant_docs', []):
305
- doc_key = doc['text'][:100]
306
- if doc_key in self.document_relevance:
307
- self.document_relevance[doc_key]['success_count'] += 1
308
-
309
- logger.info("Updated learning model with new patterns")
310
-
311
- except Exception as e:
312
- logger.error(f"Error updating learning model: {e}")
313
-
314
- def generate_context_questions(self, message: str, history: List[tuple], context: Dict[str, Any]) -> List[str]:
315
- """Generate context-aware follow-up questions"""
316
- try:
317
- # Create dynamic question generation prompt
318
- prompt = f"""Based on the conversation context, generate appropriate follow-up questions.
319
- Consider:
320
- - Understanding the main concern
321
- - Timeline and progression
322
- - Impact on daily life
323
- - Related symptoms or factors
324
- - Previous treatments or consultations
325
-
326
- Current message: {message}
327
- Context: {context}
328
-
329
- Generate questions:"""
330
-
331
- inputs = self.tokenizer(
332
- prompt,
333
- return_tensors="pt",
334
- max_length=self.config.MAX_LENGTH,
335
- truncation=True
336
- ).to(self.model.device)
337
-
338
- outputs = self.model.generate(
339
- **inputs,
340
- max_new_tokens=150,
341
- temperature=0.7,
342
- do_sample=True
343
- )
344
-
345
- questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
346
- return [q.strip() for q in questions.split("\n") if "?" in q]
347
-
348
- except Exception as e:
349
- logger.error(f"Error generating context questions: {e}")
350
- return ["Could you tell me more about your concerns?"]
351
 
352
- def _save_learning_data(self):
353
- """Save learning data to disk"""
354
- try:
355
- data = {
356
- 'patterns': self.conversation_patterns,
357
- 'feedback': self.feedback_history[-100:] # Keep last 100 entries
358
- }
359
- with open('learning_data.json', 'w') as f:
360
- json.dump(data, f)
361
- except Exception as e:
362
- logger.error(f"Error saving learning data: {e}")
363
 
364
- def _update_model_from_feedback(self):
365
- """Update model based on feedback"""
366
- try:
367
- positive_feedback = [f for f in self.feedback_history if f['feedback'] > 0]
368
- if len(positive_feedback) >= self.config.MIN_FEEDBACK_FOR_UPDATE:
369
- # Prepare training data from successful interactions
370
- training_data = []
371
- for feedback in positive_feedback:
372
- training_data.append({
373
- 'input_ids': self.tokenizer(
374
- feedback['message'],
375
- return_tensors='pt'
376
- ).input_ids,
377
- 'labels': self.tokenizer(
378
- feedback['response'],
379
- return_tensors='pt'
380
- ).input_ids
381
- })
382
-
383
- # Update model (simplified for example)
384
- logger.info("Updating model from feedback")
385
- self.feedback_history = [] # Clear history after update
386
-
387
- except Exception as e:
388
- logger.error(f"Error updating model from feedback: {e}")
389
 
390
- def analyze_symptom_context(self, message: str, history: List[tuple]) -> Dict[str, Any]:
391
- """Enhanced symptom analysis with context learning"""
392
- try:
393
- current_symptoms = set()
394
- temporal_info = {}
395
- related_conditions = set()
396
- conversation_depth = len(history) if history else 0
397
-
398
- # Analyze full conversation context
399
- all_messages = [message] + [h[0] for h in (history or [])]
400
- all_responses = [h[1] for h in (history or [])]
401
 
402
- # Extract conversation context
403
- context = {
404
- 'symptoms_mentioned': current_symptoms,
405
- 'temporal_info': temporal_info,
406
- 'conversation_depth': conversation_depth,
407
- 'needs_clarification': True,
408
- 'specialist_referral_needed': False,
409
- 'previous_questions': set()
410
- }
411
-
412
- if history:
413
- # Learn from previous interactions
414
- for prev_msg, prev_resp in history:
415
- if "?" in prev_resp:
416
- context['previous_questions'].add(prev_resp.split("?")[0] + "?")
417
-
418
- return context
419
-
420
- except Exception as e:
421
- logger.error(f"Error in symptom analysis: {e}")
422
- return {'needs_clarification': True}
423
 
424
- def generate_targeted_questions(self, symptoms: Dict[str, Any], history: List[tuple]) -> List[str]:
425
- """Generate context-aware follow-up questions"""
426
- try:
427
- # Use the model to generate relevant questions based on context
428
- context_prompt = f"""Based on the patient's symptoms and conversation history, generate 3 specific follow-up questions.
429
- Focus on:
430
- 1. Symptom details (duration, severity, patterns)
431
- 2. Impact on daily life
432
- 3. Related symptoms or conditions
433
- 4. Previous treatments or consultations
434
-
435
- Do not ask about:
436
- - Questions already asked
437
- - Diagnostic conclusions
438
- - Treatment recommendations
439
-
440
- Current context: {symptoms}
441
- Previous questions asked: {symptoms.get('previous_questions', set())}
442
-
443
- Generate questions:"""
444
-
445
- inputs = self.tokenizer(
446
- context_prompt,
447
- return_tensors="pt",
448
- max_length=self.config.MAX_LENGTH,
449
- truncation=True
450
- ).to(self.model.device)
451
-
452
- outputs = self.model.generate(
453
- **inputs,
454
- max_new_tokens=150,
455
- temperature=0.7,
456
- do_sample=True
457
- )
458
-
459
- questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
460
- return [q.strip() for q in questions.split("\n") if "?" in q]
461
-
462
- except Exception as e:
463
- logger.error(f"Error generating questions: {e}")
464
- return ["Could you tell me more about your symptoms?"]
465
 
466
- def analyze_medical_context(self, message: str, history: List[tuple]) -> Dict[str, Any]:
467
- """Comprehensive medical context analysis"""
468
- try:
469
- # Initialize context tracking
470
- context = {
471
- 'conversation_depth': len(history) if history else 0,
472
- 'needs_follow_up': True,
473
- 'previous_interactions': [],
474
- 'care_pathway': 'initial_triage',
475
- 'consultation_type': 'general',
476
- }
477
-
478
- # Analyze current conversation flow
479
- all_messages = [message] + [h[0] for h in (history or [])]
480
-
481
- # Build contextual understanding
482
- for msg in all_messages:
483
- msg_lower = msg.lower()
484
-
485
- # Track conversation patterns
486
- pattern_key = self._extract_pattern_key(msg_lower)
487
- if pattern_key in self.conversation_patterns:
488
- self.conversation_patterns[pattern_key]['frequency'] += 1
489
- context['previous_patterns'] = self.conversation_patterns[pattern_key]
490
-
491
- # Update learning patterns
492
- self._update_learning_patterns(msg_lower, context)
493
-
494
- return context
495
-
496
  except Exception as e:
497
- logger.error(f"Error in medical context analysis: {e}")
498
- return {'needs_follow_up': True}
499
-
500
- def generate_adaptive_response(self, message: str, history: List[tuple] = None) -> Dict[str, Any]:
501
- """Generate comprehensive triage response"""
 
 
502
  try:
503
- # Analyze medical context
504
- context = self.analyze_medical_context(message, history)
505
-
506
- # Retrieve relevant knowledge
507
- query_embedding = self.embedding_model.encode([message])
508
- _, indices = self.index.search(query_embedding, k=5)
509
- relevant_docs = [self.documents[idx] for idx in indices[0]]
510
-
511
- # Build conversation history
512
- conv_history = "\n".join([f"Patient: {h[0]}\nPearly: {h[1]}" for h in (history or [])])
513
-
514
- # Create dynamic prompt based on context
515
- prompt = f"""As Pearly, a compassionate GP medical triage assistant, help assess the patient's needs and provide appropriate guidance.
516
-
517
- Previous Conversation:
518
- {conv_history}
519
-
520
- Current Message: {message}
521
-
522
- Medical Knowledge Context:
523
- {[doc['text'] for doc in relevant_docs]}
524
-
525
- Guidelines:
526
- - Show empathy and understanding
527
- - Ask relevant follow-up questions
528
- - Guide to appropriate care level (GP, 111, emergency services)
529
- - Consider all aspects of patient care
530
- - Do not diagnose or recommend treatments
531
- - Focus on understanding concerns and proper healthcare guidance
532
-
533
- Response:"""
534
-
535
- # Generate base response
536
- inputs = self.tokenizer(
537
- prompt,
538
- return_tensors="pt",
539
- max_length=self.config.MAX_LENGTH,
540
- truncation=True
541
- ).to(self.model.device)
542
-
543
- outputs = self.model.generate(
544
- **inputs,
545
- max_new_tokens=300,
546
- temperature=0.7,
547
- do_sample=True
548
- )
549
-
550
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
551
-
552
- # Generate contextual follow-up questions
553
- if context['needs_follow_up']:
554
- follow_ups = self.generate_context_questions(message, history, context)
555
- if follow_ups:
556
- response = f"{response}\n\n{follow_ups[0]}"
557
-
558
- # Store interaction for learning
559
- self.store_interaction({
560
  'message': message,
561
  'response': response,
562
- 'context': context,
563
- 'relevant_docs': relevant_docs,
564
  'timestamp': datetime.now().isoformat()
565
  })
566
-
567
- return {
568
- 'response': response,
569
- 'context': context
570
- }
571
-
572
- except Exception as e:
573
- logger.error(f"Error generating response: {e}")
574
- return {
575
- 'response': "I apologize, but I'm having technical difficulties. If this is an emergency, please call 999 immediately. For urgent concerns, call 111.",
576
- 'context': {}
577
- }
578
 
579
-
580
- def _add_specialist_guidance(self, response: str, context: Dict[str, Any]) -> str:
581
- """Add specialist referral guidance to response"""
582
- try:
583
- specialist_prompt = f"""Based on the symptoms and context, suggest appropriate specialist care pathways.
584
- Context: {context}
585
- Current response: {response}
586
-
587
- Add appropriate specialist referral guidance:"""
588
-
589
- inputs = self.tokenizer(
590
- specialist_prompt,
591
- return_tensors="pt",
592
- max_length=self.config.MAX_LENGTH,
593
- truncation=True
594
- ).to(self.model.device)
595
-
596
- outputs = self.model.generate(
597
- **inputs,
598
- max_new_tokens=150,
599
- temperature=0.7,
600
- do_sample=True
601
- )
602
-
603
- specialist_guidance = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
604
-
605
- return f"{response}\n\n{specialist_guidance}"
606
-
607
- except Exception as e:
608
- logger.error(f"Error adding specialist guidance: {e}")
609
- return response
610
-
611
- def update_learning_from_interaction(self, interaction: Dict[str, Any]):
612
- """Update adaptive learning system from interaction"""
613
- try:
614
- # Extract key information
615
- message = interaction['message']
616
- response = interaction['response']
617
- context = interaction['context']
618
- relevant_docs = interaction.get('relevant_docs', [])
619
-
620
- # Update conversation patterns
621
- pattern_key = self._extract_pattern_key(message)
622
- if pattern_key in self.conversation_patterns:
623
- self.conversation_patterns[pattern_key]['frequency'] += 1
624
- if context.get('successful_response'):
625
- self.conversation_patterns[pattern_key]['successful_responses'].append(response)
626
-
627
- # Update document relevance scores
628
- for doc in relevant_docs:
629
- doc_key = doc['text'][:100] # Use first 100 chars as key
630
- if doc_key in self.document_relevance:
631
- self.document_relevance[doc_key]['usage_count'] += 1
632
- if context.get('successful_response'):
633
- self.document_relevance[doc_key]['success_count'] += 1
634
-
635
- # Save learning data periodically
636
- if len(self.learning_buffer) >= 10:
637
- self._save_learning_data()
638
- self.learning_buffer = []
639
-
640
  except Exception as e:
641
- logger.error(f"Error updating learning system: {e}")
642
 
643
  def create_demo():
644
- """Create enhanced Gradio interface with advanced features and styling"""
645
  try:
646
  bot = AdaptiveMedicalBot()
647
-
648
- def chat(message: str, history: List[Dict[str, str]], urgency_filter: str = "all"):
649
  try:
650
- # Convert history to the format expected by the bot
651
  bot_history = [(h["user"], h["bot"]) for h in history] if history else []
652
-
653
- # Generate response
654
- response_data = bot.generate_adaptive_response(message, bot_history)
655
  response = response_data['response']
656
-
657
- # Format response for Gradio chat
658
  history.append({"role": "user", "content": message})
659
  history.append({"role": "assistant", "content": response})
660
-
661
  return history
662
-
663
  except Exception as e:
664
  logger.error(f"Chat error: {e}")
665
  return history + [
666
  {"role": "user", "content": message},
667
- {"role": "assistant", "content": "I apologize, but I'm experiencing technical difficulties. For emergencies, please call 999."}
668
  ]
669
 
670
  def process_feedback(feedback: str, history: List[Dict[str, str]], comment: str = ""):
671
- """Process feedback with optional comment"""
672
  try:
673
  if history and len(history) >= 2:
674
  last_user_msg = history[-2]["content"]
675
  last_bot_msg = history[-1]["content"]
676
- bot.update_from_feedback(
677
- last_user_msg,
678
- last_bot_msg,
679
- 1 if feedback == "👍" else -1,
680
- comment=comment
681
- )
682
  except Exception as e:
683
  logger.error(f"Error processing feedback: {e}")
684
 
685
- # Create enhanced Gradio interface
686
- with gr.Blocks(theme=gr.themes.Soft(
687
- primary_hue="blue",
688
- secondary_hue="indigo",
689
- neutral_hue="slate",
690
- font=gr.themes.GoogleFont("Inter")
691
- )) as demo:
692
- # Custom CSS for enhanced styling
693
- gr.HTML("""
694
- <style>
695
- .container { max-width: 900px; margin: auto; }
696
- .header { text-align: center; padding: 20px; }
697
- .emergency-banner {
698
- background-color: #ff4444;
699
- color: white;
700
- padding: 10px;
701
- text-align: center;
702
- font-weight: bold;
703
- margin-bottom: 20px;
704
- }
705
- .features-grid {
706
- display: grid;
707
- grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
708
- gap: 20px;
709
- padding: 20px;
710
- }
711
- .feature-card {
712
- background: #f8f9fa;
713
- padding: 15px;
714
- border-radius: 10px;
715
- text-align: center;
716
- }
717
- </style>
718
- """)
719
 
720
- # Emergency Banner
721
- gr.HTML("""
722
- <div class="emergency-banner">
723
- 🚨 For medical emergencies, always call 999 immediately 🚨
724
- </div>
725
- """)
726
-
727
- # Header Section
728
- with gr.Row(elem_classes="header"):
729
- gr.Markdown("""
730
- # GP Medical Triage Assistant - Pearly
731
-
732
- Welcome to your personal medical triage assistant. I'm here to help assess your symptoms and guide you to appropriate care.
733
- """)
734
-
735
- # Main Features Grid
736
- gr.HTML("""
737
- <div class="features-grid">
738
- <div class="feature-card">
739
- 🏥 GP Appointments
740
- </div>
741
- <div class="feature-card">
742
- 🔍 Symptom Assessment
743
- </div>
744
- <div class="feature-card">
745
- ⚡ Urgent Care Guide
746
- </div>
747
- <div class="feature-card">
748
- 💊 Medical Advice
749
- </div>
750
- </div>
751
- """)
752
-
753
- # Chat Interface
754
- with gr.Row():
755
- with gr.Column(scale=4):
756
- chatbot = gr.Chatbot(
757
- value=[{
758
- "role": "assistant",
759
- "content": "Hello! I'm Pearly, your GP medical assistant. How can I help you today?"
760
- }],
761
- height=500,
762
- elem_id="chatbot",
763
- type="messages",
764
- show_label=False
765
- )
766
-
767
- with gr.Row():
768
- msg = gr.Textbox(
769
- label="Your message",
770
- placeholder="Type your message here...",
771
- lines=2,
772
- scale=4
773
- )
774
- submit = gr.Button("Send", variant="primary", scale=1)
775
-
776
- with gr.Column(scale=1):
777
- # Quick Actions Panel
778
- gr.Markdown("### Quick Actions")
779
- emergency_btn = gr.Button("🚨 Emergency Info", variant="secondary")
780
- nhs_111_btn = gr.Button("📞 NHS 111 Info", variant="secondary")
781
- booking_btn = gr.Button("📅 GP Booking", variant="secondary")
782
-
783
- # Conversation Controls
784
- gr.Markdown("### Controls")
785
- clear = gr.Button("🗑️ Clear Chat")
786
-
787
- # Feedback Section
788
- gr.Markdown("### Feedback")
789
- feedback = gr.Radio(
790
- choices=["👍", "👎"],
791
- label="Was this response helpful?",
792
- visible=True
793
- )
794
- feedback_text = gr.Textbox(
795
- label="Additional comments (optional)",
796
- placeholder="Tell us more about your experience...",
797
- lines=2
798
- )
799
-
800
- # Examples Section
801
- with gr.Accordion("Example Messages", open=False):
802
- gr.Examples(
803
- examples=[
804
- ["I've been having severe headaches for the past week"],
805
- ["I need to book a routine checkup"],
806
- ["I'm feeling very anxious lately and need help"],
807
- ["My child has had a fever for 2 days"],
808
- ["I need information about COVID-19 testing"]
809
- ],
810
- inputs=msg
811
- )
812
-
813
- # Information Accordions
814
- with gr.Accordion("NHS Services Guide", open=False):
815
- gr.Markdown("""
816
- ### Emergency Services (999)
817
- - Life-threatening emergencies
818
- - Severe injuries
819
- - Suspected heart attack or stroke
820
-
821
- ### NHS 111
822
- - Urgent but non-emergency situations
823
- - Medical advice needed
824
- - Unsure where to go
825
-
826
- ### GP Services
827
- - Routine check-ups
828
- - Non-urgent medical issues
829
- - Prescription renewals
830
- """)
831
-
832
  # Event Handlers
833
  submit.click(
834
- chat,
835
  inputs=[msg, chatbot],
836
  outputs=[chatbot]
837
  ).then(
@@ -841,7 +266,7 @@ def create_demo():
841
  )
842
 
843
  msg.submit(
844
- chat,
845
  inputs=[msg, chatbot],
846
  outputs=[chatbot]
847
  ).then(
@@ -850,63 +275,56 @@ def create_demo():
850
  msg
851
  )
852
 
853
- clear.click(
854
- lambda: [[], ""],
855
- None,
856
- [chatbot, msg]
857
- )
858
-
859
  feedback.change(
860
- process_feedback,
861
  inputs=[feedback, chatbot, feedback_text],
862
  outputs=[]
863
  )
864
-
865
- # Quick Action Button Handlers
866
- def show_emergency_info():
867
- return """🚨 Emergency Services (999)
868
- - For life-threatening emergencies
869
- - Severe chest pain
870
- - Difficulty breathing
871
- - Severe bleeding
872
- - Loss of consciousness
873
- """
874
-
875
- def show_nhs_111_info():
876
- return """📞 NHS 111 Service
877
- - Available 24/7
878
- - Medical advice
879
- - Local service information
880
- - Urgent care guidance
881
- """
882
-
883
- def show_booking_info():
884
- return """📅 GP Booking Options
885
- - Online booking
886
- - Phone booking
887
- - Routine appointments
888
- - Urgent appointments
889
- """
890
-
891
- emergency_btn.click(lambda: show_emergency_info(), outputs=[msg])
892
- nhs_111_btn.click(lambda: show_nhs_111_info(), outputs=[msg])
893
- booking_btn.click(lambda: show_booking_info(), outputs=[msg])
 
 
 
 
 
894
 
895
  return demo
896
-
897
  except Exception as e:
898
  logger.error(f"Error creating demo: {e}")
899
  raise
900
 
901
  if __name__ == "__main__":
902
- # Initialize environment
903
- load_dotenv()
904
-
905
- # Set up HuggingFace login if token exists
906
- hf_token = os.getenv("HF_TOKEN")
907
- if hf_token:
908
- login(token=hf_token)
909
-
910
- # Launch demo
911
  demo = create_demo()
912
- demo.launch()
 
 
 
9
  from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
10
  import faiss
11
  import numpy as np
 
12
  from datasets import load_dataset
 
13
  from datetime import datetime
14
  import json
15
  from huggingface_hub import login
16
  from dotenv import load_dotenv
17
 
 
 
18
  # Load environment variables
19
  load_dotenv()
20
 
 
26
  logger = logging.getLogger(__name__)
27
 
28
  # Retrieve secrets securely from environment variables
 
 
29
  hf_token = os.getenv("HF_TOKEN")
 
 
 
30
  if hf_token:
31
  login(token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  class AdaptiveMedicalBot:
34
  def __init__(self):
35
+ self.config = self.AdaptiveBotConfig()
36
  self.setup_models()
37
  self.load_datasets()
38
  self.setup_adaptive_learning()
39
+ self.conversation_history = [] # Maintain conversation history
40
+
41
+ class AdaptiveBotConfig:
42
+ MODEL_NAME = "google/gemma-7b"
43
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
44
+ LORA_R = 8
45
+ LORA_ALPHA = 16
46
+ LORA_DROPOUT = 0.1
47
+ LORA_TARGET_MODULES = ["q_proj", "v_proj"]
48
+ MAX_LENGTH = 512
49
+ BATCH_SIZE = 1
50
+ LEARNING_RATE = 1e-4
51
+
52
  def setup_adaptive_learning(self):
53
  """Initialize adaptive learning components"""
54
  self.feedback_history = []
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def setup_models(self):
57
  """Initialize models with LoRA and quantization"""
 
61
  bnb_4bit_quant_type="nf4",
62
  bnb_4bit_compute_dtype=torch.float16
63
  )
64
+
65
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME)
 
 
 
 
66
  base_model = AutoModelForCausalLM.from_pretrained(
67
  self.config.MODEL_NAME,
68
  quantization_config=bnb_config,
69
+ device_map="auto"
 
70
  )
71
+
72
  base_model = prepare_model_for_kbit_training(base_model)
 
73
  lora_config = LoraConfig(
74
  r=self.config.LORA_R,
75
  lora_alpha=self.config.LORA_ALPHA,
 
78
  bias="none",
79
  task_type=TaskType.CAUSAL_LM
80
  )
81
+
82
  self.model = get_peft_model(base_model, lora_config)
83
  self.embedding_model = SentenceTransformer(self.config.EMBEDDING_MODEL)
 
84
  except Exception as e:
85
  logger.error(f"Error setting up models: {e}")
86
  raise
87
 
88
  def load_datasets(self):
89
+ """Load and prepare datasets for RAG"""
90
  try:
91
  datasets = {
92
  "medqa": load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]"),
93
  "diagnosis": load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]"),
94
  "persona": load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
95
  }
96
+
97
  self.documents = []
 
98
  for dataset_name, dataset in datasets.items():
99
  for item in dataset:
100
  if dataset_name == "persona":
101
  if isinstance(item.get('personality'), list):
102
+ self.documents.append({'text': " ".join(item['personality']), 'type': 'persona'})
 
 
 
103
  else:
104
  if 'input' in item and 'output' in item:
105
+ self.documents.append({'text': f"{item['input']}\n{item['output']}", 'type': dataset_name})
106
+
 
 
 
107
  self._create_index()
 
108
  except Exception as e:
109
  logger.error(f"Error loading datasets: {e}")
110
  raise
111
 
112
  def _create_index(self):
113
  """Create FAISS index for RAG"""
 
 
 
 
 
 
 
 
 
 
 
 
114
  try:
115
+ sample_embedding = self.embedding_model.encode("sample text")
116
+ self.index = faiss.IndexFlatIP(sample_embedding.shape[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ embeddings = [self.embedding_model.encode(doc['text']) for doc in self.documents]
119
+ self.index.add(np.array(embeddings))
120
+ except Exception as e:
121
+ logger.error(f"Error creating FAISS index: {e}")
122
+ raise
123
 
124
+ def generate_follow_up_questions(self, message: str, context: Dict[str, Any]) -> List[str]:
125
+ """Generate follow-up questions based on context"""
126
  try:
127
+ prompt = f"""Patient message: "{message}"
128
+ Generate relevant follow-up questions focusing on timing, severity, associated symptoms, and impact on daily life.
 
 
 
 
129
  Questions:"""
130
+
131
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device)
132
+ outputs = self.model.generate(inputs['input_ids'], max_new_tokens=50, temperature=0.7, do_sample=True)
 
 
 
 
 
 
 
 
 
 
 
 
133
  questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
134
+ return questions.split("\n")
 
135
  except Exception as e:
136
  logger.error(f"Error generating follow-up questions: {e}")
137
  return ["Could you tell me more about when this started?"]
138
 
139
+ def assess_symptom_severity(self, message: str) -> str:
140
+ """Assess severity based on keywords in the message"""
141
+ if "severe" in message.lower() or "emergency" in message.lower():
142
+ return "emergency"
143
+ elif "persistent" in message.lower() or "moderate" in message.lower():
144
+ return "urgent"
145
+ return "routine"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ def generate_response(self, message: str) -> Dict[str, Any]:
148
+ """Generate a response based on the message"""
149
  try:
150
+ severity = self.assess_symptom_severity(message)
151
+ response = ""
 
 
 
 
 
 
 
 
152
 
153
+ # Retrieve relevant documents from FAISS
154
+ query_embedding = self.embedding_model.encode([message])
155
+ _, indices = self.index.search(query_embedding, k=5)
156
+ relevant_docs = [self.documents[idx]['text'] for idx in indices[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ prompt = f"""As a compassionate medical assistant, analyze the patient message: "{message}".
159
+ Consider relevant knowledge and the following documents:\n{relevant_docs}.
160
+ Respond with empathy, follow-up questions, and care guidance."""
 
 
 
 
 
 
 
 
161
 
162
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device)
163
+ outputs = self.model.generate(inputs['input_ids'], max_new_tokens=100, temperature=0.7, do_sample=True)
164
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ follow_ups = self.generate_follow_up_questions(message, {})
167
+ response += f"\n{follow_ups[0]}"
 
 
 
 
 
 
 
 
 
168
 
169
+ # Append response to conversation history
170
+ self.conversation_history.append((message, response))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ # Add care level guidance
173
+ if severity == "emergency":
174
+ response += "\nThis seems urgent. Please call 999 immediately."
175
+ elif severity == "urgent":
176
+ response += "\nConsider calling NHS 111 for urgent assistance."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ return {'response': response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  except Exception as e:
180
+ logger.error(f"Error generating response: {e}")
181
+ return {
182
+ 'response': "I'm experiencing technical issues. If this is an emergency, please call 999 immediately.",
183
+ }
184
+
185
+ def handle_feedback(self, message: str, response: str, feedback: int):
186
+ """Update model based on feedback"""
187
  try:
188
+ self.feedback_history.append({
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  'message': message,
190
  'response': response,
191
+ 'feedback': feedback,
 
192
  'timestamp': datetime.now().isoformat()
193
  })
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ if len(self.feedback_history) >= 10:
196
+ # Implement learning updates from feedback
197
+ self.feedback_history = [] # Reset history after learning update
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  except Exception as e:
199
+ logger.error(f"Error processing feedback: {e}")
200
 
201
  def create_demo():
202
+ """Set up Gradio interface for the chatbot"""
203
  try:
204
  bot = AdaptiveMedicalBot()
205
+
206
+ def chat(message: str, history: List[Dict[str, str]]):
207
  try:
 
208
  bot_history = [(h["user"], h["bot"]) for h in history] if history else []
209
+ response_data = bot.generate_response(message)
 
 
210
  response = response_data['response']
211
+
 
212
  history.append({"role": "user", "content": message})
213
  history.append({"role": "assistant", "content": response})
 
214
  return history
 
215
  except Exception as e:
216
  logger.error(f"Chat error: {e}")
217
  return history + [
218
  {"role": "user", "content": message},
219
+ {"role": "assistant", "content": "I'm experiencing technical difficulties. For emergencies, call 999."}
220
  ]
221
 
222
  def process_feedback(feedback: str, history: List[Dict[str, str]], comment: str = ""):
 
223
  try:
224
  if history and len(history) >= 2:
225
  last_user_msg = history[-2]["content"]
226
  last_bot_msg = history[-1]["content"]
227
+ bot.handle_feedback(last_user_msg, last_bot_msg, 1 if feedback == "👍" else -1)
 
 
 
 
 
228
  except Exception as e:
229
  logger.error(f"Error processing feedback: {e}")
230
 
231
+ with gr.Blocks() as demo:
232
+ chatbot = gr.Chatbot(value=[{"role": "assistant", "content": "Hello! I'm Pearly, your GP Triage medical assistant. How can I help you today?"}],
233
+ height=500,
234
+ elem_id="chatbot",
235
+ type="messages",
236
+ show_label=False
237
+ )
238
+
239
+ msg = gr.Textbox(
240
+ label="Your message",
241
+ placeholder="Type your message here...",
242
+ lines=2
243
+ )
244
+ submit = gr.Button("Send", variant="primary")
245
+
246
+ feedback = gr.Radio(
247
+ choices=["👍", "👎"],
248
+ label="Was this response helpful?",
249
+ visible=True
250
+ )
251
+ feedback_text = gr.Textbox(
252
+ label="Additional comments (optional)",
253
+ placeholder="Tell us more about your experience...",
254
+ lines=2
255
+ )
 
 
 
 
 
 
 
 
 
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  # Event Handlers
258
  submit.click(
259
+ fn=chat,
260
  inputs=[msg, chatbot],
261
  outputs=[chatbot]
262
  ).then(
 
266
  )
267
 
268
  msg.submit(
269
+ fn=chat,
270
  inputs=[msg, chatbot],
271
  outputs=[chatbot]
272
  ).then(
 
275
  msg
276
  )
277
 
 
 
 
 
 
 
278
  feedback.change(
279
+ fn=process_feedback,
280
  inputs=[feedback, chatbot, feedback_text],
281
  outputs=[]
282
  )
283
+
284
+ # Clear Chat Handler
285
+ clear = gr.Button("🗑️ Clear Chat")
286
+ clear.click(lambda: [[], ""], None, [chatbot, msg])
287
+
288
+ # Additional Information Sections
289
+ gr.HTML("""
290
+ <div style="padding: 20px;">
291
+ <h2>Quick Actions</h2>
292
+ <button onclick="document.getElementById('chatbot').value += 'Emergency Info: For emergencies, call 999.'">Emergency Info</button>
293
+ <button onclick="document.getElementById('chatbot').value += 'NHS 111 Info: For urgent but non-emergency situations, call 111.'">NHS 111 Info</button>
294
+ <button onclick="document.getElementById('chatbot').value += 'GP Booking Info: For routine appointments with your GP.'">GP Booking</button>
295
+ </div>
296
+ """)
297
+
298
+ gr.Markdown("### Example Messages")
299
+ gr.Examples(
300
+ examples=[
301
+ ["I've been having severe headaches for the past week"],
302
+ ["I need to book a routine checkup"],
303
+ ["I'm feeling very anxious lately and need help"],
304
+ ["My child has had a fever for 2 days"],
305
+ ["I need information about COVID-19 testing"]
306
+ ],
307
+ inputs=msg
308
+ )
309
+
310
+ gr.Markdown("""
311
+ ### NHS Services Guide
312
+ **999 - Emergency Services**: For life-threatening emergencies, severe injuries, heart attack, stroke.
313
+
314
+ **NHS 111**: Available 24/7 for urgent but non-life-threatening situations, medical advice, and guidance.
315
+
316
+ **GP Services**: Routine check-ups, non-urgent medical issues, and prescription renewals.
317
+ """)
318
 
319
  return demo
320
+
321
  except Exception as e:
322
  logger.error(f"Error creating demo: {e}")
323
  raise
324
 
325
  if __name__ == "__main__":
326
+ # Launch Gradio Interface
 
 
 
 
 
 
 
 
327
  demo = create_demo()
328
+ demo.launch()
329
+
330
+