"""Gradio clone of https://google-research.github.io/vision_transformer/lit/. Features: - Models are downloaded dynamically. - Models are cached on local disk, and in RAM. - Progress bars when downloading/reading/computing. - Dynamic update of model controls. - Dynamic generation of output sliders. - Use of `gr.State()` for better use of progress bars. """ import dataclasses import json import logging import os import time import urllib.request import gradio as gr import PIL.Image import big_vision_contrastive_models as models import gradio_helpers INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json' IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg' MAX_ANSWERS = 10 MAX_DISK_CACHE = 20e9 MAX_RAM_CACHE = 10e9 # CPU basic has 16G RAM LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10} # family/variant/res -> name MODEL_MAP = { 'lit': { 'B/16': { 224: 'lit_b16b', }, 'L/16': { 224: 'lit_l16l', }, }, 'siglip': { 'B/16': { 224: 'siglip_b16b_224', 256: 'siglip_b16b_256', 384: 'siglip_b16b_384', 512: 'siglip_b16b_512', }, 'L/16': { 256: 'siglip_l16l_256', 384: 'siglip_l16l_384', }, 'So400m/14': { 224: 'siglip_so400m14so440m_224', 384: 'siglip_so400m14so440m_384', }, }, } def get_cache_status(): """Returns a string summarizing cache status.""" mem_n, mem_sz = gradio_helpers.get_memory_cache_info() disk_n, disk_sz = gradio_helpers.get_disk_cache_info() return ( f'memory cache {mem_n} items [{mem_sz/1e9:.2f}G], ' f'disk cache {disk_n} items [{disk_sz/1e9:.2f}G]' ) def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progress()): """Loads model and computes answers.""" if image_path is None: raise gr.Error('Must first select an image!') t0 = time.monotonic() model_name = MODEL_MAP[family][variant][res] config = models.MODEL_CONFIGS[model_name] local_ckpt = gradio_helpers.get_disk_cache( config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE) config = dataclasses.replace(config, ckpt=local_ckpt) params, model = gradio_helpers.get_memory_cache( config, lambda: models.load_model(config), max_cache_size_bytes=MAX_RAM_CACHE, progress=progress, estimated_secs={ ('lit', 'B/16'): 1, ('lit', 'L/16'): 2.5, ('siglip', 'B/16'): 9, ('siglip', 'L/16'): 28, ('siglip', 'So400m/14'): 36, }.get((family, variant)) ) model: models.ContrastiveModel = model it = progress.tqdm(list(range(3)), desc='compute') logging.info('Opening image "%s"', image_path) with gradio_helpers.timed(f'opening image "{image_path}"'): image = PIL.Image.open(image_path) next(it) with gradio_helpers.timed('image features'): zimg, out = model.embed_images( params, model.preprocess_images([image]) ) next(it) with gradio_helpers.timed('text features'): prompts = prompts.split('\n') ztxt, out = model.embed_texts( params, model.preprocess_texts(prompts) ) next(it) t = model.get_temperature(out) if family == 'lit': text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0]) elif family == 'siglip': text_probs = list(model.get_probabilities(zimg, ztxt, t, bias=bias)[0]) state = list(zip(prompts, [round(p.item(), 3) for p in text_probs])) dt = time.monotonic() - t0 status = gr.Markdown( f'Computed inference in {dt:.1f} seconds ({get_cache_status()})') if 'b' in out: logging.info('model_name=%s default bias=%f', model_name, out['b']) return status, state def update_answers(state): """Generates visible sliders for answers.""" answers = [] for prompt, prob in state[:MAX_ANSWERS]: answers.append(gr.Slider(value=round(100*prob, 2), label=prompt, visible=True)) while len(answers) < MAX_ANSWERS: answers.append(gr.Slider(visible=False)) return answers def create_app(): """Creates demo UI.""" css = ''' .slider input[type="number"] { width: 5em; } #examples td.textbox > div { white-space: pre-wrap !important; text-align: left; } ''' with gr.Blocks(css=css) as demo: gr.Markdown('Gradio clone of the original [LiT demo](https://google-research.github.io/vision_transformer/lit/).') status = gr.Markdown(f'Ready ({get_cache_status()})') with gr.Row(): image = gr.Image(label='Image', type='filepath') source = gr.Markdown('', visible=False) state = gr.State([]) with gr.Column(): prompts = gr.Textbox(label='Prompts (press Shift-ENTER to add a prompt)') with gr.Row(): values = {} family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family') values['family'] = family.value # Unfortunately below reactive UI code is a bit convoluted, because: # 1. When e.g. `family.change()` updates `variant`, then that does not # trigger a `varaint.change()`. # 2. The widget values like `family.value` are *not* updated when the # widget is updated. Therefore, we keep a manual copy in `values`. def make_variant(family_value): choices = list(MODEL_MAP[family_value]) values['variant'] = choices[0] return gr.Dropdown(value=values['variant'], choices=choices, label='Variant') variant = make_variant(family.value) def make_res(family, variant): choices = list(MODEL_MAP[family][variant]) values['res'] = choices[0] return gr.Dropdown(value=values['res'], choices=choices, label='Resolution') res = make_res(family.value, variant.value) values['res'] = res.value def make_bias(family, variant, res): visible = family == 'siglip' value = { ('siglip', 'B/16', 224): -12.9, ('siglip', 'L/16', 256): -12.7, ('siglip', 'L/16', 256): -16.5, # ... }.get((family, variant, res), -10.0) return gr.Slider(value=value, minimum=-20, maximum=0, step=0.05, label='Bias', visible=visible) bias = make_bias(family.value, variant.value, res.value) values['bias'] = bias.value def family_changed(family): variant = list(MODEL_MAP[family])[0] res = list(MODEL_MAP[family][variant])[0] values['family'] = family values['variant'] = variant values['res'] = res return [ make_variant(family), make_res(family, variant), make_bias(family, variant, res), ] def variant_changed(variant): res = list(MODEL_MAP[values['family']][variant])[0] values['variant'] = variant values['res'] = res return [ make_res(values['family'], variant), make_bias(values['family'], variant, res), ] def res_changed(res): return make_bias(values['family'], values['variant'], res) family.change(family_changed, family, [variant, res, bias]) variant.change(variant_changed, variant, [res, bias]) res.change(res_changed, res, bias) # (end of code for reactive UI code) run = gr.Button('Run') answers = [ # Will be set to visible in `update_answers()`. gr.Slider(0, 100, 0, visible=False, elem_classes='slider') for _ in range(MAX_ANSWERS) ] # We want to avoid showing multiple progress bars, so we only update # a single `status` widget here, and store the computed information in # `state`... run.click( fn=compute, inputs=[image, prompts, family, variant, res, bias], outputs=[status, state]) # ... then we use `state` to update UI components without showing a # progress bar in their place. status.change(fn=update_answers, inputs=state, outputs=answers) info = json.load(urllib.request.urlopen(INFO_URL)) gr.Markdown('Note: below images have 224 px resolution only:') gr.Examples( examples=[ [ IMG_URL_FMT.format(ex['id']), ex['prompts'].replace(', ', '\n'), '[source](%s)' % ex['source'], ] for ex in info ], inputs=[image, prompts, source, license], outputs=answers, elem_id='examples', ) return demo if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') for k, v in os.environ.items(): logging.info('environ["%s"] = %r', k, v) models.setup() create_app().queue().launch()