BenjiELCA commited on
Commit
ca37b38
·
1 Parent(s): 9467fbe

correct a lot of bugs and allow automatic resize value

Browse files
modules/dataset_loader.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.models.detection import keypointrcnn_resnet50_fpn
2
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
3
+ from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
4
+ from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
5
+ import random
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ import torchvision.transforms.functional as F
9
+ import numpy as np
10
+ from torch.utils.data.dataloader import default_collate
11
+ import cv2
12
+ import matplotlib.pyplot as plt
13
+ from torch.utils.data import DataLoader, Subset, ConcatDataset
14
+ import streamlit as st
15
+ from modules.utils import object_dict, arrow_dict, resize_boxes, resize_keypoints
16
+
17
+ class RandomCrop:
18
+ def __init__(self, new_size=(1333,800),crop_fraction=0.5, min_objects=4):
19
+ self.crop_fraction = crop_fraction
20
+ self.min_objects = min_objects
21
+ self.new_size = new_size
22
+
23
+ def __call__(self, image, target):
24
+ new_w1, new_h1 = self.new_size
25
+ w, h = image.size
26
+ new_w = int(w * self.crop_fraction)
27
+ new_h = int(new_w*new_h1/new_w1)
28
+
29
+ i=0
30
+ for i in range(4):
31
+ if new_h >= h:
32
+ i += 0.05
33
+ new_w = int(w * (self.crop_fraction - i))
34
+ new_h = int(new_w*new_h1/new_w1)
35
+ if new_h < h:
36
+ continue
37
+
38
+ if new_h >= h:
39
+ return image, target
40
+
41
+ boxes = target["boxes"]
42
+ if 'keypoints' in target:
43
+ keypoints = target["keypoints"]
44
+ else:
45
+ keypoints = []
46
+ for i in range(len(boxes)):
47
+ keypoints.append(torch.zeros((2,3)))
48
+
49
+
50
+ # Attempt to find a suitable crop region
51
+ success = False
52
+ for _ in range(100): # Max 100 attempts to find a valid crop
53
+ top = random.randint(0, h - new_h)
54
+ left = random.randint(0, w - new_w)
55
+ crop_region = [left, top, left + new_w, top + new_h]
56
+
57
+ # Check how many objects are fully contained in this region
58
+ contained_boxes = []
59
+ contained_keypoints = []
60
+ for box, kp in zip(boxes, keypoints):
61
+ if box[0] >= crop_region[0] and box[1] >= crop_region[1] and box[2] <= crop_region[2] and box[3] <= crop_region[3]:
62
+ # Adjust box and keypoints coordinates
63
+ new_box = box - torch.tensor([crop_region[0], crop_region[1], crop_region[0], crop_region[1]])
64
+ new_kp = kp - torch.tensor([crop_region[0], crop_region[1], 0])
65
+ contained_boxes.append(new_box)
66
+ contained_keypoints.append(new_kp)
67
+
68
+ if len(contained_boxes) >= self.min_objects:
69
+ success = True
70
+ break
71
+
72
+ if success:
73
+ # Perform the actual crop
74
+ image = F.crop(image, top, left, new_h, new_w)
75
+ target["boxes"] = torch.stack(contained_boxes) if contained_boxes else torch.zeros((0, 4))
76
+ if 'keypoints' in target:
77
+ target["keypoints"] = torch.stack(contained_keypoints) if contained_keypoints else torch.zeros((0, 2, 4))
78
+
79
+ return image, target
80
+
81
+
82
+ class RandomFlip:
83
+ def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5):
84
+ """
85
+ Initializes the RandomFlip with probabilities for flipping.
86
+
87
+ Parameters:
88
+ - h_flip_prob (float): Probability of applying a horizontal flip to the image.
89
+ - v_flip_prob (float): Probability of applying a vertical flip to the image.
90
+ """
91
+ self.h_flip_prob = h_flip_prob
92
+ self.v_flip_prob = v_flip_prob
93
+
94
+ def __call__(self, image, target):
95
+ """
96
+ Applies random horizontal and/or vertical flip to the image and updates target data accordingly.
97
+
98
+ Parameters:
99
+ - image (PIL Image): The image to be flipped.
100
+ - target (dict): The target dictionary containing 'boxes' and 'keypoints'.
101
+
102
+ Returns:
103
+ - PIL Image, dict: The flipped image and its updated target dictionary.
104
+ """
105
+ if random.random() < self.h_flip_prob:
106
+ image = F.hflip(image)
107
+ w, _ = image.size # Get the new width of the image after flip for bounding box adjustment
108
+ # Adjust bounding boxes for horizontal flip
109
+ for i, box in enumerate(target['boxes']):
110
+ xmin, ymin, xmax, ymax = box
111
+ target['boxes'][i] = torch.tensor([w - xmax, ymin, w - xmin, ymax], dtype=torch.float32)
112
+
113
+ # Adjust keypoints for horizontal flip
114
+ if 'keypoints' in target:
115
+ new_keypoints = []
116
+ for keypoints_for_object in target['keypoints']:
117
+ flipped_keypoints_for_object = []
118
+ for kp in keypoints_for_object:
119
+ x, y = kp[:2]
120
+ new_x = w - x
121
+ flipped_keypoints_for_object.append(torch.tensor([new_x, y] + list(kp[2:])))
122
+ new_keypoints.append(torch.stack(flipped_keypoints_for_object))
123
+ target['keypoints'] = torch.stack(new_keypoints)
124
+
125
+ if random.random() < self.v_flip_prob:
126
+ image = F.vflip(image)
127
+ _, h = image.size # Get the new height of the image after flip for bounding box adjustment
128
+ # Adjust bounding boxes for vertical flip
129
+ for i, box in enumerate(target['boxes']):
130
+ xmin, ymin, xmax, ymax = box
131
+ target['boxes'][i] = torch.tensor([xmin, h - ymax, xmax, h - ymin], dtype=torch.float32)
132
+
133
+ # Adjust keypoints for vertical flip
134
+ if 'keypoints' in target:
135
+ new_keypoints = []
136
+ for keypoints_for_object in target['keypoints']:
137
+ flipped_keypoints_for_object = []
138
+ for kp in keypoints_for_object:
139
+ x, y = kp[:2]
140
+ new_y = h - y
141
+ flipped_keypoints_for_object.append(torch.tensor([x, new_y] + list(kp[2:])))
142
+ new_keypoints.append(torch.stack(flipped_keypoints_for_object))
143
+ target['keypoints'] = torch.stack(new_keypoints)
144
+
145
+ return image, target
146
+
147
+
148
+ class RandomRotate:
149
+ def __init__(self, max_rotate_deg=20, rotate_proba=0.3):
150
+ """
151
+ Initializes the RandomRotate with a maximum rotation angle and probability of rotating.
152
+
153
+ Parameters:
154
+ - max_rotate_deg (int): Maximum degree to rotate the image.
155
+ - rotate_proba (float): Probability of applying rotation to the image.
156
+ """
157
+ self.max_rotate_deg = max_rotate_deg
158
+ self.rotate_proba = rotate_proba
159
+
160
+ def __call__(self, image, target):
161
+ """
162
+ Randomly rotates the image and updates the target data accordingly.
163
+
164
+ Parameters:
165
+ - image (PIL Image): The image to be rotated.
166
+ - target (dict): The target dictionary containing 'boxes', 'labels', and 'keypoints'.
167
+
168
+ Returns:
169
+ - PIL Image, dict: The rotated image and its updated target dictionary.
170
+ """
171
+ if random.random() < self.rotate_proba:
172
+ angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg)
173
+ image = F.rotate(image, angle, expand=False, fill=200)
174
+
175
+ # Rotate bounding boxes
176
+ w, h = image.size
177
+ cx, cy = w / 2, h / 2
178
+ boxes = target["boxes"]
179
+ new_boxes = []
180
+ for box in boxes:
181
+ new_box = self.rotate_box(box, angle, cx, cy)
182
+ new_boxes.append(new_box)
183
+ target["boxes"] = torch.stack(new_boxes)
184
+
185
+ # Rotate keypoints
186
+ if 'keypoints' in target:
187
+ new_keypoints = []
188
+ for keypoints in target["keypoints"]:
189
+ new_kp = self.rotate_keypoints(keypoints, angle, cx, cy)
190
+ new_keypoints.append(new_kp)
191
+ target["keypoints"] = torch.stack(new_keypoints)
192
+
193
+ return image, target
194
+
195
+ def rotate_box(self, box, angle, cx, cy):
196
+ """
197
+ Rotates a bounding box by a given angle around the center of the image.
198
+ """
199
+ x1, y1, x2, y2 = box
200
+ corners = torch.tensor([
201
+ [x1, y1],
202
+ [x2, y1],
203
+ [x2, y2],
204
+ [x1, y2]
205
+ ])
206
+ corners = torch.cat((corners, torch.ones(corners.shape[0], 1)), dim=1)
207
+ M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
208
+ corners = torch.matmul(torch.tensor(M, dtype=torch.float32), corners.T).T
209
+ x_ = corners[:, 0]
210
+ y_ = corners[:, 1]
211
+ x_min, x_max = torch.min(x_), torch.max(x_)
212
+ y_min, y_max = torch.min(y_), torch.max(y_)
213
+ return torch.tensor([x_min, y_min, x_max, y_max], dtype=torch.float32)
214
+
215
+ def rotate_keypoints(self, keypoints, angle, cx, cy):
216
+ """
217
+ Rotates keypoints by a given angle around the center of the image.
218
+ """
219
+ new_keypoints = []
220
+ for kp in keypoints:
221
+ x, y, v = kp
222
+ point = torch.tensor([x, y, 1])
223
+ M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
224
+ new_point = torch.matmul(torch.tensor(M, dtype=torch.float32), point)
225
+ new_keypoints.append(torch.tensor([new_point[0], new_point[1], v], dtype=torch.float32))
226
+ return torch.stack(new_keypoints)
227
+
228
+ def rotate_90_box(box, angle, w, h):
229
+ x1, y1, x2, y2 = box
230
+ if angle == 90:
231
+ return torch.tensor([y1,h-x2,y2,h-x1])
232
+ elif angle == 270 or angle == -90:
233
+ return torch.tensor([w-y2,x1,w-y1,x2])
234
+ else:
235
+ print("angle not supported")
236
+
237
+ def rotate_90_keypoints(kp, angle, w, h):
238
+ # Extract coordinates and visibility from each keypoint tensor
239
+ x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2]
240
+ x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2]
241
+ # Swap x and y coordinates for each keypoint
242
+ if angle == 90:
243
+ new = [[y1, h-x1, v1], [y2, h-x2, v2]]
244
+ elif angle == 270 or angle == -90:
245
+ new = [[w-y1, x1, v1], [w-y2, x2, v2]]
246
+
247
+ return torch.tensor(new, dtype=torch.float32)
248
+
249
+
250
+ def rotate_vertical(image, target):
251
+ # Rotate the image and target if the image is vertical
252
+ new_boxes = []
253
+ angle = random.choice([-90,90])
254
+ image = F.rotate(image, angle, expand=True, fill=200)
255
+ for box in target["boxes"]:
256
+ new_box = rotate_90_box(box, angle, image.size[0], image.size[1])
257
+ new_boxes.append(new_box)
258
+ target["boxes"] = torch.stack(new_boxes)
259
+
260
+ if 'keypoints' in target:
261
+ new_kp = []
262
+ for kp in target['keypoints']:
263
+ new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1])
264
+ new_kp.append(new_key)
265
+ target['keypoints'] = torch.stack(new_kp)
266
+ return image, target
267
+
268
+
269
+ import torchvision.transforms.functional as F
270
+ import torch
271
+
272
+ def resize_and_pad(image, target, new_size=(1333, 800)):
273
+ original_size = image.size
274
+ # Calculate scale to fit the new size while maintaining aspect ratio
275
+ scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
276
+ new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
277
+
278
+ # Resize image to new scaled size
279
+ image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
280
+
281
+ # Calculate padding to center the image
282
+ pad_left = (new_size[0] - new_scaled_size[0]) // 2
283
+ pad_top = (new_size[1] - new_scaled_size[1]) // 2
284
+ pad_right = new_size[0] - new_scaled_size[0] - pad_left
285
+ pad_bottom = new_size[1] - new_scaled_size[1] - pad_top
286
+
287
+ # Pad the resized image to make it exactly the desired size
288
+ image = F.pad(image, (pad_left, pad_top, pad_right, pad_bottom), fill=0, padding_mode='constant')
289
+
290
+ # Adjust bounding boxes
291
+ target['boxes'] = resize_boxes(target['boxes'], original_size, new_scaled_size)
292
+ target['boxes'][:, 0::2] += pad_left
293
+ target['boxes'][:, 1::2] += pad_top
294
+
295
+ # Adjust keypoints if they exist in the target
296
+ if 'keypoints' in target:
297
+ for i in range(len(target['keypoints'])):
298
+ target['keypoints'][i] = resize_keypoints(target['keypoints'][i], original_size, new_scaled_size)
299
+ target['keypoints'][i][:, 0] += pad_left
300
+ target['keypoints'][i][:, 1] += pad_top
301
+
302
+ return image, target
303
+
304
+ class BPMN_Dataset(Dataset):
305
+ def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2,
306
+ flip_transform=None, rotate_transform=None, new_size=(1333,1333), keep_ratio=0.1, resize=True, model_type='object'):
307
+ self.annotations = annotations
308
+ print(f"Loaded {len(self.annotations)} annotations.")
309
+ self.transform = transform
310
+ self.crop_transform = crop_transform
311
+ self.crop_prob = crop_prob
312
+ self.flip_transform = flip_transform
313
+ self.rotate_transform = rotate_transform
314
+ self.resize = resize
315
+ self.new_size = new_size
316
+ self.keep_ratio = keep_ratio
317
+ self.model_type = model_type
318
+ if model_type == 'object':
319
+ self.dict = object_dict
320
+ elif model_type == 'arrow':
321
+ self.dict = arrow_dict
322
+ self.rotate_90_proba = rotate_90_proba
323
+
324
+ def __len__(self):
325
+ return len(self.annotations)
326
+
327
+ def __getitem__(self, idx):
328
+ annotation = self.annotations[idx]
329
+ image = annotation.img.convert("RGB")
330
+ boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32)
331
+ labels_names = [ann for ann in annotation.categories]
332
+
333
+ # Only keep the labels, boxes and keypoints that are in the class_dict
334
+ kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()]
335
+ boxes = boxes[kept_indices]
336
+ labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices]
337
+ # Replace any subprocess by task
338
+ labels_names = ['task' if ann == 'subProcess' else ann for ann in labels_names]
339
+
340
+ labels_id = torch.tensor([(list(self.dict.values()).index(ann)) for ann in labels_names], dtype=torch.int64)
341
+
342
+ # Initialize keypoints tensor
343
+ max_keypoints = 2
344
+ keypoints = torch.zeros((len(labels_id), max_keypoints, 3), dtype=torch.float32)
345
+
346
+ ii = 0
347
+ for i, ann in enumerate(annotation.annotations):
348
+ # Only keep the keypoints that are in the kept indices
349
+ if i not in kept_indices:
350
+ continue
351
+ if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]:
352
+ # Fill the keypoints tensor for this annotation, mark as visible (1)
353
+ kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3)
354
+ kp = kp[:,:2]
355
+ visible = np.ones((kp.shape[0], 1), dtype=np.float32)
356
+ kp = np.hstack([kp, visible])
357
+ keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32)
358
+ ii += 1
359
+
360
+ area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
361
+
362
+ if self.model_type == 'object':
363
+ target = {
364
+ "boxes": boxes,
365
+ "labels": labels_id,
366
+ #"area": area,
367
+ }
368
+ elif self.model_type == 'arrow':
369
+ target = {
370
+ "boxes": boxes,
371
+ "labels": labels_id,
372
+ #"area": area,
373
+ "keypoints": keypoints,
374
+ }
375
+
376
+ # Randomly apply flip transform
377
+ if self.flip_transform:
378
+ image, target = self.flip_transform(image, target)
379
+
380
+ # Randomly apply rotate transform
381
+ if self.rotate_transform:
382
+ image, target = self.rotate_transform(image, target)
383
+
384
+ # Randomly apply the custom cropping transform
385
+ if self.crop_transform and random.random() < self.crop_prob:
386
+ image, target = self.crop_transform(image, target)
387
+
388
+ # Rotate vertical image
389
+ if random.random() < self.rotate_90_proba:
390
+ image, target = rotate_vertical(image, target)
391
+
392
+ if self.resize:
393
+ if random.random() < self.keep_ratio:
394
+ # Center and pad the image while keeping the aspect ratio
395
+ image, target = resize_and_pad(image, target, self.new_size)
396
+ else:
397
+ target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), self.new_size)
398
+ if 'area' in target:
399
+ target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
400
+ if 'keypoints' in target:
401
+ for i in range(len(target['keypoints'])):
402
+ target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), self.new_size)
403
+ image = F.resize(image, (self.new_size[1], self.new_size[0]))
404
+
405
+ return self.transform(image), target
406
+
407
+
408
+ def collate_fn(batch):
409
+ """
410
+ Custom collation function for DataLoader that handles batches of images and targets.
411
+
412
+ This function ensures that images are properly batched together using PyTorch's default collation,
413
+ while keeping the targets (such as bounding boxes and labels) in a list of dictionaries,
414
+ as each image might have a different number of objects detected.
415
+
416
+ Parameters:
417
+ - batch (list): A list of tuples, where each tuple contains an image and its corresponding target dictionary.
418
+
419
+ Returns:
420
+ - Tuple containing:
421
+ - Tensor: Batched images.
422
+ - List of dicts: Targets corresponding to each image in the batch.
423
+ """
424
+ images, targets = zip(*batch) # Unzip the batch into separate lists for images and targets.
425
+
426
+ # Batch images using the default collate function which handles tensors, numpy arrays, numbers, etc.
427
+ images = default_collate(images)
428
+
429
+ return images, targets
430
+
431
+
432
+
433
+ def create_loader(new_size,transformation, annotations1, annotations2=None,
434
+ batch_size=4, crop_prob=0.2, crop_fraction=0.7, min_objects=3,
435
+ h_flip_prob=0.3, v_flip_prob=0.3, max_rotate_deg=20, rotate_90_proba=0.2, rotate_proba=0.3,
436
+ seed=42, resize=True, keep_ratio=0.1, model_type = 'object'):
437
+ """
438
+ Creates a DataLoader for BPMN datasets with optional transformations and concatenation of two datasets.
439
+
440
+ Parameters:
441
+ - transformation (callable): Transformation function to apply to each image (e.g., normalization).
442
+ - annotations1 (list): Primary list of annotations.
443
+ - annotations2 (list, optional): Secondary list of annotations to concatenate with the first.
444
+ - batch_size (int): Number of images per batch.
445
+ - crop_prob (float): Probability of applying the crop transformation.
446
+ - crop_fraction (float): Fraction of the original width to use when cropping.
447
+ - min_objects (int): Minimum number of objects required to be within the crop.
448
+ - h_flip_prob (float): Probability of applying horizontal flip.
449
+ - v_flip_prob (float): Probability of applying vertical flip.
450
+ - seed (int): Seed for random number generators for reproducibility.
451
+ - resize (bool): Flag indicating whether to resize images after transformations.
452
+
453
+ Returns:
454
+ - DataLoader: Configured data loader for the dataset.
455
+ """
456
+
457
+ # Initialize custom transformations for cropping and flipping
458
+ custom_crop_transform = RandomCrop(new_size,crop_fraction, min_objects)
459
+ custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob)
460
+ custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba)
461
+
462
+ # Create the primary dataset
463
+ dataset = BPMN_Dataset(
464
+ annotations=annotations1,
465
+ transform=transformation,
466
+ crop_transform=custom_crop_transform,
467
+ crop_prob=crop_prob,
468
+ rotate_90_proba=rotate_90_proba,
469
+ flip_transform=custom_flip_transform,
470
+ rotate_transform=custom_rotate_transform,
471
+ new_size=new_size,
472
+ keep_ratio=keep_ratio,
473
+ model_type=model_type,
474
+ resize=resize
475
+ )
476
+
477
+ # Optionally concatenate a second dataset
478
+ if annotations2:
479
+ dataset2 = BPMN_Dataset(
480
+ annotations=annotations2,
481
+ transform=transformation,
482
+ crop_transform=custom_crop_transform,
483
+ crop_prob=crop_prob,
484
+ rotate_90_proba=rotate_90_proba,
485
+ flip_transform=custom_flip_transform,
486
+ new_size=new_size,
487
+ keep_ratio=keep_ratio,
488
+ model_type=model_type,
489
+ resize=resize
490
+ )
491
+ dataset = ConcatDataset([dataset, dataset2]) # Concatenate the two datasets
492
+
493
+ # Set the seed for reproducibility in random operations within transformations and data loading
494
+ random.seed(seed)
495
+ torch.manual_seed(seed)
496
+
497
+ # Create the DataLoader with the dataset
498
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
499
+
500
+ return data_loader
modules/display.py CHANGED
@@ -1,4 +1,4 @@
1
- from modules.utils import draw_annotations, create_loader, class_dict, resize_boxes, resize_keypoints, find_other_keypoint
2
  import cv2
