Multimodal_App / app.py
sagar007's picture
Update app.py
5904b1d verified
raw
history blame
4.49 kB
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer
import gradio as gr
from threading import Thread
from PIL import Image
# Constants
TITLE = "<h1><center>Phi 3.5 Multimodal (Text + Vision)</center></h1>"
DESCRIPTION = "# Phi-3.5 Multimodal Demo (Text + Vision)"
# Model configurations
TEXT_MODEL_ID = "microsoft/Phi-3.5-mini-instruct"
VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load models and tokenizers
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
text_model = AutoModelForCausalLM.from_pretrained(
TEXT_MODEL_ID,
torch_dtype=torch.float32 if device == "cpu" else torch.float16,
device_map="auto" if device == "cuda" else None,
low_cpu_mem_usage=True
)
if device == "cuda":
text_model = text_model.half() # Convert to half precision if on GPU
vision_model = AutoModelForCausalLM.from_pretrained(
VISION_MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float32 if device == "cpu" else torch.float16,
low_cpu_mem_usage=True
).to(device).eval()
vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
# Helper functions
def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
conversation = [{"role": "system", "content": system_prompt}]
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer},
])
conversation.append({"role": "user", "content": message})
input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
top_p=top_p,
top_k=top_k,
temperature=temperature,
eos_token_id=[128001, 128008, 128009],
streamer=streamer,
)
with torch.no_grad():
thread = Thread(target=text_model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
def process_vision_query(image, text_input):
prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
image = Image.fromarray(image).convert("RGB")
inputs = vision_processor(prompt, image, return_tensors="pt").to(device)
with torch.no_grad():
generate_ids = vision_model.generate(
**inputs,
max_new_tokens=1000,
eos_token_id=vision_processor.tokenizer.eos_token_id
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return response
# Gradio interface
with gr.Blocks() as demo:
gr.HTML(TITLE)
gr.Markdown(DESCRIPTION)
with gr.Tab("Text Model (Phi-3.5-mini)"):
chatbot = gr.Chatbot(height=600)
gr.ChatInterface(
fn=stream_text_chat,
chatbot=chatbot,
additional_inputs=[
gr.Textbox(value="You are a helpful assistant", label="System Prompt"),
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature"),
gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p"),
gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k"),
],
)
with gr.Tab("Vision Model (Phi-3.5-vision)"):
with gr.Row():
with gr.Column():
vision_input_img = gr.Image(label="Input Picture")
vision_text_input = gr.Textbox(label="Question")
vision_submit_btn = gr.Button(value="Submit")
with gr.Column():
vision_output_text = gr.Textbox(label="Output Text")
vision_submit_btn.click(process_vision_query, [vision_input_img, vision_text_input], [vision_output_text])
if __name__ == "__main__":
print(f"Running on device: {device}")
demo.launch()