Thanaphit's picture
Add examples for damage severity
59632e2
import os
import gdown
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
from ultralytics import YOLO
from ultralytics.utils.ops import scale_image
def setup():
CAR_PART_SEG_URL = "https://drive.google.com/uc?id=1I_LKds9obElNIZcW_DM8zyknrwRmrASj"
CAR_DAM_DET_URL = "https://drive.google.com/uc?id=1AXDyFoEuNqXSaDNUHBp8H9AVjICvsUpz"
CAR_SEV_DET_URL = "https://drive.google.com/uc?id=1An7QGjbL-UEu7LOT7Xh59itBE854vy4U"
CAR_PART_SEG_OUT = "weight/yolov8-car-part-seg.pt"
CAR_DAM_DET_OUT = "weight/yolov8-car-damage-detection.pt"
CAR_SEV_DET_OUT = "weight/yolov8-car-damage-serverity-detection.pt"
SAMPLE = {
"car-parts-seg" : [ [f"{root}/{file}"] \
for root, _, files in os.walk("sample/car-parts-seg", topdown=False) \
for file in files ],
"car-dam-det" : [ [f"{root}/{file}"] \
for root, _, files in os.walk("sample/car-damage-det", topdown=False) \
for file in files ],
"car-dam-sev-det" : [ [f"{root}/{file}"] \
for root, _, files in os.walk("sample/car-damage-sev-det", topdown=False) \
for file in files ],
}
if not os.path.exists(CAR_PART_SEG_OUT):
os.makedirs("weight", exist_ok=True)
gdown.download(CAR_PART_SEG_URL, CAR_PART_SEG_OUT, quiet=True)
if not os.path.exists(CAR_DAM_DET_OUT):
os.makedirs("weight", exist_ok=True)
gdown.download(CAR_DAM_DET_URL, CAR_DAM_DET_OUT, quiet=True)
if not os.path.exists(CAR_SEV_DET_URL):
os.makedirs("weight", exist_ok=True)
gdown.download(CAR_SEV_DET_URL, CAR_SEV_DET_OUT, quiet=True)
return CAR_PART_SEG_OUT, CAR_DAM_DET_OUT, CAR_SEV_DET_OUT, SAMPLE
class Predictor:
def __init__(self, model_weight):
self.model = YOLO(model_weight)
self.category_map = self.model.names
self.NCLS = len(self.category_map)
cmap = plt.cm.rainbow
cmaplist = [cmap(i) for i in range(cmap.N)]
self.cmap = cmap.from_list(f'my cmap', cmaplist, cmap.N)
bounds = np.linspace(0, self.NCLS, self.NCLS + 1)
norm = colors.BoundaryNorm(bounds, self.cmap.N)
category_cmap = { k: cmap(norm(int(k))) for k in self.category_map }
self.category_cmap = { k: (v[2] * 255, v[1] * 255, v[0]* 255) \
for k, v in category_cmap.items() }
def predict(self, image_path):
image = cv2.imread(image_path)
outputs = self.model.predict(source=image_path)
results = outputs[0].cpu().numpy()
boxes = results.boxes.xyxy if results.boxes is not None else []
confs = results.boxes.conf if results.boxes is not None else []
cls = results.boxes.cls if results.boxes is not None else []
# probs = results.boxes.probs
masks = results.masks.data if results.masks is not None else []
return image, cls, confs, boxes, masks, results
def annotate_boxes(self, image, cls, confs, boxes, results):
# image, cls, confs, boxes, _, results = self.predict(image_path)
for i, (box, cl, conf) in enumerate(zip(boxes, cls, confs)):
label = results.names[cl]
color = self.category_cmap[cl]
text = label + f" {conf:.2f}"
x1, y1, x2, y2 = ( int(p) for p in box )
cv2.rectangle(image, (x1, y1), (x2, y2),
color=color,
thickness=2,
lineType=cv2.LINE_AA
)
(w, h), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_DUPLEX, 0.3, 1)
cv2.rectangle(image, (x1, y1 - 2*h), (x1 + w, y1), color, -1)
cv2.putText(image, text, (x1, y1 - 5),
cv2.FONT_HERSHEY_DUPLEX, 0.3, (255, 255, 255), 1)
return image
def annotate_masks(self, image, cls, confs, masks, results):
# image, cls, confs, _, masks, results = self.predict(image_path)
ori_shape = image.shape[:2]
for i, (mask, cl, conf) in enumerate(zip(masks, cls, confs)):
mask = mask.astype("uint8")
label = results.names[cl]
color = self.category_cmap[cl]
text = label + f" {conf:.2f}"
_mask = np.where(mask[..., None], color, (0, 0, 0))
_mask = scale_image(_mask, ori_shape).astype("uint8")
image = cv2.addWeighted(image, 1, _mask, 0.5, 0)
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
boundary = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGBA).astype("float")
cv2.drawContours(boundary, contours, -1, color, 2)
boundary = scale_image(boundary, ori_shape)[:, :, :-1].astype("uint8")
image = cv2.addWeighted(image, 1, boundary, 1, 0)
cy, cx = np.round(np.argwhere(_mask != [0, 0, 0]).mean(axis=0)[0:2]).astype(int)
(w, h), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_DUPLEX, 0.5, 1)
cv2.putText(image, text, (cx - int(0.5 * w), cy),
cv2.FONT_HERSHEY_DUPLEX, 0.5, (0, 0, 0), 2)
cv2.putText(image, text, (cx - int(0.5 * w), cy),
cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255), 1)
return image
def transform(self, image_path, annot_boxes=False, annot_masks=False):
image, cls, confs, boxes, masks, results = self.predict(image_path)
if annot_masks:
image = self.annotate_masks(image, cls, confs, masks, results)
if annot_boxes:
image = self.annotate_boxes(image, cls, confs, boxes, results)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)