File size: 5,391 Bytes
2c5aba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a634e56
2c5aba6
 
 
 
 
c0be566
2c5aba6
 
 
 
 
 
c0be566
5392e1d
c0be566
 
a634e56
2c5aba6
 
 
 
a634e56
f5ade32
8d279ae
 
 
 
 
 
 
2c5aba6
 
 
 
 
 
 
 
 
 
 
a634e56
2c5aba6
8dedea0
6b47b8c
 
7780c4c
6b47b8c
 
 
 
 
c63b460
6b47b8c
 
c63b460
 
 
 
 
 
6b47b8c
 
c63b460
6b47b8c
c63b460
 
 
 
 
2f9a4b7
c63b460
6b47b8c
c682534
f6c73a1
7182fe1
bceb2fe
7182fe1
4cedd75
 
bceb2fe
4cedd75
 
2c5aba6
 
 
338b885
5c5adcf
 
 
f13be7b
e306dfa
 
8dedea0
2c5aba6
1326205
31533d3
 
 
 
 
 
1326205
2c5aba6
02ca40c
f6c73a1
c63b460
 
 
f6c73a1
 
c63b460
 
 
 
 
 
f6c73a1
 
2c5aba6
02ca40c
2c5aba6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn as nn
import cv2
import gradio as gr
import numpy as np
from PIL import Image
import transformers
from transformers import RobertaModel, RobertaTokenizer
import timm
import pandas as pd 
import matplotlib.pyplot as plt
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

from model import Model
from output import visualize_output


# Use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Initialize used pretrained models
vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, global_pool='').to(device)
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True)
roberta = RobertaModel.from_pretrained("roberta-base")
model = Model(vit, roberta, tokenizer, device).to(device)
model.eval()

# Initialize trained model
state = torch.load('saved_model', map_location=torch.device('cpu'))
model.load_state_dict(state['val_model_dict'])

# Create transform for input image
config = resolve_data_config({}, model=vit)
config['no_aug'] = True
config['interpolation'] = 'bilinear'

# Inference function
def query_image(input_img, query, binarize, eval_threshold, crop_mode, crop_pct):

    if crop_mode == 'center':
        crop_mode = None

    config['crop_pct'] = crop_pct
    config['crop_mode'] = crop_mode
    transform = create_transform(**config)

    PIL_image = Image.fromarray(input_img, "RGB")
    img = transform(PIL_image)
    img = torch.unsqueeze(img,0).to(device)

    with torch.no_grad():
        output = model(img, query)

    img = visualize_output(img, output, binarize, eval_threshold)
    return img

# Gradio interface
description = """
Gradio demo for an object detection architecture, introduced in my bachelor thesis (link will be added). 
\n\n
You can use this architecture to detect objects using textual queries. To use it, simply upload an image and enter any query you want.
The model is trained to recognize only 80 categories (classes) from the COCO Detection 2017 dataset. 
Refer to <a href="https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/">this</a> website 
or the original <a href="https://arxiv.org/pdf/1405.0312.pdf">COCO</a> paper to see the full list of categories.
\n\n
Best results are obtained using one of these sentences, which were used during training:
<div class="row">
    <div class="column left">
        <ul>
            <li>Find a {class}.</li>
            <li>Find me a {class}</li>
            <li>Where is the {class}?</li>
            <li>Mark a {class}?</li>
            <li>Can you mark a {class}?</li>
            <li>Could you mark a {class}?</li>
            <li>Detect a {class}.</li>
        </ul>
    </div>
    <div class="column right">
        <ul>
            <li>Could you detect a {class}?</li>
            <li>Where is the {class} located?</li>
            <li>Where is the {class} positioned?</li>
            <li>Is there a {class}?</li>
            <li>Look for a {class}.</li>
            <li>Where can I find a {class}?</li>
            <li>Could you pinpoint a {class}?</li>
        </ul>
    </div>
</div>
\n\n
When the binarize option is turned off, model will output propabilities of requested {class} for each patch. When the binarize option is turned on
the model will binarize each propability based on set eval_threshold.
\n\n
Each input image is transformed to size 224x224 so it can be processed by ViT. During this transformation, different
crop_modes and crop_percentages can be selected. No image is lost if crop_pct = 1.0 and crop_mode='squash' or 'border'. The model was trained using crop_mode='center' and crop_pct = 0.9. 
For explanation of different crop_modes, please refer to 
<a href="https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/transforms_factory.py">this</a> website, lines 155-172.
"""
demo = gr.Interface(
    query_image, 
    #inputs=[gr.Image(), "text", "checkbox", gr.Slider(0, 1, value=0.25)],
    #inputs=[gr.Image(type='numpy', label='input_img').style(height=250, width=600), "text", "checkbox", gr.Slider(0, 1, value=0.25), 
    #        gr.Radio(["center", "squash", "border"], value='center', label='crop_mode'), gr.Slider(0.7, 1, value=0.9, step=0.01)], 
    inputs=["image", "text", "checkbox", gr.Slider(0, 1, value=0.25), 
            gr.Radio(["center", "squash", "border"], value='squash', label='crop_mode'), gr.Slider(0.7, 1, value=1, step=0.01)],  
    outputs="image",
    #outputs=gr.Image(type='numpy', label='output').style(height=600, width=600),
    title="Text-Based Object Detection",
    description=description,
    examples=[
        ["examples/imga.jpeg", "Find a person.", True, 0.45],
        ["examples/imgb.jpeg", "Could you mark a horse?", False, 0.25],
        ["examples/imgc.jpeg", "There should be a cat in this picture, where?", True, 0.25],
        ["examples/imgd.jpeg", "Mark a tv in this image.", False, 0.1],
        ["examples/imge.jpeg", "Is there a zebra in this picture?", True, 0.4],
        ["examples/imgf.jpeg", "Look for a stop sign.", True, 0.5],
    ],
    cache_examples=False,
    allow_flagging = "never",
    css = """
    .column {
      float: left;
      padding: 10px;
    }

    .left {
      width: 25%;
    }

    .right {
      width: 75%;
    }
    """
)
demo.launch()