Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -63,369 +63,376 @@ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection
|
|
63 |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
64 |
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
def process_medqa_batch(examples):
|
70 |
-
results = []
|
71 |
-
inputs = examples['input']
|
72 |
-
instructions = examples['instruction']
|
73 |
-
outputs = examples['output']
|
74 |
-
|
75 |
-
for inp, inst, out in zip(inputs, instructions, outputs):
|
76 |
-
results.append({
|
77 |
-
"input": f"{inp} {inst}",
|
78 |
-
"output": out
|
79 |
-
})
|
80 |
-
return results
|
81 |
-
|
82 |
-
def process_meddia_batch(examples):
|
83 |
-
results = []
|
84 |
-
inputs = examples['input']
|
85 |
-
outputs = examples['output']
|
86 |
-
|
87 |
-
for inp, out in zip(inputs, outputs):
|
88 |
-
results.append({
|
89 |
-
"input": inp,
|
90 |
-
"output": out
|
91 |
-
})
|
92 |
-
return results
|
93 |
-
|
94 |
-
def process_persona_batch(examples):
|
95 |
-
results = []
|
96 |
-
personalities = examples['personality']
|
97 |
-
utterances = examples['utterances']
|
98 |
-
|
99 |
-
for pers, utts in zip(personalities, utterances):
|
100 |
-
try:
|
101 |
-
# Process personality list
|
102 |
-
personality = ' '.join([
|
103 |
-
p for p in pers
|
104 |
-
if isinstance(p, str)
|
105 |
-
])
|
106 |
-
|
107 |
-
# Process utterances
|
108 |
-
if utts and len(utts) > 0:
|
109 |
-
utterance = utts[0]
|
110 |
-
history = []
|
111 |
-
|
112 |
-
# Process history
|
113 |
-
if 'history' in utterance and utterance['history']:
|
114 |
-
history = [
|
115 |
-
h for h in utterance['history']
|
116 |
-
if isinstance(h, str)
|
117 |
-
]
|
118 |
-
|
119 |
-
history_text = ' '.join(history)
|
120 |
-
|
121 |
-
# Get candidate response
|
122 |
-
candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else ''
|
123 |
-
|
124 |
-
if personality or history_text:
|
125 |
-
results.append({
|
126 |
-
"input": f"{personality} {history_text}".strip(),
|
127 |
-
"output": candidate
|
128 |
-
})
|
129 |
-
except Exception as e:
|
130 |
-
print(f"Error processing persona batch item: {e}")
|
131 |
-
continue
|
132 |
-
|
133 |
-
return results
|
134 |
-
|
135 |
-
# Load and process each dataset separately
|
136 |
-
print("Processing MedQA dataset...")
|
137 |
-
medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]")
|
138 |
-
medqa_processed = []
|
139 |
-
|
140 |
-
for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"):
|
141 |
-
batch = medqa[i:i + batch_size]
|
142 |
-
medqa_processed.extend(process_medqa_batch(batch))
|
143 |
-
if i % (batch_size * 5) == 0:
|
144 |
-
torch.cuda.empty_cache()
|
145 |
-
|
146 |
-
print("Processing MedDiagnosis dataset...")
|
147 |
-
meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]")
|
148 |
-
meddia_processed = []
|
149 |
-
|
150 |
-
for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"):
|
151 |
-
batch = meddia[i:i + batch_size]
|
152 |
-
meddia_processed.extend(process_meddia_batch(batch))
|
153 |
-
if i % (batch_size * 5) == 0:
|
154 |
-
torch.cuda.empty_cache()
|
155 |
-
|
156 |
-
print("Processing Persona-Chat dataset...")
|
157 |
-
persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
|
158 |
-
persona_processed = []
|
159 |
-
|
160 |
-
for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"):
|
161 |
-
batch = persona[i:i + batch_size]
|
162 |
-
persona_processed.extend(process_persona_batch(batch))
|
163 |
-
if i % (batch_size * 5) == 0:
|
164 |
-
torch.cuda.empty_cache()
|
165 |
-
|
166 |
-
torch.cuda.empty_cache()
|
167 |
-
|
168 |
-
print("Creating final dataset...")
|
169 |
-
all_processed = persona_processed + medqa_processed + meddia_processed
|
170 |
-
|
171 |
-
valid_data = {
|
172 |
-
"input": [],
|
173 |
-
"output": []
|
174 |
-
}
|
175 |
-
|
176 |
-
for item in all_processed:
|
177 |
-
if item["input"].strip() and item["output"].strip():
|
178 |
-
valid_data["input"].append(item["input"])
|
179 |
-
valid_data["output"].append(item["output"])
|
180 |
-
|
181 |
-
final_dataset = Dataset.from_dict(valid_data)
|
182 |
-
|
183 |
-
print(f"Final dataset size: {len(final_dataset)}")
|
184 |
-
return final_dataset
|
185 |
-
|
186 |
-
def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4):
|
187 |
-
def tokenize_batch(examples):
|
188 |
-
formatted_texts = []
|
189 |
-
|
190 |
-
for i in range(0, len(examples['input']), batch_size):
|
191 |
-
sub_batch_inputs = examples['input'][i:i + batch_size]
|
192 |
-
sub_batch_outputs = examples['output'][i:i + batch_size]
|
193 |
-
|
194 |
-
for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs):
|
195 |
-
try:
|
196 |
-
formatted_text = f"""<start_of_turn>user
|
197 |
-
{input_text}
|
198 |
-
<end_of_turn>
|
199 |
-
<start_of_turn>assistant
|
200 |
-
{output_text}
|
201 |
-
<end_of_turn>"""
|
202 |
-
formatted_texts.append(formatted_text)
|
203 |
-
except Exception as e:
|
204 |
-
print(f"Error formatting text: {e}")
|
205 |
-
continue
|
206 |
-
|
207 |
-
tokenized = tokenizer(
|
208 |
-
formatted_texts,
|
209 |
-
padding="max_length",
|
210 |
-
truncation=True,
|
211 |
-
max_length=max_length,
|
212 |
-
return_tensors=None
|
213 |
-
)
|
214 |
-
|
215 |
-
tokenized["labels"] = tokenized["input_ids"].copy()
|
216 |
-
return tokenized
|
217 |
-
|
218 |
-
print(f"Tokenizing dataset in small batches (size={batch_size})...")
|
219 |
-
tokenized_dataset = dataset.map(
|
220 |
-
tokenize_batch,
|
221 |
-
batched=True,
|
222 |
-
batch_size=batch_size,
|
223 |
-
remove_columns=dataset.column_names,
|
224 |
-
desc="Tokenizing dataset",
|
225 |
-
load_from_cache_file=False
|
226 |
-
)
|
227 |
-
|
228 |
-
return tokenized_dataset
|
229 |
-
|
230 |
-
def setup_model_and_tokenizer(model_name="google/gemma-2b"):
|
231 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
232 |
-
tokenizer.pad_token = tokenizer.eos_token
|
233 |
-
|
234 |
-
from transformers import BitsAndBytesConfig
|
235 |
-
|
236 |
-
bnb_config = BitsAndBytesConfig(
|
237 |
-
load_in_8bit=True,
|
238 |
-
bnb_8bit_compute_dtype=torch.float16,
|
239 |
-
llm_int8_enable_fp32_cpu_offload=True
|
240 |
-
)
|
241 |
-
|
242 |
-
model = AutoModelForCausalLM.from_pretrained(
|
243 |
-
model_name,
|
244 |
-
device_map="auto",
|
245 |
-
quantization_config=bnb_config,
|
246 |
-
torch_dtype=torch.float16,
|
247 |
-
low_cpu_mem_usage=True
|
248 |
-
)
|
249 |
-
|
250 |
-
model = prepare_model_for_kbit_training(model)
|
251 |
-
|
252 |
-
lora_config = LoraConfig(
|
253 |
-
r=4,
|
254 |
-
lora_alpha=16,
|
255 |
-
target_modules=["q_proj", "v_proj"],
|
256 |
-
lora_dropout=0.05,
|
257 |
-
bias="none",
|
258 |
-
task_type="CAUSAL_LM"
|
259 |
-
)
|
260 |
-
|
261 |
-
model = get_peft_model(model, lora_config)
|
262 |
-
model.print_trainable_parameters()
|
263 |
-
|
264 |
-
return model, tokenizer
|
265 |
-
|
266 |
-
def setup_training_arguments(output_dir="./pearly_fine_tuned"):
|
267 |
-
return TrainingArguments(
|
268 |
-
output_dir=output_dir,
|
269 |
-
num_train_epochs=1,
|
270 |
-
per_device_train_batch_size=1,
|
271 |
-
gradient_accumulation_steps=16,
|
272 |
-
warmup_steps=50,
|
273 |
-
logging_steps=10,
|
274 |
-
save_steps=200,
|
275 |
-
learning_rate=2e-4,
|
276 |
-
fp16=True,
|
277 |
-
gradient_checkpointing=True,
|
278 |
-
gradient_checkpointing_kwargs={"use_reentrant": False},
|
279 |
-
optim="adamw_8bit",
|
280 |
-
max_grad_norm=0.3,
|
281 |
-
weight_decay=0.001,
|
282 |
-
logging_dir="./logs",
|
283 |
-
save_total_limit=2,
|
284 |
-
remove_unused_columns=False,
|
285 |
-
dataloader_pin_memory=False,
|
286 |
-
max_steps=500,
|
287 |
-
report_to=["none"],
|
288 |
-
)
|
289 |
-
|
290 |
-
def main():
|
291 |
-
torch.backends.cuda.matmul.allow_tf32 = False
|
292 |
-
torch.backends.cudnn.allow_tf32 = False
|
293 |
-
|
294 |
-
torch.cuda.empty_cache()
|
295 |
-
if torch.cuda.is_available():
|
296 |
-
torch.cuda.reset_peak_memory_stats()
|
297 |
-
|
298 |
-
print("Preparing initial datasets...")
|
299 |
-
combined_dataset = prepare_initial_datasets(batch_size=4)
|
300 |
-
|
301 |
-
print(f"\nDataset size: {len(combined_dataset)}")
|
302 |
-
print(f"Column names: {combined_dataset.column_names}")
|
303 |
-
|
304 |
-
if len(combined_dataset) > 0:
|
305 |
-
print("\nSample input-output pair:")
|
306 |
-
print(f"Input: {combined_dataset[0]['input'][:100]}...")
|
307 |
-
print(f"Output: {combined_dataset[0]['output'][:100]}...")
|
308 |
-
|
309 |
-
print("\nSetting up model and tokenizer...")
|
310 |
-
model, tokenizer = setup_model_and_tokenizer()
|
311 |
-
|
312 |
-
print("\nPreparing dataset for training...")
|
313 |
-
processed_dataset = prepare_dataset(
|
314 |
-
combined_dataset,
|
315 |
-
tokenizer,
|
316 |
-
max_length=256,
|
317 |
-
batch_size=2
|
318 |
-
)
|
319 |
-
|
320 |
-
torch.cuda.empty_cache()
|
321 |
-
|
322 |
-
training_args = setup_training_arguments()
|
323 |
-
|
324 |
-
trainer = Trainer(
|
325 |
-
model=model,
|
326 |
-
args=training_args,
|
327 |
-
train_dataset=processed_dataset,
|
328 |
-
tokenizer=tokenizer,
|
329 |
-
)
|
330 |
-
|
331 |
-
print("\nStarting training...")
|
332 |
-
try:
|
333 |
-
trainer.train()
|
334 |
-
except Exception as e:
|
335 |
-
print(f"Training error: {e}")
|
336 |
-
torch.cuda.empty_cache()
|
337 |
-
raise e
|
338 |
-
finally:
|
339 |
-
torch.cuda.empty_cache()
|
340 |
-
|
341 |
-
print("\nSaving model...")
|
342 |
-
trainer.save_model()
|
343 |
-
print("Training completed!")
|
344 |
-
|
345 |
-
DISCLAIMER = """
|
346 |
-
IMPORTANT MEDICAL DISCLAIMER:
|
347 |
-
Pearly is an AI medical triage assistant designed to help direct you to appropriate medical services.
|
348 |
-
Pearly DOES NOT:
|
349 |
-
- Make medical diagnoses
|
350 |
-
- Prescribe medications
|
351 |
-
- Provide specific treatment recommendations
|
352 |
-
- Replace professional medical advice
|
353 |
-
|
354 |
-
Always consult qualified healthcare professionals for medical advice and treatment.
|
355 |
-
In case of emergency, call 999 immediately.
|
356 |
-
"""
|
357 |
|
358 |
class PearlyBot:
|
359 |
-
def __init__(self
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
# Clean memory
|
364 |
-
if torch.cuda.is_available():
|
365 |
-
torch.cuda.empty_cache()
|
366 |
-
|
367 |
-
# Load tokenizer and model directly from saved path
|
368 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
369 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
370 |
-
model_path,
|
371 |
-
torch_dtype=torch.float16,
|
372 |
-
low_cpu_mem_usage=True,
|
373 |
-
device_map="auto"
|
374 |
-
)
|
375 |
-
|
376 |
-
self.model.eval() # Set to evaluation mode
|
377 |
-
|
378 |
-
# Initialize RAG components
|
379 |
-
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
|
380 |
-
self.vector_store = None
|
381 |
self.conversation_history = []
|
382 |
-
|
383 |
-
def initialize_rag(self, documents_path="./knowledge_base"):
|
384 |
-
"""Initialize RAG system"""
|
385 |
-
print("Loading knowledge base...")
|
386 |
-
|
387 |
-
text_splitter = RecursiveCharacterTextSplitter(
|
388 |
-
chunk_size=300,
|
389 |
-
chunk_overlap=100,
|
390 |
-
separators=["\n\n", "\n", ".", "!", "?", ":"]
|
391 |
-
)
|
392 |
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
return ""
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
|
|
|
|
|
|
|
|
422 |
|
423 |
{context}
|
424 |
|
425 |
-
Previous
|
426 |
-
{
|
427 |
|
428 |
-
|
429 |
1. Assess symptoms and severity
|
430 |
2. Ask relevant follow-up questions
|
431 |
3. Direct to appropriate care (999, 111, or GP)
|
@@ -433,18 +440,18 @@ Based on these guidelines, I will:
|
|
433 |
5. Never diagnose or recommend treatments
|
434 |
<end_of_turn>
|
435 |
<start_of_turn>user
|
436 |
-
{
|
437 |
<end_of_turn>
|
438 |
<start_of_turn>assistant"""
|
439 |
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
outputs = self.model.generate(
|
449 |
**inputs,
|
450 |
max_new_tokens=256,
|
@@ -452,21 +459,20 @@ Based on these guidelines, I will:
|
|
452 |
do_sample=True,
|
453 |
temperature=0.7,
|
454 |
top_p=0.9,
|
455 |
-
repetition_penalty=1.2
|
456 |
-
pad_token_id=self.tokenizer.pad_token_id
|
457 |
)
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
|
471 |
def create_demo():
|
472 |
"""Set up Gradio interface for the chatbot with enhanced styling and functionality."""
|
@@ -475,9 +481,9 @@ def create_demo():
|
|
475 |
@gr.routes.get("/health")
|
476 |
def health_check():
|
477 |
return {"status": "healthy"}
|
478 |
-
bot =
|
479 |
|
480 |
-
def chat(message: str, history:
|
481 |
try:
|
482 |
if not message.strip():
|
483 |
return history
|
|
|
63 |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
64 |
|
65 |
|
66 |
+
# Setup logging
|
67 |
+
logging.basicConfig(level=logging.INFO)
|
68 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
class PearlyBot:
|
71 |
+
def __init__(self):
|
72 |
+
self.setup_model()
|
73 |
+
self.setup_rag()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
self.conversation_history = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
def setup_model(self):
|
77 |
+
try:
|
78 |
+
logger.info("Loading local checkpoint...")
|
79 |
+
# Load from the local checkpoint in your Space
|
80 |
+
checkpoint_path = "checkpoint-500.zip" # Path to your uploaded checkpoint
|
81 |
+
base_model_id = "google/gemma-2b" # Your base model
|
82 |
+
|
83 |
+
# Load tokenizer from base model
|
84 |
+
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id)
|
85 |
+
|
86 |
+
# Load model with checkpoint
|
87 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
88 |
+
checkpoint_path,
|
89 |
+
device_map="auto",
|
90 |
+
load_in_8bit=True,
|
91 |
+
torch_dtype=torch.float16,
|
92 |
+
low_cpu_mem_usage=True
|
93 |
+
)
|
94 |
+
self.model.eval()
|
95 |
+
logger.info("Model loaded successfully")
|
96 |
+
except Exception as e:
|
97 |
+
logger.error(f"Error loading model: {str(e)}")
|
98 |
+
raise
|
99 |
+
|
100 |
+
def setup_rag(self):
|
101 |
+
try:
|
102 |
+
logger.info("Setting up RAG system...")
|
103 |
+
# Load your knowledge base content
|
104 |
+
knowledge_base = {
|
105 |
+
"triage_scenarios.txt": """Medical Triage Scenarios and Responses:
|
106 |
+
|
107 |
+
EMERGENCY (999) SCENARIOS:
|
108 |
+
1. Cardiovascular:
|
109 |
+
- Chest pain/pressure
|
110 |
+
- Heart attack symptoms
|
111 |
+
- Irregular heartbeat with dizziness
|
112 |
+
Response: Immediate 999 call, sit/lie down, chew aspirin if available
|
113 |
+
|
114 |
+
2. Respiratory:
|
115 |
+
- Severe breathing difficulty
|
116 |
+
- Choking
|
117 |
+
- Unable to speak full sentences
|
118 |
+
Response: 999, sitting position, clear airway
|
119 |
+
|
120 |
+
3. Neurological:
|
121 |
+
- Stroke symptoms (FAST)
|
122 |
+
- Seizures
|
123 |
+
- Unconsciousness
|
124 |
+
Response: 999, recovery position if unconscious
|
125 |
+
|
126 |
+
4. Trauma:
|
127 |
+
- Severe bleeding
|
128 |
+
- Head injuries with confusion
|
129 |
+
- Major burns
|
130 |
+
Response: 999, apply direct pressure to bleeding
|
131 |
+
|
132 |
+
URGENT CARE (111) SCENARIOS:
|
133 |
+
1. Moderate Symptoms:
|
134 |
+
- Persistent fever
|
135 |
+
- Non-severe infections
|
136 |
+
- Minor injuries
|
137 |
+
Response: 111 contact, monitor symptoms
|
138 |
+
|
139 |
+
2. Minor Emergencies:
|
140 |
+
- Small cuts needing stitches
|
141 |
+
- Sprains and strains
|
142 |
+
- Mild allergic reactions
|
143 |
+
Response: 111 or urgent care visit
|
144 |
+
|
145 |
+
GP APPOINTMENT SCENARIOS:
|
146 |
+
1. Routine Care:
|
147 |
+
- Chronic condition review
|
148 |
+
- Medication reviews
|
149 |
+
- Non-urgent symptoms
|
150 |
+
Response: Book routine GP appointment
|
151 |
+
|
152 |
+
2. Preventive Care:
|
153 |
+
- Vaccinations
|
154 |
+
- Health screenings
|
155 |
+
- Regular check-ups
|
156 |
+
Response: Schedule with GP reception""",
|
157 |
+
"emergency_detection.txt": """Enhanced Emergency Detection Criteria:
|
158 |
+
|
159 |
+
IMMEDIATE LIFE THREATS:
|
160 |
+
1. Cardiac Symptoms:
|
161 |
+
- Chest pain/pressure/tightness
|
162 |
+
- Pain spreading to arms/jaw/neck
|
163 |
+
- Sweating with nausea
|
164 |
+
- Shortness of breath
|
165 |
+
|
166 |
+
2. Breathing Problems:
|
167 |
+
- Severe shortness of breath
|
168 |
+
- Blue lips or face
|
169 |
+
- Unable to complete sentences
|
170 |
+
- Choking/airway blockage
|
171 |
+
|
172 |
+
3. Neurological:
|
173 |
+
- FAST (Face, Arms, Speech, Time)
|
174 |
+
- Sudden confusion
|
175 |
+
- Severe headache
|
176 |
+
- Seizures
|
177 |
+
- Loss of consciousness
|
178 |
+
|
179 |
+
4. Severe Trauma:
|
180 |
+
- Heavy bleeding
|
181 |
+
- Deep wounds
|
182 |
+
- Head injury with confusion
|
183 |
+
- Severe burns
|
184 |
+
- Broken bones with deformity
|
185 |
+
|
186 |
+
5. Anaphylaxis:
|
187 |
+
- Sudden swelling
|
188 |
+
- Difficulty breathing
|
189 |
+
- Rapid onset rash
|
190 |
+
- Light-headedness
|
191 |
+
|
192 |
+
URGENT BUT NOT IMMEDIATE:
|
193 |
+
1. Moderate Symptoms:
|
194 |
+
- Persistent fever
|
195 |
+
- Dehydration
|
196 |
+
- Non-severe infections
|
197 |
+
- Minor injuries
|
198 |
+
|
199 |
+
2. Worsening Conditions:
|
200 |
+
- Increasing pain
|
201 |
+
- Progressive symptoms
|
202 |
+
- Medication reactions
|
203 |
+
|
204 |
+
RESPONSE PROTOCOLS:
|
205 |
+
1. For Life Threats:
|
206 |
+
- Immediate 999 call
|
207 |
+
- Clear first aid instructions
|
208 |
+
- Stay on line until help arrives
|
209 |
+
|
210 |
+
2. For Urgent Care:
|
211 |
+
- 111 contact
|
212 |
+
- Monitor for worsening
|
213 |
+
- Document symptoms""",
|
214 |
+
"gp_booking.txt": """GP Appointment Booking Templates:
|
215 |
+
|
216 |
+
APPOINTMENT TYPES:
|
217 |
+
1. Routine Appointments:
|
218 |
+
Template: "I need to book a routine appointment for [condition]. My availability is [times/dates]. My GP is Dr. [name] if available."
|
219 |
+
|
220 |
+
2. Follow-up Appointments:
|
221 |
+
Template: "I need a follow-up appointment regarding [condition] discussed on [date]. My previous appointment was with Dr. [name]."
|
222 |
+
|
223 |
+
3. Medication Reviews:
|
224 |
+
Template: "I need a medication review for [medication]. My last review was [date]."
|
225 |
+
|
226 |
+
BOOKING INFORMATION NEEDED:
|
227 |
+
1. Patient Details:
|
228 |
+
- Full name
|
229 |
+
- Date of birth
|
230 |
+
- NHS number (if known)
|
231 |
+
- Registered GP practice
|
232 |
+
|
233 |
+
2. Appointment Details:
|
234 |
+
- Nature of appointment
|
235 |
+
- Preferred times/dates
|
236 |
+
- Urgency level
|
237 |
+
- Special requirements
|
238 |
+
|
239 |
+
3. Contact Information:
|
240 |
+
- Phone number
|
241 |
+
- Alternative contact
|
242 |
+
- Preferred contact method
|
243 |
+
|
244 |
+
BOOKING PROCESS:
|
245 |
+
1. Online Booking:
|
246 |
+
- NHS app instructions
|
247 |
+
- Practice website guidance
|
248 |
+
- System navigation help
|
249 |
+
|
250 |
+
2. Phone Booking:
|
251 |
+
- Best times to call
|
252 |
+
- Required information
|
253 |
+
- Queue management tips
|
254 |
+
|
255 |
+
3. Special Circumstances:
|
256 |
+
- Interpreter needs
|
257 |
+
- Accessibility requirements
|
258 |
+
- Transport arrangements""",
|
259 |
+
"cultural_sensitivity.txt": """Cultural Sensitivity Guidelines:
|
260 |
+
|
261 |
+
CULTURAL AWARENESS:
|
262 |
+
1. Religious Considerations:
|
263 |
+
- Prayer times
|
264 |
+
- Religious observations
|
265 |
+
- Dietary restrictions
|
266 |
+
- Gender preferences for care
|
267 |
+
- Religious festivals/fasting periods
|
268 |
+
|
269 |
+
2. Language Support:
|
270 |
+
- Interpreter services
|
271 |
+
- Multi-language resources
|
272 |
+
- Clear communication methods
|
273 |
+
- Family involvement preferences
|
274 |
+
|
275 |
+
3. Cultural Beliefs:
|
276 |
+
- Traditional medicine practices
|
277 |
+
- Cultural health beliefs
|
278 |
+
- Family decision-making
|
279 |
+
- Privacy customs
|
280 |
+
|
281 |
+
COMMUNICATION APPROACHES:
|
282 |
+
1. Respectful Interaction:
|
283 |
+
- Use preferred names/titles
|
284 |
+
- Appropriate greetings
|
285 |
+
- Non-judgmental responses
|
286 |
+
- Active listening
|
287 |
+
|
288 |
+
2. Language Usage:
|
289 |
+
- Clear, simple terms
|
290 |
+
- Avoid medical jargon
|
291 |
+
- Confirm understanding
|
292 |
+
- Respect silence/pauses
|
293 |
+
|
294 |
+
3. Non-verbal Communication:
|
295 |
+
- Eye contact customs
|
296 |
+
- Personal space
|
297 |
+
- Body language awareness
|
298 |
+
- Gesture sensitivity
|
299 |
+
|
300 |
+
SPECIFIC CONSIDERATIONS:
|
301 |
+
1. South Asian Communities:
|
302 |
+
- Family involvement
|
303 |
+
- Gender sensitivity
|
304 |
+
- Traditional medicine
|
305 |
+
- Language diversity
|
306 |
+
|
307 |
+
2. Middle Eastern Communities:
|
308 |
+
- Gender-specific care
|
309 |
+
- Religious observations
|
310 |
+
- Family hierarchies
|
311 |
+
- Privacy concerns
|
312 |
+
|
313 |
+
3. African/Caribbean Communities:
|
314 |
+
- Traditional healers
|
315 |
+
- Community involvement
|
316 |
+
- Historical medical mistrust
|
317 |
+
- Cultural specific conditions
|
318 |
+
|
319 |
+
4. Eastern European Communities:
|
320 |
+
- Direct communication
|
321 |
+
- Family involvement
|
322 |
+
- Medical documentation
|
323 |
+
- Language support
|
324 |
+
|
325 |
+
INCLUSIVE PRACTICES:
|
326 |
+
1. Appointment Scheduling:
|
327 |
+
- Religious holidays
|
328 |
+
- Prayer times
|
329 |
+
- Family availability
|
330 |
+
- Interpreter needs
|
331 |
+
|
332 |
+
2. Treatment Planning:
|
333 |
+
- Cultural preferences
|
334 |
+
- Traditional practices
|
335 |
+
- Family involvement
|
336 |
+
- Dietary requirements
|
337 |
+
|
338 |
+
3. Support Services:
|
339 |
+
- Community resources
|
340 |
+
- Cultural organizations
|
341 |
+
- Language services
|
342 |
+
- Social support""",
|
343 |
+
"service_boundaries.txt": """Service Limitations and Professional Boundaries:
|
344 |
+
|
345 |
+
CLEAR BOUNDARIES:
|
346 |
+
1. Medical Advice:
|
347 |
+
- No diagnoses
|
348 |
+
- No prescriptions
|
349 |
+
- No treatment recommendations
|
350 |
+
- No medical procedures
|
351 |
+
- No second opinions
|
352 |
+
|
353 |
+
2. Emergency Services:
|
354 |
+
- Clear referral criteria
|
355 |
+
- Documented responses
|
356 |
+
- Follow-up protocols
|
357 |
+
- Handover procedures
|
358 |
+
|
359 |
+
3. Information Sharing:
|
360 |
+
- Confidentiality limits
|
361 |
+
- Data protection
|
362 |
+
- Record keeping
|
363 |
+
- Information governance
|
364 |
+
|
365 |
+
PROFESSIONAL CONDUCT:
|
366 |
+
1. Communication:
|
367 |
+
- Professional language
|
368 |
+
- Emotional boundaries
|
369 |
+
- Personal distance
|
370 |
+
- Service scope
|
371 |
+
|
372 |
+
2. Service Delivery:
|
373 |
+
- No financial transactions
|
374 |
+
- No personal relationships
|
375 |
+
- Clear role definition
|
376 |
+
- Professional limits"""
|
377 |
+
}
|
378 |
+
|
379 |
+
os.makedirs("knowledge_base", exist_ok=True)
|
380 |
+
|
381 |
+
# Create and process documents
|
382 |
+
documents = []
|
383 |
+
for filename, content in knowledge_base.items():
|
384 |
+
with open(f"knowledge_base/{filename}", "w") as f:
|
385 |
+
f.write(content)
|
386 |
+
documents.append(content)
|
387 |
+
|
388 |
+
# Setup embeddings and vector store
|
389 |
+
self.embeddings = HuggingFaceEmbeddings(
|
390 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
391 |
+
)
|
392 |
+
|
393 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
394 |
+
chunk_size=300,
|
395 |
+
chunk_overlap=100
|
396 |
+
)
|
397 |
+
|
398 |
+
texts = text_splitter.split_text("\n\n".join(documents))
|
399 |
+
self.vector_store = FAISS.from_texts(texts, self.embeddings)
|
400 |
+
logger.info("RAG system setup complete")
|
401 |
+
|
402 |
+
except Exception as e:
|
403 |
+
logger.error(f"Error setting up RAG: {str(e)}")
|
404 |
+
raise
|
405 |
+
|
406 |
+
def get_relevant_context(self, query):
|
407 |
+
try:
|
408 |
+
docs = self.vector_store.similarity_search(query, k=3)
|
409 |
+
return "\n".join(doc.page_content for doc in docs)
|
410 |
+
except Exception as e:
|
411 |
+
logger.error(f"Error retrieving context: {str(e)}")
|
412 |
return ""
|
413 |
+
|
414 |
+
@torch.inference_mode()
|
415 |
+
def generate_response(self, message: str, history: list) -> str:
|
416 |
+
try:
|
417 |
+
# Get RAG context
|
418 |
+
context = self.get_relevant_context(message)
|
419 |
+
|
420 |
+
# Format conversation history
|
421 |
+
conv_history = "\n".join([
|
422 |
+
f"User: {user}\nAssistant: {assistant}"
|
423 |
+
for user, assistant in history[-3:] # Keep last 3 turns
|
424 |
+
])
|
425 |
+
|
426 |
+
# Create prompt
|
427 |
+
prompt = f"""<start_of_turn>system
|
428 |
+
Using these medical guidelines:
|
429 |
|
430 |
{context}
|
431 |
|
432 |
+
Previous conversation:
|
433 |
+
{conv_history}
|
434 |
|
435 |
+
Guidelines:
|
436 |
1. Assess symptoms and severity
|
437 |
2. Ask relevant follow-up questions
|
438 |
3. Direct to appropriate care (999, 111, or GP)
|
|
|
440 |
5. Never diagnose or recommend treatments
|
441 |
<end_of_turn>
|
442 |
<start_of_turn>user
|
443 |
+
{message}
|
444 |
<end_of_turn>
|
445 |
<start_of_turn>assistant"""
|
446 |
|
447 |
+
# Generate response
|
448 |
+
inputs = self.tokenizer(
|
449 |
+
prompt,
|
450 |
+
return_tensors="pt",
|
451 |
+
truncation=True,
|
452 |
+
max_length=512
|
453 |
+
).to(self.model.device)
|
454 |
+
|
455 |
outputs = self.model.generate(
|
456 |
**inputs,
|
457 |
max_new_tokens=256,
|
|
|
459 |
do_sample=True,
|
460 |
temperature=0.7,
|
461 |
top_p=0.9,
|
462 |
+
repetition_penalty=1.2
|
|
|
463 |
)
|
464 |
+
|
465 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
466 |
+
response = response.split("<start_of_turn>assistant")[-1].strip()
|
467 |
+
if "<end_of_turn>" in response:
|
468 |
+
response = response.split("<end_of_turn>")[0].strip()
|
469 |
+
|
470 |
+
return response
|
471 |
+
|
472 |
+
except Exception as e:
|
473 |
+
logger.error(f"Error generating response: {str(e)}")
|
474 |
+
return "I apologize, but I encountered an error. Please try again."
|
475 |
+
|
476 |
|
477 |
def create_demo():
|
478 |
"""Set up Gradio interface for the chatbot with enhanced styling and functionality."""
|
|
|
481 |
@gr.routes.get("/health")
|
482 |
def health_check():
|
483 |
return {"status": "healthy"}
|
484 |
+
bot = PearlyBot() # ✅
|
485 |
|
486 |
+
def chat(message: str, history: list): # ✅
|
487 |
try:
|
488 |
if not message.strip():
|
489 |
return history
|