3
  import numpy as np
4
  import torch
 
1
+ from modules.utils import class_dict, resize_boxes, resize_keypoints, find_other_keypoint
2
  import cv2
3
  import numpy as np
4
  import torch
modules/eval.py CHANGED
@@ -3,8 +3,9 @@ import torch
3
  from modules.utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou
4
  from tqdm import tqdm
5
  from modules.toXML import get_size_elements, calculate_pool_bounds, create_BPMN_id
6
- from modules.utils import is_vertical
7
  import streamlit as st
 
8
 
9
 
10
  def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
@@ -101,10 +102,27 @@ def object_prediction(model, image, score_threshold=0.5, iou_threshold=0.5):
101
  scores = scores[selected_boxes]
102
  labels = labels[selected_boxes]
103
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  #modify the label of the sub-process to task
105
  for i in range(len(labels)):
106
  if labels[i] == list(object_dict.values()).index('subProcess'):
107
  labels[i] = list(object_dict.values()).index('task')
 
 
 
 
 
108
 
109
  prediction = {
110
  'boxes': boxes,
@@ -180,7 +198,7 @@ def mix_predictions(objects_pred, arrow_pred):
180
  return boxes, labels, scores, keypoints
181
 
182
 
183
- def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.3):
184
  pool_dict = {}
