File size: 2,873 Bytes
be1ec96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from PIL import Image
import numpy as np
import cv2 as cv2
import torch
import requests

import gradio as gr

import gem


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# OpenCLIP
model_name = 'ViT-B-16-quickgelu'
pretrained = 'metaclip_400m'
preprocess = gem.get_gem_img_transform()
# global gem_model
gem_model = gem.create_gem_model(model_name=model_name, pretrained=pretrained, device=device)
image_source = "image"
_MODELS = {
    "OpenAI": ('ViT-B-16', 'openai'),
    "MetaCLIP": ('ViT-B-16-quickgelu', 'metaclip_400m'),
    "OpenCLIP": ('ViT-B-16', 'laion400m_e32')
}

def change_weights(pretrained_weights):
    """ Handle changing model's weights triggered by a Dropdown module change."""
    curr_model = pretrained_weights
    _new_model = _MODELS[pretrained_weights]
    print(_new_model)
    global gem_model
    gem_model = gem.create_gem_model(model_name=_new_model[0], pretrained=_new_model[1], device=device)

def change_to_url(url):
    img_pil = Image.open(requests.get(url, stream=True).raw).convert('RGB')
    return img_pil

def viz_func(url, image, text, model_weights):
    image_torch = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = gem_model(image_torch, [text])
    logits = logits[0].detach().cpu().numpy()

    img_cv = cv2.cvtColor(np.array(image.resize((448, 448))), cv2.COLOR_RGB2BGR)
    logit_cs_viz = (logits * 255).astype('uint8')
    heat_maps_cs = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in logit_cs_viz]

    vizs = [0.4 * img_cv + 0.6 * heat_map for heat_map in heat_maps_cs]
    vizs = [cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB) for viz in vizs]
    return vizs[0]

inputs = [
    gr.Textbox(label="url to the image", ),
    gr.Image(type="pil"),
    gr.Textbox(label="Text Prompt"),
    gr.Dropdown(["OpenAI", "MetaCLIP", "OpenCLIP"], label="Pretrained Weights", value="MetaCLIP",
                info='It can take a few second for the model to be updated.'),
    ]

with gr.Blocks() as demo:
    inputs[-1].change(fn=change_weights, inputs=[inputs[-1]])
    inputs[0].change(fn=change_to_url, outputs=inputs[1], inputs=inputs[0])

    interact = gr.Interface(
        title="GEM: Grounding Everything Module (link to paper/code)",
        description="Grounding Everything: Emerging Localization Properties in Vision-Language Transformers",
        fn=viz_func,
        inputs=inputs,
        outputs=["image"],
    )

    gr.Examples(
        [
            ["assets/cats_remote_control.jpeg", "cat"],
            ["assets/cats_remote_control.jpeg", "remote control"],
            ["assets/elon_jeff_mark.jpeg", "elon musk"],
            ["assets/elon_jeff_mark.jpeg", "mark zuckerberg"],
            ["assets/elon_jeff_mark.jpeg", "jeff bezos"],
        ],
        [inputs[1], inputs[2]]
    )

# demo.launch(server_port=5152)
demo.launch()