|
"""PaliGemma demo gradio app.""" |
|
|
|
import datetime |
|
import functools |
|
import glob |
|
import json |
|
import logging |
|
import os |
|
import time |
|
|
|
import gradio as gr |
|
import jax |
|
import PIL.Image |
|
import gradio_helpers |
|
import models |
|
import paligemma_parse |
|
|
|
INTRO_TEXT = """🤲 PaliGemma demo\n\n |
|
| [GitHub](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) |
|
| [HF blog post](https://huggingface.co/blog/paligemma) |
|
| [Google blog post](https://developers.googleblog.com/en/gemma-family-and-toolkit-expansion-io-2024) |
|
| [Vertex AI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363) |
|
| [Demo](https://huggingface.co/spaces/google/paligemma) |
|
|\n\n |
|
[PaliGemma](https://ai.google.dev/gemma/docs/paligemma) is an open vision-language model by Google, |
|
inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and |
|
built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) |
|
vision model and the [Gemma](https://arxiv.org/abs/2403.08295) language model. PaliGemma is designed as a versatile |
|
model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question |
|
answering, text reading, object detection and object segmentation. |
|
\n\n |
|
This space includes models fine-tuned on a mix of downstream tasks. |
|
See the [blog post](https://huggingface.co/blog/paligemma) and |
|
[README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) |
|
for detailed information how to use and fine-tune PaliGemma models. |
|
\n\n |
|
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications. |
|
""" |
|
|
|
|
|
make_image = lambda value, visible: gr.Image( |
|
value, label='Image', type='filepath', visible=visible) |
|
make_annotated_image = functools.partial(gr.AnnotatedImage, label='Image') |
|
make_highlighted_text = functools.partial(gr.HighlightedText, label='Output') |
|
|
|
|
|
|
|
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] |
|
|
|
|
|
@gradio_helpers.synced |
|
def compute(image, prompt, model_name, sampler): |
|
"""Runs model inference.""" |
|
if image is None: |
|
raise gr.Error('Image required') |
|
|
|
logging.info('prompt="%s"', prompt) |
|
|
|
if isinstance(image, str): |
|
image = PIL.Image.open(image) |
|
if gradio_helpers.should_mock(): |
|
logging.warning('Mocking response') |
|
time.sleep(2.) |
|
output = paligemma_parse.EXAMPLE_STRING |
|
else: |
|
if not model_name: |
|
raise gr.Error('Models not loaded yet') |
|
output = models.generate(model_name, sampler, image, prompt) |
|
logging.info('output="%s"', output) |
|
|
|
width, height = image.size |
|
objs = paligemma_parse.extract_objs(output, width, height, unique_labels=True) |
|
labels = set(obj.get('name') for obj in objs if obj.get('name')) |
|
color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} |
|
highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] |
|
annotated_image = ( |
|
image, |
|
[ |
|
( |
|
obj['mask'] if obj.get('mask') is not None else obj['xyxy'], |
|
obj['name'] or '', |
|
) |
|
for obj in objs |
|
if 'mask' in obj or 'xyxy' in obj |
|
], |
|
) |
|
has_annotations = bool(annotated_image[1]) |
|
return ( |
|
make_highlighted_text( |
|
highlighted_text, visible=True, color_map=color_map), |
|
make_image(image, visible=not has_annotations), |
|
make_annotated_image( |
|
annotated_image, visible=has_annotations, width=width, height=height, |
|
color_map=color_map), |
|
) |
|
|
|
|
|
def warmup(model_name): |
|
image = PIL.Image.new('RGB', [1, 1]) |
|
_ = compute(image, '', model_name, 'greedy') |
|
|
|
|
|
def reset(): |
|
return ( |
|
'', make_highlighted_text('', visible=False), |
|
make_image(None, visible=True), make_annotated_image(None, visible=False), |
|
) |
|
|
|
|
|
def create_app(): |
|
"""Creates demo UI.""" |
|
|
|
make_model = lambda choices: gr.Dropdown( |
|
value=(choices + [''])[0], |
|
choices=choices, |
|
label='Model', |
|
visible=bool(choices), |
|
) |
|
make_prompt = lambda value, visible=True: gr.Textbox( |
|
value, label='Prompt', visible=visible) |
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
|
|
gr.Markdown(INTRO_TEXT) |
|
with gr.Row(): |
|
image = make_image(None, visible=True) |
|
annotated_image = make_annotated_image(None, visible=False) |
|
with gr.Column(): |
|
with gr.Row(): |
|
prompt = make_prompt('', visible=True) |
|
model_info = gr.Markdown(label='Model Info') |
|
with gr.Row(): |
|
model = make_model([]) |
|
samplers = [ |
|
'greedy', 'nucleus(0.1)', 'nucleus(0.3)', 'temperature(0.5)'] |
|
sampler = gr.Dropdown( |
|
value=samplers[0], choices=samplers, label='Decoding' |
|
) |
|
with gr.Row(): |
|
run = gr.Button('Run', variant='primary') |
|
clear = gr.Button('Clear') |
|
highlighted_text = make_highlighted_text('', visible=False) |
|
|
|
|
|
|
|
def update_ui(model, prompt): |
|
prompt = make_prompt(prompt, visible=True) |
|
model_info = f'Model `{model}` – {models.MODELS_INFO.get(model, "No info.")}' |
|
return [prompt, model_info] |
|
|
|
gr.on( |
|
[model.change], |
|
update_ui, |
|
[model, prompt], |
|
[prompt, model_info], |
|
) |
|
|
|
gr.on( |
|
[run.click, prompt.submit], |
|
compute, |
|
[image, prompt, model, sampler], |
|
[highlighted_text, image, annotated_image], |
|
) |
|
clear.click( |
|
reset, None, [prompt, highlighted_text, image, annotated_image] |
|
) |
|
|
|
|
|
|
|
gr.set_static_paths(['examples/']) |
|
all_examples = [json.load(open(p)) for p in glob.glob('examples/*.json')] |
|
logging.info('loaded %d examples', len(all_examples)) |
|
example_image = gr.Image( |
|
label='Image', visible=False) |
|
example_model = gr.Text( |
|
label='Model', visible=False) |
|
example_prompt = gr.Text( |
|
label='Prompt', visible=False) |
|
example_license = gr.Markdown( |
|
label='Image License', visible=False) |
|
gr.Examples( |
|
examples=[ |
|
[ |
|
f'examples/{ex["name"]}.jpg', |
|
ex['prompt'], |
|
ex['model'], |
|
ex['license'], |
|
] |
|
for ex in all_examples |
|
if ex['model'] in models.MODELS |
|
], |
|
inputs=[example_image, example_prompt, example_model, example_license], |
|
) |
|
|
|
|
|
|
|
example_image.change( |
|
lambda image_path: ( |
|
make_image(image_path, visible=True), |
|
make_annotated_image(None, visible=False), |
|
make_highlighted_text('', visible=False), |
|
), |
|
example_image, |
|
[image, annotated_image, highlighted_text], |
|
) |
|
def example_model_changed(model): |
|
if model not in gradio_helpers.get_paths(): |
|
raise gr.Error(f'Model "{model}" not loaded!') |
|
return model |
|
example_model.change(example_model_changed, example_model, model) |
|
example_prompt.change(make_prompt, example_prompt, prompt) |
|
|
|
|
|
|
|
status = gr.Markdown(f'Startup: {datetime.datetime.now()}') |
|
gpu_kind = gr.Markdown(f'GPU=?') |
|
demo.load( |
|
lambda: [ |
|
gradio_helpers.get_status(), |
|
make_model(list(gradio_helpers.get_paths())), |
|
], |
|
None, |
|
[status, model], |
|
) |
|
def get_gpu_kind(): |
|
device = jax.devices()[0] |
|
if not gradio_helpers.should_mock() and device.platform != 'gpu': |
|
raise gr.Error('GPU not visible to JAX!') |
|
return f'GPU={device.device_kind}' |
|
demo.load(get_gpu_kind, None, gpu_kind) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
logging.info('JAX devices: %s', jax.devices()) |
|
|
|
for k, v in os.environ.items(): |
|
logging.info('environ["%s"] = %r', k, v) |
|
|
|
gradio_helpers.set_warmup_function(warmup) |
|
for name, (repo, filename, revision) in models.MODELS.items(): |
|
gradio_helpers.register_download(name, repo, filename, revision) |
|
|
|
create_app().queue().launch() |
|
|