Spaces:
Runtime error
Runtime error
import os | |
import gdown | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib import colors | |
from ultralytics import YOLO | |
from ultralytics.utils.ops import scale_image | |
def setup(): | |
CAR_PART_SEG_URL = "https://drive.google.com/uc?id=1I_LKds9obElNIZcW_DM8zyknrwRmrASj" | |
CAR_DAM_DET_URL = "https://drive.google.com/uc?id=1AXDyFoEuNqXSaDNUHBp8H9AVjICvsUpz" | |
CAR_SEV_DET_URL = "https://drive.google.com/uc?id=1An7QGjbL-UEu7LOT7Xh59itBE854vy4U" | |
CAR_PART_SEG_OUT = "weight/yolov8-car-part-seg.pt" | |
CAR_DAM_DET_OUT = "weight/yolov8-car-damage-detection.pt" | |
CAR_SEV_DET_OUT = "weight/yolov8-car-damage-serverity-detection.pt" | |
SAMPLE = { | |
"car-parts-seg" : [ [f"{root}/{file}"] \ | |
for root, _, files in os.walk("sample/car-parts-seg", topdown=False) \ | |
for file in files ], | |
"car-dam-det" : [ [f"{root}/{file}"] \ | |
for root, _, files in os.walk("sample/car-damage-det", topdown=False) \ | |
for file in files ], | |
"car-dam-sev-det" : [ [f"{root}/{file}"] \ | |
for root, _, files in os.walk("sample/car-damage-sev-det", topdown=False) \ | |
for file in files ], | |
} | |
if not os.path.exists(CAR_PART_SEG_OUT): | |
os.makedirs("weight", exist_ok=True) | |
gdown.download(CAR_PART_SEG_URL, CAR_PART_SEG_OUT, quiet=True) | |
if not os.path.exists(CAR_DAM_DET_OUT): | |
os.makedirs("weight", exist_ok=True) | |
gdown.download(CAR_DAM_DET_URL, CAR_DAM_DET_OUT, quiet=True) | |
if not os.path.exists(CAR_SEV_DET_URL): | |
os.makedirs("weight", exist_ok=True) | |
gdown.download(CAR_SEV_DET_URL, CAR_SEV_DET_OUT, quiet=True) | |
return CAR_PART_SEG_OUT, CAR_DAM_DET_OUT, CAR_SEV_DET_OUT, SAMPLE | |
class Predictor: | |
def __init__(self, model_weight): | |
self.model = YOLO(model_weight) | |
self.category_map = self.model.names | |
self.NCLS = len(self.category_map) | |
cmap = plt.cm.rainbow | |
cmaplist = [cmap(i) for i in range(cmap.N)] | |
self.cmap = cmap.from_list(f'my cmap', cmaplist, cmap.N) | |
bounds = np.linspace(0, self.NCLS, self.NCLS + 1) | |
norm = colors.BoundaryNorm(bounds, self.cmap.N) | |
category_cmap = { k: cmap(norm(int(k))) for k in self.category_map } | |
self.category_cmap = { k: (v[2] * 255, v[1] * 255, v[0]* 255) \ | |
for k, v in category_cmap.items() } | |
def predict(self, image_path): | |
image = cv2.imread(image_path) | |
outputs = self.model.predict(source=image_path) | |
results = outputs[0].cpu().numpy() | |
boxes = results.boxes.xyxy if results.boxes is not None else [] | |
confs = results.boxes.conf if results.boxes is not None else [] | |
cls = results.boxes.cls if results.boxes is not None else [] | |
# probs = results.boxes.probs | |
masks = results.masks.data if results.masks is not None else [] | |
return image, cls, confs, boxes, masks, results | |
def annotate_boxes(self, image, cls, confs, boxes, results): | |
# image, cls, confs, boxes, _, results = self.predict(image_path) | |
for i, (box, cl, conf) in enumerate(zip(boxes, cls, confs)): | |
label = results.names[cl] | |
color = self.category_cmap[cl] | |
text = label + f" {conf:.2f}" | |
x1, y1, x2, y2 = ( int(p) for p in box ) | |
cv2.rectangle(image, (x1, y1), (x2, y2), | |
color=color, | |
thickness=2, | |
lineType=cv2.LINE_AA | |
) | |
(w, h), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_DUPLEX, 0.3, 1) | |
cv2.rectangle(image, (x1, y1 - 2*h), (x1 + w, y1), color, -1) | |
cv2.putText(image, text, (x1, y1 - 5), | |
cv2.FONT_HERSHEY_DUPLEX, 0.3, (255, 255, 255), 1) | |
return image | |
def annotate_masks(self, image, cls, confs, masks, results): | |
# image, cls, confs, _, masks, results = self.predict(image_path) | |
ori_shape = image.shape[:2] | |
for i, (mask, cl, conf) in enumerate(zip(masks, cls, confs)): | |
mask = mask.astype("uint8") | |
label = results.names[cl] | |
color = self.category_cmap[cl] | |
text = label + f" {conf:.2f}" | |
_mask = np.where(mask[..., None], color, (0, 0, 0)) | |
_mask = scale_image(_mask, ori_shape).astype("uint8") | |
image = cv2.addWeighted(image, 1, _mask, 0.5, 0) | |
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
boundary = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGBA).astype("float") | |
cv2.drawContours(boundary, contours, -1, color, 2) | |
boundary = scale_image(boundary, ori_shape)[:, :, :-1].astype("uint8") | |
image = cv2.addWeighted(image, 1, boundary, 1, 0) | |
cy, cx = np.round(np.argwhere(_mask != [0, 0, 0]).mean(axis=0)[0:2]).astype(int) | |
(w, h), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_DUPLEX, 0.5, 1) | |
cv2.putText(image, text, (cx - int(0.5 * w), cy), | |
cv2.FONT_HERSHEY_DUPLEX, 0.5, (0, 0, 0), 2) | |
cv2.putText(image, text, (cx - int(0.5 * w), cy), | |
cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255), 1) | |
return image | |
def transform(self, image_path, annot_boxes=False, annot_masks=False): | |
image, cls, confs, boxes, masks, results = self.predict(image_path) | |
if annot_masks: | |
image = self.annotate_masks(image, cls, confs, masks, results) | |
if annot_boxes: | |
image = self.annotate_boxes(image, cls, confs, boxes, results) | |
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |