Multimodal_App / app.py
sagar007's picture
Update app.py
ab8bcac verified
raw
history blame
4.42 kB
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer
import gradio as gr
from threading import Thread
from PIL import Image
# Constants
TITLE = "<h1><center>Phi 3.5 Multimodal (Text + Vision)</center></h1>"
DESCRIPTION = "# Phi-3.5 Multimodal Demo (Text + Vision)"
# Model configurations
TEXT_MODEL_ID = "microsoft/Phi-3.5-mini-instruct"
VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load models and tokenizers
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
text_model = AutoModelForCausalLM.from_pretrained(
TEXT_MODEL_ID,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto",
low_cpu_mem_usage=True
)
vision_model = AutoModelForCausalLM.from_pretrained(
VISION_MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
attn_implementation="flash_attention_2" if device == "cuda" else None,
low_cpu_mem_usage=True
).to(device).eval()
vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
# Helper functions
def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
conversation = [{"role": "system", "content": system_prompt}]
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer},
])
conversation.append({"role": "user", "content": message})
input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
top_p=top_p,
top_k=top_k,
temperature=temperature,
eos_token_id=[128001, 128008, 128009],
streamer=streamer,
)
with torch.no_grad():
thread = Thread(target=text_model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
def process_vision_query(image, text_input):
prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
image = Image.fromarray(image).convert("RGB")
inputs = vision_processor(prompt, image, return_tensors="pt").to(device)
with torch.no_grad():
generate_ids = vision_model.generate(
**inputs,
max_new_tokens=1000,
eos_token_id=vision_processor.tokenizer.eos_token_id
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return response
# Gradio interface
with gr.Blocks() as demo:
gr.HTML(TITLE)
gr.Markdown(DESCRIPTION)
with gr.Tab("Text Model (Phi-3.5-mini)"):
chatbot = gr.Chatbot(height=600)
gr.ChatInterface(
fn=stream_text_chat,
chatbot=chatbot,
additional_inputs=[
gr.Textbox(value="You are a helpful assistant", label="System Prompt"),
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature"),
gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p"),
gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k"),
],
)
with gr.Tab("Vision Model (Phi-3.5-vision)"):
with gr.Row():
with gr.Column():
vision_input_img = gr.Image(label="Input Picture")
vision_text_input = gr.Textbox(label="Question")
vision_submit_btn = gr.Button(value="Submit")
with gr.Column():
vision_output_text = gr.Textbox(label="Output Text")
vision_submit_btn.click(process_vision_query, [vision_input_img, vision_text_input], [vision_output_text])
if __name__ == "__main__":
demo.launch()