Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,164 +1,271 @@
|
|
1 |
-
# Standard
|
2 |
import os
|
3 |
-
import
|
4 |
-
import time
|
5 |
-
import asyncio
|
6 |
import logging
|
7 |
-
import gc
|
8 |
-
import re
|
9 |
-
import traceback
|
10 |
-
from pathlib import Path
|
11 |
from datetime import datetime
|
12 |
-
from typing import List, Dict, Union, Tuple, Optional, Any
|
13 |
-
from dataclasses import dataclass, field
|
14 |
-
import zipfile
|
15 |
-
|
16 |
-
# Machine Learning and Deep Learning Libraries
|
17 |
-
import torch
|
18 |
-
import torch.nn as nn
|
19 |
-
import torch.nn.functional as F
|
20 |
-
from torch.cuda.amp import autocast
|
21 |
-
from torch.utils.data import DataLoader
|
22 |
-
|
23 |
-
# Hugging Face and Transformers
|
24 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
25 |
-
from sentence_transformers import SentenceTransformer
|
26 |
-
from datasets import load_dataset, Dataset, concatenate_datasets
|
27 |
from huggingface_hub import login
|
28 |
-
|
29 |
-
# FAISS and PEFT
|
30 |
-
import faiss
|
31 |
-
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel
|
32 |
-
|
33 |
-
# LangChain - updated imports as per recent deprecations
|
34 |
-
from langchain_community.vectorstores import FAISS # Updated import
|
35 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import
|
36 |
-
from langchain_community.document_loaders import TextLoader # Updated import
|
37 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
38 |
-
|
39 |
-
|
40 |
-
# External Tools and APIs
|
41 |
-
import wandb
|
42 |
-
import requests
|
43 |
-
import gradio as gr
|
44 |
-
import IPython.display as display # Required for IPython display functionality
|
45 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
from tqdm.auto import tqdm
|
47 |
|
48 |
-
# Suppress Warnings
|
49 |
-
import warnings
|
50 |
-
warnings.filterwarnings('ignore')
|
51 |
-
|
52 |
-
|
53 |
-
# Ensure Hugging Face login
|
54 |
-
try:
|
55 |
-
hf_token = os.getenv("HF_TOKEN")
|
56 |
-
if hf_token:
|
57 |
-
login(token=hf_token)
|
58 |
-
print("Login successful!")
|
59 |
-
except Exception as e:
|
60 |
-
print("Hugging Face Login failed:", e)
|
61 |
-
|
62 |
-
|
63 |
-
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True'
|
64 |
-
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
# Setup logging
|
69 |
logging.basicConfig(level=logging.INFO)
|
70 |
logger = logging.getLogger(__name__)
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
class ModelManager:
|
75 |
-
"""Handles model loading and resource management"""
|
76 |
-
|
77 |
-
@staticmethod
|
78 |
-
def verify_and_extract_model(checkpoint_zip_path: str, extracted_model_dir: str) -> str:
|
79 |
-
"""Verify and extract the model if it's not already extracted"""
|
80 |
-
if not os.path.exists(extracted_model_dir):
|
81 |
-
# Unzip the model if it hasn’t been extracted yet
|
82 |
-
with zipfile.ZipFile(checkpoint_zip_path, 'r') as zip_ref:
|
83 |
-
zip_ref.extractall(extracted_model_dir)
|
84 |
-
logger.info(f"Extracted model to: {extracted_model_dir}")
|
85 |
-
else:
|
86 |
-
logger.info(f"Model already extracted: {extracted_model_dir}")
|
87 |
-
|
88 |
-
return extracted_model_dir
|
89 |
|
90 |
@staticmethod
|
91 |
-
def
|
92 |
-
"""
|
93 |
-
if torch.cuda.is_available():
|
94 |
-
torch.cuda.empty_cache()
|
95 |
-
gc.collect()
|
96 |
-
|
97 |
-
class PearlyBot:
|
98 |
-
def __init__(self):
|
99 |
-
try:
|
100 |
-
# Use the correct model path from your space
|
101 |
-
self.repo_id = "Pearilsa/pearly_med_triage_chatbot_kagglex"
|
102 |
-
self.model_filename = "pearly_model.zip"
|
103 |
-
self.setup_model()
|
104 |
-
self.setup_rag()
|
105 |
-
self.conversation_history = []
|
106 |
-
self.last_interaction_time = time.time()
|
107 |
-
self.interaction_cooldown = 1.0
|
108 |
-
except Exception as e:
|
109 |
-
logger.error(f"Error initializing bot: {e}")
|
110 |
-
raise
|
111 |
-
|
112 |
-
def setup_model(self):
|
113 |
-
"""Initialize model from Hugging Face space"""
|
114 |
try:
|
115 |
-
|
|
|
116 |
|
117 |
-
#
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
120 |
|
121 |
-
#
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
self.model.eval()
|
141 |
-
logger.info("Model loaded successfully")
|
142 |
|
143 |
except Exception as e:
|
144 |
-
logger.error(f"Error
|
145 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
def setup_rag(self):
|
|
|
149 |
try:
|
150 |
-
|
151 |
-
self.chunk_size = 300
|
152 |
-
self.chunk_overlap = 100
|
153 |
-
self.num_relevant_chunks = 3
|
154 |
|
155 |
# Load knowledge base
|
156 |
knowledge_base = self._load_knowledge_base()
|
157 |
|
158 |
-
# Setup embeddings
|
159 |
self.embeddings = self._initialize_embeddings()
|
160 |
|
161 |
-
#
|
162 |
texts = self._split_texts(knowledge_base)
|
163 |
|
164 |
# Create vector store with metadata
|
@@ -168,13 +275,15 @@ class PearlyBot:
|
|
168 |
metadatas=[{"source": f"chunk_{i}"} for i in range(len(texts))]
|
169 |
)
|
170 |
|
171 |
-
#
|
172 |
self._validate_rag_setup()
|
|
|
173 |
|
174 |
except Exception as e:
|
175 |
-
logger.error(f"
|
176 |
raise
|
177 |
-
|
|
|
178 |
def _load_knowledge_base(self):
|
179 |
"""Load and validate knowledge base content"""
|
180 |
try:
|
@@ -488,6 +597,156 @@ class PearlyBot:
|
|
488 |
except Exception as e:
|
489 |
logger.error(f"RAG system validation failed: {str(e)}")
|
490 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
|
492 |
def _initialize_embeddings(self):
|
493 |
try:
|
@@ -525,24 +784,22 @@ class PearlyBot:
|
|
525 |
def generate_response(self, message: str, history: list) -> str:
|
526 |
"""Generate response using both fine-tuned model and RAG"""
|
527 |
try:
|
528 |
-
# Rate limiting
|
529 |
current_time = time.time()
|
530 |
if current_time - self.last_interaction_time < self.interaction_cooldown:
|
531 |
time.sleep(self.interaction_cooldown)
|
|
|
532 |
|
533 |
-
#
|
534 |
-
|
535 |
-
|
536 |
-
# Get RAG context
|
537 |
-
context = self.get_relevant_context(message)
|
538 |
|
539 |
# Format conversation history
|
540 |
conv_history = "\n".join([
|
541 |
-
f"User: {
|
542 |
-
for
|
543 |
])
|
544 |
|
545 |
-
# Create prompt
|
546 |
prompt = f"""<start_of_turn>system
|
547 |
Using these medical guidelines:
|
548 |
|
@@ -552,9 +809,9 @@ Previous conversation:
|
|
552 |
{conv_history}
|
553 |
|
554 |
Guidelines:
|
555 |
-
1. Assess symptoms and severity
|
556 |
-
2. Ask relevant follow-up questions
|
557 |
-
3. Direct to appropriate care (999, 111, or GP)
|
558 |
4. Show empathy and cultural sensitivity
|
559 |
5. Never diagnose or recommend treatments
|
560 |
<end_of_turn>
|
@@ -563,41 +820,36 @@ Guidelines:
|
|
563 |
<end_of_turn>
|
564 |
<start_of_turn>assistant"""
|
565 |
|
566 |
-
# Generate response
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
except torch.cuda.OutOfMemoryError:
|
595 |
-
ModelManager.clear_gpu_memory()
|
596 |
-
logger.error("GPU out of memory, cleared cache and retrying...")
|
597 |
-
return "I apologize, but I'm experiencing technical difficulties. Please try again."
|
598 |
-
|
599 |
except Exception as e:
|
600 |
-
logger.error(f"Error generating response: {
|
601 |
return "I apologize, but I encountered an error. Please try again."
|
602 |
|
603 |
def handle_feedback(self, message: str, response: str, feedback: int):
|
@@ -971,23 +1223,10 @@ def create_demo():
|
|
971 |
raise
|
972 |
|
973 |
if __name__ == "__main__":
|
974 |
-
|
975 |
-
|
976 |
-
|
977 |
-
|
978 |
-
|
979 |
-
|
980 |
-
|
981 |
-
# Create and launch demo
|
982 |
-
demo = create_demo()
|
983 |
-
demo.launch(
|
984 |
-
server_name="0.0.0.0",
|
985 |
-
server_port=7860,
|
986 |
-
show_error=True
|
987 |
-
)
|
988 |
-
|
989 |
-
except Exception as e:
|
990 |
-
logger.error(f"Application startup failed: {e}")
|
991 |
-
raise
|
992 |
-
|
993 |
-
|
|
|
1 |
+
# Standard imports first
|
2 |
import os
|
3 |
+
import torch
|
|
|
|
|
4 |
import logging
|
|
|
|
|
|
|
|
|
5 |
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from huggingface_hub import login
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from dotenv import load_dotenv
|
8 |
+
from datasets import load_dataset, Dataset
|
9 |
+
from transformers import (
|
10 |
+
AutoTokenizer,
|
11 |
+
AutoModelForCausalLM,
|
12 |
+
TrainingArguments,
|
13 |
+
Trainer,
|
14 |
+
BitsAndBytesConfig
|
15 |
+
)
|
16 |
+
from peft import (
|
17 |
+
LoraConfig,
|
18 |
+
get_peft_model,
|
19 |
+
prepare_model_for_kbit_training
|
20 |
+
)
|
21 |
from tqdm.auto import tqdm
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# Setup logging
|
24 |
logging.basicConfig(level=logging.INFO)
|
25 |
logger = logging.getLogger(__name__)
|
26 |
|
27 |
+
class SecretsManager:
|
28 |
+
"""Handles authentication and secrets management"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
@staticmethod
|
31 |
+
def setup_credentials():
|
32 |
+
"""Setup all required credentials"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
try:
|
34 |
+
# Load environment variables
|
35 |
+
load_dotenv()
|
36 |
|
37 |
+
# Get credentials
|
38 |
+
credentials = {
|
39 |
+
'KAGGLE_USERNAME': os.getenv('KAGGLE_USERNAME'),
|
40 |
+
'KAGGLE_KEY': os.getenv('KAGGLE_KEY'),
|
41 |
+
'HF_TOKEN': os.getenv('HF_TOKEN'),
|
42 |
+
'WANDB_KEY': os.getenv('WANDB_KEY')
|
43 |
+
}
|
44 |
|
45 |
+
# Validate credentials
|
46 |
+
missing_creds = [k for k, v in credentials.items() if not v]
|
47 |
+
if missing_creds:
|
48 |
+
logger.warning(f"Missing credentials: {', '.join(missing_creds)}")
|
49 |
+
|
50 |
+
# Setup Hugging Face authentication
|
51 |
+
if credentials['HF_TOKEN']:
|
52 |
+
login(token=credentials['HF_TOKEN'])
|
53 |
+
logger.info("Successfully logged in to Hugging Face")
|
54 |
+
# Setup Kaggle credentials if available
|
55 |
+
if credentials['KAGGLE_USERNAME'] and credentials['KAGGLE_KEY']:
|
56 |
+
os.environ['KAGGLE_USERNAME'] = credentials['KAGGLE_USERNAME']
|
57 |
+
os.environ['KAGGLE_KEY'] = credentials['KAGGLE_KEY']
|
58 |
+
|
59 |
+
# Setup wandb if available
|
60 |
+
if credentials['WANDB_KEY']:
|
61 |
+
os.environ['WANDB_API_KEY'] = credentials['WANDB_KEY']
|
62 |
+
|
63 |
+
return credentials
|
|
|
|
|
64 |
|
65 |
except Exception as e:
|
66 |
+
logger.error(f"Error setting up credentials: {e}")
|
67 |
raise
|
68 |
+
class ModelTrainer:
|
69 |
+
"""Handles model training pipeline"""
|
70 |
+
|
71 |
+
def __init__(self):
|
72 |
+
# Set memory optimization environment variables
|
73 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True'
|
74 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
75 |
+
|
76 |
+
# Initialize attributes
|
77 |
+
self.model = None
|
78 |
+
self.tokenizer = None
|
79 |
+
self.dataset = None
|
80 |
+
self.processed_dataset = None
|
81 |
+
self.chunk_size = 300
|
82 |
+
self.chunk_overlap = 100
|
83 |
+
self.num_relevant_chunks = 3
|
84 |
+
self.vector_store = None
|
85 |
+
self.embeddings = None
|
86 |
+
self.last_interaction_time = time.time() # Add this
|
87 |
+
self.interaction_cooldown = 1.0 # Add this
|
88 |
+
|
89 |
+
# Setup GPU preferences
|
90 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
91 |
+
torch.backends.cudnn.allow_tf32 = False
|
92 |
+
|
93 |
+
def prepare_initial_datasets(batch_size=8):
|
94 |
+
print("Loading datasets with memory-optimized batch processing...")
|
95 |
+
|
96 |
+
def process_medqa_batch(examples):
|
97 |
+
results = []
|
98 |
+
inputs = examples['input']
|
99 |
+
instructions = examples['instruction']
|
100 |
+
outputs = examples['output']
|
101 |
+
|
102 |
+
for inp, inst, out in zip(inputs, instructions, outputs):
|
103 |
+
results.append({
|
104 |
+
"input": f"{inp} {inst}",
|
105 |
+
"output": out
|
106 |
+
})
|
107 |
+
return results
|
108 |
+
|
109 |
+
def process_meddia_batch(examples):
|
110 |
+
results = []
|
111 |
+
inputs = examples['input']
|
112 |
+
outputs = examples['output']
|
113 |
+
|
114 |
+
for inp, out in zip(inputs, outputs):
|
115 |
+
results.append({
|
116 |
+
"input": inp,
|
117 |
+
"output": out
|
118 |
+
})
|
119 |
+
return results
|
120 |
+
|
121 |
+
def process_persona_batch(examples):
|
122 |
+
results = []
|
123 |
+
personalities = examples['personality']
|
124 |
+
utterances = examples['utterances']
|
125 |
+
|
126 |
+
for pers, utts in zip(personalities, utterances):
|
127 |
+
try:
|
128 |
+
# Process personality list
|
129 |
+
personality = ' '.join([
|
130 |
+
p for p in pers
|
131 |
+
if isinstance(p, str)
|
132 |
+
])
|
133 |
+
|
134 |
+
# Process utterances
|
135 |
+
if utts and len(utts) > 0:
|
136 |
+
utterance = utts[0]
|
137 |
+
history = []
|
138 |
+
|
139 |
+
# Process history
|
140 |
+
if 'history' in utterance and utterance['history']:
|
141 |
+
history = [
|
142 |
+
h for h in utterance['history']
|
143 |
+
if isinstance(h, str)
|
144 |
+
]
|
145 |
+
|
146 |
+
history_text = ' '.join(history)
|
147 |
+
|
148 |
+
# Get candidate response
|
149 |
+
candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else ''
|
150 |
+
|
151 |
+
if personality or history_text:
|
152 |
+
results.append({
|
153 |
+
"input": f"{personality} {history_text}".strip(),
|
154 |
+
"output": candidate
|
155 |
+
})
|
156 |
+
except Exception as e:
|
157 |
+
print(f"Error processing persona batch item: {e}")
|
158 |
+
continue
|
159 |
|
160 |
+
return results
|
161 |
+
try:
|
162 |
+
Load and process each dataset separately
|
163 |
+
print("Processing MedQA dataset...")
|
164 |
+
medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]")
|
165 |
+
medqa_processed = []
|
166 |
+
|
167 |
+
for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"):
|
168 |
+
batch = medqa[i:i + batch_size]
|
169 |
+
medqa_processed.extend(process_medqa_batch(batch))
|
170 |
+
if i % (batch_size * 5) == 0:
|
171 |
+
torch.cuda.empty_cache()
|
172 |
+
|
173 |
+
print("Processing MedDiagnosis dataset...")
|
174 |
+
meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]")
|
175 |
+
meddia_processed = []
|
176 |
+
|
177 |
+
for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"):
|
178 |
+
batch = meddia[i:i + batch_size]
|
179 |
+
meddia_processed.extend(process_meddia_batch(batch))
|
180 |
+
if i % (batch_size * 5) == 0:
|
181 |
+
torch.cuda.empty_cache()
|
182 |
+
|
183 |
+
print("Processing Persona-Chat dataset...")
|
184 |
+
persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
|
185 |
+
persona_processed = []
|
186 |
+
|
187 |
+
for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"):
|
188 |
+
batch = persona[i:i + batch_size]
|
189 |
+
persona_processed.extend(process_persona_batch(batch))
|
190 |
+
if i % (batch_size * 5) == 0:
|
191 |
+
torch.cuda.empty_cache()
|
192 |
+
|
193 |
+
torch.cuda.empty_cache()
|
194 |
+
|
195 |
+
print("Creating final dataset...")
|
196 |
+
all_processed = persona_processed + medqa_processed + meddia_processed
|
197 |
+
|
198 |
+
valid_data = {
|
199 |
+
"input": [],
|
200 |
+
"output": []
|
201 |
+
}
|
202 |
+
|
203 |
+
for item in all_processed:
|
204 |
+
if item["input"].strip() and item["output"].strip():
|
205 |
+
valid_data["input"].append(item["input"])
|
206 |
+
valid_data["output"].append(item["output"])
|
207 |
+
|
208 |
+
final_dataset = Dataset.from_dict(valid_data)
|
209 |
+
|
210 |
+
print(f"Final dataset size: {len(final_dataset)}")
|
211 |
+
return final_dataset
|
212 |
|
213 |
+
def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4):
|
214 |
+
def tokenize_batch(examples):
|
215 |
+
formatted_texts = []
|
216 |
+
|
217 |
+
for i in range(0, len(examples['input']), batch_size):
|
218 |
+
sub_batch_inputs = examples['input'][i:i + batch_size]
|
219 |
+
sub_batch_outputs = examples['output'][i:i + batch_size]
|
220 |
+
|
221 |
+
for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs):
|
222 |
+
try:
|
223 |
+
formatted_text = f"""<start_of_turn>user
|
224 |
+
{input_text}
|
225 |
+
<end_of_turn>
|
226 |
+
<start_of_turn>assistant
|
227 |
+
{output_text}
|
228 |
+
<end_of_turn>"""
|
229 |
+
formatted_texts.append(formatted_text)
|
230 |
+
except Exception as e:
|
231 |
+
print(f"Error formatting text: {e}")
|
232 |
+
continue
|
233 |
+
|
234 |
+
tokenized = tokenizer(
|
235 |
+
formatted_texts,
|
236 |
+
padding="max_length",
|
237 |
+
truncation=True,
|
238 |
+
max_length=max_length,
|
239 |
+
return_tensors=None
|
240 |
+
)
|
241 |
+
|
242 |
+
tokenized["labels"] = tokenized["input_ids"].copy()
|
243 |
+
return tokenized
|
244 |
+
|
245 |
+
print(f"Tokenizing dataset in small batches (size={batch_size})...")
|
246 |
+
tokenized_dataset = dataset.map(
|
247 |
+
tokenize_batch,
|
248 |
+
batched=True,
|
249 |
+
batch_size=batch_size,
|
250 |
+
remove_columns=dataset.column_names,
|
251 |
+
desc="Tokenizing dataset",
|
252 |
+
load_from_cache_file=False
|
253 |
+
)
|
254 |
+
|
255 |
+
return tokenized_dataset
|
256 |
+
|
257 |
def setup_rag(self):
|
258 |
+
"""Initialize RAG components"""
|
259 |
try:
|
260 |
+
logger.info("Setting up RAG system...")
|
|
|
|
|
|
|
261 |
|
262 |
# Load knowledge base
|
263 |
knowledge_base = self._load_knowledge_base()
|
264 |
|
265 |
+
# Setup embeddings
|
266 |
self.embeddings = self._initialize_embeddings()
|
267 |
|
268 |
+
# Process texts for vector store
|
269 |
texts = self._split_texts(knowledge_base)
|
270 |
|
271 |
# Create vector store with metadata
|
|
|
275 |
metadatas=[{"source": f"chunk_{i}"} for i in range(len(texts))]
|
276 |
)
|
277 |
|
278 |
+
# Validate RAG setup
|
279 |
self._validate_rag_setup()
|
280 |
+
logger.info("RAG system setup complete")
|
281 |
|
282 |
except Exception as e:
|
283 |
+
logger.error(f"Failed to setup RAG: {e}")
|
284 |
raise
|
285 |
+
|
286 |
+
# Load your knowledge base content
|
287 |
def _load_knowledge_base(self):
|
288 |
"""Load and validate knowledge base content"""
|
289 |
try:
|
|
|
597 |
except Exception as e:
|
598 |
logger.error(f"RAG system validation failed: {str(e)}")
|
599 |
raise
|
600 |
+
|
601 |
+
|
602 |
+
|
603 |
+
|
604 |
+
|
605 |
+
|
606 |
+
|
607 |
+
def setup_model_and_tokenizer(model_name="google/gemma-2b"):
|
608 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
609 |
+
tokenizer.pad_token = tokenizer.eos_token
|
610 |
+
|
611 |
+
from transformers import BitsAndBytesConfig
|
612 |
+
|
613 |
+
bnb_config = BitsAndBytesConfig(
|
614 |
+
load_in_8bit=True,
|
615 |
+
bnb_8bit_compute_dtype=torch.float16,
|
616 |
+
llm_int8_enable_fp32_cpu_offload=True
|
617 |
+
)
|
618 |
+
|
619 |
+
model = AutoModelForCausalLM.from_pretrained(
|
620 |
+
model_name,
|
621 |
+
device_map="auto",
|
622 |
+
quantization_config=bnb_config,
|
623 |
+
torch_dtype=torch.float16,
|
624 |
+
low_cpu_mem_usage=True
|
625 |
+
)
|
626 |
+
|
627 |
+
model = prepare_model_for_kbit_training(model)
|
628 |
+
|
629 |
+
lora_config = LoraConfig(
|
630 |
+
r=4,
|
631 |
+
lora_alpha=16,
|
632 |
+
target_modules=["q_proj", "v_proj"],
|
633 |
+
lora_dropout=0.05,
|
634 |
+
bias="none",
|
635 |
+
task_type="CAUSAL_LM"
|
636 |
+
)
|
637 |
+
|
638 |
+
model = get_peft_model(model, lora_config)
|
639 |
+
model.print_trainable_parameters()
|
640 |
+
|
641 |
+
return model, tokenizer
|
642 |
+
|
643 |
+
def setup_training_arguments(output_dir="./pearly_fine_tuned"):
|
644 |
+
return TrainingArguments(
|
645 |
+
output_dir=output_dir,
|
646 |
+
num_train_epochs=1,
|
647 |
+
per_device_train_batch_size=1,
|
648 |
+
gradient_accumulation_steps=16,
|
649 |
+
warmup_steps=50,
|
650 |
+
logging_steps=10,
|
651 |
+
save_steps=200,
|
652 |
+
learning_rate=2e-4,
|
653 |
+
fp16=True,
|
654 |
+
gradient_checkpointing=True,
|
655 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
656 |
+
optim="adamw_8bit",
|
657 |
+
max_grad_norm=0.3,
|
658 |
+
weight_decay=0.001,
|
659 |
+
logging_dir="./logs",
|
660 |
+
save_total_limit=2,
|
661 |
+
remove_unused_columns=False,
|
662 |
+
dataloader_pin_memory=False,
|
663 |
+
max_steps=500,
|
664 |
+
report_to=["none"],
|
665 |
+
)
|
666 |
+
|
667 |
+
def train(self):
|
668 |
+
"""Main training pipeline with RAG integration"""
|
669 |
+
try:
|
670 |
+
logger.info("Starting training pipeline")
|
671 |
+
|
672 |
+
# Clear GPU memory
|
673 |
+
torch.cuda.empty_cache()
|
674 |
+
if torch.cuda.is_available():
|
675 |
+
torch.cuda.reset_peak_memory_stats()
|
676 |
+
|
677 |
+
# Setup model, tokenizer, and RAG
|
678 |
+
logger.info("Setting up model components...")
|
679 |
+
self.model, self.tokenizer = self.setup_model_and_tokenizer()
|
680 |
+
self.setup_rag()
|
681 |
+
|
682 |
+
# Prepare and process datasets
|
683 |
+
logger.info("Preparing datasets...")
|
684 |
+
self.dataset = self.prepare_initial_datasets(batch_size=4)
|
685 |
+
self.processed_dataset = self.prepare_dataset(
|
686 |
+
self.dataset,
|
687 |
+
self.tokenizer,
|
688 |
+
max_length=256,
|
689 |
+
batch_size=2
|
690 |
+
)
|
691 |
+
|
692 |
+
# Train model
|
693 |
+
logger.info("Starting training...")
|
694 |
+
training_args = self.setup_training_arguments()
|
695 |
+
trainer = Trainer(
|
696 |
+
model=self.model,
|
697 |
+
args=training_args,
|
698 |
+
train_dataset=self.processed_dataset,
|
699 |
+
tokenizer=self.tokenizer
|
700 |
+
)
|
701 |
+
trainer.train()
|
702 |
+
|
703 |
+
# Save and push to hub
|
704 |
+
logger.info("Saving model...")
|
705 |
+
trainer.save_model()
|
706 |
+
if os.getenv('HF_TOKEN'):
|
707 |
+
trainer.push_to_hub(
|
708 |
+
"Pearilsa/pearly_med_triage_chatbot_kagglex",
|
709 |
+
private=True
|
710 |
+
)
|
711 |
+
|
712 |
+
logger.info("Training completed successfully!")
|
713 |
+
|
714 |
+
except Exception as e:
|
715 |
+
logger.error(f"Training failed: {e}")
|
716 |
+
raise
|
717 |
+
finally:
|
718 |
+
torch.cuda.empty_cache()
|
719 |
+
|
720 |
+
if __name__ == "__main__":
|
721 |
+
# Initialize trainer
|
722 |
+
trainer = ModelTrainer()
|
723 |
+
|
724 |
+
# Train model
|
725 |
+
trainer.train()
|
726 |
+
|
727 |
+
def _get_enhanced_context(self, query: str) -> str:
|
728 |
+
"""Get relevant context with scores"""
|
729 |
+
try:
|
730 |
+
# Get documents with similarity scores
|
731 |
+
docs_and_scores = self.vector_store.similarity_search_with_score(
|
732 |
+
query,
|
733 |
+
k=self.num_relevant_chunks
|
734 |
+
)
|
735 |
+
|
736 |
+
# Filter and format relevant contexts
|
737 |
+
relevant_contexts = []
|
738 |
+
for doc, score in docs_and_scores:
|
739 |
+
if score < 0.8: # Lower score means more relevant
|
740 |
+
source = doc.metadata.get('source', 'Unknown')
|
741 |
+
relevant_contexts.append(
|
742 |
+
f"[Source: {source}]\n{doc.page_content}"
|
743 |
+
)
|
744 |
+
|
745 |
+
return "\n\n".join(relevant_contexts) if relevant_contexts else ""
|
746 |
+
|
747 |
+
except Exception as e:
|
748 |
+
logger.error(f"Error retrieving enhanced context: {e}")
|
749 |
+
return ""
|
750 |
|
751 |
def _initialize_embeddings(self):
|
752 |
try:
|
|
|
784 |
def generate_response(self, message: str, history: list) -> str:
|
785 |
"""Generate response using both fine-tuned model and RAG"""
|
786 |
try:
|
787 |
+
# Rate limiting and memory management
|
788 |
current_time = time.time()
|
789 |
if current_time - self.last_interaction_time < self.interaction_cooldown:
|
790 |
time.sleep(self.interaction_cooldown)
|
791 |
+
torch.cuda.empty_cache()
|
792 |
|
793 |
+
# Get enhanced context from RAG
|
794 |
+
context = self._get_enhanced_context(message)
|
|
|
|
|
|
|
795 |
|
796 |
# Format conversation history
|
797 |
conv_history = "\n".join([
|
798 |
+
f"User: {turn['input']}\nAssistant: {turn['output']}"
|
799 |
+
for turn in history[-3:] # Keep last 3 turns
|
800 |
])
|
801 |
|
802 |
+
# Create enhanced prompt with RAG context
|
803 |
prompt = f"""<start_of_turn>system
|
804 |
Using these medical guidelines:
|
805 |
|
|
|
809 |
{conv_history}
|
810 |
|
811 |
Guidelines:
|
812 |
+
1. Assess symptoms and severity based on both your training and the provided guidelines
|
813 |
+
2. Ask relevant follow-up questions if needed
|
814 |
+
3. Direct to appropriate care (999, 111, or GP) according to symptom severity
|
815 |
4. Show empathy and cultural sensitivity
|
816 |
5. Never diagnose or recommend treatments
|
817 |
<end_of_turn>
|
|
|
820 |
<end_of_turn>
|
821 |
<start_of_turn>assistant"""
|
822 |
|
823 |
+
# Generate response with model
|
824 |
+
inputs = self.tokenizer(
|
825 |
+
prompt,
|
826 |
+
return_tensors="pt",
|
827 |
+
truncation=True,
|
828 |
+
max_length=512
|
829 |
+
).to(self.model.device)
|
830 |
+
|
831 |
+
outputs = self.model.generate(
|
832 |
+
**inputs,
|
833 |
+
max_new_tokens=256,
|
834 |
+
min_new_tokens=20,
|
835 |
+
do_sample=True,
|
836 |
+
temperature=0.7,
|
837 |
+
top_p=0.9,
|
838 |
+
repetition_penalty=1.2,
|
839 |
+
no_repeat_ngram_size=3
|
840 |
+
)
|
841 |
+
|
842 |
+
# Process response
|
843 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
844 |
+
response = response.split("<start_of_turn>assistant")[-1].strip()
|
845 |
+
if "<end_of_turn>" in response:
|
846 |
+
response = response.split("<end_of_turn>")[0].strip()
|
847 |
+
|
848 |
+
self.last_interaction_time = time.time()
|
849 |
+
return response
|
850 |
+
|
|
|
|
|
|
|
|
|
|
|
851 |
except Exception as e:
|
852 |
+
logger.error(f"Error generating response: {e}")
|
853 |
return "I apologize, but I encountered an error. Please try again."
|
854 |
|
855 |
def handle_feedback(self, message: str, response: str, feedback: int):
|
|
|
1223 |
raise
|
1224 |
|
1225 |
if __name__ == "__main__":
|
1226 |
+
# Initialize logging and load env vars
|
1227 |
+
logging.basicConfig(level=logging.INFO)
|
1228 |
+
load_dotenv()
|
1229 |
+
|
1230 |
+
# Create and launch demo
|
1231 |
+
demo = create_demo()
|
1232 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|