Spaces:
Sleeping
Sleeping
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
import torch
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
# Model setup
|
11 |
+
model_id = "IDEA-Research/grounding-dino-base"
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
|
14 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
15 |
+
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
16 |
+
text="Container. Bottle. Fruit. Vegetable. Packet."
|
17 |
+
iou_threshold=0.4
|
18 |
+
box_threshold=0.3
|
19 |
+
score_threshold=0.3
|
20 |
+
|
21 |
+
# Function to detect objects in an image and return a JSON with count, class, box, and score
|
22 |
+
def detect_objects(image):
|
23 |
+
|
24 |
+
# Prepare inputs for the model
|
25 |
+
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
26 |
+
|
27 |
+
# Perform inference
|
28 |
+
with torch.no_grad():
|
29 |
+
outputs = model(**inputs)
|
30 |
+
|
31 |
+
# Post-process results
|
32 |
+
results = processor.post_process_grounded_object_detection(
|
33 |
+
outputs,
|
34 |
+
inputs.input_ids,
|
35 |
+
box_threshold=box_threshold,
|
36 |
+
text_threshold=score_threshold,
|
37 |
+
target_sizes=[image.size[::-1]]
|
38 |
+
)
|
39 |
+
|
40 |
+
# Function to calculate IoU (Intersection over Union)
|
41 |
+
def iou(box1, box2):
|
42 |
+
x1, y1, x2, y2 = box1
|
43 |
+
x1_2, y1_2, x2_2, y2_2 = box2
|
44 |
+
|
45 |
+
# Calculate intersection area
|
46 |
+
inter_x1 = max(x1, x1_2)
|
47 |
+
inter_y1 = max(y1, y1_2)
|
48 |
+
inter_x2 = min(x2, x2_2)
|
49 |
+
inter_y2 = min(y2, y2_2)
|
50 |
+
|
51 |
+
if inter_x2 < inter_x1 or inter_y2 < inter_y1:
|
52 |
+
return 0.0 # No intersection
|
53 |
+
|
54 |
+
intersection_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
|
55 |
+
|
56 |
+
# Calculate union area
|
57 |
+
area1 = (x2 - x1) * (y2 - y1)
|
58 |
+
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
|
59 |
+
union_area = area1 + area2 - intersection_area
|
60 |
+
|
61 |
+
return intersection_area / union_area
|
62 |
+
|
63 |
+
# Filter out overlapping boxes using NMS (Non-Maximum Suppression)
|
64 |
+
filtered_boxes = []
|
65 |
+
filtered_labels = []
|
66 |
+
filtered_scores = []
|
67 |
+
|
68 |
+
for i, (box, label, score) in enumerate(zip(results[0]['boxes'], results[0]['labels'], results[0]['scores'])):
|
69 |
+
keep = True
|
70 |
+
for j, (box2, label2, score2) in enumerate(zip(filtered_boxes, filtered_labels, filtered_scores)):
|
71 |
+
# If IoU is above the threshold, discard the box
|
72 |
+
if iou(box.tolist(), box2) > iou_threshold:
|
73 |
+
keep = False
|
74 |
+
break
|
75 |
+
if keep:
|
76 |
+
filtered_boxes.append(box.tolist())
|
77 |
+
filtered_labels.append(label)
|
78 |
+
filtered_scores.append(score.item())
|
79 |
+
|
80 |
+
# Prepare the output in the requested format
|
81 |
+
output = {
|
82 |
+
"count": len(filtered_boxes),
|
83 |
+
"class": filtered_labels,
|
84 |
+
"box": filtered_boxes,
|
85 |
+
"score": filtered_scores
|
86 |
+
}
|
87 |
+
|
88 |
+
return json.dumps(output)
|
89 |
+
|
90 |
+
# Define Gradio input and output components
|
91 |
+
image_input = gr.Image(type="pil")
|
92 |
+
|
93 |
+
# Create the Gradio interface
|
94 |
+
demo = gr.Interface(
|
95 |
+
fn=detect_objects,
|
96 |
+
inputs=image_input,
|
97 |
+
outputs='text',
|
98 |
+
title="Frshness prediction",
|
99 |
+
description="Upload an image, and the model will detect objects and return the number of objects along with the image showing the bounding boxes."
|
100 |
+
)
|
101 |
+
|
102 |
+
demo.launch(share=True)
|