185
 
186
  # Filter out pools with IoU greater than the threshold
@@ -188,7 +206,7 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
188
  for i in range(len(boxes)):
189
  for j in range(i + 1, len(boxes)):
190
  if labels[i] == labels[j] and labels[i] == list(class_dict.values()).index('pool'):
191
- if iou(np.array(boxes[i]), np.array(boxes[j])) > iou_threshold:
192
  to_delete.append(j)
193
 
194
  boxes = np.delete(boxes, to_delete, axis=0)
@@ -210,8 +228,7 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
210
  if i in pool_indices or class_dict[labels[i]] in ['messageFlow', 'pool']:
211
  continue
212
  for j, pool_box in enumerate(pool_boxes):
213
- if (box[0] >= pool_box[0] and box[1] >= pool_box[1] and
214
- box[2] <= pool_box[2] and box[3] <= pool_box[3]):
215
  pool_index = pool_indices[j]
216
  pool_dict[pool_index].append(i)
217
  assigned_to_pool = True
@@ -322,6 +339,53 @@ def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
322
  return labels, flow_links
323
 
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict, limit_area=10000):
327
 
@@ -368,6 +432,16 @@ def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_point
368
  print('delete element', i)
369
  delete_elements.append(i)
370
 
 
 
 
 
 
 
 
 
 
 
371
  #concatenate the delete_elements and the delete_pool
372
  delete_elements = delete_elements + delete_pool
373
  #delete double value in delete_elements
@@ -377,13 +451,21 @@ def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_point
377
  labels = np.delete(labels, delete_elements)
378
  scores = np.delete(scores, delete_elements)
379
  keypoints = np.delete(keypoints, delete_elements, axis=0)
380
- bpmn_id = [point for i, point in enumerate(bpmn_id) if i not in delete_elements]
381
  links = np.delete(links, delete_elements, axis=0)
382
  best_points = [point for i, point in enumerate(best_points) if i not in delete_elements]
383
 
 
 
 
 
 
 
 
 
384
  #also delete the element in the pool_dict
385
  for pool_index, elements in pool_dict.items():
386
- pool_dict[pool_index] = [i for i in elements if i not in delete_elements]
387
 
388
  return boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict
389
 
@@ -420,7 +502,7 @@ def generate_data(image, boxes, labels, scores, keypoints, bpmn_id, flow_links,
420
 
421
  return data
422
 
423
- def develop_prediction(boxes, labels, scores, keypoints, class_dict, correction=True):
424
 
425
  pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
426
 
@@ -430,8 +512,7 @@ def develop_prediction(boxes, labels, scores, keypoints, class_dict, correction=
430
  flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
431
 
432
  #Correct the labels of some sequenceflow that cross multiple pool
433
- if correction:
434
- labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
435
 
436
  #give a link to event to allow the creation of the BPMN id with start, indermediate and end event
437
  flow_links = give_link_to_element(flow_links, labels)
 
3
  from modules.utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou
4
  from tqdm import tqdm
5
  from modules.toXML import get_size_elements, calculate_pool_bounds, create_BPMN_id
6
+ from modules.utils import is_vertical, proportion_inside
7
  import streamlit as st
8
+ from builtins import dict
9
 
10
 
11
  def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
 
102
  scores = scores[selected_boxes]
103
  labels = labels[selected_boxes]
104
 
105
+ #find the outlier object that are too small by the area
106
+ obj_not_too_small = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref = ['event', 'messageEvent'], mode = "lower")
107
+ obj_not_too_big = find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=2, element_ref = ['task'], mode = "upper")
108
+
109
+ selected_object = [i for i in range(len(labels)) if i in obj_not_too_small and i in obj_not_too_big]
110
+
111
+ #selected_object = obj_not_too_small
112
+
113
+ boxes = boxes[selected_object]
114
+ scores = scores[selected_object]
115
+ labels = labels[selected_object]
116
+
117
  #modify the label of the sub-process to task
118
  for i in range(len(labels)):
119
  if labels[i] == list(object_dict.values()).index('subProcess'):
120
  labels[i] = list(object_dict.values()).index('task')
121
+ #delete all lane and also the value in the labels and scores
122
+ lane_index = [i for i in range(len(labels)) if labels[i] == list(object_dict.values()).index('lane')]
123
+ boxes = np.delete(boxes, lane_index, axis=0)
124
+ labels = np.delete(labels, lane_index)
125
+ scores = np.delete(scores, lane_index)
126
 
127
  prediction = {
128
  'boxes': boxes,
 
198
  return boxes, labels, scores, keypoints
199
 
200
 
201
+ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.6):
202
  pool_dict = {}
203
 
204
  # Filter out pools with IoU greater than the threshold
 
206
  for i in range(len(boxes)):
207
  for j in range(i + 1, len(boxes)):
208
  if labels[i] == labels[j] and labels[i] == list(class_dict.values()).index('pool'):
209
+ if proportion_inside(boxes[i], boxes[j]) > iou_threshold:
210
  to_delete.append(j)
211
 
212
  boxes = np.delete(boxes, to_delete, axis=0)
 
228
  if i in pool_indices or class_dict[labels[i]] in ['messageFlow', 'pool']:
229
  continue
230
  for j, pool_box in enumerate(pool_boxes):
231
+ if proportion_inside(box, pool_box) > iou_threshold:
 
232
  pool_index = pool_indices[j]
233
  pool_dict[pool_index].append(i)
234
  assigned_to_pool = True
 
339
  return labels, flow_links
340
 
341
 
