PearlIsa commited on
Commit
9120f67
1 Parent(s): 459f450

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +440 -201
app.py CHANGED
@@ -1,164 +1,271 @@
1
- # Standard Libraries
2
  import os
3
- import json
4
- import time
5
- import asyncio
6
  import logging
7
- import gc
8
- import re
9
- import traceback
10
- from pathlib import Path
11
  from datetime import datetime
12
- from typing import List, Dict, Union, Tuple, Optional, Any
13
- from dataclasses import dataclass, field
14
- import zipfile
15
-
16
- # Machine Learning and Deep Learning Libraries
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- from torch.cuda.amp import autocast
21
- from torch.utils.data import DataLoader
22
-
23
- # Hugging Face and Transformers
24
- from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
25
- from sentence_transformers import SentenceTransformer
26
- from datasets import load_dataset, Dataset, concatenate_datasets
27
  from huggingface_hub import login
28
-
29
- # FAISS and PEFT
30
- import faiss
31
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel
32
-
33
- # LangChain - updated imports as per recent deprecations
34
- from langchain_community.vectorstores import FAISS # Updated import
35
- from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import
36
- from langchain_community.document_loaders import TextLoader # Updated import
37
- from langchain.text_splitter import RecursiveCharacterTextSplitter
38
-
39
-
40
- # External Tools and APIs
41
- import wandb
42
- import requests
43
- import gradio as gr
44
- import IPython.display as display # Required for IPython display functionality
45
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  from tqdm.auto import tqdm
47
 
48
- # Suppress Warnings
49
- import warnings
50
- warnings.filterwarnings('ignore')
51
-
52
-
53
- # Ensure Hugging Face login
54
- try:
55
- hf_token = os.getenv("HF_TOKEN")
56
- if hf_token:
57
- login(token=hf_token)
58
- print("Login successful!")
59
- except Exception as e:
60
- print("Hugging Face Login failed:", e)
61
-
62
-
63
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True'
64
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
65
-
66
-
67
-
68
  # Setup logging
69
  logging.basicConfig(level=logging.INFO)
70
  logger = logging.getLogger(__name__)
71
 
72
-
73
-
74
- class ModelManager:
75
- """Handles model loading and resource management"""
76
-
77
- @staticmethod
78
- def verify_and_extract_model(checkpoint_zip_path: str, extracted_model_dir: str) -> str:
79
- """Verify and extract the model if it's not already extracted"""
80
- if not os.path.exists(extracted_model_dir):
81
- # Unzip the model if it hasn’t been extracted yet
82
- with zipfile.ZipFile(checkpoint_zip_path, 'r') as zip_ref:
83
- zip_ref.extractall(extracted_model_dir)
84
- logger.info(f"Extracted model to: {extracted_model_dir}")
85
- else:
86
- logger.info(f"Model already extracted: {extracted_model_dir}")
87
-
88
- return extracted_model_dir
89
 
90
  @staticmethod
91
- def clear_gpu_memory():
92
- """Clear GPU memory cache"""
93
- if torch.cuda.is_available():
94
- torch.cuda.empty_cache()
95
- gc.collect()
96
-
97
- class PearlyBot:
98
- def __init__(self):
99
- try:
100
- # Use the correct model path from your space
101
- self.repo_id = "Pearilsa/pearly_med_triage_chatbot_kagglex"
102
- self.model_filename = "pearly_model.zip"
103
- self.setup_model()
104
- self.setup_rag()
105
- self.conversation_history = []
106
- self.last_interaction_time = time.time()
107
- self.interaction_cooldown = 1.0
108
- except Exception as e:
109
- logger.error(f"Error initializing bot: {e}")
110
- raise
111
-
112
- def setup_model(self):
113
- """Initialize model from Hugging Face space"""
114
  try:
115
- logger.info(f"Loading model from {self.repo_id}")
 
116
 
