BenjiELCA commited on
Commit
3a0ed7b
1 Parent(s): 6ceb9bd

correction of infinit loop bug

Browse files
Files changed (4) hide show
  1. app.py +1 -0
  2. modules/eval.py +28 -51
  3. modules/streamlit_utils.py +53 -12
  4. modules/toXML.py +8 -4
app.py CHANGED
@@ -7,6 +7,7 @@ from modules.streamlit_utils import *
7
 
8
 
9
  def main():
 
10
  is_mobile, screen_width = configure_page()
11
  display_banner(is_mobile)
12
  display_title(is_mobile)
 
7
 
8
 
9
  def main():
10
+ st.session_state.first_run = True
11
  is_mobile, screen_width = configure_page()
12
  display_banner(is_mobile)
13
  display_title(is_mobile)
modules/eval.py CHANGED
@@ -173,20 +173,6 @@ def mix_predictions(objects_pred, arrow_pred):
173
 
174
 
175
  def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.3):
176
- """
177
- Regroups elements by the pool they belong to, and creates a single new pool for elements that are not in any existing pool.
178
- Filters out pools that have an IoU greater than the specified threshold.
179
-
180
- Parameters:
181
- - boxes (list): List of bounding boxes.
182
- - labels (list): List of labels corresponding to each bounding box.
183
- - class_dict (dict): Dictionary mapping class indices to class names.
184
- - iou_threshold (float): IoU threshold for filtering pools.
185
-
186
- Returns:
187
- - dict: A dictionary where each key is a pool's index and the value is a list of elements within that pool.
188
- """
189
- # Initialize a dictionary to hold the elements in each pool
190
  pool_dict = {}
191
 
192
  # Filter out pools with IoU greater than the threshold
@@ -197,32 +183,25 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
197
  if iou(np.array(boxes[i]), np.array(boxes[j])) > iou_threshold:
198
  to_delete.append(j)
199
 
200
-
201
  boxes = np.delete(boxes, to_delete, axis=0)
202
  labels = np.delete(labels, to_delete)
203
  scores = np.delete(scores, to_delete)
204
  keypoints = np.delete(keypoints, to_delete, axis=0)
205
 
206
- # Identify the bounding boxes of the pools
207
- pool_indices = [i for i, label in enumerate(labels) if (class_dict[label.item()] == 'pool')]
208
  pool_boxes = [boxes[i] for i in pool_indices]
209
 
210
-
211
  if pool_indices:
212
- # Initialize each pool index with an empty list
213
  for pool_index in pool_indices:
214
  pool_dict[pool_index] = []
215
 
216
- # Initialize a list for elements not in any pool
217
  elements_not_in_pool = []
218
 
219
- # Iterate over all elements
220
  for i, box in enumerate(boxes):
221
- if i in pool_indices or class_dict[labels[i]] == 'messageFlow' or class_dict[labels[i]] == 'pool':
222
- continue # Skip pool boxes themselves and messageFlow elements
223
  assigned_to_pool = False
 
 
224
  for j, pool_box in enumerate(pool_boxes):
225
- # Check if the element is within the pool's bounding box
226
  if (box[0] >= pool_box[0] and box[1] >= pool_box[1] and
227
  box[2] <= pool_box[2] and box[3] <= pool_box[3]):
228
  pool_index = pool_indices[j]
@@ -230,44 +209,44 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
230
  assigned_to_pool = True
231
  break
232
  if not assigned_to_pool:
233
- if class_dict[labels[i]] != 'messageFlow' and class_dict[labels[i]] != 'lane' or class_dict[labels[i]] != 'pool':
234
  elements_not_in_pool.append(i)
235
 
236
- if elements_not_in_pool:
237
- elements_not_in_pool_to_delete = []
238
- #find the messageflow,pool and lane in the elements_not_in_pool
239
- for i in elements_not_in_pool:
240
- if class_dict[labels[i]] == 'messageFlow' or class_dict[labels[i]] == 'lane' or class_dict[labels[i]] == 'pool':
241
- elements_not_in_pool_to_delete.append(i)
242
- #delete the messageflow from the elements_not_in_pool
243
- new_elements_not_in_pool = [i for i in elements_not_in_pool if i not in elements_not_in_pool_to_delete]
244
- count = 0
245
- for i in elements_not_in_pool:
246
- if labels[i] != list(class_dict.values()).index('sequenceFlow') or labels[i] != list(class_dict.values()).index('messageFlow'):
247
- count += 1
248
- #check if there is only sequenceFlow or messageFlow in the new pool
249
- if all([labels[i] in [list(class_dict.values()).index('sequenceFlow'),
250
- list(class_dict.values()).index('messageFlow'),
251
- list(class_dict.values()).index('dataAssociation')] for i in new_elements_not_in_pool]):
252
- print('The new pool contains only sequenceFlow or messageFlow')
253
- elif len(new_elements_not_in_pool) > 1 and count > 1:
254
  new_pool_index = len(labels)
255
- size_elements = get_size_elements(1)
256
- box = calculate_pool_bounds(boxes,labels, new_elements_not_in_pool, size_elements)
257
  boxes = np.append(boxes, [box], axis=0)
258
  labels = np.append(labels, list(class_dict.values()).index('pool'))
259
  scores = np.append(scores, 1.0)
260
  keypoints = np.append(keypoints, np.zeros((1, 2, 3)), axis=0)
261
  pool_dict[new_pool_index] = new_elements_not_in_pool
262
  print(f"Created a new pool index {new_pool_index} with elements: {new_elements_not_in_pool}")
 
 
 
 
 
 
 
 
 
 
263
 
264
- # Separate empty pools
265
  non_empty_pools = {k: v for k, v in pool_dict.items() if v}
266
  empty_pools = {k: v for k, v in pool_dict.items() if not v}
267
-
268
- # Merge non-empty pools followed by empty pools
269
  pool_dict = {**non_empty_pools, **empty_pools}
270
-
271
  return pool_dict, boxes, labels, scores, keypoints
272
 
273
 
@@ -449,8 +428,6 @@ def develop_prediction(boxes, labels, scores, keypoints, class_dict, correction=
449
 
450
  bpmn_id, pool_dict = create_BPMN_id(labels,pool_dict)
451
 
452
- print('Pool dict:', pool_dict)
453
-
454
  # Create links between elements
455
  flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
456
 
 
173
 
174
 
175
  def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_threshold=0.3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  pool_dict = {}
177
 
178
  # Filter out pools with IoU greater than the threshold
 
183
  if iou(np.array(boxes[i]), np.array(boxes[j])) > iou_threshold:
184
  to_delete.append(j)
185
 
 
186
  boxes = np.delete(boxes, to_delete, axis=0)
187
  labels = np.delete(labels, to_delete)
188
  scores = np.delete(scores, to_delete)
189
  keypoints = np.delete(keypoints, to_delete, axis=0)
190
 
191
+ pool_indices = [i for i, label in enumerate(labels) if class_dict[label.item()] == 'pool']
 
192
  pool_boxes = [boxes[i] for i in pool_indices]
193
 
 
194
  if pool_indices:
 
195
  for pool_index in pool_indices:
196
  pool_dict[pool_index] = []
197
 
 
198
  elements_not_in_pool = []
199
 
 
200
  for i, box in enumerate(boxes):
 
 
201
  assigned_to_pool = False
202
+ if i in pool_indices or class_dict[labels[i]] in ['messageFlow', 'pool']:
203
+ continue
204
  for j, pool_box in enumerate(pool_boxes):
 
205
  if (box[0] >= pool_box[0] and box[1] >= pool_box[1] and
206
  box[2] <= pool_box[2] and box[3] <= pool_box[3]):
207
  pool_index = pool_indices[j]
 
209
  assigned_to_pool = True
210
  break
211
  if not assigned_to_pool:
212
+ if class_dict[labels[i]] not in ['messageFlow', 'lane', 'pool']:
213
  elements_not_in_pool.append(i)
214
 
215
+ if len(elements_not_in_pool) > 1:
216
+ new_elements_not_in_pool = [i for i in elements_not_in_pool if class_dict[labels[i]] not in ['messageFlow', 'lane', 'pool']]
217
+
218
+ # Indices of relevant classes
219
+ sequence_flow_index = list(class_dict.values()).index('sequenceFlow')
220
+ message_flow_index = list(class_dict.values()).index('messageFlow')
221
+ data_association_index = list(class_dict.values()).index('dataAssociation')
222
+
223
+ if all(labels[i] in {sequence_flow_index, message_flow_index, data_association_index} for i in new_elements_not_in_pool):
224
+ print('The new pool contains only sequenceFlow, messageFlow, or dataAssociation')
225
+
226
+ elif len(new_elements_not_in_pool) > 1:
 
 
 
 
 
 
227
  new_pool_index = len(labels)
228
+ box = calculate_pool_bounds(boxes, labels, new_elements_not_in_pool, None)
 
229
  boxes = np.append(boxes, [box], axis=0)
230
  labels = np.append(labels, list(class_dict.values()).index('pool'))
231
  scores = np.append(scores, 1.0)
232
  keypoints = np.append(keypoints, np.zeros((1, 2, 3)), axis=0)
233
  pool_dict[new_pool_index] = new_elements_not_in_pool
234
  print(f"Created a new pool index {new_pool_index} with elements: {new_elements_not_in_pool}")
235
+ else:
236
+ all_elements = [i for i in range(len(boxes))]
237
+ new_pool_index = len(labels)
238
+ box = calculate_pool_bounds(boxes, labels, all_elements, None)
239
+ boxes = np.append(boxes, [box], axis=0)
240
+ labels = np.append(labels, list(class_dict.values()).index('pool'))
241
+ scores = np.append(scores, 1.0)
242
+ keypoints = np.append(keypoints, np.zeros((1, 2, 3)), axis=0)
243
+ pool_dict[new_pool_index] = all_elements
244
+ print(f"Created a super pool {new_pool_index} with elements: {all_elements}")
245
 
 
246
  non_empty_pools = {k: v for k, v in pool_dict.items() if v}
247
  empty_pools = {k: v for k, v in pool_dict.items() if not v}
 
 
248
  pool_dict = {**non_empty_pools, **empty_pools}
249
+
250
  return pool_dict, boxes, labels, scores, keypoints
251
 
252
 
 
428
 
429
  bpmn_id, pool_dict = create_BPMN_id(labels,pool_dict)
430
 
 
 
431
  # Create links between elements
432
  flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
433
 
modules/streamlit_utils.py CHANGED
@@ -315,20 +315,22 @@ def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
315
  def modify_results(percentage_text_dist_thresh=0.5):
316
  with st.expander("Method and Style modification (beta version)"):
317
  label_list = list(object_dict.values())
318
- bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
 
 
 
 
 
319
  for i in range(len(bboxes)):
320
  bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
321
  bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
322
- labels = [int(label) for label in st.session_state.prediction['labels']]
323
-
324
- #st.session_state.arrow_predic = [boxes, labels, scores, keypoints]
325
 
326
  arrow_bboxes = st.session_state.arrow_pred['boxes']
327
  arrow_labels = st.session_state.arrow_pred['labels']
328
  arrow_score = st.session_state.arrow_pred['scores']
329
  arrow_keypoints = st.session_state.arrow_pred['keypoints']
330
 
331
- # Filter boxes and labels where label is less than 12
332
  object_bboxes = []
333
  object_labels = []
334
  for i in range(len(bboxes)):
@@ -338,22 +340,58 @@ def modify_results(percentage_text_dist_thresh=0.5):
338
 
339
  uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
340
 
341
- new_labels = detection(
342
  image=uploaded_image, bboxes=object_bboxes, labels=object_labels,
343
  label_list=label_list, line_width=3, width=2000, use_space=False
344
  )
345
 
346
- if new_labels is not None:
347
- new_lab = np.array([label['label_id'] for label in new_labels])
 
348
  # Convert back to original format
349
- bboxes = np.array([label['bbox'] for label in new_labels])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  for i in range(len(bboxes)):
351
  bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
352
  bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
353
 
354
  object_scores = []
355
  object_keypoints = []
356
- for i in range(len(new_labels)):
357
  object_scores.append(1.0)
358
  object_keypoints.append([[0, 0, 0], [0, 0, 0]])
359
 
@@ -363,12 +401,15 @@ def modify_results(percentage_text_dist_thresh=0.5):
363
  new_keypoints = np.concatenate((object_keypoints, arrow_keypoints))
364
 
365
 
366
- boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, new_scores, new_keypoints, class_dict, correction=False)
367
 
368
  st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict)
369
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
370
 
371
- st.rerun()
 
 
 
372
 
373
 
374
  def display_bpmn_modeler(is_mobile, screen_width):
 
315
  def modify_results(percentage_text_dist_thresh=0.5):
316
  with st.expander("Method and Style modification (beta version)"):
317
  label_list = list(object_dict.values())
318
+ if st.session_state.prediction['labels'][-1] == 6:
319
+ bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes'][:-1]]
320
+ labels = [int(label) for label in st.session_state.prediction['labels'][:-1]]
321
+ else:
322
+ bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
323
+ labels = [int(label) for label in st.session_state.prediction['labels']]
324
  for i in range(len(bboxes)):
325
  bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
326
  bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
 
 
 
327
 
328
  arrow_bboxes = st.session_state.arrow_pred['boxes']
329
  arrow_labels = st.session_state.arrow_pred['labels']
330
  arrow_score = st.session_state.arrow_pred['scores']
331
  arrow_keypoints = st.session_state.arrow_pred['keypoints']
332
 
333
+ # Filter boxes and labels where label is less than 12 to only have objects
334
  object_bboxes = []
335
  object_labels = []
336
  for i in range(len(bboxes)):
 
340
 
341
  uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
342
 
343
+ new_data = detection(
344
  image=uploaded_image, bboxes=object_bboxes, labels=object_labels,
345
  label_list=label_list, line_width=3, width=2000, use_space=False
346
  )