342
+ def find_outlier_objects_by_area(boxes, labels, class_dict, std_factor=1.5, element_ref = ['event', 'messageEvent'], mode = "lower"):
343
+ # Filter out the sizes of events, data objects, and message events
344
+ event_indices = [i for i, label in enumerate(labels) if class_dict[label] in element_ref]
345
+ event_boxes = [boxes[i] for i in event_indices]
346
+
347
+ # Calculate the areas of these typical objects
348
+ event_areas = np.array([(box[2] - box[0]) * (box[3] - box[1]) for box in event_boxes])
349
+
350
+ # Compute the mean and standard deviation for areas
351
+ mean_area = np.mean(event_areas)
352
+ std_area = np.std(event_areas)
353
+
354
+ # Define thresholds for outliers
355
+ area_lower_threshold = mean_area - std_factor * std_area
356
+ area_upper_threshold = mean_area + std_factor * std_area
357
+
358
+ # Identify indices of outliers and the ones to keep
359
+ outlier_indices = []
360
+ kept_indices = []
361
+
362
+ if mode == "lower" or mode == 'both':
363
+ #check for object that could be too small
364
+ for idx, (box, label) in enumerate(zip(boxes, labels)):
365
+ area = (box[2] - box[0]) * (box[3] - box[1])
366
+ if not (area_lower_threshold <= area):
367
+ outlier_indices.append(idx)
368
+ print(f"Element {idx} is an outlier with area {area} that is too small")
369
+ else:
370
+ kept_indices.append(idx)
371
+
372
+ if mode == "upper" or mode == 'both':
373
+ #check for object that could be too big
374
+ for idx, (box, label) in enumerate(zip(boxes, labels)):
375
+ if label == list(class_dict.values()).index('pool') or label == list(class_dict.values()).index('lane'):
376
+ kept_indices.append(idx)
377
+ continue
378
+ area = (box[2] - box[0]) * (box[3] - box[1])
379
+ if not (area_upper_threshold >= area):
380
+ outlier_indices.append(idx)
381
+ print(f"Element {idx} is an outlier with area {area} that is too big")
382
+ else:
383
+ kept_indices.append(idx)
384
+
385
+
386
+ return kept_indices
387
+
388
+
389
 
390
  def last_correction(boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict, limit_area=10000):
391
 
 
432
  print('delete element', i)
433
  delete_elements.append(i)
434
 
435
+ #filter box that are inside a text box
436
+ """tex_pred = st.session_state.text_pred
437
+ for i in range(len(boxes)):
438
+ for j in range(len(tex_pred[0])):
439
+ #check if the box is inside the text box but if the text box is inside the box then it is not a problem
440
+ if proportion_inside(boxes[i], tex_pred[0][j]) > 0.1:
441
+ #delete_elements.append(i)
442
+ print('delete element', i)"""
443
+
444
+
445
  #concatenate the delete_elements and the delete_pool
446
  delete_elements = delete_elements + delete_pool
447
  #delete double value in delete_elements
 
451
  labels = np.delete(labels, delete_elements)
452
  scores = np.delete(scores, delete_elements)
453
  keypoints = np.delete(keypoints, delete_elements, axis=0)
454
+
455
  links = np.delete(links, delete_elements, axis=0)
456
  best_points = [point for i, point in enumerate(best_points) if i not in delete_elements]
457
 
458
+ for i in range(len(delete_pool)):
459
+ #find the bpmn_id of the pool
460
+ pool_index = bpmn_id[delete_pool[i]]
461
+ #delete the pool_index in pool_dict
462
+ del pool_dict[pool_index]
463
+
464
+ bpmn_id = [point for i, point in enumerate(bpmn_id) if i not in delete_elements]
465
+
466
  #also delete the element in the pool_dict
467
  for pool_index, elements in pool_dict.items():
468
+ pool_dict[pool_index] = [i for i in elements if i not in delete_elements]
469
 
470
  return boxes, labels, scores, keypoints, bpmn_id, links, best_points, pool_dict
471
 
 
502
 
503
  return data
504
 
505
+ def develop_prediction(boxes, labels, scores, keypoints, class_dict):
506
 
507
  pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
508
 
 
512
  flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
513
 
514
  #Correct the labels of some sequenceflow that cross multiple pool
515
+ labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
 
516
 
517
  #give a link to event to allow the creation of the BPMN id with start, indermediate and end event
518
  flow_links = give_link_to_element(flow_links, labels)
modules/streamlit_utils.py CHANGED
@@ -30,6 +30,8 @@ from modules.toWizard import create_wizard_file
30
  from huggingface_hub import hf_hub_download
31
  import time
32
 
 
 
33
 
34
 
35
 
@@ -440,12 +442,13 @@ def modify_results(percentage_text_dist_thresh=0.5):
440
  new_keypoints = np.concatenate((object_keypoints, arrow_keypoints))
441
 
442
 
443
- boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, new_scores, new_keypoints, class_dict, correction=True)
444
 
445
  st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict)
446
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
447
 
448
  if changes:
 
449
  st.rerun()
450
 
451
  return True
@@ -460,14 +463,49 @@ def display_bpmn_modeler(is_mobile, screen_width):
460
  st.session_state.size_scale, st.session_state.scale
461
  )
462
  st.session_state.vizi_file = create_wizard_file(st.session_state.prediction.copy(), st.session_state.text_mapping)
463
- display_bpmn_xml(st.session_state.bpmn_xml, st.session_state.vizi_file, is_mobile=is_mobile, screen_width=int(4/5 * screen_width))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
  def modeler_options(is_mobile):
466
  if not is_mobile:
467
  with st.expander("Options for BPMN modeler"):
468
  col1, col2 = st.columns(2)
469
  with col1:
470
- st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
 
 
471
  st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
472
  else:
473
  st.session_state.scale = 1.0
 
30
  from huggingface_hub import hf_hub_download
31
  import time
32
 
33
+ from modules.toXML import get_size_elements
34
+
35
 
36
 
37
 
 
442
  new_keypoints = np.concatenate((object_keypoints, arrow_keypoints))
443
 
444
 
445
+ boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, new_scores, new_keypoints, class_dict)
446
 
447
  st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict)
448
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
449
 
450
  if changes:
451
+ changes = False
452
  st.rerun()
453
 
454
  return True
 
463
  st.session_state.size_scale, st.session_state.scale
464
  )
465
  st.session_state.vizi_file = create_wizard_file(st.session_state.prediction.copy(), st.session_state.text_mapping)
466
+
467
+ display_bpmn_xml(st.session_state.bpmn_xml, st.session_state.vizi_file, is_mobile=is_mobile, screen_width=int(4/5 * screen_width))
468
+
469
+
470
+ def find_best_scale(pred, size_elements):
471
+ boxes = pred['boxes']
472
+ labels = pred['labels']
473
+
474
+ # Find average size of the tasks in pred
475
+ avg_size = 0
476
+ count = 0
477
+ for i in range(len(boxes)):
478
+ if class_dict[labels[i]] == 'task':
479
+ avg_size += (boxes[i][2] - boxes[i][0]) * (boxes[i][3] - boxes[i][1])
480
+ count += 1
481
+
482
+ if count == 0:
483
+ raise ValueError("No tasks found in the provided prediction.")
484
+
485
+ avg_size /= count
486
+
487
+ # Get the size of a task element from size_elements dictionary
488
+ task_size = size_elements['task']
489
+ task_area = task_size[0] * task_size[1]
490
+
491
+ # Find the best scale
492
+ best_scale = (avg_size / task_area) ** 0.5
493
+
494
+ if best_scale < 0.5:
495
+ best_scale = 0.5
496
+ elif best_scale > 1:
497
+ best_scale = 1
498
+
499
+ return best_scale
500
 
501
  def modeler_options(is_mobile):
502
  if not is_mobile:
503
  with st.expander("Options for BPMN modeler"):
504
  col1, col2 = st.columns(2)
505
  with col1:
506
+ st.session_state.best_scale = find_best_scale(st.session_state.prediction, get_size_elements())
507
+ print(f"Best scale: {st.session_state.best_scale}")
508
+ st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1/st.session_state.best_scale, step=0.1)
509
  st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
510
  else:
511
  st.session_state.scale = 1.0
modules/toWizard.py CHANGED
@@ -131,7 +131,7 @@ def create_wizard_file(data, text_mapping):
131
  ET.SubElement(activity, 'subActivityFlows')
132
  ET.SubElement(activity, 'messageFlows')
133
 
134
- activityFlows = ET.SubElement(root, 'activityFlows')
135
  i=0
136
  for i, link in enumerate(data['links']):
137
  if link[0] is None and link[1] is not None and (data['BPMN_id'][i].split('_')[0] == 'event' or data['BPMN_id'][i].split('_')[0] == 'message'):
@@ -145,7 +145,7 @@ def create_wizard_file(data, text_mapping):
145
  if current_text is None or next_text is None:
146
  continue
147
  ET.SubElement(activityFlows, 'activityFlow', attrib={'activity': current_text, 'endState': '---', 'target': next_text, 'isMerging': 'False', 'isPredefined': 'True'})
148
- i+=1
149
 
150
  ET.SubElement(root, 'participants')
151
 
 
131
  ET.SubElement(activity, 'subActivityFlows')
132
  ET.SubElement(activity, 'messageFlows')
133
 
134
+ """activityFlows = ET.SubElement(root, 'activityFlows')
135
  i=0
136
  for i, link in enumerate(data['links']):
137
  if link[0] is None and link[1] is not None and (data['BPMN_id'][i].split('_')[0] == 'event' or data['BPMN_id'][i].split('_')[0] == 'message'):
 
145
  if current_text is None or next_text is None:
146
  continue
147
  ET.SubElement(activityFlows, 'activityFlow', attrib={'activity': current_text, 'endState': '---', 'target': next_text, 'isMerging': 'False', 'isPredefined': 'True'})
148
+ i+=1"""
149
 
150
  ET.SubElement(root, 'participants')
151
 
modules/toXML.py CHANGED
@@ -113,8 +113,8 @@ def expand_pool_bounding_boxes(modified_pred, pred, size_elements):
113
  if pool_width < 300 or pool_height < 30:
114
  error("The pool is maybe too small, please add more elements or increase the scale by zooming on the image.")
115
  continue