117
- # Download and prepare model path
118
- local_model_path = os.path.join(os.getcwd(), "models")
119
- os.makedirs(local_model_path, exist_ok=True)
 
 
 
 
120
 
121
- # Load tokenizer and model from the space
122
- self.tokenizer = AutoTokenizer.from_pretrained(
123
- self.repo_id,
124
- token=os.getenv("HF_TOKEN"), # Use your Hugging Face token
125
- cache_dir=local_model_path
126
- )
127
- self.tokenizer.pad_token = self.tokenizer.eos_token
128
- logger.info("Tokenizer loaded successfully")
129
-
130
- # Load model with 8-bit quantization
131
- self.model = AutoModelForCausalLM.from_pretrained(
132
- self.repo_id,
133
- token=os.getenv("HF_TOKEN"),
134
- device_map="auto",
135
- load_in_8bit=True,
136
- torch_dtype=torch.float16,
137
- low_cpu_mem_usage=True,
138
- cache_dir=local_model_path
139
- )
140
- self.model.eval()
141
- logger.info("Model loaded successfully")
142
 
143
  except Exception as e:
144
- logger.error(f"Error in model setup: {str(e)}")
145
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  def setup_rag(self):
 
149
  try:
150
- # Add configuration options
151
- self.chunk_size = 300
152
- self.chunk_overlap = 100
153
- self.num_relevant_chunks = 3
154
 
155
  # Load knowledge base
156
  knowledge_base = self._load_knowledge_base()
157
 
158
- # Setup embeddings with error handling
159
  self.embeddings = self._initialize_embeddings()
160
 
161
- # Enhanced text splitting
162
  texts = self._split_texts(knowledge_base)
163
 
164
  # Create vector store with metadata
@@ -168,13 +275,15 @@ class PearlyBot:
168
  metadatas=[{"source": f"chunk_{i}"} for i in range(len(texts))]
169
  )
170
 
171
- # Add validation
172
  self._validate_rag_setup()
 
173
 
174
  except Exception as e:
175
- logger.error(f"RAG setup failed: {str(e)}")
176
  raise
177
- # Load your knowledge base content
 
178
  def _load_knowledge_base(self):
179
  """Load and validate knowledge base content"""
180
  try:
@@ -488,6 +597,156 @@ class PearlyBot:
488
  except Exception as e:
489
  logger.error(f"RAG system validation failed: {str(e)}")
490
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
  def _initialize_embeddings(self):
493
  try:
@@ -525,24 +784,22 @@ class PearlyBot:
525
  def generate_response(self, message: str, history: list) -> str:
526
  """Generate response using both fine-tuned model and RAG"""
527
  try:
528
- # Rate limiting
529
  current_time = time.time()
530
  if current_time - self.last_interaction_time < self.interaction_cooldown:
531
  time.sleep(self.interaction_cooldown)
 
532
 
533
- # Clear GPU memory before generation
534
- ModelManager.clear_gpu_memory()
535
-
536
- # Get RAG context
537
- context = self.get_relevant_context(message)
538
 
539
  # Format conversation history
540
  conv_history = "\n".join([
541
- f"User: {user}\nAssistant: {assistant}"
542
- for user, assistant in history[-3:] # Keep last 3 turns
543
  ])
544
 
545
- # Create prompt
546
  prompt = f"""<start_of_turn>system
547
  Using these medical guidelines:
548
 
@@ -552,9 +809,9 @@ Previous conversation:
552
  {conv_history}
553
 
554
  Guidelines:
555
- 1. Assess symptoms and severity
556
- 2. Ask relevant follow-up questions
557
- 3. Direct to appropriate care (999, 111, or GP)
558
  4. Show empathy and cultural sensitivity
559
  5. Never diagnose or recommend treatments
560
  <end_of_turn>
@@ -563,41 +820,36 @@ Guidelines:
563
  <end_of_turn>
564
  <start_of_turn>assistant"""
565
 
