merve's picture
merve HF staff
change layout a bit
bbb97d0 verified
raw
history blame
4.04 kB
import os
os.system('pip install ./transformers-4.47.0.dev0-py3-none-any.whl')
import gradio as gr
import PIL.Image
import transformers
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import string
import functools
import re
import numpy as np
import spaces
adapter_id = "merve/paligemma2-3b-vqav2"
model_id = "gv-hf/paligemma2-3b-pt-448"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PaliGemmaForConditionalGeneration.from_pretrained(adapter_id).eval().to(device)
processor = PaliGemmaProcessor.from_pretrained(model_id)
###### Transformers Inference
@spaces.GPU
def infer(
text,
image: PIL.Image.Image,
max_new_tokens: int
) -> str:
text = "answer en " + text
inputs = processor(text=text, images=image, return_tensors="pt").to(device)
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)
return result[0][len(text):].lstrip("\n")
######## Demo
INTRO_TEXT = """## PaliGemma 2 demo\n\n
| [Github](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md)
| [Blogpost](https://huggingface.co/blog/paligemma)
| [Fine-tuning notebook](https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb)
|\n\n
PaliGemma 2 is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and
built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343)
vision model and the [Gemma 2](https://arxiv.org/abs/2408.00118) language model. PaliGemma 2 is designed as a versatile
model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question
answering, text reading, object detection and object segmentation.
\n\n
This space includes a model LoRA fine-tuned by the team at Hugging Face on VQAv2, inferred using transformers.
See the [Blogpost](https://huggingface.co/blog/paligemma2), the project
[README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) and the
[fine-tuning notebook](https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb)
for detailed information about how to use and fine-tune PaliGemma and PaliGemma 2 models.
\n\n
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications.
"""
with gr.Blocks(css="style.css") as demo:
gr.Markdown(INTRO_TEXT)
with gr.Column():
image = gr.Image(label="Input Image", type="pil", height=400)
question = gr.Text(label="Question")
tokens = gr.Slider(
label="Max New Tokens",
info="Set to larger for longer generation.",
minimum=20,
maximum=160,
value=80,
step=10,
)
caption_btn = gr.Button(value="Submit")
text_output = gr.Text(label="Text Output")
caption_inputs = [
question,
image,
tokens
]
caption_outputs = [
text_output
]
caption_btn.click(
fn=infer,
inputs=caption_inputs,
outputs=caption_outputs,
)
examples = [
["What is the graphic about?", "./howto.jpg", 60],
["What is the password", "./password.jpg", 20],
["Who is in this image?", "./examples_bowie.jpg", 80],
]
gr.Markdown("Example images are licensed CC0 by [akolesnikoff@](https://github.com/akolesnikoff), [mbosnjak@](https://github.com/mbosnjak), [maximneumann@](https://github.com/maximneumann) and [merve](https://huggingface.co/merve).")
gr.Examples(
examples=examples,
inputs=caption_inputs,
)
#########
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)