aanshbasu's picture
Add application file
9eb9e4b
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import time
import numpy as np
import json
import torch
import gradio as gr
# Model setup
model_id = "IDEA-Research/grounding-dino-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
text="Container. Bottle. Fruit. Vegetable. Packet."
iou_threshold=0.4
box_threshold=0.3
score_threshold=0.3
# Function to detect objects in an image and return a JSON with count, class, box, and score
def detect_objects(image):
# Prepare inputs for the model
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process results
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=score_threshold,
target_sizes=[image.size[::-1]]
)
# Function to calculate IoU (Intersection over Union)
def iou(box1, box2):
x1, y1, x2, y2 = box1
x1_2, y1_2, x2_2, y2_2 = box2
# Calculate intersection area
inter_x1 = max(x1, x1_2)
inter_y1 = max(y1, y1_2)
inter_x2 = min(x2, x2_2)
inter_y2 = min(y2, y2_2)
if inter_x2 < inter_x1 or inter_y2 < inter_y1:
return 0.0 # No intersection
intersection_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
# Calculate union area
area1 = (x2 - x1) * (y2 - y1)
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
union_area = area1 + area2 - intersection_area
return intersection_area / union_area
# Filter out overlapping boxes using NMS (Non-Maximum Suppression)
filtered_boxes = []
filtered_labels = []
filtered_scores = []
for i, (box, label, score) in enumerate(zip(results[0]['boxes'], results[0]['labels'], results[0]['scores'])):
keep = True
for j, (box2, label2, score2) in enumerate(zip(filtered_boxes, filtered_labels, filtered_scores)):
# If IoU is above the threshold, discard the box
if iou(box.tolist(), box2) > iou_threshold:
keep = False
break
if keep:
filtered_boxes.append(box.tolist())
filtered_labels.append(label)
filtered_scores.append(score.item())
# Prepare the output in the requested format
output = {
"count": len(filtered_boxes),
"class": filtered_labels,
"box": filtered_boxes,
"score": filtered_scores
}
return json.dumps(output)
# Define Gradio input and output components
image_input = gr.Image(type="pil")
# Create the Gradio interface
demo = gr.Interface(
fn=detect_objects,
inputs=image_input,
outputs='text',
title="Frshness prediction",
description="Upload an image, and the model will detect objects and return the number of objects along with the image showing the bounding boxes."
)
demo.launch(share=True)