116
-
117
- modified_pred['boxes'][position] = [min_x - marge, min_y - marge // 2, min_x + pool_width + marge, min_y + pool_height + marge // 2]
118
 
119
  # Adjust left and right boundaries of all pools
120
  def adjust_pool_boundaries(modified_pred, pred):
@@ -148,9 +148,9 @@ def align_boxes(pred, size, class_dict):
148
  pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
149
  align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
150
 
151
- #if len(pred['pool_dict']) > 1:
152
- #expand_pool_bounding_boxes(modified_pred, pred, size)
153
- #adjust_pool_boundaries(modified_pred, pred)
154
 
155
  return modified_pred['boxes']
156
 
@@ -176,10 +176,11 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
176
  'id': "simpleExample"
177
  })
178
 
 
179
  size_elements = get_size_elements(size_scale)
180
 
181
  #if there is no pool or lane, create a pool with all elements
182
- if len(full_pred['pool_dict'])==0 or (len(full_pred['pool_dict'])==1 and len(full_pred['pool_dict']['pool_1'])==len(full_pred['labels'])):
183
  full_pred, text_mapping = create_big_pool(full_pred, text_mapping)
184
 
185
  #modify the boxes positions
@@ -249,13 +250,13 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
249
  return pretty_xml_as_string
250
 
251
  # Function that creates a single pool with all elements
252
- def create_big_pool(full_pred, text_mapping):
253
  # If no pools or lanes are detected, create a single pool with all elements
254
  new_pool_index = 'pool_1'
255
  size_elements = get_size_elements(st.session_state.size_scale)
256
  elements_pool = list(range(len(full_pred['boxes'])))
257
  min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred['boxes'],full_pred['labels'], elements_pool, size_elements)
258
- box = [min_x, min_y, max_x, max_y]
259
  full_pred['boxes'] = np.append(full_pred['boxes'], [box], axis=0)
260
  full_pred['pool_dict'][new_pool_index] = elements_pool
261
  full_pred['BPMN_id'].append('pool_1')
@@ -264,7 +265,7 @@ def create_big_pool(full_pred, text_mapping):
264
  return full_pred, text_mapping
265
 
266
  # Function that gives the size of the elements
267
- def get_size_elements(size_scale):
268
  size_elements = {
269
  'event': (size_scale*43.2, size_scale*43.2),
270
  'task': (size_scale*120, size_scale*96),
@@ -400,8 +401,9 @@ def check_data_association(i, links, labels, keep_elements):
400
 
401
  def create_data_Association(bpmn,data,size,element_id,current_idx,source_id,target_id):
402
  waypoints = calculate_waypoints(data, size, current_idx, source_id, target_id)
403
- add_diagram_edge(bpmn, element_id, waypoints)
404
-
 
405
  def check_eventBasedGateway(i, links, labels):
406
  status, links_idx = [], []
407
  for j, (k,l) in enumerate(links):
@@ -582,7 +584,7 @@ def calculate_pool_bounds(boxes, labels, keep_elements, size):
582
  max_x = max(max_x, x + element_width)
583
  max_y = max(max_y, y + element_height)
584
 
585
- return min_x-50, min_y-50, max_x+50, max_y+50
586
 
587
 
588
 
@@ -680,10 +682,16 @@ def calculate_waypoints(data, size, current_idx, source_id, target_id):
680
  if source_idx is None or target_idx is None:
681
  warning()
682
  return None
 
683
 
684
  name_source = source_id.split('_')[0]
685
  name_target = target_id.split('_')[0]
686
 
 
 
 
 
 
687
  # Get the position of the source and target
688
  source_x, source_y = data['boxes'][source_idx][:2]
689
  target_x, target_y = data['boxes'][target_idx][:2]
 
113
  if pool_width < 300 or pool_height < 30:
114
  error("The pool is maybe too small, please add more elements or increase the scale by zooming on the image.")
115
  continue
116
+
117
+ modified_pred['boxes'][position] = [min_x - marge//20, min_y - marge, min_x + pool_width + marge//20, min_y + pool_height + marge]
118
 
119
  # Adjust left and right boundaries of all pools
120
  def adjust_pool_boundaries(modified_pred, pred):
 
148
  pool_groups = calculate_centers_and_group_by_pool(pred, class_dict)
149
  align_elements_within_pool(modified_pred, pool_groups, class_dict, size)
150
 
151
+ if len(pred['pool_dict']) > 1:
152
+ expand_pool_bounding_boxes(modified_pred, pred, size)
153
+ adjust_pool_boundaries(modified_pred, pred)
154
 
155
  return modified_pred['boxes']
156
 
 
176
  'id': "simpleExample"
177
  })
178
 
179
+
180
  size_elements = get_size_elements(size_scale)
181
 
182
  #if there is no pool or lane, create a pool with all elements
183
+ if len(full_pred['pool_dict']) == 0 or (len(full_pred['pool_dict']) == 1 and len(next(iter(full_pred['pool_dict'].values()))) == len(full_pred['labels'])):
184
  full_pred, text_mapping = create_big_pool(full_pred, text_mapping)
185
 
186
  #modify the boxes positions
 
250
  return pretty_xml_as_string
251
 
252
  # Function that creates a single pool with all elements
253
+ def create_big_pool(full_pred, text_mapping, marge=50):
254
  # If no pools or lanes are detected, create a single pool with all elements
255
  new_pool_index = 'pool_1'
256
  size_elements = get_size_elements(st.session_state.size_scale)
257
  elements_pool = list(range(len(full_pred['boxes'])))
258
  min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred['boxes'],full_pred['labels'], elements_pool, size_elements)
259
+ box = [min_x-marge, min_y-marge, max_x+marge, max_y+marge]
260
  full_pred['boxes'] = np.append(full_pred['boxes'], [box], axis=0)
261
  full_pred['pool_dict'][new_pool_index] = elements_pool
262
  full_pred['BPMN_id'].append('pool_1')
 
265
  return full_pred, text_mapping
266
 
267
  # Function that gives the size of the elements
