from torchvision.models.detection import keypointrcnn_resnet50_fpn from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights import random import torch from torch.utils.data import Dataset import torchvision.transforms.functional as F import numpy as np from torch.utils.data.dataloader import default_collate import cv2 import matplotlib.pyplot as plt from torch.utils.data import DataLoader, Subset, ConcatDataset import streamlit as st """object_dict = { 0: 'background', 1: 'task', 2: 'exclusiveGateway', 3: 'eventBasedGateway', 4: 'event', 5: 'messageEvent', 6: 'timerEvent', 7: 'dataObject', 8: 'dataStore', 9: 'pool', 10: 'lane', } arrow_dict = { 0: 'background', 1: 'sequenceFlow', 2: 'dataAssociation', 3: 'messageFlow', } class_dict = { 0: 'background', 1: 'task', 2: 'exclusiveGateway', 3: 'eventBasedGateway', 4: 'event', 5: 'messageEvent', 6: 'timerEvent', 7: 'dataObject', 8: 'dataStore', 9: 'pool', 10: 'lane', 11: 'sequenceFlow', 12: 'dataAssociation', 13: 'messageFlow', }""" object_dict = { 0: 'background', 1: 'task', 2: 'exclusiveGateway', 3: 'event', 4: 'parallelGateway', 5: 'messageEvent', 6: 'pool', 7: 'lane', 8: 'dataObject', 9: 'dataStore', 10: 'subProcess', 11: 'eventBasedGateway', 12: 'timerEvent', } arrow_dict = { 0: 'background', 1: 'sequenceFlow', 2: 'dataAssociation', 3: 'messageFlow', } class_dict = { 0: 'background', 1: 'task', 2: 'exclusiveGateway', 3: 'event', 4: 'parallelGateway', 5: 'messageEvent', 6: 'pool', 7: 'lane', 8: 'dataObject', 9: 'dataStore', 10: 'subProcess', 11: 'eventBasedGateway', 12: 'timerEvent', 13: 'sequenceFlow', 14: 'dataAssociation', 15: 'messageFlow', } def is_inside(box1, box2): """Check if the center of box1 is inside box2.""" x_center = (box1[0] + box1[2]) / 2 y_center = (box1[1] + box1[3]) / 2 return box2[0] <= x_center <= box2[2] and box2[1] <= y_center <= box2[3] def is_vertical(box): """Determine if the text in the bounding box is vertically aligned.""" width = box[2] - box[0] height = box[3] - box[1] return (height > 2*width) def rescale_boxes(scale, boxes): for i in range(len(boxes)): boxes[i] = [boxes[i][0]*scale, boxes[i][1]*scale, boxes[i][2]*scale, boxes[i][3]*scale] return boxes def iou(box1, box2): # Calcule l'intersection des deux boîtes englobantes inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])] inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1]) # Calcule l'union des deux boîtes englobantes box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) union_area = box1_area + box2_area - inter_area return inter_area / union_area def proportion_inside(box1, box2): # Calculate the areas of both boxes box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) # Determine the bigger and smaller boxes if box1_area > box2_area: big_box = box1 small_box = box2 else: big_box = box2 small_box = box1 # Calculate the intersection of the two bounding boxes inter_box = [max(small_box[0], big_box[0]), max(small_box[1], big_box[1]), min(small_box[2], big_box[2]), min(small_box[3], big_box[3])] inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1]) # Calculate the proportion of the smaller box inside the bigger box if (small_box[2] - small_box[0]) * (small_box[3] - small_box[1]) == 0: return 0 proportion = inter_area / ((small_box[2] - small_box[0]) * (small_box[3] - small_box[1])) # Ensure the proportion is at most 100% return min(proportion, 1.0) def resize_boxes(boxes, original_size, target_size): """ Resizes bounding boxes according to a new image size. Parameters: - boxes (np.array): The original bounding boxes as a numpy array of shape [N, 4]. - original_size (tuple): The original size of the image as (width, height). - target_size (tuple): The desired size to resize the image to as (width, height). Returns: - np.array: The resized bounding boxes as a numpy array of shape [N, 4]. """ orig_width, orig_height = original_size target_width, target_height = target_size # Calculate the ratios for width and height width_ratio = target_width / orig_width height_ratio = target_height / orig_height # Apply the ratios to the bounding boxes boxes[:, 0] *= width_ratio boxes[:, 1] *= height_ratio boxes[:, 2] *= width_ratio boxes[:, 3] *= height_ratio return boxes def resize_keypoints(keypoints: np.ndarray, original_size: tuple, target_size: tuple) -> np.ndarray: """ Resize keypoints based on the original and target dimensions of an image. Parameters: - keypoints (np.ndarray): The array of keypoints, where each keypoint is represented by its (x, y) coordinates. - original_size (tuple): The width and height of the original image (width, height). - target_size (tuple): The width and height of the target image (width, height). Returns: - np.ndarray: The resized keypoints. Explanation: The function calculates the ratio of the target dimensions to the original dimensions. It then applies these ratios to the x and y coordinates of each keypoint to scale them appropriately to the target image size. """ orig_width, orig_height = original_size target_width, target_height = target_size # Calculate the ratios for width and height scaling width_ratio = target_width / orig_width height_ratio = target_height / orig_height # Apply the scaling ratios to the x and y coordinates of each keypoint keypoints[:, 0] *= width_ratio # Scale x coordinates keypoints[:, 1] *= height_ratio # Scale y coordinates return keypoints def write_results(name_model,metrics_list,start_epoch): with open('./results/'+ name_model+ '.txt', 'w') as f: for i in range(len(metrics_list[0])): f.write(f"{i+1+start_epoch},{metrics_list[0][i]},{metrics_list[1][i]},{metrics_list[2][i]},{metrics_list[3][i]},{metrics_list[4][i]},{metrics_list[5][i]},{metrics_list[6][i]},{metrics_list[7][i]},{metrics_list[8][i]},{metrics_list[9][i]} \n") def find_other_keypoint(idx, keypoints, boxes): box = boxes[idx] key1,key2 = keypoints[idx] x1, y1, x2, y2 = box center = ((x1 + x2) // 2, (y1 + y2) // 2) average_keypoint = (key1 + key2) // 2 #find the opposite keypoint to the center if average_keypoint[0] < center[0]: x = center[0] + abs(center[0] - average_keypoint[0]) else: x = center[0] - abs(center[0] - average_keypoint[0]) if average_keypoint[1] < center[1]: y = center[1] + abs(center[1] - average_keypoint[1]) else: y = center[1] - abs(center[1] - average_keypoint[1]) return x, y, average_keypoint[0], average_keypoint[1] def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5): """ Filters overlapping boxes based on the Intersection over Union (IoU) metric, keeping only the boxes with the highest scores. Parameters: - boxes (np.ndarray): Array of bounding boxes with shape (N, 4), where each row contains [x_min, y_min, x_max, y_max]. - scores (np.ndarray): Array of scores for each box, reflecting the confidence of detection. - labels (np.ndarray): Array of labels corresponding to each box. - keypoints (np.ndarray): Array of keypoints associated with each box. - iou_threshold (float): Threshold for IoU above which a box is considered overlapping. Returns: - tuple: Filtered boxes, scores, labels, and keypoints. """ # Calculate the area of each bounding box to use in IoU calculation. areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) # Sort the indices of the boxes based on their scores in descending order. order = scores.argsort()[::-1] keep = [] # List to store indices of boxes to keep. while order.size > 0: # Take the first index (highest score) from the sorted list. i = order[0] keep.append(i) # Add this index to 'keep' list. # Compute the coordinates of the intersection rectangle. xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0]) yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1]) xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2]) yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3]) # Compute the area of the intersection rectangle. w = np.maximum(0.0, xx2 - xx1) h = np.maximum(0.0, yy2 - yy1) inter = w * h # Calculate IoU and find boxes with IoU less than the threshold to keep. iou = inter / (areas[i] + areas[order[1:]] - inter) inds = np.where(iou <= iou_threshold)[0] # Update the list of box indices to consider in the next iteration. order = order[inds + 1] # Skip the first element since it's already included in 'keep'. # Use the indices in 'keep' to select the boxes, scores, labels, and keypoints to return. boxes = boxes[keep] scores = scores[keep] labels = labels[keep] keypoints = keypoints[keep] return boxes, scores, labels, keypoints def draw_annotations(image, target=None, prediction=None, full_prediction=None, text_predictions=None, model_dict=class_dict, draw_keypoints=False, draw_boxes=False, draw_text=False, draw_links=False, draw_twins=False, write_class=False, write_score=False, write_text=False, write_idx=False, score_threshold=0.4, keypoints_correction=False, only_print=None, axis=False, return_image=False, new_size=(1333,800), resize=False): """ Draws annotations on images including bounding boxes, keypoints, links, and text. Parameters: - image (np.array): The image on which annotations will be drawn. - target (dict): Ground truth data containing boxes, labels, etc. - prediction (dict): Prediction data from a model. - full_prediction (dict): Additional detailed prediction data, potentially including relationships. - text_predictions (tuple): OCR text predictions containing bounding boxes and texts. - model_dict (dict): Mapping from class IDs to class names. - draw_keypoints (bool): Flag to draw keypoints. - draw_boxes (bool): Flag to draw bounding boxes. - draw_text (bool): Flag to draw text annotations. - draw_links (bool): Flag to draw links between annotations. - draw_twins (bool): Flag to draw twins keypoints. - write_class (bool): Flag to write class names near the annotations. - write_score (bool): Flag to write scores near the annotations. - write_text (bool): Flag to write OCR recognized text. - score_threshold (float): Threshold for scores above which annotations will be drawn. - only_print (str): Specific class name to filter annotations by. - resize (bool): Whether to resize annotations to fit the image size. """ # Convert image to RGB (if not already in that format) if prediction is None: image = image.squeeze(0).permute(1, 2, 0).cpu().numpy() image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image_copy = image.copy() scale = max(image.shape[0], image.shape[1]) / 1000 # Function to draw bounding boxes and keypoints def draw(data,is_prediction=False): """ Helper function to draw annotations based on provided data. """ for i in range(len(data['boxes'])): if is_prediction: box = data['boxes'][i].tolist() x1, y1, x2, y2 = box if resize: x1, y1, x2, y2 = resize_boxes(np.array([box]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0] score = data['scores'][i].item() if score < score_threshold: continue else: box = data['boxes'][i].tolist() x1, y1, x2, y2 = box if draw_boxes: if only_print is not None: if data['labels'][i] != list(model_dict.values()).index(only_print): continue cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0) if is_prediction else (0, 0, 0), int(2*scale)) if is_prediction and write_score: cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2) if write_class and 'labels' in data: class_id = data['labels'][i].item() cv2.putText(image_copy, model_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2) if write_idx: cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2) # Draw keypoints if available if draw_keypoints and 'keypoints' in data: if is_prediction and keypoints_correction: for idx, (key1, key2) in enumerate(data['keypoints']): if data['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'), list(model_dict.values()).index('messageFlow'), list(model_dict.values()).index('dataAssociation')]: continue # Calculate the Euclidean distance between the two keypoints distance = np.linalg.norm(key1[:2] - key2[:2]) if distance < 5: x_new,y_new, x,y = find_other_keypoint(idx, data['keypoints'], data['boxes']) data['keypoints'][idx][0] = torch.tensor([x_new, y_new,1]) data['keypoints'][idx][1] = torch.tensor([x, y,1]) print("keypoint has been changed") for i in range(len(data['keypoints'])): kp = data['keypoints'][i] for j in range(kp.shape[0]): if is_prediction and data['labels'][i] != list(model_dict.values()).index('sequenceFlow') and data['labels'][i] != list(model_dict.values()).index('messageFlow') and data['labels'][i] != list(model_dict.values()).index('dataAssociation'): continue if is_prediction: score = data['scores'][i] if score < score_threshold: continue x,y,v = np.array(kp[j]) if resize: x, y, v = resize_keypoints(np.array([kp[j]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0] if j == 0: cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1) else: cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1) # Draw text predictions if available if (draw_text or write_text) and text_predictions is not None: for i in range(len(text_predictions[0])): x1, y1, x2, y2 = text_predictions[0][i] text = text_predictions[1][i] if resize: x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0] if draw_text: cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale)) if write_text: cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2) def draw_with_links(full_prediction): '''Draws links between objects based on the full prediction data.''' #check if keypoints detected are the same if draw_twins and full_prediction is not None: # Pre-calculate indices for performance circle_color = (0, 255, 0) # Green color for the circle circle_radius = int(10 * scale) # Circle radius scaled by image scale for idx, (key1, key2) in enumerate(full_prediction['keypoints']): if full_prediction['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'), list(model_dict.values()).index('messageFlow'), list(model_dict.values()).index('dataAssociation')]: continue # Calculate the Euclidean distance between the two keypoints distance = np.linalg.norm(key1[:2] - key2[:2]) if distance < 10: x_new,y_new, x,y = find_other_keypoint(idx,full_prediction) cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1) cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1) # Draw links between objects if draw_links==True and full_prediction is not None: for i, (start_idx, end_idx) in enumerate(full_prediction['links']): if start_idx is None or end_idx is None: continue start_box = full_prediction['boxes'][start_idx] end_box = full_prediction['boxes'][end_idx] current_box = full_prediction['boxes'][i] # Calculate the center of each bounding box start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2) end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2) current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2) # Draw a line between the centers of the connected objects cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale)) cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale)) i+=1 # Draw GT annotations if target is not None: draw(target, is_prediction=False) # Draw predictions if prediction is not None: #prediction = prediction[0] draw(prediction, is_prediction=True) # Draw links with full predictions if full_prediction is not None: draw_with_links(full_prediction) # Display the image image_copy = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB) plt.figure(figsize=(12, 12)) plt.imshow(image_copy) if axis==False: plt.axis('off') plt.show() if return_image: return image_copy def find_closest_object(keypoint, boxes, labels): """ Find the closest object to a keypoint based on their proximity. Parameters: - keypoint (numpy.ndarray): The coordinates of the keypoint. - boxes (numpy.ndarray): The bounding boxes of the objects. Returns: - int or None: The index of the closest object to the keypoint, or None if no object is found. """ closest_object_idx = None best_point = None min_distance = float('inf') # Iterate over each bounding box for i, box in enumerate(boxes): if labels[i] in [list(class_dict.values()).index('sequenceFlow'), list(class_dict.values()).index('messageFlow'), list(class_dict.values()).index('dataAssociation'), #list(class_dict.values()).index('pool'), list(class_dict.values()).index('lane')]: continue x1, y1, x2, y2 = box top = ((x1+x2)/2, y1) bottom = ((x1+x2)/2, y2) left = (x1, (y1+y2)/2) right = (x2, (y1+y2)/2) points = [left, top , right, bottom] pos_dict = {0:'left', 1:'top', 2:'right', 3:'bottom'} # Calculate the distance between the keypoint and the center of the bounding box for pos, (point) in enumerate(points): distance = np.linalg.norm(keypoint[:2] - point) # Update the closest object index if this object is closer if distance < min_distance: min_distance = distance closest_object_idx = i best_point = pos_dict[pos] return closest_object_idx, best_point def error(text='There is an error in the detection'): st.error(text, icon="🚨") def warning(text='Some element are maybe not detected, verify the results, try to modify the parameters or try to add it in the method and style step.'): st.warning(text, icon="⚠️")