PearlIsa commited on
Commit
28ef8c8
·
verified ·
1 Parent(s): f510fb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -1146
app.py CHANGED
@@ -1,621 +1,75 @@
1
- # Standard imports first
 
 
2
  import os
 
3
  import torch
4
  import logging
 
 
5
  from datetime import datetime
6
  from huggingface_hub import login
7
  from dotenv import load_dotenv
8
- from datasets import load_dataset, Dataset
9
  from transformers import (
10
- AutoTokenizer,
11
- AutoModelForCausalLM,
12
- TrainingArguments,
13
- Trainer,
14
  BitsAndBytesConfig
15
  )
16
- from peft import (
17
- LoraConfig,
18
- get_peft_model,
19
- prepare_model_for_kbit_training
20
- )
21
- from tqdm.auto import tqdm
22
 
23
- # Setup logging
24
  logging.basicConfig(level=logging.INFO)
25
- logger = logging.getLogger(__name__)
26
 
 
 
 
27
  class SecretsManager:
28
- """Handles authentication and secrets management"""
29
-
30
  @staticmethod
31
- def setup_credentials():
32
- """Setup all required credentials"""
33
- try:
34
- # Load environment variables
35
- load_dotenv()
36
-
37
- # Get credentials
38
- credentials = {
39
- 'KAGGLE_USERNAME': os.getenv('KAGGLE_USERNAME'),
40
- 'KAGGLE_KEY': os.getenv('KAGGLE_KEY'),
41
- 'HF_TOKEN': os.getenv('HF_TOKEN'),
42
- 'WANDB_KEY': os.getenv('WANDB_KEY')
43
- }
44
-
45
- # Validate credentials
46
- missing_creds = [k for k, v in credentials.items() if not v]
47
- if missing_creds:
48
- logger.warning(f"Missing credentials: {', '.join(missing_creds)}")
49
-
50
- # Setup Hugging Face authentication
51
- if credentials['HF_TOKEN']:
52
- login(token=credentials['HF_TOKEN'])
53
- logger.info("Successfully logged in to Hugging Face")
54
- # Setup Kaggle credentials if available
55
- if credentials['KAGGLE_USERNAME'] and credentials['KAGGLE_KEY']:
56
- os.environ['KAGGLE_USERNAME'] = credentials['KAGGLE_USERNAME']
57
- os.environ['KAGGLE_KEY'] = credentials['KAGGLE_KEY']
58
-
59
- # Setup wandb if available
60
- if credentials['WANDB_KEY']:
61
- os.environ['WANDB_API_KEY'] = credentials['WANDB_KEY']
62
-
63
- return credentials
64
-
65
- except Exception as e:
66
- logger.error(f"Error setting up credentials: {e}")
67
- raise
68
- class ModelTrainer:
69
- """Handles model training pipeline"""
70
-
71
  def __init__(self):
72
- # Set memory optimization environment variables
73
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True'
74
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
75
-
76
- # Initialize attributes
77
- self.model = None
78
  self.tokenizer = None
79
- self.dataset = None
80
- self.processed_dataset = None
 
81
  self.chunk_size = 300
82
  self.chunk_overlap = 100
83
  self.num_relevant_chunks = 3
84
- self.vector_store = None
85
- self.embeddings = None
86
- self.last_interaction_time = time.time() # Add this
87
- self.interaction_cooldown = 1.0 # Add this
88
-
89
- # Setup GPU preferences
90
- torch.backends.cuda.matmul.allow_tf32 = False
91
- torch.backends.cudnn.allow_tf32 = False
92
-
93
- def prepare_initial_datasets(batch_size=8):
94
- print("Loading datasets with memory-optimized batch processing...")
95
-
96
- def process_medqa_batch(examples):
97
- results = []
98
- inputs = examples['input']
99
- instructions = examples['instruction']
100
- outputs = examples['output']
101
-
102
- for inp, inst, out in zip(inputs, instructions, outputs):
103
- results.append({
104
- "input": f"{inp} {inst}",
105
- "output": out
106
- })
107
- return results
108
-
109
- def process_meddia_batch(examples):
110
- results = []
111
- inputs = examples['input']
112
- outputs = examples['output']
113
-
114
- for inp, out in zip(inputs, outputs):
115
- results.append({
116
- "input": inp,
117
- "output": out
118
- })
119
- return results
120
 