566
- # Generate response
567
- try:
568
- inputs = self.tokenizer(
569
- prompt,
570
- return_tensors="pt",
571
- truncation=True,
572
- max_length=512
573
- ).to(self.model.device)
574
-
575
- outputs = self.model.generate(
576
- **inputs,
577
- max_new_tokens=256,
578
- min_new_tokens=20,
579
- do_sample=True,
580
- temperature=0.7,
581
- top_p=0.9,
582
- repetition_penalty=1.2,
583
- no_repeat_ngram_size=3
584
- )
585
-
586
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
587
- response = response.split("<start_of_turn>assistant")[-1].strip()
588
- if "<end_of_turn>" in response:
589
- response = response.split("<end_of_turn>")[0].strip()
590
-
591
- self.last_interaction_time = time.time()
592
- return response
593
-
594
- except torch.cuda.OutOfMemoryError:
595
- ModelManager.clear_gpu_memory()
596
- logger.error("GPU out of memory, cleared cache and retrying...")
597
- return "I apologize, but I'm experiencing technical difficulties. Please try again."
598
-
599
  except Exception as e:
600
- logger.error(f"Error generating response: {str(e)}")
601
  return "I apologize, but I encountered an error. Please try again."
602
 
603
  def handle_feedback(self, message: str, response: str, feedback: int):
@@ -971,23 +1223,10 @@ def create_demo():
971
  raise
972
 
973
  if __name__ == "__main__":
974
- try:
975
- # Initialize logging
976
- logging.basicConfig(level=logging.INFO)
977
-
978
- # Load environment variables
979
- load_dotenv()
980
-
981
- # Create and launch demo
982
- demo = create_demo()
983
- demo.launch(
984
- server_name="0.0.0.0",
985
- server_port=7860,
986
- show_error=True
987
- )
988
-
989
- except Exception as e:
990
- logger.error(f"Application startup failed: {e}")
991
- raise
992
-
993
-
 
1
+ # Standard imports first
2
  import os
3
+ import torch
 
 
4
  import logging
 
 
 
 
5
  from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from huggingface_hub import login
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from dotenv import load_dotenv
8
+ from datasets import load_dataset, Dataset
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForCausalLM,
12
+ TrainingArguments,
13
+ Trainer,
14
+ BitsAndBytesConfig
15
+ )
16
+ from peft import (
17
+ LoraConfig,
18
+ get_peft_model,
19
+ prepare_model_for_kbit_training
20
+ )
21
  from tqdm.auto import tqdm
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Setup logging
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
27
+ class SecretsManager:
28
+ """Handles authentication and secrets management"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  @staticmethod
31
+ def setup_credentials():
32
+ """Setup all required credentials"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  try:
34
+ # Load environment variables
35
+ load_dotenv()
36
 
37
+ # Get credentials
38
+ credentials = {
39
+ 'KAGGLE_USERNAME': os.getenv('KAGGLE_USERNAME'),
40
+ 'KAGGLE_KEY': os.getenv('KAGGLE_KEY'),
41
+ 'HF_TOKEN': os.getenv('HF_TOKEN'),
42
+ 'WANDB_KEY': os.getenv('WANDB_KEY')
43
+ }
44
 
45
+ # Validate credentials
46
+ missing_creds = [k for k, v in credentials.items() if not v]
47
+ if missing_creds:
48
+ logger.warning(f"Missing credentials: {', '.join(missing_creds)}")
49
+
50
+ # Setup Hugging Face authentication
51
+ if credentials['HF_TOKEN']:
52
+ login(token=credentials['HF_TOKEN'])
53
+ logger.info("Successfully logged in to Hugging Face")
54
+ # Setup Kaggle credentials if available
55
+ if credentials['KAGGLE_USERNAME'] and credentials['KAGGLE_KEY']:
56
+ os.environ['KAGGLE_USERNAME'] = credentials['KAGGLE_USERNAME']
57
+ os.environ['KAGGLE_KEY'] = credentials['KAGGLE_KEY']
58
+
59
+ # Setup wandb if available
60
+ if credentials['WANDB_KEY']:
61
+ os.environ['WANDB_API_KEY'] = credentials['WANDB_KEY']
62
+
63
+ return credentials
 
 
64
 
