BenjiELCA commited on
Commit
ebef706
Β·
1 Parent(s): c3f2df0

add verification of length and ready for demo

Browse files
Files changed (4) hide show
  1. app.py +84 -10
  2. modules/display.py +2 -0
  3. modules/streamlit_utils.py +1 -1
  4. modules/toXML.py +10 -2
app.py CHANGED
@@ -13,7 +13,7 @@ from glob import glob
13
  from streamlit_image_annotation import detection
14
  from modules.toXML import create_XML
15
  from modules.eval import develop_prediction, generate_data
16
- from modules.utils import class_dict
17
 
18
  def configure_page():
19
  st.set_page_config(layout="wide")
@@ -114,41 +114,114 @@ def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
114
  score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
115
  )
116
  st.balloons()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  def modify_results(percentage_text_dist_thresh=0.5):
120
  with st.expander("Method and Style modification (beta version)"):
121
- label_list = list(class_dict.values())
122
  bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
123
  for i in range(len(bboxes)):
124
  bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
125
  bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
126
  labels = [int(label) for label in st.session_state.prediction['labels']]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
128
  scale = 2000 / uploaded_image.size[0]
129
  new_labels = detection(
130
- image=uploaded_image, bboxes=bboxes, labels=labels,
131
  label_list=label_list, line_width=3, width=2000, use_space=False
132
  )
133
 
134
  if new_labels is not None:
135
- new_lab = np.array([label['label_id'] for label in new_labels])
136
-
137
  # Convert back to original format
138
  bboxes = np.array([label['bbox'] for label in new_labels])
139
  for i in range(len(bboxes)):
140
  bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
141
  bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
 
 
 
 
 
 
 
 
 
142
 
143
  scores = st.session_state.prediction['scores']
144
  keypoints = st.session_state.prediction['keypoints']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  #print('Old prediction:', st.session_state.prediction['keypoints'])
146
- boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(bboxes, new_lab, scores, keypoints, class_dict, correction=False)
147
 
148
  st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
149
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
150
 
151
  #print('New prediction:', st.session_state.prediction['keypoints'])
 
152
 
153
 
154
  def display_bpmn_modeler(is_mobile, screen_width):
@@ -186,15 +259,16 @@ def main():
186
 
187
  if cropped_image is not None:
188
  get_score_threshold(is_mobile)
189
- if st.button("Launch Prediction"):
190
  launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
 
191
  st.rerun()
192
 
193
  if 'prediction' in st.session_state and uploaded_file:
194
- if st.button("πŸ”„ Refresh image"):
195
- st.rerun()
196
 
197
- with st.expander("Show result"):
198
  with st.spinner('Waiting for result display...'):
199
  display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))
200
 
 
13
  from streamlit_image_annotation import detection
14
  from modules.toXML import create_XML
15
  from modules.eval import develop_prediction, generate_data
16
+ from modules.utils import class_dict, object_dict
17
 
18
  def configure_page():
19
  st.set_page_config(layout="wide")
 
114
  score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
115
  )
116
  st.balloons()
117
+
118
+ def mix_new_pred(objects_pred, arrow_pred):
119
+ # Initialize the list of lists for keypoints
120
+ object_keypoints = []
121
+
122
+ # Number of boxes
123
+ num_boxes = len(objects_pred['boxes'])
124
+
125
+ # Iterate over the number of boxes
126
+ for _ in range(num_boxes):
127
+ # Each box has 2 keypoints, both initialized to [0, 0, 0]
128
+ keypoints = [[0, 0, 0], [0, 0, 0]]
129
+ object_keypoints.append(keypoints)
130
+
131
+ #concatenate the two predictions
132
+ boxes = np.concatenate((objects_pred['boxes'], arrow_pred['boxes']))
133
+ labels = np.concatenate((objects_pred['labels'], arrow_pred['labels']))
134
+
135
+ return boxes, labels, keypoints
136
 
137
 
138
  def modify_results(percentage_text_dist_thresh=0.5):
139
  with st.expander("Method and Style modification (beta version)"):
140
+ label_list = list(object_dict.values())
141
  bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
142
  for i in range(len(bboxes)):
143
  bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
144
  bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
145
  labels = [int(label) for label in st.session_state.prediction['labels']]
146
+
147
+
148
+ # Filter boxes and labels where label is less than 12
149
+ ignore_labels = [6, 7]
150
+ object_bboxes = []
151
+ object_labels = []
152
+ arrow_bboxes = []
153
+ arrow_labels = []
154
+ for i in range(len(bboxes)):
155
+ if labels[i] <= 12:
156
+ object_bboxes.append(bboxes[i])
157
+ object_labels.append(labels[i])
158
+ else:
159
+ arrow_bboxes.append(bboxes[i])
160
+ arrow_labels.append(labels[i])
161
+
162
+ print('Object bboxes:', object_bboxes)
163
+ print('Object labels:', object_labels)
164
+ print('Arrow bboxes:', arrow_bboxes)
165
+ print('Arrow labels:', arrow_labels)
166
+
167
+ original_obj_len = len(object_bboxes)
168
+
169
+
170
  uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
171
  scale = 2000 / uploaded_image.size[0]
172
  new_labels = detection(
173
+ image=uploaded_image, bboxes=object_bboxes, labels=object_labels,
174
  label_list=label_list, line_width=3, width=2000, use_space=False
175
  )
176
 
177
  if new_labels is not None:
178
+ new_lab = np.array([label['label_id'] for label in new_labels])
 
179
  # Convert back to original format
180
  bboxes = np.array([label['bbox'] for label in new_labels])
181
  for i in range(len(bboxes)):
182
  bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
183
  bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
184
+ for i in range(len(arrow_bboxes)):
185
+ arrow_bboxes[i][2] = arrow_bboxes[i][2] + arrow_bboxes[i][0]
186
+ arrow_bboxes[i][3] = arrow_bboxes[i][3] + arrow_bboxes[i][1]
187
+
188
+
189
+ new_bbox = np.concatenate((bboxes, arrow_bboxes))
190
+ new_lab = np.concatenate((new_lab, arrow_labels))
191
+
192
+ print('New labels:', new_lab)
193
 
194
  scores = st.session_state.prediction['scores']
195
  keypoints = st.session_state.prediction['keypoints']
196
+
197
+ #delete element in keypoints to make it match the new number of boxes
198
+ len_keypoints = len(keypoints)
199
+ keypoints = keypoints.tolist()
200
+ scores = scores.tolist()
201
+
202
+
203
+ diff = original_obj_len-len(bboxes)
204
+ if diff > 0:
205
+ for i in range(diff):
206
+ keypoints.pop(0)
207
+ scores.pop(0)
208
+ elif diff < 0:
209
+ for i in range(-diff):
210
+ keypoints.insert(0, [[0, 0, 0], [0, 0, 0]])
211
+ scores.insert(0, 0.0)
212
+
213
+ print('lenghts: ',len(bboxes), len(new_lab), len(scores), len(keypoints))
214
+ keypoints = np.array(keypoints)
215
+ scores = np.array(scores)
216
+
217
  #print('Old prediction:', st.session_state.prediction['keypoints'])
218
+ boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, scores, keypoints, class_dict, correction=False)
219
 
220
  st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
221
  st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
222
 
223
  #print('New prediction:', st.session_state.prediction['keypoints'])
224
+ st.rerun()
225
 
226
 
227
  def display_bpmn_modeler(is_mobile, screen_width):
 
259
 
260
  if cropped_image is not None:
261
  get_score_threshold(is_mobile)
262
+ if st.button("πŸš€ Launch Prediction"):
263
  launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
264
+ st.session_state.original_prediction = st.session_state.prediction.copy()
265
  st.rerun()
266
 
267
  if 'prediction' in st.session_state and uploaded_file:
268
+ #if st.button("πŸ”„ Refresh image"):
269
+ #st.rerun()
270
 
271
+ with st.expander("Show result of prediction"):
272
  with st.spinner('Waiting for result display...'):
273
  display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))
274
 