121
- def process_persona_batch(examples):
122
- results = []
123
- personalities = examples['personality']
124
- utterances = examples['utterances']
125
-
126
- for pers, utts in zip(personalities, utterances):
127
- try:
128
- # Process personality list
129
- personality = ' '.join([
130
- p for p in pers
131
- if isinstance(p, str)
132
- ])
133
-
134
- # Process utterances
135
- if utts and len(utts) > 0:
136
- utterance = utts[0]
137
- history = []
138
-
139
- # Process history
140
- if 'history' in utterance and utterance['history']:
141
- history = [
142
- h for h in utterance['history']
143
- if isinstance(h, str)
144
- ]
145
-
146
- history_text = ' '.join(history)
147
-
148
- # Get candidate response
149
- candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else ''
150
-
151
- if personality or history_text:
152
- results.append({
153
- "input": f"{personality} {history_text}".strip(),
154
- "output": candidate
155
- })
156
- except Exception as e:
157
- print(f"Error processing persona batch item: {e}")
158
- continue
159
-
160
- return results
161
- try:
162
- Load and process each dataset separately
163
- print("Processing MedQA dataset...")
164
- medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]")
165
- medqa_processed = []
166
-
167
- for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"):
168
- batch = medqa[i:i + batch_size]
169
- medqa_processed.extend(process_medqa_batch(batch))
170
- if i % (batch_size * 5) == 0:
171
- torch.cuda.empty_cache()
172
-
173
- print("Processing MedDiagnosis dataset...")
174
- meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]")
175
- meddia_processed = []
176
-
177
- for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"):
178
- batch = meddia[i:i + batch_size]
179
- meddia_processed.extend(process_meddia_batch(batch))
180
- if i % (batch_size * 5) == 0:
181
- torch.cuda.empty_cache()
182
-
183
- print("Processing Persona-Chat dataset...")
184
- persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
185
- persona_processed = []
186
-
187
- for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"):
188
- batch = persona[i:i + batch_size]
189
- persona_processed.extend(process_persona_batch(batch))
190
- if i % (batch_size * 5) == 0:
191
- torch.cuda.empty_cache()
192
-
193
- torch.cuda.empty_cache()
194
-
195
- print("Creating final dataset...")
196
- all_processed = persona_processed + medqa_processed + meddia_processed
197
-
198
- valid_data = {
199
- "input": [],
200
- "output": []
201
- }
202
-
203
- for item in all_processed:
204
- if item["input"].strip() and item["output"].strip():
205
- valid_data["input"].append(item["input"])
206
- valid_data["output"].append(item["output"])
207
-
208
- final_dataset = Dataset.from_dict(valid_data)
209
-
210
- print(f"Final dataset size: {len(final_dataset)}")
211
- return final_dataset
212
-
213
- def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4):
214
- def tokenize_batch(examples):
215
- formatted_texts = []
216
-
217
- for i in range(0, len(examples['input']), batch_size):
218
- sub_batch_inputs = examples['input'][i:i + batch_size]
219
- sub_batch_outputs = examples['output'][i:i + batch_size]
220
-
221
- for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs):
222
- try:
223
- formatted_text = f"""<start_of_turn>user
224
- {input_text}
225
- <end_of_turn>
226
- <start_of_turn>assistant
227
- {output_text}
228
- <end_of_turn>"""
229
- formatted_texts.append(formatted_text)
230
- except Exception as e:
231
- print(f"Error formatting text: {e}")
232
- continue
233
-
234
- tokenized = tokenizer(
235
- formatted_texts,
236
- padding="max_length",
237
- truncation=True,
238
- max_length=max_length,
239
- return_tensors=None
240
- )
241
-
242
- tokenized["labels"] = tokenized["input_ids"].copy()
243
- return tokenized
244
-
245
- print(f"Tokenizing dataset in small batches (size={batch_size})...")
246
- tokenized_dataset = dataset.map(
247
- tokenize_batch,
248
- batched=True,
249
- batch_size=batch_size,
250
- remove_columns=dataset.column_names,
251
- desc="Tokenizing dataset",
252
- load_from_cache_file=False
253
- )
254
-
255
- return tokenized_dataset
256
-
257
- def setup_rag(self):
258
- """Initialize RAG components"""
259
- try:
260
- logger.info("Setting up RAG system...")
261
-
262
- # Load knowledge base
263
- knowledge_base = self._load_knowledge_base()
264
-
265
- # Setup embeddings
266
- self.embeddings = self._initialize_embeddings()
267
-
268
- # Process texts for vector store
269
- texts = self._split_texts(knowledge_base)
270
-
271
- # Create vector store with metadata
272
- self.vector_store = FAISS.from_texts(
273
- texts,
274
- self.embeddings,
275
- metadatas=[{"source": f"chunk_{i}"} for i in range(len(texts))]
276
- )
277
-
278
- # Validate RAG setup
279
- self._validate_rag_setup()
280
- logger.info("RAG system setup complete")
281
-
282
- except Exception as e:
283
- logger.error(f"Failed to setup RAG: {e}")
284
- raise
285
-
286
- # Load your knowledge base content
287
- def _load_knowledge_base(self):
288
- """Load and validate knowledge base content"""
289
- try:
290
- knowledge_base = {
291
- "triage_scenarios.txt": """Medical Triage Scenarios and Responses:
292
-
293
- EMERGENCY (999) SCENARIOS:
294
- 1. Cardiovascular:
295
- - Chest pain/pressure
296
- - Heart attack symptoms
297
- - Irregular heartbeat with dizziness
298
- Response: Immediate 999 call, sit/lie down, chew aspirin if available
299
-
300
- 2. Respiratory:
301
- - Severe breathing difficulty
302
- - Choking
303
- - Unable to speak full sentences
304
- Response: 999, sitting position, clear airway
305
-
306
- 3. Neurological:
307
- - Stroke symptoms (FAST)
308
- - Seizures
309
- - Unconsciousness
310
- Response: 999, recovery position if unconscious
311
-
312
- 4. Trauma:
313
- - Severe bleeding
314
- - Head injuries with confusion
315
- - Major burns
316
- Response: 999, apply direct pressure to bleeding
317
-
318
- URGENT CARE (111) SCENARIOS:
319
- 1. Moderate Symptoms:
320
- - Persistent fever
321
- - Non-severe infections
322
- - Minor injuries
323
- Response: 111 contact, monitor symptoms
324
-
325
- 2. Minor Emergencies:
326
- - Small cuts needing stitches
327
- - Sprains and strains
328
- - Mild allergic reactions
329
- Response: 111 or urgent care visit
330
-
331
- GP APPOINTMENT SCENARIOS:
332
- 1. Routine Care:
333
- - Chronic condition review
334
- - Medication reviews
335
- - Non-urgent symptoms
336
- Response: Book routine GP appointment
337
-
338
- 2. Preventive Care:
339
- - Vaccinations
340
- - Health screenings
341
- - Regular check-ups
342
- Response: Schedule with GP reception""",
343
- "emergency_detection.txt": """Enhanced Emergency Detection Criteria:
344
-
345
- IMMEDIATE LIFE THREATS:
346
- 1. Cardiac Symptoms:
347
- - Chest pain/pressure/tightness
348
- - Pain spreading to arms/jaw/neck
349
- - Sweating with nausea
350
- - Shortness of breath
351
-
352
- 2. Breathing Problems:
353
- - Severe shortness of breath
354
- - Blue lips or face
355
- - Unable to complete sentences
356
- - Choking/airway blockage
357
-
358
- 3. Neurological:
359
- - FAST (Face, Arms, Speech, Time)
360
- - Sudden confusion
361
- - Severe headache
362
- - Seizures
363
- - Loss of consciousness
364
-
365
- 4. Severe Trauma:
366
- - Heavy bleeding
367
- - Deep wounds
368
- - Head injury with confusion
369
- - Severe burns
370
- - Broken bones with deformity
371
-
372
- 5. Anaphylaxis:
373
- - Sudden swelling
374
- - Difficulty breathing
375
- - Rapid onset rash
376
- - Light-headedness
377
-
378
- URGENT BUT NOT IMMEDIATE:
379
- 1. Moderate Symptoms:
380
- - Persistent fever
381
- - Dehydration
382
- - Non-severe infections
383
- - Minor injuries
384
-
385
- 2. Worsening Conditions:
386
- - Increasing pain
387
- - Progressive symptoms
388
- - Medication reactions
389
-
390
- RESPONSE PROTOCOLS:
391
- 1. For Life Threats:
392
- - Immediate 999 call
393
- - Clear first aid instructions
394
- - Stay on line until help arrives
395
-
396
- 2. For Urgent Care:
397
- - 111 contact
398
- - Monitor for worsening
399
- - Document symptoms""",
400
- "gp_booking.txt": """GP Appointment Booking Templates:
401
-
402
- APPOINTMENT TYPES:
403
- 1. Routine Appointments:
404
- Template: "I need to book a routine appointment for [condition]. My availability is [times/dates]. My GP is Dr. [name] if available."
405
-
406
- 2. Follow-up Appointments:
407
- Template: "I need a follow-up appointment regarding [condition] discussed on [date]. My previous appointment was with Dr. [name]."
408
-
409
- 3. Medication Reviews:
410
- Template: "I need a medication review for [medication]. My last review was [date]."
411
-
412
- BOOKING INFORMATION NEEDED:
413
- 1. Patient Details:
414
- - Full name
415
- - Date of birth
416
- - NHS number (if known)
417
- - Registered GP practice
418
-
419
- 2. Appointment Details:
420
- - Nature of appointment
421
- - Preferred times/dates
422
- - Urgency level
423
- - Special requirements
424
-
425
- 3. Contact Information:
426
- - Phone number
427
- - Alternative contact
428
- - Preferred contact method
429
-
430
- BOOKING PROCESS:
431
- 1. Online Booking:
432
- - NHS app instructions
433
- - Practice website guidance
434
- - System navigation help
435
-
436
- 2. Phone Booking:
437
- - Best times to call
438
- - Required information
439
- - Queue management tips
440
-
441
- 3. Special Circumstances:
442
- - Interpreter needs
443
- - Accessibility requirements
444
- - Transport arrangements""",
445
- "cultural_sensitivity.txt": """Cultural Sensitivity Guidelines:
446
-
447
- CULTURAL AWARENESS:
448
- 1. Religious Considerations:
449
- - Prayer times
450
- - Religious observations
451
- - Dietary restrictions
452
- - Gender preferences for care
453
- - Religious festivals/fasting periods
454
-
455
- 2. Language Support:
456
- - Interpreter services
457
- - Multi-language resources
458
- - Clear communication methods
459
- - Family involvement preferences
460
-
461
- 3. Cultural Beliefs:
462
- - Traditional medicine practices
463
- - Cultural health beliefs
464
- - Family decision-making
465
- - Privacy customs
466
-
467
- COMMUNICATION APPROACHES:
468
- 1. Respectful Interaction:
469
- - Use preferred names/titles
470
- - Appropriate greetings
471
- - Non-judgmental responses
472
- - Active listening
473
-
474
- 2. Language Usage:
475
- - Clear, simple terms
476
- - Avoid medical jargon
477
- - Confirm understanding
478
- - Respect silence/pauses
479
-
480
- 3. Non-verbal Communication:
481
- - Eye contact customs
482
- - Personal space
483
- - Body language awareness
484
- - Gesture sensitivity
485
-
486
- SPECIFIC CONSIDERATIONS:
487
- 1. South Asian Communities:
488
- - Family involvement
489
- - Gender sensitivity
490
- - Traditional medicine
491
- - Language diversity
492
-
493
- 2. Middle Eastern Communities:
494
- - Gender-specific care
495
- - Religious observations
496
- - Family hierarchies
497
- - Privacy concerns
498
-
499
- 3. African/Caribbean Communities:
500
- - Traditional healers
501
- - Community involvement
502
- - Historical medical mistrust
503
- - Cultural specific conditions
504
-
505
- 4. Eastern European Communities:
506
- - Direct communication
507
- - Family involvement
508
- - Medical documentation
509
- - Language support
510
-
511
- INCLUSIVE PRACTICES:
512
- 1. Appointment Scheduling:
513
- - Religious holidays
514
- - Prayer times
515
- - Family availability
516
- - Interpreter needs
517
-
518
- 2. Treatment Planning:
519
- - Cultural preferences
520
- - Traditional practices
521
- - Family involvement
522
- - Dietary requirements
523
-
524
- 3. Support Services:
525
- - Community resources
526
- - Cultural organizations
527
- - Language services
528
- - Social support""",
529
- "service_boundaries.txt": """Service Limitations and Professional Boundaries:
530
-
531
- CLEAR BOUNDARIES:
532
- 1. Medical Advice:
533
- - No diagnoses
534
- - No prescriptions
535
- - No treatment recommendations
536
- - No medical procedures
537
- - No second opinions
538
-
539
- 2. Emergency Services:
540
- - Clear referral criteria
541
- - Documented responses
542
- - Follow-up protocols
543
- - Handover procedures
544
-
545
- 3. Information Sharing:
546
- - Confidentiality limits
547
- - Data protection
548
- - Record keeping
549
- - Information governance
550
-
551
- PROFESSIONAL CONDUCT:
552
- 1. Communication:
553
- - Professional language
554
- - Emotional boundaries
555
- - Personal distance
556
- - Service scope
557
-
558
- 2. Service Delivery:
559
- - No financial transactions
560
- - No personal relationships
561
- - Clear role definition
562
- - Professional limits"""
563
- }
564
-
565
- # Create knowledge base directory
566
- os.makedirs("knowledge_base", exist_ok=True)
567
-
568
- # Write files and process documents
569
- documents = []
570
- for filename, content in knowledge_base.items():
571
- filepath = os.path.join("knowledge_base", filename)
572
- with open(filepath, "w", encoding="utf-8") as f:
573
- f.write(content)
574
- documents.append(content)
575
- logger.info(f"Written knowledge base file: {filename}")
576
-
577
- return knowledge_base
578
-
579
- except Exception as e:
580
- logger.error(f"Error loading knowledge base: {str(e)}")
581
- raise
582
-
583
- def _validate_rag_setup(self):
584
- """Validate RAG system setup"""
585
- try:
586
- # Verify embeddings are working
587
- test_text = "This is a test embedding"
588
- test_embedding = self.embeddings.encode(test_text)
589
- assert len(test_embedding) > 0
590
-
591
- # Verify vector store is operational
592
- test_results = self.vector_store.similarity_search(test_text, k=1)
593
- assert len(test_results) > 0
594
-
595
- logger.info("RAG system validation successful")
596
- return True
597
- except Exception as e:
598
- logger.error(f"RAG system validation failed: {str(e)}")
599
- raise
600
-
601
-
602
-
603
-
604
-
605
-
606
-
607
- def setup_model_and_tokenizer(model_name="google/gemma-2b"):
608
- tokenizer = AutoTokenizer.from_pretrained(model_name)
609
- tokenizer.pad_token = tokenizer.eos_token
610
-
611
- from transformers import BitsAndBytesConfig
612
-
613
  bnb_config = BitsAndBytesConfig(
614
- load_in_8bit=True,
615
- bnb_8bit_compute_dtype=torch.float16,
616
- llm_int8_enable_fp32_cpu_offload=True
 
617
  )
618
-
 
619
  model = AutoModelForCausalLM.from_pretrained(
620
  model_name,
621
  device_map="auto",
@@ -623,9 +77,7 @@ class ModelTrainer:
623
  torch_dtype=torch.float16,
624
  low_cpu_mem_usage=True
625
  )
626
-
627
  model = prepare_model_for_kbit_training(model)
628
-
629
  lora_config = LoraConfig(
630
  r=4,
631
  lora_alpha=16,
@@ -634,203 +86,106 @@ class ModelTrainer:
634
  bias="none",
635
  task_type="CAUSAL_LM"
636
  )
637
-
638
- model = get_peft_model(model, lora_config)
639
- model.print_trainable_parameters()
640
-
641
- return model, tokenizer
642
-
643
- def setup_training_arguments(output_dir="./pearly_fine_tuned"):
644
- return TrainingArguments(
645
- output_dir=output_dir,
646
- num_train_epochs=1,
647
- per_device_train_batch_size=1,
648
- gradient_accumulation_steps=16,
649
- warmup_steps=50,
650
- logging_steps=10,
651
- save_steps=200,
652
- learning_rate=2e-4,
653
- fp16=True,
654
- gradient_checkpointing=True,
655
- gradient_checkpointing_kwargs={"use_reentrant": False},
656
- optim="adamw_8bit",
657
- max_grad_norm=0.3,
658
- weight_decay=0.001,
659
- logging_dir="./logs",
660
- save_total_limit=2,
661
- remove_unused_columns=False,
662
- dataloader_pin_memory=False,
663
- max_steps=500,
664
- report_to=["none"],
665
- )
666
-
667
- def train(self):
668
- """Main training pipeline with RAG integration"""
669
- try:
670
- logger.info("Starting training pipeline")
671
-
672
- # Clear GPU memory
673
- torch.cuda.empty_cache()
674
- if torch.cuda.is_available():
675
- torch.cuda.reset_peak_memory_stats()
676
-
677
- # Setup model, tokenizer, and RAG
678
- logger.info("Setting up model components...")
679
- self.model, self.tokenizer = self.setup_model_and_tokenizer()
680
- self.setup_rag()
681
-
682
- # Prepare and process datasets
683
- logger.info("Preparing datasets...")
684
- self.dataset = self.prepare_initial_datasets(batch_size=4)
685
- self.processed_dataset = self.prepare_dataset(
686
- self.dataset,
687
- self.tokenizer,
688
- max_length=256,
689
- batch_size=2
690
- )
691
-
692
- # Train model
693
- logger.info("Starting training...")
694
- training_args = self.setup_training_arguments()
695
- trainer = Trainer(
696
- model=self.model,
697
- args=training_args,
698
- train_dataset=self.processed_dataset,
699
- tokenizer=self.tokenizer
700
- )
701
- trainer.train()
702
-
703
- # Save and push to hub
704
- logger.info("Saving model...")
705
- trainer.save_model()
706
- if os.getenv('HF_TOKEN'):
707
- trainer.push_to_hub(
708
- "Pearilsa/pearly_med_triage_chatbot_kagglex",
709
- private=True
710
- )
711
-
712
- logger.info("Training completed successfully!")
713
-
714
- except Exception as e:
715
- logger.error(f"Training failed: {e}")
716
- raise
717
- finally:
718
- torch.cuda.empty_cache()
719
-
720
- if __name__ == "__main__":
721
- # Initialize trainer
722
- trainer = ModelTrainer()
723
-
724
- # Train model
725
- trainer.train()
726
-
727
- def _get_enhanced_context(self, query: str) -> str:
728
- """Get relevant context with scores"""
729
- try:
730
- # Get documents with similarity scores
731
- docs_and_scores = self.vector_store.similarity_search_with_score(
732
- query,
733
- k=self.num_relevant_chunks
734
- )
735
-
736
- # Filter and format relevant contexts
737
- relevant_contexts = []
738
- for doc, score in docs_and_scores:
739
- if score < 0.8: # Lower score means more relevant
740
- source = doc.metadata.get('source', 'Unknown')
741
- relevant_contexts.append(
742
- f"[Source: {source}]\n{doc.page_content}"
743
- )
744
-
745
- return "\n\n".join(relevant_contexts) if relevant_contexts else ""
746
-
747
- except Exception as e:
748
- logger.error(f"Error retrieving enhanced context: {e}")
749
- return ""
750
-
751
- def _initialize_embeddings(self):
752
- try:
753
- return HuggingFaceEmbeddings(
754
  model_name="sentence-transformers/all-MiniLM-L6-v2",
755
- cache_folder="./embeddings_cache" # Added caching
756
  )
757
- except Exception as e:
758
- logger.error(f"Failed to initialize embeddings: {str(e)}")
759
- raise
760
 
761
- def _split_texts(self, knowledge_base):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
762
  splitter = RecursiveCharacterTextSplitter(
763
  chunk_size=self.chunk_size,
764
  chunk_overlap=self.chunk_overlap,
765
  length_function=len,
766
  add_start_index=True
767
  )
768
-
769
- all_texts = []
770
- for content in knowledge_base.values():
771
- texts = splitter.split_text(content)
772
- all_texts.extend(texts)
773
- return all_texts
774
 
775
- def get_relevant_context(self, query):
776
  try:
777
- docs = self.vector_store.similarity_search(query, k=3)
778
- return "\n".join(doc.page_content for doc in docs)
 
 
779
  except Exception as e:
780
- logger.error(f"Error retrieving context: {str(e)}")
781
  return ""
782
 
783
  @torch.inference_mode()
784
- def generate_response(self, message: str, history: list) -> str:
785
- """Generate response using both fine-tuned model and RAG"""
786
  try:
787
- # Rate limiting and memory management
788
- current_time = time.time()
789
- if current_time - self.last_interaction_time < self.interaction_cooldown:
790
  time.sleep(self.interaction_cooldown)
791
- torch.cuda.empty_cache()
792
-
793
- # Get enhanced context from RAG
 
 
794
  context = self._get_enhanced_context(message)
