new-owl-vit / app.py
Mithu96's picture
Create app.py
7acb9bb verified
import torch
import cv2
import gradio as gr
import numpy as np
import requests
from PIL import Image
from io import BytesIO
from transformers import OwlViTProcessor, OwlViTForObjectDetection
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device)
model.eval()
processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14")
def query_image(img_url, text_queries, score_threshold):
text_queries = text_queries.split(",")
response = requests.get(img_url)
img = Image.open(BytesIO(response.content))
img = np.array(img)
target_sizes = torch.Tensor([img.shape[:2]])
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
font = cv2.FONT_HERSHEY_SIMPLEX
for box, score, label in zip(boxes, scores, labels):
box = [int(i) for i in box.tolist()]
if score >= score_threshold:
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
if box[3] + 25 > 768:
y = box[3] - 10
else:
y = box[3] + 25
img = cv2.putText(
img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
)
return img
description = """
DEMO
"""
demo = gr.Interface(
query_image,
inputs=["text", "text", gr.Slider(0, 1, value=0.1)],
outputs="image",
title="Zero-Shot Object Detection with OWL-ViT",
description=description,
examples=[],
)
demo.launch()