modules/display.py CHANGED
@@ -94,6 +94,8 @@ def draw_stream(image,
94
  # Draw keypoints if available
95
  if draw_keypoints and 'keypoints' in prediction:
96
  for i in range(len(prediction['keypoints'])):
 
 
97
  kp = prediction['keypoints'][i]
98
  for j in range(kp.shape[0]):
99
  if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
 
94
  # Draw keypoints if available
95
  if draw_keypoints and 'keypoints' in prediction:
96
  for i in range(len(prediction['keypoints'])):
97
+ if i >= len(prediction['keypoints']):
98
+ continue
99
  kp = prediction['keypoints'][i]
100
  for j in range(kp.shape[0]):
101
  if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
modules/streamlit_utils.py CHANGED
@@ -130,7 +130,7 @@ def display_options(image, score_threshold, is_mobile, screen_width):
130
 
131
  # Draw the annotated image with selected options
132
  annotated_image = draw_stream(
133
- np.array(image), prediction=st.session_state.prediction, text_predictions=st.session_state.text_pred,
134
  draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
135
  write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_show=selected_option,
136
  score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
 
130
 
131
  # Draw the annotated image with selected options
132
  annotated_image = draw_stream(
133
+ np.array(image), prediction=st.session_state.original_prediction, text_predictions=st.session_state.text_pred,
134
  draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
135
  write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_show=selected_option,
136
  score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
modules/toXML.py CHANGED
@@ -13,7 +13,7 @@ def align_boxes(pred, size):
13
  for pool_index, element_indices in pred['pool_dict'].items():
14
  pool_groups[pool_index] = []
15
  for i in element_indices:
16
- if i > len(modified_pred['labels']):
17
  continue
18
  if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
19
  x1, y1, x2, y2 = modified_pred['boxes'][i]
@@ -138,7 +138,7 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
138
  pool_width = max_x - min_x + 100 # Adding padding
139
  pool_height = max_y - min_y + 100 # Adding padding
140
  #check area
141
- if pool_width < 400 or pool_height < 30:
142
  print("The pool is too small, please add more elements or increase the scale")
143
  continue
144
 
@@ -157,6 +157,9 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
157
  # Create sequence flow elements
158
  for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
159
  for i in keep_elements:
 
 
 
160
  if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
161
  create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
162
 
@@ -259,6 +262,8 @@ def add_diagram_edge(parent, element_id, waypoints):
259
  'id': element_id + '_di'
260
  })
261
  for x, y in waypoints:
 
 
262
  ET.SubElement(edge, 'di:waypoint', attrib={
263
  'x': str(x),
264
  'y': str(y)
@@ -312,6 +317,9 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
312
  links = data['links']
313
 
314
  for i in keep_elements:
 
 
 
315
  element_id = elements[i]
316
 
317
  if element_id is None:
 
13
  for pool_index, element_indices in pred['pool_dict'].items():
14
  pool_groups[pool_index] = []
15
  for i in element_indices:
16
+ if i >= len(modified_pred['labels']):
17
  continue
18
  if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
19
  x1, y1, x2, y2 = modified_pred['boxes'][i]
 
138
  pool_width = max_x - min_x + 100 # Adding padding
139
  pool_height = max_y - min_y + 100 # Adding padding
140
  #check area
141
+ if pool_width < 300 or pool_height < 30:
142
  print("The pool is too small, please add more elements or increase the scale")
143
  continue
144
 
 
157
  # Create sequence flow elements
158
  for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
159
  for i in keep_elements:
160
+ if i >= len(full_pred['labels']):
161
+ print("Problem with the index")
162
+ continue
163
  if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
164
  create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
165
 
 
262
  'id': element_id + '_di'
263
  })
264
  for x, y in waypoints:
265
+ if x is None or y is None:
266
+ return
267
  ET.SubElement(edge, 'di:waypoint', attrib={
268
  'x': str(x),
269
  'y': str(y)
 
317
  links = data['links']
318
 
319
  for i in keep_elements:
320
+ if i >= len(elements):
321
+ print("Problem with the index")
322
+ continue
323
  element_id = elements[i]
324
 
325
  if element_id is None: