|
import torch |
|
import torch.nn as nn |
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import transformers |
|
from transformers import RobertaModel, RobertaTokenizer |
|
import timm |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from timm.data import resolve_data_config |
|
from timm.data.transforms_factory import create_transform |
|
|
|
from model import Model |
|
from output import visualize_output |
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, global_pool='').to(device) |
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True) |
|
roberta = RobertaModel.from_pretrained("roberta-base") |
|
model = Model(vit, roberta, tokenizer, device).to(device) |
|
model.eval() |
|
|
|
|
|
state = torch.load('saved_model', map_location=torch.device('cpu')) |
|
model.load_state_dict(state['val_model_dict']) |
|
|
|
|
|
config = resolve_data_config({}, model=vit) |
|
config['no_aug'] = True |
|
config['interpolation'] = 'bilinear' |
|
|
|
|
|
def query_image(input_img, query, binarize, eval_threshold, crop_mode, crop_pct): |
|
|
|
if crop_mode == 'center': |
|
crop_mode = None |
|
|
|
config['crop_pct'] = crop_pct |
|
config['crop_mode'] = crop_mode |
|
transform = create_transform(**config) |
|
|
|
PIL_image = Image.fromarray(input_img, "RGB") |
|
img = transform(PIL_image) |
|
img = torch.unsqueeze(img,0).to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(img, query) |
|
|
|
img = visualize_output(img, output, binarize, eval_threshold) |
|
return img |
|
|
|
|
|
description = """ |
|
Gradio demo for an object detection architecture, introduced in my bachelor thesis (link will be added). |
|
\n\n |
|
You can use this architecture to detect objects using textual queries. To use it, simply upload an image and enter any query you want. |
|
The model is trained to recognize only 80 categories (classes) from the COCO Detection 2017 dataset. |
|
Refer to <a href="https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/">this</a> website |
|
or the original <a href="https://arxiv.org/pdf/1405.0312.pdf">COCO</a> paper to see the full list of categories. |
|
\n\n |
|
Best results are obtained using one of these sentences, which were used during training: |
|
<div class="row"> |
|
<div class="column left"> |
|
<ul> |
|
<li>Find a {class}.</li> |
|
<li>Find me a {class}</li> |
|
<li>Where is the {class}?</li> |
|
<li>Mark a {class}?</li> |
|
<li>Can you mark a {class}?</li> |
|
<li>Could you mark a {class}?</li> |
|
<li>Detect a {class}.</li> |
|
</ul> |
|
</div> |
|
<div class="column right"> |
|
<ul> |
|
<li>Could you detect a {class}?</li> |
|
<li>Where is the {class} located?</li> |
|
<li>Where is the {class} positioned?</li> |
|
<li>Is there a {class}?</li> |
|
<li>Look for a {class}.</li> |
|
<li>Where can I find a {class}?</li> |
|
<li>Could you pinpoint a {class}?</li> |
|
</ul> |
|
</div> |
|
</div> |
|
\n\n |
|
When the binarize option is turned off, model will output propabilities of requested {class} for each patch. When the binarize option is turned on |
|
the model will binarize each propability based on set eval_threshold. |
|
\n\n |
|
Each input image is transformed to size 224x224 so it can be processed by ViT. During this transformation, different |
|
crop_modes and crop_percentages can be selected. No image is lost if crop_pct = 1.0 and crop_mode='squash' or 'border'. The model was trained using crop_mode='center' and crop_pct = 0.9. |
|
For explanation of different crop_modes, please refer to |
|
<a href="https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/transforms_factory.py">this</a> website, lines 155-172. |
|
""" |
|
demo = gr.Interface( |
|
query_image, |
|
|
|
|
|
|
|
inputs=["image", "text", "checkbox", gr.Slider(0, 1, value=0.25), |
|
gr.Radio(["center", "squash", "border"], value='squash', label='crop_mode'), gr.Slider(0.7, 1, value=1, step=0.01)], |
|
outputs="image", |
|
|
|
title="Text-Based Object Detection", |
|
description=description, |
|
examples=[ |
|
["examples/imga.jpeg", "Find a person.", True, 0.45], |
|
["examples/imgb.jpeg", "Could you mark a horse?", False, 0.25], |
|
["examples/imgc.jpeg", "There should be a cat in this picture, where?", True, 0.25], |
|
["examples/imgd.jpeg", "Mark a tv in this image.", False, 0.1], |
|
["examples/imge.jpeg", "Is there a zebra in this picture?", True, 0.4], |
|
["examples/imgf.jpeg", "Look for a stop sign.", True, 0.5], |
|
], |
|
cache_examples=False, |
|
allow_flagging = "never", |
|
css = """ |
|
.column { |
|
float: left; |
|
padding: 10px; |
|
} |
|
|
|
.left { |
|
width: 25%; |
|
} |
|
|
|
.right { |
|
width: 75%; |
|
} |
|
""" |
|
) |
|
demo.launch() |
|
|
|
|