import numpy as np import torch from modules.utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou from tqdm import tqdm from modules.toXML import get_size_elements, calculate_pool_bounds, create_BPMN_id from modules.utils import is_vertical, proportion_inside import streamlit as st from builtins import dict def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5): """ Perform non-maximum suppression to filter out overlapping bounding boxes. Parameters: - boxes (array): Array of bounding boxes. - scores (array): Array of confidence scores for each bounding box. - labels (array, optional): Array of labels for each bounding box. - iou_threshold (float): Intersection-over-Union threshold to use for filtering. Returns: - list: Indices of selected boxes after suppression. """ exception = ['pool', 'lane'] idxs = np.argsort(scores) # Sort the boxes according to their scores in ascending order selected_boxes = [] while len(idxs) > 0: last = len(idxs) - 1 i = idxs[last] # Skip if the label is a lane if labels is not None and (class_dict[labels[i]] in exception): selected_boxes.append(i) idxs = np.delete(idxs, last) continue selected_boxes.append(i) # Find the intersection of the box with the rest suppress = [last] for pos in range(0, last): j = idxs[pos] if iou(boxes[i], boxes[j]) > iou_threshold: suppress.append(pos) idxs = np.delete(idxs, suppress) # Return only the boxes that were selected return selected_boxes def keypoint_correction(keypoints, boxes, labels, model_dict=arrow_dict, distance_treshold=15): """ Correct keypoints that are too close together by adjusting their positions. Parameters: - keypoints (array): Array of keypoints. - boxes (array): Array of bounding boxes. - labels (array): Array of labels for each bounding box. - model_dict (dict): Dictionary mapping model labels to indices. - distance_treshold (int): Distance threshold below which keypoints are considered too close. Returns: - array: Corrected keypoints. """ for idx, (key1, key2) in enumerate(keypoints): if labels[idx] not in [list(model_dict.values()).index('sequenceFlow'), list(model_dict.values()).index('messageFlow'), list(model_dict.values()).index('dataAssociation')]: continue # Calculate the Euclidean distance between the two keypoints distance = np.linalg.norm(key1[:2] - key2[:2]) if distance < distance_treshold: print('Key modified for index:', idx) x_new, y_new, x, y = find_other_keypoint(idx, keypoints, boxes) keypoints[idx][0][:2] = [x_new, y_new] keypoints[idx][1][:2] = [x, y] return keypoints def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5): """ Perform object detection prediction using the model. Parameters: - model (torch.nn.Module): The object detection model. - image (torch.Tensor): The input image. - score_threshold (float): Score threshold for filtering predictions. - iou_threshold (float): IoU threshold for non-maximum suppression. Returns: - numpy.array, dict: The processed image and the prediction dictionary containing 'boxes', 'scores', and 'labels'. """ model.eval() with torch.no_grad(): image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) predictions = model(image_tensor) boxes = predictions[0]['boxes'].cpu().numpy() labels = predictions[0]['labels'].cpu().numpy() scores = predictions[0]['scores'].cpu().numpy() idx = np.where(scores > score_threshold)[0] boxes = boxes[idx] scores = scores[idx] labels = labels[idx] selected_boxes = non_maximum_suppression(boxes, scores, labels=labels, iou_threshold=iou_threshold) # Find orientation of the task by checking the size of all the boxes and delete the ones that are not in the same orientation vertical = 0 for i in range(len(labels)): if labels[i] != list(object_dict.values()).index('task'): continue if is_vertical(boxes[i]): vertical += 1 horizontal = len(labels) - vertical for i in range(len(labels)): if labels[i] != list(object_dict.values()).index('task'): continue if vertical < horizontal: if is_vertical(boxes[i]): # Find the element in the list and remove it if i in selected_boxes: selected_boxes.remove(i) elif vertical > horizontal: if is_vertical(boxes[i]) == False: # Find the element in the list and remove it if i in selected_boxes: selected_boxes.remove(i) else: pass boxes = boxes[selected_boxes] scores = scores[selected_boxes] labels = labels[selected_boxes] # Find the outlier objects that are too small by the area obj_not_too_small = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref=['event', 'messageEvent'], mode="lower") obj_not_too_big = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=2, element_ref=['task'], mode="upper") selected_object = [i for i in range(len(labels)) if i in obj_not_too_small and i in obj_not_too_big] boxes = boxes[selected_object] scores = scores[selected_object] labels = labels[selected_object] # Modify the label of the sub-process to task for i in range(len(labels)): if labels[i] == list(object_dict.values()).index('subProcess'): labels[i] = list(object_dict.values()).index('task') # Delete all lane and also the value in the labels and scores lane_index = [i for i in range(len(labels)) if labels[i] == list(object_dict.values()).index('lane')] boxes = np.delete(boxes, lane_index, axis=0) labels = np.delete(labels, lane_index) scores = np.delete(scores, lane_index) prediction = { 'boxes': boxes, 'scores': scores, 'labels': labels, } image = image.permute(1, 2, 0).cpu().numpy() image = (image * 255).astype(np.uint8) return image, prediction def arrow_prediction(model, image, score_threshold=0.5, iou_threshold=0.5, distance_treshold=15): """ Perform arrow detection prediction using the model. Parameters: - model (torch.nn.Module): The arrow detection model. - image (torch.Tensor): The input image. - score_threshold (float): Score threshold for filtering predictions. - iou_threshold (float): IoU threshold for non-maximum suppression. - distance_treshold (int): Distance threshold for keypoint correction. Returns: - numpy.array, dict: The processed image and the prediction dictionary containing 'boxes', 'scores', 'labels', and 'keypoints'. """ model.eval() with torch.no_grad(): image_tensor = image.unsqueeze(0).to(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) predictions = model(image_tensor) boxes = predictions[0]['boxes'].cpu().numpy() labels = predictions[0]['labels'].cpu().numpy() + (len(object_dict) - 1) scores = predictions[0]['scores'].cpu().numpy() keypoints = predictions[0]['keypoints'].cpu().numpy() idx = np.where(scores > score_threshold)[0] boxes = boxes[idx] scores = scores[idx] labels = labels[idx] keypoints = keypoints[idx] selected_boxes = non_maximum_suppression(boxes, scores, iou_threshold=iou_threshold) boxes = boxes[selected_boxes] scores = scores[selected_boxes] labels = labels[selected_boxes] keypoints = keypoints[selected_boxes] keypoints = keypoint_correction(keypoints, boxes, labels, class_dict, distance_treshold=distance_treshold) prediction = { 'boxes': boxes, 'scores': scores, 'labels': labels, 'keypoints': keypoints, } image = image.permute(1, 2, 0).cpu().numpy() image = (image * 255).astype(np.uint8) return image, prediction def mix_predictions(objects_pred, arrow_pred): """ Combine object and arrow predictions into a single set of predictions. Parameters: - objects_pred (dict): Object predictions dictionary. - arrow_pred (dict): Arrow predictions dictionary. Returns: - tuple: Combined boxes, labels, scores, and keypoints. """ # Initialize the list of lists for keypoints object_keypoints = [] # Number of boxes num_boxes = len(objects_pred['boxes']) # Iterate over the number of boxes for _ in range(num_boxes): # Each box has 2 keypoints, both initialized to [0, 0, 0] keypoints = [[0, 0, 0], [0, 0, 0]] object_keypoints.append(keypoints) # Concatenate the two predictions if len(arrow_pred['boxes']) == 0: return objects_pred['boxes'], objects_pred['labels'], objects_pred['scores'], object_keypoints boxes = np.concatenate((objects_pred['boxes'], arrow_pred['boxes'])) labels = np.concatenate((objects_pred['labels'], arrow_pred['labels'])) scores = np.concatenate((objects_pred['scores'], arrow_pred['scores'])) keypoints = np.concatenate((object_keypoints, arrow_pred['keypoints'])) return boxes, labels, scores, keypoints def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.6): """ Regroup elements by pool based on IoU and proximity. Parameters: - boxes (array): Array of bounding boxes. - labels (array): Array of labels for each bounding box. - scores (array): Array of confidence scores for each bounding box. - keypoints (array): Array of keypoints. - class_dict (dict): Dictionary mapping class names to indices. - iou_threshold (float): IoU threshold for grouping. Returns: - dict: Dictionary grouping elements by pool. - array: Updated arrays of boxes, labels, scores, and keypoints. """ pool_dict = {} # Filter out pools with IoU greater than the threshold to_delete = [] for i in range(len(boxes)): for j in range(i + 1, len(boxes)): if labels[i] == labels[j] and labels[i] == list(class_dict.values()).index('pool'): if proportion_inside(boxes[i], boxes[j]) > iou_threshold: to_delete.append(j) boxes = np.delete(boxes, to_delete, axis=0) labels = np.delete(labels, to_delete) scores = np.delete(scores, to_delete) keypoints = np.delete(keypoints, to_delete, axis=0) pool_indices = [i for i, label in enumerate(labels) if class_dict[label.item()] == 'pool'] pool_boxes = [boxes[i] for i in pool_indices] if pool_indices: for pool_index in pool_indices: pool_dict[pool_index] = [] elements_not_in_pool = [] for i, box in enumerate(boxes): assigned_to_pool = False if i in pool_indices or class_dict[labels[i]] in ['messageFlow', 'pool']: continue for j, pool_box in enumerate(pool_boxes): if proportion_inside(box, pool_box) > iou_threshold: pool_index = pool_indices[j] pool_dict[pool_index].append(i) assigned_to_pool = True break if not assigned_to_pool: if class_dict[labels[i]] not in ['messageFlow', 'lane', 'pool']: elements_not_in_pool.append(i) if len(elements_not_in_pool) > 1: new_elements_not_in_pool = [i for i in elements_not_in_pool if class_dict[labels[i]] not in ['messageFlow', 'lane', 'pool']] # Indices of relevant classes sequence_flow_index = list(class_dict.values()).index('sequenceFlow') message_flow_index = list(class_dict.values()).index('messageFlow') data_association_index = list(class_dict.values()).index('dataAssociation') if all(labels[i] in {sequence_flow_index, message_flow_index, data_association_index} for i in new_elements_not_in_pool): print('The new pool contains only sequenceFlow, messageFlow, or dataAssociation') elif len(new_elements_not_in_pool) > 1: new_pool_index = len(labels) box = calculate_pool_bounds(boxes, labels, new_elements_not_in_pool, None) boxes = np.append(boxes, [box], axis=0) labels = np.append(labels, list(class_dict.values()).index('pool')) scores = np.append(scores, 1.0) keypoints = np.append(keypoints, np.zeros((1, 2, 3)), axis=0) pool_dict[new_pool_index] = new_elements_not_in_pool print(f"Created a new pool index {new_pool_index} with elements: {new_elements_not_in_pool}") non_empty_pools = {k: v for k, v in pool_dict.items() if v} empty_pools = {k: v for k, v in pool_dict.items() if not v} pool_dict = {**non_empty_pools, **empty_pools} return pool_dict, boxes, labels, scores, keypoints def create_links(keypoints, boxes, labels, class_dict): """ Create links between elements based on keypoints. Parameters: - keypoints (array): Array of keypoints. - boxes (array): Array of bounding boxes. - labels (array): Array of labels for each bounding box. - class_dict (dict): Dictionary mapping class names to indices. Returns: - list: List of links between elements. - list: List of best points for each link. """ best_points = [] links = [] for i in range(len(labels)): if labels[i] == list(class_dict.values()).index('sequenceFlow') or labels[i] == list(class_dict.values()).index('messageFlow'): closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels) closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels) if closest1 is not None and closest2 is not None: best_points.append([point_start, point_end]) links.append([closest1, closest2]) else: best_points.append([None, None]) links.append([None, None]) for i in range(len(labels)): if labels[i] == list(class_dict.values()).index('dataAssociation'): closest1, point_start = find_closest_object(keypoints[i][0], boxes, labels) closest2, point_end = find_closest_object(keypoints[i][1], boxes, labels) if closest1 is not None and closest2 is not None: best_points[i] = ([point_start, point_end]) links[i] = ([closest1, closest2]) return links, best_points def correction_labels(boxes, labels, class_dict, pool_dict, flow_links): """ Correct labels based on the relationships between elements and pools. Parameters: - boxes (array): Array of bounding boxes. - labels (array): Array of labels for each bounding box. - class_dict (dict): Dictionary mapping class names to indices. - pool_dict (dict): Dictionary grouping elements by pool. - flow_links (list): List of links between elements. Returns: - array: Corrected labels. - list: Updated flow links. """ sequence_flow_index = list(class_dict.values()).index('sequenceFlow') message_flow_index = list(class_dict.values()).index('messageFlow') data_association_index = list(class_dict.values()).index('dataAssociation') data_object_index = list(class_dict.values()).index('dataObject') data_store_index = list(class_dict.values()).index('dataStore') message_event_index = list(class_dict.values()).index('messageEvent') senquence_flow_indexx = list(class_dict.values()).index('sequenceFlow') for pool_index, elements in pool_dict.items(): print(f"Pool {pool_index} contains elements: {elements}") # Check if the label sequenceFlow or messageFlow is good for i, (id1, id2) in enumerate(flow_links): if labels[i] in {sequence_flow_index, message_flow_index}: if id1 is not None and id2 is not None: # Check if each link is in the same pool if id1 in elements and id2 in elements: # Check if the link is between a dataObject or a dataStore if labels[id1] in {data_object_index, data_store_index} or labels[id2] in {data_object_index, data_store_index}: print('Change the link from sequenceFlow/messageFlow to dataAssociation') labels[i] = data_association_index else: continue elif id1 not in elements and id2 not in elements: continue else: print('Change the link from sequenceFlow to messageFlow') labels[i] = message_flow_index # Check if dataAssociation is connected to a dataObject for i, (id1, id2) in enumerate(flow_links): if labels[i] == data_association_index: if id1 is not None and id2 is not None: label1 = labels[id1] label2 = labels[id2] if data_object_index in {label1, label2} or data_store_index in {label1, label2}: continue elif message_event_index in {label1, label2}: print('Change the link from dataAssociation to messageFlow') labels[i] = message_flow_index else: print('Change the link from dataAssociation to sequenceFlow') labels[i] = senquence_flow_indexx return labels, flow_links def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref=['event', 'messageEvent'], mode="lower"): """ Identify outlier objects based on their area. Parameters: - boxes (array): Array of bounding boxes. - labels (array): Array of labels for each bounding box. - class_dict (dict): Dictionary mapping class names to indices. - std_factor (float): Standard deviation factor for determining outliers. - element_ref (list): List of reference elements for calculating area statistics. - mode (str): Mode to identify outliers ('lower', 'upper', or 'both'). Returns: - list: Indices of kept objects that are not outliers. """ # Filter out the sizes of events, data objects, and message events event_indices = [i for i, label in enumerate(labels) if class_dict[label] in element_ref] event_boxes = [boxes[i] for i in event_indices] # Calculate the areas of these typical objects event_areas = np.array([(box[2] - box[0]) * (box[3] - box[1]) for box in event_boxes]) # Compute the mean and standard deviation for areas mean_area = np.mean(event_areas) std_area = np.std(event_areas) # Define thresholds for outliers area_lower_threshold = mean_area - std_factor * std_area area_upper_threshold = mean_area + std_factor * std_area # Identify indices of outliers and the ones to keep outlier_indices = [] kept_indices = [] if mode == "lower" or mode == 'both': # Check for objects that could be too small for idx, (box, label) in enumerate(zip(boxes, labels)): area = (box[2] - box[0]) * (box[3] - box[1]) if not (area_lower_threshold <= area): outlier_indices.append(idx) print(f"Element {idx} is an outlier with area {area} that is too small") else: kept_indices.append(idx) if mode == "upper" or mode == 'both': # Check for objects that could be too big for idx, (box, label) in enumerate(zip(boxes, labels)): if label == list(class_dict.values()).index('pool') or label == list(class_dict.values()).index('lane'): kept_indices.append(idx) continue area = (box[2] - box[0]) * (box[3] - box[1]) if not (area_upper_threshold >= area): outlier_indices.append(idx) print(f"Element {idx} is an outlier with area {area} that is too big") else: kept_indices.append(idx) return kept_indices def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict, limit_area=10000): """ Perform final corrections on the predictions by deleting irrelevant or small pools and duplicate elements. Parameters: - boxes (array): Array of bounding boxes. - labels (array): Array of labels for each bounding box. - scores (array): Array of confidence scores for each bounding box. - keypoints (array): Array of keypoints. - bpmn_id (list): List of BPMN IDs. - links (list): List of links between elements. - best_points (list): List of best points for each link. - pool_dict (dict): Dictionary grouping elements by pool. - limit_area (int): Minimum area threshold for pools. Returns: - tuple: Corrected arrays of boxes, labels, scores, keypoints, BPMN IDs, links, best points, and pool dictionary. """ # Delete pools that have only messageFlow on it delete_pool = [] for pool_index, elements in pool_dict.items(): # Find the position of the pool_index in the bpmn_id if pool_index in bpmn_id: position = bpmn_id.index(pool_index) else: continue if all([labels[i] in [list(class_dict.values()).index('messageFlow'), list(class_dict.values()).index('sequenceFlow'), list(class_dict.values()).index('dataAssociation'), list(class_dict.values()).index('lane')] for i in elements]): if len(elements) > 0: delete_pool.append(position) print(f"Pool {pool_index} contains only arrow elements, deleting it") # Calculate the area of the pool if position < len(boxes): pool = boxes[position] area = (pool[2] - pool[0]) * (pool[3] - pool[1]) if len(pool_dict) > 1 and area < limit_area: delete_pool.append(position) print(f"Pool {pool_index} is too small, deleting it") if is_vertical(boxes[position]): delete_pool.append(position) print(f"Pool {position} is vertical, deleting it") delete_elements = [] # Check if there is an arrow that has the same links for i in range(len(labels)): for j in range(i + 1, len(labels)): if labels[i] == list(class_dict.values()).index('sequenceFlow') and labels[j] == list(class_dict.values()).index('sequenceFlow'): if links[i] == links[j]: print(f'Element {i} and {j} have the same links') if scores[i] > scores[j]: print('Delete element', j) delete_elements.append(j) else: print('Delete element', i) delete_elements.append(i) # Concatenate the delete_elements and the delete_pool delete_elements = delete_elements + delete_pool # Delete double value in delete_elements delete_elements = list(set(delete_elements)) boxes = np.delete(boxes, delete_elements, axis=0) labels = np.delete(labels, delete_elements) scores = np.delete(scores, delete_elements) keypoints = np.delete(keypoints, delete_elements, axis=0) links = np.delete(links, delete_elements, axis=0) best_points = [point for i, point in enumerate(best_points) if i not in delete_elements] for i in range(len(delete_pool)): # Find the bpmn_id of the pool pool_index = bpmn_id[delete_pool[i]] # Delete the pool_index in pool_dict del pool_dict[pool_index] bpmn_id = [point for i, point in enumerate(bpmn_id) if i not in delete_elements] # Also delete the element in the pool_dict for pool_index, elements in pool_dict.items(): pool_dict[pool_index] = [i for i in elements if i not in delete_elements] return boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict def give_link_to_element(links, labels): """ Assign links to elements to create BPMN IDs for events. Parameters: - links (list): List of links between elements. - labels (array): Array of labels for each bounding box. Returns: - list: Updated list of links with assigned links for events. """ # Give a link to event to allow the creation of the BPMN ID with start, intermediate, and end event for i in range(len(links)): if labels[i] == list(class_dict.values()).index('sequenceFlow'): id1, id2 = links[i] if (id1 and id2) is not None: links[id1][1] = i links[id2][0] = i return links def generate_data(image, boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict): """ Generate a data dictionary containing image and prediction information. Parameters: - image (numpy.array): The input image. - boxes (array): Array of bounding boxes. - labels (array): Array of labels for each bounding box. - scores (array): Array of confidence scores for each bounding box. - keypoints (array): Array of keypoints. - bpmn_id (list): List of BPMN IDs. - flow_links (list): List of links between elements. - best_points (list): List of best points for each link. - pool_dict (dict): Dictionary grouping elements by pool. Returns: - dict: Data dictionary containing all prediction information. """ idx = [] for i in range(len(labels)): idx.append(i) data = { 'image': image, 'idx': idx, 'boxes': boxes, 'labels': labels, 'scores': scores, 'keypoints': keypoints, 'links': flow_links, 'best_points': best_points, 'pool_dict': pool_dict, 'BPMN_id': bpmn_id, } return data def develop_prediction(boxes, labels, scores, keypoints, class_dict): """ Develop predictions by regrouping elements, creating links, and correcting labels. Parameters: - boxes (array): Array of bounding boxes. - labels (array): Array of labels for each bounding box. - scores (array): Array of confidence scores for each bounding box. - keypoints (array): Array of keypoints. - class_dict (dict): Dictionary mapping class names to indices. Returns: - tuple: Developed prediction components including boxes, labels, scores, keypoints, BPMN IDs, flow links, best points, and pool dictionary. """ pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict) bpmn_id, pool_dict = create_BPMN_id(labels, pool_dict) # Create links between elements flow_links, best_points = create_links(keypoints, boxes, labels, class_dict) # Correct the labels of some sequenceFlow that cross multiple pools labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links) # Give a link to event to allow the creation of the BPMN ID with start, intermediate, and end event flow_links = give_link_to_element(flow_links, labels) boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = last_correction( boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict ) return boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_threshold=0.5, resize=True, distance_treshold=15): """ Perform a full prediction by combining object and arrow models and generating data. Parameters: - model_object (torch.nn.Module): The object detection model. - model_arrow (torch.nn.Module): The arrow detection model. - image (torch.Tensor): The input image. - score_threshold (float): Score threshold for filtering predictions. - iou_threshold (float): IoU threshold for non-maximum suppression. - resize (bool): Flag indicating whether to resize the image. - distance_treshold (int): Distance threshold for keypoint correction. Returns: - numpy.array, dict: The processed image and the data dictionary containing prediction information. """ model_object.eval() # Set the model to evaluation mode model_arrow.eval() # Set the model to evaluation mode # Load an image with torch.no_grad(): # Disable gradient calculation for inference _, objects_pred = object_prediction(model_object, image, score_threshold=score_threshold, iou_threshold=0.1) _, arrow_pred = arrow_prediction(model_arrow, image, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold) st.session_state.arrow_pred = arrow_pred boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred) boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction( boxes, labels, scores, keypoints, class_dict ) image = image.permute(1, 2, 0).cpu().numpy() image = (image * 255).astype(np.uint8) data = generate_data(image, boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict) return image, data def evaluate_model_by_class(pred_boxes, true_boxes, pred_labels, true_labels, model_dict, iou_threshold=0.5): """ Evaluate the model's performance on a per-class basis. Parameters: - pred_boxes (array): Predicted bounding boxes. - true_boxes (array): Ground truth bounding boxes. - pred_labels (array): Predicted labels. - true_labels (array): Ground truth labels. - model_dict (dict): Dictionary mapping model labels to indices. - iou_threshold (float): IoU threshold for determining matches. Returns: - tuple: Precision, recall, and F1-score per class. """ # Initialize dictionaries to hold per-class counts class_tp = {cls: 0 for cls in model_dict.values()} class_fp = {cls: 0 for cls in model_dict.values()} class_fn = {cls: 0 for cls in model_dict.values()} # Track which true boxes have been matched matched = [False] * len(true_boxes) # Check each prediction against true boxes for pred_box, pred_label in zip(pred_boxes, pred_labels): match_found = False for idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)): if not matched[idx] and pred_label == true_label: if iou(np.array(pred_box), np.array(true_box)) >= iou_threshold: class_tp[model_dict[pred_label]] += 1 matched[idx] = True match_found = True break if not match_found: class_fp[model_dict[pred_label]] += 1 # Count false negatives for idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)): if not matched[idx]: class_fn[model_dict[true_label]] += 1 # Calculate precision, recall, and F1-score per class class_precision = {} class_recall = {} class_f1_score = {} for cls in model_dict.values(): precision = class_tp[cls] / (class_tp[cls] + class_fp[cls]) if class_tp[cls] + class_fp[cls] > 0 else 0 recall = class_tp[cls] / (class_tp[cls] + class_fn[cls]) if class_tp[cls] + class_fn[cls] > 0 else 0 f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 class_precision[cls] = precision class_recall[cls] = recall class_f1_score[cls] = f1_score return class_precision, class_recall, class_f1_score def keypoints_measure(pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold=5): """ Measure the accuracy of predicted keypoints compared to true keypoints. Parameters: - pred_boxes (array): Predicted bounding boxes. - pred_box (array): Single predicted bounding box. - true_boxes (array): Ground truth bounding boxes. - true_box (array): Single ground truth bounding box. - pred_keypoints (array): Predicted keypoints. - true_keypoints (array): Ground truth keypoints. - distance_threshold (int): Distance threshold for considering a keypoint match. Returns: - tuple: Number of correct keypoints and whether the keypoints are reverted. """ result = 0 reverted = False # Find the position of keypoints in the list idx = np.where(pred_boxes == pred_box)[0][0] idx2 = np.where(true_boxes == true_box)[0][0] keypoint1_pred = pred_keypoints[idx][0] keypoint1_true = true_keypoints[idx2][0] keypoint2_pred = pred_keypoints[idx][1] keypoint2_true = true_keypoints[idx2][1] distance1 = np.linalg.norm(keypoint1_pred[:2] - keypoint1_true[:2]) distance2 = np.linalg.norm(keypoint2_pred[:2] - keypoint2_true[:2]) distance3 = np.linalg.norm(keypoint1_pred[:2] - keypoint2_true[:2]) distance4 = np.linalg.norm(keypoint2_pred[:2] - keypoint1_true[:2]) if distance1 < distance_threshold: result += 1 if distance2 < distance_threshold: result += 1 if distance3 < distance_threshold or distance4 < distance_threshold: reverted = True return result, reverted def evaluate_single_image(pred_boxes, true_boxes, pred_labels, true_labels, pred_keypoints, true_keypoints, iou_threshold=0.5, distance_threshold=5): """ Evaluate a single image's predictions against the ground truth. Parameters: - pred_boxes (array): Predicted bounding boxes. - true_boxes (array): Ground truth bounding boxes. - pred_labels (array): Predicted labels. - true_labels (array): Ground truth labels. - pred_keypoints (array): Predicted keypoints. - true_keypoints (array): Ground truth keypoints. - iou_threshold (float): IoU threshold for determining matches. - distance_threshold (int): Distance threshold for considering a keypoint match. Returns: - tuple: True positives, false positives, false negatives, correct labels, incorrect labels, correct keypoints, incorrect keypoints, and reverted keypoints count. """ tp, fp, fn = 0, 0, 0 key_t, key_f = 0, 0 labels_t, labels_f = 0, 0 reverted_tot = 0 matched_true_boxes = set() for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)): match_found = False for true_idx, true_box in enumerate(true_boxes): if true_idx in matched_true_boxes: continue iou_val = iou(pred_box, true_box) if iou_val >= iou_threshold: if true_keypoints is not None and pred_keypoints is not None: key_result, reverted = keypoints_measure( pred_boxes, pred_box, true_boxes, true_box, pred_keypoints, true_keypoints, distance_threshold ) key_t += key_result key_f += 2 - key_result if reverted: reverted_tot += 1 match_found = True matched_true_boxes.add(true_idx) if pred_label == true_labels[true_idx]: labels_t += 1 else: labels_f += 1 tp += 1 break if not match_found: fp += 1 fn = len(true_boxes) - tp return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted_tot def pred_4_evaluation(model, loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type='object'): """ Evaluate the model on a dataset using predictions for evaluation. Parameters: - model (torch.nn.Module): The model to evaluate. - loader (torch.utils.data.DataLoader): DataLoader for the dataset. - score_threshold (float): Score threshold for filtering predictions. - iou_threshold (float): IoU threshold for determining matches. - distance_threshold (int): Distance threshold for considering a keypoint match. - key_correction (bool): Whether to apply keypoint correction. - model_type (str): Type of model ('object' or 'arrow'). Returns: - tuple: Evaluation results including true positives, false positives, false negatives, correct labels, incorrect labels, correct keypoints, incorrect keypoints, and reverted keypoints count. """ model.eval() tp, fp, fn = 0, 0, 0 labels_t, labels_f = 0, 0 key_t, key_f = 0, 0 reverted = 0 with torch.no_grad(): for images, targets_im in tqdm(loader, desc="Testing... "): # Wrap the loader with tqdm devices = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') images = [image.to(devices) for image in images] targets = [{k: v.clone().detach().to(devices) for k, v in t.items()} for t in targets_im] predictions = model(images) for target, prediction in zip(targets, predictions): true_boxes = target['boxes'].cpu().numpy() true_labels = target['labels'].cpu().numpy() if 'keypoints' in target: true_keypoints = target['keypoints'].cpu().numpy() pred_boxes = prediction['boxes'].cpu().numpy() scores = prediction['scores'].cpu().numpy() pred_labels = prediction['labels'].cpu().numpy() if 'keypoints' in prediction: pred_keypoints = prediction['keypoints'].cpu().numpy() selected_boxes = non_maximum_suppression(pred_boxes, scores, iou_threshold=iou_threshold) pred_boxes = pred_boxes[selected_boxes] scores = scores[selected_boxes] pred_labels = pred_labels[selected_boxes] if 'keypoints' in prediction: pred_keypoints = pred_keypoints[selected_boxes] filtered_boxes = [] filtered_labels = [] filtered_keypoints = [] if 'keypoints' not in prediction: # Create a list of zeros of length equal to the number of boxes pred_keypoints = [np.zeros((2, 3)) for _ in range(len(pred_boxes))] for box, score, label, keypoints in zip(pred_boxes, scores, pred_labels, pred_keypoints): if score >= score_threshold: filtered_boxes.append(box) filtered_labels.append(label) if 'keypoints' in prediction: filtered_keypoints.append(keypoints) if key_correction and ('keypoints' in prediction): filtered_keypoints = keypoint_correction(filtered_keypoints, filtered_boxes, filtered_labels) if 'keypoints' not in target: filtered_keypoints = None true_keypoints = None tp_img, fp_img, fn_img, labels_t_img, labels_f_img, key_t_img, key_f_img, reverted_img = evaluate_single_image( filtered_boxes, true_boxes, filtered_labels, true_labels, filtered_keypoints, true_keypoints, iou_threshold, distance_threshold ) tp += tp_img fp += fp_img fn += fn_img labels_t += labels_t_img labels_f += labels_f_img key_t += key_t_img key_f += key_f_img reverted += reverted_img return tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted def main_evaluation(model, test_loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=5, key_correction=True, model_type='object'): """ Main function to evaluate the model on the test dataset. Parameters: - model (torch.nn.Module): The model to evaluate. - test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset. - score_threshold (float): Score threshold for filtering predictions. - iou_threshold (float): IoU threshold for determining matches. - distance_threshold (int): Distance threshold for considering a keypoint match. - key_correction (bool): Whether to apply keypoint correction. - model_type (str): Type of model ('object' or 'arrow'). Returns: - tuple: Precision, recall, F1-score, key accuracy, and reverted accuracy. """ tp, fp, fn, labels_t, labels_f, key_t, key_f, reverted = pred_4_evaluation( model, test_loader, score_threshold, iou_threshold, distance_threshold, key_correction, model_type ) labels_precision = labels_t / (labels_t + labels_f) if (labels_t + labels_f) > 0 else 0 precision = tp / (tp + fp) if (tp + fp) > 0 else 0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0 f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 if model_type == 'arrow': key_accuracy = key_t / (key_t + key_f) if (key_t + key_f) > 0 else 0 reverted_accuracy = reverted / (key_t + key_f) if (key_t + key_f) > 0 else 0 else: key_accuracy = 0 reverted_accuracy = 0 return labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy def evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold=0.5): """ Evaluate a single image's predictions on a per-class basis. Parameters: - pred_boxes (array): Predicted bounding boxes. - true_boxes (array): Ground truth bounding boxes. - pred_labels (array): Predicted labels. - true_labels (array): Ground truth labels. - class_tp (dict): Dictionary of true positive counts per class. - class_fp (dict): Dictionary of false positive counts per class. - class_fn (dict): Dictionary of false negative counts per class. - model_dict (dict): Dictionary mapping model labels to indices. - iou_threshold (float): IoU threshold for determining matches. """ matched_true_boxes = set() for pred_idx, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)): match_found = False for true_idx, (true_box, true_label) in enumerate(zip(true_boxes, true_labels)): if true_idx in matched_true_boxes: continue if pred_label == true_label and iou(np.array(pred_box), np.array(true_box)) >= iou_threshold: class_tp[model_dict[pred_label]] += 1 matched_true_boxes.add(true_idx) match_found = True break if not match_found: class_fp[model_dict[pred_label]] += 1 for idx, true_label in enumerate(true_labels): if idx not in matched_true_boxes: class_fn[model_dict[true_label]] += 1 def pred_4_evaluation_per_class(model, loader, score_threshold=0.5, iou_threshold=0.5): """ Generate predictions for evaluation on a per-class basis. Parameters: - model (torch.nn.Module): The model to evaluate. - loader (torch.utils.data.DataLoader): DataLoader for the dataset. - score_threshold (float): Score threshold for filtering predictions. - iou_threshold (float): IoU threshold for determining matches. Yields: - tuple: Predicted and true boxes and labels for each batch. """ model.eval() with torch.no_grad(): for images, targets_im in tqdm(loader, desc="Testing... "): devices = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') images = [image.to(devices) for image in images] targets = [{k: v.clone().detach().to(devices) for k, v in t.items()} for t in targets_im] predictions = model(images) for target, prediction in zip(targets, predictions): true_boxes = target['boxes'].cpu().numpy() true_labels = target['labels'].cpu().numpy() pred_boxes = prediction['boxes'].cpu().numpy() scores = prediction['scores'].cpu().numpy() pred_labels = prediction['labels'].cpu().numpy() idx = np.where(scores > score_threshold)[0] pred_boxes = pred_boxes[idx] scores = scores[idx] pred_labels = pred_labels[idx] selected_boxes = non_maximum_suppression(pred_boxes, scores, iou_threshold=iou_threshold) pred_boxes = pred_boxes[selected_boxes] scores = scores[selected_boxes] pred_labels = pred_labels[selected_boxes] yield pred_boxes, true_boxes, pred_labels, true_labels def evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.5, iou_threshold=0.5): """ Evaluate the model's performance on a per-class basis for the entire dataset. Parameters: - model (torch.nn.Module): The model to evaluate. - test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset. - model_dict (dict): Dictionary mapping model labels to indices. - score_threshold (float): Score threshold for filtering predictions. - iou_threshold (float): IoU threshold for determining matches. Returns: - tuple: Precision, recall, and F1-score per class. """ class_tp = {cls: 0 for cls in model_dict.values()} class_fp = {cls: 0 for cls in model_dict.values()} class_fn = {cls: 0 for cls in model_dict.values()} for pred_boxes, true_boxes, pred_labels, true_labels in pred_4_evaluation_per_class(model, test_loader, score_threshold, iou_threshold): evaluate_model_by_class_single_image(pred_boxes, true_boxes, pred_labels, true_labels, class_tp, class_fp, class_fn, model_dict, iou_threshold) class_precision = {} class_recall = {} class_f1_score = {} for cls in model_dict.values(): precision = class_tp[cls] / (class_tp[cls] + class_fp[cls]) if class_tp[cls] + class_fp[cls] > 0 else 0 recall = class_tp[cls] / (class_tp[cls] + class_fn[cls]) if class_tp[cls] + class_fn[cls] > 0 else 0 f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 class_precision[cls] = precision class_recall[cls] = recall class_f1_score[cls] = f1_score return class_precision, class_recall, class_f1_score