Spaces:
Sleeping
Sleeping
import torch | |
import spaces | |
from transformers import ( | |
AutoProcessor, | |
BitsAndBytesConfig, | |
LlavaForConditionalGeneration, | |
) | |
from PIL import Image | |
import gradio as gr | |
from threading import Thread | |
from transformers import TextIteratorStreamer, AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer | |
from dotenv import load_dotenv | |
import os | |
# Import Supabase functions | |
from db_client import get_user_history, update_user_history, delete_user_history | |
# Add these imports | |
from datetime import datetime | |
import pytz | |
load_dotenv() | |
# Add TESTING variable | |
TESTING = False # You can change this to False when not testing | |
IS_LOGGED_IN = True | |
USER_ID = "jeremie.feron@gmail.com" | |
# Hugging Face model id | |
model_id = "blanchon/pixtral-nutrition-2" | |
# BitsAndBytesConfig int-4 config | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
# Modify the model and processor initialization | |
if TESTING: | |
model_id = "vikhyatk/moondream1" | |
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) | |
processor = Tokenizer.from_pretrained(model_id) | |
else: | |
model = LlavaForConditionalGeneration.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
quantization_config=bnb_config, | |
) | |
processor = AutoProcessor.from_pretrained(model_id) | |
# Set the chat template for the tokenizer | |
processor.chat_template = """ | |
{%- for message in messages %} | |
{%- if message.role == "user" %} | |
<s>[INST] | |
{%- for item in message.content %} | |
{%- if item.type == "text" %} | |
{{ item.text }} | |
{%- elif item.type == "image" %} | |
\n[IMG] | |
{%- endif %} | |
{%- endfor %} | |
[/INST] | |
{%- elif message.role == "assistant" %} | |
{%- for item in message.content %} | |
{%- if item.type == "text" %} | |
{{ item.text }} | |
{%- endif %} | |
{%- endfor %} | |
</s> | |
{%- endif %} | |
{%- endfor %} | |
""".replace(' ', "") | |
processor.tokenizer.pad_token = processor.tokenizer.eos_token | |
def bot_streaming(chatbot, image_input, max_new_tokens=250): | |
# Preprocess inputs | |
messages = get_user_history(USER_ID) | |
images = [] | |
text_input = chatbot[-1][0] | |
# Get current time in Paris timezone | |
paris_tz = pytz.timezone('Europe/Paris') | |
current_time = datetime.now(paris_tz).strftime("%I:%M%p") | |
if text_input != "": | |
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?" | |
else: | |
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?" | |
# Add current message | |
if image_input is not None: | |
# Check if image_input is already a PIL Image | |
if isinstance(image_input, Image.Image): | |
image = image_input.convert("RGB") | |
else: | |
image = Image.fromarray(image_input).convert("RGB") | |
images.append(image) | |
messages.append({ | |
"role": "user", | |
"content": [{"type": "text", "text": text_input}, {"type": "image"}] | |
}) | |
else: | |
messages.append({ | |
"role": "user", | |
"content": [{"type": "text", "text": text_input}] | |
}) | |
# Apply chat template | |
texts = processor.apply_chat_template(messages) | |
# Process inputs | |
if not images: | |
inputs = processor(text=texts, return_tensors="pt").to("cuda") | |
else: | |
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda") | |
streamer = TextIteratorStreamer( | |
processor.tokenizer, skip_special_tokens=True, skip_prompt=True | |
) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
response = "" | |
for new_text in streamer: | |
response += new_text | |
chatbot[-1][1] = response | |
yield chatbot | |
thread.join() | |
# Debug output | |
print('*'*60) | |
print('*'*60) | |
print('BOT_STREAMING_CONV_START') | |
for i, (request, answer) in enumerate(chatbot[:-1], 1): | |
print(f'Q{i}:\n {request}') | |
print(f'A{i}:\n {answer}') | |
print('New_Q:\n', text_input) | |
print('New_A:\n', response) | |
print('BOT_STREAMING_CONV_END') | |
if IS_LOGGED_IN: | |
new_history = messages + [{"role": "assistant", "content": [{"type": "text", "text": response}]}] | |
update_user_history(USER_ID, new_history) | |
# Define the HTML content for the header | |
html = f""" | |
<p align="center" style="font-size: 2.5em; line-height: 1;"> | |
<span style="display: inline-block; vertical-align: middle;">🍽️</span> | |
<span style="display: inline-block; vertical-align: middle;">PixDiet</span> | |
</p> | |
<center><font size=3><b>PixDiet</b> is your AI nutrition expert. Upload an image of your meal and chat with our AI to get personalized advice on your diet, meal composition, and ways to improve your nutrition.</font></center> | |
<div style="display: flex; justify-content: center; align-items: center; margin-top: 20px;"> | |
<img src="https://zozs97eh0bkqexza.public.blob.vercel-storage.com/alan-VD7bRf1rKuEBL6EDAjw0eLGVodhoh8.png" alt="Alan AI Logo" style="height: 50px; margin-right: 20px;"> | |
<img src="https://seeklogo.com/images/M/mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png" alt="Mistral AI Logo" style="height: 50px;"> | |
</div> | |
""" | |
# Define LaTeX delimiters | |
latex_delimiters_set = [ | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\begin{equation}", "right": "\\end{equation}", "display": True}, | |
{"left": "\\begin{align}", "right": "\\end{align}", "display": True}, | |
{"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True}, | |
{"left": "\\begin{gather}", "right": "\\end{gather}", "display": True}, | |
{"left": "\\begin{CD}", "right": "\\end{CD}", "display": True}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
] | |
# Create the Gradio interface | |
with gr.Blocks(title="PixDiet", theme=gr.themes.Ocean()) as demo: | |
gr.HTML(html) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
image_input = gr.Image(label="Upload your meal image", height=350, type="pil") | |
gr.Examples( | |
examples=[ | |
["./examples/mistral_breakfast.jpeg", ""], | |
["./examples/mistral_desert.jpeg", ""], | |
["./examples/mistral_snacks.jpeg", ""], | |
["./examples/mistral_pasta.jpeg", ""], | |
], | |
inputs=[image_input, gr.Textbox(visible=False)] | |
) | |
with gr.Column(scale=7): | |
chatbot = gr.Chatbot(label="Chat with PixDiet", layout="panel", height=600, show_copy_button=True, latex_delimiters=latex_delimiters_set) | |
text_input = gr.Textbox(label="Ask about your meal", placeholder="(Optional) Enter your message here...", lines=1, container=False) | |
with gr.Row(): | |
send_btn = gr.Button("Send", variant="primary") | |
clear_btn = gr.Button("Delete my historic", variant="huggingface") | |
def submit_chat(chatbot, text_input): | |
response = '' | |
chatbot.append((text_input, response)) | |
return chatbot, '' | |
def clear_chat(): | |
delete_user_history(USER_ID) | |
return [], None, "" | |
send_click_event = send_btn.click(submit_chat, [chatbot, text_input], [chatbot, text_input]).then( | |
bot_streaming, [chatbot, image_input], chatbot | |
) | |
submit_event = text_input.submit(submit_chat, [chatbot, text_input], [chatbot, text_input]).then( | |
bot_streaming, [chatbot, image_input], chatbot | |
) | |
clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input]) | |
if __name__ == "__main__": | |
demo.launch(debug=False, share=False, show_api=False) | |