|
import torch |
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
from transformers import OwlViTProcessor, OwlViTForObjectDetection |
|
import os |
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device) |
|
model.eval() |
|
processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14") |
|
|
|
def query_image(img, text_queries, score_threshold): |
|
text_queries = text_queries.split(",") |
|
|
|
img = np.array(img) |
|
|
|
target_sizes = torch.Tensor([img.shape[:2]]) |
|
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
outputs.logits = outputs.logits.cpu() |
|
outputs.pred_boxes = outputs.pred_boxes.cpu() |
|
results = processor.post_process(outputs=outputs, target_sizes=target_sizes) |
|
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] |
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
|
|
for box, score, label in zip(boxes, scores, labels): |
|
box = [int(i) for i in box.tolist()] |
|
|
|
if score >= score_threshold: |
|
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5) |
|
if box[3] + 25 > 768: |
|
y = box[3] - 10 |
|
else: |
|
y = box[3] + 25 |
|
|
|
img = cv2.putText( |
|
img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA |
|
) |
|
return img |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
with gr.Tab("Upload image"): |
|
gr.Markdown(""" |
|
\n OWL-ViT(https://huggingface.co/docs/transformers/model_doc/owlvit) is a vision transformer architecture that can be used for image inputs with text queries. This is achieved by adding a text embedding layer to the model, which allows it to process both image and text inputs. |
|
\n You can use to query images with text descriptions of any object. To use it, simply upload an image or capture one with the webcam and enter comma separated text descriptions of objects you want to query the image for. |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
gr.Markdown("""Insert an image below and add text descriptions of what you are looking for. |
|
If you wish for assistance to find the right text queries you can ask for help from [ChatBRD](https://chatbrd.novonordisk.com/#/) but remember you need to log on novos VPN before you can use it.""") |
|
inputf1 = gr.Image(source="upload") |
|
inputf2 = gr.Textbox() |
|
gr.Markdown(""" |
|
\n You can also use the score threshold slider to set a threshold to filter out lower probability predictions. |
|
""") |
|
inputf3 = gr.Slider(0, 1, value=0.1) |
|
|
|
inputs_file = [inputf1, inputf2, inputf3] |
|
submit_btn = gr.Button("Submit") |
|
|
|
im_output = gr.Image() |
|
|
|
with gr.Tab("Capture image with webcam"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("""Insert an image below and add text descriptions of what you are looking for. |
|
If you wish for assistance to find the right text queries you can ask for help from [ChatBRD](https://chatbrd.novonordisk.com/#/) but remember you need to log on novos VPN before you can use it.""") |
|
inputweb1 = gr.Image(source="webcam") |
|
inputweb2 = gr.Textbox() |
|
gr.Markdown(""" |
|
\n You can also use the score threshold slider to set a threshold to filter out lower probability predictions. |
|
""") |
|
inputweb3 = gr.Slider(0, 1, value=0.1) |
|
|
|
inputs_web = [inputweb1, inputweb2, inputweb3] |
|
submit_btn_web = gr.Button("Submit") |
|
|
|
web_output = gr.Image() |
|
|
|
submit_btn.click(fn=query_image, inputs= inputs_file, outputs = im_output, queue=True) |
|
submit_btn_web.click(fn=query_image, inputs= inputs_web, outputs = web_output, queue=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo.queue( |
|
concurrency_count=40, |
|
max_size=25, |
|
api_open = False |
|
) |
|
demo.launch(auth=("novouser", "bstad2023")) |
|
|