XGen-MM / app.py
maxiw's picture
load models to GPU during use to fit all versions
3858798
raw
history blame
5.95 kB
import gradio as gr
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor, StoppingCriteria
import spaces
import torch
from PIL import Image
models = {
"Salesforce/xgen-mm-phi3-mini-instruct-r-v1": AutoModelForVision2Seq.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-r-v1", trust_remote_code=True),
"Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5": AutoModelForVision2Seq.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5", trust_remote_code=True),
"Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5": AutoModelForVision2Seq.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5", trust_remote_code=True),
"Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5": AutoModelForVision2Seq.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5", trust_remote_code=True)
}
processors = {
"Salesforce/xgen-mm-phi3-mini-instruct-r-v1": AutoImageProcessor.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-r-v1", trust_remote_code=True),
"Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5": AutoImageProcessor.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5", trust_remote_code=True),
"Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5": AutoImageProcessor.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5", trust_remote_code=True),
"Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5": AutoImageProcessor.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5", trust_remote_code=True)
}
tokenizers = {
"Salesforce/xgen-mm-phi3-mini-instruct-r-v1": AutoTokenizer.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-r-v1", trust_remote_code=True, use_fast=False, legacy=False),
"Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5": AutoTokenizer.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5", trust_remote_code=True, use_fast=False, legacy=False),
"Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5": AutoTokenizer.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5", trust_remote_code=True, use_fast=False, legacy=False),
"Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5": AutoTokenizer.from_pretrained("Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5", trust_remote_code=True, use_fast=False, legacy=False)
}
DESCRIPTION = "# [xGen-MM Demo](https://huggingface.co/collections/Salesforce/xgen-mm-1-models-662971d6cecbf3a7f80ecc2e)"
def apply_prompt_template(prompt):
s = (
'<|system|>\nA chat between a curious user and an artificial intelligence assistant. '
"The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
f'<|user|>\n<image>\n{prompt}<|end|>\n<|assistant|>\n'
)
return s
class EosListStoppingCriteria(StoppingCriteria):
def __init__(self, eos_sequence = [32007]):
self.eos_sequence = eos_sequence
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
return self.eos_sequence in last_ids
@spaces.GPU
def run_example(image, text_input=None, model_id="Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5"):
model = models[model_id].to("cuda").eval()
processor = processors[model_id]
tokenizer = tokenizers[model_id]
tokenizer = model.update_special_tokens(tokenizer)
if model_id == "Salesforce/xgen-mm-phi3-mini-instruct-r-v1":
image = Image.fromarray(image).convert("RGB")
prompt = apply_prompt_template(text_input)
language_inputs = tokenizer([prompt], return_tensors="pt")
inputs = processor([image], return_tensors="pt", image_aspect_ratio='anyres')
inputs.update(language_inputs)
inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
generated_text = model.generate(**inputs, image_size=[image.size],
pad_token_id=tokenizer.pad_token_id,
do_sample=False, max_new_tokens=768, top_p=None, num_beams=1,
stopping_criteria = [EosListStoppingCriteria()],
)
else:
image_list = []
image_sizes = []
img = Image.fromarray(image).convert("RGB")
image_list.append(processor([img], image_aspect_ratio='anyres')["pixel_values"].cuda())
image_sizes.append(img.size)
inputs = {
"pixel_values": [image_list]
}
prompt = apply_prompt_template(text_input)
language_inputs = tokenizer([prompt], return_tensors="pt")
inputs.update(language_inputs)
for name, value in inputs.items():
if isinstance(value, torch.Tensor):
inputs[name] = value.cuda()
generated_text = model.generate(**inputs, image_size=[image_sizes],
pad_token_id=tokenizer.pad_token_id,
do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1,
)
prediction = tokenizer.decode(generated_text[0], skip_special_tokens=True).split("<|end|>")[0]
return prediction
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="xGen-MM Input"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5")
text_input = gr.Textbox(label="Question")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Output Text")
submit_btn.click(run_example, [input_img, text_input, model_selector], [output_text])
demo.launch(debug=True)