795
-
796
- # Format conversation history
797
  conv_history = "\n".join([
798
- f"User: {turn['input']}\nAssistant: {turn['output']}"
799
- for turn in history[-3:] # Keep last 3 turns
800
  ])
801
-
802
- # Create enhanced prompt with RAG context
803
- prompt = f"""<start_of_turn>system
804
- Using these medical guidelines:
805
 
 
 
806
  {context}
807
-
808
- Previous conversation:
809
  {conv_history}
810
-
811
- Guidelines:
812
- 1. Assess symptoms and severity based on both your training and the provided guidelines
813
- 2. Ask relevant follow-up questions if needed
814
- 3. Direct to appropriate care (999, 111, or GP) according to symptom severity
815
- 4. Show empathy and cultural sensitivity
816
- 5. Never diagnose or recommend treatments
817
  <end_of_turn>
818
  <start_of_turn>user
819
  {message}
820
  <end_of_turn>
821
  <start_of_turn>assistant"""
822
 
823
- # Generate response with model
824
- inputs = self.tokenizer(
825
- prompt,
826
- return_tensors="pt",
827
- truncation=True,
828
- max_length=512
829
- ).to(self.model.device)
830
-
831
  outputs = self.model.generate(
832
  **inputs,
833
- max_new_tokens=256,
834
  min_new_tokens=20,
835
  do_sample=True,
836
  temperature=0.7,
@@ -838,395 +193,46 @@ Guidelines:
838
  repetition_penalty=1.2,
839
  no_repeat_ngram_size=3
840
  )
841
-
842
- # Process response
843
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
844
- response = response.split("<start_of_turn>assistant")[-1].strip()
845
- if "<end_of_turn>" in response:
846
- response = response.split("<end_of_turn>")[0].strip()
847
-
848
  self.last_interaction_time = time.time()
 
849
  return response
850
-
851
- except Exception as e:
852
- logger.error(f"Error generating response: {e}")
853
- return "I apologize, but I encountered an error. Please try again."
854
-
855
- def handle_feedback(self, message: str, response: str, feedback: int):
856
- """Handle user feedback for responses"""
857
- try:
858
- timestamp = datetime.now().isoformat()
859
- feedback_data = {
860
- "message": message,
861
- "response": response,
862
- "feedback": feedback,
863
- "timestamp": timestamp
864
- }
865
-
866
- # Log feedback
867
- logger.info(f"Feedback received: {feedback_data}")
868
-
869
- # Here you could:
870
- # 1. Store feedback in a database
871
- # 2. Send to monitoring system
872
- # 3. Use for model improvements
873
-
874
- return True
875
- except Exception as e:
876
- logger.error(f"Error handling feedback: {e}")
877
- return False
878
 
879
- def __del__(self):
880
- """Cleanup resources"""
881
- try:
882
- if hasattr(self, 'model'):
883
- del self.model
884
- ModelManager.clear_gpu_memory()
885
  except Exception as e:
886
- logger.error(f"Error in cleanup: {e}")
887
-
888
 
 
 
 
889
  def create_demo():