268
+ def get_size_elements(size_scale=1):
269
  size_elements = {
270
  'event': (size_scale*43.2, size_scale*43.2),
271
  'task': (size_scale*120, size_scale*96),
 
401
 
402
  def create_data_Association(bpmn,data,size,element_id,current_idx,source_id,target_id):
403
  waypoints = calculate_waypoints(data, size, current_idx, source_id, target_id)
404
+ if waypoints is not None:
405
+ add_diagram_edge(bpmn, element_id, waypoints)
406
+
407
  def check_eventBasedGateway(i, links, labels):
408
  status, links_idx = [], []
409
  for j, (k,l) in enumerate(links):
 
584
  max_x = max(max_x, x + element_width)
585
  max_y = max(max_y, y + element_height)
586
 
587
+ return min_x, min_y, max_x, max_y
588
 
589
 
590
 
 
682
  if source_idx is None or target_idx is None:
683
  warning()
684
  return None
685
+
686
 
687
  name_source = source_id.split('_')[0]
688
  name_target = target_id.split('_')[0]
689
 
690
+ avoid_element = ['pool', 'sequenceFlow', 'messageFlow', 'dataAssociation']
691
+ if name_target in avoid_element or name_source in avoid_element:
692
+ warning()
693
+ return None
694
+
695
  # Get the position of the source and target
696
  source_x, source_y = data['boxes'][source_idx][:2]
697
  target_x, target_y = data['boxes'][target_idx][:2]
modules/train.py CHANGED
@@ -100,7 +100,14 @@ def prepare_model(dict,opti,learning_rate= 0.0003,model_to_load=None, model_type
100
  return model, optimizer, device
101
 
102
 
 
 
 
 
103
 
 
 
 
104
 
105
  def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
106
  model.train() # Set the model to evaluation mode
@@ -178,13 +185,14 @@ def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=Fal
178
 
179
 
180
  def training_model(num_epochs, model, data_loader, subset_test_loader,
181
- optimizer, model_to_load=None, change_learning_rate=5, start_key=30,
182
- batch_size=4, crop_prob=0.2, h_flip_prob=0.3, v_flip_prob=0.3,
183
- max_rotate_deg=20, rotate_proba=0.2, blur_prob=0.2,
184
  score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
185
  information_training='training', start_epoch=0, loss_config=None, model_type = 'object',
186
  eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
187
 
 
 
188
 
189
  if loss_config is None:
190
  print('No loss config found, all losses will be used.')
@@ -219,14 +227,20 @@ def training_model(num_epochs, model, data_loader, subset_test_loader,
219
  bad_test_loss = 0
220
  previous_test_loss = 1000
221
 
 
 
 
222
  print(f"Let's go training {model_type} model with {num_epochs} epochs!")
223
- print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, Flip prob: {h_flip_prob}, Rotate prob: {rotate_proba}, Blur prob: {blur_prob}")
 
224
 
225
  for epoch in range(num_epochs):
226
 
227
- if (epoch>0 and (epoch)%change_learning_rate == 0) or bad_test_loss>1:
228
  learning_rate = 0.7*learning_rate
229
  optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
 
 
230
  print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
231
  bad_test_loss = 0
232
  if epoch>0 and (epoch)==start_key:
@@ -315,24 +329,19 @@ def training_model(num_epochs, model, data_loader, subset_test_loader,
315
 
316
 
317
  # Evaluate the model on the test set
318
- if eval_metric != 'loss':
319
- avg_test_loss = 0
320
- labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader,score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)
321
- print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
322
- if eval_metric == 'all':
323
- avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
324
- print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
325
  if eval_metric == 'loss':
326
  labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0,0,0,0,0,0
327
  avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
328
  print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
 
 
 
 
 
 
329
 
330
  print(f"Time: {time.time() - start:.2f} [s]")
331
 
332
-
333
- if epoch>0 and (epoch)%start_key == 0:
334
- print(f"Keypoints Accuracy: {key_accuracy:.4f}", end=", ")
335
-
336
  if eval_metric == 'f1_score':
337
  metric_used = f1_score
338
  elif eval_metric == 'precision':
@@ -357,15 +366,14 @@ def training_model(num_epochs, model, data_loader, subset_test_loader,
357
  epoch_test_loss.append(avg_test_loss)
358
 
359
  name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}"
 
360
 
361
  if same >=1 :
362
- metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
363
  torch.save(best_model_state, './models/'+ name_model +'.pth')
364
  write_results(name_model,metrics_list,start_epoch)
365
  break
366
 
367
  if (epoch+1+start_epoch) % 5 == 0:
368
- metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
369
  torch.save(best_model_state, './models/'+ name_model +'.pth')
370
  model.load_state_dict(best_model_state)
371
  write_results(name_model,metrics_list,start_epoch)
@@ -375,12 +383,11 @@ def training_model(num_epochs, model, data_loader, subset_test_loader,
375
  previous_test_loss = avg_test_loss
376
 
377
 
378
- print(f"\n Total time: {(time.time() - start_tot)/60} minutes, Best Epoch is {best_epoch} with an f1_score of {best_metrics:.4f}")
379
  if best_model_state:
380
- metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
381
  torch.save(best_model_state, './models/'+ name_model +'.pth')
382
  model.load_state_dict(best_model_state)
383
  write_results(name_model,metrics_list,start_epoch)
384
  print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}")
385
 
386
- return model, metrics_list
 
100
  return model, optimizer, device
101
 
102
 
103
+ import copy
104
+ from torch.optim import AdamW
105
+ import time
106
+ from modules.train import write_results
107
 
108
+ import torch
109
+ import numpy as np
110
+ from tqdm import tqdm
111
 
112
  def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
113
  model.train() # Set the model to evaluation mode
 
185
 
186
 
187
  def training_model(num_epochs, model, data_loader, subset_test_loader,
188
+ optimizer, model_to_load=None, change_learning_rate=100, start_key=100,
189
+ parameters=None, blur_prob=0.02,
 
190
  score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
191
  information_training='training', start_epoch=0, loss_config=None, model_type = 'object',
192
  eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
193
 
194
+ # Set the model to training mode
195
+ model.train()
196
 
197
  if loss_config is None:
198
  print('No loss config found, all losses will be used.')
 
227
  bad_test_loss = 0
228
  previous_test_loss = 1000
229
 
230
+ if parameters is not None:
231
+ batch_size, crop_prob, rotate_90_proba, h_flip_prob, v_flip_prob, max_rotate_deg, rotate_proba, keep_ratio = parameters.values()
232
+
233
  print(f"Let's go training {model_type} model with {num_epochs} epochs!")
234
+ if parameters is not None:
235
+ print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, H flip prob: {h_flip_prob}, V flip prob: {v_flip_prob}, Max rotate deg: {max_rotate_deg}, Rotate proba: {rotate_proba}, Rotate 90 proba: {rotate_90_proba}, Keep ratio: {keep_ratio}")
236
 
237
  for epoch in range(num_epochs):
238
 
239
+ if (epoch>0 and (epoch)%change_learning_rate == 0) or bad_test_loss>=3:
240
  learning_rate = 0.7*learning_rate
241
  optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
242
+ if best_model_state is not None:
243
+ model.load_state_dict(best_model_state)
244
  print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
245
  bad_test_loss = 0
246
  if epoch>0 and (epoch)==start_key:
 
329
 
330
 
331
  # Evaluate the model on the test set
 
 
 
 
 
 
 
332
  if eval_metric == 'loss':
333
  labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0,0,0,0,0,0
334
  avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
335
  print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
336
+ else:
337
+ avg_test_loss = 0
338
+ labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader,score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)
339
+ print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
340
+ avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
341
+ print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
342
 
343
  print(f"Time: {time.time() - start:.2f} [s]")
344
 
 
 
 
 
345
  if eval_metric == 'f1_score':
346
  metric_used = f1_score
347
  elif eval_metric == 'precision':
 
366
  epoch_test_loss.append(avg_test_loss)
367
 
368
  name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}"
369
+ metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss]
370
 
371
  if same >=1 :
 
372
  torch.save(best_model_state, './models/'+ name_model +'.pth')
373
  write_results(name_model,metrics_list,start_epoch)
374
  break
375
 
376
  if (epoch+1+start_epoch) % 5 == 0:
 
377
  torch.save(best_model_state, './models/'+ name_model +'.pth')
378
  model.load_state_dict(best_model_state)
379
  write_results(name_model,metrics_list,start_epoch)
 
383
  previous_test_loss = avg_test_loss
384
 
385
 
386
+ print(f"\n Total time: {(time.time() - start_tot)/60} minutes, Best Epoch is {best_epoch} with an {eval_metric} of {best_metrics:.4f}")
387
  if best_model_state:
 
388
  torch.save(best_model_state, './models/'+ name_model +'.pth')
389
  model.load_state_dict(best_model_state)
390
  write_results(name_model,metrics_list,start_epoch)
391
  print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}")
392
 
393
+ return model
modules/utils.py CHANGED
@@ -14,6 +14,46 @@ from torch.utils.data import DataLoader, Subset, ConcatDataset
14
  import streamlit as st
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  object_dict = {
18
  0: 'background',
19
  1: 'task',
@@ -90,17 +130,26 @@ def iou(box1, box2):
90
  return inter_area / union_area
91
 
92
  def proportion_inside(box1, box2):
 
 
 
 
 
 
 
 
 
 
 
 
93
  # Calculate the intersection of the two bounding boxes
94
- inter_box = [max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])]
95
  inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
96
 
97
- # Calculate the area of box1
98
- box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
99
-
100
- # Calculate the proportion of box1 inside box2
101
- if box1_area == 0:
102
  return 0
103
- proportion = inter_area / box1_area
104
 
105
  # Ensure the proportion is at most 100%
106
  return min(proportion, 1.0)
