import torch import cv2 import gradio as gr import numpy as np import requests from PIL import Image from io import BytesIO from transformers import OwlViTProcessor, OwlViTForObjectDetection import os # Use GPU if available if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device) model.eval() processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14") def query_image(img, text_queries, score_threshold): text_queries = text_queries.split(",") img = np.array(img) target_sizes = torch.Tensor([img.shape[:2]]) inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) outputs.logits = outputs.logits.cpu() outputs.pred_boxes = outputs.pred_boxes.cpu() results = processor.post_process(outputs=outputs, target_sizes=target_sizes) boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] font = cv2.FONT_HERSHEY_SIMPLEX for box, score, label in zip(boxes, scores, labels): box = [int(i) for i in box.tolist()] if score >= score_threshold: img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5) if box[3] + 25 > 768: y = box[3] - 10 else: y = box[3] + 25 img = cv2.putText( img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA ) return img description = """ \n\nYou can use OWL-ViT to query images with text descriptions of any object. To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You can also use the score threshold slider to set a threshold to filter out low probability predictions. """ with gr.Blocks() as demo: with gr.Column(): with gr.Tab("Upload image"): with gr.Row(): with gr.Column(): inputs_file=[gr.Image(source="upload"), gr.Textbox(), gr.Slider(0, 1, value=0.1)] submit_btn = gr.Button("Submit") im_output = gr.Image() with gr.Tab("Capture image with webcam"): with gr.Row(): with gr.Column(): inputs_web=[gr.Image(source="webcam"), gr.Textbox(), gr.Slider(0, 1, value=0.1)] submit_btn_web = gr.Button("Submit") web_output = gr.Image() submit_btn.click(fn=query_image, inputs= inputs_file, outputs = im_output) submit_btn_web.click(fn=query_image, inputs= inputs_web, outputs = web_output) gr.Markdown("## Image Examples") gr.Examples( examples=os.path.join(os.path.dirname(__file__), "examples", "IMGP0178.jpg"), inputs=inputs_file, outputs=im_output, fn=query_image, cache_examples=True, ) demo.launch()