Spaces:
Runtime error
Runtime error
# Standard imports first | |
import os | |
import torch | |
import logging | |
from datetime import datetime | |
from huggingface_hub import login | |
from dotenv import load_dotenv | |
from datasets import load_dataset, Dataset | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TrainingArguments, | |
Trainer, | |
BitsAndBytesConfig | |
) | |
from peft import ( | |
LoraConfig, | |
get_peft_model, | |
prepare_model_for_kbit_training | |
) | |
from tqdm.auto import tqdm | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class SecretsManager: | |
"""Handles authentication and secrets management""" | |
def setup_credentials(): | |
"""Setup all required credentials""" | |
try: | |
# Load environment variables | |
load_dotenv() | |
# Get credentials | |
credentials = { | |
'KAGGLE_USERNAME': os.getenv('KAGGLE_USERNAME'), | |
'KAGGLE_KEY': os.getenv('KAGGLE_KEY'), | |
'HF_TOKEN': os.getenv('HF_TOKEN'), | |
'WANDB_KEY': os.getenv('WANDB_KEY') | |
} | |
# Validate credentials | |
missing_creds = [k for k, v in credentials.items() if not v] | |
if missing_creds: | |
logger.warning(f"Missing credentials: {', '.join(missing_creds)}") | |
# Setup Hugging Face authentication | |
if credentials['HF_TOKEN']: | |
login(token=credentials['HF_TOKEN']) | |
logger.info("Successfully logged in to Hugging Face") | |
# Setup Kaggle credentials if available | |
if credentials['KAGGLE_USERNAME'] and credentials['KAGGLE_KEY']: | |
os.environ['KAGGLE_USERNAME'] = credentials['KAGGLE_USERNAME'] | |
os.environ['KAGGLE_KEY'] = credentials['KAGGLE_KEY'] | |
# Setup wandb if available | |
if credentials['WANDB_KEY']: | |
os.environ['WANDB_API_KEY'] = credentials['WANDB_KEY'] | |
return credentials | |
except Exception as e: | |
logger.error(f"Error setting up credentials: {e}") | |
raise | |
class ModelTrainer: | |
"""Handles model training pipeline""" | |
def __init__(self): | |
# Set memory optimization environment variables | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True' | |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
# Initialize attributes | |
self.model = None | |
self.tokenizer = None | |
self.dataset = None | |
self.processed_dataset = None | |
self.chunk_size = 300 | |
self.chunk_overlap = 100 | |
self.num_relevant_chunks = 3 | |
self.vector_store = None | |
self.embeddings = None | |
self.last_interaction_time = time.time() # Add this | |
self.interaction_cooldown = 1.0 # Add this | |
# Setup GPU preferences | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cudnn.allow_tf32 = False | |
def prepare_initial_datasets(batch_size=8): | |
print("Loading datasets with memory-optimized batch processing...") | |
def process_medqa_batch(examples): | |
results = [] | |
inputs = examples['input'] | |
instructions = examples['instruction'] | |
outputs = examples['output'] | |
for inp, inst, out in zip(inputs, instructions, outputs): | |
results.append({ | |
"input": f"{inp} {inst}", | |
"output": out | |
}) | |
return results | |
def process_meddia_batch(examples): | |
results = [] | |
inputs = examples['input'] | |
outputs = examples['output'] | |
for inp, out in zip(inputs, outputs): | |
results.append({ | |
"input": inp, | |
"output": out | |
}) | |
return results | |
def process_persona_batch(examples): | |
results = [] | |
personalities = examples['personality'] | |
utterances = examples['utterances'] | |
for pers, utts in zip(personalities, utterances): | |
try: | |
# Process personality list | |
personality = ' '.join([ | |
p for p in pers | |
if isinstance(p, str) | |
]) | |
# Process utterances | |
if utts and len(utts) > 0: | |
utterance = utts[0] | |
history = [] | |
# Process history | |
if 'history' in utterance and utterance['history']: | |
history = [ | |
h for h in utterance['history'] | |
if isinstance(h, str) | |
] | |
history_text = ' '.join(history) | |
# Get candidate response | |
candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else '' | |
if personality or history_text: | |
results.append({ | |
"input": f"{personality} {history_text}".strip(), | |
"output": candidate | |
}) | |
except Exception as e: | |
print(f"Error processing persona batch item: {e}") | |
continue | |
return results | |
try: | |
Load and process each dataset separately | |
print("Processing MedQA dataset...") | |
medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]") | |
medqa_processed = [] | |
for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"): | |
batch = medqa[i:i + batch_size] | |
medqa_processed.extend(process_medqa_batch(batch)) | |
if i % (batch_size * 5) == 0: | |
torch.cuda.empty_cache() | |
print("Processing MedDiagnosis dataset...") | |
meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]") | |
meddia_processed = [] | |
for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"): | |
batch = meddia[i:i + batch_size] | |
meddia_processed.extend(process_meddia_batch(batch)) | |
if i % (batch_size * 5) == 0: | |
torch.cuda.empty_cache() | |
print("Processing Persona-Chat dataset...") | |
persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]") | |
persona_processed = [] | |
for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"): | |
batch = persona[i:i + batch_size] | |
persona_processed.extend(process_persona_batch(batch)) | |
if i % (batch_size * 5) == 0: | |
torch.cuda.empty_cache() | |
torch.cuda.empty_cache() | |
print("Creating final dataset...") | |
all_processed = persona_processed + medqa_processed + meddia_processed | |
valid_data = { | |
"input": [], | |
"output": [] | |
} | |
for item in all_processed: | |
if item["input"].strip() and item["output"].strip(): | |
valid_data["input"].append(item["input"]) | |
valid_data["output"].append(item["output"]) | |
final_dataset = Dataset.from_dict(valid_data) | |
print(f"Final dataset size: {len(final_dataset)}") | |
return final_dataset | |
def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4): | |
def tokenize_batch(examples): | |
formatted_texts = [] | |
for i in range(0, len(examples['input']), batch_size): | |
sub_batch_inputs = examples['input'][i:i + batch_size] | |
sub_batch_outputs = examples['output'][i:i + batch_size] | |
for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs): | |
try: | |
formatted_text = f"""<start_of_turn>user | |
{input_text} | |
<end_of_turn> | |
<start_of_turn>assistant | |
{output_text} | |
<end_of_turn>""" | |
formatted_texts.append(formatted_text) | |
except Exception as e: | |
print(f"Error formatting text: {e}") | |
continue | |
tokenized = tokenizer( | |
formatted_texts, | |
padding="max_length", | |
truncation=True, | |
max_length=max_length, | |
return_tensors=None | |
) | |
tokenized["labels"] = tokenized["input_ids"].copy() | |
return tokenized | |
print(f"Tokenizing dataset in small batches (size={batch_size})...") | |
tokenized_dataset = dataset.map( | |
tokenize_batch, | |
batched=True, | |
batch_size=batch_size, | |
remove_columns=dataset.column_names, | |
desc="Tokenizing dataset", | |
load_from_cache_file=False | |
) | |
return tokenized_dataset | |
def setup_rag(self): | |
"""Initialize RAG components""" | |
try: | |
logger.info("Setting up RAG system...") | |
# Load knowledge base | |
knowledge_base = self._load_knowledge_base() | |
# Setup embeddings | |
self.embeddings = self._initialize_embeddings() | |
# Process texts for vector store | |
texts = self._split_texts(knowledge_base) | |
# Create vector store with metadata | |
self.vector_store = FAISS.from_texts( | |
texts, | |
self.embeddings, | |
metadatas=[{"source": f"chunk_{i}"} for i in range(len(texts))] | |
) | |
# Validate RAG setup | |
self._validate_rag_setup() | |
logger.info("RAG system setup complete") | |
except Exception as e: | |
logger.error(f"Failed to setup RAG: {e}") | |
raise | |
# Load your knowledge base content | |
def _load_knowledge_base(self): | |
"""Load and validate knowledge base content""" | |
try: | |
knowledge_base = { | |
"triage_scenarios.txt": """Medical Triage Scenarios and Responses: | |
EMERGENCY (999) SCENARIOS: | |
1. Cardiovascular: | |
- Chest pain/pressure | |
- Heart attack symptoms | |
- Irregular heartbeat with dizziness | |
Response: Immediate 999 call, sit/lie down, chew aspirin if available | |
2. Respiratory: | |
- Severe breathing difficulty | |
- Choking | |
- Unable to speak full sentences | |
Response: 999, sitting position, clear airway | |
3. Neurological: | |
- Stroke symptoms (FAST) | |
- Seizures | |
- Unconsciousness | |
Response: 999, recovery position if unconscious | |
4. Trauma: | |
- Severe bleeding | |
- Head injuries with confusion | |
- Major burns | |
Response: 999, apply direct pressure to bleeding | |
URGENT CARE (111) SCENARIOS: | |
1. Moderate Symptoms: | |
- Persistent fever | |
- Non-severe infections | |
- Minor injuries | |
Response: 111 contact, monitor symptoms | |
2. Minor Emergencies: | |
- Small cuts needing stitches | |
- Sprains and strains | |
- Mild allergic reactions | |
Response: 111 or urgent care visit | |
GP APPOINTMENT SCENARIOS: | |
1. Routine Care: | |
- Chronic condition review | |
- Medication reviews | |
- Non-urgent symptoms | |
Response: Book routine GP appointment | |
2. Preventive Care: | |
- Vaccinations | |
- Health screenings | |
- Regular check-ups | |
Response: Schedule with GP reception""", | |
"emergency_detection.txt": """Enhanced Emergency Detection Criteria: | |
IMMEDIATE LIFE THREATS: | |
1. Cardiac Symptoms: | |
- Chest pain/pressure/tightness | |
- Pain spreading to arms/jaw/neck | |
- Sweating with nausea | |
- Shortness of breath | |
2. Breathing Problems: | |
- Severe shortness of breath | |
- Blue lips or face | |
- Unable to complete sentences | |
- Choking/airway blockage | |
3. Neurological: | |
- FAST (Face, Arms, Speech, Time) | |
- Sudden confusion | |
- Severe headache | |
- Seizures | |
- Loss of consciousness | |
4. Severe Trauma: | |
- Heavy bleeding | |
- Deep wounds | |
- Head injury with confusion | |
- Severe burns | |
- Broken bones with deformity | |
5. Anaphylaxis: | |
- Sudden swelling | |
- Difficulty breathing | |
- Rapid onset rash | |
- Light-headedness | |
URGENT BUT NOT IMMEDIATE: | |
1. Moderate Symptoms: | |
- Persistent fever | |
- Dehydration | |
- Non-severe infections | |
- Minor injuries | |
2. Worsening Conditions: | |
- Increasing pain | |
- Progressive symptoms | |
- Medication reactions | |
RESPONSE PROTOCOLS: | |
1. For Life Threats: | |
- Immediate 999 call | |
- Clear first aid instructions | |
- Stay on line until help arrives | |
2. For Urgent Care: | |
- 111 contact | |
- Monitor for worsening | |
- Document symptoms""", | |
"gp_booking.txt": """GP Appointment Booking Templates: | |
APPOINTMENT TYPES: | |
1. Routine Appointments: | |
Template: "I need to book a routine appointment for [condition]. My availability is [times/dates]. My GP is Dr. [name] if available." | |
2. Follow-up Appointments: | |
Template: "I need a follow-up appointment regarding [condition] discussed on [date]. My previous appointment was with Dr. [name]." | |
3. Medication Reviews: | |
Template: "I need a medication review for [medication]. My last review was [date]." | |
BOOKING INFORMATION NEEDED: | |
1. Patient Details: | |
- Full name | |
- Date of birth | |
- NHS number (if known) | |
- Registered GP practice | |
2. Appointment Details: | |
- Nature of appointment | |
- Preferred times/dates | |
- Urgency level | |
- Special requirements | |
3. Contact Information: | |
- Phone number | |
- Alternative contact | |
- Preferred contact method | |
BOOKING PROCESS: | |
1. Online Booking: | |
- NHS app instructions | |
- Practice website guidance | |
- System navigation help | |
2. Phone Booking: | |
- Best times to call | |
- Required information | |
- Queue management tips | |
3. Special Circumstances: | |
- Interpreter needs | |
- Accessibility requirements | |
- Transport arrangements""", | |
"cultural_sensitivity.txt": """Cultural Sensitivity Guidelines: | |
CULTURAL AWARENESS: | |
1. Religious Considerations: | |
- Prayer times | |
- Religious observations | |
- Dietary restrictions | |
- Gender preferences for care | |
- Religious festivals/fasting periods | |
2. Language Support: | |
- Interpreter services | |
- Multi-language resources | |
- Clear communication methods | |
- Family involvement preferences | |
3. Cultural Beliefs: | |
- Traditional medicine practices | |
- Cultural health beliefs | |
- Family decision-making | |
- Privacy customs | |
COMMUNICATION APPROACHES: | |
1. Respectful Interaction: | |
- Use preferred names/titles | |
- Appropriate greetings | |
- Non-judgmental responses | |
- Active listening | |
2. Language Usage: | |
- Clear, simple terms | |
- Avoid medical jargon | |
- Confirm understanding | |
- Respect silence/pauses | |
3. Non-verbal Communication: | |
- Eye contact customs | |
- Personal space | |
- Body language awareness | |
- Gesture sensitivity | |
SPECIFIC CONSIDERATIONS: | |
1. South Asian Communities: | |
- Family involvement | |
- Gender sensitivity | |
- Traditional medicine | |
- Language diversity | |
2. Middle Eastern Communities: | |
- Gender-specific care | |
- Religious observations | |
- Family hierarchies | |
- Privacy concerns | |
3. African/Caribbean Communities: | |
- Traditional healers | |
- Community involvement | |
- Historical medical mistrust | |
- Cultural specific conditions | |
4. Eastern European Communities: | |
- Direct communication | |
- Family involvement | |
- Medical documentation | |
- Language support | |
INCLUSIVE PRACTICES: | |
1. Appointment Scheduling: | |
- Religious holidays | |
- Prayer times | |
- Family availability | |
- Interpreter needs | |
2. Treatment Planning: | |
- Cultural preferences | |
- Traditional practices | |
- Family involvement | |
- Dietary requirements | |
3. Support Services: | |
- Community resources | |
- Cultural organizations | |
- Language services | |
- Social support""", | |
"service_boundaries.txt": """Service Limitations and Professional Boundaries: | |
CLEAR BOUNDARIES: | |
1. Medical Advice: | |
- No diagnoses | |
- No prescriptions | |
- No treatment recommendations | |
- No medical procedures | |
- No second opinions | |
2. Emergency Services: | |
- Clear referral criteria | |
- Documented responses | |
- Follow-up protocols | |
- Handover procedures | |
3. Information Sharing: | |
- Confidentiality limits | |
- Data protection | |
- Record keeping | |
- Information governance | |
PROFESSIONAL CONDUCT: | |
1. Communication: | |
- Professional language | |
- Emotional boundaries | |
- Personal distance | |
- Service scope | |
2. Service Delivery: | |
- No financial transactions | |
- No personal relationships | |
- Clear role definition | |
- Professional limits""" | |
} | |
# Create knowledge base directory | |
os.makedirs("knowledge_base", exist_ok=True) | |
# Write files and process documents | |
documents = [] | |
for filename, content in knowledge_base.items(): | |
filepath = os.path.join("knowledge_base", filename) | |
with open(filepath, "w", encoding="utf-8") as f: | |
f.write(content) | |
documents.append(content) | |
logger.info(f"Written knowledge base file: {filename}") | |
return knowledge_base | |
except Exception as e: | |
logger.error(f"Error loading knowledge base: {str(e)}") | |
raise | |
def _validate_rag_setup(self): | |
"""Validate RAG system setup""" | |
try: | |
# Verify embeddings are working | |
test_text = "This is a test embedding" | |
test_embedding = self.embeddings.encode(test_text) | |
assert len(test_embedding) > 0 | |
# Verify vector store is operational | |
test_results = self.vector_store.similarity_search(test_text, k=1) | |
assert len(test_results) > 0 | |
logger.info("RAG system validation successful") | |
return True | |
except Exception as e: | |
logger.error(f"RAG system validation failed: {str(e)}") | |
raise | |
def setup_model_and_tokenizer(model_name="google/gemma-2b"): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
from transformers import BitsAndBytesConfig | |
bnb_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_compute_dtype=torch.float16, | |
llm_int8_enable_fp32_cpu_offload=True | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
quantization_config=bnb_config, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
) | |
model = prepare_model_for_kbit_training(model) | |
lora_config = LoraConfig( | |
r=4, | |
lora_alpha=16, | |
target_modules=["q_proj", "v_proj"], | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM" | |
) | |
model = get_peft_model(model, lora_config) | |
model.print_trainable_parameters() | |
return model, tokenizer | |
def setup_training_arguments(output_dir="./pearly_fine_tuned"): | |
return TrainingArguments( | |
output_dir=output_dir, | |
num_train_epochs=1, | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=16, | |
warmup_steps=50, | |
logging_steps=10, | |
save_steps=200, | |
learning_rate=2e-4, | |
fp16=True, | |
gradient_checkpointing=True, | |
gradient_checkpointing_kwargs={"use_reentrant": False}, | |
optim="adamw_8bit", | |
max_grad_norm=0.3, | |
weight_decay=0.001, | |
logging_dir="./logs", | |
save_total_limit=2, | |
remove_unused_columns=False, | |
dataloader_pin_memory=False, | |
max_steps=500, | |
report_to=["none"], | |
) | |
def train(self): | |
"""Main training pipeline with RAG integration""" | |
try: | |
logger.info("Starting training pipeline") | |
# Clear GPU memory | |
torch.cuda.empty_cache() | |
if torch.cuda.is_available(): | |
torch.cuda.reset_peak_memory_stats() | |
# Setup model, tokenizer, and RAG | |
logger.info("Setting up model components...") | |
self.model, self.tokenizer = self.setup_model_and_tokenizer() | |
self.setup_rag() | |
# Prepare and process datasets | |
logger.info("Preparing datasets...") | |
self.dataset = self.prepare_initial_datasets(batch_size=4) | |
self.processed_dataset = self.prepare_dataset( | |
self.dataset, | |
self.tokenizer, | |
max_length=256, | |
batch_size=2 | |
) | |
# Train model | |
logger.info("Starting training...") | |
training_args = self.setup_training_arguments() | |
trainer = Trainer( | |
model=self.model, | |
args=training_args, | |
train_dataset=self.processed_dataset, | |
tokenizer=self.tokenizer | |
) | |
trainer.train() | |
# Save and push to hub | |
logger.info("Saving model...") | |
trainer.save_model() | |
if os.getenv('HF_TOKEN'): | |
trainer.push_to_hub( | |
"Pearilsa/pearly_med_triage_chatbot_kagglex", | |
private=True | |
) | |
logger.info("Training completed successfully!") | |
except Exception as e: | |
logger.error(f"Training failed: {e}") | |
raise | |
finally: | |
torch.cuda.empty_cache() | |
if __name__ == "__main__": | |
# Initialize trainer | |
trainer = ModelTrainer() | |
# Train model | |
trainer.train() | |
def _get_enhanced_context(self, query: str) -> str: | |
"""Get relevant context with scores""" | |
try: | |
# Get documents with similarity scores | |
docs_and_scores = self.vector_store.similarity_search_with_score( | |
query, | |
k=self.num_relevant_chunks | |
) | |
# Filter and format relevant contexts | |
relevant_contexts = [] | |
for doc, score in docs_and_scores: | |
if score < 0.8: # Lower score means more relevant | |
source = doc.metadata.get('source', 'Unknown') | |
relevant_contexts.append( | |
f"[Source: {source}]\n{doc.page_content}" | |
) | |
return "\n\n".join(relevant_contexts) if relevant_contexts else "" | |
except Exception as e: | |
logger.error(f"Error retrieving enhanced context: {e}") | |
return "" | |
def _initialize_embeddings(self): | |
try: | |
return HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
cache_folder="./embeddings_cache" # Added caching | |
) | |
except Exception as e: | |
logger.error(f"Failed to initialize embeddings: {str(e)}") | |
raise | |
def _split_texts(self, knowledge_base): | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=self.chunk_size, | |
chunk_overlap=self.chunk_overlap, | |
length_function=len, | |
add_start_index=True | |
) | |
all_texts = [] | |
for content in knowledge_base.values(): | |
texts = splitter.split_text(content) | |
all_texts.extend(texts) | |
return all_texts | |
def get_relevant_context(self, query): | |
try: | |
docs = self.vector_store.similarity_search(query, k=3) | |
return "\n".join(doc.page_content for doc in docs) | |
except Exception as e: | |
logger.error(f"Error retrieving context: {str(e)}") | |
return "" | |
def generate_response(self, message: str, history: list) -> str: | |
"""Generate response using both fine-tuned model and RAG""" | |
try: | |
# Rate limiting and memory management | |
current_time = time.time() | |
if current_time - self.last_interaction_time < self.interaction_cooldown: | |
time.sleep(self.interaction_cooldown) | |
torch.cuda.empty_cache() | |
# Get enhanced context from RAG | |
context = self._get_enhanced_context(message) | |
# Format conversation history | |
conv_history = "\n".join([ | |
f"User: {turn['input']}\nAssistant: {turn['output']}" | |
for turn in history[-3:] # Keep last 3 turns | |
]) | |
# Create enhanced prompt with RAG context | |
prompt = f"""<start_of_turn>system | |
Using these medical guidelines: | |
{context} | |
Previous conversation: | |
{conv_history} | |
Guidelines: | |
1. Assess symptoms and severity based on both your training and the provided guidelines | |
2. Ask relevant follow-up questions if needed | |
3. Direct to appropriate care (999, 111, or GP) according to symptom severity | |
4. Show empathy and cultural sensitivity | |
5. Never diagnose or recommend treatments | |
<end_of_turn> | |
<start_of_turn>user | |
{message} | |
<end_of_turn> | |
<start_of_turn>assistant""" | |
# Generate response with model | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512 | |
).to(self.model.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=256, | |
min_new_tokens=20, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
no_repeat_ngram_size=3 | |
) | |
# Process response | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response.split("<start_of_turn>assistant")[-1].strip() | |
if "<end_of_turn>" in response: | |
response = response.split("<end_of_turn>")[0].strip() | |
self.last_interaction_time = time.time() | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
return "I apologize, but I encountered an error. Please try again." | |
def handle_feedback(self, message: str, response: str, feedback: int): | |
"""Handle user feedback for responses""" | |
try: | |
timestamp = datetime.now().isoformat() | |
feedback_data = { | |
"message": message, | |
"response": response, | |
"feedback": feedback, | |
"timestamp": timestamp | |
} | |
# Log feedback | |
logger.info(f"Feedback received: {feedback_data}") | |
# Here you could: | |
# 1. Store feedback in a database | |
# 2. Send to monitoring system | |
# 3. Use for model improvements | |
return True | |
except Exception as e: | |
logger.error(f"Error handling feedback: {e}") | |
return False | |
def __del__(self): | |
"""Cleanup resources""" | |
try: | |
if hasattr(self, 'model'): | |
del self.model | |
ModelManager.clear_gpu_memory() | |
except Exception as e: | |
logger.error(f"Error in cleanup: {e}") | |
def create_demo(): | |
try: | |
# Initialize bot | |
bot = PearlyBot() | |
def chat(message: str, history: list): | |
"""Handle chat interactions""" | |
try: | |
if not message.strip(): | |
return history | |
response = bot.generate_response(message, history) | |
history.append({ | |
"role": "user", | |
"content": message | |
}) | |
history.append({ | |
"role": "assistant", | |
"content": response | |
}) | |
return history | |
except Exception as e: | |
logger.error(f"Chat error: {e}") | |
return history + [{ | |
"role": "assistant", | |
"content": "I apologize, but I'm experiencing technical difficulties. For emergencies, please call 999." | |
}] | |
def process_feedback(positive: bool, comment: str, history: list): | |
try: | |
if not history or len(history) < 2: | |
return gr.update(value="") | |
last_user_msg = history[-2]["content"] if isinstance(history[-2], dict) else history[-2][0] | |
last_bot_msg = history[-1]["content"] if isinstance(history[-1], dict) else history[-1][1] | |
bot.handle_feedback( | |
message=last_user_msg, | |
response=last_bot_msg, | |
feedback=1 if positive else -1 | |
) | |
return gr.update(value="") | |
except Exception as e: | |
logger.error(f"Error processing feedback: {e}") | |
return gr.update(value="") | |
# Create Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft(...)) as demo: | |
# 1. First, create all UI elements | |
# CSS styles | |
gr.HTML("""<style>...""") | |
# Emergency Banner | |
gr.HTML("""<div class="emergency-banner">...""") | |
# Header | |
with gr.Row(elem_classes="header"): | |
gr.Markdown("""# GP Medical Triage Assistant...""") | |
# Features Grid | |
gr.HTML("""<div class="features-grid">...""") | |
# Chat Interface | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot(...) | |
with gr.Row(): | |
msg = gr.Textbox(...) | |
submit = gr.Button(...) | |
with gr.Column(scale=1): | |
# Quick Actions | |
emergency_btn = gr.Button("π¨ Emergency Info", variant="secondary") | |
nhs_111_btn = gr.Button("π NHS 111 Info", variant="secondary") | |
booking_btn = gr.Button("π GP Booking", variant="secondary") | |
# Controls | |
clear = gr.Button("ποΈ Clear Chat") | |
# Feedback | |
with gr.Row(): | |
feedback_positive = gr.Button("π", elem_id="thumb-up") | |
feedback_negative = gr.Button("π", elem_id="thumb-down") | |
feedback_text = gr.Textbox(...) | |
feedback_submit = gr.Button(...) | |
# Examples and Guide | |
with gr.Accordion("Example Messages", open=False): | |
gr.Examples([...]) | |
with gr.Accordion("NHS Services Guide", open=False): | |
gr.Markdown("""...""") | |
# Create enhanced Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="indigo", | |
neutral_hue="slate", | |
font=gr.themes.GoogleFont("Inter") | |
)) as demo: | |
# Custom CSS for enhanced styling | |
gr.HTML(""" | |
<style> | |
.container { max-width: 900px; margin: auto; } | |
.header { text-align: center; padding: 20px; } | |
.emergency-banner { | |
background-color: #ff4444; | |
color: white; | |
padding: 10px; | |
text-align: center; | |
font-weight: bold; | |
margin-bottom: 20px; | |
} | |
.feature-card { | |
padding: 15px; | |
border-radius: 10px; | |
text-align: center; | |
transition: transform 0.2s; | |
color: white; | |
font-weight: bold; | |
} | |
.feature-card:nth-child(1) { background: linear-gradient(135deg, #2193b0, #6dd5ed); } | |
.feature-card:nth-child(2) { background: linear-gradient(135deg, #834d9b, #d04ed6); } | |
.feature-card:nth-child(3) { background: linear-gradient(135deg, #ff4b1f, #ff9068); } | |
.feature-card:nth-child(4) { background: linear-gradient(135deg, #38ef7d, #11998e); } | |
.feature-card:hover { | |
transform: translateY(-5px); | |
box-shadow: 0 5px 15px rgba(0,0,0,0.2); | |
} | |
.feature-card span.emoji { | |
font-size: 2em; | |
display: block; | |
margin-bottom: 10px; | |
} | |
.message-textbox textarea { resize: none; } | |
#thumb-up, #thumb-down { | |
min-width: 60px; | |
padding: 8px; | |
margin: 5px; | |
} | |
.chatbot-message { | |
padding: 12px; | |
margin: 8px 0; | |
border-radius: 8px; | |
} | |
.user-message { background-color: #e3f2fd; } | |
.assistant-message { background-color: #f5f5f5; } | |
.feedback-section { | |
margin-top: 20px; | |
padding: 15px; | |
border-radius: 8px; | |
background-color: #f8f9fa; | |
} | |
</style> | |
""") | |
# Emergency Banner | |
gr.HTML(""" | |
<div class="emergency-banner"> | |
π¨ For medical emergencies, always call 999 immediately π¨ | |
</div> | |
""") | |
# Header Section | |
with gr.Row(elem_classes="header"): | |
gr.Markdown(""" | |
# GP Medical Triage Assistant - Pearly | |
Welcome to your personal medical triage assistant. I'm here to help assess your symptoms and guide you to appropriate care. | |
""") | |
# Main Features Grid | |
gr.HTML(""" | |
<div class="features-grid"> | |
<div class="feature-card"> | |
<span class="emoji">π₯</span> | |
<div>GP Appointments</div> | |
</div> | |
<div class="feature-card"> | |
<span class="emoji">π</span> | |
<div>Symptom Assessment</div> | |
</div> | |
<div class="feature-card"> | |
<span class="emoji">β‘</span> | |
<div>Urgent Care Guide</div> | |
</div> | |
<div class="feature-card"> | |
<span class="emoji">π</span> | |
<div>Medical Advice</div> | |
</div> | |
</div> | |
""") | |
# Chat Interface | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot( | |
value=[{ | |
"role": "assistant", | |
"content": "Hello! I'm Pearly, your GP medical assistant. How can I help you today?" | |
}], | |
height=500, | |
elem_id="chatbot", | |
type="messages", | |
show_label=False | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
lines=2, | |
scale=4, | |
autofocus=True, | |
submit_on_enter=True | |
) | |
submit = gr.Button("Send", variant="primary", scale=1) | |
with gr.Column(scale=1): | |
# Quick Actions Panel | |
gr.Markdown("### Quick Actions") | |
emergency_btn = gr.Button("π¨ Emergency Info", variant="secondary") | |
nhs_111_btn = gr.Button("π NHS 111 Info", variant="secondary") | |
booking_btn = gr.Button("π GP Booking", variant="secondary") | |
# Controls and Feedback | |
gr.Markdown("### Controls") | |
clear = gr.Button("ποΈ Clear Chat") | |
gr.Markdown("### Feedback") | |
with gr.Row(): | |
feedback_positive = gr.Button("π", elem_id="thumb-up") | |
feedback_negative = gr.Button("π", elem_id="thumb-down") | |
feedback_text = gr.Textbox( | |
label="Additional comments", | |
placeholder="Tell us more...", | |
lines=2, | |
visible=True | |
) | |
feedback_submit = gr.Button("Submit Feedback", visible=True) | |
# Examples and Information | |
with gr.Accordion("Example Messages", open=False): | |
gr.Examples([ | |
["I've been having severe headaches for the past week"], | |
["I need to book a routine checkup"], | |
["I'm feeling very anxious lately and need help"], | |
["My child has had a fever for 2 days"], | |
["I need information about COVID-19 testing"] | |
], inputs=msg) | |
with gr.Accordion("NHS Services Guide", open=False): | |
gr.Markdown(""" | |
### Emergency Services (999) | |
- Life-threatening emergencies | |
- Severe injuries | |
- Suspected heart attack or stroke | |
### NHS 111 | |
- Urgent but non-emergency situations | |
- Medical advice needed | |
- Unsure where to go | |
### GP Services | |
- Routine check-ups | |
- Non-urgent medical issues | |
- Prescription renewals | |
""") | |
def show_emergency_info(): | |
return """π¨ Emergency Services (999) | |
- For life-threatening emergencies | |
- Severe chest pain | |
- Difficulty breathing | |
- Severe bleeding | |
- Loss of consciousness | |
""" | |
def show_nhs_111_info(): | |
return """π NHS 111 Service | |
- Available 24/7 | |
- Medical advice | |
- Local service information | |
- Urgent care guidance | |
""" | |
def show_booking_info(): | |
return """π GP Booking Options | |
- Online booking | |
- Phone booking | |
- Routine appointments | |
- Urgent appointments | |
""" | |
# Chat handlers | |
msg.submit(chat, [msg, chatbot], [chatbot]).then( | |
lambda: gr.update(value=""), None, [msg] | |
) | |
submit.click(chat, [msg, chatbot], [chatbot]).then( | |
lambda: gr.update(value=""), None, [msg] | |
) | |
# Quick action handlers | |
emergency_btn.click(lambda: show_emergency_info(), outputs=[msg]) | |
nhs_111_btn.click(lambda: show_nhs_111_info(), outputs=[msg]) | |
booking_btn.click(lambda: show_booking_info(), outputs=[msg]) | |
# Feedback handlers | |
feedback_positive.click( | |
lambda h: process_feedback(True, feedback_text.value, h), | |
inputs=[chatbot], | |
outputs=[feedback_text] | |
) | |
feedback_negative.click( | |
lambda h: process_feedback(False, feedback_text.value, h), | |
inputs=[chatbot], | |
outputs=[feedback_text] | |
) | |
# Clear chat | |
clear.click(lambda: None, None, chatbot) | |
# 3. Finally, add the queue | |
demo.queue(concurrency_count=1, max_size=10) | |
return demo | |
except Exception as e: | |
logger.error(f"Error creating demo: {e}") | |
raise | |
if __name__ == "__main__": | |
# Initialize logging and load env vars | |
logging.basicConfig(level=logging.INFO) | |
load_dotenv() | |
# Create and launch demo | |
demo = create_demo() | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |