VHA1 / app.py
lukiod's picture
Update app.py
a548a89 verified
raw
history blame
10.3 kB
import gradio as gr
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import logging
from typing import List, Dict
import gc
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class HealthAssistant:
def __init__(self):
self.model_name = "Qwen/Qwen2-VL-7B-Instruct"
self.model = None
self.tokenizer = None
self.processor = None
self.metrics = []
self.medications = []
self.initialize_model()
def initialize_model(self):
try:
logger.info("Loading Qwen2-VL model...")
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.processor = AutoProcessor.from_pretrained(
self.model_name,
min_pixels=256*28*28,
max_pixels=1280*28*28
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def generate_response(self, message: str, history: List = None) -> str:
try:
# Format conversation with health context
messages = self._format_messages(message, history)
# Prepare for inference
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Since we're not using images in this case
image_inputs, video_inputs = [], []
# Process inputs
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
inputs = inputs.to(self.model.device)
# Generate response
generated_ids = self.model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9
)
# Decode response
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
# Cleanup
del inputs, generated_ids, generated_ids_trimmed
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return output_text.strip()
except Exception as e:
logger.error(f"Error generating response: {e}")
return "I apologize, but I encountered an error. Please try again."
def _format_messages(self, message: str, history: List = None) -> List[Dict]:
"""Format messages for the Qwen2-VL model"""
# Add system context
messages = []
# Add health context
health_context = self._get_health_context()
if health_context:
messages.append({
"role": "system",
"content": [{"type": "text", "text": f"Current health information:\n{health_context}"}]
})
# Add conversation history
if history:
for user_msg, assistant_msg in history[-3:]: # Last 3 exchanges
messages.extend([
{"role": "user", "content": [{"type": "text", "text": user_msg}]},
{"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}
])
# Add current message
messages.append({
"role": "user",
"content": [{"type": "text", "text": message}]
})
return messages
def _get_health_context(self) -> str:
"""Get health metrics and medications context"""
context_parts = []
if self.metrics:
latest = self.metrics[-1]
context_parts.extend([
"Recent Health Metrics:",
f"- Weight: {latest.get('Weight', 'N/A')} kg",
f"- Steps: {latest.get('Steps', 'N/A')}",
f"- Sleep: {latest.get('Sleep', 'N/A')} hours"
])
if self.medications:
context_parts.append("\nCurrent Medications:")
for med in self.medications:
med_info = f"- {med['Medication']} ({med['Dosage']}) at {med['Time']}"
if med.get('Notes'):
med_info += f" | Note: {med['Notes']}"
context_parts.append(med_info)
return "\n".join(context_parts) if context_parts else ""
def add_metrics(self, weight: float, steps: int, sleep: float) -> bool:
try:
self.metrics.append({
'Weight': weight,
'Steps': steps,
'Sleep': sleep
})
return True
except Exception as e:
logger.error(f"Error adding metrics: {e}")
return False
def add_medication(self, name: str, dosage: str, time: str, notes: str = "") -> bool:
try:
self.medications.append({
'Medication': name,
'Dosage': dosage,
'Time': time,
'Notes': notes
})
return True
except Exception as e:
logger.error(f"Error adding medication: {e}")
return False
class GradioInterface:
def __init__(self):
self.assistant = HealthAssistant()
def chat_response(self, message: str, history: List) -> tuple:
if not message.strip():
return "", history
response = self.assistant.generate_response(message, history)
history.append([message, response])
return "", history
def add_health_metrics(self, weight: float, steps: int, sleep: float) -> str:
if not all([weight, steps, sleep]):
return "⚠️ Please fill in all metrics."
if self.assistant.add_metrics(weight, steps, sleep):
return "βœ… Health metrics saved successfully!"
return "❌ Error saving metrics."
def add_medication_info(self, name: str, dosage: str, time: str, notes: str) -> str:
if not all([name, dosage, time]):
return "⚠️ Please fill in all required fields."
if self.assistant.add_medication(name, dosage, time, notes):
return "βœ… Medication added successfully!"
return "❌ Error adding medication."
def create_interface(self):
with gr.Blocks(title="Health Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ₯ AI Health Assistant
Powered by Qwen2-VL for intelligent health guidance and monitoring.
"""
)
with gr.Tabs():
# Chat Interface
with gr.Tab("πŸ’¬ Health Chat"):
chatbot = gr.Chatbot(
height=450,
show_label=False
)
with gr.Row():
msg = gr.Textbox(
placeholder="Ask your health question... (Press Enter)",
lines=2,
show_label=False,
scale=9
)
send_btn = gr.Button("Send", scale=1)
clear_btn = gr.Button("Clear Chat")
# Health Metrics
with gr.Tab("πŸ“Š Health Metrics"):
with gr.Row():
weight_input = gr.Number(label="Weight (kg)")
steps_input = gr.Number(label="Steps")
sleep_input = gr.Number(label="Hours Slept")
metrics_btn = gr.Button("Save Metrics")
metrics_status = gr.Markdown()
# Medication Manager
with gr.Tab("πŸ’Š Medication Manager"):
with gr.Row():
med_name = gr.Textbox(label="Medication Name")
med_dosage = gr.Textbox(label="Dosage")
med_time = gr.Textbox(label="Time (e.g., 9:00 AM)")
med_notes = gr.Textbox(label="Notes (optional)")
med_btn = gr.Button("Add Medication")
med_status = gr.Markdown()
# Event handlers
msg.submit(self.chat_response, [msg, chatbot], [msg, chatbot])
send_btn.click(self.chat_response, [msg, chatbot], [msg, chatbot])
clear_btn.click(lambda: [], None, chatbot)
metrics_btn.click(
self.add_health_metrics,
inputs=[weight_input, steps_input, sleep_input],
outputs=[metrics_status]
)
med_btn.click(
self.add_medication_info,
inputs=[med_name, med_dosage, med_time, med_notes],
outputs=[med_status]
)
gr.Markdown(
"""
### ⚠️ Important Note
This AI assistant provides general health information only.
Always consult healthcare professionals for medical advice.
"""
)
return demo
def main():
try:
interface = GradioInterface()
demo = interface.create_interface()
demo.launch(
share=False,
enable_queue=True,
max_threads=4
)
except Exception as e:
logger.error(f"Error starting application: {e}")
if __name__ == "__main__":
main()