65
  except Exception as e:
66
+ logger.error(f"Error setting up credentials: {e}")
67
  raise
68
+ class ModelTrainer:
69
+ """Handles model training pipeline"""
70
+
71
+ def __init__(self):
72
+ # Set memory optimization environment variables
73
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True'
74
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
75
+
76
+ # Initialize attributes
77
+ self.model = None
78
+ self.tokenizer = None
79
+ self.dataset = None
80
+ self.processed_dataset = None
81
+ self.chunk_size = 300
82
+ self.chunk_overlap = 100
83
+ self.num_relevant_chunks = 3
84
+ self.vector_store = None
85
+ self.embeddings = None
86
+ self.last_interaction_time = time.time() # Add this
87
+ self.interaction_cooldown = 1.0 # Add this
88
+
89
+ # Setup GPU preferences
90
+ torch.backends.cuda.matmul.allow_tf32 = False
91
+ torch.backends.cudnn.allow_tf32 = False
92
+
93
+ def prepare_initial_datasets(batch_size=8):
94
+ print("Loading datasets with memory-optimized batch processing...")
95
+
96
+ def process_medqa_batch(examples):
97
+ results = []
98
+ inputs = examples['input']
99
+ instructions = examples['instruction']
100
+ outputs = examples['output']
101
+
102
+ for inp, inst, out in zip(inputs, instructions, outputs):
103
+ results.append({
104
+ "input": f"{inp} {inst}",
105
+ "output": out
106
+ })
107
+ return results
108
+
109
+ def process_meddia_batch(examples):
110
+ results = []
111
+ inputs = examples['input']
112
+ outputs = examples['output']
113
+
114
+ for inp, out in zip(inputs, outputs):
115
+ results.append({
116
+ "input": inp,
117
+ "output": out
118
+ })
119
+ return results
120
+
121
+ def process_persona_batch(examples):
122
+ results = []
123
+ personalities = examples['personality']
124
+ utterances = examples['utterances']
125
+
126
+ for pers, utts in zip(personalities, utterances):
127
+ try:
128
+ # Process personality list
129
+ personality = ' '.join([
130
+ p for p in pers
131
+ if isinstance(p, str)
132
+ ])
133
+
134
+ # Process utterances
135
+ if utts and len(utts) > 0:
136
+ utterance = utts[0]
137
+ history = []
138
+
139
+ # Process history
140
+ if 'history' in utterance and utterance['history']:
141
+ history = [
142
+ h for h in utterance['history']
143
+ if isinstance(h, str)
144
+ ]
145
+
146
+ history_text = ' '.join(history)
147
+
148
+ # Get candidate response
149
+ candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else ''
150
+
151
+ if personality or history_text:
152
+ results.append({
153
+ "input": f"{personality} {history_text}".strip(),
154
+ "output": candidate
155
+ })
156
+ except Exception as e:
157
+ print(f"Error processing persona batch item: {e}")
158
+ continue
159
 
160
+ return results
161
+ try:
162
+ Load and process each dataset separately
163
+ print("Processing MedQA dataset...")
164
+ medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]")
165
+ medqa_processed = []
166
+
167
+ for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"):
168
+ batch = medqa[i:i + batch_size]
169
+ medqa_processed.extend(process_medqa_batch(batch))
170
+ if i % (batch_size * 5) == 0:
171
+ torch.cuda.empty_cache()
172
+
173
+ print("Processing MedDiagnosis dataset...")
174
+ meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]")
175
+ meddia_processed = []
176
+
177
+ for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"):
178
+ batch = meddia[i:i + batch_size]
179
+ meddia_processed.extend(process_meddia_batch(batch))
180
+ if i % (batch_size * 5) == 0:
181
+ torch.cuda.empty_cache()
182
+
183
+ print("Processing Persona-Chat dataset...")
184
+ persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
185
+ persona_processed = []
186
+
187
+ for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"):
188
+ batch = persona[i:i + batch_size]
189
+ persona_processed.extend(process_persona_batch(batch))
190
+ if i % (batch_size * 5) == 0:
191
+ torch.cuda.empty_cache()
192
+
193
+ torch.cuda.empty_cache()
194
+
195
+ print("Creating final dataset...")
196
+ all_processed = persona_processed + medqa_processed + meddia_processed
197
+
198
+ valid_data = {
199
+ "input": [],
200
+ "output": []
201
+ }
202
+
203
+ for item in all_processed:
204
+ if item["input"].strip() and item["output"].strip():
205
+ valid_data["input"].append(item["input"])
206
+ valid_data["output"].append(item["output"])
207
+
208
+ final_dataset = Dataset.from_dict(valid_data)
209
+
210
+ print(f"Final dataset size: {len(final_dataset)}")
211
+ return final_dataset
212
 
213
+ def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4):
214
+ def tokenize_batch(examples):
215
+ formatted_texts = []
216
+
217
+ for i in range(0, len(examples['input']), batch_size):
218
+ sub_batch_inputs = examples['input'][i:i + batch_size]
219
+ sub_batch_outputs = examples['output'][i:i + batch_size]
220
+
221
+ for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs):
222
+ try:
223
+ formatted_text = f"""<start_of_turn>user
224
+ {input_text}
225
+ <end_of_turn>
226
+ <start_of_turn>assistant
227
+ {output_text}
228
+ <end_of_turn>"""
229
+ formatted_texts.append(formatted_text)
230
+ except Exception as e:
231
+ print(f"Error formatting text: {e}")
232
+ continue
233
+
234
+ tokenized = tokenizer(
235
+ formatted_texts,
236
+ padding="max_length",
237
+ truncation=True,
238
+ max_length=max_length,
239
+ return_tensors=None
240
+ )
241
+
242
+ tokenized["labels"] = tokenized["input_ids"].copy()
243
+ return tokenized
244
+
245
+ print(f"Tokenizing dataset in small batches (size={batch_size})...")
246
+ tokenized_dataset = dataset.map(
247
+ tokenize_batch,
248
+ batched=True,
249
+ batch_size=batch_size,
250
+ remove_columns=dataset.column_names,
251
+ desc="Tokenizing dataset",
252
+ load_from_cache_file=False
253
+ )
254
+
255
+ return tokenized_dataset
256
+
257
  def setup_rag(self):
258
+ """Initialize RAG components"""
259
  try:
260
+ logger.info("Setting up RAG system...")
 
 
 
261
 
262
  # Load knowledge base
263
  knowledge_base = self._load_knowledge_base()
264
 
265
+ # Setup embeddings
266
  self.embeddings = self._initialize_embeddings()
267
 
268
+ # Process texts for vector store
269
  texts = self._split_texts(knowledge_base)
270
 
271
  # Create vector store with metadata
 
275
  metadatas=[{"source": f"chunk_{i}"} for i in range(len(texts))]
276
  )
277
 
278
+ # Validate RAG setup
279
  self._validate_rag_setup()
280
+ logger.info("RAG system setup complete")
281
 
282
  except Exception as e:
283
+ logger.error(f"Failed to setup RAG: {e}")
284
  raise
285
+
286
+ # Load your knowledge base content
287
  def _load_knowledge_base(self):
288
  """Load and validate knowledge base content"""
289
  try:
 
597
  except Exception as e:
598
  logger.error(f"RAG system validation failed: {str(e)}")
599
  raise
600
+
601
+
602
+
603
+
604
+
605
+
606
+
607
+ def setup_model_and_tokenizer(model_name="google/gemma-2b"):
608
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
609
+ tokenizer.pad_token = tokenizer.eos_token
610
+
611
+ from transformers import BitsAndBytesConfig
612
+
613
+ bnb_config = BitsAndBytesConfig(
614
+ load_in_8bit=True,
615
+ bnb_8bit_compute_dtype=torch.float16,
616
+ llm_int8_enable_fp32_cpu_offload=True
617
+ )
618
+
619
+ model = AutoModelForCausalLM.from_pretrained(
620
+ model_name,
621
+ device_map="auto",
622
+ quantization_config=bnb_config,
623
+ torch_dtype=torch.float16,
624
+ low_cpu_mem_usage=True
625
+ )
626
+
627
+ model = prepare_model_for_kbit_training(model)
628
+
629
+ lora_config = LoraConfig(
630
+ r=4,
631
+ lora_alpha=16,
632
+ target_modules=["q_proj", "v_proj"],
633
+ lora_dropout=0.05,
634
+ bias="none",
635
+ task_type="CAUSAL_LM"
636
+ )
637
+
638
+ model = get_peft_model(model, lora_config)
639
+ model.print_trainable_parameters()
640
+
641
+ return model, tokenizer
642
+
643
+ def setup_training_arguments(output_dir="./pearly_fine_tuned"):
644
+ return TrainingArguments(
645
+ output_dir=output_dir,
646
+ num_train_epochs=1,
647
+ per_device_train_batch_size=1,
648
+ gradient_accumulation_steps=16,
649
+ warmup_steps=50,
650
+ logging_steps=10,
651
+ save_steps=200,
652
+ learning_rate=2e-4,
653
+ fp16=True,
654
+ gradient_checkpointing=True,
655
+ gradient_checkpointing_kwargs={"use_reentrant": False},
656
+ optim="adamw_8bit",
657
+ max_grad_norm=0.3,
658
+ weight_decay=0.001,
659
+ logging_dir="./logs",
660
+ save_total_limit=2,
661
+ remove_unused_columns=False,
662
+ dataloader_pin_memory=False,
663
+ max_steps=500,
664
+ report_to=["none"],
665
+ )
666
+
667
+ def train(self):
668
+ """Main training pipeline with RAG integration"""
669
+ try:
670
+ logger.info("Starting training pipeline")
671
+
672
+ # Clear GPU memory
673
+ torch.cuda.empty_cache()
674
+ if torch.cuda.is_available():
675
+ torch.cuda.reset_peak_memory_stats()
676
+
677
+ # Setup model, tokenizer, and RAG
678
+ logger.info("Setting up model components...")
679
+ self.model, self.tokenizer = self.setup_model_and_tokenizer()
680
+ self.setup_rag()
681
+
682
+ # Prepare and process datasets
683
+ logger.info("Preparing datasets...")
684
+ self.dataset = self.prepare_initial_datasets(batch_size=4)
685
+ self.processed_dataset = self.prepare_dataset(
686
+ self.dataset,
687
+ self.tokenizer,
688
+ max_length=256,
689
+ batch_size=2
690
+ )
691
+
692
+ # Train model
693
+ logger.info("Starting training...")
694
+ training_args = self.setup_training_arguments()
695
+ trainer = Trainer(
696
+ model=self.model,
697
+ args=training_args,
698
+ train_dataset=self.processed_dataset,
699
+ tokenizer=self.tokenizer
700
+ )
701
+ trainer.train()
702
+
703
+ # Save and push to hub
704
+ logger.info("Saving model...")
705
+ trainer.save_model()
706
+ if os.getenv('HF_TOKEN'):
707
+ trainer.push_to_hub(
708
+ "Pearilsa/pearly_med_triage_chatbot_kagglex",
709
+ private=True
710
+ )
711
+
712
+ logger.info("Training completed successfully!")
713
+
714
+ except Exception as e:
715
+ logger.error(f"Training failed: {e}")
716
+ raise
717
+ finally:
718
+ torch.cuda.empty_cache()
719
+
720
+ if __name__ == "__main__":
721
+ # Initialize trainer
722
+ trainer = ModelTrainer()
723
+
724
+ # Train model
725
+ trainer.train()
726
+
727
+ def _get_enhanced_context(self, query: str) -> str:
728
+ """Get relevant context with scores"""
729
+ try:
730
+ # Get documents with similarity scores
731
+ docs_and_scores = self.vector_store.similarity_search_with_score(
732
+ query,
733
+ k=self.num_relevant_chunks
734
+ )
735
+
736
+ # Filter and format relevant contexts
737
+ relevant_contexts = []
738
+ for doc, score in docs_and_scores:
739
+ if score < 0.8: # Lower score means more relevant
740
+ source = doc.metadata.get('source', 'Unknown')
741
+ relevant_contexts.append(
742
+ f"[Source: {source}]\n{doc.page_content}"
743
+ )
744
+
745
+ return "\n\n".join(relevant_contexts) if relevant_contexts else ""
746
+
747
+ except Exception as e:
748
+ logger.error(f"Error retrieving enhanced context: {e}")
749
+ return ""
750
 
