Spaces:
Sleeping
Sleeping
import spaces | |
from transformers import ( | |
TextIteratorStreamer, | |
) | |
from transformers import ( | |
AutoProcessor, | |
BitsAndBytesConfig, | |
LlavaForConditionalGeneration, | |
) | |
from PIL import Image | |
import gradio as gr | |
from threading import Thread | |
from dotenv import load_dotenv | |
# Add these imports | |
from datetime import datetime | |
import pytz | |
from typing import Optional | |
from transformers import AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer | |
import torch | |
from theme import Seafoam | |
load_dotenv() | |
# Add TESTING variable | |
TESTING = False | |
# Hugging Face model id | |
# model_id = "mistral-community/pixtral-12b" | |
model_id = "blanchon/PixDiet-pixtral-nutrition-v2" | |
# BitsAndBytesConfig int-4 config | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
# Modify the model and processor initialization | |
if TESTING: | |
model_id = "vikhyatk/moondream1" | |
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) | |
processor = Tokenizer.from_pretrained(model_id) | |
else: | |
model = LlavaForConditionalGeneration.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
quantization_config=bnb_config, | |
) | |
processor = AutoProcessor.from_pretrained(model_id) | |
# Set the chat template for the tokenizer | |
processor.chat_template = """ | |
{%- for message in messages %} | |
{%- if message.role == "user" %} | |
<s>[INST] | |
{%- for item in message.content %} | |
{%- if item.type == "text" %} | |
{{ item.text }} | |
{%- elif item.type == "image" %} | |
\n[IMG] | |
{%- endif %} | |
{%- endfor %} | |
[/INST] | |
{%- elif message.role == "assistant" %} | |
{%- for item in message.content %} | |
{%- if item.type == "text" %} | |
{{ item.text }} | |
{%- endif %} | |
{%- endfor %} | |
</s> | |
{%- endif %} | |
{%- endfor %} | |
""".replace(" ", "") | |
processor.tokenizer.pad_token = processor.tokenizer.eos_token | |
def bot_streaming(chatbot, image_input, max_new_tokens=250): | |
# Preprocess inputs | |
images = [] | |
text_input = chatbot[-1][0] | |
# Get current time in Paris timezone | |
paris_tz = pytz.timezone("Europe/Paris") | |
current_time = datetime.now(paris_tz).strftime("%I:%M%p") | |
if text_input != "": | |
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?" | |
else: | |
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?" | |
# Add current message | |
if image_input is not None: | |
# Check if image_input is already a PIL Image | |
if isinstance(image_input, Image.Image): | |
image = image_input.convert("RGB") | |
else: | |
image = Image.fromarray(image_input).convert("RGB") | |
images.append(image) | |
messages.append( | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": text_input}, {"type": "image"}], | |
} | |
) | |
else: | |
messages.append( | |
{"role": "user", "content": [{"type": "text", "text": text_input}]} | |
) | |
# Apply chat template | |
texts = processor.apply_chat_template(messages) | |
# Process inputs | |
if not images: | |
inputs = processor(text=texts, return_tensors="pt").to("cuda") | |
else: | |
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda") | |
streamer = TextIteratorStreamer( | |
processor.tokenizer, skip_special_tokens=True, skip_prompt=True | |
) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
response = "" | |
for new_text in streamer: | |
response += new_text | |
chatbot[-1][1] = response | |
yield chatbot | |
thread.join() | |
# Debug output | |
print("*" * 60) | |
print("*" * 60) | |
print("BOT_STREAMING_CONV_START") | |
for i, (request, answer) in enumerate(chatbot[:-1], 1): | |
print(f"Q{i}:\n {request}") | |
print(f"A{i}:\n {answer}") | |
print("New_Q:\n", text_input) | |
print("New_A:\n", response) | |
print("BOT_STREAMING_CONV_END") | |
new_history = messages + [ | |
{"role": "assistant", "content": [{"type": "text", "text": response}]} | |
] | |
seafoam = Seafoam() | |
# Define the HTML content for the header | |
html = """ | |
<!-- Foreground content --> | |
<p align="center" style="font-size: 2.5em; line-height: 1; "> | |
<span style="display: inline-block; vertical-align: middle;">🍽️</span> | |
<span style="display: inline-block; vertical-align: middle;">PixDiet</span> | |
</p> | |
<center> | |
<font size=3><b>PixDiet</b> is your AI nutrition expert. Upload an image of your meal and chat with our AI to get personalized advice on your diet, meal composition, and ways to improve your nutrition.</font> | |
</center> | |
<!-- Background image positioned behind everything --> | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; margin-top: 20px; width: 100%;"> | |
<div style="display: flex; justify-content: center; width: 100%;"> | |
<img src="https://dropshare.blanchon.xyz/public/dropshare/alan.png" alt="Alan AI Logo" style="height: 50px; margin-right: 20px;"> | |
<img src="https://dropshare.blanchon.xyz/public/dropshare/mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png" alt="Mistral AI Logo" style="height: 50px;"> | |
</div> | |
</div> | |
""" | |
footer_html = """ | |
<!-- Footer content --> | |
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; margin-top: 20px; width: 100%;"> | |
<div style="display: flex; justify-content: center; width: 100%;"> | |
<img src="https://dropshare.blanchon.xyz/public/dropshare//VariantVariant6-Photoroom.png" alt="Background Image" | |
style="height: 100px; width: 100%; object-fit: scale-down;"> | |
</div> | |
<div> | |
Made with ❤️ during the Mistral AI x Alan Hackathon. | |
</div> | |
</div> | |
""" | |
# Define LaTeX delimiters | |
latex_delimiters_set = [ | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\begin{equation}", "right": "\\end{equation}", "display": True}, | |
{"left": "\\begin{align}", "right": "\\end{align}", "display": True}, | |
{"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True}, | |
{"left": "\\begin{gather}", "right": "\\end{gather}", "display": True}, | |
{"left": "\\begin{CD}", "right": "\\end{CD}", "display": True}, | |
{"left": "\\[", "right": "\\]", "display": True}, | |
] | |
# Create the Gradio interface | |
with gr.Blocks( | |
title="PixDiet", theme=seafoam, css="footer{display:none !important}" | |
) as demo: | |
gr.HTML(html) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
about_you = gr.Textbox( | |
label="About you", | |
placeholder="Add information about you here...", | |
lines=3, | |
interactive=True, | |
) | |
image_input = gr.Image( | |
label="Upload your meal image", height=350, type="pil" | |
) | |
gr.Examples( | |
examples=[ | |
[ | |
"./examples/mistral_breakfast.jpeg", | |
"John, 45 years old, 80kg, lactose intolerant. Training for his first triathlon.", | |
], | |
[ | |
"./examples/mistral_desert.jpeg", | |
"Emma, 26 years old, 55kg, iron deficiency. Training for her first Ironman competition.", | |
], | |
[ | |
"./examples/mistral_snacks.jpeg", | |
"Paul, 34 years old, 62kg, no known pathologies. Focused on improving strength for weightlifting competitions.", | |
], | |
[ | |
"./examples/mistral_pasta.jpeg", | |
"Carla, 52 years old, 58kg, no known pathologies. Currently training for her first marathon.", | |
], | |
], | |
inputs=[image_input, about_you], | |
) | |
with gr.Column(scale=7): | |
chatbot = gr.Chatbot( | |
label="Chat with PixDiet", | |
layout="panel", | |
height=700, | |
show_copy_button=True, | |
latex_delimiters=latex_delimiters_set, | |
type=None, | |
) | |
text_input = gr.Textbox( | |
label="Ask about your meal", | |
placeholder="(Optional) Enter your message here...", | |
lines=1, | |
container=False, | |
interactive=True, | |
) | |
with gr.Row(): | |
send_btn = gr.Button("Send", variant="primary", visible=True) | |
clear_btn = gr.Button( | |
"Delete my history", | |
variant="stop", | |
visible=True, | |
) | |
def submit_chat(chatbot, text_input): | |
response = "" | |
chatbot.append((text_input, response)) | |
return chatbot, "" | |
def clear_chat(): | |
return [], None, "" | |
send_click_event = send_btn.click( | |
submit_chat, [chatbot, text_input], [chatbot, text_input] | |
).then(bot_streaming, [chatbot, image_input], chatbot) | |
submit_event = text_input.submit( | |
submit_chat, [chatbot, text_input], [chatbot, text_input] | |
).then(bot_streaming, [chatbot, image_input], chatbot) | |
clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input]) | |
gr.HTML(footer_html) | |
if __name__ == "__main__": | |
demo.launch(debug=False, share=False, show_api=False) | |