PearlIsa commited on
Commit
e49dc64
1 Parent(s): faf3edb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +385 -379
app.py CHANGED
@@ -63,369 +63,376 @@ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection
63
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
64
 
65
 
66
- def prepare_initial_datasets(batch_size=8):
67
- print("Loading datasets with memory-optimized batch processing...")
68
-
69
- def process_medqa_batch(examples):
70
- results = []
71
- inputs = examples['input']
72
- instructions = examples['instruction']
73
- outputs = examples['output']
74
-
75
- for inp, inst, out in zip(inputs, instructions, outputs):
76
- results.append({
77
- "input": f"{inp} {inst}",
78
- "output": out
79
- })
80
- return results
81
-
82
- def process_meddia_batch(examples):
83
- results = []
84
- inputs = examples['input']
85
- outputs = examples['output']
86
-
87
- for inp, out in zip(inputs, outputs):
88
- results.append({
89
- "input": inp,
90
- "output": out
91
- })
92
- return results
93
-
94
- def process_persona_batch(examples):
95
- results = []
96
- personalities = examples['personality']
97
- utterances = examples['utterances']
98
-
99
- for pers, utts in zip(personalities, utterances):
100
- try:
101
- # Process personality list
102
- personality = ' '.join([
103
- p for p in pers
104
- if isinstance(p, str)
105
- ])
106
-
107
- # Process utterances
108
- if utts and len(utts) > 0:
109
- utterance = utts[0]
110
- history = []
111
-
112
- # Process history
113
- if 'history' in utterance and utterance['history']:
114
- history = [
115
- h for h in utterance['history']
116
- if isinstance(h, str)
117
- ]
118
-
119
- history_text = ' '.join(history)
120
-
121
- # Get candidate response
122
- candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else ''
123
-
124
- if personality or history_text:
125
- results.append({
126
- "input": f"{personality} {history_text}".strip(),
127
- "output": candidate
128
- })
129
- except Exception as e:
130
- print(f"Error processing persona batch item: {e}")
131
- continue
132
-
133
- return results
134
-
135
- # Load and process each dataset separately
136
- print("Processing MedQA dataset...")
137
- medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]")
138
- medqa_processed = []
139
-
140
- for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"):
141
- batch = medqa[i:i + batch_size]
142
- medqa_processed.extend(process_medqa_batch(batch))
143
- if i % (batch_size * 5) == 0:
144
- torch.cuda.empty_cache()
145
-
146
- print("Processing MedDiagnosis dataset...")
147
- meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]")
148
- meddia_processed = []
149
-
150
- for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"):
151
- batch = meddia[i:i + batch_size]
152
- meddia_processed.extend(process_meddia_batch(batch))
153
- if i % (batch_size * 5) == 0:
154
- torch.cuda.empty_cache()
155
-
156
- print("Processing Persona-Chat dataset...")
157
- persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
158
- persona_processed = []
159
-
160
- for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"):
161
- batch = persona[i:i + batch_size]
162
- persona_processed.extend(process_persona_batch(batch))
163
- if i % (batch_size * 5) == 0:
164
- torch.cuda.empty_cache()
165
-
166
- torch.cuda.empty_cache()
167
-
168
- print("Creating final dataset...")
169
- all_processed = persona_processed + medqa_processed + meddia_processed
170
-
171
- valid_data = {
172
- "input": [],
173
- "output": []
174
- }
175
-
176
- for item in all_processed:
177
- if item["input"].strip() and item["output"].strip():
178
- valid_data["input"].append(item["input"])
179
- valid_data["output"].append(item["output"])
180
-
181
- final_dataset = Dataset.from_dict(valid_data)
182
-
183
- print(f"Final dataset size: {len(final_dataset)}")
184
- return final_dataset
185
-
186
- def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4):
187
- def tokenize_batch(examples):
188
- formatted_texts = []
189
-
190
- for i in range(0, len(examples['input']), batch_size):
191
- sub_batch_inputs = examples['input'][i:i + batch_size]
192
- sub_batch_outputs = examples['output'][i:i + batch_size]
193
-
194
- for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs):
195
- try:
196
- formatted_text = f"""<start_of_turn>user
197
- {input_text}
198
- <end_of_turn>
199
- <start_of_turn>assistant
200
- {output_text}
201
- <end_of_turn>"""
202
- formatted_texts.append(formatted_text)
203
- except Exception as e:
204
- print(f"Error formatting text: {e}")
205
- continue
206
-
207
- tokenized = tokenizer(
208
- formatted_texts,
209
- padding="max_length",
210
- truncation=True,
211
- max_length=max_length,
212
- return_tensors=None
213
- )
214
-
215
- tokenized["labels"] = tokenized["input_ids"].copy()
216
- return tokenized
217
-
218
- print(f"Tokenizing dataset in small batches (size={batch_size})...")
219
- tokenized_dataset = dataset.map(
220
- tokenize_batch,
221
- batched=True,
222
- batch_size=batch_size,
223
- remove_columns=dataset.column_names,
224
- desc="Tokenizing dataset",
225
- load_from_cache_file=False
226
- )
227
-
228
- return tokenized_dataset
229
-
230
- def setup_model_and_tokenizer(model_name="google/gemma-2b"):
231
- tokenizer = AutoTokenizer.from_pretrained(model_name)
232
- tokenizer.pad_token = tokenizer.eos_token
233
-
234
- from transformers import BitsAndBytesConfig
235
-
236
- bnb_config = BitsAndBytesConfig(
237
- load_in_8bit=True,
238
- bnb_8bit_compute_dtype=torch.float16,
239
- llm_int8_enable_fp32_cpu_offload=True
240
- )
241
-
242
- model = AutoModelForCausalLM.from_pretrained(
243
- model_name,
244
- device_map="auto",
245
- quantization_config=bnb_config,
246
- torch_dtype=torch.float16,
247
- low_cpu_mem_usage=True
248
- )
249
-
250
- model = prepare_model_for_kbit_training(model)
251
-
252
- lora_config = LoraConfig(
253
- r=4,
254
- lora_alpha=16,
255
- target_modules=["q_proj", "v_proj"],
256
- lora_dropout=0.05,
257
- bias="none",
258
- task_type="CAUSAL_LM"
259
- )
260
-
261
- model = get_peft_model(model, lora_config)
262
- model.print_trainable_parameters()
263
-
264
- return model, tokenizer
265
-
266
- def setup_training_arguments(output_dir="./pearly_fine_tuned"):
267
- return TrainingArguments(
268
- output_dir=output_dir,
269
- num_train_epochs=1,
270
- per_device_train_batch_size=1,
271
- gradient_accumulation_steps=16,
272
- warmup_steps=50,
273
- logging_steps=10,
274
- save_steps=200,
275
- learning_rate=2e-4,
276
- fp16=True,
277
- gradient_checkpointing=True,
278
- gradient_checkpointing_kwargs={"use_reentrant": False},
279
- optim="adamw_8bit",
280
- max_grad_norm=0.3,
281
- weight_decay=0.001,
282
- logging_dir="./logs",
283
- save_total_limit=2,
284
- remove_unused_columns=False,
285
- dataloader_pin_memory=False,
286
- max_steps=500,
287
- report_to=["none"],
288
- )
289
-
290
- def main():
291
- torch.backends.cuda.matmul.allow_tf32 = False
292
- torch.backends.cudnn.allow_tf32 = False
293
-
294
- torch.cuda.empty_cache()
295
- if torch.cuda.is_available():
296
- torch.cuda.reset_peak_memory_stats()
297
-
298
- print("Preparing initial datasets...")
299
- combined_dataset = prepare_initial_datasets(batch_size=4)
300
-
301
- print(f"\nDataset size: {len(combined_dataset)}")
302
- print(f"Column names: {combined_dataset.column_names}")
303
-
304
- if len(combined_dataset) > 0:
305
- print("\nSample input-output pair:")
306
- print(f"Input: {combined_dataset[0]['input'][:100]}...")
307
- print(f"Output: {combined_dataset[0]['output'][:100]}...")
308
-
309
- print("\nSetting up model and tokenizer...")
310
- model, tokenizer = setup_model_and_tokenizer()
311
-
312
- print("\nPreparing dataset for training...")
313
- processed_dataset = prepare_dataset(
314
- combined_dataset,
315
- tokenizer,
316
- max_length=256,
317
- batch_size=2
318
- )
319
-
320
- torch.cuda.empty_cache()
321
-
322
- training_args = setup_training_arguments()
323
-
324
- trainer = Trainer(
325
- model=model,
326
- args=training_args,
327
- train_dataset=processed_dataset,
328
- tokenizer=tokenizer,
329
- )
330
-
331
- print("\nStarting training...")
332
- try:
333
- trainer.train()
334
- except Exception as e:
335
- print(f"Training error: {e}")
336
- torch.cuda.empty_cache()
337
- raise e
338
- finally:
339
- torch.cuda.empty_cache()
340
-
341
- print("\nSaving model...")
342
- trainer.save_model()
343
- print("Training completed!")
344
-
345
- DISCLAIMER = """
346
- IMPORTANT MEDICAL DISCLAIMER:
347
- Pearly is an AI medical triage assistant designed to help direct you to appropriate medical services.
348
- Pearly DOES NOT:
349
- - Make medical diagnoses
350
- - Prescribe medications
351
- - Provide specific treatment recommendations
352
- - Replace professional medical advice
353
-
354
- Always consult qualified healthcare professionals for medical advice and treatment.
355
- In case of emergency, call 999 immediately.
356
- """
357
 
