import contextlib import functools import json import logging import os import time import urllib.request import gradio as gr import open_clip # works on open-clip-torch>=2.23.0, timm>=0.9.8 import PIL.Image import torch import torch.nn.functional as F INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json' IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg' @contextlib.contextmanager def timed(name): t0 = time.monotonic() try: yield finally: logging.info('Timed %s: %.1f secs', name, time.monotonic() - t0) @functools.cache def load_model(name='hf-hub:timm/ViT-SO400M-14-SigLIP-384'): with timed('loading model, preprocess, tokenizer'): t0 = time.time() model, preprocess = open_clip.create_model_from_pretrained(name) tokenizer = open_clip.get_tokenizer(name) logging.info('loaded in %.1fs', time.time() - t0) return model, preprocess, tokenizer def generate_answers(image_path, prompts): model, preprocess, tokenizer = load_model() with torch.no_grad(), torch.cuda.amp.autocast(): logging.info('Opening image "%s"', image_path) with timed(f'opening image "{image_path}"'): image = PIL.Image.open(image_path) with timed('image features'): image = preprocess(image).unsqueeze(0) image_features = model.encode_image(image) with timed('text features'): prompts = prompts.split(', ') text = tokenizer(prompts, context_length=model.context_length) text_features = model.encode_text(text) image_features = F.normalize(image_features, dim=-1) text_features = F.normalize(text_features, dim=-1) exp, bias = model.logit_scale.exp(), model.logit_bias text_probs = torch.sigmoid(image_features @ text_features.T * exp + bias) return list(zip(prompts, [round(p.item(), 3) for p in text_probs[0]])) def create_app(): info = json.load(urllib.request.urlopen(INFO_URL)) with gr.Blocks() as demo: gr.Markdown('Minimal gradio clone of [lit-tuning-demo](https://google-research.github.io/vision_transformer/lit/)') gr.Markdown('Using `open_clip` implementation of SigLIP model `timm/ViT-SO400M-14-SigLIP-384`') with gr.Row(): image = gr.Image(label='input_image', type='filepath') with gr.Column(): prompts = gr.Textbox(label='prompts') answer = gr.Textbox(label='answer') run = gr.Button('Run') gr.Examples( examples=[ [IMG_URL_FMT.format(ex['id']), ex['prompts']] for ex in info ], inputs=[image, prompts], outputs=[answer], ) run.click(fn=generate_answers, inputs=[image, prompts], outputs=[answer]) return demo if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') for k, v in os.environ.items(): logging.info('environ["%s"] = %r', k, v) _ = load_model() create_app().queue().launch()