PixDiet / app.py
blanchon's picture
Remove login
3e661b5 verified
raw
history blame
9.93 kB
import spaces
from transformers import (
TextIteratorStreamer,
)
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
LlavaForConditionalGeneration,
)
from PIL import Image
import gradio as gr
from threading import Thread
from dotenv import load_dotenv
# Add these imports
from datetime import datetime
import pytz
from typing import Optional
from transformers import AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer
import torch
from theme import Seafoam
load_dotenv()
# Add TESTING variable
TESTING = False
# Hugging Face model id
# model_id = "mistral-community/pixtral-12b"
model_id = "blanchon/PixDiet-pixtral-nutrition-v2"
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Modify the model and processor initialization
if TESTING:
model_id = "vikhyatk/moondream1"
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
processor = Tokenizer.from_pretrained(model_id)
else:
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
)
processor = AutoProcessor.from_pretrained(model_id)
# Set the chat template for the tokenizer
processor.chat_template = """
{%- for message in messages %}
{%- if message.role == "user" %}
<s>[INST]
{%- for item in message.content %}
{%- if item.type == "text" %}
{{ item.text }}
{%- elif item.type == "image" %}
\n[IMG]
{%- endif %}
{%- endfor %}
[/INST]
{%- elif message.role == "assistant" %}
{%- for item in message.content %}
{%- if item.type == "text" %}
{{ item.text }}
{%- endif %}
{%- endfor %}
</s>
{%- endif %}
{%- endfor %}
""".replace(" ", "")
processor.tokenizer.pad_token = processor.tokenizer.eos_token
@spaces.GPU
def bot_streaming(chatbot, image_input, max_new_tokens=250):
# Preprocess inputs
images = []
text_input = chatbot[-1][0]
# Get current time in Paris timezone
paris_tz = pytz.timezone("Europe/Paris")
current_time = datetime.now(paris_tz).strftime("%I:%M%p")
if text_input != "":
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?"
else:
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?"
# Add current message
if image_input is not None:
# Check if image_input is already a PIL Image
if isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
image = Image.fromarray(image_input).convert("RGB")
images.append(image)
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": text_input}, {"type": "image"}],
}
)
else:
messages.append(
{"role": "user", "content": [{"type": "text", "text": text_input}]}
)
# Apply chat template
texts = processor.apply_chat_template(messages)
# Process inputs
if not images:
inputs = processor(text=texts, return_tensors="pt").to("cuda")
else:
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(
processor.tokenizer, skip_special_tokens=True, skip_prompt=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
chatbot[-1][1] = response
yield chatbot
thread.join()
# Debug output
print("*" * 60)
print("*" * 60)
print("BOT_STREAMING_CONV_START")
for i, (request, answer) in enumerate(chatbot[:-1], 1):
print(f"Q{i}:\n {request}")
print(f"A{i}:\n {answer}")
print("New_Q:\n", text_input)
print("New_A:\n", response)
print("BOT_STREAMING_CONV_END")
new_history = messages + [
{"role": "assistant", "content": [{"type": "text", "text": response}]}
]
seafoam = Seafoam()
# Define the HTML content for the header
html = """
<!-- Foreground content -->
<p align="center" style="font-size: 2.5em; line-height: 1; ">
<span style="display: inline-block; vertical-align: middle;">🍽️</span>
<span style="display: inline-block; vertical-align: middle;">PixDiet</span>
</p>
<center>
<font size=3><b>PixDiet</b> is your AI nutrition expert. Upload an image of your meal and chat with our AI to get personalized advice on your diet, meal composition, and ways to improve your nutrition.</font>
</center>
<!-- Background image positioned behind everything -->
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; margin-top: 20px; width: 100%;">
<div style="display: flex; justify-content: center; width: 100%;">
<img src="https://dropshare.blanchon.xyz/public/dropshare/alan.png" alt="Alan AI Logo" style="height: 50px; margin-right: 20px;">
<img src="https://dropshare.blanchon.xyz/public/dropshare/mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png" alt="Mistral AI Logo" style="height: 50px;">
</div>
</div>
"""
footer_html = """
<!-- Footer content -->
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; margin-top: 20px; width: 100%;">
<div style="display: flex; justify-content: center; width: 100%;">
<img src="https://dropshare.blanchon.xyz/public/dropshare//VariantVariant6-Photoroom.png" alt="Background Image"
style="height: 100px; width: 100%; object-fit: scale-down;">
</div>
<div>
Made with ❤️ during the Mistral AI x Alan Hackathon.
</div>
</div>
"""
# Define LaTeX delimiters
latex_delimiters_set = [
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\begin{equation}", "right": "\\end{equation}", "display": True},
{"left": "\\begin{align}", "right": "\\end{align}", "display": True},
{"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True},
{"left": "\\begin{gather}", "right": "\\end{gather}", "display": True},
{"left": "\\begin{CD}", "right": "\\end{CD}", "display": True},
{"left": "\\[", "right": "\\]", "display": True},
]
# Create the Gradio interface
with gr.Blocks(
title="PixDiet", theme=seafoam, css="footer{display:none !important}"
) as demo:
gr.HTML(html)
with gr.Row():
with gr.Column(scale=3):
about_you = gr.Textbox(
label="About you",
placeholder="Add information about you here...",
lines=3,
interactive=True,
)
image_input = gr.Image(
label="Upload your meal image", height=350, type="pil"
)
gr.Examples(
examples=[
[
"./examples/mistral_breakfast.jpeg",
"John, 45 years old, 80kg, lactose intolerant. Training for his first triathlon.",
],
[
"./examples/mistral_desert.jpeg",
"Emma, 26 years old, 55kg, iron deficiency. Training for her first Ironman competition.",
],
[
"./examples/mistral_snacks.jpeg",
"Paul, 34 years old, 62kg, no known pathologies. Focused on improving strength for weightlifting competitions.",
],
[
"./examples/mistral_pasta.jpeg",
"Carla, 52 years old, 58kg, no known pathologies. Currently training for her first marathon.",
],
],
inputs=[image_input, about_you],
)
with gr.Column(scale=7):
chatbot = gr.Chatbot(
label="Chat with PixDiet",
layout="panel",
height=700,
show_copy_button=True,
latex_delimiters=latex_delimiters_set,
type=None,
)
text_input = gr.Textbox(
label="Ask about your meal",
placeholder="(Optional) Enter your message here...",
lines=1,
container=False,
interactive=True,
)
with gr.Row():
send_btn = gr.Button("Send", variant="primary", visible=True)
clear_btn = gr.Button(
"Delete my history",
variant="stop",
visible=True,
)
def submit_chat(chatbot, text_input):
response = ""
chatbot.append((text_input, response))
return chatbot, ""
def clear_chat():
return [], None, ""
send_click_event = send_btn.click(
submit_chat, [chatbot, text_input], [chatbot, text_input]
).then(bot_streaming, [chatbot, image_input], chatbot)
submit_event = text_input.submit(
submit_chat, [chatbot, text_input], [chatbot, text_input]
).then(bot_streaming, [chatbot, image_input], chatbot)
clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input])
gr.HTML(footer_html)
if __name__ == "__main__":
demo.launch(debug=False, share=False, show_api=False)