Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,493 Bytes
820ac3d a5e055b 5904b1d 7e2d83a 820ac3d 7e2d83a a5e055b 820ac3d d45486e 820ac3d a5e055b 820ac3d a5e055b 820ac3d 5904b1d 820ac3d 5904b1d 820ac3d 5904b1d 820ac3d a5e055b 820ac3d 7e2d83a 820ac3d 11cd804 820ac3d 7e2d83a 820ac3d 106d95c 820ac3d a5e055b 820ac3d f073c65 820ac3d 5904b1d 820ac3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer
import gradio as gr
from threading import Thread
from PIL import Image
# Constants
TITLE = "<h1><center>Phi 3.5 Multimodal (Text + Vision)</center></h1>"
DESCRIPTION = "# Phi-3.5 Multimodal Demo (Text + Vision)"
# Model configurations
TEXT_MODEL_ID = "microsoft/Phi-3.5-mini-instruct"
VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load models and tokenizers
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
text_model = AutoModelForCausalLM.from_pretrained(
TEXT_MODEL_ID,
torch_dtype=torch.float32 if device == "cpu" else torch.float16,
device_map="auto" if device == "cuda" else None,
low_cpu_mem_usage=True
)
if device == "cuda":
text_model = text_model.half() # Convert to half precision if on GPU
vision_model = AutoModelForCausalLM.from_pretrained(
VISION_MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float32 if device == "cpu" else torch.float16,
low_cpu_mem_usage=True
).to(device).eval()
vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
# Helper functions
def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
conversation = [{"role": "system", "content": system_prompt}]
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer},
])
conversation.append({"role": "user", "content": message})
input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
top_p=top_p,
top_k=top_k,
temperature=temperature,
eos_token_id=[128001, 128008, 128009],
streamer=streamer,
)
with torch.no_grad():
thread = Thread(target=text_model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
def process_vision_query(image, text_input):
prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
image = Image.fromarray(image).convert("RGB")
inputs = vision_processor(prompt, image, return_tensors="pt").to(device)
with torch.no_grad():
generate_ids = vision_model.generate(
**inputs,
max_new_tokens=1000,
eos_token_id=vision_processor.tokenizer.eos_token_id
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return response
# Gradio interface
with gr.Blocks() as demo:
gr.HTML(TITLE)
gr.Markdown(DESCRIPTION)
with gr.Tab("Text Model (Phi-3.5-mini)"):
chatbot = gr.Chatbot(height=600)
gr.ChatInterface(
fn=stream_text_chat,
chatbot=chatbot,
additional_inputs=[
gr.Textbox(value="You are a helpful assistant", label="System Prompt"),
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature"),
gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p"),
gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k"),
],
)
with gr.Tab("Vision Model (Phi-3.5-vision)"):
with gr.Row():
with gr.Column():
vision_input_img = gr.Image(label="Input Picture")
vision_text_input = gr.Textbox(label="Question")
vision_submit_btn = gr.Button(value="Submit")
with gr.Column():
vision_output_text = gr.Textbox(label="Output Text")
vision_submit_btn.click(process_vision_query, [vision_input_img, vision_text_input], [vision_output_text])
if __name__ == "__main__":
print(f"Running on device: {device}")
demo.launch() |