Aastha
Add filtering step for display of bounding box
a21b606
import gradio as gr
import argparse
from pathlib import Path
import torch
from numpy import random
from models.experimental import attempt_load
from utils.datasets import LoadImages
from utils.general import check_img_size, scale_coords, xyxy2xywh, non_max_suppression
from utils.plots import plot_one_box
from PIL import Image
from huggingface_hub import hf_hub_download
import cv2
import json
import numpy as np
def load_model(model_name):
model_path = hf_hub_download(repo_id=f"Yolov7/{model_name}", filename=f"{model_name}.pt")
return model_path
model_names = ["yolov7", "yolov7-e6e", "yolov7-e6"]
models = {model_name: load_model(model_name) for model_name in model_names}
def parse_rois(roi_json):
"""
Parse the ROI JSON structure and return a list of ROI polygons.
"""
rois_data = json.loads(roi_json)
roi_polygons = []
for roi in rois_data:
coordinates = roi["coordinates"]
roi_points = [(int(coord["x"]), int(coord["y"])) for coord in coordinates]
roi_polygons.append(roi_points)
return roi_polygons
rois = []
names= [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush' ]
desired_classes = ['person', 'bicycle', 'car', 'motorcycle', 'bus', 'train', 'truck']
desired_indices = [names.index(cls) for cls in desired_classes if cls in names]
print(desired_indices)
def detect(img, model, rois):
if img is None:
raise ValueError("No image provided!")
img.save("Inference/test.jpg")
source = 'Inference/'
weights = models[model]
imgsz = 640
# Parse the ROIs from the textbox input
rois_data = json.loads(rois)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = attempt_load(weights, map_location=device)
stride = int(model.stride.max())
imgsz = check_img_size(imgsz, s=stride)
dataset = LoadImages(source, img_size=imgsz, stride=stride)
names = model.module.names if hasattr(model, 'module') else model.names
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
for path, img, im0s, _ in dataset:
img = torch.from_numpy(img).to(device).float()
img /= 255.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
pred = model(img)[0]
pred = non_max_suppression(pred, 0.25, 0.45)
p, s, im0 = path, '', im0s
# Draw ROI polygons on the image
roi_points_list = []
for roi in rois_data:
roi_points = [(int(coord['x'] * im0.shape[1]), int(coord['y'] * im0.shape[0])) for coord in roi['coordinates']]
roi_points_list.append(roi_points)
cv2.polylines(im0, [np.array(roi_points)], isClosed=True, color=(0, 255, 0), thickness=2)
for i, det in enumerate(pred):
p, s, im0 = path, '', im0s
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]
if len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Filter detections based on ROIs and desired classes
filtered_detections = []
for *xyxy, conf, cls in reversed(det):
if int(cls) not in desired_indices:
continue
x_center = float((xyxy[0] + xyxy[2]) / 2)
y_center = float((xyxy[1] + xyxy[3]) / 2)
inside_roi = False
for roi_points in roi_points_list:
if cv2.pointPolygonTest(np.array(roi_points, dtype=np.float32), (x_center, y_center), False) >= 0:
inside_roi = True
break
filtered_detections.append((*xyxy, conf, cls, inside_roi))
# Plot the detections with the desired color
for *xyxy, conf, cls, inside_roi in filtered_detections:
color = (0, 255, 0) if inside_roi else (0, 0, 255) # green for inside ROI, red for outside
label = f'{names[int(cls)]} {conf:.2f}'
plot_one_box(xyxy, im0, label=label, color=color, line_thickness=1)
return Image.fromarray(im0[:,:,::-1])
# Modify the Gradio interface to accept ROIs
roi_example = [
["8-2.jpg", "yolov7", json.dumps([
{"coordinates": [{ "x": 0.005, "y": 0.644 },{ "x": 0.047, "y": 0.572 },{ "x": 0.961, "y": 0.834 },{ "x": 0.695, "y": 0.919 }]},
{"coordinates": [{"x": 0.237, "y": 0.505}, {"x": 0.283, "y": 0.460}, {"x": 0.921, "y": 0.578}, {"x": 0.912, "y": 0.654}]}
])],
["9-1.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.124, "y": 0.143}, {"x": 0.328, "y": 0.369}, {"x": 0.322, "y": 0.510}, {"x": 0.006, "y": 0.668}]},
{"coordinates": [{"x": 0.226, "y": 0.129}, {"x": 0.498, "y": 0.275}, {"x": 0.773, "y": 0.234}, {"x": 0.889, "y": 0.086}]}
])],
["5-2.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.369, "y": 0.444}, {"x": 0.454, "y": 0.433}, {"x": 0.988, "y": 0.870}, {"x": 0.006, "y": 0.856}]},
{"coordinates": [{"x": 0.242, "y": 0.323}, {"x": 0.250, "y": 0.490}, {"x": 0.004, "y": 0.776}, {"x": 0.006, "y": 0.613}]},
{"coordinates": [{"x": 0.466, "y": 0.356}, {"x": 0.512, "y": 0.323}, {"x": 0.989, "y": 0.524}, {"x": 0.979, "y": 0.630}]}
])],
["11-2.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.751, "y": 0.514}, {"x": 0.824, "y": 0.499}, {"x": 0.996, "y": 0.768}, {"x": 0.923, "y": 0.881}]},
{"coordinates": [{"x": 0.566, "y": 0.494}, {"x": 0.609, "y": 0.511}, {"x": 0.305, "y": 0.917}, {"x": 0.002, "y": 0.722}]},
{"coordinates": [{"x": 0.430, "y": 0.388}, {"x": 0.455, "y": 0.419}, {"x": 0.005, "y": 0.582}, {"x": 0.004, "y": 0.514}]}
])],
["8-1.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.027, "y": 0.434}, {"x": 0.069, "y": 0.380}, {"x": 0.993, "y": 0.556}, {"x": 0.985, "y": 0.654}]}
])],
["9-2.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.124, "y": 0.143}, {"x": 0.328, "y": 0.369}, {"x": 0.322, "y": 0.510}, {"x": 0.006, "y": 0.668}]},
{"coordinates": [{"x": 0.226, "y": 0.129}, {"x": 0.498, "y": 0.275}, {"x": 0.773, "y": 0.234}, {"x": 0.889, "y": 0.086}]}
])],
["2-2.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.302, "y": 0.475}, {"x": 0.264, "y": 0.523}, {"x": 0.366, "y": 0.625}, {"x": 0.994, "y": 0.907}, {"x": 0.270, "y": 0.909}, {"x": 0.153, "y": 0.556}, {"x": 0.245, "y": 0.470}]},
{"coordinates": [{"x": 0.445, "y": 0.507}, {"x": 0.380, "y": 0.547}, {"x": 0.993, "y": 0.824}, {"x": 0.997, "y": 0.681}]},
{"coordinates": [{"x": 0.401, "y": 0.464}, {"x": 0.434, "y": 0.470}, {"x": 0.428, "y": 0.497}, {"x": 0.373, "y": 0.508}]},
{"coordinates": [{"x": 0.599, "y": 0.542}, {"x": 0.623, "y": 0.492}, {"x": 0.996, "y": 0.564}, {"x": 0.995, "y": 0.646}]},
{"coordinates": [{"x": 0.556, "y": 0.451}, {"x": 0.517, "y": 0.488}, {"x": 0.580, "y": 0.527}, {"x": 0.572, "y": 0.475}]}
])],
["2-1.jpg", "yolov7", json.dumps([{"coordinates": [{"x": 0.263, "y": 0.612}, {"x": 0.297, "y": 0.607}, {"x": 0.403, "y": 0.913}, {"x": 0.039, "y": 0.724}]},
{"coordinates": [{"x": 0.356, "y": 0.494}, {"x": 0.419, "y": 0.477}, {"x": 0.914, "y": 0.893}, {"x": 0.746, "y": 0.912}]},
{"coordinates": [{"x": 0.226, "y": 0.481}, {"x": 0.252, "y": 0.478}, {"x": 0.249, "y": 0.566}, {"x": 0.033, "y": 0.708}]}
])],
["6-1.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.604, "y": 0.391}, {"x": 0.723, "y": 0.385}, {"x": 0.973, "y": 0.594}, {"x": 0.497, "y": 0.744}]},
{"coordinates": [{"x": 0.482, "y": 0.198}, {"x": 0.558, "y": 0.202}, {"x": 0.515, "y": 0.556}, {"x": 0.339, "y": 0.501}]}
])],
["11-5.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.871, "y": 0.461}, {"x": 0.982, "y": 0.485}, {"x": 0.934, "y": 0.891}, {"x": 0.599, "y": 0.892}]},
{"coordinates": [{"x": 0.512, "y": 0.441}, {"x": 0.572, "y": 0.454}, {"x": 0.003, "y": 0.761}, {"x": 0.003, "y": 0.637}]}
])],
["5-1.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.277, "y": 0.217}, {"x": 0.300, "y": 0.220}, {"x": 0.263, "y": 0.906}, {"x": 0.003, "y": 0.912}]},
{"coordinates": [{"x": 0.376, "y": 0.236}, {"x": 0.429, "y": 0.223}, {"x": 0.923, "y": 0.833}, {"x": 0.623, "y": 0.912}]},
{"coordinates": [{"x": 0.497, "y": 0.232}, {"x": 0.536, "y": 0.217}, {"x": 0.995, "y": 0.519}, {"x": 0.978, "y": 0.652}]}
])],
["6-2.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.369, "y": 0.533}, {"x": 0.410, "y": 0.527}, {"x": 0.553, "y": 0.907}, {"x": 0.404, "y": 0.909}]},
{"coordinates": [{"x": 0.551, "y": 0.461}, {"x": 0.595, "y": 0.440}, {"x": 0.993, "y": 0.734}, {"x": 0.943, "y": 0.821}]}
])],
["10-2.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.619, "y": 0.256}, {"x": 0.649, "y": 0.283}, {"x": 0.329, "y": 0.549}, {"x": 0.126, "y": 0.502}]},
{"coordinates": [{"x": 0.792, "y": 0.261}, {"x": 0.901, "y": 0.265}, {"x": 0.996, "y": 0.482}, {"x": 0.613, "y": 0.596}]}
])],
["1-1.jpg", "yolov7", json.dumps([
{"coordinates": [{"x": 0.941, "y": 0.231}, {"x": 0.890, "y": 0.244}, {"x": 0.910, "y": 0.393}, {"x": 0.886, "y": 0.706}, {"x": 0.953, "y": 0.904}, {"x": 0.992, "y": 0.732}, {"x": 0.972, "y": 0.396}]},
{"coordinates": [{"x": 0.629, "y": 0.255}, {"x": 0.661, "y": 0.281}, {"x": 0.006, "y": 0.602}, {"x": 0.005, "y": 0.408}]}
])]
]
print(len(roi_example))
gr_examples = [[example[0], example[1], example[2]] for example in roi_example]
description_html = """<b>Demo for YOLOv7 Object Detection</b>: This interface is specifically tailored for <b>detecting vehicles</b> in images. The primary focus is on <b>accident-prone regions</b> on public roads. By leveraging state-of-the-art object detection techniques, this tool aims to provide insights into areas where vehicles are most at risk, helping in <b>road safety analysis</b> and <b>preventive measures</b>. Users can also define <b>Regions of Interest (ROIs)</b> to narrow down the detection area, ensuring that the analysis is focused on the most critical parts of the image."""
gr.Interface(
detect,
[gr.Image(type="pil", label="Upload Image"), gr.Dropdown(choices=model_names, label="Model"), gr.Textbox(label="ROIs (JSON format)")],
gr.Image(type="pil", label="Detected Objects"),
title="YOLOv7 Object Detection for Accident-Prone Regions",
examples=gr_examples,
description=description_html,
live=False,# This ensures that the model doesn't run until the 'Submit' button is clicked
example_ceiling=14
).launch()