idefics2 / app.py
arjunanand13's picture
Update app.py
eeae644 verified
raw
history blame contribute delete
No virus
5.08 kB
import gradio as gr
from transformers import AutoProcessor, Idefics2ForConditionalGeneration, AutoModelForPreTraining
import subprocess
import torch
from peft import LoraConfig
from huggingface_hub import InferenceApi
from transformers import BitsAndBytesConfig
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
DEVICE = "cuda:0"
USE_LORA = False
USE_QLORA = True
processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
model = Idefics2ForConditionalGeneration.from_pretrained("HuggingFaceM4/idefics2-8b",quantization_config=bnb_config)
# if USE_QLORA or USE_LORA:
# lora_config = LoraConfig(
# r=8,
# lora_alpha=8,
# lora_dropout=0.1,
# target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*',
# use_dora=False if USE_QLORA else True,
# init_lora_weights="gaussian"
# )
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.float16
# ) if USE_QLORA else None
# model = Idefics2ForConditionalGeneration.from_pretrained(
# "HuggingFaceM4/idefics2-8b",
# torch_dtype=torch.float16,
# quantization_config=bnb_config,
# )
# model.add_adapter(lora_config)
# model.enable_adapters()
# else:
# model = Idefics2ForConditionalGeneration.from_pretrained(
# "HuggingFaceM4/idefics2-8b",
# torch_dtype=torch.float16,
# _attn_implementation="flash_attention_2"
# ).to(DEVICE)
def model_inference(image, text):
resulting_messages = [{"role": "user", "content": [{"type": "image"}] + [{"type": "text", "text": text}]}]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image], return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
generated_ids = model.generate(
**inputs,
max_new_tokens=1024, # More tokens for extended content
temperature=0.3,
do_sample=True, # Slightly more random, to enhance creativity
top_p=0.7, # Nucleus sampling, for focused yet diverse output
# num_beams=5, # Use beam search with 5 beams
num_return_sequences=1 # Return the top 3 sequences from the beam search
)
generated_text = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)
return generated_text[0]
# with gr.Blocks() as demo:
# gr.Markdown("## Enhanced IDEFICS2 Demo")
# image_input = gr.Image(label="Upload Image", type="pil",height=480,width=640)
# query_input = gr.Textbox(label="Enter Prompt")
# submit_btn = gr.Button("Generate")
# output = gr.Textbox(label="Model Output")
with gr.Blocks(css="background-color:lightgrey;") as demo:
gr.Markdown("## IDEFICS2 Demo")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Upload Image", type="pil")#, height=480, width=640)
query_input = gr.Textbox(label="Enter Prompt", placeholder="Type your prompt here...")
with gr.Column(scale=1):
output = gr.Textbox(label="Model Output", interactive=True, placeholder="Output will be displayed here...",lines=12)
submit_btn = gr.Button("Generate")
submit_btn.click(model_inference, inputs=[image_input, query_input], outputs=output)
examples = [
["american_football.png", "Explain in detail what is depicted in the picture"],
["bike.png", "Explore the image closely and describe in detail what you discover."],
["finance.png", "Provide a detailed description of everything you see in the image."],
["science.png", "Please perform optical character recognition (OCR) on the uploaded image. Extract all text visible in the image accurately. Ensure to capture the text in its entirety and maintain the formatting as closely as possible to how it appears in the image. After extracting the text, display it in a clear and readable format, making sure that any special characters or symbols are also accurately represented. Provide the extracted text as output."],
["spirituality.png", "Please perform optical character recognition (OCR) on the uploaded image. Extract all text visible in the image accurately. Ensure to capture the text in its entirety and maintain the formatting as closely as possible to how it appears in the image. After extracting the text, display it in a clear and readable format, making sure that any special characters or symbols are also accurately represented. Provide the extracted text as output."]
]
gr.Examples(examples=examples, inputs=[image_input, query_input], outputs=output)
demo.launch(debug=True)