@@ -164,472 +213,6 @@ def resize_keypoints(keypoints: np.ndarray, original_size: tuple, target_size: t
164
  return keypoints
165
 
166
 
167
-
168
- class RandomCrop:
169
- def __init__(self, new_size=(1333,800),crop_fraction=0.5, min_objects=4):
170
- self.crop_fraction = crop_fraction
171
- self.min_objects = min_objects
172
- self.new_size = new_size
173
-
174
- def __call__(self, image, target):
175
- new_w1, new_h1 = self.new_size
176
- w, h = image.size
177
- new_w = int(w * self.crop_fraction)
178
- new_h = int(new_w*new_h1/new_w1)
179
-
180
- i=0
181
- for i in range(4):
182
- if new_h >= h:
183
- i += 0.05
184
- new_w = int(w * (self.crop_fraction - i))
185
- new_h = int(new_w*new_h1/new_w1)
186
- if new_h < h:
187
- continue
188
-
189
- if new_h >= h:
190
- return image, target
191
-
192
- boxes = target["boxes"]
193
- if 'keypoints' in target:
194
- keypoints = target["keypoints"]
195
- else:
196
- keypoints = []
197
- for i in range(len(boxes)):
198
- keypoints.append(torch.zeros((2,3)))
199
-
200
-
201
- # Attempt to find a suitable crop region
202
- success = False
203
- for _ in range(100): # Max 100 attempts to find a valid crop
204
- top = random.randint(0, h - new_h)
205
- left = random.randint(0, w - new_w)
206
- crop_region = [left, top, left + new_w, top + new_h]
207
-
208
- # Check how many objects are fully contained in this region
209
- contained_boxes = []
210
- contained_keypoints = []
211
- for box, kp in zip(boxes, keypoints):
212
- if box[0] >= crop_region[0] and box[1] >= crop_region[1] and box[2] <= crop_region[2] and box[3] <= crop_region[3]:
213
- # Adjust box and keypoints coordinates
214
- new_box = box - torch.tensor([crop_region[0], crop_region[1], crop_region[0], crop_region[1]])
215
- new_kp = kp - torch.tensor([crop_region[0], crop_region[1], 0])
216
- contained_boxes.append(new_box)
217
- contained_keypoints.append(new_kp)
218
-
219
- if len(contained_boxes) >= self.min_objects:
220
- success = True
221
- break
222
-
223
- if success:
224
- # Perform the actual crop
225
- image = F.crop(image, top, left, new_h, new_w)
226
- target["boxes"] = torch.stack(contained_boxes) if contained_boxes else torch.zeros((0, 4))
227
- if 'keypoints' in target:
228
- target["keypoints"] = torch.stack(contained_keypoints) if contained_keypoints else torch.zeros((0, 2, 4))
229
-
230
- return image, target
231
-
232
-
233
- class RandomFlip:
234
- def __init__(self, h_flip_prob=0.5, v_flip_prob=0.5):
235
- """
236
- Initializes the RandomFlip with probabilities for flipping.
237
-
238
- Parameters:
239
- - h_flip_prob (float): Probability of applying a horizontal flip to the image.
240
- - v_flip_prob (float): Probability of applying a vertical flip to the image.
241
- """
242
- self.h_flip_prob = h_flip_prob
243
- self.v_flip_prob = v_flip_prob
244
-
245
- def __call__(self, image, target):
246
- """
247
- Applies random horizontal and/or vertical flip to the image and updates target data accordingly.
248
-
249
- Parameters:
250
- - image (PIL Image): The image to be flipped.
251
- - target (dict): The target dictionary containing 'boxes' and 'keypoints'.
252
-
253
- Returns:
254
- - PIL Image, dict: The flipped image and its updated target dictionary.
255
- """
256
- if random.random() < self.h_flip_prob:
257
- image = F.hflip(image)
258
- w, _ = image.size # Get the new width of the image after flip for bounding box adjustment
259
- # Adjust bounding boxes for horizontal flip
260
- for i, box in enumerate(target['boxes']):
261
- xmin, ymin, xmax, ymax = box
262
- target['boxes'][i] = torch.tensor([w - xmax, ymin, w - xmin, ymax], dtype=torch.float32)
263
-
264
- # Adjust keypoints for horizontal flip
265
- if 'keypoints' in target:
266
- new_keypoints = []
267
- for keypoints_for_object in target['keypoints']:
268
- flipped_keypoints_for_object = []
269
- for kp in keypoints_for_object:
270
- x, y = kp[:2]
271
- new_x = w - x
272
- flipped_keypoints_for_object.append(torch.tensor([new_x, y] + list(kp[2:])))
273
- new_keypoints.append(torch.stack(flipped_keypoints_for_object))
274
- target['keypoints'] = torch.stack(new_keypoints)
275
-
276
- if random.random() < self.v_flip_prob:
277
- image = F.vflip(image)
278
- _, h = image.size # Get the new height of the image after flip for bounding box adjustment
279
- # Adjust bounding boxes for vertical flip
280
- for i, box in enumerate(target['boxes']):
281
- xmin, ymin, xmax, ymax = box
282
- target['boxes'][i] = torch.tensor([xmin, h - ymax, xmax, h - ymin], dtype=torch.float32)
283
-
284
- # Adjust keypoints for vertical flip
285
- if 'keypoints' in target:
286
- new_keypoints = []
287
- for keypoints_for_object in target['keypoints']:
288
- flipped_keypoints_for_object = []
289
- for kp in keypoints_for_object:
290
- x, y = kp[:2]
291
- new_y = h - y
292
- flipped_keypoints_for_object.append(torch.tensor([x, new_y] + list(kp[2:])))
293
- new_keypoints.append(torch.stack(flipped_keypoints_for_object))
294
- target['keypoints'] = torch.stack(new_keypoints)
295
-
296
- return image, target
297
-
298
-
299
- class RandomRotate:
300
- def __init__(self, max_rotate_deg=20, rotate_proba=0.3):
301
- """
302
- Initializes the RandomRotate with a maximum rotation angle and probability of rotating.
303
-
304
- Parameters:
305
- - max_rotate_deg (int): Maximum degree to rotate the image.
306
- - rotate_proba (float): Probability of applying rotation to the image.
307
- """
308
- self.max_rotate_deg = max_rotate_deg
309
- self.rotate_proba = rotate_proba
310
-
311
- def __call__(self, image, target):
312
- """
313
- Randomly rotates the image and updates the target data accordingly.
314
-
315
- Parameters:
316
- - image (PIL Image): The image to be rotated.
317
- - target (dict): The target dictionary containing 'boxes', 'labels', and 'keypoints'.
318
-
319
- Returns:
320
- - PIL Image, dict: The rotated image and its updated target dictionary.
321
- """
322
- if random.random() < self.rotate_proba:
323
- angle = random.uniform(-self.max_rotate_deg, self.max_rotate_deg)
324
- image = F.rotate(image, angle, expand=False, fill=200)
325
-
326
- # Rotate bounding boxes
327
- w, h = image.size
328
- cx, cy = w / 2, h / 2
329
- boxes = target["boxes"]
330
- new_boxes = []
331
- for box in boxes:
332
- new_box = self.rotate_box(box, angle, cx, cy)
333
- new_boxes.append(new_box)
334
- target["boxes"] = torch.stack(new_boxes)
335
-
336
- # Rotate keypoints
337
- if 'keypoints' in target:
338
- new_keypoints = []
339
- for keypoints in target["keypoints"]:
340
- new_kp = self.rotate_keypoints(keypoints, angle, cx, cy)
341
- new_keypoints.append(new_kp)
342
- target["keypoints"] = torch.stack(new_keypoints)
343
-
344
- return image, target
345
-
346
- def rotate_box(self, box, angle, cx, cy):
347
- """
348
- Rotates a bounding box by a given angle around the center of the image.
349
- """
350
- x1, y1, x2, y2 = box
351
- corners = torch.tensor([
352
- [x1, y1],
353
- [x2, y1],
354
- [x2, y2],
355
- [x1, y2]
356
- ])
357
- corners = torch.cat((corners, torch.ones(corners.shape[0], 1)), dim=1)
358
- M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
359
- corners = torch.matmul(torch.tensor(M, dtype=torch.float32), corners.T).T
360
- x_ = corners[:, 0]
361
- y_ = corners[:, 1]
362
- x_min, x_max = torch.min(x_), torch.max(x_)
363
- y_min, y_max = torch.min(y_), torch.max(y_)
364
- return torch.tensor([x_min, y_min, x_max, y_max], dtype=torch.float32)
365
-
366
- def rotate_keypoints(self, keypoints, angle, cx, cy):
367
- """
368
- Rotates keypoints by a given angle around the center of the image.
369
- """
370
- new_keypoints = []
371
- for kp in keypoints:
372
- x, y, v = kp
373
- point = torch.tensor([x, y, 1])
374
- M = cv2.getRotationMatrix2D((cx, cy), angle, 1)
375
- new_point = torch.matmul(torch.tensor(M, dtype=torch.float32), point)
376
- new_keypoints.append(torch.tensor([new_point[0], new_point[1], v], dtype=torch.float32))
377
- return torch.stack(new_keypoints)
378
-
379
- def rotate_90_box(box, angle, w, h):
380
- x1, y1, x2, y2 = box
381
- if angle == 90:
382
- return torch.tensor([y1,h-x2,y2,h-x1])
383
- elif angle == 270 or angle == -90:
384
- return torch.tensor([w-y2,x1,w-y1,x2])
385
- else:
386
- print("angle not supported")
387
-
388
- def rotate_90_keypoints(kp, angle, w, h):
389
- # Extract coordinates and visibility from each keypoint tensor
390
- x1, y1, v1 = kp[0][0], kp[0][1], kp[0][2]
391
- x2, y2, v2 = kp[1][0], kp[1][1], kp[1][2]
392
- # Swap x and y coordinates for each keypoint
393
- if angle == 90:
394
- new = [[y1, h-x1, v1], [y2, h-x2, v2]]
395
- elif angle == 270 or angle == -90:
396
- new = [[w-y1, x1, v1], [w-y2, x2, v2]]
397
-
398
- return torch.tensor(new, dtype=torch.float32)
399
-
400
-
401
- def rotate_vertical(image, target):
402
- # Rotate the image and target if the image is vertical
403
- new_boxes = []
404
- angle = random.choice([-90,90])
405
- image = F.rotate(image, angle, expand=True, fill=200)
406
- for box in target["boxes"]:
407
- new_box = rotate_90_box(box, angle, image.size[0], image.size[1])
408
- new_boxes.append(new_box)
409
- target["boxes"] = torch.stack(new_boxes)
410
-
411
- if 'keypoints' in target:
412
- new_kp = []
413
- for kp in target['keypoints']:
414
- new_key = rotate_90_keypoints(kp, angle, image.size[0], image.size[1])
415
- new_kp.append(new_key)
416
- target['keypoints'] = torch.stack(new_kp)
417
- return image, target
418
-
419
- class BPMN_Dataset(Dataset):
420
- def __init__(self, annotations, transform=None, crop_transform=None, crop_prob=0.3, rotate_90_proba=0.2, flip_transform=None, rotate_transform=None, new_size=(1333,800),keep_ratio=0.1,resize=True, model_type='object'):
421
- self.annotations = annotations
422
- print(f"Loaded {len(self.annotations)} annotations.")
423
- self.transform = transform
424
- self.crop_transform = crop_transform
425
- self.crop_prob = crop_prob
426
- self.flip_transform = flip_transform
427
- self.rotate_transform = rotate_transform
428
- self.resize = resize
429
- self.new_size = new_size
430
- self.keep_ratio = keep_ratio
431
- self.model_type = model_type
432
- if model_type == 'object':
433
- self.dict = object_dict
434
- elif model_type == 'arrow':
435
- self.dict = arrow_dict
436
- self.rotate_90_proba = rotate_90_proba
437
-
438
- def __len__(self):
439
- return len(self.annotations)
440
-
441
- def __getitem__(self, idx):
442
- annotation = self.annotations[idx]
443
- image = annotation.img.convert("RGB")
444
- boxes = torch.tensor(np.array(annotation.boxes_ltrb), dtype=torch.float32)
445
- labels_names = [ann for ann in annotation.categories]
446
-
447
- #only keep the labels, boxes and keypoints that are in the class_dict
448
- kept_indices = [i for i, ann in enumerate(annotation.categories) if ann in self.dict.values()]
449
- boxes = boxes[kept_indices]
450
- labels_names = [ann for i, ann in enumerate(labels_names) if i in kept_indices]
451
-
452
- labels_id = torch.tensor([(list(self.dict.values()).index(ann)) for ann in labels_names], dtype=torch.int64)
453
-
454
- # Initialize keypoints tensor
455
- max_keypoints = 2
456
- keypoints = torch.zeros((len(labels_id), max_keypoints, 3), dtype=torch.float32)
457
-
458
- ii=0
459
- for i, ann in enumerate(annotation.annotations):
460
- #only keep the keypoints that are in the kept indices
461
- if i not in kept_indices:
462
- continue
463
- if ann.category in ["sequenceFlow", "messageFlow", "dataAssociation"]:
464
- # Fill the keypoints tensor for this annotation, mark as visible (1)
465
- kp = np.array(ann.keypoints, dtype=np.float32).reshape(-1, 3)
466
- kp = kp[:,:2]
467
- visible = np.ones((kp.shape[0], 1), dtype=np.float32)
468
- kp = np.hstack([kp, visible])
469
- keypoints[ii, :kp.shape[0], :] = torch.tensor(kp, dtype=torch.float32)
470
- ii += 1
471
-
472
- area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
473
-
474
- if self.model_type == 'object':
475
- target = {
476
- "boxes": boxes,
477
- "labels": labels_id,
478
- #"area": area,
479
- #"keypoints": keypoints,
480
- }
481
- elif self.model_type == 'arrow':
482
- target = {
483
- "boxes": boxes,
484
- "labels": labels_id,
485
- #"area": area,
486
- "keypoints": keypoints,
487
- }
488
-
489
- # Randomly apply flip transform
490
- if self.flip_transform:
491
- image, target = self.flip_transform(image, target)
492
-
493
- # Randomly apply rotate transform
494
- if self.rotate_transform:
495
- image, target = self.rotate_transform(image, target)
496
-
497
- # Randomly apply the custom cropping transform
498
- if self.crop_transform and random.random() < self.crop_prob:
499
- image, target = self.crop_transform(image, target)
500
-
501
- # Rotate vertical image
502
- if random.random() < self.rotate_90_proba:
503
- image, target = rotate_vertical(image, target)
504
-
505
- if self.resize:
506
- if random.random() < self.keep_ratio:
507
- original_size = image.size
508
- # Calculate scale to fit the new size while maintaining aspect ratio
509
- scale = min(self.new_size[0] / original_size[0], self.new_size[1] / original_size[1])
510
- new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
511
-
512
- target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), (new_scaled_size))
513
- if 'area' in target:
514
- target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
515
-
516
- if 'keypoints' in target:
517
- for i in range(len(target['keypoints'])):
518
- target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), (new_scaled_size))
519
-
520
- # Resize image to new scaled size
521
- image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
522
-
523
- # Pad the resized image to make it exactly the desired size
524
- padding = [0, 0, self.new_size[0] - new_scaled_size[0], self.new_size[1] - new_scaled_size[1]]
525
- image = F.pad(image, padding, fill=200, padding_mode='constant')
526
- else:
527
- target['boxes'] = resize_boxes(target['boxes'], (image.size[0],image.size[1]), self.new_size)
528
- if 'area' in target:
529
- target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0])
530
- if 'keypoints' in target:
531
- for i in range(len(target['keypoints'])):
532
- target['keypoints'][i] = resize_keypoints(target['keypoints'][i], (image.size[0],image.size[1]), self.new_size)
533
- image = F.resize(image, (self.new_size[1], self.new_size[0]))
534
-
535
- return self.transform(image), target
536
-
537
- def collate_fn(batch):
538
- """
539
- Custom collation function for DataLoader that handles batches of images and targets.
540
-
541
- This function ensures that images are properly batched together using PyTorch's default collation,
542
- while keeping the targets (such as bounding boxes and labels) in a list of dictionaries,
543
- as each image might have a different number of objects detected.
544
-
545
- Parameters:
546
- - batch (list): A list of tuples, where each tuple contains an image and its corresponding target dictionary.
547
-
548
- Returns:
549
- - Tuple containing:
550
- - Tensor: Batched images.
551
- - List of dicts: Targets corresponding to each image in the batch.
552
- """
553
- images, targets = zip(*batch) # Unzip the batch into separate lists for images and targets.
554
-
555
- # Batch images using the default collate function which handles tensors, numpy arrays, numbers, etc.
556
- images = default_collate(images)
557
-
558
- return images, targets
559
-
560
-
561
-
562
- def create_loader(new_size,transformation, annotations1, annotations2=None,
563
- batch_size=4, crop_prob=0.2, crop_fraction=0.7, min_objects=3,
564
- h_flip_prob=0.3, v_flip_prob=0.3, max_rotate_deg=20, rotate_90_proba=0.2, rotate_proba=0.3,
565
- seed=42, resize=True, keep_ratio=0.1, model_type = 'object'):
566
- """
567
- Creates a DataLoader for BPMN datasets with optional transformations and concatenation of two datasets.
568
-
569
- Parameters:
570
- - transformation (callable): Transformation function to apply to each image (e.g., normalization).
571
- - annotations1 (list): Primary list of annotations.
572
- - annotations2 (list, optional): Secondary list of annotations to concatenate with the first.
573
- - batch_size (int): Number of images per batch.
574
- - crop_prob (float): Probability of applying the crop transformation.
575
- - crop_fraction (float): Fraction of the original width to use when cropping.
576
- - min_objects (int): Minimum number of objects required to be within the crop.
577
- - h_flip_prob (float): Probability of applying horizontal flip.
578
- - v_flip_prob (float): Probability of applying vertical flip.
579
- - seed (int): Seed for random number generators for reproducibility.
580
- - resize (bool): Flag indicating whether to resize images after transformations.
581
-
582
- Returns:
583
- - DataLoader: Configured data loader for the dataset.
584
- """
585
-
586
- # Initialize custom transformations for cropping and flipping
587
- custom_crop_transform = RandomCrop(new_size,crop_fraction, min_objects)
588
- custom_flip_transform = RandomFlip(h_flip_prob, v_flip_prob)
589
- custom_rotate_transform = RandomRotate(max_rotate_deg, rotate_proba)
590
-
591
- # Create the primary dataset
592
- dataset = BPMN_Dataset(
593
- annotations=annotations1,
594
- transform=transformation,
595
- crop_transform=custom_crop_transform,
596
- crop_prob=crop_prob,
597
- rotate_90_proba=rotate_90_proba,
598
- flip_transform=custom_flip_transform,
599
- rotate_transform=custom_rotate_transform,
600
- new_size=new_size,
601
- keep_ratio=keep_ratio,
602
- model_type=model_type,
603
- resize=resize
604
- )
605
-
606
- # Optionally concatenate a second dataset
607
- if annotations2:
608
- dataset2 = BPMN_Dataset(
609
- annotations=annotations2,
610
- transform=transformation,
611
- crop_transform=custom_crop_transform,
612
- crop_prob=crop_prob,
613
- rotate_90_proba=rotate_90_proba,
614
- flip_transform=custom_flip_transform,
615
- new_size=new_size,
616
- keep_ratio=keep_ratio,
617
- model_type=model_type,
618
- resize=resize
619
- )
620
- dataset = ConcatDataset([dataset, dataset2]) # Concatenate the two datasets
621
-
622
- # Set the seed for reproducibility in random operations within transformations and data loading
623
- random.seed(seed)
624
- torch.manual_seed(seed)
625
-
626
- # Create the DataLoader with the dataset
627
- data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
628
-
629
- return data_loader
630
-
631
-
632
-
633
  def write_results(name_model,metrics_list,start_epoch):
