import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer, BitsAndBytesConfig import gradio as gr from threading import Thread from PIL import Image import subprocess import spaces # Add this import # Install flash-attention subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) # Constants TITLE = "

Phi 3.5 Multimodal (Text + Vision)

" DESCRIPTION = "# Phi-3.5 Multimodal Demo (Text + Vision)" # Model configurations TEXT_MODEL_ID = "microsoft/Phi-3.5-mini-instruct" VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" device = "cuda" if torch.cuda.is_available() else "cpu" # Quantization config for text model quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) # Load models and tokenizers text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID) text_model = AutoModelForCausalLM.from_pretrained( TEXT_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config ) vision_model = AutoModelForCausalLM.from_pretrained( VISION_MODEL_ID, trust_remote_code=True, torch_dtype="auto", attn_implementation="flash_attention_2" ).to(device).eval() vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True) # Helper functions @spaces.GPU def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20): conversation = [{"role": "system", "content": system_prompt}] for prompt, answer in history: conversation.extend([ {"role": "user", "content": prompt}, {"role": "assistant", "content": answer}, ]) conversation.append({"role": "user", "content": message}) input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device) streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=temperature > 0, top_p=top_p, top_k=top_k, temperature=temperature, eos_token_id=[128001, 128008, 128009], streamer=streamer, ) with torch.no_grad(): thread = Thread(target=text_model.generate, kwargs=generate_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield history + [[message, buffer]] @spaces.GPU # Add this decorator def process_vision_query(image, text_input): prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n" image = Image.fromarray(image).convert("RGB") inputs = vision_processor(prompt, image, return_tensors="pt").to(device) with torch.no_grad(): generate_ids = vision_model.generate( **inputs, max_new_tokens=1000, eos_token_id=vision_processor.tokenizer.eos_token_id ) generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] return response # Custom CSS custom_css = """ body { background-color: #0b0f19; color: #e2e8f0; font-family: 'Arial', sans-serif;} #custom-header { text-align: center; padding: 20px 0; background-color: #1a202c; margin-bottom: 20px; border-radius: 10px;} #custom-header h1 { font-size: 2.5rem; margin-bottom: 0.5rem;} #custom-header h1 .blue { color: #60a5fa;} #custom-header h1 .pink { color: #f472b6;} #custom-header h2 { font-size: 1.5rem; color: #94a3b8;} .suggestions { display: flex; justify-content: center; flex-wrap: wrap; gap: 1rem; margin: 20px 0;} .suggestion { background-color: #1e293b; border-radius: 0.5rem; padding: 1rem; display: flex; align-items: center; transition: transform 0.3s ease; width: 200px;} .suggestion:hover { transform: translateY(-5px);} .suggestion-icon { font-size: 1.5rem; margin-right: 1rem; background-color: #2d3748; padding: 0.5rem; border-radius: 50%;} .gradio-container { max-width: 100% !important;} #component-0, #component-1, #component-2 { max-width: 100% !important;} footer { text-align: center; margin-top: 2rem; color: #64748b;} """ # Custom HTML for the header custom_header = """

Phi 3.5 Multimodal Assistant

Text and Vision AI at Your Service

""" # Custom HTML for suggestions custom_suggestions = """
💬

Chat with the Text Model

🖼️

Analyze Images with Vision Model

🤖

Get AI-generated responses

🔍

Explore advanced options

""" # Gradio interface with gr.Blocks(css=custom_css, theme=gr.themes.Base().set( body_background_fill="#0b0f19", body_text_color="#e2e8f0", button_primary_background_fill="#3b82f6", button_primary_background_fill_hover="#2563eb", button_primary_text_color="white", block_title_text_color="#94a3b8", block_label_text_color="#94a3b8", )) as demo: gr.HTML(custom_header) gr.HTML(custom_suggestions) with gr.Tab("Text Model (Phi-3.5-mini)"): chatbot = gr.Chatbot(height=400) msg = gr.Textbox(label="Message", placeholder="Type your message here...") with gr.Accordion("Advanced Options", open=False): system_prompt = gr.Textbox(value="You are a helpful assistant", label="System Prompt") temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature") max_new_tokens = gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens") top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p") top_k = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k") submit_btn = gr.Button("Submit", variant="primary") clear_btn = gr.Button("Clear Chat", variant="secondary") submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], [chatbot]) clear_btn.click(lambda: None, None, chatbot, queue=False) with gr.Tab("Vision Model (Phi-3.5-vision)"): with gr.Row(): with gr.Column(scale=1): vision_input_img = gr.Image(label="Upload an Image", type="pil") vision_text_input = gr.Textbox(label="Ask a question about the image", placeholder="What do you see in this image?") vision_submit_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(scale=1): vision_output_text = gr.Textbox(label="AI Analysis", lines=10) vision_submit_btn.click(process_vision_query, [vision_input_img, vision_text_input], [vision_output_text]) gr.HTML("") if __name__ == "__main__": demo.launch()