Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,251 +1,472 @@
|
|
1 |
# app.py
|
2 |
import os
|
3 |
-
import
|
|
|
|
|
|
|
|
|
4 |
import torch
|
5 |
-
from
|
6 |
-
import gradio as gr
|
7 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
-
from
|
10 |
import faiss
|
11 |
import numpy as np
|
12 |
-
from datasets import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from datetime import datetime
|
14 |
-
import
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
try:
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
23 |
except Exception as e:
|
24 |
print("Hugging Face Login failed:", e)
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
def
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
trust_remote_code=True
|
98 |
-
)
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
quantization_config=bnb_config,
|
104 |
-
device_map="auto",
|
105 |
-
trust_remote_code=True,
|
106 |
-
use_auth_token=True
|
107 |
-
)
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
r=self.config.LORA_R,
|
115 |
-
lora_alpha=self.config.LORA_ALPHA,
|
116 |
-
target_modules=self.config.LORA_TARGET_MODULES,
|
117 |
-
lora_dropout=self.config.LORA_DROPOUT,
|
118 |
-
bias="none",
|
119 |
-
task_type=TaskType.CAUSAL_LM
|
120 |
-
)
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
""
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
raise
|
160 |
-
|
161 |
-
def _create_index(self):
|
162 |
-
"""Create FAISS index for RAG"""
|
163 |
-
try:
|
164 |
-
sample_embedding = self.embedding_model.encode("sample text")
|
165 |
-
self.index = faiss.IndexFlatIP(sample_embedding.shape[0])
|
166 |
-
|
167 |
-
embeddings = [self.embedding_model.encode(doc['text']) for doc in self.documents]
|
168 |
-
self.index.add(np.array(embeddings))
|
169 |
-
except Exception as e:
|
170 |
-
logger.error(f"Error creating FAISS index: {e}")
|
171 |
-
raise
|
172 |
-
|
173 |
-
def generate_follow_up_questions(self, message: str, context: Dict[str, Any]) -> List[str]:
|
174 |
-
"""Generate follow-up questions based on context"""
|
175 |
-
try:
|
176 |
-
prompt = f"""Patient message: "{message}"
|
177 |
-
Generate relevant follow-up questions focusing on timing, severity, associated symptoms, and impact on daily life.
|
178 |
-
Questions:"""
|
179 |
-
|
180 |
-
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device)
|
181 |
-
outputs = self.model.generate(inputs['input_ids'], max_new_tokens=50, temperature=0.7, do_sample=True)
|
182 |
-
questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
183 |
-
return questions.split("\n")
|
184 |
-
except Exception as e:
|
185 |
-
logger.error(f"Error generating follow-up questions: {e}")
|
186 |
-
return ["Could you tell me more about when this started?"]
|
187 |
-
|
188 |
-
def assess_symptom_severity(self, message: str) -> str:
|
189 |
-
"""Assess severity based on keywords in the message"""
|
190 |
-
if "severe" in message.lower() or "emergency" in message.lower():
|
191 |
-
return "emergency"
|
192 |
-
elif "persistent" in message.lower() or "moderate" in message.lower():
|
193 |
-
return "urgent"
|
194 |
-
return "routine"
|
195 |
-
|
196 |
-
def generate_response(self, message: str) -> Dict[str, Any]:
|
197 |
-
"""Generate a response based on the message"""
|
198 |
-
try:
|
199 |
-
severity = self.assess_symptom_severity(message)
|
200 |
-
response = ""
|
201 |
-
|
202 |
-
# Retrieve relevant documents from FAISS
|
203 |
-
query_embedding = self.embedding_model.encode([message])
|
204 |
-
_, indices = self.index.search(query_embedding, k=5)
|
205 |
-
relevant_docs = [self.documents[idx]['text'] for idx in indices[0]]
|
206 |
-
|
207 |
-
prompt = f"""As a compassionate medical assistant, analyze the patient message: "{message}".
|
208 |
-
Consider relevant knowledge and the following documents:\n{relevant_docs}.
|
209 |
-
Respond with empathy, follow-up questions, and care guidance."""
|
210 |
-
|
211 |
-
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device)
|
212 |
-
outputs = self.model.generate(inputs['input_ids'], max_new_tokens=100, temperature=0.7, do_sample=True)
|
213 |
-
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
214 |
-
|
215 |
-
follow_ups = self.generate_follow_up_questions(message, {})
|
216 |
-
response += f"\n{follow_ups[0]}"
|
217 |
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
def create_demo():
|
251 |
"""Set up Gradio interface for the chatbot with enhanced styling and functionality."""
|
@@ -607,17 +828,8 @@ def create_demo():
|
|
607 |
raise
|
608 |
|
609 |
if __name__ == "__main__":
|
610 |
-
#
|
611 |
-
|
612 |
-
|
613 |
-
# Set up Hugging Face login if token exists
|
614 |
-
hf_token = os.getenv("HF_TOKEN")
|
615 |
-
if hf_token:
|
616 |
-
login(token=hf_token)
|
617 |
-
|
618 |
-
# Launch demo
|
619 |
-
os.environ.pop("HF_HUB_OFFLINE", None) # Ensure online mode
|
620 |
-
demo = create_demo()
|
621 |
demo.launch(share=True)
|
622 |
|
623 |
|
|
|
1 |
# app.py
|
2 |
import os
|
3 |
+
import json
|
4 |
+
import keras
|
5 |
+
from datasets import load_dataset
|
6 |
+
import tensorflow as tf
|
7 |
+
from huggingface_hub import login
|
8 |
import torch
|
9 |
+
from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer)
|
|
|
|
|
10 |
from sentence_transformers import SentenceTransformer
|
11 |
+
from typing import List, Dict, Union, Tuple
|
12 |
import faiss
|
13 |
import numpy as np
|
14 |
+
from datasets import Dataset
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.cuda.amp import autocast
|
17 |
+
import gc
|
18 |
+
from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel)
|
19 |
+
from tqdm.auto import tqdm
|
20 |
+
from torch.utils.data import DataLoader
|
21 |
+
import logging
|
22 |
+
import wandb
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import List, Dict, Union, Optional, Any
|
25 |
+
import torch.nn as nn
|
26 |
+
from dataclasses import dataclass, field
|
27 |
+
import time
|
28 |
+
import asyncio
|
29 |
+
import pytest
|
30 |
+
from unittest.mock import Mock, patch
|
31 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
32 |
+
import gradio as gr
|
33 |
+
import matplotlib.pyplot as plt
|
34 |
from datetime import datetime
|
35 |
+
import requests
|
36 |
+
import pandas as pd
|
37 |
+
import seaborn as sns
|
38 |
+
import traceback
|
39 |
+
from matplotlib.gridspec import GridSpec
|
40 |
+
from datasets import load_dataset, concatenate_datasets
|
41 |
+
from langchain.vectorstores import FAISS
|
42 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
43 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
44 |
+
from langchain.document_loaders import TextLoader
|
45 |
+
from google.colab import output
|
46 |
+
import IPython.display as display
|
47 |
+
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
|
48 |
+
|
49 |
+
|
50 |
+
# Ensure Hugging Face login
|
51 |
try:
|
52 |
+
hf_token = os.getenv("HF_TOKEN")
|
53 |
+
if hf_token:
|
54 |
+
login(token=hf_token)
|
55 |
+
print("Login successful!")
|
56 |
except Exception as e:
|
57 |
print("Hugging Face Login failed:", e)
|
58 |
|
59 |
+
# CUDA and Memory Configurations
|
60 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
61 |
+
torch.backends.cudnn.allow_tf32 = False
|
62 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True'
|
63 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
64 |
+
|
65 |
|
66 |
+
def prepare_initial_datasets(batch_size=8):
|
67 |
+
print("Loading datasets with memory-optimized batch processing...")
|
68 |
+
|
69 |
+
def process_medqa_batch(examples):
|
70 |
+
results = []
|
71 |
+
inputs = examples['input']
|
72 |
+
instructions = examples['instruction']
|
73 |
+
outputs = examples['output']
|
74 |
+
|
75 |
+
for inp, inst, out in zip(inputs, instructions, outputs):
|
76 |
+
results.append({
|
77 |
+
"input": f"{inp} {inst}",
|
78 |
+
"output": out
|
79 |
+
})
|
80 |
+
return results
|
81 |
+
|
82 |
+
def process_meddia_batch(examples):
|
83 |
+
results = []
|
84 |
+
inputs = examples['input']
|
85 |
+
outputs = examples['output']
|
86 |
+
|
87 |
+
for inp, out in zip(inputs, outputs):
|
88 |
+
results.append({
|
89 |
+
"input": inp,
|
90 |
+
"output": out
|
91 |
+
})
|
92 |
+
return results
|
93 |
+
|
94 |
+
def process_persona_batch(examples):
|
95 |
+
results = []
|
96 |
+
personalities = examples['personality']
|
97 |
+
utterances = examples['utterances']
|
98 |
+
|
99 |
+
for pers, utts in zip(personalities, utterances):
|
100 |
+
try:
|
101 |
+
# Process personality list
|
102 |
+
personality = ' '.join([
|
103 |
+
p for p in pers
|
104 |
+
if isinstance(p, str)
|
105 |
+
])
|
106 |
+
|
107 |
+
# Process utterances
|
108 |
+
if utts and len(utts) > 0:
|
109 |
+
utterance = utts[0]
|
110 |
+
history = []
|
111 |
+
|
112 |
+
# Process history
|
113 |
+
if 'history' in utterance and utterance['history']:
|
114 |
+
history = [
|
115 |
+
h for h in utterance['history']
|
116 |
+
if isinstance(h, str)
|
117 |
+
]
|
118 |
+
|
119 |
+
history_text = ' '.join(history)
|
120 |
+
|
121 |
+
# Get candidate response
|
122 |
+
candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else ''
|
123 |
+
|
124 |
+
if personality or history_text:
|
125 |
+
results.append({
|
126 |
+
"input": f"{personality} {history_text}".strip(),
|
127 |
+
"output": candidate
|
128 |
+
})
|
129 |
+
except Exception as e:
|
130 |
+
print(f"Error processing persona batch item: {e}")
|
131 |
+
continue
|
132 |
|
133 |
+
return results
|
134 |
+
|
135 |
+
# Load and process each dataset separately
|
136 |
+
print("Processing MedQA dataset...")
|
137 |
+
medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]")
|
138 |
+
medqa_processed = []
|
139 |
|
140 |
+
for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"):
|
141 |
+
batch = medqa[i:i + batch_size]
|
142 |
+
medqa_processed.extend(process_medqa_batch(batch))
|
143 |
+
if i % (batch_size * 5) == 0:
|
144 |
+
torch.cuda.empty_cache()
|
|
|
|
|
145 |
|
146 |
+
print("Processing MedDiagnosis dataset...")
|
147 |
+
meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]")
|
148 |
+
meddia_processed = []
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
+
for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"):
|
151 |
+
batch = meddia[i:i + batch_size]
|
152 |
+
meddia_processed.extend(process_meddia_batch(batch))
|
153 |
+
if i % (batch_size * 5) == 0:
|
154 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
+
print("Processing Persona-Chat dataset...")
|
157 |
+
persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
|
158 |
+
persona_processed = []
|
159 |
+
|
160 |
+
for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"):
|
161 |
+
batch = persona[i:i + batch_size]
|
162 |
+
persona_processed.extend(process_persona_batch(batch))
|
163 |
+
if i % (batch_size * 5) == 0:
|
164 |
+
torch.cuda.empty_cache()
|
165 |
+
|
166 |
+
torch.cuda.empty_cache()
|
167 |
+
|
168 |
+
print("Creating final dataset...")
|
169 |
+
all_processed = persona_processed + medqa_processed + meddia_processed
|
170 |
+
|
171 |
+
valid_data = {
|
172 |
+
"input": [],
|
173 |
+
"output": []
|
174 |
+
}
|
175 |
+
|
176 |
+
for item in all_processed:
|
177 |
+
if item["input"].strip() and item["output"].strip():
|
178 |
+
valid_data["input"].append(item["input"])
|
179 |
+
valid_data["output"].append(item["output"])
|
180 |
+
|
181 |
+
final_dataset = Dataset.from_dict(valid_data)
|
182 |
+
|
183 |
+
print(f"Final dataset size: {len(final_dataset)}")
|
184 |
+
return final_dataset
|
185 |
+
|
186 |
+
def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4):
|
187 |
+
def tokenize_batch(examples):
|
188 |
+
formatted_texts = []
|
189 |
+
|
190 |
+
for i in range(0, len(examples['input']), batch_size):
|
191 |
+
sub_batch_inputs = examples['input'][i:i + batch_size]
|
192 |
+
sub_batch_outputs = examples['output'][i:i + batch_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
+
for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs):
|
195 |
+
try:
|
196 |
+
formatted_text = f"""<start_of_turn>user
|
197 |
+
{input_text}
|
198 |
+
<end_of_turn>
|
199 |
+
<start_of_turn>assistant
|
200 |
+
{output_text}
|
201 |
+
<end_of_turn>"""
|
202 |
+
formatted_texts.append(formatted_text)
|
203 |
+
except Exception as e:
|
204 |
+
print(f"Error formatting text: {e}")
|
205 |
+
continue
|
206 |
+
|
207 |
+
tokenized = tokenizer(
|
208 |
+
formatted_texts,
|
209 |
+
padding="max_length",
|
210 |
+
truncation=True,
|
211 |
+
max_length=max_length,
|
212 |
+
return_tensors=None
|
213 |
+
)
|
214 |
+
|
215 |
+
tokenized["labels"] = tokenized["input_ids"].copy()
|
216 |
+
return tokenized
|
217 |
+
|
218 |
+
print(f"Tokenizing dataset in small batches (size={batch_size})...")
|
219 |
+
tokenized_dataset = dataset.map(
|
220 |
+
tokenize_batch,
|
221 |
+
batched=True,
|
222 |
+
batch_size=batch_size,
|
223 |
+
remove_columns=dataset.column_names,
|
224 |
+
desc="Tokenizing dataset",
|
225 |
+
load_from_cache_file=False
|
226 |
+
)
|
227 |
+
|
228 |
+
return tokenized_dataset
|
229 |
|
230 |
+
def setup_model_and_tokenizer(model_name="google/gemma-2b"):
|
231 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
232 |
+
tokenizer.pad_token = tokenizer.eos_token
|
233 |
+
|
234 |
+
from transformers import BitsAndBytesConfig
|
235 |
+
|
236 |
+
bnb_config = BitsAndBytesConfig(
|
237 |
+
load_in_8bit=True,
|
238 |
+
bnb_8bit_compute_dtype=torch.float16,
|
239 |
+
llm_int8_enable_fp32_cpu_offload=True
|
240 |
+
)
|
241 |
+
|
242 |
+
model = AutoModelForCausalLM.from_pretrained(
|
243 |
+
model_name,
|
244 |
+
device_map="auto",
|
245 |
+
quantization_config=bnb_config,
|
246 |
+
torch_dtype=torch.float16,
|
247 |
+
low_cpu_mem_usage=True
|
248 |
+
)
|
249 |
+
|
250 |
+
model = prepare_model_for_kbit_training(model)
|
251 |
+
|
252 |
+
lora_config = LoraConfig(
|
253 |
+
r=4,
|
254 |
+
lora_alpha=16,
|
255 |
+
target_modules=["q_proj", "v_proj"],
|
256 |
+
lora_dropout=0.05,
|
257 |
+
bias="none",
|
258 |
+
task_type="CAUSAL_LM"
|
259 |
+
)
|
260 |
+
|
261 |
+
model = get_peft_model(model, lora_config)
|
262 |
+
model.print_trainable_parameters()
|
263 |
+
|
264 |
+
return model, tokenizer
|
265 |
+
|
266 |
+
def setup_training_arguments(output_dir="./pearly_fine_tuned"):
|
267 |
+
return TrainingArguments(
|
268 |
+
output_dir=output_dir,
|
269 |
+
num_train_epochs=1,
|
270 |
+
per_device_train_batch_size=1,
|
271 |
+
gradient_accumulation_steps=16,
|
272 |
+
warmup_steps=50,
|
273 |
+
logging_steps=10,
|
274 |
+
save_steps=200,
|
275 |
+
learning_rate=2e-4,
|
276 |
+
fp16=True,
|
277 |
+
gradient_checkpointing=True,
|
278 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
279 |
+
optim="adamw_8bit",
|
280 |
+
max_grad_norm=0.3,
|
281 |
+
weight_decay=0.001,
|
282 |
+
logging_dir="./logs",
|
283 |
+
save_total_limit=2,
|
284 |
+
remove_unused_columns=False,
|
285 |
+
dataloader_pin_memory=False,
|
286 |
+
max_steps=500,
|
287 |
+
report_to=["none"],
|
288 |
+
)
|
289 |
+
|
290 |
+
def main():
|
291 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
292 |
+
torch.backends.cudnn.allow_tf32 = False
|
293 |
+
|
294 |
+
torch.cuda.empty_cache()
|
295 |
+
if torch.cuda.is_available():
|
296 |
+
torch.cuda.reset_peak_memory_stats()
|
297 |
+
|
298 |
+
print("Preparing initial datasets...")
|
299 |
+
combined_dataset = prepare_initial_datasets(batch_size=4)
|
300 |
+
|
301 |
+
print(f"\nDataset size: {len(combined_dataset)}")
|
302 |
+
print(f"Column names: {combined_dataset.column_names}")
|
303 |
+
|
304 |
+
if len(combined_dataset) > 0:
|
305 |
+
print("\nSample input-output pair:")
|
306 |
+
print(f"Input: {combined_dataset[0]['input'][:100]}...")
|
307 |
+
print(f"Output: {combined_dataset[0]['output'][:100]}...")
|
308 |
+
|
309 |
+
print("\nSetting up model and tokenizer...")
|
310 |
+
model, tokenizer = setup_model_and_tokenizer()
|
311 |
+
|
312 |
+
print("\nPreparing dataset for training...")
|
313 |
+
processed_dataset = prepare_dataset(
|
314 |
+
combined_dataset,
|
315 |
+
tokenizer,
|
316 |
+
max_length=256,
|
317 |
+
batch_size=2
|
318 |
+
)
|
319 |
+
|
320 |
+
torch.cuda.empty_cache()
|
321 |
+
|
322 |
+
training_args = setup_training_arguments()
|
323 |
+
|
324 |
+
trainer = Trainer(
|
325 |
+
model=model,
|
326 |
+
args=training_args,
|
327 |
+
train_dataset=processed_dataset,
|
328 |
+
tokenizer=tokenizer,
|
329 |
+
)
|
330 |
+
|
331 |
+
print("\nStarting training...")
|
332 |
+
try:
|
333 |
+
trainer.train()
|
334 |
+
except Exception as e:
|
335 |
+
print(f"Training error: {e}")
|
336 |
+
torch.cuda.empty_cache()
|
337 |
+
raise e
|
338 |
+
finally:
|
339 |
+
torch.cuda.empty_cache()
|
340 |
+
|
341 |
+
print("\nSaving model...")
|
342 |
+
trainer.save_model()
|
343 |
+
print("Training completed!")
|
344 |
+
|
345 |
+
DISCLAIMER = """
|
346 |
+
IMPORTANT MEDICAL DISCLAIMER:
|
347 |
+
Pearly is an AI medical triage assistant designed to help direct you to appropriate medical services.
|
348 |
+
Pearly DOES NOT:
|
349 |
+
- Make medical diagnoses
|
350 |
+
- Prescribe medications
|
351 |
+
- Provide specific treatment recommendations
|
352 |
+
- Replace professional medical advice
|
353 |
+
|
354 |
+
Always consult qualified healthcare professionals for medical advice and treatment.
|
355 |
+
In case of emergency, call 999 immediately.
|
356 |
+
"""
|
357 |
+
|
358 |
+
class PearlyBot:
|
359 |
+
def __init__(self, model_path="./pearly_fine_tuned", embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
|
360 |
+
print("Loading saved model...")
|
361 |
+
print(DISCLAIMER)
|
362 |
+
|
363 |
+
# Clean memory
|
364 |
+
if torch.cuda.is_available():
|
365 |
+
torch.cuda.empty_cache()
|
366 |
+
|
367 |
+
# Load tokenizer and model directly from saved path
|
368 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
369 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
370 |
+
model_path,
|
371 |
+
torch_dtype=torch.float16,
|
372 |
+
low_cpu_mem_usage=True,
|
373 |
+
device_map="auto"
|
374 |
+
)
|
375 |
+
|
376 |
+
self.model.eval() # Set to evaluation mode
|
377 |
+
|
378 |
+
# Initialize RAG components
|
379 |
+
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
|
380 |
+
self.vector_store = None
|
381 |
+
self.conversation_history = []
|
382 |
+
|
383 |
+
def initialize_rag(self, documents_path="./knowledge_base"):
|
384 |
+
"""Initialize RAG system"""
|
385 |
+
print("Loading knowledge base...")
|
386 |
+
|
387 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
388 |
+
chunk_size=300,
|
389 |
+
chunk_overlap=100,
|
390 |
+
separators=["\n\n", "\n", ".", "!", "?", ":"]
|
391 |
+
)
|
392 |
+
|
393 |
+
documents = []
|
394 |
+
for filename in os.listdir(documents_path):
|
395 |
+
if filename.endswith('.txt'):
|
396 |
+
loader = TextLoader(os.path.join(documents_path, filename))
|
397 |
+
documents.extend(loader.load())
|
398 |
+
|
399 |
+
texts = text_splitter.split_documents(documents)
|
400 |
+
self.vector_store = FAISS.from_documents(texts, self.embeddings)
|
401 |
+
self.retriever = self.vector_store.as_retriever(
|
402 |
+
search_type="similarity",
|
403 |
+
search_kwargs={"k": 5}
|
404 |
+
)
|
405 |
+
print("Knowledge base loaded successfully!")
|
406 |
+
|
407 |
+
def get_relevant_context(self, user_input):
|
408 |
+
if not self.retriever:
|
409 |
+
return ""
|
410 |
+
docs = self.retriever.get_relevant_documents(user_input)
|
411 |
+
return "\n\n".join([doc.page_content for doc in docs])
|
412 |
+
|
413 |
+
def generate_response(self, user_input):
|
414 |
+
context = self.get_relevant_context(user_input)
|
415 |
+
history = "\n".join([
|
416 |
+
f"User: {turn['user']}\nAssistant: {turn['assistant']}\n"
|
417 |
+
for turn in self.conversation_history[-3:]
|
418 |
+
])
|
419 |
+
|
420 |
+
prompt = f"""<start_of_turn>system
|
421 |
+
As Pearly, I use the following medical guidelines to help triage patients:
|
422 |
+
|
423 |
+
{context}
|
424 |
+
|
425 |
+
Previous Conversation:
|
426 |
+
{history}
|
427 |
+
|
428 |
+
Based on these guidelines, I will:
|
429 |
+
1. Assess symptoms and severity
|
430 |
+
2. Ask relevant follow-up questions
|
431 |
+
3. Direct to appropriate care (999, 111, or GP)
|
432 |
+
4. Show empathy and cultural sensitivity
|
433 |
+
5. Never diagnose or recommend treatments
|
434 |
+
<end_of_turn>
|
435 |
+
<start_of_turn>user
|
436 |
+
{user_input}
|
437 |
+
<end_of_turn>
|
438 |
+
<start_of_turn>assistant"""
|
439 |
+
|
440 |
+
inputs = self.tokenizer(
|
441 |
+
prompt,
|
442 |
+
return_tensors="pt",
|
443 |
+
truncation=True,
|
444 |
+
max_length=512
|
445 |
+
).to(self.model.device)
|
446 |
+
|
447 |
+
with torch.no_grad():
|
448 |
+
outputs = self.model.generate(
|
449 |
+
**inputs,
|
450 |
+
max_new_tokens=256,
|
451 |
+
min_new_tokens=20,
|
452 |
+
do_sample=True,
|
453 |
+
temperature=0.7,
|
454 |
+
top_p=0.9,
|
455 |
+
repetition_penalty=1.2,
|
456 |
+
pad_token_id=self.tokenizer.pad_token_id
|
457 |
+
)
|
458 |
+
|
459 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
460 |
+
response = response.split("<start_of_turn>assistant")[-1].strip()
|
461 |
+
if "<end_of_turn>" in response:
|
462 |
+
response = response.split("<end_of_turn>")[0].strip()
|
463 |
+
|
464 |
+
self.conversation_history.append({
|
465 |
+
"user": user_input,
|
466 |
+
"assistant": response
|
467 |
+
})
|
468 |
+
|
469 |
+
return response
|
470 |
|
471 |
def create_demo():
|
472 |
"""Set up Gradio interface for the chatbot with enhanced styling and functionality."""
|
|
|
828 |
raise
|
829 |
|
830 |
if __name__ == "__main__":
|
831 |
+
load_dotenv() # Load environment variables
|
832 |
+
demo = create_demo() # Launch the Gradio app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
833 |
demo.launch(share=True)
|
834 |
|
835 |
|