358
  class PearlyBot:
359
- def __init__(self, model_path="./pearly_fine_tuned", embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
360
- print("Loading saved model...")
361
- print(DISCLAIMER)
362
-
363
- # Clean memory
364
- if torch.cuda.is_available():
365
- torch.cuda.empty_cache()
366
-
367
- # Load tokenizer and model directly from saved path
368
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
369
- self.model = AutoModelForCausalLM.from_pretrained(
370
- model_path,
371
- torch_dtype=torch.float16,
372
- low_cpu_mem_usage=True,
373
- device_map="auto"
374
- )
375
-
376
- self.model.eval() # Set to evaluation mode
377
-
378
- # Initialize RAG components
379
- self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
380
- self.vector_store = None
381
  self.conversation_history = []
382
-
383
- def initialize_rag(self, documents_path="./knowledge_base"):
384
- """Initialize RAG system"""
385
- print("Loading knowledge base...")
386
-
387
- text_splitter = RecursiveCharacterTextSplitter(
388
- chunk_size=300,
389
- chunk_overlap=100,
390
- separators=["\n\n", "\n", ".", "!", "?", ":"]
391
- )
392
 
393
- documents = []
394
- for filename in os.listdir(documents_path):
395
- if filename.endswith('.txt'):
396
- loader = TextLoader(os.path.join(documents_path, filename))
397
- documents.extend(loader.load())
398
-
399
- texts = text_splitter.split_documents(documents)
400
- self.vector_store = FAISS.from_documents(texts, self.embeddings)
401
- self.retriever = self.vector_store.as_retriever(
402
- search_type="similarity",
403
- search_kwargs={"k": 5}
404
- )
405
- print("Knowledge base loaded successfully!")
406
-
407
- def get_relevant_context(self, user_input):
408
- if not self.retriever:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  return ""
410
- docs = self.retriever.get_relevant_documents(user_input)
411
- return "\n\n".join([doc.page_content for doc in docs])
412
-
413
- def generate_response(self, user_input):
414
- context = self.get_relevant_context(user_input)
415
- history = "\n".join([
416
- f"User: {turn['user']}\nAssistant: {turn['assistant']}\n"
417
- for turn in self.conversation_history[-3:]
418
- ])
419
-
420
- prompt = f"""<start_of_turn>system
421
- As Pearly, I use the following medical guidelines to help triage patients:
 
 
 
 
422
 
423
  {context}
424
 
425
- Previous Conversation:
426
- {history}
427
 
428
- Based on these guidelines, I will:
429
  1. Assess symptoms and severity
430
  2. Ask relevant follow-up questions
431
  3. Direct to appropriate care (999, 111, or GP)
@@ -433,18 +440,18 @@ Based on these guidelines, I will:
433
  5. Never diagnose or recommend treatments
434
  <end_of_turn>
435
  <start_of_turn>user
436
- {user_input}
437
  <end_of_turn>
438
  <start_of_turn>assistant"""
439
 
440
- inputs = self.tokenizer(
441
- prompt,
442
- return_tensors="pt",
443
- truncation=True,
444
- max_length=512
445
- ).to(self.model.device)
446
-
447
- with torch.no_grad():
448
  outputs = self.model.generate(
449
  **inputs,
450
  max_new_tokens=256,
@@ -452,21 +459,20 @@ Based on these guidelines, I will:
452
  do_sample=True,
453
  temperature=0.7,
454
  top_p=0.9,
455
- repetition_penalty=1.2,
456
- pad_token_id=self.tokenizer.pad_token_id
457
  )
458
-
459
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
460
- response = response.split("<start_of_turn>assistant")[-1].strip()
461
- if "<end_of_turn>" in response:
462
- response = response.split("<end_of_turn>")[0].strip()
463
-
464
- self.conversation_history.append({
465
- "user": user_input,
466
- "assistant": response
467
- })
468
-
469
- return response
470
 
471
  def create_demo():
472
  """Set up Gradio interface for the chatbot with enhanced styling and functionality."""
@@ -475,9 +481,9 @@ def create_demo():
475
  @gr.routes.get("/health")
476
  def health_check():
477
  return {"status": "healthy"}
478
- bot = AdaptiveMedicalBot()
479
 
480
- def chat(message: str, history: List[Dict[str, str]]):
481
  try:
482
  if not message.strip():
483
  return history
 
63
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
64
 
65
 
66
+ # Setup logging
67
+ logging.basicConfig(level=logging.INFO)
68
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  class PearlyBot:
71
+ def __init__(self):
72
+ self.setup_model()
73
+ self.setup_rag()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  self.conversation_history = []
 
 
 
 
 
 
 
 
 
 
75
 
76
+ def setup_model(self):
77
+ try:
78
+ logger.info("Loading local checkpoint...")
79
+ # Load from the local checkpoint in your Space
80
+ checkpoint_path = "checkpoint-500.zip" # Path to your uploaded checkpoint
81
+ base_model_id = "google/gemma-2b" # Your base model
82
+
83
+ # Load tokenizer from base model
84
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model_id)
85
+
86
+ # Load model with checkpoint
87
+ self.model = AutoModelForCausalLM.from_pretrained(
88
+ checkpoint_path,
89
+ device_map="auto",
90
+ load_in_8bit=True,
91
+ torch_dtype=torch.float16,
92
+ low_cpu_mem_usage=True
93
+ )
94
+ self.model.eval()
95
+ logger.info("Model loaded successfully")
96
+ except Exception as e:
97
+ logger.error(f"Error loading model: {str(e)}")
98
+ raise
99
+
100
+ def setup_rag(self):
101
+ try:
102
+ logger.info("Setting up RAG system...")
103
+ # Load your knowledge base content
104
+ knowledge_base = {
105
+ "triage_scenarios.txt": """Medical Triage Scenarios and Responses:
106
+
107
+ EMERGENCY (999) SCENARIOS:
108
+ 1. Cardiovascular:
109
+ - Chest pain/pressure
110
+ - Heart attack symptoms
111
+ - Irregular heartbeat with dizziness
112
+ Response: Immediate 999 call, sit/lie down, chew aspirin if available
113
+
114
+ 2. Respiratory:
115
+ - Severe breathing difficulty
116
+ - Choking
117
+ - Unable to speak full sentences
118
+ Response: 999, sitting position, clear airway
119
+
120
+ 3. Neurological:
121
+ - Stroke symptoms (FAST)
122
+ - Seizures
123
+ - Unconsciousness
124
+ Response: 999, recovery position if unconscious
125
+
126
+ 4. Trauma:
127
+ - Severe bleeding
128
+ - Head injuries with confusion
129
+ - Major burns
130
+ Response: 999, apply direct pressure to bleeding
131
+
132
+ URGENT CARE (111) SCENARIOS:
133
+ 1. Moderate Symptoms:
134
+ - Persistent fever
135
+ - Non-severe infections
136
+ - Minor injuries
137
+ Response: 111 contact, monitor symptoms
138
+
139
+ 2. Minor Emergencies:
140
+ - Small cuts needing stitches
141
+ - Sprains and strains
142
+ - Mild allergic reactions
143
+ Response: 111 or urgent care visit
144
+
145
+ GP APPOINTMENT SCENARIOS:
146
+ 1. Routine Care:
147
+ - Chronic condition review
148
+ - Medication reviews
149
+ - Non-urgent symptoms
150
+ Response: Book routine GP appointment
151
+
152
+ 2. Preventive Care:
153
+ - Vaccinations
154
+ - Health screenings
155
+ - Regular check-ups
156
+ Response: Schedule with GP reception""",
157
+ "emergency_detection.txt": """Enhanced Emergency Detection Criteria:
158
+
159
+ IMMEDIATE LIFE THREATS:
160
+ 1. Cardiac Symptoms:
161
+ - Chest pain/pressure/tightness
162
+ - Pain spreading to arms/jaw/neck
163
+ - Sweating with nausea
164
+ - Shortness of breath
165
+
166
+ 2. Breathing Problems:
167
+ - Severe shortness of breath
168
+ - Blue lips or face
169
+ - Unable to complete sentences
170
+ - Choking/airway blockage
171
+
172
+ 3. Neurological:
173
+ - FAST (Face, Arms, Speech, Time)
174
+ - Sudden confusion
175
+ - Severe headache
176
+ - Seizures
177
+ - Loss of consciousness
178
+
179
+ 4. Severe Trauma:
180
+ - Heavy bleeding
181
+ - Deep wounds
182
+ - Head injury with confusion
183
+ - Severe burns
184
+ - Broken bones with deformity
185
+
186
+ 5. Anaphylaxis:
187
+ - Sudden swelling
188
+ - Difficulty breathing
189
+ - Rapid onset rash
190
+ - Light-headedness
191
+
192
+ URGENT BUT NOT IMMEDIATE:
193
+ 1. Moderate Symptoms:
194
+ - Persistent fever
195
+ - Dehydration
196
+ - Non-severe infections
197
+ - Minor injuries
198
+
199
+ 2. Worsening Conditions:
200
+ - Increasing pain
201
+ - Progressive symptoms
202
+ - Medication reactions
203
+
204
+ RESPONSE PROTOCOLS:
205
+ 1. For Life Threats:
206
+ - Immediate 999 call
207
+ - Clear first aid instructions
208
+ - Stay on line until help arrives
209
+
210
+ 2. For Urgent Care:
211
+ - 111 contact
212
+ - Monitor for worsening
213
+ - Document symptoms""",
214
+ "gp_booking.txt": """GP Appointment Booking Templates:
215
+
216
+ APPOINTMENT TYPES:
217
+ 1. Routine Appointments:
218
+ Template: "I need to book a routine appointment for [condition]. My availability is [times/dates]. My GP is Dr. [name] if available."
219
+
220
+ 2. Follow-up Appointments:
221
+ Template: "I need a follow-up appointment regarding [condition] discussed on [date]. My previous appointment was with Dr. [name]."
222
+
223
+ 3. Medication Reviews:
224
+ Template: "I need a medication review for [medication]. My last review was [date]."
225
+
226
+ BOOKING INFORMATION NEEDED:
227
+ 1. Patient Details:
228
+ - Full name
229
+ - Date of birth
230
+ - NHS number (if known)
231
+ - Registered GP practice
232
+
233
+ 2. Appointment Details:
234
+ - Nature of appointment
235
+ - Preferred times/dates
236
+ - Urgency level
237
+ - Special requirements
238
+
239
+ 3. Contact Information:
240
+ - Phone number
241
+ - Alternative contact
242
+ - Preferred contact method
243
+
244
+ BOOKING PROCESS:
245
+ 1. Online Booking:
246
+ - NHS app instructions
247
+ - Practice website guidance
248
+ - System navigation help
249
+
250
+ 2. Phone Booking:
251
+ - Best times to call
252
+ - Required information
253
+ - Queue management tips
254
+
255
+ 3. Special Circumstances:
256
+ - Interpreter needs
257
+ - Accessibility requirements
258
+ - Transport arrangements""",
259
+ "cultural_sensitivity.txt": """Cultural Sensitivity Guidelines:
260
+
261
+ CULTURAL AWARENESS:
262
+ 1. Religious Considerations:
263
+ - Prayer times
264
+ - Religious observations
265
+ - Dietary restrictions
266
+ - Gender preferences for care
267
+ - Religious festivals/fasting periods
268
+
269
+ 2. Language Support:
270
+ - Interpreter services
271
+ - Multi-language resources
272
+ - Clear communication methods
273
+ - Family involvement preferences
274
+
275
+ 3. Cultural Beliefs:
276
+ - Traditional medicine practices
277
+ - Cultural health beliefs
278
+ - Family decision-making
279
+ - Privacy customs
280
+
281
+ COMMUNICATION APPROACHES:
282
+ 1. Respectful Interaction:
283
+ - Use preferred names/titles
284
+ - Appropriate greetings
285
+ - Non-judgmental responses
286
+ - Active listening
287
+
288
+ 2. Language Usage:
289
+ - Clear, simple terms
290
+ - Avoid medical jargon
291
+ - Confirm understanding
292
+ - Respect silence/pauses
293
+
294
+ 3. Non-verbal Communication:
295
+ - Eye contact customs
296
+ - Personal space
297
+ - Body language awareness
298
+ - Gesture sensitivity
299
+
300
+ SPECIFIC CONSIDERATIONS:
301
+ 1. South Asian Communities:
302
+ - Family involvement
303
+ - Gender sensitivity
304
+ - Traditional medicine
305
+ - Language diversity
306
+
307
+ 2. Middle Eastern Communities:
308
+ - Gender-specific care
309
+ - Religious observations
310
+ - Family hierarchies
311
+ - Privacy concerns
312
+
313
+ 3. African/Caribbean Communities:
314
+ - Traditional healers
315
+ - Community involvement
316
+ - Historical medical mistrust
317
+ - Cultural specific conditions
318
+
319
+ 4. Eastern European Communities:
320
+ - Direct communication
321
+ - Family involvement
322
+ - Medical documentation
323
+ - Language support
324
+
325
+ INCLUSIVE PRACTICES:
326
+ 1. Appointment Scheduling:
327
+ - Religious holidays
328
+ - Prayer times
329
+ - Family availability
330
+ - Interpreter needs
331
+
332
+ 2. Treatment Planning:
333
+ - Cultural preferences
334
+ - Traditional practices
335
+ - Family involvement
336
+ - Dietary requirements
337
+
338
+ 3. Support Services:
339
+ - Community resources
340
+ - Cultural organizations
341
+ - Language services
342
+ - Social support""",
343
+ "service_boundaries.txt": """Service Limitations and Professional Boundaries:
344
+
345
+ CLEAR BOUNDARIES:
346
+ 1. Medical Advice:
347
+ - No diagnoses
348
+ - No prescriptions
349
+ - No treatment recommendations
350
+ - No medical procedures
351
+ - No second opinions
352
+
353
+ 2. Emergency Services:
354
+ - Clear referral criteria
355
+ - Documented responses
356
+ - Follow-up protocols
357
+ - Handover procedures
358
+
359
+ 3. Information Sharing:
360
+ - Confidentiality limits
361
+ - Data protection
362
+ - Record keeping
363
+ - Information governance
364
+
365
+ PROFESSIONAL CONDUCT:
366
+ 1. Communication:
367
+ - Professional language
368
+ - Emotional boundaries
369
+ - Personal distance
370
+ - Service scope
371
+
372
+ 2. Service Delivery:
373
+ - No financial transactions
374
+ - No personal relationships
375
+ - Clear role definition
376
+ - Professional limits"""
377
+ }
378
+
379
+ os.makedirs("knowledge_base", exist_ok=True)
380
+
381
+ # Create and process documents
382
+ documents = []
383
+ for filename, content in knowledge_base.items():
384
+ with open(f"knowledge_base/{filename}", "w") as f:
385
+ f.write(content)
386
+ documents.append(content)
387
+
388
+ # Setup embeddings and vector store
389
+ self.embeddings = HuggingFaceEmbeddings(
390
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
391
+ )
392
+
393
+ text_splitter = RecursiveCharacterTextSplitter(
394
+ chunk_size=300,
395
+ chunk_overlap=100
396
+ )
397
+
398
+ texts = text_splitter.split_text("\n\n".join(documents))
399
+ self.vector_store = FAISS.from_texts(texts, self.embeddings)
400
+ logger.info("RAG system setup complete")
401
+
402
+ except Exception as e:
403
+ logger.error(f"Error setting up RAG: {str(e)}")
404
+ raise
405
+
406
+ def get_relevant_context(self, query):
407
+ try:
408
+ docs = self.vector_store.similarity_search(query, k=3)
409
+ return "\n".join(doc.page_content for doc in docs)
410
+ except Exception as e:
411
+ logger.error(f"Error retrieving context: {str(e)}")
412
  return ""
