Spaces:
Running
Running
pool align width and modification works better
Browse files- modules/eval.py +10 -3
- modules/streamlit_utils.py +207 -3
- modules/toXML.py +59 -31
modules/eval.py
CHANGED
@@ -239,8 +239,15 @@ def regroup_elements_by_pool(boxes, labels, scores, keypoints, class_dict, iou_t
|
|
239 |
elements_not_in_pool.append(i)
|
240 |
|
241 |
if elements_not_in_pool:
|
242 |
-
|
|
|
243 |
labels = np.append(labels, list(class_dict.values()).index('pool'))
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
pool_dict[new_pool_index] = elements_not_in_pool
|
245 |
|
246 |
# Separate empty pools
|
@@ -330,7 +337,8 @@ def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_d
|
|
330 |
for pool_index, elements in pool_dict.items():
|
331 |
if all([labels[i] in [list(class_dict.values()).index('messageFlow'),
|
332 |
list(class_dict.values()).index('sequenceFlow'),
|
333 |
-
list(class_dict.values()).index('dataAssociation')
|
|
|
334 |
if len(elements) > 0:
|
335 |
delete_pool.append(pool_index)
|
336 |
print(f"Pool {pool_index} contains only arrow elements, deleting it")
|
@@ -339,7 +347,6 @@ def last_correction(boxes, labels, scores, keypoints, links, best_points, pool_d
|
|
339 |
if pool_index < len(boxes):
|
340 |
pool = boxes[pool_index]
|
341 |
area = (pool[2] - pool[0]) * (pool[3] - pool[1])
|
342 |
-
print("area: ",area)
|
343 |
if len(pool_dict)>1 and area < limit_area:
|
344 |
delete_pool.append(pool_index)
|
345 |
print(f"Pool {pool_index} is too small, deleting it")
|
|
|
239 |
elements_not_in_pool.append(i)
|
240 |
|
241 |
if elements_not_in_pool:
|
242 |
+
elements_not_in_pool_to_delete = []
|
243 |
+
new_pool_index = len(labels)
|
244 |
labels = np.append(labels, list(class_dict.values()).index('pool'))
|
245 |
+
#find the messageflow,pool and lane in the elements_not_in_pool
|
246 |
+
for i in elements_not_in_pool:
|
247 |
+
if class_dict[labels[i]] == 'messageFlow' or class_dict[labels[i]] == 'lane' or class_dict[labels[i]] == 'pool':
|
248 |
+
elements_not_in_pool_to_delete.append(i)
|
249 |
+
#delete the messageflow from the elements_not_in_pool
|
250 |
+
elements_not_in_pool = [i for i in elements_not_in_pool if i not in elements_not_in_pool_to_delete]
|
251 |
pool_dict[new_pool_index] = elements_not_in_pool
|
252 |
|
253 |
# Separate empty pools
|
|
|
337 |
for pool_index, elements in pool_dict.items():
|
338 |
if all([labels[i] in [list(class_dict.values()).index('messageFlow'),
|
339 |
list(class_dict.values()).index('sequenceFlow'),
|
340 |
+
list(class_dict.values()).index('dataAssociation'),
|
341 |
+
list(class_dict.values()).index('lane')] for i in elements]):
|
342 |
if len(elements) > 0:
|
343 |
delete_pool.append(pool_index)
|
344 |
print(f"Pool {pool_index} contains only arrow elements, deleting it")
|
|
|
347 |
if pool_index < len(boxes):
|
348 |
pool = boxes[pool_index]
|
349 |
area = (pool[2] - pool[0]) * (pool[3] - pool[1])
|
|
|
350 |
if len(pool_dict)>1 and area < limit_area:
|
351 |
delete_pool.append(pool_index)
|
352 |
print(f"Pool {pool_index} is too small, deleting it")
|
modules/streamlit_utils.py
CHANGED
@@ -4,8 +4,6 @@ import torch
|
|
4 |
from torchvision.transforms import functional as F
|
5 |
import gc
|
6 |
import psutil
|
7 |
-
import copy
|
8 |
-
import xml.etree.ElementTree as ET
|
9 |
import numpy as np
|
10 |
from pathlib import Path
|
11 |
import gdown
|
@@ -17,6 +15,17 @@ from modules.eval import full_prediction
|
|
17 |
from modules.train import get_faster_rcnn_model, get_arrow_model
|
18 |
from streamlit_image_comparison import image_comparison
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def get_memory_usage():
|
@@ -186,4 +195,199 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
|
|
186 |
|
187 |
@st.cache_data
|
188 |
def get_image(uploaded_file):
|
189 |
-
return Image.open(uploaded_file).convert('RGB')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from torchvision.transforms import functional as F
|
5 |
import gc
|
6 |
import psutil
|
|
|
|
|
7 |
import numpy as np
|
8 |
from pathlib import Path
|
9 |
import gdown
|
|
|
15 |
from modules.train import get_faster_rcnn_model, get_arrow_model
|
16 |
from streamlit_image_comparison import image_comparison
|
17 |
|
18 |
+
from streamlit_image_annotation import detection
|
19 |
+
from modules.toXML import create_XML
|
20 |
+
from modules.eval import develop_prediction, generate_data
|
21 |
+
from modules.utils import class_dict, object_dict
|
22 |
+
|
23 |
+
from modules.htlm_webpage import display_bpmn_xml
|
24 |
+
from streamlit_cropper import st_cropper
|
25 |
+
from streamlit_image_select import image_select
|
26 |
+
from streamlit_js_eval import streamlit_js_eval
|
27 |
+
|
28 |
+
|
29 |
|
30 |
|
31 |
def get_memory_usage():
|
|
|
195 |
|
196 |
@st.cache_data
|
197 |
def get_image(uploaded_file):
|
198 |
+
return Image.open(uploaded_file).convert('RGB')
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
def configure_page():
|
203 |
+
st.set_page_config(layout="wide")
|
204 |
+
screen_width = streamlit_js_eval(js_expressions='screen.width', want_output=True, key='SCR')
|
205 |
+
is_mobile = screen_width is not None and screen_width < 800
|
206 |
+
return is_mobile, screen_width
|
207 |
+
|
208 |
+
def display_banner(is_mobile):
|
209 |
+
if is_mobile:
|
210 |
+
st.image("./images/banner_mobile.png", use_column_width=True)
|
211 |
+
else:
|
212 |
+
st.image("./images/banner_desktop.png", use_column_width=True)
|
213 |
+
|
214 |
+
def display_title(is_mobile):
|
215 |
+
title = "Welcome on the BPMN AI model recognition app"
|
216 |
+
if is_mobile:
|
217 |
+
title = "Welcome on the mobile version of BPMN AI model recognition app"
|
218 |
+
st.title(title)
|
219 |
+
|
220 |
+
def display_sidebar():
|
221 |
+
sidebar()
|
222 |
+
|
223 |
+
def initialize_session_state():
|
224 |
+
if 'pool_bboxes' not in st.session_state:
|
225 |
+
st.session_state.pool_bboxes = []
|
226 |
+
if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state:
|
227 |
+
clear_memory()
|
228 |
+
load_models()
|
229 |
+
|
230 |
+
def load_example_image():
|
231 |
+
with st.expander("Use example images"):
|
232 |
+
img_selected = image_select(
|
233 |
+
"If you have no image and just want to test the demo, click on one of these images",
|
234 |
+
["./images/none.jpg", "./images/example1.jpg", "./images/example2.jpg", "./images/example3.jpg", "./images/example4.jpg"],
|
235 |
+
captions=["None", "Example 1", "Example 2", "Example 3", "Example 4"],
|
236 |
+
index=0,
|
237 |
+
use_container_width=False,
|
238 |
+
return_value="original"
|
239 |
+
)
|
240 |
+
return img_selected
|
241 |
+
|
242 |
+
def load_user_image(img_selected, is_mobile):
|
243 |
+
if img_selected == './images/none.jpg':
|
244 |
+
img_selected = None
|
245 |
+
|
246 |
+
if img_selected is not None:
|
247 |
+
uploaded_file = img_selected
|
248 |
+
else:
|
249 |
+
if is_mobile:
|
250 |
+
uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"], accept_multiple_files=False)
|
251 |
+
else:
|
252 |
+
col1, col2 = st.columns(2)
|
253 |
+
with col1:
|
254 |
+
uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])
|
255 |
+
|
256 |
+
return uploaded_file
|
257 |
+
|
258 |
+
def display_image(uploaded_file, screen_width, is_mobile):
|
259 |
+
|
260 |
+
with st.spinner('Waiting for image display...'):
|
261 |
+
original_image = get_image(uploaded_file)
|
262 |
+
resized_image = original_image.resize((screen_width // 2, int(original_image.height * (screen_width // 2) / original_image.width)))
|
263 |
+
|
264 |
+
if not is_mobile:
|
265 |
+
cropped_image = crop_image(resized_image, original_image)
|
266 |
+
else:
|
267 |
+
st.image(resized_image, caption="Image", use_column_width=False, width=int(4/5 * screen_width))
|
268 |
+
cropped_image = original_image
|
269 |
+
|
270 |
+
return cropped_image
|
271 |
+
|
272 |
+
def crop_image(resized_image, original_image):
|
273 |
+
marge = 10
|
274 |
+
cropped_box = st_cropper(
|
275 |
+
resized_image,
|
276 |
+
realtime_update=True,
|
277 |
+
box_color='#0000FF',
|
278 |
+
return_type='box',
|
279 |
+
should_resize_image=False,
|
280 |
+
default_coords=(marge, resized_image.width - marge, marge, resized_image.height - marge)
|
281 |
+
)
|
282 |
+
scale_x = original_image.width / resized_image.width
|
283 |
+
scale_y = original_image.height / resized_image.height
|
284 |
+
x0, y0, x1, y1 = int(cropped_box['left'] * scale_x), int(cropped_box['top'] * scale_y), int((cropped_box['left'] + cropped_box['width']) * scale_x), int((cropped_box['top'] + cropped_box['height']) * scale_y)
|
285 |
+
cropped_image = original_image.crop((x0, y0, x1, y1))
|
286 |
+
return cropped_image
|
287 |
+
|
288 |
+
def get_score_threshold(is_mobile):
|
289 |
+
col1, col2 = st.columns(2)
|
290 |
+
with col1:
|
291 |
+
st.session_state.score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5 if not is_mobile else 0.6, step=0.05)
|
292 |
+
|
293 |
+
def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
|
294 |
+
st.session_state.crop_image = cropped_image
|
295 |
+
with st.spinner('Processing...'):
|
296 |
+
perform_inference(
|
297 |
+
st.session_state.model_object, st.session_state.model_arrow, st.session_state.crop_image,
|
298 |
+
score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
|
299 |
+
)
|
300 |
+
st.balloons()
|
301 |
+
|
302 |
+
|
303 |
+
def modify_results(percentage_text_dist_thresh=0.5):
|
304 |
+
with st.expander("Method and Style modification (beta version)"):
|
305 |
+
label_list = list(object_dict.values())
|
306 |
+
bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
|
307 |
+
for i in range(len(bboxes)):
|
308 |
+
bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
|
309 |
+
bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
|
310 |
+
labels = [int(label) for label in st.session_state.prediction['labels']]
|
311 |
+
|
312 |
+
# Filter boxes and labels where label is less than 12
|
313 |
+
object_bboxes = []
|
314 |
+
object_labels = []
|
315 |
+
arrow_bboxes = []
|
316 |
+
arrow_labels = []
|
317 |
+
for i in range(len(bboxes)):
|
318 |
+
if labels[i] <= 12:
|
319 |
+
object_bboxes.append(bboxes[i])
|
320 |
+
object_labels.append(labels[i])
|
321 |
+
else:
|
322 |
+
arrow_bboxes.append(bboxes[i])
|
323 |
+
arrow_labels.append(labels[i])
|
324 |
+
|
325 |
+
original_obj_len = len(object_bboxes)
|
326 |
+
|
327 |
+
uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
|
328 |
+
|
329 |
+
new_labels = detection(
|
330 |
+
image=uploaded_image, bboxes=object_bboxes, labels=object_labels,
|
331 |
+
label_list=label_list, line_width=3, width=2000, use_space=False
|
332 |
+
)
|
333 |
+
|
334 |
+
if new_labels is not None:
|
335 |
+
new_lab = np.array([label['label_id'] for label in new_labels])
|
336 |
+
# Convert back to original format
|
337 |
+
bboxes = np.array([label['bbox'] for label in new_labels])
|
338 |
+
for i in range(len(bboxes)):
|
339 |
+
bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
|
340 |
+
bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
|
341 |
+
for i in range(len(arrow_bboxes)):
|
342 |
+
arrow_bboxes[i][2] = arrow_bboxes[i][2] + arrow_bboxes[i][0]
|
343 |
+
arrow_bboxes[i][3] = arrow_bboxes[i][3] + arrow_bboxes[i][1]
|
344 |
+
|
345 |
+
new_bbox = np.concatenate((bboxes, arrow_bboxes))
|
346 |
+
new_lab = np.concatenate((new_lab, arrow_labels))
|
347 |
+
|
348 |
+
scores = st.session_state.prediction['scores']
|
349 |
+
keypoints = st.session_state.prediction['keypoints']
|
350 |
+
|
351 |
+
#delete element in keypoints to make it match the new number of boxes
|
352 |
+
keypoints = keypoints.tolist()
|
353 |
+
scores = scores.tolist()
|
354 |
+
|
355 |
+
diff = original_obj_len-len(bboxes)
|
356 |
+
if diff > 0:
|
357 |
+
for i in range(diff):
|
358 |
+
keypoints.pop(0)
|
359 |
+
scores.pop(0)
|
360 |
+
elif diff < 0:
|
361 |
+
for i in range(-diff):
|
362 |
+
keypoints.insert(0, [[0, 0, 0], [0, 0, 0]])
|
363 |
+
scores.insert(0, 0.0)
|
364 |
+
|
365 |
+
keypoints = np.array(keypoints)
|
366 |
+
scores = np.array(scores)
|
367 |
+
|
368 |
+
boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, scores, keypoints, class_dict, correction=False)
|
369 |
+
|
370 |
+
st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
|
371 |
+
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
372 |
+
|
373 |
+
st.rerun()
|
374 |
+
|
375 |
+
|
376 |
+
def display_bpmn_modeler(is_mobile, screen_width):
|
377 |
+
with st.spinner('Waiting for BPMN modeler...'):
|
378 |
+
st.session_state.bpmn_xml = create_XML(
|
379 |
+
st.session_state.prediction.copy(), st.session_state.text_mapping,
|
380 |
+
st.session_state.size_scale, st.session_state.scale
|
381 |
+
)
|
382 |
+
display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5 * screen_width))
|
383 |
+
|
384 |
+
def modeler_options(is_mobile):
|
385 |
+
if not is_mobile:
|
386 |
+
with st.expander("Options for BPMN modeler"):
|
387 |
+
col1, col2 = st.columns(2)
|
388 |
+
with col1:
|
389 |
+
st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
|
390 |
+
st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
|
391 |
+
else:
|
392 |
+
st.session_state.scale = 1.0
|
393 |
+
st.session_state.size_scale = 1.0
|
modules/toXML.py
CHANGED
@@ -4,6 +4,7 @@ import streamlit as st
|
|
4 |
from modules.utils import class_dict, rescale_boxes
|
5 |
import copy
|
6 |
from xml.dom import minidom
|
|
|
7 |
|
8 |
def align_boxes(pred, size):
|
9 |
modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction
|
@@ -76,6 +77,44 @@ def align_boxes(pred, size):
|
|
76 |
new_center[0] + size[label][0] / 2,
|
77 |
modified_pred['boxes'][idx][3]
|
78 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
return modified_pred['boxes']
|
80 |
|
81 |
# Function to create a BPMN XML file from prediction results
|
@@ -125,24 +164,12 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
125 |
pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]])
|
126 |
|
127 |
# Calculate the bounding box for the pool
|
128 |
-
if len(keep_elements) == 0:
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
#check area
|
133 |
-
if pool_width < 400 or pool_height < 30:
|
134 |
-
print("The pool is too small, please add more elements or increase the scale")
|
135 |
-
continue
|
136 |
-
else:
|
137 |
-
min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred, keep_elements, size_elements)
|
138 |
-
pool_width = max_x - min_x + 100 # Adding padding
|
139 |
-
pool_height = max_y - min_y + 100 # Adding padding
|
140 |
-
#check area
|
141 |
-
if pool_width < 300 or pool_height < 30:
|
142 |
-
print("The pool is too small, please add more elements or increase the scale")
|
143 |
-
continue
|
144 |
|
145 |
-
add_diagram_elements(bpmnplane, pool_id, min_x
|
146 |
|
147 |
|
148 |
# Create BPMN elements for each pool
|
@@ -446,21 +473,21 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
|
|
446 |
add_diagram_elements(bpmnplane, element_id, x, y, size['timerEvent'][0], size['timerEvent'][1])
|
447 |
|
448 |
|
449 |
-
def calculate_pool_bounds(
|
450 |
-
min_x
|
451 |
-
max_x
|
452 |
-
|
453 |
for i in keep_elements:
|
454 |
-
if i >= len(
|
455 |
print("Problem with the index")
|
456 |
continue
|
457 |
-
|
458 |
-
|
|
|
459 |
continue
|
460 |
|
461 |
-
|
462 |
-
|
463 |
-
element_width, element_height = size[element_type]
|
464 |
|
465 |
min_x = min(min_x, x)
|
466 |
min_y = min(min_y, y)
|
@@ -469,6 +496,7 @@ def calculate_pool_bounds(data, keep_elements, size):
|
|
469 |
|
470 |
return min_x, min_y, max_x, max_y
|
471 |
|
|
|
472 |
|
473 |
def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_element, target_element):
|
474 |
# Get the bounding boxes of the source and target elements
|
@@ -489,9 +517,9 @@ def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_ele
|
|
489 |
element_mid_y = (element_box[1] + element_box[3]) / 2
|
490 |
# Connect the pool's bottom or top side to the target element's top or bottom center
|
491 |
if pool_box[3] < element_box[1]: # Pool is above the target element
|
492 |
-
waypoints = [(element_mid_x, pool_box[3]
|
493 |
else: # Pool is below the target element
|
494 |
-
waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1]
|
495 |
else:
|
496 |
pool_box = target_box
|
497 |
element_box = (source_box[0], source_box[1], source_box[0]+size[source_element][0], source_box[1]+size[source_element][1])
|
@@ -500,9 +528,9 @@ def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_ele
|
|
500 |
|
501 |
# Connect the element's bottom or top center to the pool's top or bottom side
|
502 |
if pool_box[3] < element_box[1]: # Pool is above the target element
|
503 |
-
waypoints = [(element_mid_x, element_box[1]), (element_mid_x, pool_box[3]
|
504 |
else: # Pool is below the target element
|
505 |
-
waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1]
|
506 |
|
507 |
return waypoints
|
508 |
|
|
|
4 |
from modules.utils import class_dict, rescale_boxes
|
5 |
import copy
|
6 |
from xml.dom import minidom
|
7 |
+
import numpy as np
|
8 |
|
9 |
def align_boxes(pred, size):
|
10 |
modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction
|
|
|
77 |
new_center[0] + size[label][0] / 2,
|
78 |
modified_pred['boxes'][idx][3]
|
79 |
]
|
80 |
+
|
81 |
+
# Step 3: Expand the pool bounding boxes to fit the aligned elements
|
82 |
+
for idx, (pool_index, keep_elements) in enumerate(modified_pred['pool_dict'].items()):
|
83 |
+
size_elements = get_size_elements(st.session_state.size_scale)
|
84 |
+
if len(keep_elements) != 0:
|
85 |
+
min_x, min_y, max_x, max_y = calculate_pool_bounds(modified_pred['boxes'],modified_pred['labels'], keep_elements, size_elements)
|
86 |
+
else:
|
87 |
+
min_x, min_y, max_x, max_y = modified_pred['boxes'][pool_index]
|
88 |
+
pool_width = max_x - min_x
|
89 |
+
pool_height = max_y - min_y
|
90 |
+
if pool_width < 300 or pool_height < 30:
|
91 |
+
error("The pool is maybe too small, please add more elements or increase the scale by zooming on the image.")
|
92 |
+
continue
|
93 |
+
if pool_index >= len(modified_pred['boxes']):
|
94 |
+
new_box = np.array([min_x - 50, min_y - 50, min_x + pool_width + 50, min_y + pool_height + 50])
|
95 |
+
modified_pred['boxes'] = np.append(modified_pred['boxes'], [new_box], axis=0)
|
96 |
+
else:
|
97 |
+
modified_pred['boxes'][pool_index] = [min_x -50, min_y-50, min_x+pool_width+50, min_y+pool_height+50]
|
98 |
+
|
99 |
+
min_left,max_right = 0, 0
|
100 |
+
if len(pred['pool_dict'])>1:
|
101 |
+
for pool_index, element_indices in pred['pool_dict'].items():
|
102 |
+
x1, y1, x2, y2 = modified_pred['boxes'][pool_index]
|
103 |
+
left = x1
|
104 |
+
right = x2
|
105 |
+
if left < min_left:
|
106 |
+
min_left = left
|
107 |
+
if right > max_right:
|
108 |
+
max_right = right
|
109 |
+
|
110 |
+
for pool_index, element_indices in pred['pool_dict'].items():
|
111 |
+
x1, y1, x2, y2 = modified_pred['boxes'][pool_index]
|
112 |
+
if x1 > min_left:
|
113 |
+
x1 = min_left
|
114 |
+
if x2 < max_right:
|
115 |
+
x2 = max_right
|
116 |
+
modified_pred['boxes'][pool_index] = [x1, y1, x2, y2]
|
117 |
+
|
118 |
return modified_pred['boxes']
|
119 |
|
120 |
# Function to create a BPMN XML file from prediction results
|
|
|
164 |
pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]])
|
165 |
|
166 |
# Calculate the bounding box for the pool
|
167 |
+
#if len(keep_elements) == 0:
|
168 |
+
min_x, min_y, max_x, max_y = full_pred['boxes'][pool_index]
|
169 |
+
pool_width = max_x - min_x
|
170 |
+
pool_height = max_y - min_y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
+
add_diagram_elements(bpmnplane, pool_id, min_x, min_y, pool_width, pool_height)
|
173 |
|
174 |
|
175 |
# Create BPMN elements for each pool
|
|
|
473 |
add_diagram_elements(bpmnplane, element_id, x, y, size['timerEvent'][0], size['timerEvent'][1])
|
474 |
|
475 |
|
476 |
+
def calculate_pool_bounds(boxes, labels, keep_elements, size):
|
477 |
+
min_x, min_y = float('inf'), float('inf')
|
478 |
+
max_x, max_y = float('-inf'), float('-inf')
|
479 |
+
|
480 |
for i in keep_elements:
|
481 |
+
if i >= len(labels):
|
482 |
print("Problem with the index")
|
483 |
continue
|
484 |
+
|
485 |
+
element = labels[i]
|
486 |
+
if element in {None, 7, 13, 14, 15}:
|
487 |
continue
|
488 |
|
489 |
+
x, y = boxes[i][:2]
|
490 |
+
element_width, element_height = size[class_dict[labels[i]]]
|
|
|
491 |
|
492 |
min_x = min(min_x, x)
|
493 |
min_y = min(min_y, y)
|
|
|
496 |
|
497 |
return min_x, min_y, max_x, max_y
|
498 |
|
499 |
+
|
500 |
|
501 |
def calculate_pool_waypoints(idx, data, size, source_idx, target_idx, source_element, target_element):
|
502 |
# Get the bounding boxes of the source and target elements
|
|
|
517 |
element_mid_y = (element_box[1] + element_box[3]) / 2
|
518 |
# Connect the pool's bottom or top side to the target element's top or bottom center
|
519 |
if pool_box[3] < element_box[1]: # Pool is above the target element
|
520 |
+
waypoints = [(element_mid_x, pool_box[3]), (element_mid_x, element_box[1])]
|
521 |
else: # Pool is below the target element
|
522 |
+
waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1])]
|
523 |
else:
|
524 |
pool_box = target_box
|
525 |
element_box = (source_box[0], source_box[1], source_box[0]+size[source_element][0], source_box[1]+size[source_element][1])
|
|
|
528 |
|
529 |
# Connect the element's bottom or top center to the pool's top or bottom side
|
530 |
if pool_box[3] < element_box[1]: # Pool is above the target element
|
531 |
+
waypoints = [(element_mid_x, element_box[1]), (element_mid_x, pool_box[3])]
|
532 |
else: # Pool is below the target element
|
533 |
+
waypoints = [(element_mid_x, element_box[3]), (element_mid_x, pool_box[1])]
|
534 |
|
535 |
return waypoints
|
536 |
|