|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import logging |
|
from typing import List, Dict |
|
import gc |
|
import os |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
torch.set_num_threads(4) |
|
|
|
class HealthAssistant: |
|
def __init__(self, use_smaller_model=True): |
|
if use_smaller_model: |
|
self.model_name = "facebook/opt-125m" |
|
else: |
|
self.model_name = "Qwen/Qwen2-VL-7B-Instruct" |
|
|
|
self.model = None |
|
self.tokenizer = None |
|
self.metrics = [] |
|
self.medications = [] |
|
self.initialize_model() |
|
|
|
def initialize_model(self): |
|
try: |
|
logger.info(f"Starting model initialization: {self.model_name}") |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_name, |
|
trust_remote_code=True |
|
) |
|
logger.info("Tokenizer loaded") |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_name, |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True |
|
) |
|
|
|
self.model = self.model.to("cpu") |
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
logger.info("Model loaded successfully") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Error in model initialization: {str(e)}") |
|
raise |
|
|
|
def is_initialized(self): |
|
return (self.model is not None and |
|
self.tokenizer is not None and |
|
hasattr(self.model, 'generate')) |
|
|
|
def generate_response(self, message: str, history: List = None) -> str: |
|
try: |
|
if not self.is_initialized(): |
|
return "System is still initializing. Please try again in a moment." |
|
|
|
|
|
prompt = self._prepare_prompt(message, history) |
|
|
|
|
|
inputs = self.tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=512 |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
inputs["input_ids"], |
|
max_new_tokens=128, |
|
num_beams=1, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
eos_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response = self.tokenizer.decode( |
|
outputs[0][inputs["input_ids"].shape[1]:], |
|
skip_special_tokens=True |
|
) |
|
|
|
|
|
del outputs, inputs |
|
gc.collect() |
|
|
|
return response.strip() |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating response: {str(e)}") |
|
return "I apologize, but I encountered an error. Please try again." |
|
|
|
def _prepare_prompt(self, message: str, history: List = None) -> str: |
|
parts = [ |
|
"You are a helpful healthcare assistant providing accurate and helpful medical information.", |
|
self._get_health_context() or "No health data available yet." |
|
] |
|
|
|
if history: |
|
parts.append("Previous conversation:") |
|
for h in history[-3:]: |
|
parts.extend([ |
|
f"User: {h[0]}", |
|
f"Assistant: {h[1]}" |
|
]) |
|
|
|
parts.extend([ |
|
f"User: {message}", |
|
"Assistant:" |
|
]) |
|
|
|
return "\n\n".join(parts) |
|
|
|
def _get_health_context(self) -> str: |
|
context_parts = [] |
|
|
|
if self.metrics: |
|
latest = self.metrics[-1] |
|
context_parts.extend([ |
|
"Recent Health Metrics:", |
|
f"- Weight: {latest.get('Weight', 'N/A')} kg", |
|
f"- Steps: {latest.get('Steps', 'N/A')}", |
|
f"- Sleep: {latest.get('Sleep', 'N/A')} hours" |
|
]) |
|
|
|
if self.medications: |
|
context_parts.append("\nCurrent Medications:") |
|
for med in self.medications: |
|
med_info = f"- {med['Medication']} ({med['Dosage']}) at {med['Time']}" |
|
if med.get('Notes'): |
|
med_info += f" | Note: {med['Notes']}" |
|
context_parts.append(med_info) |
|
|
|
return "\n".join(context_parts) if context_parts else "" |
|
|
|
def add_metrics(self, weight: float, steps: int, sleep: float) -> bool: |
|
try: |
|
self.metrics.append({ |
|
'Weight': weight, |
|
'Steps': steps, |
|
'Sleep': sleep |
|
}) |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error adding metrics: {e}") |
|
return False |
|
|
|
def add_medication(self, name: str, dosage: str, time: str, notes: str = "") -> bool: |
|
try: |
|
self.medications.append({ |
|
'Medication': name, |
|
'Dosage': dosage, |
|
'Time': time, |
|
'Notes': notes |
|
}) |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error adding medication: {e}") |
|
return False |
|
|
|
class GradioInterface: |
|
def __init__(self): |
|
try: |
|
logger.info("Initializing Health Assistant...") |
|
self.assistant = HealthAssistant(use_smaller_model=True) |
|
if not self.assistant.is_initialized(): |
|
raise RuntimeError("Health Assistant failed to initialize properly") |
|
logger.info("Health Assistant initialized successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize Health Assistant: {e}") |
|
raise |
|
|
|
def chat_response(self, message: str, history: List) -> tuple: |
|
if not message.strip(): |
|
return "", history |
|
|
|
response = self.assistant.generate_response(message, history) |
|
history.append([message, response]) |
|
return "", history |
|
|
|
def add_health_metrics(self, weight: float, steps: int, sleep: float) -> str: |
|
if not all([weight is not None, steps is not None, sleep is not None]): |
|
return "β οΈ Please fill in all metrics." |
|
|
|
if self.assistant.add_metrics(weight, steps, sleep): |
|
return "β
Health metrics saved successfully!" |
|
return "β Error saving metrics." |
|
|
|
def add_medication_info(self, name: str, dosage: str, time: str, notes: str) -> str: |
|
if not all([name, dosage, time]): |
|
return "β οΈ Please fill in all required fields." |
|
|
|
if self.assistant.add_medication(name, dosage, time, notes): |
|
return "β
Medication added successfully!" |
|
return "β Error adding medication." |
|
|
|
def create_interface(self): |
|
with gr.Blocks(title="Health Assistant") as demo: |
|
gr.Markdown("# π₯ AI Health Assistant") |
|
|
|
with gr.Tabs(): |
|
|
|
with gr.Tab("π¬ Health Chat"): |
|
chatbot = gr.Chatbot( |
|
value=[], |
|
height=450 |
|
) |
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
placeholder="Ask your health question... (Press Enter)", |
|
lines=2, |
|
show_label=False, |
|
scale=9 |
|
) |
|
send_btn = gr.Button("Send", scale=1) |
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
|
|
with gr.Tab("π Health Metrics"): |
|
with gr.Row(): |
|
weight_input = gr.Number(label="Weight (kg)") |
|
steps_input = gr.Number(label="Steps") |
|
sleep_input = gr.Number(label="Hours Slept") |
|
metrics_btn = gr.Button("Save Metrics") |
|
metrics_status = gr.Markdown() |
|
|
|
|
|
with gr.Tab("π Medication Manager"): |
|
with gr.Row(): |
|
med_name = gr.Textbox(label="Medication Name") |
|
med_dosage = gr.Textbox(label="Dosage") |
|
med_time = gr.Textbox(label="Time (e.g., 9:00 AM)") |
|
med_notes = gr.Textbox(label="Notes (optional)") |
|
med_btn = gr.Button("Add Medication") |
|
med_status = gr.Markdown() |
|
|
|
|
|
msg.submit(self.chat_response, [msg, chatbot], [msg, chatbot]) |
|
send_btn.click(self.chat_response, [msg, chatbot], [msg, chatbot]) |
|
clear_btn.click(lambda: [], None, chatbot) |
|
|
|
metrics_btn.click( |
|
self.add_health_metrics, |
|
inputs=[weight_input, steps_input, sleep_input], |
|
outputs=[metrics_status] |
|
) |
|
|
|
med_btn.click( |
|
self.add_medication_info, |
|
inputs=[med_name, med_dosage, med_time, med_notes], |
|
outputs=[med_status] |
|
) |
|
|
|
demo.queue() |
|
|
|
return demo |
|
|
|
def main(): |
|
try: |
|
logger.info("Starting application...") |
|
interface = GradioInterface() |
|
demo = interface.create_interface() |
|
logger.info("Launching Gradio interface...") |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |
|
except Exception as e: |
|
logger.error(f"Error starting application: {e}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |