import torch from PIL import Image from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection import time import numpy as np import json import torch import gradio as gr # Model setup model_id = "IDEA-Research/grounding-dino-base" device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) text="Container. Bottle. Fruit. Vegetable. Packet." iou_threshold=0.4 box_threshold=0.3 score_threshold=0.3 # Function to detect objects in an image and return a JSON with count, class, box, and score def detect_objects(image): # Prepare inputs for the model inputs = processor(images=image, text=text, return_tensors="pt").to(device) # Perform inference with torch.no_grad(): outputs = model(**inputs) # Post-process results results = processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=box_threshold, text_threshold=score_threshold, target_sizes=[image.size[::-1]] ) # Function to calculate IoU (Intersection over Union) def iou(box1, box2): x1, y1, x2, y2 = box1 x1_2, y1_2, x2_2, y2_2 = box2 # Calculate intersection area inter_x1 = max(x1, x1_2) inter_y1 = max(y1, y1_2) inter_x2 = min(x2, x2_2) inter_y2 = min(y2, y2_2) if inter_x2 < inter_x1 or inter_y2 < inter_y1: return 0.0 # No intersection intersection_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1) # Calculate union area area1 = (x2 - x1) * (y2 - y1) area2 = (x2_2 - x1_2) * (y2_2 - y1_2) union_area = area1 + area2 - intersection_area return intersection_area / union_area # Filter out overlapping boxes using NMS (Non-Maximum Suppression) filtered_boxes = [] filtered_labels = [] filtered_scores = [] for i, (box, label, score) in enumerate(zip(results[0]['boxes'], results[0]['labels'], results[0]['scores'])): keep = True for j, (box2, label2, score2) in enumerate(zip(filtered_boxes, filtered_labels, filtered_scores)): # If IoU is above the threshold, discard the box if iou(box.tolist(), box2) > iou_threshold: keep = False break if keep: filtered_boxes.append(box.tolist()) filtered_labels.append(label) filtered_scores.append(score.item()) # Prepare the output in the requested format output = { "count": len(filtered_boxes), "class": filtered_labels, "box": filtered_boxes, "score": filtered_scores } return json.dumps(output) # Define Gradio input and output components image_input = gr.Image(type="pil") # Create the Gradio interface demo = gr.Interface( fn=detect_objects, inputs=image_input, outputs='text', title="Frshness prediction", description="Upload an image, and the model will detect objects and return the number of objects along with the image showing the bounding boxes." ) demo.launch(share=True)