import spaces from transformers import ( TextIteratorStreamer, ) from transformers import ( AutoProcessor, BitsAndBytesConfig, LlavaForConditionalGeneration, ) from PIL import Image import gradio as gr from threading import Thread from dotenv import load_dotenv # Add these imports from datetime import datetime import pytz from typing import Optional from transformers import AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer import torch from theme import Seafoam load_dotenv() # Add TESTING variable TESTING = False # Hugging Face model id # model_id = "mistral-community/pixtral-12b" model_id = "blanchon/PixDiet-pixtral-nutrition-v2" # BitsAndBytesConfig int-4 config bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) # Modify the model and processor initialization if TESTING: model_id = "vikhyatk/moondream1" model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) processor = Tokenizer.from_pretrained(model_id) else: model = LlavaForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=bnb_config, ) processor = AutoProcessor.from_pretrained(model_id) # Set the chat template for the tokenizer processor.chat_template = """ {%- for message in messages %} {%- if message.role == "user" %} [INST] {%- for item in message.content %} {%- if item.type == "text" %} {{ item.text }} {%- elif item.type == "image" %} \n[IMG] {%- endif %} {%- endfor %} [/INST] {%- elif message.role == "assistant" %} {%- for item in message.content %} {%- if item.type == "text" %} {{ item.text }} {%- endif %} {%- endfor %} {%- endif %} {%- endfor %} """.replace(" ", "") processor.tokenizer.pad_token = processor.tokenizer.eos_token @spaces.GPU def bot_streaming(chatbot, image_input, max_new_tokens=250): # Preprocess inputs images = [] text_input = chatbot[-1][0] # Get current time in Paris timezone paris_tz = pytz.timezone("Europe/Paris") current_time = datetime.now(paris_tz).strftime("%I:%M%p") if text_input != "": text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?" else: text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?" # Add current message if image_input is not None: # Check if image_input is already a PIL Image if isinstance(image_input, Image.Image): image = image_input.convert("RGB") else: image = Image.fromarray(image_input).convert("RGB") images.append(image) messages.append( { "role": "user", "content": [{"type": "text", "text": text_input}, {"type": "image"}], } ) else: messages.append( {"role": "user", "content": [{"type": "text", "text": text_input}]} ) # Apply chat template texts = processor.apply_chat_template(messages) # Process inputs if not images: inputs = processor(text=texts, return_tensors="pt").to("cuda") else: inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda") streamer = TextIteratorStreamer( processor.tokenizer, skip_special_tokens=True, skip_prompt=True ) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() response = "" for new_text in streamer: response += new_text chatbot[-1][1] = response yield chatbot thread.join() # Debug output print("*" * 60) print("*" * 60) print("BOT_STREAMING_CONV_START") for i, (request, answer) in enumerate(chatbot[:-1], 1): print(f"Q{i}:\n {request}") print(f"A{i}:\n {answer}") print("New_Q:\n", text_input) print("New_A:\n", response) print("BOT_STREAMING_CONV_END") new_history = messages + [ {"role": "assistant", "content": [{"type": "text", "text": response}]} ] seafoam = Seafoam() # Define the HTML content for the header html = """

🍽️ PixDiet

PixDiet is your AI nutrition expert. Upload an image of your meal and chat with our AI to get personalized advice on your diet, meal composition, and ways to improve your nutrition.
Alan AI Logo Mistral AI Logo
""" footer_html = """
Background Image
Made with ❤️ during the Mistral AI x Alan Hackathon.
""" # Define LaTeX delimiters latex_delimiters_set = [ {"left": "\\(", "right": "\\)", "display": False}, {"left": "\\begin{equation}", "right": "\\end{equation}", "display": True}, {"left": "\\begin{align}", "right": "\\end{align}", "display": True}, {"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True}, {"left": "\\begin{gather}", "right": "\\end{gather}", "display": True}, {"left": "\\begin{CD}", "right": "\\end{CD}", "display": True}, {"left": "\\[", "right": "\\]", "display": True}, ] # Create the Gradio interface with gr.Blocks( title="PixDiet", theme=seafoam, css="footer{display:none !important}" ) as demo: gr.HTML(html) with gr.Row(): with gr.Column(scale=3): about_you = gr.Textbox( label="About you", placeholder="Add information about you here...", lines=3, interactive=True, ) image_input = gr.Image( label="Upload your meal image", height=350, type="pil" ) gr.Examples( examples=[ [ "./examples/mistral_breakfast.jpeg", "John, 45 years old, 80kg, lactose intolerant. Training for his first triathlon.", ], [ "./examples/mistral_desert.jpeg", "Emma, 26 years old, 55kg, iron deficiency. Training for her first Ironman competition.", ], [ "./examples/mistral_snacks.jpeg", "Paul, 34 years old, 62kg, no known pathologies. Focused on improving strength for weightlifting competitions.", ], [ "./examples/mistral_pasta.jpeg", "Carla, 52 years old, 58kg, no known pathologies. Currently training for her first marathon.", ], ], inputs=[image_input, about_you], ) with gr.Column(scale=7): chatbot = gr.Chatbot( label="Chat with PixDiet", layout="panel", height=700, show_copy_button=True, latex_delimiters=latex_delimiters_set, type=None, ) text_input = gr.Textbox( label="Ask about your meal", placeholder="(Optional) Enter your message here...", lines=1, container=False, interactive=True, ) with gr.Row(): send_btn = gr.Button("Send", variant="primary", visible=True) clear_btn = gr.Button( "Delete my history", variant="stop", visible=True, ) def submit_chat(chatbot, text_input): response = "" chatbot.append((text_input, response)) return chatbot, "" def clear_chat(): return [], None, "" send_click_event = send_btn.click( submit_chat, [chatbot, text_input], [chatbot, text_input] ).then(bot_streaming, [chatbot, image_input], chatbot) submit_event = text_input.submit( submit_chat, [chatbot, text_input], [chatbot, text_input] ).then(bot_streaming, [chatbot, image_input], chatbot) clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input]) gr.HTML(footer_html) if __name__ == "__main__": demo.launch(debug=False, share=False, show_api=False)