347
 
348
+ if new_data is not None:
349
+ changes = False
350
+ new_lab = np.array([data['label_id'] for data in new_data])
351
  # Convert back to original format
352
+ bboxes = np.array([data['bbox'] for data in new_data])
353
+ object_bboxes = np.array(object_bboxes)
354
+
355
+ # Order bboxes and labels
356
+ order = np.argsort(bboxes[:, 0])
357
+ bboxes = bboxes[order]
358
+ new_lab = new_lab[order]
359
+
360
+ order2 = np.argsort(object_bboxes[:, 0])
361
+ object_bboxes = object_bboxes[order2]
362
+ object_labels = np.array(object_labels)[order2]
363
+
364
+ # Make all values of bboxes integers
365
+ bboxes = bboxes.astype(int)
366
+
367
+ tolerance = 1
368
+
369
+ object_labels = np.array(object_labels)
370
+
371
+
372
+ if len(object_bboxes) == len(bboxes):
373
+ # Calculate absolute differences
374
+ abs_diff = np.abs(object_bboxes - bboxes)
375
+
376
+ for i in range(len(object_bboxes)):
377
+ for j in range(len(object_bboxes[i])):
378
+ if abs_diff[i][j] > tolerance:
379
+ changes = True
380
+ break
381
+
382
+ #check if labels are the same
383
+ if not np.array_equal(object_labels, new_lab):
384
+ changes = True
385
+ else:
386
+ changes = True
387
+
388
  for i in range(len(bboxes)):
389
  bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
390
  bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
391
 
392
  object_scores = []
393
  object_keypoints = []
394
+ for i in range(len(new_data)):
395
  object_scores.append(1.0)
396
  object_keypoints.append([[0, 0, 0], [0, 0, 0]])
397
 
 
401
  new_keypoints = np.concatenate((object_keypoints, arrow_keypoints))
402
 
403
 
404
+ boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, new_scores, new_keypoints, class_dict, correction=True)
405
 
406
  st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, bpmn_id, flow_links, best_points, pool_dict)
407
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
408
 
409
+ if changes:
410
+ st.rerun()
411
+
412
+
413
 
414
 
415
  def display_bpmn_modeler(is_mobile, screen_width):
modules/toXML.py CHANGED
@@ -176,7 +176,6 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
176
 
177
  # Create BPMN process elements
178
  process = []
179
- print(full_pred['pool_dict'])
180
  for idx in range (len(full_pred['pool_dict'].items())):
181
  process_id = f'process_{idx+1}'
182
  process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false'))
@@ -544,15 +543,20 @@ def calculate_pool_bounds(boxes, labels, keep_elements, size):
544
  if element in {None, 7, 13, 14, 15}:
545
  continue
546
 
547
- x, y = boxes[i][:2]
548
- element_width, element_height = size[class_dict[labels[i]]]
 
 
 
 
549
 
 
550
  min_x = min(min_x, x)
551
  min_y = min(min_y, y)
552
  max_x = max(max_x, x + element_width)
553
  max_y = max(max_y, y + element_height)
554
 
555
- return min_x-50, min_y-50, max_x+100, max_y+50
556
 
557
 
558
 
 
176
 
177
  # Create BPMN process elements
178
  process = []
 
179
  for idx in range (len(full_pred['pool_dict'].items())):
180
  process_id = f'process_{idx+1}'
181
  process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false'))
 
543
  if element in {None, 7, 13, 14, 15}:
544
  continue
545
 
546
+
547
+ if size == None:
548
+ element_width = boxes[i][2] - boxes[i][0]
549
+ element_height = boxes[i][3] - boxes[i][1]
550
+ else:
551
+ element_width, element_height = size[class_dict[labels[i]]]
552
 
553
+ x, y = boxes[i][:2]
554
  min_x = min(min_x, x)
555
  min_y = min(min_y, y)
556
  max_x = max(max_x, x + element_width)
557
  max_y = max(max_y, y + element_height)
558
 
559
+ return min_x-50, min_y-50, max_x+50, max_y+50
560
 
561
 
562