import spaces from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection import torch import gradio as gr device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device) def generate_colors(labels): import random random.seed(42) colors = {} for label in labels: colors[label] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) return colors @spaces.GPU def infer(img, text_queries, score_threshold, model): if model == "dino": queries = "" for query in text_queries: queries += f"{query}. " height, width = img.shape[:2] target_sizes = [(height, width)] inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device) with torch.no_grad(): outputs = dino_model(**inputs) outputs.logits = outputs.logits.cpu() outputs.pred_boxes = outputs.pred_boxes.cpu() results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids, box_threshold=score_threshold, target_sizes=target_sizes) boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] result_labels = [] for box, score, label in zip(boxes, scores, labels): box = [int(i) for i in box.tolist()] if score < score_threshold: continue # Only include the labels that are part of the input text_queries if model == "dino" and label in text_queries: result_labels.append((box, label)) return result_labels def query_image(img, text_queries, dino_threshold): text_queries = text_queries.split(",") dino_output = infer(img, text_queries, dino_threshold, "dino") annotations = [] for box, label in dino_output: annotations.append({"label": label, "coordinates": {"x": box[0], "y": box[1], "width": box[2] - box[0], "height": box[3] - box[1]}}) colors = generate_colors(text_queries) return (img, {"boxes": annotations, "colors": colors}) dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold") dino_output = gr.AnnotatedImage(label="Grounding DINO Output") demo = gr.Interface( query_image, inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), dino_threshold], outputs=[dino_output], title="OWLv2 ⚔ Grounding DINO", description="Evaluate state-of-the-art [Grounding DINO](https://huggingface.co/IDEA-Research/grounding-dino-base) zero-shot object detection models. Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in the model.", examples=[["./warthog.jpg", "zebra, warthog", 0.16], ["./zebra.jpg", "zebra, lion", 0.16]] ) demo.launch(debug=True)