413
+
414
+ @torch.inference_mode()
415
+ def generate_response(self, message: str, history: list) -> str:
416
+ try:
417
+ # Get RAG context
418
+ context = self.get_relevant_context(message)
419
+
420
+ # Format conversation history
421
+ conv_history = "\n".join([
422
+ f"User: {user}\nAssistant: {assistant}"
423
+ for user, assistant in history[-3:] # Keep last 3 turns
424
+ ])
425
+
426
+ # Create prompt
427
+ prompt = f"""<start_of_turn>system
428
+ Using these medical guidelines:
429
 
430
  {context}
431
 
432
+ Previous conversation:
433
+ {conv_history}
434
 
435
+ Guidelines:
436
  1. Assess symptoms and severity
437
  2. Ask relevant follow-up questions
438
  3. Direct to appropriate care (999, 111, or GP)
 
440
  5. Never diagnose or recommend treatments
441
  <end_of_turn>
442
  <start_of_turn>user
443
+ {message}
444
  <end_of_turn>
445
  <start_of_turn>assistant"""
446
 
447
+ # Generate response
448
+ inputs = self.tokenizer(
449
+ prompt,
450
+ return_tensors="pt",
451
+ truncation=True,
452
+ max_length=512
453
+ ).to(self.model.device)
454
+
455
  outputs = self.model.generate(
456
  **inputs,
457
  max_new_tokens=256,
 
459
  do_sample=True,
460
  temperature=0.7,
461
  top_p=0.9,
462
+ repetition_penalty=1.2
 
463
  )
464
+
465
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
466
+ response = response.split("<start_of_turn>assistant")[-1].strip()
467
+ if "<end_of_turn>" in response:
468
+ response = response.split("<end_of_turn>")[0].strip()
469
+
470
+ return response
471
+
472
+ except Exception as e:
473
+ logger.error(f"Error generating response: {str(e)}")
474
+ return "I apologize, but I encountered an error. Please try again."
475
+
476
 
477
  def create_demo():
478
  """Set up Gradio interface for the chatbot with enhanced styling and functionality."""
 
481
  @gr.routes.get("/health")
482
  def health_check():
483
  return {"status": "healthy"}
484
+ bot = PearlyBot() # ✅
485
 
486
+ def chat(message: str, history: list): # ✅
487
  try:
488
  if not message.strip():
489
  return history