BenjiELCA's picture
correct a lot of bugs and allow automatic resize value
ca37b38
raw
history blame
22.2 kB
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
import random
import torch
from torch.utils.data import Dataset
import torchvision.transforms.functional as F
import numpy as np
from torch.utils.data.dataloader import default_collate
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset, ConcatDataset
import streamlit as st
"""object_dict = {
0: 'background',
1: 'task',
2: 'exclusiveGateway',
3: 'eventBasedGateway',
4: 'event',
5: 'messageEvent',
6: 'timerEvent',
7: 'dataObject',
8: 'dataStore',
9: 'pool',
10: 'lane',
}
arrow_dict = {
0: 'background',
1: 'sequenceFlow',
2: 'dataAssociation',
3: 'messageFlow',
}
class_dict = {
0: 'background',
1: 'task',
2: 'exclusiveGateway',
3: 'eventBasedGateway',
4: 'event',
5: 'messageEvent',
6: 'timerEvent',
7: 'dataObject',
8: 'dataStore',
9: 'pool',
10: 'lane',
11: 'sequenceFlow',
12: 'dataAssociation',
13: 'messageFlow',
}"""
object_dict = {
0: 'background',
1: 'task',
2: 'exclusiveGateway',
3: 'event',
4: 'parallelGateway',
5: 'messageEvent',
6: 'pool',
7: 'lane',
8: 'dataObject',
9: 'dataStore',
10: 'subProcess',
11: 'eventBasedGateway',
12: 'timerEvent',
}
arrow_dict = {
0: 'background',
1: 'sequenceFlow',
2: 'dataAssociation',
3: 'messageFlow',
}
class_dict = {
0: 'background',
1: 'task',
2: 'exclusiveGateway',
3: 'event',
4: 'parallelGateway',
5: 'messageEvent',
6: 'pool',
7: 'lane',
8: 'dataObject',
9: 'dataStore',
10: 'subProcess',
11: 'eventBasedGateway',
12: 'timerEvent',
13: 'sequenceFlow',
14: 'dataAssociation',
15: 'messageFlow',
}
def is_inside(box1, box2):
"""Check if the center of box1 is inside box2."""
x_center = (box1[0] + box1[2]) / 2
y_center = (box1[1] + box1[3]) / 2
return box2[0] <= x_center <= box2[2] and box2[1] <= y_center <= box2[3]
def is_vertical(box):
"""Determine if the text in the bounding box is vertically aligned."""
width = box[2] - box[0]
height = box[3] - box[1]
return (height > 2*width)
def rescale_boxes(scale, boxes):
for i in range(len(boxes)):
boxes[i] = [boxes[i][0]*scale,
boxes[i][1]*scale,
boxes[i][2]*scale,
boxes[i][3]*scale]
return boxes
def iou(box1, box2):
# Calcule l'intersection des deux boîtes englobantes
inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
# Calcule l'union des deux boîtes englobantes
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union_area = box1_area + box2_area - inter_area
return inter_area / union_area
def proportion_inside(box1, box2):
# Calculate the areas of both boxes
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
# Determine the bigger and smaller boxes
if box1_area > box2_area:
big_box = box1
small_box = box2
else:
big_box = box2
small_box = box1
# Calculate the intersection of the two bounding boxes
inter_box = [max(small_box[0], big_box[0]), max(small_box[1], big_box[1]), min(small_box[2], big_box[2]), min(small_box[3], big_box[3])]
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
# Calculate the proportion of the smaller box inside the bigger box
if (small_box[2] - small_box[0]) * (small_box[3] - small_box[1]) == 0:
return 0
proportion = inter_area / ((small_box[2] - small_box[0]) * (small_box[3] - small_box[1]))
# Ensure the proportion is at most 100%
return min(proportion, 1.0)
def resize_boxes(boxes, original_size, target_size):
"""
Resizes bounding boxes according to a new image size.
Parameters:
- boxes (np.array): The original bounding boxes as a numpy array of shape [N, 4].
- original_size (tuple): The original size of the image as (width, height).
- target_size (tuple): The desired size to resize the image to as (width, height).
Returns:
- np.array: The resized bounding boxes as a numpy array of shape [N, 4].
"""
orig_width, orig_height = original_size
target_width, target_height = target_size
# Calculate the ratios for width and height
width_ratio = target_width / orig_width
height_ratio = target_height / orig_height
# Apply the ratios to the bounding boxes
boxes[:, 0] *= width_ratio
boxes[:, 1] *= height_ratio
boxes[:, 2] *= width_ratio
boxes[:, 3] *= height_ratio
return boxes
def resize_keypoints(keypoints: np.ndarray, original_size: tuple, target_size: tuple) -> np.ndarray:
"""
Resize keypoints based on the original and target dimensions of an image.
Parameters:
- keypoints (np.ndarray): The array of keypoints, where each keypoint is represented by its (x, y) coordinates.
- original_size (tuple): The width and height of the original image (width, height).
- target_size (tuple): The width and height of the target image (width, height).
Returns:
- np.ndarray: The resized keypoints.
Explanation:
The function calculates the ratio of the target dimensions to the original dimensions.
It then applies these ratios to the x and y coordinates of each keypoint to scale them
appropriately to the target image size.
"""
orig_width, orig_height = original_size
target_width, target_height = target_size
# Calculate the ratios for width and height scaling
width_ratio = target_width / orig_width
height_ratio = target_height / orig_height
# Apply the scaling ratios to the x and y coordinates of each keypoint
keypoints[:, 0] *= width_ratio # Scale x coordinates
keypoints[:, 1] *= height_ratio # Scale y coordinates
return keypoints
def write_results(name_model,metrics_list,start_epoch):
with open('./results/'+ name_model+ '.txt', 'w') as f:
for i in range(len(metrics_list[0])):
f.write(f"{i+1+start_epoch},{metrics_list[0][i]},{metrics_list[1][i]},{metrics_list[2][i]},{metrics_list[3][i]},{metrics_list[4][i]},{metrics_list[5][i]},{metrics_list[6][i]},{metrics_list[7][i]},{metrics_list[8][i]},{metrics_list[9][i]} \n")
def find_other_keypoint(idx, keypoints, boxes):
box = boxes[idx]
key1,key2 = keypoints[idx]
x1, y1, x2, y2 = box
center = ((x1 + x2) // 2, (y1 + y2) // 2)
average_keypoint = (key1 + key2) // 2
#find the opposite keypoint to the center
if average_keypoint[0] < center[0]:
x = center[0] + abs(center[0] - average_keypoint[0])
else:
x = center[0] - abs(center[0] - average_keypoint[0])
if average_keypoint[1] < center[1]:
y = center[1] + abs(center[1] - average_keypoint[1])
else:
y = center[1] - abs(center[1] - average_keypoint[1])
return x, y, average_keypoint[0], average_keypoint[1]
def filter_overlap_boxes(boxes, scores, labels, keypoints, iou_threshold=0.5):
"""
Filters overlapping boxes based on the Intersection over Union (IoU) metric, keeping only the boxes with the highest scores.
Parameters:
- boxes (np.ndarray): Array of bounding boxes with shape (N, 4), where each row contains [x_min, y_min, x_max, y_max].
- scores (np.ndarray): Array of scores for each box, reflecting the confidence of detection.
- labels (np.ndarray): Array of labels corresponding to each box.
- keypoints (np.ndarray): Array of keypoints associated with each box.
- iou_threshold (float): Threshold for IoU above which a box is considered overlapping.
Returns:
- tuple: Filtered boxes, scores, labels, and keypoints.
"""
# Calculate the area of each bounding box to use in IoU calculation.
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# Sort the indices of the boxes based on their scores in descending order.
order = scores.argsort()[::-1]
keep = [] # List to store indices of boxes to keep.
while order.size > 0:
# Take the first index (highest score) from the sorted list.
i = order[0]
keep.append(i) # Add this index to 'keep' list.
# Compute the coordinates of the intersection rectangle.
xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
# Compute the area of the intersection rectangle.
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
inter = w * h
# Calculate IoU and find boxes with IoU less than the threshold to keep.
iou = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(iou <= iou_threshold)[0]
# Update the list of box indices to consider in the next iteration.
order = order[inds + 1] # Skip the first element since it's already included in 'keep'.
# Use the indices in 'keep' to select the boxes, scores, labels, and keypoints to return.
boxes = boxes[keep]
scores = scores[keep]
labels = labels[keep]
keypoints = keypoints[keep]
return boxes, scores, labels, keypoints
def draw_annotations(image,
target=None,
prediction=None,
full_prediction=None,
text_predictions=None,
model_dict=class_dict,
draw_keypoints=False,
draw_boxes=False,
draw_text=False,
draw_links=False,
draw_twins=False,
write_class=False,
write_score=False,
write_text=False,
write_idx=False,
score_threshold=0.4,
keypoints_correction=False,
only_print=None,
axis=False,
return_image=False,
new_size=(1333,800),
resize=False):
"""
Draws annotations on images including bounding boxes, keypoints, links, and text.
Parameters:
- image (np.array): The image on which annotations will be drawn.
- target (dict): Ground truth data containing boxes, labels, etc.
- prediction (dict): Prediction data from a model.
- full_prediction (dict): Additional detailed prediction data, potentially including relationships.
- text_predictions (tuple): OCR text predictions containing bounding boxes and texts.
- model_dict (dict): Mapping from class IDs to class names.
- draw_keypoints (bool): Flag to draw keypoints.
- draw_boxes (bool): Flag to draw bounding boxes.
- draw_text (bool): Flag to draw text annotations.
- draw_links (bool): Flag to draw links between annotations.
- draw_twins (bool): Flag to draw twins keypoints.
- write_class (bool): Flag to write class names near the annotations.
- write_score (bool): Flag to write scores near the annotations.
- write_text (bool): Flag to write OCR recognized text.
- score_threshold (float): Threshold for scores above which annotations will be drawn.
- only_print (str): Specific class name to filter annotations by.
- resize (bool): Whether to resize annotations to fit the image size.
"""
# Convert image to RGB (if not already in that format)
if prediction is None:
image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_copy = image.copy()
scale = max(image.shape[0], image.shape[1]) / 1000
# Function to draw bounding boxes and keypoints
def draw(data,is_prediction=False):
""" Helper function to draw annotations based on provided data. """
for i in range(len(data['boxes'])):
if is_prediction:
box = data['boxes'][i].tolist()
x1, y1, x2, y2 = box
if resize:
x1, y1, x2, y2 = resize_boxes(np.array([box]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
score = data['scores'][i].item()
if score < score_threshold:
continue
else:
box = data['boxes'][i].tolist()
x1, y1, x2, y2 = box
if draw_boxes:
if only_print is not None:
if data['labels'][i] != list(model_dict.values()).index(only_print):
continue
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0) if is_prediction else (0, 0, 0), int(2*scale))
if is_prediction and write_score:
cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
if write_class and 'labels' in data:
class_id = data['labels'][i].item()
cv2.putText(image_copy, model_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
if write_idx:
cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
# Draw keypoints if available
if draw_keypoints and 'keypoints' in data:
if is_prediction and keypoints_correction:
for idx, (key1, key2) in enumerate(data['keypoints']):
if data['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
list(model_dict.values()).index('messageFlow'),
list(model_dict.values()).index('dataAssociation')]:
continue
# Calculate the Euclidean distance between the two keypoints
distance = np.linalg.norm(key1[:2] - key2[:2])
if distance < 5:
x_new,y_new, x,y = find_other_keypoint(idx, data['keypoints'], data['boxes'])
data['keypoints'][idx][0] = torch.tensor([x_new, y_new,1])
data['keypoints'][idx][1] = torch.tensor([x, y,1])
print("keypoint has been changed")
for i in range(len(data['keypoints'])):
kp = data['keypoints'][i]
for j in range(kp.shape[0]):
if is_prediction and data['labels'][i] != list(model_dict.values()).index('sequenceFlow') and data['labels'][i] != list(model_dict.values()).index('messageFlow') and data['labels'][i] != list(model_dict.values()).index('dataAssociation'):
continue
if is_prediction:
score = data['scores'][i]
if score < score_threshold:
continue
x,y,v = np.array(kp[j])
if resize:
x, y, v = resize_keypoints(np.array([kp[j]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
if j == 0:
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
else:
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
# Draw text predictions if available
if (draw_text or write_text) and text_predictions is not None:
for i in range(len(text_predictions[0])):
x1, y1, x2, y2 = text_predictions[0][i]
text = text_predictions[1][i]
if resize:
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), new_size, (image_copy.shape[1],image_copy.shape[0]))[0]
if draw_text:
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
if write_text:
cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2)
def draw_with_links(full_prediction):
'''Draws links between objects based on the full prediction data.'''
#check if keypoints detected are the same
if draw_twins and full_prediction is not None:
# Pre-calculate indices for performance
circle_color = (0, 255, 0) # Green color for the circle
circle_radius = int(10 * scale) # Circle radius scaled by image scale
for idx, (key1, key2) in enumerate(full_prediction['keypoints']):
if full_prediction['labels'][idx] not in [list(model_dict.values()).index('sequenceFlow'),
list(model_dict.values()).index('messageFlow'),
list(model_dict.values()).index('dataAssociation')]:
continue
# Calculate the Euclidean distance between the two keypoints
distance = np.linalg.norm(key1[:2] - key2[:2])
if distance < 10:
x_new,y_new, x,y = find_other_keypoint(idx,full_prediction)
cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
# Draw links between objects
if draw_links==True and full_prediction is not None:
for i, (start_idx, end_idx) in enumerate(full_prediction['links']):
if start_idx is None or end_idx is None:
continue
start_box = full_prediction['boxes'][start_idx]
end_box = full_prediction['boxes'][end_idx]
current_box = full_prediction['boxes'][i]
# Calculate the center of each bounding box
start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
# Draw a line between the centers of the connected objects
cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale))
cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
i+=1
# Draw GT annotations
if target is not None:
draw(target, is_prediction=False)
# Draw predictions
if prediction is not None:
#prediction = prediction[0]
draw(prediction, is_prediction=True)
# Draw links with full predictions
if full_prediction is not None:
draw_with_links(full_prediction)
# Display the image
image_copy = cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(12, 12))
plt.imshow(image_copy)
if axis==False:
plt.axis('off')
plt.show()
if return_image:
return image_copy
def find_closest_object(keypoint, boxes, labels):
"""
Find the closest object to a keypoint based on their proximity.
Parameters:
- keypoint (numpy.ndarray): The coordinates of the keypoint.
- boxes (numpy.ndarray): The bounding boxes of the objects.
Returns:
- int or None: The index of the closest object to the keypoint, or None if no object is found.
"""
closest_object_idx = None
best_point = None
min_distance = float('inf')
# Iterate over each bounding box
for i, box in enumerate(boxes):
if labels[i] in [list(class_dict.values()).index('sequenceFlow'),
list(class_dict.values()).index('messageFlow'),
list(class_dict.values()).index('dataAssociation'),
#list(class_dict.values()).index('pool'),
list(class_dict.values()).index('lane')]:
continue
x1, y1, x2, y2 = box
top = ((x1+x2)/2, y1)
bottom = ((x1+x2)/2, y2)
left = (x1, (y1+y2)/2)
right = (x2, (y1+y2)/2)
points = [left, top , right, bottom]
pos_dict = {0:'left', 1:'top', 2:'right', 3:'bottom'}
# Calculate the distance between the keypoint and the center of the bounding box
for pos, (point) in enumerate(points):
distance = np.linalg.norm(keypoint[:2] - point)
# Update the closest object index if this object is closer
if distance < min_distance:
min_distance = distance
closest_object_idx = i
best_point = pos_dict[pos]
return closest_object_idx, best_point
def error(text='There is an error in the detection'):
st.error(text, icon="🚨")
def warning(text='Some element are maybe not detected, verify the results, try to modify the parameters or try to add it in the method and style step.'):
st.warning(text, icon="⚠️")