import spaces
from transformers import (
TextIteratorStreamer,
)
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
LlavaForConditionalGeneration,
)
from PIL import Image
import gradio as gr
from threading import Thread
from dotenv import load_dotenv
# Add these imports
from datetime import datetime
import pytz
from typing import Optional
from transformers import AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer
import torch
from theme import Seafoam
load_dotenv()
# Add TESTING variable
TESTING = False
# Hugging Face model id
# model_id = "mistral-community/pixtral-12b"
model_id = "blanchon/PixDiet-pixtral-nutrition-v2"
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Modify the model and processor initialization
if TESTING:
model_id = "vikhyatk/moondream1"
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
processor = Tokenizer.from_pretrained(model_id)
else:
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
)
processor = AutoProcessor.from_pretrained(model_id)
# Set the chat template for the tokenizer
processor.chat_template = """
{%- for message in messages %}
{%- if message.role == "user" %}
[INST]
{%- for item in message.content %}
{%- if item.type == "text" %}
{{ item.text }}
{%- elif item.type == "image" %}
\n[IMG]
{%- endif %}
{%- endfor %}
[/INST]
{%- elif message.role == "assistant" %}
{%- for item in message.content %}
{%- if item.type == "text" %}
{{ item.text }}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- endfor %}
""".replace(" ", "")
processor.tokenizer.pad_token = processor.tokenizer.eos_token
@spaces.GPU
def bot_streaming(chatbot, image_input, max_new_tokens=250):
# Preprocess inputs
images = []
text_input = chatbot[-1][0]
# Get current time in Paris timezone
paris_tz = pytz.timezone("Europe/Paris")
current_time = datetime.now(paris_tz).strftime("%I:%M%p")
if text_input != "":
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?"
else:
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?"
# Add current message
if image_input is not None:
# Check if image_input is already a PIL Image
if isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
image = Image.fromarray(image_input).convert("RGB")
images.append(image)
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": text_input}, {"type": "image"}],
}
)
else:
messages.append(
{"role": "user", "content": [{"type": "text", "text": text_input}]}
)
# Apply chat template
texts = processor.apply_chat_template(messages)
# Process inputs
if not images:
inputs = processor(text=texts, return_tensors="pt").to("cuda")
else:
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(
processor.tokenizer, skip_special_tokens=True, skip_prompt=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
chatbot[-1][1] = response
yield chatbot
thread.join()
# Debug output
print("*" * 60)
print("*" * 60)
print("BOT_STREAMING_CONV_START")
for i, (request, answer) in enumerate(chatbot[:-1], 1):
print(f"Q{i}:\n {request}")
print(f"A{i}:\n {answer}")
print("New_Q:\n", text_input)
print("New_A:\n", response)
print("BOT_STREAMING_CONV_END")
new_history = messages + [
{"role": "assistant", "content": [{"type": "text", "text": response}]}
]
seafoam = Seafoam()
# Define the HTML content for the header
html = """
🍽️ PixDiet