PearlIsa commited on
Commit
1298db9
β€’
1 Parent(s): e2d6407

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -1158
app.py CHANGED
@@ -1,1200 +1,149 @@
1
- # Core imports
2
  import os
3
  import logging
4
  import torch
5
- from typing import Dict, List, Any
6
  import gradio as gr
7
- from huggingface_hub import login
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
  from sentence_transformers import SentenceTransformer
10
- import faiss
11
- import numpy as np
12
- from tqdm import tqdm
13
- from datetime import datetime
14
- from dataclasses import dataclass, field
15
- from dotenv import load_dotenv
16
- import wandb
17
- from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
18
- import bitsandbytes as bnb
19
- from accelerate import infer_auto_device_map, init_empty_weights
20
- from transformers import BitsAndBytesConfig
21
- from datasets import load_dataset
22
- import time
23
 
24
-
25
- # Load environment variables
26
- load_dotenv()
27
-
28
- # Set up logging
29
- logging.basicConfig(
30
- level=logging.INFO,
31
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
32
- )
33
  logger = logging.getLogger(__name__)
34
 
35
- # Retrieve secrets securely from environment variables
36
- kaggle_username = os.getenv("KAGGLE_USERNAME")
37
- kaggle_key = os.getenv("KAGGLE_KEY")
38
- hf_token = os.getenv("HF_TOKEN")
39
- wandb_key = os.getenv("WANDB_API_KEY")
40
-
41
- # Log in to Hugging Face
42
- if hf_token:
43
- login(token=hf_token)
44
- else:
45
- logger.warning("Hugging Face token not found in environment variables.")
46
-
47
- @dataclass
48
- class MedicalConfig:
49
- """Enhanced configuration for medical chatbot"""
50
- # LoRA parameters
51
- LORA_WEIGHTS_PATH: str = "medical_lora_weights"
52
- LORA_R: int = 16
53
- LORA_ALPHA: int = 32
54
- LORA_DROPOUT: float = 0.1
55
- LORA_TARGET_MODULES: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "k_proj", "o_proj"])
56
-
57
- # Training parameters
58
- TRAINING_BATCH_SIZE: int = 4
59
- LEARNING_RATE: float = 2e-5
60
- NUM_EPOCHS: int = 3
61
- MAX_LENGTH: int = 2048
62
- INDEX_BATCH_SIZE: int = 32
63
-
64
- # Medical specific parameters
65
- EMERGENCY_KEYWORDS: List[str] = field(default_factory=lambda: [
66
- 'chest pain', 'breathing difficulty', 'stroke', 'heart attack', 'unconscious',
67
- 'severe bleeding', 'seizure', 'anaphylaxis', 'severe burn', 'choking',
68
- 'severe head injury', 'spinal injury', 'drowning', 'electric shock',
69
- 'severe allergic reaction', 'poisoning', 'overdose', 'self-harm',
70
- 'suicidal thoughts', 'severe trauma'
71
- ])
72
-
73
- URGENT_KEYWORDS: List[str] = field(default_factory=lambda: [
74
- 'infection', 'high fever', 'severe pain', 'vomiting', 'dehydration',
75
- 'anxiety attack', 'panic attack', 'mental health crisis', 'broken bone',
76
- 'deep cut', 'asthma attack', 'migraine', 'severe rash', 'eye injury',
77
- 'dental emergency', 'pregnancy complications', 'severe back pain',
78
- 'severe abdominal pain', 'concussion', 'severe allergies'
79
- ])
80
 
81
- # UK Healthcare specific
82
- EMERGENCY_NUMBERS: List[str] = field(default_factory=lambda: ["999", "112", "111"])
83
- GP_SERVICES: Dict[str, Dict[str, str]] = field(default_factory=lambda: {
84
- "EMERGENCY": {
85
- "name": "A&E",
86
- "wait_time": "4 hours target",
87
- "when_to_use": "Life-threatening emergencies"
88
- },
89
- "URGENT": {
90
- "name": "Urgent Care Center",
91
- "wait_time": "2-4 hours typically",
92
- "when_to_use": "Urgent but not life-threatening conditions"
93
- },
94
- "NON_URGENT": {
95
- "name": "GP Practice",
96
- "wait_time": "Same day to 2 weeks",
97
- "when_to_use": "Routine medical care"
98
- }
99
- })
100
 
101
- # Cultural considerations
102
- CULTURAL_CONTEXTS: List[Dict[str, str]] = field(default_factory=lambda: [
103
- {
104
- "group": "South Asian",
105
- "considerations": [
106
- "Different presentation of skin conditions",
107
- "Higher diabetes risk",
108
- "Cultural dietary practices",
109
- "Language preferences"
110
- ]
111
- },
112
- {
113
- "group": "African/Caribbean",
114
- "considerations": [
115
- "Different presentation of skin conditions",
116
- "Higher hypertension risk",
117
- "Specific hair/scalp conditions",
118
- "Cultural health beliefs"
119
- ]
120
- },
121
- {
122
- "group": "Middle Eastern",
123
- "considerations": [
124
- "Cultural modesty requirements",
125
- "Ramadan considerations",
126
- "Gender preferences for healthcare providers",
127
- "Traditional medicine practices"
128
- ]
129
- }
130
- ])
131
- class GPUOptimizedRAG:
132
- def __init__(
133
- self,
134
- model_path: str = "google/gemma-7b",
135
- embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
136
- config: MedicalConfig = MedicalConfig(),
137
- use_cpu_fallback: bool = False
138
- ):
139
- """Initialize RAG with T4 optimization"""
140
  try:
141
- # Initialize conversation memory
142
- self.conversation_memory = {
143
- 'name': 'Pearly',
144
- 'role': 'GP Medical Assistant',
145
- 'style': 'professional, empathetic, and clear',
146
- 'system_prompt': None,
147
- 'past_interactions': []
148
- }
149
-
150
- # Log GPU info
151
- if torch.cuda.is_available():
152
- logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
153
- logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f}GB")
154
-
155
- # Initialize tokenizer
156
  self.tokenizer = AutoTokenizer.from_pretrained(
157
- model_path,
158
  trust_remote_code=True
159
  )
160
- logger.info("Tokenizer loaded successfully")
161
-
162
- # Initialize model with T4 optimizations
163
- self.model = AutoModelForCausalLM.from_pretrained(
164
- model_path,
165
- torch_dtype=torch.float16, # Use fp16 for memory efficiency
166
- device_map="auto", # Let accelerate handle memory mapping
167
- trust_remote_code=True,
168
- max_memory={0: "14GB"}, # Reserve 14GB for model, leaving 2GB for other operations
169
- load_in_8bit=True, # Use 8-bit quantization for additional memory savings
170
- )
171
- logger.info(f"Model loaded successfully on {self.model.device}")
172
-
173
- # Set up device
174
- self.device = torch.device("cuda")
175
- self.use_cpu_fallback = use_cpu_fallback
176
-
177
- # Initialize embedding model
178
- self.embedding_model = SentenceTransformer(embedding_model)
179
- self.embedding_model.to(self.device)
180
- logger.info("Embedding model loaded successfully")
181
-
182
- # Add clinical quality metrics
183
- self.clinical_metrics = {
184
- 'terminology_accuracy': 0.0,
185
- 'assessment_accuracy': 0.0,
186
- 'guideline_adherence': 0.0,
187
- 'symptom_recognition': 0.0
188
- }
189
-
190
- # Initialize other components
191
- self.config = config
192
 
193
- # Setup FAISS index with GPU support
194
- self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
195
- if torch.cuda.is_available() and not use_cpu_fallback:
196
- try:
197
- self.index = faiss.IndexFlatIP(self.embedding_dim)
198
- res = faiss.StandardGpuResources()
199
- self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
200
- logger.info("FAISS GPU index initialized successfully")
201
- except Exception as e:
202
- logger.warning(f"GPU FAISS initialization failed: {e}, using CPU index")
203
- self.index = faiss.IndexFlatIP(self.embedding_dim)
204
- else:
205
- self.index = faiss.IndexFlatIP(self.embedding_dim)
206
-
207
- # Setup LoRA after model initialization
208
- self.setup_lora()
209
-
210
- except Exception as e:
211
- logger.error(f"Error in initialization: {e}")
212
- raise
213
-
214
- # Add learning components
215
- self.learning_buffer = []
216
- self.feedback_history = []
217
- self.learning_rate = 0.0001
218
- self.min_feedback_threshold = 10
219
-
220
- # Initialize learning storage
221
- self.storage_path = os.path.join(os.getcwd(), "learning_data")
222
- os.makedirs(self.storage_path, exist_ok=True)
223
-
224
- def setup_lora(self):
225
- """Configure and apply LoRA with T4 optimization"""
226
- try:
227
- # Prepare model for k-bit training
228
- model = prepare_model_for_kbit_training(self.model)
229
-
230
- lora_config = LoraConfig(
231
- r=self.config.LORA_R,
232
- lora_alpha=self.config.LORA_ALPHA,
233
- target_modules=self.config.LORA_TARGET_MODULES,
234
- lora_dropout=self.config.LORA_DROPOUT,
235
- bias="none",
236
- task_type=TaskType.CAUSAL_LM,
237
- inference_mode=False,
238
  )
239
-
240
- self.model = get_peft_model(model, lora_config)
241
- logger.info("LoRA configuration applied successfully")
242
 
243
- # Monitor memory after LoRA setup
244
- if torch.cuda.is_available():
245
- monitor_gpu_memory()
246
-
247
  except Exception as e:
248
- logger.error(f"Error setting up LoRA: {e}")
249
- if torch.cuda.is_available():
250
- torch.cuda.empty_cache()
251
  raise
252
 
253
- def store_interaction(self, message: str, response: str):
254
- """Store interaction for learning"""
255
- self.learning_buffer.append({
256
- 'message': message,
257
- 'response': response,
258
- 'timestamp': datetime.now().isoformat(),
259
- 'feedback_pending': True
260
- })
261
 
262
- # Periodically save buffer
263
- if len(self.learning_buffer) >= 100:
264
- self._save_buffer()
265
-
266
- def update_from_feedback(self, message: str, response: str, feedback: int):
267
- """Process and store feedback"""
 
 
268
  try:
269
- # Store feedback
270
- self.feedback_history.append({
271
- 'message': message,
272
- 'response': response,
273
- 'feedback': feedback,
274
- 'timestamp': datetime.now().isoformat()
275
- })
276
 
277
- # Update model if we have enough feedback
278
- if len(self.feedback_history) >= self.min_feedback_threshold:
279
- self._update_model()
280
-
281
- except Exception as e:
282
- logger.error(f"Error in feedback processing: {e}")
283
-
284
- def _update_model(self):
285
- """Update model weights based on feedback"""
286
- try:
287
- positive_samples = [f for f in self.feedback_history if f['feedback'] > 0]
288
- if len(positive_samples) >= self.min_feedback_threshold:
289
- # Prepare training data
290
- train_data = self._prepare_training_data(positive_samples)
291
-
292
- # Update model weights
293
- self._fine_tune_model(train_data)
294
-
295
- # Clear history after successful update
296
- self.feedback_history = []
297
- logger.info("Model updated successfully")
298
-
299
- except Exception as e:
300
- logger.error(f"Error updating model: {e}")
301
 
302
- def _prepare_training_data(self, samples):
303
- """Convert feedback samples to training data"""
304
- return [
305
- {
306
- 'input_ids': self.tokenizer(
307
- s['message'],
308
- max_length=self.config.MAX_LENGTH,
309
- truncation=True,
310
- return_tensors='pt'
311
- ).input_ids,
312
- 'labels': self.tokenizer(
313
- s['response'],
314
- max_length=self.config.MAX_LENGTH,
315
- truncation=True,
316
- return_tensors='pt'
317
- ).input_ids
318
- }
319
- for s in samples
320
- ]
321
-
322
- def evaluate_clinical_quality(self, response: str, expected_elements: List[str]) -> Dict[str, float]:
323
- """Add clinical quality evaluation matching test requirements"""
324
- quality_metrics = {
325
- 'terminology_accuracy': self._evaluate_terminology(response, expected_elements),
326
- 'assessment_accuracy': self._evaluate_assessment(response),
327
- 'guideline_adherence': self._evaluate_guidelines(response),
328
- 'symptom_recognition': self._evaluate_symptoms(response, expected_elements)
329
- }
330
- return quality_metrics
331
-
332
- def assess_urgency(self, symptoms: str) -> Dict[str, Any]:
333
- """Enhanced symptom assessment with detailed analysis"""
334
- symptoms_lower = symptoms.lower()
335
-
336
- # Initialize response
337
- assessment = {
338
- 'level': 'NON-URGENT',
339
- 'reasons': [],
340
- 'recommendations': [],
341
- 'follow_up_needed': False
342
- }
343
-
344
- # Check emergency keywords
345
- emergency_matches = [kw for kw in self.config.EMERGENCY_KEYWORDS
346
- if kw in symptoms_lower]
347
- if emergency_matches:
348
- assessment.update({
349
- 'level': 'EMERGENCY',
350
- 'reasons': emergency_matches,
351
- 'recommendations': [
352
- 'Call 999 immediately',
353
- 'Do not move if spinal injury suspected',
354
- 'Stay on the line for guidance'
355
- ],
356
- 'follow_up_needed': True
357
- })
358
- return assessment
359
-
360
- # Check urgent keywords
361
- urgent_matches = [kw for kw in self.config.URGENT_KEYWORDS
362
- if kw in symptoms_lower]
363
- if urgent_matches:
364
- assessment.update({
365
- 'level': 'URGENT',
366
- 'reasons': urgent_matches,
367
- 'recommendations': [
368
- 'Visit urgent care center',
369
- 'Book emergency GP appointment',
370
- 'Monitor symptoms closely'
371
- ],
372
- 'follow_up_needed': True
373
- })
374
- return assessment
375
 
376
- # Non-urgent default
377
- assessment.update({
378
- 'recommendations': [
379
- 'Book routine GP appointment',
380
- 'Monitor symptoms',
381
- 'Try self-care measures'
382
- ],
383
- 'follow_up_needed': False
384
- })
385
- return assessment
386
 
387
- def prepare_documents(self, documents: List[Dict]):
388
- """Enhanced document preparation with improved batching and memory management"""
389
- self.documents = documents
390
- embeddings = []
391
 
392
- try:
393
- for i in tqdm(range(0, len(documents), self.config.INDEX_BATCH_SIZE),
394
- desc="Processing documents"):
395
- batch = documents[i:i + self.config.INDEX_BATCH_SIZE]
396
- texts = [doc['text'] for doc in batch]
397
-
398
- with torch.amp.autocast(device_type='cuda'):
399
- batch_embeddings = self.embedding_model.encode(
400
- texts,
401
- convert_to_tensor=True,
402
- show_progress_bar=False,
403
- batch_size=8
404
- )
405
-
406
- embeddings.append(batch_embeddings.cpu().numpy())
407
 
408
- all_embeddings = np.vstack(embeddings)
409
- self.index.add(all_embeddings)
410
- logger.info(f"Indexed {len(documents)} documents successfully")
411
-
412
- except Exception as e:
413
- logger.error(f"Error preparing documents: {e}")
414
- raise
415
 
416
-
417
- def generate_cultural_considerations(self, symptoms: str) -> List[str]:
418
- """Generate culturally-aware medical considerations"""
419
- considerations = []
420
- symptoms_lower = symptoms.lower()
421
-
422
- for context in self.config.CULTURAL_CONTEXTS:
423
- relevant_considerations = [
424
- cons for cons in context['considerations']
425
- if any(keyword in symptoms_lower for keyword in cons.lower().split())
426
- ]
427
-
428
- if relevant_considerations:
429
- considerations.extend([
430
- f"{context['group']}: {consideration}"
431
- for consideration in relevant_considerations
432
- ])
433
-
434
- return considerations if considerations else ["No specific cultural considerations identified"]
435
-
436
- def retrieve(self, query: str, k: int = 5) -> List[Dict]:
437
- """Enhanced document retrieval with T4 optimization"""
438
- try:
439
- # Use mixed precision for query encoding
440
- with torch.cuda.amp.autocast():
441
- query_embedding = self.embedding_model.encode(
442
- query,
443
- convert_to_tensor=True,
444
- show_progress_bar=False
445
- )
446
- # Move to CPU for FAISS search
447
- query_embedding = query_embedding.cpu().numpy().reshape(1, -1)
448
-
449
- # Perform search
450
- scores, indices = self.index.search(query_embedding, k)
451
-
452
- # Filter and format results
453
- results = [
454
- {
455
- 'document': self.documents[idx],
456
- 'score': float(score),
457
- 'relevance_metrics': {
458
- 'semantic_similarity': float(score),
459
- 'keyword_match': self._calculate_keyword_match(query, self.documents[idx]['text'])
460
- }
461
- }
462
- for score, idx in zip(scores[0], indices[0])
463
- if score > 0.5 # Relevance threshold
464
- ]
465
-
466
- # Log retrieval metrics
467
- if results:
468
- avg_score = np.mean([r['score'] for r in results])
469
- logger.info(f"Retrieved {len(results)} documents with average score: {avg_score:.3f}")
470
-
471
- return results
472
-
473
- except Exception as e:
474
- logger.error(f"Error in retrieval: {e}")
475
- if torch.cuda.is_available():
476
- torch.cuda.empty_cache()
477
- return []
478
-
479
- def _calculate_keyword_match(self, query: str, doc_text: str) -> float:
480
- """Calculate keyword match score between query and document"""
481
- query_words = set(query.lower().split())
482
- doc_words = set(doc_text.lower().split())
483
- matches = query_words.intersection(doc_words)
484
- return len(matches) / len(query_words) if query_words else 0.0
485
-
486
- def generate_report(self, results: Dict) -> Dict:
487
- """Generate enhanced summary report with T4 metrics"""
488
- try:
489
- total_cases = sum(cat['total'] for cat in results.values())
490
- total_correct = sum(cat['correct'] for cat in results.values())
491
-
492
- # Basic performance metrics
493
- performance_metrics = {
494
- 'timestamp': datetime.now().isoformat(),
495
- 'triage_performance': {
496
- 'emergency_accuracy': results['emergency']['correct'] / results['emergency']['total'],
497
- 'urgent_accuracy': results['urgent']['correct'] / results['urgent']['total'],
498
- 'non_urgent_accuracy': results['non_urgent']['correct'] / results['non_urgent']['total'],
499
- 'overall_accuracy': total_correct / total_cases
500
- }
501
- }
502
-
503
- # Add document processing metrics if available
504
- if hasattr(self, 'document_metrics'):
505
- performance_metrics['document_processing'] = self.document_metrics
506
-
507
- # Add GPU metrics if available
508
- if torch.cuda.is_available():
509
- gpu_metrics = {
510
- 'gpu_name': torch.cuda.get_device_name(0),
511
- 'gpu_memory_allocated': torch.cuda.memory_allocated() / 1024**2, # MB
512
- 'gpu_memory_cached': torch.cuda.memory_reserved() / 1024**2, # MB
513
- }
514
- performance_metrics['gpu_metrics'] = gpu_metrics
515
-
516
- return performance_metrics
517
-
518
- except Exception as e:
519
- logger.error(f"Error generating report: {e}")
520
- if torch.cuda.is_available():
521
- torch.cuda.empty_cache()
522
- return {
523
- 'timestamp': datetime.now().isoformat(),
524
- 'error': str(e)
525
- }
526
-
527
- def get_booking_template(self, urgency_level: str) -> str:
528
- """Get appropriate booking template based on urgency level"""
529
- service_info = self.config.GP_SERVICES[urgency_level]
530
-
531
- templates = {
532
- "EMERGENCY": f"""
533
- 🚨 EMERGENCY SERVICES REQUIRED 🚨
534
-
535
- Service: {service_info['name']}
536
- Target Wait Time: {service_info['wait_time']}
537
- When to Use: {service_info['when_to_use']}
538
-
539
- IMMEDIATE ACTIONS:
540
- 1. πŸš‘ Call 999 (or 112)
541
- 2. πŸ₯ Nearest A&E: [Location Placeholder]
542
- 3. 🚨 Stay on line for guidance
543
-
544
- Type '999' to initiate emergency call
545
- """,
546
- "URGENT": f"""
547
- ⚑ URGENT CARE NEEDED ⚑
548
-
549
- Service: {service_info['name']}
550
- Expected Wait: {service_info['wait_time']}
551
- When to Use: {service_info['when_to_use']}
552
-
553
- OPTIONS:
554
- 1. πŸ₯ Find nearest urgent care
555
- 2. πŸ“… Book urgent GP slot
556
- 3. πŸ” Locate walk-in clinic
557
-
558
- Reply with option number (1-3)
559
- """,
560
- "NON_URGENT": f"""
561
- πŸ“‹ ROUTINE CARE BOOKING πŸ“‹
562
-
563
- Service: {service_info['name']}
564
- Typical Wait: {service_info['wait_time']}
565
- When to Use: {service_info['when_to_use']}
566
-
567
- OPTIONS:
568
- 1. πŸ“… Schedule GP visit
569
- 2. πŸ‘¨β€βš•οΈ Find local GP
570
- 3. ℹ️ Self-care advice
571
-
572
- Reply with option number (1-3)
573
- """
574
- }
575
-
576
- return templates.get(urgency_level, templates["NON_URGENT"])
577
-
578
- def generate_response(self, query: str, chat_history: List[tuple] = None) -> Dict[str, Any]:
579
- """Generate response with enhanced conversational context and T4 optimization"""
580
- try:
581
- # Update conversation memory
582
- if chat_history:
583
- self.conversation_memory['past_interactions'] = chat_history[-3:]
584
-
585
- # Use mixed precision for T4
586
- with torch.cuda.amp.autocast():
587
- # Retrieve relevant documents with boosted weights for persona matches
588
- retrieved_docs = self.retrieve(query, k=7)
589
-
590
- # Separate documents by type
591
- medical_docs = [doc for doc in retrieved_docs if doc['document']['type'] in ['medical_qa', 'diagnosis']]
592
- persona_docs = [doc for doc in retrieved_docs if doc['document']['type'] in ['persona', 'conversation', 'GP_template']]
593
-
594
- # Build context with weighted emphasis on different document types
595
- medical_context = " ".join([doc['document']['text'] for doc in medical_docs])
596
- persona_context = " ".join([doc['document']['text'] for doc in persona_docs])
597
-
598
- # Assess urgency and get considerations
599
- urgency_assessment = self.assess_urgency(query)
600
- cultural_considerations = self.generate_cultural_considerations(query)
601
-
602
- # Build conversation history context
603
- history_context = ""
604
- if chat_history:
605
- history_context = "\n".join([f"Human: {h}\nPearly: {a}" for h, a in chat_history[-3:]])
606
-
607
- # Add persona reminder
608
- persona_reminder = f"""
609
- I am {self.conversation_memory['name']}, a {self.conversation_memory['role']}.
610
- My communication style is {self.conversation_memory['style']}.
611
- """
612
-
613
- # Create enhanced prompt with persona integration
614
- prompt = f"""Context:
615
- Medical Information: {medical_context}
616
-
617
- {persona_reminder}
618
-
619
- Previous Interactions:
620
- {history_context}
621
-
622
- Current Query: {query}
623
-
624
- Maintain my identity as {self.conversation_memory['name']}, the {self.conversation_memory['role']},
625
- providing clear, professional guidance following NHS protocols.
626
- Urgency Level: {urgency_assessment['level']}
627
- Cultural Considerations: {', '.join(cultural_considerations)}
628
-
629
- Respond in a clear, caring manner, always referring to myself as {self.conversation_memory['name']}.
630
-
631
- Response:"""
632
-
633
- # Generate response with T4 optimizations
634
- inputs = self.tokenizer(
635
- prompt,
636
- return_tensors="pt",
637
- max_length=self.config.MAX_LENGTH,
638
- truncation=True
639
- ).to(self.device)
640
-
641
- outputs = self.model.generate(
642
- **inputs,
643
- max_new_tokens=512,
644
- do_sample=True,
645
- top_p=0.9,
646
- temperature=0.7,
647
- num_return_sequences=1,
648
- pad_token_id=self.tokenizer.eos_token_id,
649
- use_cache=True, # Enable KV cache
650
- low_cpu_mem_usage=True
651
- )
652
-
653
- # Clean up CUDA cache after generation
654
- if torch.cuda.is_available():
655
- torch.cuda.empty_cache()
656
-
657
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
658
- response = response.split("Response:")[-1].strip()
659
-
660
- # Add booking template for emergency/urgent cases
661
- if urgency_assessment['level'] in ["EMERGENCY", "URGENT"]:
662
- booking_template = self.get_booking_template(urgency_assessment['level'])
663
- response = f"{response}\n\n{booking_template}"
664
-
665
- return {
666
- 'response': response,
667
- 'urgency_assessment': urgency_assessment,
668
- 'cultural_considerations': cultural_considerations
669
- }
670
-
671
  except Exception as e:
672
  logger.error(f"Error generating response: {e}")
673
- if torch.cuda.is_available():
674
- torch.cuda.empty_cache() # Clean up on error
675
- return {
676
- 'response': "I apologize, but I encountered an error. If this is an emergency, please call 999 immediately.",
677
- 'urgency_assessment': {'level': 'UNKNOWN'},
678
- 'cultural_considerations': []
679
- }
680
-
681
-
682
-
683
- def enhance_response_generation(self):
684
- """Add test-aligned response enhancement"""
685
- self.response_enhancers = {
686
- 'demographic_sensitivity': self._enhance_demographic_sensitivity,
687
- 'cultural_competency': self._enhance_cultural_competency,
688
- 'clinical_quality': self._enhance_clinical_quality,
689
- 'follow_up_generation': self._enhance_follow_up
690
- }
691
-
692
- def _enhance_demographic_sensitivity(self, response: str, demographic: str) -> str:
693
- """Add demographic-specific enhancements matching test requirements"""
694
- demographic_patterns = {
695
- 'pediatric': ['age-appropriate', 'child-friendly', 'developmental'],
696
- 'elderly': ['mobility', 'cognitive', 'fall risk'],
697
- 'pregnant': ['trimester', 'fetal', 'pregnancy-safe'],
698
- 'chronic_condition': ['management', 'monitoring', 'ongoing care']
699
- }
700
- return response # Placeholder implementation
701
-
702
- def process_appointment_booking(self, message, patient_info):
703
- """Process appointment booking queries"""
704
- return "I can help you book an appointment. Please provide further details."
705
 
706
- def check_urgency_accuracy(self, predicted: str, expected: str) -> float:
707
- """Check if urgency level matches expected"""
708
- return 1.0 if predicted == expected else 0.0
709
-
710
- def check_action_accuracy(self, response: str, expected_actions: List[str]) -> float:
711
- """Check if recommended actions match expected"""
712
- if not expected_actions:
713
- return 1.0
714
- found_actions = sum(1 for action in expected_actions
715
- if action.lower() in response.lower())
716
- return found_actions / len(expected_actions)
717
-
718
- def assess_conversation_quality(self, response: str) -> float:
719
- """Assess conversation quality metrics"""
720
- metrics = {
721
- 'empathy': any(word in response.lower()
722
- for word in ['understand', 'hear you', 'sorry']),
723
- 'clarity': len(response.split('.')) <= 5, # Check for concise sentences
724
- 'follow_up': '?' in response, # Check for follow-up questions
725
- 'structure': any(word in response.lower()
726
- for word in ['first', 'then', 'next', 'finally'])
727
- }
728
- return sum(metrics.values()) / len(metrics)
729
-
730
- def check_cultural_sensitivity(self, response_data: Dict, context: str) -> float:
731
- """Check cultural sensitivity of response"""
732
- if not context:
733
- return 1.0
734
 
735
- cultural_considerations = response_data.get('cultural_considerations', [])
736
- return 1.0 if any(context.lower() in cons.lower()
737
- for cons in cultural_considerations) else 0.0
738
-
739
-
740
-
741
-
742
- def monitor_gpu_memory():
743
- """Monitor GPU memory usage"""
744
- if torch.cuda.is_available():
745
- device = torch.cuda.current_device()
746
- allocated = torch.cuda.memory_allocated(device) / 1024**2
747
- reserved = torch.cuda.memory_reserved(device) / 1024**2
748
- logger.info(f"GPU Memory: Allocated: {allocated:.2f}MB, Reserved: {reserved:.2f}MB")
749
-
750
- def prepare_medical_documents():
751
- """Prepare medical knowledge base documents with enhanced conversation flow"""
752
- try:
753
- logger.info("Loading medical and persona datasets...")
754
- datasets = {
755
- "persona": load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]"),
756
- "medqa": load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]"),
757
- "meddia": load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]")
758
- }
759
-
760
- documents = []
761
-
762
- # Process Persona dataset for enhanced conversational style
763
- logger.info("Processing persona dataset...")
764
- for item in datasets["persona"]:
765
- if isinstance(item.get('personality'), list):
766
- personality = " ".join(item['personality'])
767
- documents.append({
768
- 'text': f"""
769
- Conversation Style Guide:
770
- Personality: {personality}
771
- Role: Pearly - Medical Assistant
772
- Core Traits: Professional, empathetic, clear
773
- Key Behaviors:
774
- - Always introduce as Pearly
775
- - Show empathy for symptoms
776
- - Ask relevant follow-up questions
777
- - Offer practical assistance
778
- - Maintain professional tone while being approachable
779
- """,
780
- 'type': 'persona'
781
- })
782
-
783
- # Process conversation examples with enhanced structure
784
- if isinstance(item.get('utterances'), list):
785
- for utterance in item['utterances']:
786
- if isinstance(utterance, dict) and 'history' in utterance:
787
- conversation = ' '.join(utterance['history'])
788
- documents.append({
789
- 'text': f"""
790
- Medical Consultation Pattern:
791
- Conversation: {conversation}
792
- Key Elements:
793
- - Show understanding of symptoms
794
- - Ask clarifying questions
795
- - Provide clear guidance
796
- - Offer next steps
797
- - Check if assistance needed
798
- """,
799
- 'type': 'conversation_pattern'
800
- })
801
-
802
- # Process MedQA dataset with enhanced medical context
803
- logger.info("Processing medical QA dataset...")
804
- for item in datasets["medqa"]:
805
- if 'input' in item and 'output' in item:
806
- input_text = item['input']
807
- if input_text.startswith('Q:'):
808
- input_text = input_text[2:]
809
-
810
- documents.append({
811
- 'text': f"""
812
- Medical Knowledge Base:
813
- Question: {input_text}
814
- Answer: {item['output']}
815
- Application:
816
- - Use information to inform recommendations
817
- - Adapt to patient's situation
818
- - Maintain clinical accuracy
819
- - Explain in clear terms
820
- """,
821
- 'type': 'medical_qa'
822
- })
823
-
824
- # Process diagnosis dataset with structured guidance
825
- logger.info("Processing diagnosis dataset...")
826
- for item in datasets["meddia"]:
827
- if 'input' in item and 'output' in item:
828
- documents.append({
829
- 'text': f"""
830
- Clinical Assessment Framework:
831
- Symptoms: {item['input']}
832
- Assessment and Plan: {item['output']}
833
- Response Structure:
834
- 1. Acknowledge symptoms
835
- 2. Ask about severity and duration
836
- 3. Inquire about related symptoms
837
- 4. Provide clear recommendations
838
- 5. Offer assistance with next steps
839
- """,
840
- 'type': 'diagnosis_guidance'
841
- })
842
-
843
- # Add enhanced conversation templates
844
- conversation_templates = [
845
- {
846
- 'text': """
847
- Consultation Framework:
848
- 1. Initial Response:
849
- - Acknowledge the concern
850
- - Show empathy
851
- - Ask about duration/severity
852
-
853
- 2. Follow-up Questions:
854
- - Ask specific, relevant questions
855
- - Clarify symptoms
856
- - Check for related issues
857
-
858
- 3. Assessment:
859
- - Summarize findings
860
- - Explain reasoning
861
- - State level of concern
862
-
863
- 4. Recommendations:
864
- - Provide clear guidance
865
- - List specific actions
866
- - Offer assistance
867
-
868
- 5. Next Steps:
869
- - Suggest appropriate care level
870
- - Offer to help with appointments
871
- - Provide relevant resources
872
-
873
- 6. Safety Checks:
874
- - Verify understanding
875
- - Confirm action plan
876
- - Ensure patient comfort
877
-
878
- Response Patterns:
879
- Emergency:
880
- "I understand you're experiencing [symptom]. This requires immediate attention..."
881
-
882
- Urgent:
883
- "I'm concerned about [symptom]. To better assess the urgency..."
884
-
885
- Non-urgent:
886
- "I hear you about [symptom]. Let's understand this better..."
887
-
888
- Follow-ups:
889
- "Would you like help with [specific assistance]?"
890
- "Shall I help you [book/find/connect]?"
891
-
892
- Guidance Structure:
893
- 1. Acknowledge
894
- 2. Assess
895
- 3. Advise
896
- 4. Assist
897
- 5. Arrange follow-up
898
- """,
899
- 'type': 'conversation_framework'
900
- }
901
- ]
902
- documents.extend(conversation_templates)
903
-
904
- # Add your existing GP templates
905
- # Define GP templates inside the function
906
- GP_templates = [
907
- {
908
- 'text': """
909
- Role: GP Medical Assistant
910
- Context: Emergency situation - Chest Pain
911
- Style: Calm, clear, and authoritative
912
- Conversation Pattern:
913
-
914
- Patient Input: Chest pain symptoms
915
- Response Framework:
916
- 1. Initial acknowledgment
917
- 2. Immediate triage questions:
918
- - Duration of pain
919
- - Pain characteristics
920
- - Associated symptoms
921
- 3. Emergency guidance
922
- 4. Clear action steps
923
-
924
- Example Flow:
925
- Patient: I'm having chest pain
926
- Pearly: I understand you're experiencing chest pain. This requires immediate attention. How long have you been experiencing this pain?
927
-
928
- Patient: About 30 minutes now
929
- Pearly: Thank you for letting me know. Is the pain sharp or dull, and where exactly is it located? Also, are you experiencing any other symptoms like shortness of breath or sweating?
930
-
931
- Patient: It's sharp and in the center of my chest. Yes, I'm sweating a lot
932
- Pearly: Given your symptoms, this could be a serious cardiac event requiring immediate emergency care.
933
-
934
- IMMEDIATE ACTIONS:
935
- 1. Call 999 immediately - I can help you do this
936
- 2. Sit down and try to stay calm
937
- 3. If available and you're not allergic, take one adult aspirin (300mg)
938
- 4. Loosen any tight clothing
939
- 5. Keep the line open - I'll stay with you until help arrives
940
-
941
- Would you like me to connect you to emergency services now?
942
- """,
943
- 'type': 'GP_template'
944
- },
945
- {
946
- 'text': """
947
- Role: GP Medical Assistant
948
- Context: Non-emergency consultation - Persistent Cough
949
- Style: Empathetic and thorough
950
- Conversation Pattern:
951
-
952
- Patient Input: Cough symptoms
953
- Response Framework:
954
- 1. Show empathy
955
- 2. Gather specific details:
956
- - Type of cough
957
- - Duration
958
- - Associated symptoms
959
- 3. Risk assessment
960
- 4. Clear recommendations
961
-
962
- Example Flow:
963
- Patient: I've had a cough for 5 days now
964
- Pearly: I'm sorry to hear you've been unwell. To better understand your condition, is it a dry cough or are you producing any phlegm?
965
-
966
- Patient: It's mostly dry but sometimes I cough up clear phlegm
967
- Pearly: Thank you for that detail. Have you experienced any fever, and are you having any difficulty breathing?
968
-
969
- Patient: No fever but I feel a bit short of breath sometimes
970
- Pearly: I understand. A few more important questions:
971
- 1. Have you been exposed to anyone with COVID-19 or other respiratory infections?
972
- 2. Do you have any underlying health conditions?
973
- 3. Are your symptoms worse at any particular time of day?
974
-
975
- Would you like help booking an appointment with your GP?
976
- """,
977
- 'type': 'GP_template'
978
- },
979
- {
980
- 'text': """
981
- Role: GP Medical Assistant
982
- Context: Mental Health Support
983
- Style: Empathetic, patient, and supportive
984
- Conversation Pattern:
985
-
986
- Patient Input: Mental health concerns
987
- Response Framework:
988
- 1. Validate feelings
989
- 2. Gentle exploration:
990
- - Duration
991
- - Impact on daily life
992
- - Support systems
993
- 3. Clear support options
994
- 4. Immediate resources
995
-
996
- Example Flow:
997
- Patient: I've been feeling really anxious and can't sleep
998
- Pearly: I hear you, and I want you to know it's brave to speak up about this. How long have you been experiencing these feelings?
999
-
1000
- Patient: About a month now, getting worse
1001
- Pearly: Thank you for sharing that. To better understand your situation:
1002
- 1. Are there any specific triggers you've noticed?
1003
- 2. How is this affecting your daily activities?
1004
- 3. Have you had any previous mental health support?
1005
-
1006
- Would you like information about immediate support services or help booking a GP appointment?
1007
- """,
1008
- 'type': 'GP_template'
1009
- }
1010
- ]
1011
-
1012
- # Add all templates to documents
1013
- documents.extend(GP_templates)
1014
-
1015
- logger.info(f"Prepared {len(documents)} documents including:")
1016
- logger.info(f"- {len([d for d in documents if d['type'] == 'persona'])} persona guides")
1017
- logger.info(f"- {len([d for d in documents if d['type'] == 'conversation_pattern'])} conversation patterns")
1018
- logger.info(f"- {len([d for d in documents if d['type'] == 'medical_qa'])} medical QA pairs")
1019
- logger.info(f"- {len([d for d in documents if d['type'] == 'diagnosis_guidance'])} diagnosis guidelines")
1020
- logger.info(f"- {len([d for d in documents if d['type'] == 'conversation_framework'])} conversation frameworks")
1021
- logger.info(f"- {len([d for d in documents if d['type'] == 'GP_template'])} GP templates")
1022
-
1023
- return documents
1024
-
1025
- except Exception as e:
1026
- logger.error(f"Error preparing medical documents: {e}")
1027
- # Print sample data for debugging
1028
- for dataset_name, dataset in datasets.items():
1029
- try:
1030
- sample = dataset[0]
1031
- logger.error(f"\nSample from {dataset_name}:")
1032
- logger.error(f"Keys: {list(sample.keys())}")
1033
- logger.error(f"Sample content: {str(sample)[:500]}")
1034
- except Exception as debug_e:
1035
- logger.error(f"Error inspecting {dataset_name}: {debug_e}")
1036
- raise
1037
-
1038
-
1039
- def setup_wandb(config: MedicalConfig):
1040
- """Setup Weights & Biases tracking"""
1041
  try:
