File size: 5,391 Bytes
2c5aba6 a634e56 2c5aba6 c0be566 2c5aba6 c0be566 5392e1d c0be566 a634e56 2c5aba6 a634e56 f5ade32 8d279ae 2c5aba6 a634e56 2c5aba6 8dedea0 6b47b8c 7780c4c 6b47b8c c63b460 6b47b8c c63b460 6b47b8c c63b460 6b47b8c c63b460 2f9a4b7 c63b460 6b47b8c c682534 f6c73a1 7182fe1 bceb2fe 7182fe1 4cedd75 bceb2fe 4cedd75 2c5aba6 338b885 5c5adcf f13be7b e306dfa 8dedea0 2c5aba6 1326205 31533d3 1326205 2c5aba6 02ca40c f6c73a1 c63b460 f6c73a1 c63b460 f6c73a1 2c5aba6 02ca40c 2c5aba6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
import torch.nn as nn
import cv2
import gradio as gr
import numpy as np
from PIL import Image
import transformers
from transformers import RobertaModel, RobertaTokenizer
import timm
import pandas as pd
import matplotlib.pyplot as plt
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from model import Model
from output import visualize_output
# Use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Initialize used pretrained models
vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, global_pool='').to(device)
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True)
roberta = RobertaModel.from_pretrained("roberta-base")
model = Model(vit, roberta, tokenizer, device).to(device)
model.eval()
# Initialize trained model
state = torch.load('saved_model', map_location=torch.device('cpu'))
model.load_state_dict(state['val_model_dict'])
# Create transform for input image
config = resolve_data_config({}, model=vit)
config['no_aug'] = True
config['interpolation'] = 'bilinear'
# Inference function
def query_image(input_img, query, binarize, eval_threshold, crop_mode, crop_pct):
if crop_mode == 'center':
crop_mode = None
config['crop_pct'] = crop_pct
config['crop_mode'] = crop_mode
transform = create_transform(**config)
PIL_image = Image.fromarray(input_img, "RGB")
img = transform(PIL_image)
img = torch.unsqueeze(img,0).to(device)
with torch.no_grad():
output = model(img, query)
img = visualize_output(img, output, binarize, eval_threshold)
return img
# Gradio interface
description = """
Gradio demo for an object detection architecture, introduced in my bachelor thesis (link will be added).
\n\n
You can use this architecture to detect objects using textual queries. To use it, simply upload an image and enter any query you want.
The model is trained to recognize only 80 categories (classes) from the COCO Detection 2017 dataset.
Refer to <a href="https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/">this</a> website
or the original <a href="https://arxiv.org/pdf/1405.0312.pdf">COCO</a> paper to see the full list of categories.
\n\n
Best results are obtained using one of these sentences, which were used during training:
<div class="row">
<div class="column left">
<ul>
<li>Find a {class}.</li>
<li>Find me a {class}</li>
<li>Where is the {class}?</li>
<li>Mark a {class}?</li>
<li>Can you mark a {class}?</li>
<li>Could you mark a {class}?</li>
<li>Detect a {class}.</li>
</ul>
</div>
<div class="column right">
<ul>
<li>Could you detect a {class}?</li>
<li>Where is the {class} located?</li>
<li>Where is the {class} positioned?</li>
<li>Is there a {class}?</li>
<li>Look for a {class}.</li>
<li>Where can I find a {class}?</li>
<li>Could you pinpoint a {class}?</li>
</ul>
</div>
</div>
\n\n
When the binarize option is turned off, model will output propabilities of requested {class} for each patch. When the binarize option is turned on
the model will binarize each propability based on set eval_threshold.
\n\n
Each input image is transformed to size 224x224 so it can be processed by ViT. During this transformation, different
crop_modes and crop_percentages can be selected. No image is lost if crop_pct = 1.0 and crop_mode='squash' or 'border'. The model was trained using crop_mode='center' and crop_pct = 0.9.
For explanation of different crop_modes, please refer to
<a href="https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/transforms_factory.py">this</a> website, lines 155-172.
"""
demo = gr.Interface(
query_image,
#inputs=[gr.Image(), "text", "checkbox", gr.Slider(0, 1, value=0.25)],
#inputs=[gr.Image(type='numpy', label='input_img').style(height=250, width=600), "text", "checkbox", gr.Slider(0, 1, value=0.25),
# gr.Radio(["center", "squash", "border"], value='center', label='crop_mode'), gr.Slider(0.7, 1, value=0.9, step=0.01)],
inputs=["image", "text", "checkbox", gr.Slider(0, 1, value=0.25),
gr.Radio(["center", "squash", "border"], value='squash', label='crop_mode'), gr.Slider(0.7, 1, value=1, step=0.01)],
outputs="image",
#outputs=gr.Image(type='numpy', label='output').style(height=600, width=600),
title="Text-Based Object Detection",
description=description,
examples=[
["examples/imga.jpeg", "Find a person.", True, 0.45],
["examples/imgb.jpeg", "Could you mark a horse?", False, 0.25],
["examples/imgc.jpeg", "There should be a cat in this picture, where?", True, 0.25],
["examples/imgd.jpeg", "Mark a tv in this image.", False, 0.1],
["examples/imge.jpeg", "Is there a zebra in this picture?", True, 0.4],
["examples/imgf.jpeg", "Look for a stop sign.", True, 0.5],
],
cache_examples=False,
allow_flagging = "never",
css = """
.column {
float: left;
padding: 10px;
}
.left {
width: 25%;
}
.right {
width: 75%;
}
"""
)
demo.launch()
|