|
import gradio as gr |
|
from transformers import AutoProcessor, Idefics3ForConditionalGeneration |
|
import re |
|
import time |
|
from PIL import Image |
|
import torch |
|
import spaces |
|
import subprocess |
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3") |
|
|
|
model = Idefics3ForConditionalGeneration.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", |
|
torch_dtype=torch.bfloat16, |
|
|
|
trust_remote_code=True).to("cuda") |
|
|
|
BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids |
|
EOS_WORDS_IDS = [processor.tokenizer.eos_token_id] |
|
|
|
@spaces.GPU |
|
def model_inference( |
|
images, text, decoding_strategy, temperature, max_new_tokens, |
|
repetition_penalty, top_p |
|
): |
|
if text == "" and not images: |
|
gr.Error("Please input a query and optionally image(s).") |
|
|
|
if text == "" and images: |
|
gr.Error("Please input a text query along the image(s).") |
|
|
|
if isinstance(images, Image.Image): |
|
images = [images] |
|
|
|
if isinstance(text, str): |
|
text = "<image>" + text |
|
text = [text] |
|
|
|
inputs = processor(text=text, images=images, padding=True, return_tensors="pt").to("cuda") |
|
print("inputs",inputs) |
|
|
|
assert decoding_strategy in [ |
|
"Greedy", |
|
"Top P Sampling", |
|
] |
|
if decoding_strategy == "Greedy": |
|
do_sample = False |
|
elif decoding_strategy == "Top P Sampling": |
|
do_sample = True |
|
|
|
|
|
|
|
generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=max_new_tokens, |
|
temperature=temperature, do_sample=do_sample, repetition_penalty=repetition_penalty, |
|
top_p=top_p), |
|
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) |
|
|
|
print("INPUT:", text, "|OUTPUT:", generated_texts) |
|
return generated_texts[0] |
|
|
|
|
|
with gr.Blocks(fill_height=True) as demo: |
|
gr.Markdown("## IDEFICS3-Llama 🐶") |
|
gr.Markdown("Play with [HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) in this demo. To get started, upload an image and text or try one of the examples.") |
|
with gr.Column(): |
|
image_input = gr.Image(label="Upload your Image", type="pil") |
|
query_input = gr.Textbox(label="Prompt") |
|
submit_btn = gr.Button("Submit") |
|
output = gr.Textbox(label="Output") |
|
|
|
with gr.Accordion(label="Example Inputs and Advanced Generation Parameters"): |
|
examples=[["example_images/travel_tips.jpg", "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips.", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["example_images/dummy_pdf.png", "How much percent is the order status?", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["example_images/art_critic.png", "As an art critic AI assistant, could you describe this painting in details and make a thorough critic?.", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["example_images/s2w_example.png", "What is this UI about?", "Greedy", 0.4, 512, 1.2, 0.8]] |
|
|
|
|
|
max_new_tokens = gr.Slider( |
|
minimum=8, |
|
maximum=1024, |
|
value=512, |
|
step=1, |
|
interactive=True, |
|
label="Maximum number of new tokens to generate", |
|
) |
|
repetition_penalty = gr.Slider( |
|
minimum=0.01, |
|
maximum=5.0, |
|
value=1.2, |
|
step=0.01, |
|
interactive=True, |
|
label="Repetition penalty", |
|
info="1.0 is equivalent to no penalty", |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=5.0, |
|
value=0.4, |
|
step=0.1, |
|
interactive=True, |
|
label="Sampling temperature", |
|
info="Higher values will produce more diverse outputs.", |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.01, |
|
maximum=0.99, |
|
value=0.8, |
|
step=0.01, |
|
interactive=True, |
|
label="Top P", |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
decoding_strategy = gr.Radio( |
|
[ |
|
"Greedy", |
|
"Top P Sampling", |
|
], |
|
value="Greedy", |
|
label="Decoding strategy", |
|
interactive=True, |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider( |
|
visible=( |
|
selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] |
|
) |
|
), |
|
inputs=decoding_strategy, |
|
outputs=temperature, |
|
) |
|
|
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider( |
|
visible=( |
|
selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] |
|
) |
|
), |
|
inputs=decoding_strategy, |
|
outputs=repetition_penalty, |
|
) |
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])), |
|
inputs=decoding_strategy, |
|
outputs=top_p, |
|
) |
|
gr.Examples( |
|
examples = examples, |
|
inputs=[image_input, query_input, decoding_strategy, temperature, |
|
max_new_tokens, repetition_penalty, top_p], |
|
outputs=output, |
|
fn=model_inference |
|
) |
|
|
|
submit_btn.click(model_inference, inputs = [image_input, query_input, decoding_strategy, temperature, |
|
max_new_tokens, repetition_penalty, top_p], outputs=output) |
|
|
|
|
|
demo.launch(debug=True) |