890
- try:
891
- # Initialize bot
892
- bot = PearlyBot()
893
-
894
- def chat(message: str, history: list):
895
- """Handle chat interactions"""
896
- try:
897
- if not message.strip():
898
- return history
899
-
900
- response = bot.generate_response(message, history)
901
- history.append({
902
- "role": "user",
903
- "content": message
904
- })
905
- history.append({
906
- "role": "assistant",
907
- "content": response
908
- })
909
- return history
910
- except Exception as e:
911
- logger.error(f"Chat error: {e}")
912
- return history + [{
913
- "role": "assistant",
914
- "content": "I apologize, but I'm experiencing technical difficulties. For emergencies, please call 999."
915
- }]
916
-
917
- def process_feedback(positive: bool, comment: str, history: list):
918
- try:
919
- if not history or len(history) < 2:
920
- return gr.update(value="")
921
-
922
- last_user_msg = history[-2]["content"] if isinstance(history[-2], dict) else history[-2][0]
923
- last_bot_msg = history[-1]["content"] if isinstance(history[-1], dict) else history[-1][1]
924
-
925
- bot.handle_feedback(
926
- message=last_user_msg,
927
- response=last_bot_msg,
928
- feedback=1 if positive else -1
929
- )
930
- return gr.update(value="")
931
- except Exception as e:
932
- logger.error(f"Error processing feedback: {e}")
933
- return gr.update(value="")
934
-
935
- # Create Gradio interface
936
- with gr.Blocks(theme=gr.themes.Soft(...)) as demo:
937
- # 1. First, create all UI elements
938
- # CSS styles
939
- gr.HTML("""<style>...""")
940
-
941
- # Emergency Banner
942
- gr.HTML("""<div class="emergency-banner">...""")
943
-
944
- # Header
945
- with gr.Row(elem_classes="header"):
946
- gr.Markdown("""# GP Medical Triage Assistant...""")
947
-
948
- # Features Grid
949
- gr.HTML("""<div class="features-grid">...""")
950
-
951
- # Chat Interface
952
- with gr.Row():
953
- with gr.Column(scale=4):
954
- chatbot = gr.Chatbot(...)
955
- with gr.Row():
956
- msg = gr.Textbox(...)
957
- submit = gr.Button(...)
958
-
959
- with gr.Column(scale=1):
960
- # Quick Actions
961
- emergency_btn = gr.Button("🚨 Emergency Info", variant="secondary")
962
- nhs_111_btn = gr.Button("📞 NHS 111 Info", variant="secondary")
963
- booking_btn = gr.Button("📅 GP Booking", variant="secondary")
964
-
965
- # Controls
966
- clear = gr.Button("🗑️ Clear Chat")
967
-
968
- # Feedback
969
- with gr.Row():
970
- feedback_positive = gr.Button("👍", elem_id="thumb-up")
971
- feedback_negative = gr.Button("👎", elem_id="thumb-down")
972
- feedback_text = gr.Textbox(...)
973
- feedback_submit = gr.Button(...)
974
-
975
- # Examples and Guide
976
- with gr.Accordion("Example Messages", open=False):
977
- gr.Examples([...])
978
-
979
- with gr.Accordion("NHS Services Guide", open=False):
980
- gr.Markdown("""...""")
981
-
982
-
983
- # Create enhanced Gradio interface
984
- with gr.Blocks(theme=gr.themes.Soft(
985
- primary_hue="blue",
986
- secondary_hue="indigo",
987
- neutral_hue="slate",
988
- font=gr.themes.GoogleFont("Inter")
989
- )) as demo:
990
- # Custom CSS for enhanced styling
991
- gr.HTML("""
992
- <style>
993
- .container { max-width: 900px; margin: auto; }
994
- .header { text-align: center; padding: 20px; }
995
- .emergency-banner {
996
- background-color: #ff4444;
997
- color: white;
998
- padding: 10px;
999
- text-align: center;
1000
- font-weight: bold;
1001
- margin-bottom: 20px;
1002
- }
1003
- .feature-card {
1004
- padding: 15px;
1005
- border-radius: 10px;
1006
- text-align: center;
1007
- transition: transform 0.2s;
1008
- color: white;
1009
- font-weight: bold;
1010
- }
1011
- .feature-card:nth-child(1) { background: linear-gradient(135deg, #2193b0, #6dd5ed); }
1012
- .feature-card:nth-child(2) { background: linear-gradient(135deg, #834d9b, #d04ed6); }
1013
- .feature-card:nth-child(3) { background: linear-gradient(135deg, #ff4b1f, #ff9068); }
1014
- .feature-card:nth-child(4) { background: linear-gradient(135deg, #38ef7d, #11998e); }
1015
- .feature-card:hover {
1016
- transform: translateY(-5px);
1017
- box-shadow: 0 5px 15px rgba(0,0,0,0.2);
1018
- }
1019
- .feature-card span.emoji {
1020
- font-size: 2em;
1021
- display: block;
1022
- margin-bottom: 10px;
1023
- }
1024
- .message-textbox textarea { resize: none; }
1025
- #thumb-up, #thumb-down {
1026
- min-width: 60px;
1027
- padding: 8px;
1028
- margin: 5px;
1029
- }
1030
- .chatbot-message {
1031
- padding: 12px;
1032
- margin: 8px 0;
1033
- border-radius: 8px;
1034
- }
1035
- .user-message { background-color: #e3f2fd; }
1036
- .assistant-message { background-color: #f5f5f5; }
1037
- .feedback-section {
1038
- margin-top: 20px;
1039
- padding: 15px;
1040
- border-radius: 8px;
1041
- background-color: #f8f9fa;
1042
- }
1043
- </style>
1044
- """)
1045
-
1046
- # Emergency Banner
1047
- gr.HTML("""
1048
- <div class="emergency-banner">
1049
- 🚨 For medical emergencies, always call 999 immediately 🚨
1050
- </div>
1051
- """)
1052
-
1053
- # Header Section
1054
- with gr.Row(elem_classes="header"):
1055
- gr.Markdown("""
1056
- # GP Medical Triage Assistant - Pearly
1057
- Welcome to your personal medical triage assistant. I'm here to help assess your symptoms and guide you to appropriate care.
1058
- """)
1059
-
1060
- # Main Features Grid
1061
- gr.HTML("""
1062
- <div class="features-grid">
1063
- <div class="feature-card">
1064
- <span class="emoji">🏥</span>
1065
- <div>GP Appointments</div>
1066
- </div>
1067
- <div class="feature-card">
1068
- <span class="emoji">🔍</span>
1069
- <div>Symptom Assessment</div>
1070
- </div>
1071
- <div class="feature-card">
1072
- <span class="emoji">⚡</span>
1073
- <div>Urgent Care Guide</div>
1074
- </div>
1075
- <div class="feature-card">
1076
- <span class="emoji">💊</span>
1077
- <div>Medical Advice</div>
1078
- </div>
1079
- </div>
1080
- """)
1081
-
1082
- # Chat Interface
1083
- with gr.Row():
1084
- with gr.Column(scale=4):
1085
- chatbot = gr.Chatbot(
1086
- value=[{
1087
- "role": "assistant",
1088
- "content": "Hello! I'm Pearly, your GP medical assistant. How can I help you today?"
1089
- }],
1090
- height=500,
1091
- elem_id="chatbot",
1092
- type="messages",
1093
- show_label=False
1094
- )
1095
-
1096
- with gr.Row():
1097
- msg = gr.Textbox(
1098
- label="Your message",
1099
- placeholder="Type your message here...",
1100
- lines=2,
1101
- scale=4,
1102
- autofocus=True,
1103
- submit_on_enter=True
1104
- )
1105
- submit = gr.Button("Send", variant="primary", scale=1)
1106
-
1107
- with gr.Column(scale=1):
1108
- # Quick Actions Panel
1109
- gr.Markdown("### Quick Actions")
1110
- emergency_btn = gr.Button("🚨 Emergency Info", variant="secondary")
1111
- nhs_111_btn = gr.Button("📞 NHS 111 Info", variant="secondary")
1112
- booking_btn = gr.Button("📅 GP Booking", variant="secondary")
1113
-
1114
- # Controls and Feedback
1115
- gr.Markdown("### Controls")
1116
- clear = gr.Button("🗑️ Clear Chat")
1117
-
1118
- gr.Markdown("### Feedback")
1119
- with gr.Row():
1120
- feedback_positive = gr.Button("👍", elem_id="thumb-up")
1121
- feedback_negative = gr.Button("👎", elem_id="thumb-down")
1122
-
1123
- feedback_text = gr.Textbox(
1124
- label="Additional comments",
1125
- placeholder="Tell us more...",
1126
- lines=2,
1127
- visible=True
1128
- )
1129
- feedback_submit = gr.Button("Submit Feedback", visible=True)
1130
-
1131
- # Examples and Information
1132
- with gr.Accordion("Example Messages", open=False):
1133
- gr.Examples([
1134
- ["I've been having severe headaches for the past week"],
1135
- ["I need to book a routine checkup"],
1136
- ["I'm feeling very anxious lately and need help"],
1137
- ["My child has had a fever for 2 days"],
1138
- ["I need information about COVID-19 testing"]
1139
- ], inputs=msg)
1140
-
1141
- with gr.Accordion("NHS Services Guide", open=False):
1142
- gr.Markdown("""
1143
- ### Emergency Services (999)
1144
- - Life-threatening emergencies
1145
- - Severe injuries
1146
- - Suspected heart attack or stroke
1147
-
1148
- ### NHS 111
1149
- - Urgent but non-emergency situations
1150
- - Medical advice needed
1151
- - Unsure where to go
1152
-
1153
- ### GP Services
1154
- - Routine check-ups
1155
- - Non-urgent medical issues
1156
- - Prescription renewals
1157
- """)
1158
-
1159
-
1160
-
1161
- def show_emergency_info():
1162
- return """🚨 Emergency Services (999)
1163
- - For life-threatening emergencies
1164
- - Severe chest pain
1165
- - Difficulty breathing
1166
- - Severe bleeding
1167
- - Loss of consciousness
1168
- """
1169
-
1170
- def show_nhs_111_info():
1171
- return """📞 NHS 111 Service
1172
- - Available 24/7
1173
- - Medical advice
1174
- - Local service information
1175
- - Urgent care guidance
1176
- """
1177
-
1178
- def show_booking_info():
1179
- return """📅 GP Booking Options
1180
- - Online booking
1181
- - Phone booking
1182
- - Routine appointments
1183
- - Urgent appointments
1184
- """
1185
-
1186
- # Chat handlers
1187
- msg.submit(chat, [msg, chatbot], [chatbot]).then(
1188
- lambda: gr.update(value=""), None, [msg]
1189
- )
1190
-
1191
- submit.click(chat, [msg, chatbot], [chatbot]).then(
1192
- lambda: gr.update(value=""), None, [msg]
1193
- )
1194
-
1195
- # Quick action handlers
1196
- emergency_btn.click(lambda: show_emergency_info(), outputs=[msg])
1197
- nhs_111_btn.click(lambda: show_nhs_111_info(), outputs=[msg])
1198
- booking_btn.click(lambda: show_booking_info(), outputs=[msg])
1199
-
1200
- # Feedback handlers
1201
- feedback_positive.click(
1202
- lambda h: process_feedback(True, feedback_text.value, h),
1203
- inputs=[chatbot],
1204
- outputs=[feedback_text]
1205
- )
1206
-
1207
- feedback_negative.click(
1208
- lambda h: process_feedback(False, feedback_text.value, h),
1209
- inputs=[chatbot],
1210
- outputs=[feedback_text]
1211
- )
1212
-
1213
- # Clear chat
1214
- clear.click(lambda: None, None, chatbot)
1215
-
1216
- # 3. Finally, add the queue
1217
- demo.queue(concurrency_count=1, max_size=10)
1218
-
1219
- return demo
1220
-
1221
- except Exception as e:
1222
- logger.error(f"Error creating demo: {e}")
1223
- raise
1224
-
1225
  if __name__ == "__main__":
1226
- # Initialize logging and load env vars
1227
- logging.basicConfig(level=logging.INFO)
1228
- load_dotenv()
1229
-
1230
- # Create and launch demo
1231
  demo = create_demo()
1232
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # Optimized Triage Chatbot Code for Hugging Face Space (NVIDIA T4 GPU)
2
+ # Covers: Memory optimizations, 4-bit quantization, lazy loading, FAISS caching, faster inference, safe Gradio UI
3
+
4
  import os
5
+ import time
6
  import torch
7
  import logging
8
+ import gradio as gr
9
+ import psutil
10
  from datetime import datetime
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
+ from datasets import load_dataset, load_from_disk, Dataset
14
  from transformers import (
15
+ AutoTokenizer, AutoModelForCausalLM,
16
+ TrainingArguments, Trainer,
 
 
17
  BitsAndBytesConfig
18
  )
19
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
20
+ from langchain.vectorstores import FAISS
21
+ from langchain.embeddings import HuggingFaceEmbeddings
22
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
23
 
24
+ # Logging Setup
25
  logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger("PearlyBot")
27
 
28
+ # ===========================
29
+ # 🧠 SECRETS MANAGER
30
+ # ===========================
31
  class SecretsManager:
 
 
32
  @staticmethod
33
+ def setup():
34
+ load_dotenv()
35
+ creds = {
36
+ 'KAGGLE_USERNAME': os.getenv('KAGGLE_USERNAME'),
37
+ 'KAGGLE_KEY': os.getenv('KAGGLE_KEY'),
38
+ 'HF_TOKEN': os.getenv('HF_TOKEN'),
39
+ 'WANDB_KEY': os.getenv('WANDB_KEY')
40
+ }
41
+ if creds['HF_TOKEN']:
42
+ login(token=creds['HF_TOKEN'])
43
+ logger.info("🔐 Logged in to Hugging Face")
44
+ return creds
45
+
46
+ # ===========================
47
+ # 🧠 CHATBOT CLASS
48
+ # ===========================
49
+ class PearlyBot:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def __init__(self):
 
 
 
 
 
 
51
  self.tokenizer = None
52
+ self.model = None
53
+ self.embeddings = None
54
+ self.vector_store = None
55
  self.chunk_size = 300
56
  self.chunk_overlap = 100
57
  self.num_relevant_chunks = 3
58
+ self.last_interaction_time = time.time()
59
+ self.interaction_cooldown = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ def setup_model_and_tokenizer(self, model_name="google/gemma-7b"):
62
+ if self.model is not None:
63
+ return
64
+ logger.info("🚀 Loading model & tokenizer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  bnb_config = BitsAndBytesConfig(
66
+ load_in_4bit=True,
67
+ bnb_4bit_use_double_quant=True,
68
+ bnb_4bit_quant_type="nf4",
69
+ bnb_4bit_compute_dtype=torch.float16
70
  )
71
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
72
+ self.tokenizer.pad_token = self.tokenizer.eos_token
73
  model = AutoModelForCausalLM.from_pretrained(
74
  model_name,
75
  device_map="auto",
 
77
  torch_dtype=torch.float16,
78
  low_cpu_mem_usage=True
79
  )
 
80
  model = prepare_model_for_kbit_training(model)
 