1042
- wandb.init(
1043
- project="medical-chatbot",
1044
- config={
1045
- "learning_rate": config.LEARNING_RATE,
1046
- "epochs": config.NUM_EPOCHS,
1047
- "batch_size": config.TRAINING_BATCH_SIZE,
1048
- "lora_r": config.LORA_R,
1049
- "lora_alpha": config.LORA_ALPHA
1050
- }
1051
- )
1052
- logger.info("Weights & Biases initialized successfully")
1053
  except Exception as e:
1054
- logger.warning(f"Failed to initialize Weights & Biases: {e}")
1055
- logger.warning("Continuing without wandb tracking")
1056
-
1057
-
1058
- # Global process_chat_response function
1059
- def process_chat_response(response_data: Dict[str, Any], message: str, history: List[tuple]) -> str:
1060
- """Format chat response based on context"""
1061
- try:
1062
- if not history or message.lower().startswith(("hi", "hello", "hey", "good")):
1063
- return "Hi! I'm Pearly, your medical triaging assistant. I'm here to help assess your symptoms and provide guidance. How may I assist you today?"
1064
-
1065
- urgency_level = response_data['urgency_assessment']['level']
1066
- response_text = response_data['response']
1067
-
1068
- if urgency_level == "EMERGENCY":
1069
- return f"🚨 EMERGENCY ALERT 🚨\n\n{response_text}\n\nWould you like me to help connect you to emergency services?"
1070
- elif urgency_level == "URGENT":
1071
- return f"⚠️ URGENT CARE NEEDED ⚠️\n\n{response_text}\n\nWould you like help finding your nearest urgent care center?"
1072
- else:
1073
- return f"{response_text}\n\nWould you like help booking a GP appointment or finding more NHS resources?"
1074
- except Exception as e:
1075
- logger.error(f"Error processing chat response: {e}")
1076
- return (
1077
- "I'm Pearly, and I apologize for the technical difficulty. For your safety:\n\n"
1078
- "- Call 999 for emergencies\n"
1079
- "- Call 111 for urgent medical advice\n"
1080
- "- Visit NHS 111 online for non-urgent concerns\n\n"
1081
- "Would you like to try asking your question again?"
1082
- )
1083
-
1084
- # Global chat function
1085
- def chat(message: str, history: List[tuple]) -> List[tuple]:
1086
- """Enhanced chat function for Hugging Face Space"""
1087
- try:
1088
- if torch.cuda.is_available():
1089
- torch.cuda.empty_cache()
1090
-
1091
- # Convert history format for RAG system
1092
- rag_history = [(h["content"], a["content"])
1093
- for h, a in zip(history[::2], history[1::2])] if history else []
1094
-
1095
-
1096
- response_data = rag_system.generate_response(message, history)
1097
- response = process_chat_response(response_data, message, history)
1098
-
1099
- if not history:
1100
- history = []
1101
- # Convert to message format
1102
- history.append({"role": "user", "content": message})
1103
- history.append({"role": "assistant", "content": response})
1104
-
1105
- # Store interaction for learning
1106
- rag_system.store_interaction(message, response)
1107
-
1108
- return history
1109
 
1110
- except Exception as e:
1111
- logger.error(f"Error in chat: {e}")
1112
- return handle_error(message, history)
1113
-
1114
- # Initialize RAG system globally
1115
- rag_system = GPUOptimizedRAG(
1116
- config=MedicalConfig(),
1117
- use_cpu_fallback=False,
1118
- model_path="google/gemma-7b",
1119
- embedding_model="sentence-transformers/all-MiniLM-L6-v2"
1120
- )
1121
-
1122
- # Initialize medical documents globally
1123
- medical_documents = prepare_medical_documents()
1124
- rag_system.prepare_documents(medical_documents)
1125
-
1126
- def handle_feedback(feedback: str, message: str, response: str) -> None:
1127
- """Process user feedback for adaptive learning"""
1128
- try:
1129
- feedback_value = 1 if feedback == "πŸ‘ Helpful" else -1
1130
- rag_system.update_from_feedback(message, response, feedback_value)
1131
- except Exception as e:
1132
- logger.error(f"Error processing feedback: {e}")
1133
-
1134
- # Define the Gradio interface
1135
- title = """
1136
- <div style="text-align: center; max-width: 700px; margin: 0 auto;">
1137
- <h1>Pearly Medical Assistant</h1>
1138
- <p>Hi! I'm Pearly, your GP medical assistant. I can help assess your symptoms,
1139
- provide medical guidance and assist with finding appropriate care.</p>
1140
- <p style="color: #666; font-size: 0.9em;">For emergencies, always call 999</p>
1141
- </div>
1142
- """
1143
-
1144
- def create_demo() -> gr.Blocks:
1145
- """Create and return the Gradio demo interface"""
1146
- def on_chat(message, history):
1147
- """Chat handler with feedback visibility"""
1148
- result = chat(message, history)
1149
- feedback.update(visible=True)
1150
- return result, gr.update(value="")
1151
-
1152
- with gr.Blocks(css="footer {display: none}") as demo:
1153
- gr.HTML(title)
1154
- chatbot = gr.Chatbot(
1155
- value=[{"role": "assistant", "content": "Hi! I'm Pearly, your GP medical assistant. How can I help you today?"}],
1156
- height=600,
1157
- type="messages",
1158
- label="Medical Consultation"
1159
- )
1160
- msg = gr.Textbox(
1161
- label="Your Message",
1162
- placeholder="Describe your symptoms or ask a medical question...",
1163
- lines=2
1164
- )
1165
- submit = gr.Button("Send", variant="primary")
1166
- clear = gr.Button("Clear Conversation")
1167
-
1168
- # Feedback mechanism
1169
- with gr.Row():
1170
- feedback = gr.Radio(
1171
- choices=["πŸ‘ Helpful", "πŸ‘Ž Not Helpful"],
1172
- label="Was this response helpful?",
1173
- visible=False
1174
- )
1175
-
1176
- # Set up event handlers
1177
- msg.submit(on_chat, [msg, chatbot], [chatbot, msg])
1178
- submit.click(on_chat, [msg, chatbot], [chatbot, msg])
1179
- clear.click(lambda: None, None, chatbot, queue=False)
1180
 
1181
- feedback.change(
1182
- handle_feedback,
1183
- [feedback, msg, chatbot],
1184
- None
1185
- )
1186
-
1187
-
1188
- gr.HTML("""
1189
- <div style="text-align: center; margin-top: 20px; font-size: 0.8em; color: #666;">
1190
- <p>This is a medical triage assistant. For emergencies, always call 999.</p>
1191
- <p>Your privacy is important. Conversations are not stored permanently.</p>
1192
- </div>
1193
- """)
1194
 
1195
  return demo
1196
-
1197
-
1198
  if __name__ == "__main__":
1199
  demo = create_demo()
1200
  demo.launch()
 
1
+ # app.py
2
  import os
3
  import logging
4
  import torch
 
5
  import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
7
  from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Setup logging
10
+ logging.basicConfig(level=logging.INFO)
 
 
 
 
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
13
+ class MedicalTriageBot:
14
+ def __init__(self):
15
+ self.setup_models()
16
+ self.emergency_keywords = [
17
+ 'chest pain', 'breathing', 'stroke', 'heart attack',
18
+ 'unconscious', 'bleeding', 'seizure'
19
+ ]
20
+ self.urgent_keywords = [
21
+ 'infection', 'fever', 'severe pain', 'vomiting',
22
+ 'mental health', 'broken'
23
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def setup_models(self):
26
+ """Initialize models with optimal T4 settings"""
27
+ # Configure quantization
28
+ bnb_config = BitsAndBytesConfig(
29
+ load_in_4bit=True,
30
+ bnb_4bit_quant_type="nf4",
31
+ bnb_4bit_compute_dtype=torch.float16
32
+ )
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  try:
35
+ # Load tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  self.tokenizer = AutoTokenizer.from_pretrained(
37
+ "google/gemma-7b",
38
  trust_remote_code=True
39
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Load model with quantization
42
+ self.model = AutoModelForCausalLM.from_pretrained(
43
+ "google/gemma-7b",
44
+ quantization_config=bnb_config,
45
+ device_map="auto",
46
+ trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
 
 
 
48
 
 
 
 
 
49
  except Exception as e:
50
+ logger.error(f"Error loading models: {e}")
 
 
51
  raise
52
 
53
+ def assess_urgency(self, message):
54
+ """Determine message urgency"""
55
+ message_lower = message.lower()
 
 
 
 
 
56
 
57
+ if any(keyword in message_lower for keyword in self.emergency_keywords):
58
+ return "EMERGENCY"
59
+ elif any(keyword in message_lower for keyword in self.urgent_keywords):
60
+ return "URGENT"
61
+ return "NON_URGENT"
62
+
63
+ def generate_response(self, message, history):
64
+ """Generate appropriate response based on context"""
65
  try:
66
+ urgency = self.assess_urgency(message)
 
 
 
 
 
 
67
 
68
+ # Build prompt with context
69
+ conversation_history = "\n".join([f"User: {h[0]}\nAssistant: {h[1]}" for h in (history or [])])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ prompt = f"""You are Pearly, an empathetic NHS medical triage assistant. Consider the conversation history and respond appropriately:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ Previous conversation:
74
+ {conversation_history}
 
 
 
 
 
 
 
 
75
 
76
+ Current message: {message}
77
+ Urgency Level: {urgency}
 
 
78
 
79
+ Respond professionally and empathetically. For emergencies, recommend calling 999. For urgent cases, recommend NHS 111. For non-urgent cases, offer GP booking assistance. Ask relevant follow-up questions to better understand the situation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ Response:"""
 
 
 
 
 
 
82
 
83
+ # Generate response
84
+ inputs = self.tokenizer(
85
+ prompt,
86
+ return_tensors="pt",
87
+ max_length=512,
88
+ truncation=True
89
+ ).to(self.model.device)
90
+
91
+ outputs = self.model.generate(
92
+ **inputs,
93
+ max_new_tokens=200,
94
+ temperature=0.7,
95
+ do_sample=True
96
+ )
97
+
98
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
99
+
100
+ # Format based on urgency
101
+ if urgency == "EMERGENCY":
102
+ return f"🚨 EMERGENCY: Please call 999 immediately.\n\n{response}\n\nWould you like first aid guidance while waiting for emergency services?"
103
+ elif urgency == "URGENT":
104
+ return f"⚠️ URGENT: {response}\n\nPlease consider calling NHS 111 for immediate advice. Would you like information about urgent care centers?"
105
+ else:
106
+ return f"{response}\n\nWould you like help booking a GP appointment?"
107
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
  logger.error(f"Error generating response: {e}")
110
+ return "I apologize, but I'm having technical difficulties. If this is an emergency, please call 999 immediately. For urgent concerns, call 111."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ # Create Gradio interface
113
+ def create_demo():
114
+ bot = MedicalTriageBot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ def chat(message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  try:
118
+ response = bot.generate_response(message, history)
119
+ return response
 
 
 
 
 
 
 
 
 
120
  except Exception as e:
121
+ logger.error(f"Chat error: {e}")
122
+ return "I apologize, but I'm experiencing technical difficulties. For emergencies, please call 999."
123
+
124
+ demo = gr.ChatInterface(
125
+ fn=chat,
126
+ title="NHS Medical Triage Assistant - Pearly",
127
+ description="""
128
+ πŸ‘‹ Hello! I'm Pearly, your NHS medical triage assistant. I can help assess your symptoms and direct you to appropriate care.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ 🚨 For emergencies, always call 999
131
+ ⚠️ For urgent concerns, call NHS 111
132
+ πŸ‘©β€βš•οΈ For routine care, I can help book GP appointments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ Please describe your symptoms or concerns below.
135
+ """,
136
+ examples=[
137
+ "I've been having chest pain for the last hour",
138
+ "I have a fever and sore throat",
139
+ "I'd like to book a routine checkup",
140
+ ],
141
+ theme="soft"
142
+ )
 
 
 
 
143
 
144
  return demo
145
+
146
+ # Launch the app
147
  if __name__ == "__main__":
148
  demo = create_demo()
149
  demo.launch()