croppie_coffee_ug / scripts /render_results.py
rgautroncgiar's picture
Initial commit
b6ad7e1
import cv2
import matplotlib.pyplot as plt
from PIL import ImageColor
from pathlib import Path
import os
def annotate_image_prediction(image_path, yolo_boxes, class_dic, saving_folder, hex_class_colors=None, show=False, true_count=False, saving_image_name=None, put_title=True, box_thickness=3, font_scale=1, font_thickness=5):
"""
Fonction to label individual images with YOLO predictions
Args:
image_path (str): path to the image to label
yolo_boxes (str): YOLO predicted boxes
class_dic (dict): dictionary with predicted class as key and corresponding label as value
saving_folder (str): folder where to save the annotated image
hex_class_colors (dict, optional): HEX color code dict of the class to plot. Defaults to None.
show (bool, optional): If you want a window of the annotated image to pop up. Defaults to False.
true_count (bool, optional): If you want to display the true total count of cherries. Defaults to None.
saving_image_name (str, optional): Name of the annotated image to save. Defaults to None.
put_title (bool, optional): If you want a title to show in the plot. Defaults to True.
box_thickness (int, optional): Thickness of the bounding boxes to plot. Defaults to 3.
font_scale (int, optional): Font scale of the text of counts to be displayed. Defaults to 1.
font_thickness (int, optional): Font thickness of the text of counts to be displayed. Defaults to 5.
Returns:
string: saving path of the annotated image
"""
if os.path.isfile(image_path):
Path(saving_folder).mkdir(parents=True, exist_ok=True)
image_file = image_path.split('/')[-1]
if not hex_class_colors:
hex_class_colors = {class_name: (255, 0, 0) for class_name in class_dic.values()}
color_map = {key: ImageColor.getcolor(hex_class_colors[class_dic[key]], 'RGB') for key in [*class_dic]}
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
dh, dw, _ = img.shape
for yolo_box in yolo_boxes:
x1, y1, x2, y2 = yolo_box.xyxy[0]
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
c = int(yolo_box.cls[0])
cv2.rectangle(img, (x1, y1), (x2, y2), color_map[c], box_thickness)
if show:
plt.imshow(img)
plt.show()
img_copy = img.copy()
if put_title:
if true_count:
title = f'Predicted count: {len(yolo_boxes)}, true count: {true_count}, delta: {len(yolo_boxes) - true_count}'
else:
title = f'Predicted count: {len(yolo_boxes)}'
cv2.putText(
img=img_copy,
text=title,
org=(int(0.1 * dw), int(0.1 * dh)),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=font_scale,
thickness=font_thickness,
color=(255,251,5),
)
if not saving_image_name:
saving_image_name = f'annotated_{image_file}'
Path(saving_folder).mkdir(parents=True, exist_ok=True)
full_saving_path = os.path.join(saving_folder, saving_image_name)
plt.imsave(full_saving_path, img_copy)
else:
full_saving_path = None
print(f'WARNING: {image_path} does not exists')
return full_saving_path