751
  def _initialize_embeddings(self):
752
  try:
 
784
  def generate_response(self, message: str, history: list) -> str:
785
  """Generate response using both fine-tuned model and RAG"""
786
  try:
787
+ # Rate limiting and memory management
788
  current_time = time.time()
789
  if current_time - self.last_interaction_time < self.interaction_cooldown:
790
  time.sleep(self.interaction_cooldown)
791
+ torch.cuda.empty_cache()
792
 
793
+ # Get enhanced context from RAG
794
+ context = self._get_enhanced_context(message)
 
 
 
795
 
796
  # Format conversation history
797
  conv_history = "\n".join([
798
+ f"User: {turn['input']}\nAssistant: {turn['output']}"
799
+ for turn in history[-3:] # Keep last 3 turns
800
  ])
801
 
802
+ # Create enhanced prompt with RAG context
803
  prompt = f"""<start_of_turn>system
804
  Using these medical guidelines:
805
 
 
809
  {conv_history}
810
 
811
  Guidelines:
812
+ 1. Assess symptoms and severity based on both your training and the provided guidelines
813
+ 2. Ask relevant follow-up questions if needed
814
+ 3. Direct to appropriate care (999, 111, or GP) according to symptom severity
815
  4. Show empathy and cultural sensitivity
816
  5. Never diagnose or recommend treatments
817
  <end_of_turn>
 
820
  <end_of_turn>
821
  <start_of_turn>assistant"""
822
 
823
+ # Generate response with model
824
+ inputs = self.tokenizer(
825
+ prompt,
826
+ return_tensors="pt",
827
+ truncation=True,
828
+ max_length=512
829
+ ).to(self.model.device)
830
+
831
+ outputs = self.model.generate(
832
+ **inputs,
833
+ max_new_tokens=256,
834
+ min_new_tokens=20,
835
+ do_sample=True,
836
+ temperature=0.7,
837
+ top_p=0.9,
838
+ repetition_penalty=1.2,
839
+ no_repeat_ngram_size=3
840
+ )
841
+
842
+ # Process response
843
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
844
+ response = response.split("<start_of_turn>assistant")[-1].strip()
845
+ if "<end_of_turn>" in response:
846
+ response = response.split("<end_of_turn>")[0].strip()
847
+
848
+ self.last_interaction_time = time.time()
849
+ return response
850
+
 
 
 
 
 
851
  except Exception as e:
852
+ logger.error(f"Error generating response: {e}")
853
  return "I apologize, but I encountered an error. Please try again."
854
 
855
  def handle_feedback(self, message: str, response: str, feedback: int):
 
1223
  raise
1224
 
1225
  if __name__ == "__main__":
1226
+ # Initialize logging and load env vars
1227
+ logging.basicConfig(level=logging.INFO)
1228
+ load_dotenv()
1229
+
1230
+ # Create and launch demo
1231
+ demo = create_demo()
1232
+ demo.launch(server_name="0.0.0.0", server_port=7860)