PixDiet / app.py
blanchon's picture
examples
442ad61
raw
history blame
8.1 kB
import torch
import spaces
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
LlavaForConditionalGeneration,
)
from PIL import Image
import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer, AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer
from dotenv import load_dotenv
import os
# Import Supabase functions
from db_client import get_user_history, update_user_history, delete_user_history
# Add these imports
from datetime import datetime
import pytz
load_dotenv()
# Add TESTING variable
TESTING = False # You can change this to False when not testing
IS_LOGGED_IN = True
USER_ID = "jeremie.feron@gmail.com"
# Hugging Face model id
model_id = "blanchon/pixtral-nutrition-2"
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Modify the model and processor initialization
if TESTING:
model_id = "vikhyatk/moondream1"
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
processor = Tokenizer.from_pretrained(model_id)
else:
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
)
processor = AutoProcessor.from_pretrained(model_id)
# Set the chat template for the tokenizer
processor.chat_template = """
{%- for message in messages %}
{%- if message.role == "user" %}
<s>[INST]
{%- for item in message.content %}
{%- if item.type == "text" %}
{{ item.text }}
{%- elif item.type == "image" %}
\n[IMG]
{%- endif %}
{%- endfor %}
[/INST]
{%- elif message.role == "assistant" %}
{%- for item in message.content %}
{%- if item.type == "text" %}
{{ item.text }}
{%- endif %}
{%- endfor %}
</s>
{%- endif %}
{%- endfor %}
""".replace(' ', "")
processor.tokenizer.pad_token = processor.tokenizer.eos_token
@spaces.GPU
def bot_streaming(chatbot, image_input, max_new_tokens=250):
# Preprocess inputs
messages = get_user_history(USER_ID)
images = []
text_input = chatbot[-1][0]
# Get current time in Paris timezone
paris_tz = pytz.timezone('Europe/Paris')
current_time = datetime.now(paris_tz).strftime("%I:%M%p")
if text_input != "":
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?"
else:
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?"
# Add current message
if image_input is not None:
# Check if image_input is already a PIL Image
if isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
image = Image.fromarray(image_input).convert("RGB")
images.append(image)
messages.append({
"role": "user",
"content": [{"type": "text", "text": text_input}, {"type": "image"}]
})
else:
messages.append({
"role": "user",
"content": [{"type": "text", "text": text_input}]
})
# Apply chat template
texts = processor.apply_chat_template(messages)
# Process inputs
if not images:
inputs = processor(text=texts, return_tensors="pt").to("cuda")
else:
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(
processor.tokenizer, skip_special_tokens=True, skip_prompt=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
chatbot[-1][1] = response
yield chatbot
thread.join()
# Debug output
print('*'*60)
print('*'*60)
print('BOT_STREAMING_CONV_START')
for i, (request, answer) in enumerate(chatbot[:-1], 1):
print(f'Q{i}:\n {request}')
print(f'A{i}:\n {answer}')
print('New_Q:\n', text_input)
print('New_A:\n', response)
print('BOT_STREAMING_CONV_END')
if IS_LOGGED_IN:
new_history = messages + [{"role": "assistant", "content": [{"type": "text", "text": response}]}]
update_user_history(USER_ID, new_history)
# Define the HTML content for the header
html = f"""
<p align="center" style="font-size: 2.5em; line-height: 1;">
<span style="display: inline-block; vertical-align: middle;">🍽️</span>
<span style="display: inline-block; vertical-align: middle;">PixDiet</span>
</p>
<center><font size=3><b>PixDiet</b> is your AI nutrition expert. Upload an image of your meal and chat with our AI to get personalized advice on your diet, meal composition, and ways to improve your nutrition.</font></center>
<div style="display: flex; justify-content: center; align-items: center; margin-top: 20px;">
<img src="https://zozs97eh0bkqexza.public.blob.vercel-storage.com/alan-VD7bRf1rKuEBL6EDAjw0eLGVodhoh8.png" alt="Alan AI Logo" style="height: 50px; margin-right: 20px;">
<img src="https://seeklogo.com/images/M/mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png" alt="Mistral AI Logo" style="height: 50px;">
</div>
"""
# Define LaTeX delimiters
latex_delimiters_set = [
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\begin{equation}", "right": "\\end{equation}", "display": True},
{"left": "\\begin{align}", "right": "\\end{align}", "display": True},
{"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True},
{"left": "\\begin{gather}", "right": "\\end{gather}", "display": True},
{"left": "\\begin{CD}", "right": "\\end{CD}", "display": True},
{"left": "\\[", "right": "\\]", "display": True}
]
# Create the Gradio interface
with gr.Blocks(title="PixDiet", theme=gr.themes.Ocean()) as demo:
gr.HTML(html)
with gr.Row():
with gr.Column(scale=3):
image_input = gr.Image(label="Upload your meal image", height=350, type="pil")
gr.Examples(
examples=[
["./examples/mistral_breakfast.jpeg", ""],
["./examples/mistral_desert.jpeg", ""],
["./examples/mistral_snacks.jpeg", ""],
["./examples/mistral_pasta.jpeg", ""],
],
inputs=[image_input, gr.Textbox(visible=False)]
)
with gr.Column(scale=7):
chatbot = gr.Chatbot(label="Chat with PixDiet", layout="panel", height=600, show_copy_button=True, latex_delimiters=latex_delimiters_set)
text_input = gr.Textbox(label="Ask about your meal", placeholder="(Optional) Enter your message here...", lines=1, container=False)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Delete my historic", variant="huggingface")
def submit_chat(chatbot, text_input):
response = ''
chatbot.append((text_input, response))
return chatbot, ''
def clear_chat():
delete_user_history(USER_ID)
return [], None, ""
send_click_event = send_btn.click(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(
bot_streaming, [chatbot, image_input], chatbot
)
submit_event = text_input.submit(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(
bot_streaming, [chatbot, image_input], chatbot
)
clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input])
if __name__ == "__main__":
demo.launch(debug=False, share=False, show_api=False)