634
  with open('./results/'+ name_model+ '.txt', 'w') as f:
635
  for i in range(len(metrics_list[0])):
 
14
  import streamlit as st
15
 
16
 
17
+ """object_dict = {
18
+ 0: 'background',
19
+ 1: 'task',
20
+ 2: 'exclusiveGateway',
21
+ 3: 'eventBasedGateway',
22
+ 4: 'event',
23
+ 5: 'messageEvent',
24
+ 6: 'timerEvent',
25
+ 7: 'dataObject',
26
+ 8: 'dataStore',
27
+ 9: 'pool',
28
+ 10: 'lane',
29
+ }
30
+
31
+
32
+ arrow_dict = {
33
+ 0: 'background',
34
+ 1: 'sequenceFlow',
35
+ 2: 'dataAssociation',
36
+ 3: 'messageFlow',
37
+ }
38
+
39
+ class_dict = {
40
+ 0: 'background',
41
+ 1: 'task',
42
+ 2: 'exclusiveGateway',
43
+ 3: 'eventBasedGateway',
44
+ 4: 'event',
45
+ 5: 'messageEvent',
46
+ 6: 'timerEvent',
47
+ 7: 'dataObject',
48
+ 8: 'dataStore',
49
+ 9: 'pool',
50
+ 10: 'lane',
51
+ 11: 'sequenceFlow',
52
+ 12: 'dataAssociation',
53
+ 13: 'messageFlow',
54
+ }"""
55
+
56
+
57
  object_dict = {
58
  0: 'background',
59
  1: 'task',
 
130
  return inter_area / union_area
131
 
132
  def proportion_inside(box1, box2):
133
+ # Calculate the areas of both boxes
134
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
135
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
136
+
137
+ # Determine the bigger and smaller boxes
138
+ if box1_area > box2_area:
139
+ big_box = box1
140
+ small_box = box2
141
+ else:
142
+ big_box = box2
143
+ small_box = box1
144
+
145
  # Calculate the intersection of the two bounding boxes
146
+ inter_box = [max(small_box[0], big_box[0]), max(small_box[1], big_box[1]), min(small_box[2], big_box[2]), min(small_box[3], big_box[3])]
147
  inter_area = max(0, inter_box[2] - inter_box[0]) * max(0, inter_box[3] - inter_box[1])
148
 
149
+ # Calculate the proportion of the smaller box inside the bigger box
150
+ if (small_box[2] - small_box[0]) * (small_box[3] - small_box[1]) == 0:
 
 
 
151
  return 0
152
+ proportion = inter_area / ((small_box[2] - small_box[0]) * (small_box[3] - small_box[1]))
153
 
154
  # Ensure the proportion is at most 100%
155
  return min(proportion, 1.0)
 
213
  return keypoints
214
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def write_results(name_model,metrics_list,start_epoch):
217
  with open('./results/'+ name_model+ '.txt', 'w') as f:
218
  for i in range(len(metrics_list[0])):