BenjiELCA commited on
Commit
cc79c19
·
1 Parent(s): 5f3e9f4

ready for demo

Browse files
Files changed (4) hide show
  1. modules/display.py +10 -1
  2. modules/eval.py +29 -14
  3. modules/streamlit_utils.py +18 -32
  4. modules/toXML.py +32 -22
modules/display.py CHANGED
@@ -52,10 +52,19 @@ def draw_stream(image,
52
  - resize (bool): Whether to resize annotations to fit the image size.
53
  """
54
 
 
 
 
 
 
 
 
 
 
 
55
  # Convert image to RGB (if not already in that format)
56
  if prediction is None:
57
  image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
58
-
59
 
60
  image_copy = image.copy()
61
  scale = max(image.shape[0], image.shape[1]) / 1000
 
52
  - resize (bool): Whether to resize annotations to fit the image size.
53
  """
54
 
55
+ #delete the global pool if it is the only one to show
56
+ """if len(prediction['pool_dict'])==1 and prediction['labels'][-1]==6:
57
+ pool_index = list(prediction['pool_dict'])[0]
58
+ if len(prediction['pool_dict'][pool_index])==(len(prediction['boxes'])-1):
59
+ prediction['boxes'] = prediction['boxes'][:-1]
60
+ prediction['labels'] = prediction['labels'][:-1]
61
+ prediction['scores'] = prediction['scores'][:-1]
62
+ prediction['keypoints'] = prediction['keypoints'][:-1]
63
+ prediction['links'] = prediction['links'][:-1]"""
64
+
65
  # Convert image to RGB (if not already in that format)
66
  if prediction is None:
67
  image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
 
68
 
69
  image_copy = image.copy()
70
  scale = max(image.shape[0], image.shape[1]) / 1000
modules/eval.py CHANGED
@@ -2,8 +2,9 @@ import numpy as np
2
  import torch
3
  from modules.utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou
4
  from tqdm import tqdm
5
- from modules.toXML import create_BPMN_id
6
  from modules.utils import is_vertical
 
7
 
8
 
9
  def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
@@ -207,11 +208,7 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
207
  pool_boxes = [boxes[i] for i in pool_indices]
208
 
209
 
210
- if not pool_indices:
211
- # If no pools or lanes are detected, create a single pool with all elements
212
- labels = np.append(labels, list(class_dict.values()).index('pool'))
213
- pool_dict[len(labels) - 1] = list(range(len(boxes)))
214
- else:
215
  # Initialize each pool index with an empty list
216
  for pool_index in pool_indices:
217
  pool_dict[pool_index] = []
@@ -250,10 +247,15 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
250
  count += 1
251
  if len(elements_not_in_pool) > 1 and count > 1:
252
  new_pool_index = len(labels)
 
 
 
253
  labels = np.append(labels, list(class_dict.values()).index('pool'))
 
 
254
  pool_dict[new_pool_index] = elements_not_in_pool
255
  print(f"Created a new pool index {new_pool_index} with elements: {elements_not_in_pool}")
256
-
257
  # Separate empty pools
258
  non_empty_pools = {k: v for k, v in pool_dict.items() if v}
259
  empty_pools = {k: v for k, v in pool_dict.items() if not v}
@@ -296,6 +298,8 @@ def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
296
  data_association_index = list(class_dict.values()).index('dataAssociation')
297
  data_object_index = list(class_dict.values()).index('dataObject')
298
  data_store_index = list(class_dict.values()).index('dataStore')
 
 
299
 
300
  for pool_index, elements in pool_dict.items():
301
  print(f"Pool {pool_index} contains elements: {elements}")
@@ -326,9 +330,12 @@ def correction_labels(boxes, labels, class_dict, pool_dict, flow_links):
326
  label2 = labels[id2]
327
  if data_object_index in {label1, label2} or data_store_index in {label1, label2}:
328
  continue
329
- else:
330
  print('Change the link from dataAssociation to messageFlow')
331
  labels[i] = message_flow_index
 
 
 
332
 
333
  return labels, flow_links
334
 
@@ -355,6 +362,10 @@ def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_d
355
  delete_pool.append(pool_index)
356
  print(f"Pool {pool_index} is too small, deleting it")
357
 
 
 
 
 
358
 
359
  delete_elements = []
360
  # Check if there is an arrow that has the same links
@@ -374,7 +385,7 @@ def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_d
374
  delete_elements = delete_elements + delete_pool
375
  #delete double value in delete_elements
376
  delete_elements = list(set(delete_elements))
377
-
378
  boxes = np.delete(boxes, delete_elements, axis=0)
379
  labels = np.delete(labels, delete_elements)
380
  scores = np.delete(scores, delete_elements)
@@ -424,12 +435,16 @@ def generate_data(image, boxes, labels, scores, keypoints, flow_links, best_poin
424
  return data
425
 
426
  def develop_prediction(boxes, labels, scores, keypoints, class_dict, correction=True):
 
427
  pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
 
428
  # Create links between elements
429
  flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
 
430
  #Correct the labels of some sequenceflow that cross multiple pool
431
  if correction:
432
  labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
 
433
  #give a link to event to allow the creation of the BPMN id with start, indermediate and end event
434
  flow_links = give_link_to_element(flow_links, labels)
435
 
@@ -439,6 +454,7 @@ def develop_prediction(boxes, labels, scores, keypoints, class_dict, correction=
439
  labels[i] = list(class_dict.values()).index('dataObject')
440
 
441
  boxes,labels,scores,keypoints,flow_links,best_points,pool_dict = last_correction(boxes,labels,scores,keypoints,flow_links,best_points, pool_dict)
 
442
 
443
  return boxes, labels, scores, keypoints, flow_links, best_points, pool_dict
444
 
@@ -452,14 +468,13 @@ def full_prediction(model_object, model_arrow, image, score_threshold=0.5, iou_t
452
  with torch.no_grad(): # Disable gradient calculation for inference
453
  _, objects_pred = object_prediction(model_object, image, score_threshold=score_threshold, iou_threshold=0.1)
454
  _, arrow_pred = arrow_prediction(model_arrow, image, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
455
-
456
- #print('Object prediction:', objects_pred)
457
-
458
 
 
 
459
  boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
460
-
461
  boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(boxes, labels, scores, keypoints, class_dict)
462
-
463
  image = image.permute(1, 2, 0).cpu().numpy()
464
  image = (image * 255).astype(np.uint8)
465
 
 
2
  import torch
3
  from modules.utils import class_dict, object_dict, arrow_dict, find_closest_object, find_other_keypoint, filter_overlap_boxes, iou
4
  from tqdm import tqdm
5
+ from modules.toXML import get_size_elements, calculate_pool_bounds, create_BPMN_id
6
  from modules.utils import is_vertical
7
+ import streamlit as st
8
 
9
 
10
  def non_maximum_suppression(boxes, scores, labels=None, iou_threshold=0.5):
 
208
  pool_boxes = [boxes[i] for i in pool_indices]
209
 
210
 
211
+ if pool_indices:
 
 
 
 
212
  # Initialize each pool index with an empty list
213
  for pool_index in pool_indices:
214
  pool_dict[pool_index] = []
 
247
  count += 1
248
  if len(elements_not_in_pool) > 1 and count > 1:
249
  new_pool_index = len(labels)
250
+ size_elements = get_size_elements(1)
251
+ box = calculate_pool_bounds(boxes,labels, elements_not_in_pool, size_elements)
252
+ boxes = np.append(boxes, [box], axis=0)
253
  labels = np.append(labels, list(class_dict.values()).index('pool'))
254
+ scores = np.append(scores, 1.0)
255
+ keypoints = np.append(keypoints, np.zeros((1, 2, 3)), axis=0)
256
  pool_dict[new_pool_index] = elements_not_in_pool
257
  print(f"Created a new pool index {new_pool_index} with elements: {elements_not_in_pool}")
258
+
259
  # Separate empty pools
260
  non_empty_pools = {k: v for k, v in pool_dict.items() if v}
261
  empty_pools = {k: v for k, v in pool_dict.items() if not v}
 
298
  data_association_index = list(class_dict.values()).index('dataAssociation')
299
  data_object_index = list(class_dict.values()).index('dataObject')
300
  data_store_index = list(class_dict.values()).index('dataStore')
301
+ message_event_index = list(class_dict.values()).index('messageEvent')
302
+ senquence_flow_indexx = list(class_dict.values()).index('sequenceFlow')
303
 
304
  for pool_index, elements in pool_dict.items():
305
  print(f"Pool {pool_index} contains elements: {elements}")
 
330
  label2 = labels[id2]
331
  if data_object_index in {label1, label2} or data_store_index in {label1, label2}:
332
  continue
333
+ elif message_event_index in {label1, label2}:
334
  print('Change the link from dataAssociation to messageFlow')
335
  labels[i] = message_flow_index
336
+ else:
337
+ print('Change the link from dataAssociation to sequenceFlow')
338
+ labels[i] = senquence_flow_indexx
339
 
340
  return labels, flow_links
341
 
 
362
  delete_pool.append(pool_index)
363
  print(f"Pool {pool_index} is too small, deleting it")
364
 
365
+ if is_vertical(boxes[pool_index]):
366
+ delete_pool.append(pool_index)
367
+ print(f"Pool {pool_index} is vertical, deleting it")
368
+
369
 
370
  delete_elements = []
371
  # Check if there is an arrow that has the same links
 
385
  delete_elements = delete_elements + delete_pool
386
  #delete double value in delete_elements
387
  delete_elements = list(set(delete_elements))
388
+
389
  boxes = np.delete(boxes, delete_elements, axis=0)
390
  labels = np.delete(labels, delete_elements)
391
  scores = np.delete(scores, delete_elements)
 
435
  return data
436
 
437
  def develop_prediction(boxes, labels, scores, keypoints, class_dict, correction=True):
438
+ print("Lenghts 1", len(boxes), len(labels), len(scores), len(keypoints))
439
  pool_dict, boxes, labels, scores, keypoints = regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict)
440
+ print("Lenghts 2", len(boxes), len(labels), len(scores), len(keypoints))
441
  # Create links between elements
442
  flow_links, best_points = create_links(keypoints, boxes, labels, class_dict)
443
+
444
  #Correct the labels of some sequenceflow that cross multiple pool
445
  if correction:
446
  labels, flow_links = correction_labels(boxes, labels, class_dict, pool_dict, flow_links)
447
+
448
  #give a link to event to allow the creation of the BPMN id with start, indermediate and end event
449
  flow_links = give_link_to_element(flow_links, labels)
450
 
 
454
  labels[i] = list(class_dict.values()).index('dataObject')
455
 
456
  boxes,labels,scores,keypoints,flow_links,best_points,pool_dict = last_correction(boxes,labels,scores,keypoints,flow_links,best_points, pool_dict)
457
+ print("Lenghts 3", len(boxes), len(labels), len(scores), len(keypoints))
458
 
459
  return boxes, labels, scores, keypoints, flow_links, best_points, pool_dict
460
 
 
468
  with torch.no_grad(): # Disable gradient calculation for inference
469
  _, objects_pred = object_prediction(model_object, image, score_threshold=score_threshold, iou_threshold=0.1)
470
  _, arrow_pred = arrow_prediction(model_arrow, image, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
 
 
 
471
 
472
+ st.session_state.arrow_pred = arrow_pred
473
+
474
  boxes, labels, scores, keypoints = mix_predictions(objects_pred, arrow_pred)
475
+
476
  boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(boxes, labels, scores, keypoints, class_dict)
477
+
478
  image = image.permute(1, 2, 0).cpu().numpy()
479
  image = (image * 255).astype(np.uint8)
480
 
modules/streamlit_utils.py CHANGED
@@ -310,20 +310,20 @@ def modify_results(percentage_text_dist_thresh=0.5):
310
  bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
311
  labels = [int(label) for label in st.session_state.prediction['labels']]
312
 
 
 
 
 
 
 
 
313
  # Filter boxes and labels where label is less than 12
314
  object_bboxes = []
315
- object_labels = []
316
- arrow_bboxes = []
317
- arrow_labels = []
318
  for i in range(len(bboxes)):
319
  if labels[i] <= 12:
320
  object_bboxes.append(bboxes[i])
321
  object_labels.append(labels[i])
322
- else:
323
- arrow_bboxes.append(bboxes[i])
324
- arrow_labels.append(labels[i])
325
-
326
- original_obj_len = len(object_bboxes)
327
 
328
  uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
329
 
@@ -339,34 +339,20 @@ def modify_results(percentage_text_dist_thresh=0.5):
339
  for i in range(len(bboxes)):
340
  bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
341
  bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
342
- for i in range(len(arrow_bboxes)):
343
- arrow_bboxes[i][2] = arrow_bboxes[i][2] + arrow_bboxes[i][0]
344
- arrow_bboxes[i][3] = arrow_bboxes[i][3] + arrow_bboxes[i][1]
 
 
 
345
 
346
  new_bbox = np.concatenate((bboxes, arrow_bboxes))
347
  new_lab = np.concatenate((new_lab, arrow_labels))
 
 
348
 
349
- scores = st.session_state.prediction['scores']
350
- keypoints = st.session_state.prediction['keypoints']
351
-
352
- #delete element in keypoints to make it match the new number of boxes
353
- keypoints = keypoints.tolist()
354
- scores = scores.tolist()
355
-
356
- diff = original_obj_len-len(bboxes)
357
- if diff > 0:
358
- for i in range(diff):
359
- keypoints.pop(0)
360
- scores.pop(0)
361
- elif diff < 0:
362
- for i in range(-diff):
363
- keypoints.insert(0, [[0, 0, 0], [0, 0, 0]])
364
- scores.insert(0, 0.0)
365
-
366
- keypoints = np.array(keypoints)
367
- scores = np.array(scores)
368
-
369
- boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, scores, keypoints, class_dict, correction=False)
370
 
371
  st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
372
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
 
310
  bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
311
  labels = [int(label) for label in st.session_state.prediction['labels']]
312
 
313
+ #st.session_state.arrow_predic = [boxes, labels, scores, keypoints]
314
+
315
+ arrow_bboxes = st.session_state.arrow_pred['boxes']
316
+ arrow_labels = st.session_state.arrow_pred['labels']
317
+ arrow_score = st.session_state.arrow_pred['scores']
318
+ arrow_keypoints = st.session_state.arrow_pred['keypoints']
319
+
320
  # Filter boxes and labels where label is less than 12
321
  object_bboxes = []
322
+ object_labels = []
 
 
323
  for i in range(len(bboxes)):
324
  if labels[i] <= 12:
325
  object_bboxes.append(bboxes[i])
326
  object_labels.append(labels[i])
 
 
 
 
 
327
 
328
  uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
329
 
 
339
  for i in range(len(bboxes)):
340
  bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
341
  bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
342
+
343
+ object_scores = []
344
+ object_keypoints = []
345
+ for i in range(len(new_labels)):
346
+ object_scores.append(1.0)
347
+ object_keypoints.append([[0, 0, 0], [0, 0, 0]])
348
 
349
  new_bbox = np.concatenate((bboxes, arrow_bboxes))
350
  new_lab = np.concatenate((new_lab, arrow_labels))
351
+ new_scores = np.concatenate((object_scores, arrow_score))
352
+ new_keypoints = np.concatenate((object_keypoints, arrow_keypoints))
353
 
354
+
355
+ boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, new_scores, new_keypoints, class_dict, correction=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
  st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
358
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
modules/toXML.py CHANGED
@@ -79,32 +79,24 @@ def align_boxes(pred, size):
79
  ]
80
 
81
  # Step 3: Expand the pool bounding boxes to fit the aligned elements
82
- for idx, (pool_index, keep_elements) in enumerate(modified_pred['pool_dict'].items()):
83
- size_elements = get_size_elements(st.session_state.size_scale)
84
- if len(keep_elements) != 0:
85
- min_x, min_y, max_x, max_y = calculate_pool_bounds(modified_pred['boxes'],modified_pred['labels'], keep_elements, size_elements)
86
- marge = 50
87
- else:
88
- if pool_index >= len(modified_pred['boxes']):
89
- print("Problem with the index")
90
- continue
91
  min_x, min_y, max_x, max_y = modified_pred['boxes'][pool_index]
92
- marge = 0
93
-
94
- pool_width = max_x - min_x
95
- pool_height = max_y - min_y
96
- if pool_width < 300 or pool_height < 30:
97
- error("The pool is maybe too small, please add more elements or increase the scale by zooming on the image.")
98
- continue
99
 
100
- if pool_index >= len(modified_pred['boxes']):
101
- new_box = np.array([min_x - marge, min_y - marge//2, min_x + pool_width + marge, min_y + pool_height + marge//2])
102
- modified_pred['boxes'] = np.append(modified_pred['boxes'], [new_box], axis=0)
103
- else:
104
  modified_pred['boxes'][pool_index] = [min_x -marge, min_y-marge//2, min_x+pool_width+marge, min_y+pool_height+marge//2]
105
 
106
- min_left,max_right = 0, 0
107
- if len(pred['pool_dict'])>1:
108
  for pool_index, element_indices in pred['pool_dict'].items():
109
  if pool_index >= len(modified_pred['boxes']):
110
  print(f"Problem with the index {pool_index} with a length of {len(modified_pred['boxes'])}")
@@ -152,6 +144,10 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
152
  })
153
 
154
  size_elements = get_size_elements(size_scale)
 
 
 
 
155
 
156
  #modify the boxes positions
157
  old_boxes = copy.deepcopy(full_pred)
@@ -217,6 +213,20 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
217
 
218
  return pretty_xml_as_string
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  def get_size_elements(size_scale):
221
  size_elements = {
222
  'event': (size_scale*43.2, size_scale*43.2),
 
79
  ]
80
 
81
  # Step 3: Expand the pool bounding boxes to fit the aligned elements
82
+ if len(pred['pool_dict'])>1:
83
+ for idx, (pool_index, keep_elements) in enumerate(modified_pred['pool_dict'].items()):
84
+ size_elements = get_size_elements(st.session_state.size_scale)
85
+ if len(keep_elements) != 0:
86
+ marge = 50
87
+ else:
88
+ marge = 0
89
+
 
90
  min_x, min_y, max_x, max_y = modified_pred['boxes'][pool_index]
91
+ pool_width = max_x - min_x
92
+ pool_height = max_y - min_y
93
+ if pool_width < 300 or pool_height < 30:
94
+ error("The pool is maybe too small, please add more elements or increase the scale by zooming on the image.")
95
+ continue
 
 
96
 
 
 
 
 
97
  modified_pred['boxes'][pool_index] = [min_x -marge, min_y-marge//2, min_x+pool_width+marge, min_y+pool_height+marge//2]
98
 
99
+ min_left,max_right = 0, 0
 
100
  for pool_index, element_indices in pred['pool_dict'].items():
101
  if pool_index >= len(modified_pred['boxes']):
102
  print(f"Problem with the index {pool_index} with a length of {len(modified_pred['boxes'])}")
 
144
  })
145
 
146
  size_elements = get_size_elements(size_scale)
147
+
148
+ #if there is no pool or lane, create a pool with all elements
149
+ if len(full_pred['pool_dict'])==0 or (len(full_pred['pool_dict'])==1 and list(full_pred['pool_dict'])[0]==len(full_pred['labels'])):
150
+ full_pred, text_mapping = create_big_pool(full_pred, text_mapping)
151
 
152
  #modify the boxes positions
153
  old_boxes = copy.deepcopy(full_pred)
 
213
 
214
  return pretty_xml_as_string
215
 
216
+ def create_big_pool(full_pred, text_mapping):
217
+ # If no pools or lanes are detected, create a single pool with all elements
218
+ new_pool_index = len(full_pred['labels'])
219
+ size_elements = get_size_elements(st.session_state.size_scale)
220
+ elements_pool = list(range(len(full_pred['boxes'])))
221
+ min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred['boxes'],full_pred['labels'], elements_pool, size_elements)
222
+ box = [min_x-50, min_y-50, max_x+100, max_y+100]
223
+ full_pred['boxes'] = np.append(full_pred['boxes'], [box], axis=0)
224
+ full_pred['pool_dict'][new_pool_index] = elements_pool
225
+ full_pred['BPMN_id'].append('pool_1')
226
+ text_mapping['pool_1'] = 'Process'
227
+ print(f"Created a new pool index {new_pool_index} with elements: {elements_pool}")
228
+ return full_pred, text_mapping
229
+
230
  def get_size_elements(size_scale):
231
  size_elements = {
232
  'event': (size_scale*43.2, size_scale*43.2),