81
  lora_config = LoraConfig(
82
  r=4,
83
  lora_alpha=16,
 
86
  bias="none",
87
  task_type="CAUSAL_LM"
88
  )
89
+ self.model = get_peft_model(model, lora_config)
90
+ self.model.to("cuda" if torch.cuda.is_available() else "cpu")
91
+ logger.info("✅ Model & tokenizer ready")
92
+
93
+ def setup_embeddings(self):
94
+ if self.embeddings is None:
95
+ logger.info("📌 Loading sentence-transformer embeddings")
96
+ self.embeddings = HuggingFaceEmbeddings(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  model_name="sentence-transformers/all-MiniLM-L6-v2",
98
+ cache_folder="./embeddings_cache"
99
  )
 
 
 
100
 
101
+ def load_faiss_index(self):
102
+ logger.info("📁 Loading FAISS index")
103
+ if os.path.exists("index_store/index.faiss"):
104
+ self.vector_store = FAISS.load_local("index_store", self.embeddings)
105
+ else:
106
+ self.build_faiss_index()
107
+
108
+ def build_faiss_index(self):
109
+ logger.info("🔧 Building FAISS index from knowledge base")
110
+ knowledge_base = self._load_knowledge_base()
111
+ self.setup_embeddings()
112
+ texts = self._split_texts(knowledge_base)
113
+ self.vector_store = FAISS.from_texts(
114
+ texts,
115
+ self.embeddings,
116
+ metadatas=[{"source": f"chunk_{i}"} for i in range(len(texts))]
117
+ )
118
+ self.vector_store.save_local("index_store")
119
+
120
+ def _load_knowledge_base(self):
121
+ kb_dir = "knowledge_base"
122
+ os.makedirs(kb_dir, exist_ok=True)
123
+ kb_files = {
124
+ "triage.txt": "Severe chest pain? Call 999. Persistent cough? Book GP.",
125
+ "emergency.txt": "Unconscious? 999. Breathing issues? 999.",
126
+ "cultural.txt": "Respect prayer times, language needs, traditional remedies.",
127
+ "gp_booking.txt": "I need to book a GP for routine care next week."
128
+ }
129
+ for file, content in kb_files.items():
130
+ with open(os.path.join(kb_dir, file), 'w') as f:
131
+ f.write(content)
132
+ return kb_files
133
+
134
+ def _split_texts(self, kb):
135
  splitter = RecursiveCharacterTextSplitter(
136
  chunk_size=self.chunk_size,
137
  chunk_overlap=self.chunk_overlap,
138
  length_function=len,
139
  add_start_index=True
140
  )
141
+ texts = []
142
+ for text in kb.values():
143
+ chunks = splitter.split_text(text)
144
+ texts.extend(chunks)
145
+ return texts
 
146
 
147
+ def _get_enhanced_context(self, query):
148
  try:
149
+ results = self.vector_store.similarity_search_with_score(query, k=self.num_relevant_chunks)
150
+ context = [f"[Source: {doc.metadata.get('source', 'unknown')}]:\n{doc.page_content}"
151
+ for doc, score in results if score < 0.8]
152
+ return "\n\n".join(context)
153
  except Exception as e:
154
+ logger.error(f"Context error: {e}")
155
  return ""
156
 
157
  @torch.inference_mode()
158
+ def generate_response(self, message, history):
 
159
  try:
160
+ # Throttle
161
+ if time.time() - self.last_interaction_time < self.interaction_cooldown:
 
162
  time.sleep(self.interaction_cooldown)
163
+
164
+ self.setup_model_and_tokenizer()
165
+ self.setup_embeddings()
166
+ self.load_faiss_index()
167
+
168
  context = self._get_enhanced_context(message)
 
 
169
  conv_history = "\n".join([
170
+ f"User: {turn['content']}" if turn['role'] == 'user' else f"Assistant: {turn['content']}"
171
+ for turn in history[-3:]
172
  ])
 
 
 
 
173
 
174
+ prompt = f"""<start_of_turn>system
175
+ Context:
176
  {context}
177
+ Conversation:
 
178
  {conv_history}
 
 
 
 
 
 
 
179
  <end_of_turn>
180
  <start_of_turn>user
181
  {message}
182
  <end_of_turn>
183
  <start_of_turn>assistant"""
184
 
185
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.model.device)
 
 
 
 
 
 
 
186
  outputs = self.model.generate(
187
  **inputs,
188
+ max_new_tokens=128,
189
  min_new_tokens=20,
190
  do_sample=True,
191
  temperature=0.7,
 
193
  repetition_penalty=1.2,
194
  no_repeat_ngram_size=3
195
  )
 
 
196
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
197
+ response = response.split("<start_of_turn>assistant")[-1].strip().split("<end_of_turn>")[0].strip()
 
 
 
198
  self.last_interaction_time = time.time()
199
+ logger.info(f"💬 Memory use: {psutil.virtual_memory().percent}%")
200
  return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
 
 
 
 
 
 
202
  except Exception as e:
203
+ logger.error(f"Generation error: {e}")
204
+ return "I encountered an error. Please try again."
205
 
206
+ # ===========================
207
+ # 💬 GRADIO UI
208
+ # ===========================
209
  def create_demo():
210
+ bot = PearlyBot()
211
+
212
+ def chat(message, history):
213
+ if not message.strip():
214
+ return history
215
+ response = bot.generate_response(message, history)
216
+ history.append({"role": "user", "content": message})
217
+ history.append({"role": "assistant", "content": response})
218
+ return history
219
+
220
+ with gr.Blocks() as demo:
221
+ chatbot = gr.Chatbot([], height=500, show_label=False)
222
+ msg = gr.Textbox(label="Type your message")
223
+ send = gr.Button("Send")
224
+ clear = gr.Button("Clear Chat")
225
+
226
+ msg.submit(chat, [msg, chatbot], [chatbot]).then(lambda: gr.update(value=""), None, [msg])
227
+ send.click(chat, [msg, chatbot], [chatbot]).then(lambda: gr.update(value=""), None, [msg])
228
+ clear.click(lambda: [], None, chatbot)
229
+
230
+ return demo
231
+
232
+ # ===========================
233
+ # 🚀 MAIN
234
+ # ===========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  if __name__ == "__main__":
236
+ SecretsManager.setup()
 
 
 
 
237
  demo = create_demo()
238
  demo.launch(server_name="0.0.0.0", server_port=7860)