Spaces:
Running
Running
correct a lot of bugs and allow automatic resize value
Browse files- modules/dataset_loader.py +500 -0
- modules/display.py +1 -1
- modules/eval.py +91 -10
- modules/streamlit_utils.py +41 -3
- modules/toWizard.py +2 -2
- modules/toXML.py +20 -12
- modules/train.py +28 -21
- modules/utils.py +56 -473
modules/dataset_loader.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.models.detection import keypointrcnn_resnet50_fpn
|
2 |
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
3 |
+
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
|
4 |
+
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
import numpy as np
|
10 |
+
from torch.utils.data.dataloader import default_collate
|
11 |
+
import cv2
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from torch.utils.data import DataLoader, Subset, ConcatDataset
|
14 |
+
import streamlit as st
|
15 |
+
from modules.utils import object_dict, arrow_dict, resize_boxes, resize_keypoints
|
16 |
+
|
17 |
+
class RandomCrop:
|
18 |
+
def __init__(self, new_size=(1333,800),crop_fraction=0.5, min_objects=4):
|
19 |
+
self.crop_fraction = crop_fraction
|
20 |
+
self.min_objects = min_objects
|
21 |
+
self.new_size = new_size
|
22 |
+
|
23 |
+
def __call__(self, image, target):
|
24 |
+
new_w1, new_h1 = self.new_size
|
25 |
+
w, h = image.size
|
26 |
+
new_w = int(w * self.crop_fraction)
|
27 |
+
new_h = int(new_w*new_h1/new_w1)
|
28 |
+
|
29 |
+
i=0
|
30 |
+
for i in range(4):
|
31 |
+
if new_h >= h:
|
32 |
+
i += 0.05
|
33 |
+
new_w = int(w * (self.crop_fraction - i))
|
34 |
+
new_h = int(new_w*new_h1/new_w1)
|
35 |
+
if new_h < h:
|
36 |
+
continue
|
37 |
+
|
38 |
+
if new_h >= h:
|
39 |
+
return image, target
|
40 |
+
|
41 |
+
boxes = target["boxes"]
|
42 |
+
if 'keypoints' in target:
|
43 |
+
keypoints = target["keypoints"]
|
44 |
+
else:
|
45 |
+
keypoints = []
|
46 |
+
for i in range(len(boxes)):
|
47 |
+
keypoints.append(torch.zeros((2,3)))
|
48 |
+
|
49 |
+
|
50 |
+
# Attempt to find a suitable crop region
|
51 |
+
success = False
|
52 |
+
for _ in range(100): # Max 100 attempts to find a valid crop
|
53 |
+
top = random.randint(0, h - new_h)
|
54 |
+
left = random.randint(0, w - new_w)
|
55 |
+
crop_region = [left, top, left + new_w, top + new_h]
|
56 |
+
|
57 |
+
# Check how many objects are fully contained in this region
|
58 |
+
contained_boxes = []
|
59 |
+
contained_keypoints = []
|
60 |
+
for box, kp in zip(boxes, keypoints):
|
61 |
+
if box[0] >= crop_region[0] and box[1] >= crop_region[1] and box[2] <= crop_region[2] and box[3] <= crop_region[3]:
|
62 |
+
# Adjust box and keypoints coordinates
|
63 |
+
new_box = box - torch.tensor([crop_region[0], crop_region[1], crop_region[0], crop_region[1]])
|
64 |
+
new_kp = kp - torch.tensor([crop_region[0], crop_region[1], 0])
|
65 |
+
contained_boxes.append(new_box)
|
66 |
+
contained_keypoints.append(new_kp)
|
67 |
+
|
68 |
+
if len(contained_boxes) >= self.min_objects:
|
69 |
+
success = True
|
70 |
+
break
|
71 |
+
|
72 |
+
if success:
|
73 |
+
# Perform the actual crop
|
74 |
+
image = F.crop(image, top, left, new_h, new_w)
|
75 |
+
target["boxes"] = torch.stack(contained_boxes) if contained_boxes else torch.zeros((0, 4))
|
76 |
+
if 'keypoints' in target:
|
77 |
+
target["keypoints"] = torch.stack(contained_keypoints) if contained_keypoints else torch.zeros((0, 2, 4))
|
78 |
+
|
79 |
+
return image, target
|
80 |
+
|
81 |
+
|
82 |
+
class RandomFlip:
|
83 |
+
def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5):
|
84 |
+
"""
|
85 |
+
Initializes the RandomFlip with probabilities for flipping.
|
86 |
+
|
87 |
+
Parameters:
|
88 |
+
- h_flip_prob (float): Probability of applying a horizontal flip to the image.
|
89 |
+
- v_flip_prob (float): Probability of applying a vertical flip to the image.
|
90 |
+
"""
|
91 |
+
self.h_flip_prob = h_flip_prob
|
92 |
+
self.v_flip_prob = v_flip_prob
|
93 |
+
|
94 |
+
def __call__(self, image, target):
|
95 |
+
"""
|
96 |
+
Applies random horizontal and/or vertical flip to the image and updates target data accordingly.
|
97 |
+
|
98 |
+
Parameters:
|
99 |
+
- image (PIL Image): The image to be flipped.
|
100 |
+
- target (dict): The target dictionary containing 'boxes' and 'keypoints'.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
- PIL Image, dict: The flipped image and its updated target dictionary.
|
104 |
+
"""
|
105 |
+
if random.random() < self.h_flip_prob:
|
106 |
+
image = F.hflip(image)
|
107 |
+
w, _ = image.size # Get the new width of the image after flip for bounding box adjustment
|
108 |
+
# Adjust bounding boxes for horizontal flip
|
109 |
+
for i, box in enumerate(target['boxes']):
|
110 |
+
xmin, ymin, xmax, ymax = box
|
111 |
+
target['boxes'][i] = torch.tensor([w - xmax, ymin, w - xmin, ymax], dtype=torch.float32)
|
112 |
+
|
113 |
+
# Adjust keypoints for horizontal flip
|
114 |
+
if 'keypoints' in target:
|
115 |
+
new_keypoints = []
|
116 |
+
for keypoints_for_object in target['keypoints']:
|
117 |
+
flipped_keypoints_for_object = []
|
118 |
+
for kp in keypoints_for_object:
|
119 |
+
x, y = kp[:2]
|
120 |
+
new_x = w - x
|
121 |
+
flipped_keypoints_for_object.append(torch.tensor([new_x, y] + list(kp[2:])))
|
122 |
+
new_keypoints.append(torch.stack(flipped_keypoints_for_object))
|
123 |
+
target['keypoints'] = torch.stack(new_keypoints)
|
124 |
+
|
125 |
+
if random.random() < self.v_flip_prob:
|
126 |
+
image = F.vflip(image)
|
127 |
+
_, h = image.size # Get the new height of the image after flip for bounding box adjustment
|
128 |
+
# Adjust bounding boxes for vertical flip
|
129 |
+
for i, box in enumerate(target['boxes']):
|
130 |
+
xmin, ymin, xmax, ymax = box
|
131 |
+
target['boxes'][i] = torch.tensor([xmin, h - ymax, xmax, h - ymin], dtype=torch.float32)
|
132 |
+
|
133 |
+
# Adjust keypoints for vertical flip
|
134 |
+
if 'keypoints' in target:
|
135 |
+
new_keypoints = []
|
136 |
+
for keypoints_for_object in target['keypoints']:
|
137 |
+
flipped_keypoints_for_object = []
|
138 |
+
for kp in keypoints_for_object:
|
139 |
+
x, y = kp[:2]
|
140 |
+
new_y = h - y
|
141 |
+
flipped_keypoints_for_object.append(torch.tensor([x, new_y] + list(kp[2:])))
|
142 |
+
new_keypoints.append(torch.stack(flipped_keypoints_for_object))
|
143 |
+
target['keypoints'] = torch.stack(new_keypoints)
|
144 |
+
|
145 |
+
return image, target
|
146 |
+
|
147 |
+
|
148 |
+
class RandomRotate:
|
149 |
+
def __init__(self, max_rotate_deg=20, rotate_proba=0.3):
|
150 |
+
"""
|
151 |
+
Initializes the RandomRotate with a maximum rotation angle and probability of rotating.
|
152 |
+
|
153 |
+
Parameters:
|
154 |
+
- max_rotate_deg (int): Maximum degree to rotate the image.
|
155 |
+
- rotate_proba (float): Probability of applying rotation to the image.
|
156 |
+
"""
|
157 |
+
self.max_rotate_deg = max_rotate_deg
|
158 |
+
self.rotate_proba = rotate_proba
|
159 |
+
|
160 |
+
def __call__(self, image, target):
|
161 |
+
"""
|
162 |
+
Randomly rotates the image and updates the target data accordingly.
|
163 |
+
|
164 |
+
Parameters:
|
165 |
+
- image (PIL Image): The image to be rotated.
|
166 |
+
- target (dict): The target dictionary containing 'boxes', 'labels', and 'keypoints'.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
- PIL Image, dict: The rotated image and its updated target dictionary.
|
170 |
+
"""
|
171 |
+
if random.random() < self.rotate_proba:
|
172 |
+
angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg)
|
173 |
+
image = F.rotate(image, angle, expand=False, fill=200)
|
174 |
+
|
175 |
+
# Rotate bounding boxes
|
176 |
+
w, h = image.size
|
177 |
+
cx, cy = w / 2, h / 2
|
178 |
+
boxes = target["boxes"]
|
179 |
+
new_boxes = []
|
180 |
+
for box in boxes:
|
181 |
+
new_box = self.rotate_box(box, angle, cx, cy)
|
182 |
+
new_boxes.append(new_box)
|
183 |
+
target["boxes"] = torch.stack(new_boxes)
|
184 |
+
|
185 |
+
# Rotate keypoints
|
186 |
+
if 'keypoints' in target:
|
187 |
+
new_keypoints = []
|
188 |
+
for keypoints in target["keypoints"]:
|
189 |
+
new_kp = self.rotate_keypoints(keypoints, angle, cx, cy)
|
190 |
+
new_keypoints.append(new_kp)
|
191 |
+
target["keypoints"] = torch.stack(new_keypoints)
|
192 |
+
|
193 |
+
return image, target
|
194 |
+
|
195 |
+
def rotate_box(self, box, angle, cx, cy):
|
196 |
+
"""
|
197 |
+
Rotates a bounding box by a given angle around the center of the image.
|
198 |
+
"""
|
199 |
+
x1, y1, x2, y2 = box
|
200 |
+
corners = torch.tensor([
|
201 |
+
[x1, y1],
|
202 |
+
[x2, y1],
|
203 |
+
[x2, y2],
|
204 |
+
[x1, y2]
|
205 |
+
])
|
206 |
+
corners = torch.cat((corners, torch.ones(corners.shape[0], 1)), dim=1)
|
207 |
+
M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
|
208 |
+
corners = torch.matmul(torch.tensor(M, dtype=torch.float32), corners.T).T
|
209 |
+
x_ = corners[:, 0]
|
210 |
+
y_ = corners[:, 1]
|
211 |
+
x_min, x_max = torch.min(x_), torch.max(x_)
|
212 |
+
y_min, y_max = torch.min(y_), torch.max(y_)
|
213 |
+
return torch.tensor([x_min, y_min, x_max, y_max], dtype=torch.float32)
|
214 |
+
|
215 |
+
def rotate_keypoints(self, keypoints, angle, cx, cy):
|
216 |
+
"""
|
217 |
+
Rotates keypoints by a given angle around the center of the image.
|
218 |
+
"""
|
219 |
+
new_keypoints = []
|
220 |
+
for kp in keypoints:
|
221 |
+
x, y, v = kp
|
222 |
+
point = torch.tensor([x, y, 1])
|
223 |
+
M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
|
224 |
+
new_point = torch.matmul(torch.tensor(M, dtype=torch.float32), point)
|
225 |
+
new_keypoints.append(torch.tensor([new_point[0], new_point[1], v], dtype=torch.float32))
|
226 |
+
return torch.stack(new_keypoints)
|
227 |
+
|
228 |
+
def rotate_90_box(box, angle, w, h):
|
229 |
+
x1, y1, x2, y2 = box
|
230 |
+
if angle == 90:
|
231 |
+
return torch.tensor([y1,h-x2,y2,h-x1])
|
232 |
+
elif angle == 270 or angle == -90:
|
233 |
+
return torch.tensor([w-y2,x1,w-y1,x2])
|
234 |
+
else:
|
235 |
+
print("angle not supported")
|
236 |
+
|
237 |
+
def rotate_90_keypoints(kp, angle, w, h):
|
238 |
+
# Extract coordinates and visibility from each keypoint tensor
|
239 |
+
x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2]
|
240 |
+
x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2]
|
241 |
+
# Swap x and y coordinates for each keypoint
|
242 |
+
if angle == 90:
|
243 |
+
new = [[y1, h-x1, v1], [y2, h-x2, v2]]
|
244 |
+
elif angle == 270 or angle == -90:
|
245 |
+
new = [[w-y1, x1, v1], [w-y2, x2, v2]]
|
246 |
+
|
247 |
+
return torch.tensor(new, dtype=torch.float32)
|
248 |
+
|
249 |
+
|
250 |
+
def rotate_vertical(image, target):
|
251 |
+
# Rotate the image and target if the image is vertical
|
252 |
+
new_boxes = []
|
253 |
+
angle = random.choice([-90,90])
|
254 |
+
image = F.rotate(image, angle, expand=True, fill=200)
|
255 |
+
for box in target["boxes"]:
|
256 |
+
new_box = rotate_90_box(box, angle, image.size[0], image.size[1])
|
257 |
+
new_boxes.append(new_box)
|
258 |
+
target["boxes"] = torch.stack(new_boxes)
|
259 |
+
|
260 |
+
if 'keypoints' in target:
|
261 |
+
new_kp = []
|
262 |
+
for kp in target['keypoints']:
|
263 |
+
new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1])
|
264 |
+
new_kp.append(new_key)
|
265 |
+
target['keypoints'] = torch.stack(new_kp)
|
266 |
+
return image, target
|
267 |
+
|
268 |
+
|
269 |
+
import torchvision.transforms.functional as F
|
270 |
+
import torch
|
271 |
+
|
272 |
+
def resize_and_pad(image, target, new_size=(1333, 800)):
|
273 |
+
original_size = image.size
|
274 |
+
# Calculate scale to fit the new size while maintaining aspect ratio
|
275 |
+
scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
|
276 |
+
new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
|
277 |
+
|
278 |
+
# Resize image to new scaled size
|
279 |
+
image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
|
280 |
+
|
281 |
+
# Calculate padding to center the image
|
282 |
+
pad_left = (new_size[0] - new_scaled_size[0]) // 2
|
283 |
+
pad_top = (new_size[1] - new_scaled_size[1]) // 2
|
284 |
+
pad_right = new_size[0] - new_scaled_size[0] - pad_left
|
285 |
+
pad_bottom = new_size[1] - new_scaled_size[1] - pad_top
|
286 |
+
|
287 |
+
# Pad the resized image to make it exactly the desired size
|
288 |
+
image = F.pad(image, (pad_left, pad_top, pad_right, pad_bottom), fill=0, padding_mode='constant')
|
289 |
+
|
290 |
+
# Adjust bounding boxes
|
291 |
+
target['boxes'] = resize_boxes(target['boxes'], original_size, new_scaled_size)
|
292 |
+
target['boxes'][:, 0::2] += pad_left
|
293 |
+
target['boxes'][:, 1::2] += pad_top
|
294 |
+
|
295 |
+
# Adjust keypoints if they exist in the target
|
296 |
+
if 'keypoints' in target:
|
297 |
+
for i in range(len(target['keypoints'])):
|
298 |
+
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], original_size, new_scaled_size)
|
299 |
+
target['keypoints'][i][:, 0] += pad_left
|
300 |
+
target['keypoints'][i][:, 1] += pad_top
|
301 |
+
|
302 |
+
return image, target
|
303 |
+
|
304 |
+
class BPMN_Dataset(Dataset):
|
305 |
+
def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2,
|
306 |
+
flip_transform=None, rotate_transform=None, new_size=(1333,1333), keep_ratio=0.1, resize=True, model_type='object'):
|
307 |
+
self.annotations = annotations
|
308 |
+
print(f"Loaded {len(self.annotations)} annotations.")
|
309 |
+
self.transform = transform
|
310 |
+
self.crop_transform = crop_transform
|
311 |
+
self.crop_prob = crop_prob
|
312 |
+
self.flip_transform = flip_transform
|
313 |
+
self.rotate_transform = rotate_transform
|
314 |
+
self.resize = resize
|
315 |
+
self.new_size = new_size
|
316 |
+
self.keep_ratio = keep_ratio
|
317 |
+
self.model_type = model_type
|
318 |
+
if model_type == 'object':
|
319 |
+
self.dict = object_dict
|
320 |
+
elif model_type == 'arrow':
|
321 |
+
self.dict = arrow_dict
|
322 |
+
self.rotate_90_proba = rotate_90_proba
|
323 |
+
|
324 |
+
def __len__(self):
|
325 |
+
return len(self.annotations)
|
326 |
+
|
327 |
+
def __getitem__(self, idx):
|
328 |
+
annotation = self.annotations[idx]
|
329 |
+
image = annotation.img.convert("RGB")
|
330 |
+
boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32)
|
331 |
+
labels_names = [ann for ann in annotation.categories]
|
332 |
+
|
333 |
+
# Only keep the labels, boxes and keypoints that are in the class_dict
|
334 |
+
kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()]
|
335 |
+
boxes = boxes[kept_indices]
|
336 |
+
labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices]
|
337 |
+
# Replace any subprocess by task
|
338 |
+
labels_names = ['task' if ann == 'subProcess' else ann for ann in labels_names]
|
339 |
+
|
340 |
+
labels_id = torch.tensor([(list(self.dict.values()).index(ann)) for ann in labels_names], dtype=torch.int64)
|
341 |
+
|
342 |
+
# Initialize keypoints tensor
|
343 |
+
max_keypoints = 2
|
344 |
+
keypoints = torch.zeros((len(labels_id), max_keypoints, 3), dtype=torch.float32)
|
345 |
+
|
346 |
+
ii = 0
|
347 |
+
for i, ann in enumerate(annotation.annotations):
|
348 |
+
# Only keep the keypoints that are in the kept indices
|
349 |
+
if i not in kept_indices:
|
350 |
+
continue
|
351 |
+
if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]:
|
352 |
+
# Fill the keypoints tensor for this annotation, mark as visible (1)
|
353 |
+
kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3)
|
354 |
+
kp = kp[:,:2]
|
355 |
+
visible = np.ones((kp.shape[0], 1), dtype=np.float32)
|
356 |
+
kp = np.hstack([kp, visible])
|
357 |
+
keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32)
|
358 |
+
ii += 1
|
359 |
+
|
360 |
+
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
361 |
+
|
362 |
+
if self.model_type == 'object':
|
363 |
+
target = {
|
364 |
+
"boxes": boxes,
|
365 |
+
"labels": labels_id,
|
366 |
+
#"area": area,
|
367 |
+
}
|
368 |
+
elif self.model_type == 'arrow':
|
369 |
+
target = {
|
370 |
+
"boxes": boxes,
|
371 |
+
"labels": labels_id,
|
372 |
+
#"area": area,
|
373 |
+
"keypoints": keypoints,
|
374 |
+
}
|
375 |
+
|
376 |
+
# Randomly apply flip transform
|
377 |
+
if self.flip_transform:
|
378 |
+
image, target = self.flip_transform(image, target)
|
379 |
+
|
380 |
+
# Randomly apply rotate transform
|
381 |
+
if self.rotate_transform:
|
382 |
+
image, target = self.rotate_transform(image, target)
|
383 |
+
|
384 |
+
# Randomly apply the custom cropping transform
|
385 |
+
if self.crop_transform and random.random() < self.crop_prob:
|
386 |
+
image, target = self.crop_transform(image, target)
|
387 |
+
|
388 |
+
# Rotate vertical image
|
389 |
+
if random.random() < self.rotate_90_proba:
|
390 |
+
image, target = rotate_vertical(image, target)
|
391 |
+
|
392 |
+
if self.resize:
|
393 |
+
if random.random() < self.keep_ratio:
|
394 |
+
# Center and pad the image while keeping the aspect ratio
|
395 |
+
image, target = resize_and_pad(image, target, self.new_size)
|
396 |
+
else:
|
397 |
+
target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), self.new_size)
|
398 |
+
if 'area' in target:
|
399 |
+
target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
|
400 |
+
if 'keypoints' in target:
|
401 |
+
for i in range(len(target['keypoints'])):
|
402 |
+
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), self.new_size)
|
403 |
+
image = F.resize(image, (self.new_size[1], self.new_size[0]))
|
404 |
+
|
405 |
+
return self.transform(image), target
|
406 |
+
|
407 |
+
|
408 |
+
def collate_fn(batch):
|
409 |
+
"""
|
410 |
+
Custom collation function for DataLoader that handles batches of images and targets.
|
411 |
+
|
412 |
+
This function ensures that images are properly batched together using PyTorch's default collation,
|
413 |
+
while keeping the targets (such as bounding boxes and labels) in a list of dictionaries,
|
414 |
+
as each image might have a different number of objects detected.
|
415 |
+
|
416 |
+
Parameters:
|
417 |
+
- batch (list): A list of tuples, where each tuple contains an image and its corresponding target dictionary.
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
- Tuple containing:
|
421 |
+
- Tensor: Batched images.
|
422 |
+
- List of dicts: Targets corresponding to each image in the batch.
|
423 |
+
"""
|
424 |
+
images, targets = zip(*batch) # Unzip the batch into separate lists for images and targets.
|
425 |
+
|
426 |
+
# Batch images using the default collate function which handles tensors, numpy arrays, numbers, etc.
|
427 |
+
images = default_collate(images)
|
428 |
+
|
429 |
+
return images, targets
|
430 |
+
|
431 |
+
|
432 |
+
|
433 |
+
def create_loader(new_size,transformation, annotations1, annotations2=None,
|
434 |
+
batch_size=4, crop_prob=0.2, crop_fraction=0.7, min_objects=3,
|
435 |
+
h_flip_prob=0.3, v_flip_prob=0.3, max_rotate_deg=20, rotate_90_proba=0.2, rotate_proba=0.3,
|
436 |
+
seed=42, resize=True, keep_ratio=0.1, model_type = 'object'):
|
437 |
+
"""
|
438 |
+
Creates a DataLoader for BPMN datasets with optional transformations and concatenation of two datasets.
|
439 |
+
|
440 |
+
Parameters:
|
441 |
+
- transformation (callable): Transformation function to apply to each image (e.g., normalization).
|
442 |
+
- annotations1 (list): Primary list of annotations.
|
443 |
+
- annotations2 (list, optional): Secondary list of annotations to concatenate with the first.
|
444 |
+
- batch_size (int): Number of images per batch.
|
445 |
+
- crop_prob (float): Probability of applying the crop transformation.
|
446 |
+
- crop_fraction (float): Fraction of the original width to use when cropping.
|
447 |
+
- min_objects (int): Minimum number of objects required to be within the crop.
|
448 |
+
- h_flip_prob (float): Probability of applying horizontal flip.
|
449 |
+
- v_flip_prob (float): Probability of applying vertical flip.
|
450 |
+
- seed (int): Seed for random number generators for reproducibility.
|
451 |
+
- resize (bool): Flag indicating whether to resize images after transformations.
|
452 |
+
|
453 |
+
Returns:
|
454 |
+
- DataLoader: Configured data loader for the dataset.
|
455 |
+
"""
|
456 |
+
|
457 |
+
# Initialize custom transformations for cropping and flipping
|
458 |
+
custom_crop_transform = RandomCrop(new_size,crop_fraction, min_objects)
|
459 |
+
custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob)
|
460 |
+
custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba)
|
461 |
+
|
462 |
+
# Create the primary dataset
|
463 |
+
dataset = BPMN_Dataset(
|
464 |
+
annotations=annotations1,
|
465 |
+
transform=transformation,
|
466 |
+
crop_transform=custom_crop_transform,
|
467 |
+
crop_prob=crop_prob,
|
468 |
+
rotate_90_proba=rotate_90_proba,
|
469 |
+
flip_transform=custom_flip_transform,
|
470 |
+
rotate_transform=custom_rotate_transform,
|
471 |
+
new_size=new_size,
|
472 |
+
keep_ratio=keep_ratio,
|
473 |
+
model_type=model_type,
|
474 |
+
resize=resize
|
475 |
+
)
|
476 |
+
|
477 |
+
# Optionally concatenate a second dataset
|
478 |
+
if annotations2:
|
479 |
+
dataset2 = BPMN_Dataset(
|
480 |
+
annotations=annotations2,
|
481 |
+
transform=transformation,
|
482 |
+
crop_transform=custom_crop_transform,
|
483 |
+
crop_prob=crop_prob,
|
484 |
+
rotate_90_proba=rotate_90_proba,
|
485 |
+
flip_transform=custom_flip_transform,
|
486 |
+
new_size=new_size,
|
487 |
+
keep_ratio=keep_ratio,
|
488 |
+
model_type=model_type,
|
489 |
+
resize=resize
|
490 |
+
)
|
491 |
+
dataset = ConcatDataset([dataset, dataset2]) # Concatenate the two datasets
|
492 |
+
|
493 |
+
# Set the seed for reproducibility in random operations within transformations and data loading
|
494 |
+
random.seed(seed)
|
495 |
+
torch.manual_seed(seed)
|
496 |
+
|
497 |
+
# Create the DataLoader with the dataset
|
498 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
|
499 |
+
|
500 |
+
return data_loader
|
modules/display.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from modules.utils import
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
import torch
|
|
|
1 |
+
from modules.utils import class_dict, resize_boxes, resize_keypoints, find_other_keypoint
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
import torch
|
modules/eval.py
CHANGED
@@ -3,8 +3,9 @@ import torch
|
|
3 |
from modules.utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou
|
4 |
from tqdm import tqdm
|
5 |
from modules.toXML import get_size_elements, calculate_pool_bounds, create_BPMN_id
|
6 |
-
from modules.utils import is_vertical
|
7 |
import streamlit as st
|
|
|
8 |
|
9 |
|
10 |
def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
|
@@ -101,10 +102,27 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
|
|
101 |
scores = scores[selected_boxes]
|
102 |
labels = labels[selected_boxes]
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
#modify the label of the sub-process to task
|
105 |
for i in range(len(labels)):
|
106 |
if labels[i] == list(object_dict.values()).index('subProcess'):
|
107 |
labels[i] = list(object_dict.values()).index('task')
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
prediction = {
|
110 |
'boxes': boxes,
|
@@ -180,7 +198,7 @@ def mix_predictions(objects_pred, arrow_pred):
|
|
180 |
return boxes, labels, scores, keypoints
|
181 |
|
182 |
|
183 |
-
def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.
|
184 |
pool_dict = {}
|
185 |
|
186 |
# Filter out pools with IoU greater than the threshold
|
@@ -188,7 +206,7 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
|
|
188 |
for i in range(len(boxes)):
|
189 |
for j in range(i + 1, len(boxes)):
|
190 |
if labels[i] == labels[j] and labels[i] == list(class_dict.values()).index('pool'):
|
191 |
-
if
|
192 |
to_delete.append(j)
|
193 |
|
194 |
boxes = np.delete(boxes, to_delete, axis=0)
|
@@ -210,8 +228,7 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
|
|
210 |
if i in pool_indices or class_dict[labels[i]] in ['messageFlow', 'pool']:
|
211 |
continue
|
212 |
for j, pool_box in enumerate(pool_boxes):
|
213 |
-
if (box
|
214 |
-
box[2] <= pool_box[2] and box[3] <= pool_box[3]):
|
215 |
pool_index = pool_indices[j]
|
216 |
pool_dict[pool_index].append(i)
|
217 |
assigned_to_pool = True
|
@@ -322,6 +339,53 @@ def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
|
|
322 |
return labels, flow_links
|
323 |
|
324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
|
326 |
def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict, limit_area=10000):
|
327 |
|
@@ -368,6 +432,16 @@ def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_point
|
|
368 |
print('delete element', i)
|
369 |
delete_elements.append(i)
|
370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
#concatenate the delete_elements and the delete_pool
|
372 |
delete_elements = delete_elements + delete_pool
|
373 |
#delete double value in delete_elements
|
@@ -377,13 +451,21 @@ def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_point
|
|
377 |
labels = np.delete(labels, delete_elements)
|
378 |
scores = np.delete(scores, delete_elements)
|
379 |
keypoints = np.delete(keypoints, delete_elements, axis=0)
|
380 |
-
|
381 |
links = np.delete(links, delete_elements, axis=0)
|
382 |
best_points = [point for i, point in enumerate(best_points) if i not in delete_elements]
|
383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
#also delete the element in the pool_dict
|
385 |
for pool_index, elements in pool_dict.items():
|
386 |
-
pool_dict[pool_index] = [i for i in elements if i not in delete_elements]
|
387 |
|
388 |
return boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict
|
389 |
|
@@ -420,7 +502,7 @@ def generate_data(image, boxes, labels, scores, keypoints, bpmn_id, flow_links,
|
|
420 |
|
421 |
return data
|
422 |
|
423 |
-
def develop_prediction(boxes, labels, scores, keypoints, class_dict
|
424 |
|
425 |
pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
|
426 |
|
@@ -430,8 +512,7 @@ def develop_prediction(boxes, labels, scores, keypoints, class_dict, correction=
|
|
430 |
flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
|
431 |
|
432 |
#Correct the labels of some sequenceflow that cross multiple pool
|
433 |
-
|
434 |
-
labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
|
435 |
|
436 |
#give a link to event to allow the creation of the BPMN id with start, indermediate and end event
|
437 |
flow_links = give_link_to_element(flow_links, labels)
|
|
|
3 |
from modules.utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou
|
4 |
from tqdm import tqdm
|
5 |
from modules.toXML import get_size_elements, calculate_pool_bounds, create_BPMN_id
|
6 |
+
from modules.utils import is_vertical, proportion_inside
|
7 |
import streamlit as st
|
8 |
+
from builtins import dict
|
9 |
|
10 |
|
11 |
def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
|
|
|
102 |
scores = scores[selected_boxes]
|
103 |
labels = labels[selected_boxes]
|
104 |
|
105 |
+
#find the outlier object that are too small by the area
|
106 |
+
obj_not_too_small = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref = ['event', 'messageEvent'], mode = "lower")
|
107 |
+
obj_not_too_big = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=2, element_ref = ['task'], mode = "upper")
|
108 |
+
|
109 |
+
selected_object = [i for i in range(len(labels)) if i in obj_not_too_small and i in obj_not_too_big]
|
110 |
+
|
111 |
+
#selected_object = obj_not_too_small
|
112 |
+
|
113 |
+
boxes = boxes[selected_object]
|
114 |
+
scores = scores[selected_object]
|
115 |
+
labels = labels[selected_object]
|
116 |
+
|
117 |
#modify the label of the sub-process to task
|
118 |
for i in range(len(labels)):
|
119 |
if labels[i] == list(object_dict.values()).index('subProcess'):
|
120 |
labels[i] = list(object_dict.values()).index('task')
|
121 |
+
#delete all lane and also the value in the labels and scores
|
122 |
+
lane_index = [i for i in range(len(labels)) if labels[i] == list(object_dict.values()).index('lane')]
|
123 |
+
boxes = np.delete(boxes, lane_index, axis=0)
|
124 |
+
labels = np.delete(labels, lane_index)
|
125 |
+
scores = np.delete(scores, lane_index)
|
126 |
|
127 |
prediction = {
|
128 |
'boxes': boxes,
|
|
|
198 |
return boxes, labels, scores, keypoints
|
199 |
|
200 |
|
201 |
+
def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.6):
|
202 |
pool_dict = {}
|
203 |
|
204 |
# Filter out pools with IoU greater than the threshold
|
|
|
206 |
for i in range(len(boxes)):
|
207 |
for j in range(i + 1, len(boxes)):
|
208 |
if labels[i] == labels[j] and labels[i] == list(class_dict.values()).index('pool'):
|
209 |
+
if proportion_inside(boxes[i], boxes[j]) > iou_threshold:
|
210 |
to_delete.append(j)
|
211 |
|
212 |
boxes = np.delete(boxes, to_delete, axis=0)
|
|
|
228 |
if i in pool_indices or class_dict[labels[i]] in ['messageFlow', 'pool']:
|
229 |
continue
|
230 |
for j, pool_box in enumerate(pool_boxes):
|
231 |
+
if proportion_inside(box, pool_box) > iou_threshold:
|
|
|
232 |
pool_index = pool_indices[j]
|
233 |
pool_dict[pool_index].append(i)
|
234 |
assigned_to_pool = True
|
|
|
339 |
return labels, flow_links
|
340 |
|
341 |
|
342 |
+
def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref = ['event', 'messageEvent'], mode = "lower"):
|
343 |
+
# Filter out the sizes of events, data objects, and message events
|
344 |
+
event_indices = [i for i, label in enumerate(labels) if class_dict[label] in element_ref]
|
345 |
+
event_boxes = [boxes[i] for i in event_indices]
|
346 |
+
|
347 |
+
# Calculate the areas of these typical objects
|
348 |
+
event_areas = np.array([(box[2] - box[0]) * (box[3] - box[1]) for box in event_boxes])
|
349 |
+
|
350 |
+
# Compute the mean and standard deviation for areas
|
351 |
+
mean_area = np.mean(event_areas)
|
352 |
+
std_area = np.std(event_areas)
|
353 |
+
|
354 |
+
# Define thresholds for outliers
|
355 |
+
area_lower_threshold = mean_area - std_factor * std_area
|
356 |
+
area_upper_threshold = mean_area + std_factor * std_area
|
357 |
+
|
358 |
+
# Identify indices of outliers and the ones to keep
|
359 |
+
outlier_indices = []
|
360 |
+
kept_indices = []
|
361 |
+
|
362 |
+
if mode == "lower" or mode == 'both':
|
363 |
+
#check for object that could be too small
|
364 |
+
for idx, (box, label) in enumerate(zip(boxes, labels)):
|
365 |
+
area = (box[2] - box[0]) * (box[3] - box[1])
|
366 |
+
if not (area_lower_threshold <= area):
|
367 |
+
outlier_indices.append(idx)
|
368 |
+
print(f"Element {idx} is an outlier with area {area} that is too small")
|
369 |
+
else:
|
370 |
+
kept_indices.append(idx)
|
371 |
+
|
372 |
+
if mode == "upper" or mode == 'both':
|
373 |
+
#check for object that could be too big
|
374 |
+
for idx, (box, label) in enumerate(zip(boxes, labels)):
|
375 |
+
if label == list(class_dict.values()).index('pool') or label == list(class_dict.values()).index('lane'):
|
376 |
+
kept_indices.append(idx)
|
377 |
+
continue
|
378 |
+
area = (box[2] - box[0]) * (box[3] - box[1])
|
379 |
+
if not (area_upper_threshold >= area):
|
380 |
+
outlier_indices.append(idx)
|
381 |
+
print(f"Element {idx} is an outlier with area {area} that is too big")
|
382 |
+
else:
|
383 |
+
kept_indices.append(idx)
|
384 |
+
|
385 |
+
|
386 |
+
return kept_indices
|
387 |
+
|
388 |
+
|
389 |
|
390 |
def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict, limit_area=10000):
|
391 |
|
|
|
432 |
print('delete element', i)
|
433 |
delete_elements.append(i)
|
434 |
|
435 |
+
#filter box that are inside a text box
|
436 |
+
"""tex_pred = st.session_state.text_pred
|
437 |
+
for i in range(len(boxes)):
|
438 |
+
for j in range(len(tex_pred[0])):
|
439 |
+
#check if the box is inside the text box but if the text box is inside the box then it is not a problem
|
440 |
+
if proportion_inside(boxes[i], tex_pred[0][j]) > 0.1:
|
441 |
+
#delete_elements.append(i)
|
442 |
+
print('delete element', i)"""
|
443 |
+
|
444 |
+
|
445 |
#concatenate the delete_elements and the delete_pool
|
446 |
delete_elements = delete_elements + delete_pool
|
447 |
#delete double value in delete_elements
|
|
|
451 |
labels = np.delete(labels, delete_elements)
|
452 |
scores = np.delete(scores, delete_elements)
|
453 |
keypoints = np.delete(keypoints, delete_elements, axis=0)
|
454 |
+
|
455 |
links = np.delete(links, delete_elements, axis=0)
|
456 |
best_points = [point for i, point in enumerate(best_points) if i not in delete_elements]
|
457 |
|
458 |
+
for i in range(len(delete_pool)):
|
459 |
+
#find the bpmn_id of the pool
|
460 |
+
pool_index = bpmn_id[delete_pool[i]]
|
461 |
+
#delete the pool_index in pool_dict
|
462 |
+
del pool_dict[pool_index]
|
463 |
+
|
464 |
+
bpmn_id = [point for i, point in enumerate(bpmn_id) if i not in delete_elements]
|
465 |
+
|
466 |
#also delete the element in the pool_dict
|
467 |
for pool_index, elements in pool_dict.items():
|
468 |
+
pool_dict[pool_index] = [i for i in elements if i not in delete_elements]
|
469 |
|
470 |
return boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict
|
471 |
|
|
|
502 |
|
503 |
return data
|
504 |
|
505 |
+
def develop_prediction(boxes, labels, scores, keypoints, class_dict):
|
506 |
|
507 |
pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
|
508 |
|
|
|
512 |
flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
|
513 |
|
514 |
#Correct the labels of some sequenceflow that cross multiple pool
|
515 |
+
labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
|
|
|
516 |
|
517 |
#give a link to event to allow the creation of the BPMN id with start, indermediate and end event
|
518 |
flow_links = give_link_to_element(flow_links, labels)
|
modules/streamlit_utils.py
CHANGED
@@ -30,6 +30,8 @@ from modules.toWizard import create_wizard_file
|
|
30 |
from huggingface_hub import hf_hub_download
|
31 |
import time
|
32 |
|
|
|
|
|
33 |
|
34 |
|
35 |
|
@@ -440,12 +442,13 @@ def modify_results(percentage_text_dist_thresh=0.5):
|
|
440 |
new_keypoints = np.concatenate((object_keypoints, arrow_keypoints))
|
441 |
|
442 |
|
443 |
-
boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, new_scores, new_keypoints, class_dict
|
444 |
|
445 |
st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict)
|
446 |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
447 |
|
448 |
if changes:
|
|
|
449 |
st.rerun()
|
450 |
|
451 |
return True
|
@@ -460,14 +463,49 @@ def display_bpmn_modeler(is_mobile, screen_width):
|
|
460 |
st.session_state.size_scale, st.session_state.scale
|
461 |
)
|
462 |
st.session_state.vizi_file = create_wizard_file(st.session_state.prediction.copy(), st.session_state.text_mapping)
|
463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
|
465 |
def modeler_options(is_mobile):
|
466 |
if not is_mobile:
|
467 |
with st.expander("Options for BPMN modeler"):
|
468 |
col1, col2 = st.columns(2)
|
469 |
with col1:
|
470 |
-
st.session_state.
|
|
|
|
|
471 |
st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
|
472 |
else:
|
473 |
st.session_state.scale = 1.0
|
|
|
30 |
from huggingface_hub import hf_hub_download
|
31 |
import time
|
32 |
|
33 |
+
from modules.toXML import get_size_elements
|
34 |
+
|
35 |
|
36 |
|
37 |
|
|
|
442 |
new_keypoints = np.concatenate((object_keypoints, arrow_keypoints))
|
443 |
|
444 |
|
445 |
+
boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, new_scores, new_keypoints, class_dict)
|
446 |
|
447 |
st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict)
|
448 |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
449 |
|
450 |
if changes:
|
451 |
+
changes = False
|
452 |
st.rerun()
|
453 |
|
454 |
return True
|
|
|
463 |
st.session_state.size_scale, st.session_state.scale
|
464 |
)
|
465 |
st.session_state.vizi_file = create_wizard_file(st.session_state.prediction.copy(), st.session_state.text_mapping)
|
466 |
+
|
467 |
+
display_bpmn_xml(st.session_state.bpmn_xml, st.session_state.vizi_file, is_mobile=is_mobile, screen_width=int(4/5 * screen_width))
|
468 |
+
|
469 |
+
|
470 |
+
def find_best_scale(pred, size_elements):
|
471 |
+
boxes = pred['boxes']
|
472 |
+
labels = pred['labels']
|
473 |
+
|
474 |
+
# Find average size of the tasks in pred
|
475 |
+
avg_size = 0
|
476 |
+
count = 0
|
477 |
+
for i in range(len(boxes)):
|
478 |
+
if class_dict[labels[i]] == 'task':
|
479 |
+
avg_size += (boxes[i][2] - boxes[i][0]) * (boxes[i][3] - boxes[i][1])
|
480 |
+
count += 1
|
481 |
+
|
482 |
+
if count == 0:
|
483 |
+
raise ValueError("No tasks found in the provided prediction.")
|
484 |
+
|
485 |
+
avg_size /= count
|
486 |
+
|
487 |
+
# Get the size of a task element from size_elements dictionary
|
488 |
+
task_size = size_elements['task']
|
489 |
+
task_area = task_size[0] * task_size[1]
|
490 |
+
|
491 |
+
# Find the best scale
|
492 |
+
best_scale = (avg_size / task_area) ** 0.5
|
493 |
+
|
494 |
+
if best_scale < 0.5:
|
495 |
+
best_scale = 0.5
|
496 |
+
elif best_scale > 1:
|
497 |
+
best_scale = 1
|
498 |
+
|
499 |
+
return best_scale
|
500 |
|
501 |
def modeler_options(is_mobile):
|
502 |
if not is_mobile:
|
503 |
with st.expander("Options for BPMN modeler"):
|
504 |
col1, col2 = st.columns(2)
|
505 |
with col1:
|
506 |
+
st.session_state.best_scale = find_best_scale(st.session_state.prediction, get_size_elements())
|
507 |
+
print(f"Best scale: {st.session_state.best_scale}")
|
508 |
+
st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1/st.session_state.best_scale, step=0.1)
|
509 |
st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
|
510 |
else:
|
511 |
st.session_state.scale = 1.0
|
modules/toWizard.py
CHANGED
@@ -131,7 +131,7 @@ def create_wizard_file(data, text_mapping):
|
|
131 |
ET.SubElement(activity, 'subActivityFlows')
|
132 |
ET.SubElement(activity, 'messageFlows')
|
133 |
|
134 |
-
activityFlows = ET.SubElement(root, 'activityFlows')
|
135 |
i=0
|
136 |
for i, link in enumerate(data['links']):
|
137 |
if link[0] is None and link[1] is not None and (data['BPMN_id'][i].split('_')[0] == 'event' or data['BPMN_id'][i].split('_')[0] == 'message'):
|
@@ -145,7 +145,7 @@ def create_wizard_file(data, text_mapping):
|
|
145 |
if current_text is None or next_text is None:
|
146 |
continue
|
147 |
ET.SubElement(activityFlows, 'activityFlow', attrib={'activity': current_text, 'endState': '---', 'target': next_text, 'isMerging': 'False', 'isPredefined': 'True'})
|
148 |
-
i+=1
|
149 |
|
150 |
ET.SubElement(root, 'participants')
|
151 |
|
|
|
131 |
ET.SubElement(activity, 'subActivityFlows')
|
132 |
ET.SubElement(activity, 'messageFlows')
|
133 |
|
134 |
+
"""activityFlows = ET.SubElement(root, 'activityFlows')
|
135 |
i=0
|
136 |
for i, link in enumerate(data['links']):
|
137 |
if link[0] is None and link[1] is not None and (data['BPMN_id'][i].split('_')[0] == 'event' or data['BPMN_id'][i].split('_')[0] == 'message'):
|
|
|
145 |
if current_text is None or next_text is None:
|
146 |
continue
|
147 |
ET.SubElement(activityFlows, 'activityFlow', attrib={'activity': current_text, 'endState': '---', 'target': next_text, 'isMerging': 'False', 'isPredefined': 'True'})
|
148 |
+
i+=1"""
|
149 |
|
150 |
ET.SubElement(root, 'participants')
|
151 |
|
modules/toXML.py
CHANGED
@@ -113,8 +113,8 @@ def expand_pool_bounding_boxes(modified_pred, pred, size_elements):
|
|
113 |
if pool_width < 300 or pool_height < 30:
|
114 |
error("The pool is maybe too small, please add more elements or increase the scale by zooming on the image.")
|
115 |
continue
|
116 |
-
|
117 |
-
modified_pred['boxes'][position] = [min_x - marge, min_y - marge
|
118 |
|
119 |
# Adjust left and right boundaries of all pools
|
120 |
def adjust_pool_boundaries(modified_pred, pred):
|
@@ -148,9 +148,9 @@ def align_boxes(pred, size, class_dict):
|
|
148 |
pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
|
149 |
align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
|
155 |
return modified_pred['boxes']
|
156 |
|
@@ -176,10 +176,11 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
176 |
'id': "simpleExample"
|
177 |
})
|
178 |
|
|
|
179 |
size_elements = get_size_elements(size_scale)
|
180 |
|
181 |
#if there is no pool or lane, create a pool with all elements
|
182 |
-
if len(full_pred['pool_dict'])==0 or (len(full_pred['pool_dict'])==1 and len(full_pred['pool_dict']
|
183 |
full_pred, text_mapping = create_big_pool(full_pred, text_mapping)
|
184 |
|
185 |
#modify the boxes positions
|
@@ -249,13 +250,13 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
249 |
return pretty_xml_as_string
|
250 |
|
251 |
# Function that creates a single pool with all elements
|
252 |
-
def create_big_pool(full_pred, text_mapping):
|
253 |
# If no pools or lanes are detected, create a single pool with all elements
|
254 |
new_pool_index = 'pool_1'
|
255 |
size_elements = get_size_elements(st.session_state.size_scale)
|
256 |
elements_pool = list(range(len(full_pred['boxes'])))
|
257 |
min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred['boxes'],full_pred['labels'], elements_pool, size_elements)
|
258 |
-
box = [min_x, min_y, max_x, max_y]
|
259 |
full_pred['boxes'] = np.append(full_pred['boxes'], [box], axis=0)
|
260 |
full_pred['pool_dict'][new_pool_index] = elements_pool
|
261 |
full_pred['BPMN_id'].append('pool_1')
|
@@ -264,7 +265,7 @@ def create_big_pool(full_pred, text_mapping):
|
|
264 |
return full_pred, text_mapping
|
265 |
|
266 |
# Function that gives the size of the elements
|
267 |
-
def get_size_elements(size_scale):
|
268 |
size_elements = {
|
269 |
'event': (size_scale*43.2, size_scale*43.2),
|
270 |
'task': (size_scale*120, size_scale*96),
|
@@ -400,8 +401,9 @@ def check_data_association(i, links, labels, keep_elements):
|
|
400 |
|
401 |
def create_data_Association(bpmn,data,size,element_id,current_idx,source_id,target_id):
|
402 |
waypoints = calculate_waypoints(data, size, current_idx, source_id, target_id)
|
403 |
-
|
404 |
-
|
|
|
405 |
def check_eventBasedGateway(i, links, labels):
|
406 |
status, links_idx = [], []
|
407 |
for j, (k,l) in enumerate(links):
|
@@ -582,7 +584,7 @@ def calculate_pool_bounds(boxes, labels, keep_elements, size):
|
|
582 |
max_x = max(max_x, x + element_width)
|
583 |
max_y = max(max_y, y + element_height)
|
584 |
|
585 |
-
return min_x
|
586 |
|
587 |
|
588 |
|
@@ -680,10 +682,16 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
|
|
680 |
if source_idx is None or target_idx is None:
|
681 |
warning()
|
682 |
return None
|
|
|
683 |
|
684 |
name_source = source_id.split('_')[0]
|
685 |
name_target = target_id.split('_')[0]
|
686 |
|
|
|
|
|
|
|
|
|
|
|
687 |
# Get the position of the source and target
|
688 |
source_x, source_y = data['boxes'][source_idx][:2]
|
689 |
target_x, target_y = data['boxes'][target_idx][:2]
|
|
|
113 |
if pool_width < 300 or pool_height < 30:
|
114 |
error("The pool is maybe too small, please add more elements or increase the scale by zooming on the image.")
|
115 |
continue
|
116 |
+
|
117 |
+
modified_pred['boxes'][position] = [min_x - marge//20, min_y - marge, min_x + pool_width + marge//20, min_y + pool_height + marge]
|
118 |
|
119 |
# Adjust left and right boundaries of all pools
|
120 |
def adjust_pool_boundaries(modified_pred, pred):
|
|
|
148 |
pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
|
149 |
align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
|
150 |
|
151 |
+
if len(pred['pool_dict']) > 1:
|
152 |
+
expand_pool_bounding_boxes(modified_pred, pred, size)
|
153 |
+
adjust_pool_boundaries(modified_pred, pred)
|
154 |
|
155 |
return modified_pred['boxes']
|
156 |
|
|
|
176 |
'id': "simpleExample"
|
177 |
})
|
178 |
|
179 |
+
|
180 |
size_elements = get_size_elements(size_scale)
|
181 |
|
182 |
#if there is no pool or lane, create a pool with all elements
|
183 |
+
if len(full_pred['pool_dict']) == 0 or (len(full_pred['pool_dict']) == 1 and len(next(iter(full_pred['pool_dict'].values()))) == len(full_pred['labels'])):
|
184 |
full_pred, text_mapping = create_big_pool(full_pred, text_mapping)
|
185 |
|
186 |
#modify the boxes positions
|
|
|
250 |
return pretty_xml_as_string
|
251 |
|
252 |
# Function that creates a single pool with all elements
|
253 |
+
def create_big_pool(full_pred, text_mapping, marge=50):
|
254 |
# If no pools or lanes are detected, create a single pool with all elements
|
255 |
new_pool_index = 'pool_1'
|
256 |
size_elements = get_size_elements(st.session_state.size_scale)
|
257 |
elements_pool = list(range(len(full_pred['boxes'])))
|
258 |
min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred['boxes'],full_pred['labels'], elements_pool, size_elements)
|
259 |
+
box = [min_x-marge, min_y-marge, max_x+marge, max_y+marge]
|
260 |
full_pred['boxes'] = np.append(full_pred['boxes'], [box], axis=0)
|
261 |
full_pred['pool_dict'][new_pool_index] = elements_pool
|
262 |
full_pred['BPMN_id'].append('pool_1')
|
|
|
265 |
return full_pred, text_mapping
|
266 |
|
267 |
# Function that gives the size of the elements
|
268 |
+
def get_size_elements(size_scale=1):
|
269 |
size_elements = {
|
270 |
'event': (size_scale*43.2, size_scale*43.2),
|
271 |
'task': (size_scale*120, size_scale*96),
|
|
|
401 |
|
402 |
def create_data_Association(bpmn,data,size,element_id,current_idx,source_id,target_id):
|
403 |
waypoints = calculate_waypoints(data, size, current_idx, source_id, target_id)
|
404 |
+
if waypoints is not None:
|
405 |
+
add_diagram_edge(bpmn, element_id, waypoints)
|
406 |
+
|
407 |
def check_eventBasedGateway(i, links, labels):
|
408 |
status, links_idx = [], []
|
409 |
for j, (k,l) in enumerate(links):
|
|
|
584 |
max_x = max(max_x, x + element_width)
|
585 |
max_y = max(max_y, y + element_height)
|
586 |
|
587 |
+
return min_x, min_y, max_x, max_y
|
588 |
|
589 |
|
590 |
|
|
|
682 |
if source_idx is None or target_idx is None:
|
683 |
warning()
|
684 |
return None
|
685 |
+
|
686 |
|
687 |
name_source = source_id.split('_')[0]
|
688 |
name_target = target_id.split('_')[0]
|
689 |
|
690 |
+
avoid_element = ['pool', 'sequenceFlow', 'messageFlow', 'dataAssociation']
|
691 |
+
if name_target in avoid_element or name_source in avoid_element:
|
692 |
+
warning()
|
693 |
+
return None
|
694 |
+
|
695 |
# Get the position of the source and target
|
696 |
source_x, source_y = data['boxes'][source_idx][:2]
|
697 |
target_x, target_y = data['boxes'][target_idx][:2]
|
modules/train.py
CHANGED
@@ -100,7 +100,14 @@ def prepare_model(dict,opti,learning_rate= 0.0003,model_to_load=None, model_type
|
|
100 |
return model, optimizer, device
|
101 |
|
102 |
|
|
|
|
|
|
|
|
|
103 |
|
|
|
|
|
|
|
104 |
|
105 |
def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
|
106 |
model.train() # Set the model to evaluation mode
|
@@ -178,13 +185,14 @@ def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=Fal
|
|
178 |
|
179 |
|
180 |
def training_model(num_epochs, model, data_loader, subset_test_loader,
|
181 |
-
optimizer, model_to_load=None, change_learning_rate=
|
182 |
-
|
183 |
-
max_rotate_deg=20, rotate_proba=0.2, blur_prob=0.2,
|
184 |
score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
|
185 |
information_training='training', start_epoch=0, loss_config=None, model_type = 'object',
|
186 |
eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
|
187 |
|
|
|
|
|
188 |
|
189 |
if loss_config is None:
|
190 |
print('No loss config found, all losses will be used.')
|
@@ -219,14 +227,20 @@ def training_model(num_epochs, model, data_loader, subset_test_loader,
|
|
219 |
bad_test_loss = 0
|
220 |
previous_test_loss = 1000
|
221 |
|
|
|
|
|
|
|
222 |
print(f"Let's go training {model_type} model with {num_epochs} epochs!")
|
223 |
-
|
|
|
224 |
|
225 |
for epoch in range(num_epochs):
|
226 |
|
227 |
-
if (epoch>0 and (epoch)%change_learning_rate == 0) or bad_test_loss
|
228 |
learning_rate = 0.7*learning_rate
|
229 |
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
|
|
|
|
|
230 |
print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
|
231 |
bad_test_loss = 0
|
232 |
if epoch>0 and (epoch)==start_key:
|
@@ -315,24 +329,19 @@ def training_model(num_epochs, model, data_loader, subset_test_loader,
|
|
315 |
|
316 |
|
317 |
# Evaluate the model on the test set
|
318 |
-
if eval_metric != 'loss':
|
319 |
-
avg_test_loss = 0
|
320 |
-
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader,score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)
|
321 |
-
print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
|
322 |
-
if eval_metric == 'all':
|
323 |
-
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
|
324 |
-
print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
|
325 |
if eval_metric == 'loss':
|
326 |
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0,0,0,0,0,0
|
327 |
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
|
328 |
print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
print(f"Time: {time.time() - start:.2f} [s]")
|
331 |
|
332 |
-
|
333 |
-
if epoch>0 and (epoch)%start_key == 0:
|
334 |
-
print(f"Keypoints Accuracy: {key_accuracy:.4f}", end=", ")
|
335 |
-
|
336 |
if eval_metric == 'f1_score':
|
337 |
metric_used = f1_score
|
338 |
elif eval_metric == 'precision':
|
@@ -357,15 +366,14 @@ def training_model(num_epochs, model, data_loader, subset_test_loader,
|
|
357 |
epoch_test_loss.append(avg_test_loss)
|
358 |
|
359 |
name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}"
|
|
|
360 |
|
361 |
if same >=1 :
|
362 |
-
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
|
363 |
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
364 |
write_results(name_model,metrics_list,start_epoch)
|
365 |
break
|
366 |
|
367 |
if (epoch+1+start_epoch) % 5 == 0:
|
368 |
-
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
|
369 |
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
370 |
model.load_state_dict(best_model_state)
|
371 |
write_results(name_model,metrics_list,start_epoch)
|
@@ -375,12 +383,11 @@ def training_model(num_epochs, model, data_loader, subset_test_loader,
|
|
375 |
previous_test_loss = avg_test_loss
|
376 |
|
377 |
|
378 |
-
print(f"\n Total time: {(time.time() - start_tot)/60} minutes, Best Epoch is {best_epoch} with an
|
379 |
if best_model_state:
|
380 |
-
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
|
381 |
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
382 |
model.load_state_dict(best_model_state)
|
383 |
write_results(name_model,metrics_list,start_epoch)
|
384 |
print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}")
|
385 |
|
386 |
-
return model
|
|
|
100 |
return model, optimizer, device
|
101 |
|
102 |
|
103 |
+
import copy
|
104 |
+
from torch.optim import AdamW
|
105 |
+
import time
|
106 |
+
from modules.train import write_results
|
107 |
|
108 |
+
import torch
|
109 |
+
import numpy as np
|
110 |
+
from tqdm import tqdm
|
111 |
|
112 |
def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
|
113 |
model.train() # Set the model to evaluation mode
|
|
|
185 |
|
186 |
|
187 |
def training_model(num_epochs, model, data_loader, subset_test_loader,
|
188 |
+
optimizer, model_to_load=None, change_learning_rate=100, start_key=100,
|
189 |
+
parameters=None, blur_prob=0.02,
|
|
|
190 |
score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
|
191 |
information_training='training', start_epoch=0, loss_config=None, model_type = 'object',
|
192 |
eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
|
193 |
|
194 |
+
# Set the model to training mode
|
195 |
+
model.train()
|
196 |
|
197 |
if loss_config is None:
|
198 |
print('No loss config found, all losses will be used.')
|
|
|
227 |
bad_test_loss = 0
|
228 |
previous_test_loss = 1000
|
229 |
|
230 |
+
if parameters is not None:
|
231 |
+
batch_size, crop_prob, rotate_90_proba, h_flip_prob, v_flip_prob, max_rotate_deg, rotate_proba, keep_ratio = parameters.values()
|
232 |
+
|
233 |
print(f"Let's go training {model_type} model with {num_epochs} epochs!")
|
234 |
+
if parameters is not None:
|
235 |
+
print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, H flip prob: {h_flip_prob}, V flip prob: {v_flip_prob}, Max rotate deg: {max_rotate_deg}, Rotate proba: {rotate_proba}, Rotate 90 proba: {rotate_90_proba}, Keep ratio: {keep_ratio}")
|
236 |
|
237 |
for epoch in range(num_epochs):
|
238 |
|
239 |
+
if (epoch>0 and (epoch)%change_learning_rate == 0) or bad_test_loss>=3:
|
240 |
learning_rate = 0.7*learning_rate
|
241 |
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
|
242 |
+
if best_model_state is not None:
|
243 |
+
model.load_state_dict(best_model_state)
|
244 |
print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
|
245 |
bad_test_loss = 0
|
246 |
if epoch>0 and (epoch)==start_key:
|
|
|
329 |
|
330 |
|
331 |
# Evaluate the model on the test set
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
if eval_metric == 'loss':
|
333 |
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0,0,0,0,0,0
|
334 |
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
|
335 |
print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
|
336 |
+
else:
|
337 |
+
avg_test_loss = 0
|
338 |
+
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader,score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)
|
339 |
+
print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
|
340 |
+
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
|
341 |
+
print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
|
342 |
|
343 |
print(f"Time: {time.time() - start:.2f} [s]")
|
344 |
|
|
|
|
|
|
|
|
|
345 |
if eval_metric == 'f1_score':
|
346 |
metric_used = f1_score
|
347 |
elif eval_metric == 'precision':
|
|
|
366 |
epoch_test_loss.append(avg_test_loss)
|
367 |
|
368 |
name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}"
|
369 |
+
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
|
370 |
|
371 |
if same >=1 :
|
|
|
372 |
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
373 |
write_results(name_model,metrics_list,start_epoch)
|
374 |
break
|
375 |
|
376 |
if (epoch+1+start_epoch) % 5 == 0:
|
|
|
377 |
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
378 |
model.load_state_dict(best_model_state)
|
379 |
write_results(name_model,metrics_list,start_epoch)
|
|
|
383 |
previous_test_loss = avg_test_loss
|
384 |
|
385 |
|
386 |
+
print(f"\n Total time: {(time.time() - start_tot)/60} minutes, Best Epoch is {best_epoch} with an {eval_metric} of {best_metrics:.4f}")
|
387 |
if best_model_state:
|
|
|
388 |
torch.save(best_model_state, './models/'+ name_model +'.pth')
|
389 |
model.load_state_dict(best_model_state)
|
390 |
write_results(name_model,metrics_list,start_epoch)
|
391 |
print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}")
|
392 |
|
393 |
+
return model
|
modules/utils.py
CHANGED
@@ -14,6 +14,46 @@ from torch.utils.data import DataLoader, Subset, ConcatDataset
|
|
14 |
import streamlit as st
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
object_dict = {
|
18 |
0: 'background',
|
19 |
1: 'task',
|
@@ -90,17 +130,26 @@ def iou(box1, box2):
|
|
90 |
return inter_area / union_area
|
91 |
|
92 |
def proportion_inside(box1, box2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
# Calculate the intersection of the two bounding boxes
|
94 |
-
inter_box = [max(
|
95 |
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
|
96 |
|
97 |
-
# Calculate the
|
98 |
-
|
99 |
-
|
100 |
-
# Calculate the proportion of box1 inside box2
|
101 |
-
if box1_area == 0:
|
102 |
return 0
|
103 |
-
proportion = inter_area /
|
104 |
|
105 |
# Ensure the proportion is at most 100%
|
106 |
return min(proportion, 1.0)
|
@@ -164,472 +213,6 @@ def resize_keypoints(keypoints: np.ndarray, original_size: tuple, target_size: t
|
|
164 |
return keypoints
|
165 |
|
166 |
|
167 |
-
|
168 |
-
class RandomCrop:
|
169 |
-
def __init__(self, new_size=(1333,800),crop_fraction=0.5, min_objects=4):
|
170 |
-
self.crop_fraction = crop_fraction
|
171 |
-
self.min_objects = min_objects
|
172 |
-
self.new_size = new_size
|
173 |
-
|
174 |
-
def __call__(self, image, target):
|
175 |
-
new_w1, new_h1 = self.new_size
|
176 |
-
w, h = image.size
|
177 |
-
new_w = int(w * self.crop_fraction)
|
178 |
-
new_h = int(new_w*new_h1/new_w1)
|
179 |
-
|
180 |
-
i=0
|
181 |
-
for i in range(4):
|
182 |
-
if new_h >= h:
|
183 |
-
i += 0.05
|
184 |
-
new_w = int(w * (self.crop_fraction - i))
|
185 |
-
new_h = int(new_w*new_h1/new_w1)
|
186 |
-
if new_h < h:
|
187 |
-
continue
|
188 |
-
|
189 |
-
if new_h >= h:
|
190 |
-
return image, target
|
191 |
-
|
192 |
-
boxes = target["boxes"]
|
193 |
-
if 'keypoints' in target:
|
194 |
-
keypoints = target["keypoints"]
|
195 |
-
else:
|
196 |
-
keypoints = []
|
197 |
-
for i in range(len(boxes)):
|
198 |
-
keypoints.append(torch.zeros((2,3)))
|
199 |
-
|
200 |
-
|
201 |
-
# Attempt to find a suitable crop region
|
202 |
-
success = False
|
203 |
-
for _ in range(100): # Max 100 attempts to find a valid crop
|
204 |
-
top = random.randint(0, h - new_h)
|
205 |
-
left = random.randint(0, w - new_w)
|
206 |
-
crop_region = [left, top, left + new_w, top + new_h]
|
207 |
-
|
208 |
-
# Check how many objects are fully contained in this region
|
209 |
-
contained_boxes = []
|
210 |
-
contained_keypoints = []
|
211 |
-
for box, kp in zip(boxes, keypoints):
|
212 |
-
if box[0] >= crop_region[0] and box[1] >= crop_region[1] and box[2] <= crop_region[2] and box[3] <= crop_region[3]:
|
213 |
-
# Adjust box and keypoints coordinates
|
214 |
-
new_box = box - torch.tensor([crop_region[0], crop_region[1], crop_region[0], crop_region[1]])
|
215 |
-
new_kp = kp - torch.tensor([crop_region[0], crop_region[1], 0])
|
216 |
-
contained_boxes.append(new_box)
|
217 |
-
contained_keypoints.append(new_kp)
|
218 |
-
|
219 |
-
if len(contained_boxes) >= self.min_objects:
|
220 |
-
success = True
|
221 |
-
break
|
222 |
-
|
223 |
-
if success:
|
224 |
-
# Perform the actual crop
|
225 |
-
image = F.crop(image, top, left, new_h, new_w)
|
226 |
-
target["boxes"] = torch.stack(contained_boxes) if contained_boxes else torch.zeros((0, 4))
|
227 |
-
if 'keypoints' in target:
|
228 |
-
target["keypoints"] = torch.stack(contained_keypoints) if contained_keypoints else torch.zeros((0, 2, 4))
|
229 |
-
|
230 |
-
return image, target
|
231 |
-
|
232 |
-
|
233 |
-
class RandomFlip:
|
234 |
-
def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5):
|
235 |
-
"""
|
236 |
-
Initializes the RandomFlip with probabilities for flipping.
|
237 |
-
|
238 |
-
Parameters:
|
239 |
-
- h_flip_prob (float): Probability of applying a horizontal flip to the image.
|
240 |
-
- v_flip_prob (float): Probability of applying a vertical flip to the image.
|
241 |
-
"""
|
242 |
-
self.h_flip_prob = h_flip_prob
|
243 |
-
self.v_flip_prob = v_flip_prob
|
244 |
-
|
245 |
-
def __call__(self, image, target):
|
246 |
-
"""
|
247 |
-
Applies random horizontal and/or vertical flip to the image and updates target data accordingly.
|
248 |
-
|
249 |
-
Parameters:
|
250 |
-
- image (PIL Image): The image to be flipped.
|
251 |
-
- target (dict): The target dictionary containing 'boxes' and 'keypoints'.
|
252 |
-
|
253 |
-
Returns:
|
254 |
-
- PIL Image, dict: The flipped image and its updated target dictionary.
|
255 |
-
"""
|
256 |
-
if random.random() < self.h_flip_prob:
|
257 |
-
image = F.hflip(image)
|
258 |
-
w, _ = image.size # Get the new width of the image after flip for bounding box adjustment
|
259 |
-
# Adjust bounding boxes for horizontal flip
|
260 |
-
for i, box in enumerate(target['boxes']):
|
261 |
-
xmin, ymin, xmax, ymax = box
|
262 |
-
target['boxes'][i] = torch.tensor([w - xmax, ymin, w - xmin, ymax], dtype=torch.float32)
|
263 |
-
|
264 |
-
# Adjust keypoints for horizontal flip
|
265 |
-
if 'keypoints' in target:
|
266 |
-
new_keypoints = []
|
267 |
-
for keypoints_for_object in target['keypoints']:
|
268 |
-
flipped_keypoints_for_object = []
|
269 |
-
for kp in keypoints_for_object:
|
270 |
-
x, y = kp[:2]
|
271 |
-
new_x = w - x
|
272 |
-
flipped_keypoints_for_object.append(torch.tensor([new_x, y] + list(kp[2:])))
|
273 |
-
new_keypoints.append(torch.stack(flipped_keypoints_for_object))
|
274 |
-
target['keypoints'] = torch.stack(new_keypoints)
|
275 |
-
|
276 |
-
if random.random() < self.v_flip_prob:
|
277 |
-
image = F.vflip(image)
|
278 |
-
_, h = image.size # Get the new height of the image after flip for bounding box adjustment
|
279 |
-
# Adjust bounding boxes for vertical flip
|
280 |
-
for i, box in enumerate(target['boxes']):
|
281 |
-
xmin, ymin, xmax, ymax = box
|
282 |
-
target['boxes'][i] = torch.tensor([xmin, h - ymax, xmax, h - ymin], dtype=torch.float32)
|
283 |
-
|
284 |
-
# Adjust keypoints for vertical flip
|
285 |
-
if 'keypoints' in target:
|
286 |
-
new_keypoints = []
|
287 |
-
for keypoints_for_object in target['keypoints']:
|
288 |
-
flipped_keypoints_for_object = []
|
289 |
-
for kp in keypoints_for_object:
|
290 |
-
x, y = kp[:2]
|
291 |
-
new_y = h - y
|
292 |
-
flipped_keypoints_for_object.append(torch.tensor([x, new_y] + list(kp[2:])))
|
293 |
-
new_keypoints.append(torch.stack(flipped_keypoints_for_object))
|
294 |
-
target['keypoints'] = torch.stack(new_keypoints)
|
295 |
-
|
296 |
-
return image, target
|
297 |
-
|
298 |
-
|
299 |
-
class RandomRotate:
|
300 |
-
def __init__(self, max_rotate_deg=20, rotate_proba=0.3):
|
301 |
-
"""
|
302 |
-
Initializes the RandomRotate with a maximum rotation angle and probability of rotating.
|
303 |
-
|
304 |
-
Parameters:
|
305 |
-
- max_rotate_deg (int): Maximum degree to rotate the image.
|
306 |
-
- rotate_proba (float): Probability of applying rotation to the image.
|
307 |
-
"""
|
308 |
-
self.max_rotate_deg = max_rotate_deg
|
309 |
-
self.rotate_proba = rotate_proba
|
310 |
-
|
311 |
-
def __call__(self, image, target):
|
312 |
-
"""
|
313 |
-
Randomly rotates the image and updates the target data accordingly.
|
314 |
-
|
315 |
-
Parameters:
|
316 |
-
- image (PIL Image): The image to be rotated.
|
317 |
-
- target (dict): The target dictionary containing 'boxes', 'labels', and 'keypoints'.
|
318 |
-
|
319 |
-
Returns:
|
320 |
-
- PIL Image, dict: The rotated image and its updated target dictionary.
|
321 |
-
"""
|
322 |
-
if random.random() < self.rotate_proba:
|
323 |
-
angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg)
|
324 |
-
image = F.rotate(image, angle, expand=False, fill=200)
|
325 |
-
|
326 |
-
# Rotate bounding boxes
|
327 |
-
w, h = image.size
|
328 |
-
cx, cy = w / 2, h / 2
|
329 |
-
boxes = target["boxes"]
|
330 |
-
new_boxes = []
|
331 |
-
for box in boxes:
|
332 |
-
new_box = self.rotate_box(box, angle, cx, cy)
|
333 |
-
new_boxes.append(new_box)
|
334 |
-
target["boxes"] = torch.stack(new_boxes)
|
335 |
-
|
336 |
-
# Rotate keypoints
|
337 |
-
if 'keypoints' in target:
|
338 |
-
new_keypoints = []
|
339 |
-
for keypoints in target["keypoints"]:
|
340 |
-
new_kp = self.rotate_keypoints(keypoints, angle, cx, cy)
|
341 |
-
new_keypoints.append(new_kp)
|
342 |
-
target["keypoints"] = torch.stack(new_keypoints)
|
343 |
-
|
344 |
-
return image, target
|
345 |
-
|
346 |
-
def rotate_box(self, box, angle, cx, cy):
|
347 |
-
"""
|
348 |
-
Rotates a bounding box by a given angle around the center of the image.
|
349 |
-
"""
|
350 |
-
x1, y1, x2, y2 = box
|
351 |
-
corners = torch.tensor([
|
352 |
-
[x1, y1],
|
353 |
-
[x2, y1],
|
354 |
-
[x2, y2],
|
355 |
-
[x1, y2]
|
356 |
-
])
|
357 |
-
corners = torch.cat((corners, torch.ones(corners.shape[0], 1)), dim=1)
|
358 |
-
M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
|
359 |
-
corners = torch.matmul(torch.tensor(M, dtype=torch.float32), corners.T).T
|
360 |
-
x_ = corners[:, 0]
|
361 |
-
y_ = corners[:, 1]
|
362 |
-
x_min, x_max = torch.min(x_), torch.max(x_)
|
363 |
-
y_min, y_max = torch.min(y_), torch.max(y_)
|
364 |
-
return torch.tensor([x_min, y_min, x_max, y_max], dtype=torch.float32)
|
365 |
-
|
366 |
-
def rotate_keypoints(self, keypoints, angle, cx, cy):
|
367 |
-
"""
|
368 |
-
Rotates keypoints by a given angle around the center of the image.
|
369 |
-
"""
|
370 |
-
new_keypoints = []
|
371 |
-
for kp in keypoints:
|
372 |
-
x, y, v = kp
|
373 |
-
point = torch.tensor([x, y, 1])
|
374 |
-
M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
|
375 |
-
new_point = torch.matmul(torch.tensor(M, dtype=torch.float32), point)
|
376 |
-
new_keypoints.append(torch.tensor([new_point[0], new_point[1], v], dtype=torch.float32))
|
377 |
-
return torch.stack(new_keypoints)
|
378 |
-
|
379 |
-
def rotate_90_box(box, angle, w, h):
|
380 |
-
x1, y1, x2, y2 = box
|
381 |
-
if angle == 90:
|
382 |
-
return torch.tensor([y1,h-x2,y2,h-x1])
|
383 |
-
elif angle == 270 or angle == -90:
|
384 |
-
return torch.tensor([w-y2,x1,w-y1,x2])
|
385 |
-
else:
|
386 |
-
print("angle not supported")
|
387 |
-
|
388 |
-
def rotate_90_keypoints(kp, angle, w, h):
|
389 |
-
# Extract coordinates and visibility from each keypoint tensor
|
390 |
-
x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2]
|
391 |
-
x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2]
|
392 |
-
# Swap x and y coordinates for each keypoint
|
393 |
-
if angle == 90:
|
394 |
-
new = [[y1, h-x1, v1], [y2, h-x2, v2]]
|
395 |
-
elif angle == 270 or angle == -90:
|
396 |
-
new = [[w-y1, x1, v1], [w-y2, x2, v2]]
|
397 |
-
|
398 |
-
return torch.tensor(new, dtype=torch.float32)
|
399 |
-
|
400 |
-
|
401 |
-
def rotate_vertical(image, target):
|
402 |
-
# Rotate the image and target if the image is vertical
|
403 |
-
new_boxes = []
|
404 |
-
angle = random.choice([-90,90])
|
405 |
-
image = F.rotate(image, angle, expand=True, fill=200)
|
406 |
-
for box in target["boxes"]:
|
407 |
-
new_box = rotate_90_box(box, angle, image.size[0], image.size[1])
|
408 |
-
new_boxes.append(new_box)
|
409 |
-
target["boxes"] = torch.stack(new_boxes)
|
410 |
-
|
411 |
-
if 'keypoints' in target:
|
412 |
-
new_kp = []
|
413 |
-
for kp in target['keypoints']:
|
414 |
-
new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1])
|
415 |
-
new_kp.append(new_key)
|
416 |
-
target['keypoints'] = torch.stack(new_kp)
|
417 |
-
return image, target
|
418 |
-
|
419 |
-
class BPMN_Dataset(Dataset):
|
420 |
-
def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2, flip_transform=None, rotate_transform=None, new_size=(1333,800),keep_ratio=0.1,resize=True, model_type='object'):
|
421 |
-
self.annotations = annotations
|
422 |
-
print(f"Loaded {len(self.annotations)} annotations.")
|
423 |
-
self.transform = transform
|
424 |
-
self.crop_transform = crop_transform
|
425 |
-
self.crop_prob = crop_prob
|
426 |
-
self.flip_transform = flip_transform
|
427 |
-
self.rotate_transform = rotate_transform
|
428 |
-
self.resize = resize
|
429 |
-
self.new_size = new_size
|
430 |
-
self.keep_ratio = keep_ratio
|
431 |
-
self.model_type = model_type
|
432 |
-
if model_type == 'object':
|
433 |
-
self.dict = object_dict
|
434 |
-
elif model_type == 'arrow':
|
435 |
-
self.dict = arrow_dict
|
436 |
-
self.rotate_90_proba = rotate_90_proba
|
437 |
-
|
438 |
-
def __len__(self):
|
439 |
-
return len(self.annotations)
|
440 |
-
|
441 |
-
def __getitem__(self, idx):
|
442 |
-
annotation = self.annotations[idx]
|
443 |
-
image = annotation.img.convert("RGB")
|
444 |
-
boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32)
|
445 |
-
labels_names = [ann for ann in annotation.categories]
|
446 |
-
|
447 |
-
#only keep the labels, boxes and keypoints that are in the class_dict
|
448 |
-
kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()]
|
449 |
-
boxes = boxes[kept_indices]
|
450 |
-
labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices]
|
451 |
-
|
452 |
-
labels_id = torch.tensor([(list(self.dict.values()).index(ann)) for ann in labels_names], dtype=torch.int64)
|
453 |
-
|
454 |
-
# Initialize keypoints tensor
|
455 |
-
max_keypoints = 2
|
456 |
-
keypoints = torch.zeros((len(labels_id), max_keypoints, 3), dtype=torch.float32)
|
457 |
-
|
458 |
-
ii=0
|
459 |
-
for i, ann in enumerate(annotation.annotations):
|
460 |
-
#only keep the keypoints that are in the kept indices
|
461 |
-
if i not in kept_indices:
|
462 |
-
continue
|
463 |
-
if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]:
|
464 |
-
# Fill the keypoints tensor for this annotation, mark as visible (1)
|
465 |
-
kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3)
|
466 |
-
kp = kp[:,:2]
|
467 |
-
visible = np.ones((kp.shape[0], 1), dtype=np.float32)
|
468 |
-
kp = np.hstack([kp, visible])
|
469 |
-
keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32)
|
470 |
-
ii += 1
|
471 |
-
|
472 |
-
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
473 |
-
|
474 |
-
if self.model_type == 'object':
|
475 |
-
target = {
|
476 |
-
"boxes": boxes,
|
477 |
-
"labels": labels_id,
|
478 |
-
#"area": area,
|
479 |
-
#"keypoints": keypoints,
|
480 |
-
}
|
481 |
-
elif self.model_type == 'arrow':
|
482 |
-
target = {
|
483 |
-
"boxes": boxes,
|
484 |
-
"labels": labels_id,
|
485 |
-
#"area": area,
|
486 |
-
"keypoints": keypoints,
|
487 |
-
}
|
488 |
-
|
489 |
-
# Randomly apply flip transform
|
490 |
-
if self.flip_transform:
|
491 |
-
image, target = self.flip_transform(image, target)
|
492 |
-
|
493 |
-
# Randomly apply rotate transform
|
494 |
-
if self.rotate_transform:
|
495 |
-
image, target = self.rotate_transform(image, target)
|
496 |
-
|
497 |
-
# Randomly apply the custom cropping transform
|
498 |
-
if self.crop_transform and random.random() < self.crop_prob:
|
499 |
-
image, target = self.crop_transform(image, target)
|
500 |
-
|
501 |
-
# Rotate vertical image
|
502 |
-
if random.random() < self.rotate_90_proba:
|
503 |
-
image, target = rotate_vertical(image, target)
|
504 |
-
|
505 |
-
if self.resize:
|
506 |
-
if random.random() < self.keep_ratio:
|
507 |
-
original_size = image.size
|
508 |
-
# Calculate scale to fit the new size while maintaining aspect ratio
|
509 |
-
scale = min(self.new_size[0] / original_size[0], self.new_size[1] / original_size[1])
|
510 |
-
new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
|
511 |
-
|
512 |
-
target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), (new_scaled_size))
|
513 |
-
if 'area' in target:
|
514 |
-
target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
|
515 |
-
|
516 |
-
if 'keypoints' in target:
|
517 |
-
for i in range(len(target['keypoints'])):
|
518 |
-
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), (new_scaled_size))
|
519 |
-
|
520 |
-
# Resize image to new scaled size
|
521 |
-
image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
|
522 |
-
|
523 |
-
# Pad the resized image to make it exactly the desired size
|
524 |
-
padding = [0, 0, self.new_size[0] - new_scaled_size[0], self.new_size[1] - new_scaled_size[1]]
|
525 |
-
image = F.pad(image, padding, fill=200, padding_mode='constant')
|
526 |
-
else:
|
527 |
-
target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), self.new_size)
|
528 |
-
if 'area' in target:
|
529 |
-
target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
|
530 |
-
if 'keypoints' in target:
|
531 |
-
for i in range(len(target['keypoints'])):
|
532 |
-
target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), self.new_size)
|
533 |
-
image = F.resize(image, (self.new_size[1], self.new_size[0]))
|
534 |
-
|
535 |
-
return self.transform(image), target
|
536 |
-
|
537 |
-
def collate_fn(batch):
|
538 |
-
"""
|
539 |
-
Custom collation function for DataLoader that handles batches of images and targets.
|
540 |
-
|
541 |
-
This function ensures that images are properly batched together using PyTorch's default collation,
|
542 |
-
while keeping the targets (such as bounding boxes and labels) in a list of dictionaries,
|
543 |
-
as each image might have a different number of objects detected.
|
544 |
-
|
545 |
-
Parameters:
|
546 |
-
- batch (list): A list of tuples, where each tuple contains an image and its corresponding target dictionary.
|
547 |
-
|
548 |
-
Returns:
|
549 |
-
- Tuple containing:
|
550 |
-
- Tensor: Batched images.
|
551 |
-
- List of dicts: Targets corresponding to each image in the batch.
|
552 |
-
"""
|
553 |
-
images, targets = zip(*batch) # Unzip the batch into separate lists for images and targets.
|
554 |
-
|
555 |
-
# Batch images using the default collate function which handles tensors, numpy arrays, numbers, etc.
|
556 |
-
images = default_collate(images)
|
557 |
-
|
558 |
-
return images, targets
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
def create_loader(new_size,transformation, annotations1, annotations2=None,
|
563 |
-
batch_size=4, crop_prob=0.2, crop_fraction=0.7, min_objects=3,
|
564 |
-
h_flip_prob=0.3, v_flip_prob=0.3, max_rotate_deg=20, rotate_90_proba=0.2, rotate_proba=0.3,
|
565 |
-
seed=42, resize=True, keep_ratio=0.1, model_type = 'object'):
|
566 |
-
"""
|
567 |
-
Creates a DataLoader for BPMN datasets with optional transformations and concatenation of two datasets.
|
568 |
-
|
569 |
-
Parameters:
|
570 |
-
- transformation (callable): Transformation function to apply to each image (e.g., normalization).
|
571 |
-
- annotations1 (list): Primary list of annotations.
|
572 |
-
- annotations2 (list, optional): Secondary list of annotations to concatenate with the first.
|
573 |
-
- batch_size (int): Number of images per batch.
|
574 |
-
- crop_prob (float): Probability of applying the crop transformation.
|
575 |
-
- crop_fraction (float): Fraction of the original width to use when cropping.
|
576 |
-
- min_objects (int): Minimum number of objects required to be within the crop.
|
577 |
-
- h_flip_prob (float): Probability of applying horizontal flip.
|
578 |
-
- v_flip_prob (float): Probability of applying vertical flip.
|
579 |
-
- seed (int): Seed for random number generators for reproducibility.
|
580 |
-
- resize (bool): Flag indicating whether to resize images after transformations.
|
581 |
-
|
582 |
-
Returns:
|
583 |
-
- DataLoader: Configured data loader for the dataset.
|
584 |
-
"""
|
585 |
-
|
586 |
-
# Initialize custom transformations for cropping and flipping
|
587 |
-
custom_crop_transform = RandomCrop(new_size,crop_fraction, min_objects)
|
588 |
-
custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob)
|
589 |
-
custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba)
|
590 |
-
|
591 |
-
# Create the primary dataset
|
592 |
-
dataset = BPMN_Dataset(
|
593 |
-
annotations=annotations1,
|
594 |
-
transform=transformation,
|
595 |
-
crop_transform=custom_crop_transform,
|
596 |
-
crop_prob=crop_prob,
|
597 |
-
rotate_90_proba=rotate_90_proba,
|
598 |
-
flip_transform=custom_flip_transform,
|
599 |
-
rotate_transform=custom_rotate_transform,
|
600 |
-
new_size=new_size,
|
601 |
-
keep_ratio=keep_ratio,
|
602 |
-
model_type=model_type,
|
603 |
-
resize=resize
|
604 |
-
)
|
605 |
-
|
606 |
-
# Optionally concatenate a second dataset
|
607 |
-
if annotations2:
|
608 |
-
dataset2 = BPMN_Dataset(
|
609 |
-
annotations=annotations2,
|
610 |
-
transform=transformation,
|
611 |
-
crop_transform=custom_crop_transform,
|
612 |
-
crop_prob=crop_prob,
|
613 |
-
rotate_90_proba=rotate_90_proba,
|
614 |
-
flip_transform=custom_flip_transform,
|
615 |
-
new_size=new_size,
|
616 |
-
keep_ratio=keep_ratio,
|
617 |
-
model_type=model_type,
|
618 |
-
resize=resize
|
619 |
-
)
|
620 |
-
dataset = ConcatDataset([dataset, dataset2]) # Concatenate the two datasets
|
621 |
-
|
622 |
-
# Set the seed for reproducibility in random operations within transformations and data loading
|
623 |
-
random.seed(seed)
|
624 |
-
torch.manual_seed(seed)
|
625 |
-
|
626 |
-
# Create the DataLoader with the dataset
|
627 |
-
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
|
628 |
-
|
629 |
-
return data_loader
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
def write_results(name_model,metrics_list,start_epoch):
|
634 |
with open('./results/'+ name_model+ '.txt', 'w') as f:
|
635 |
for i in range(len(metrics_list[0])):
|
|
|
14 |
import streamlit as st
|
15 |
|
16 |
|
17 |
+
"""object_dict = {
|
18 |
+
0: 'background',
|
19 |
+
1: 'task',
|
20 |
+
2: 'exclusiveGateway',
|
21 |
+
3: 'eventBasedGateway',
|
22 |
+
4: 'event',
|
23 |
+
5: 'messageEvent',
|
24 |
+
6: 'timerEvent',
|
25 |
+
7: 'dataObject',
|
26 |
+
8: 'dataStore',
|
27 |
+
9: 'pool',
|
28 |
+
10: 'lane',
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
arrow_dict = {
|
33 |
+
0: 'background',
|
34 |
+
1: 'sequenceFlow',
|
35 |
+
2: 'dataAssociation',
|
36 |
+
3: 'messageFlow',
|
37 |
+
}
|
38 |
+
|
39 |
+
class_dict = {
|
40 |
+
0: 'background',
|
41 |
+
1: 'task',
|
42 |
+
2: 'exclusiveGateway',
|
43 |
+
3: 'eventBasedGateway',
|
44 |
+
4: 'event',
|
45 |
+
5: 'messageEvent',
|
46 |
+
6: 'timerEvent',
|
47 |
+
7: 'dataObject',
|
48 |
+
8: 'dataStore',
|
49 |
+
9: 'pool',
|
50 |
+
10: 'lane',
|
51 |
+
11: 'sequenceFlow',
|
52 |
+
12: 'dataAssociation',
|
53 |
+
13: 'messageFlow',
|
54 |
+
}"""
|
55 |
+
|
56 |
+
|
57 |
object_dict = {
|
58 |
0: 'background',
|
59 |
1: 'task',
|
|
|
130 |
return inter_area / union_area
|
131 |
|
132 |
def proportion_inside(box1, box2):
|
133 |
+
# Calculate the areas of both boxes
|
134 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
135 |
+
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
136 |
+
|
137 |
+
# Determine the bigger and smaller boxes
|
138 |
+
if box1_area > box2_area:
|
139 |
+
big_box = box1
|
140 |
+
small_box = box2
|
141 |
+
else:
|
142 |
+
big_box = box2
|
143 |
+
small_box = box1
|
144 |
+
|
145 |
# Calculate the intersection of the two bounding boxes
|
146 |
+
inter_box = [max(small_box[0], big_box[0]), max(small_box[1], big_box[1]), min(small_box[2], big_box[2]), min(small_box[3], big_box[3])]
|
147 |
inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
|
148 |
|
149 |
+
# Calculate the proportion of the smaller box inside the bigger box
|
150 |
+
if (small_box[2] - small_box[0]) * (small_box[3] - small_box[1]) == 0:
|
|
|
|
|
|
|
151 |
return 0
|
152 |
+
proportion = inter_area / ((small_box[2] - small_box[0]) * (small_box[3] - small_box[1]))
|
153 |
|
154 |
# Ensure the proportion is at most 100%
|
155 |
return min(proportion, 1.0)
|
|
|
213 |
return keypoints
|
214 |
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
def write_results(name_model,metrics_list,start_epoch):
|
217 |
with open('./results/'+ name_model+ '.txt', 'w') as f:
|
218 |
for i in range(len(metrics_list[0])):
|