BenjiELCA commited on
Commit
e108fc3
·
1 Parent(s): 1b7de55

add some file to reduce code in the app.py

Browse files
Files changed (3) hide show
  1. app.py +11 -304
  2. modules/streamlit_utils.py +262 -0
  3. modules/toXML.py +36 -4
app.py CHANGED
@@ -1,155 +1,25 @@
1
  import streamlit as st
2
- from PIL import Image, ImageEnhance
3
- import torch
4
  from torchvision.transforms import functional as F
5
  import gc
6
- import psutil
7
  import copy
8
  import xml.etree.ElementTree as ET
9
  import numpy as np
10
  from xml.dom import minidom
11
- from pathlib import Path
12
- import gdown
13
 
14
  from modules.htlm_webpage import display_bpmn_xml
15
- from modules.OCR import text_prediction, filter_text, mapping_text, rescale
16
- from modules.utils import class_dict, arrow_dict, object_dict
17
- from modules.toXML import calculate_pool_bounds, add_diagram_elements, create_bpmn_object, create_flow_element
18
- from modules.display import draw_stream
19
- from modules.eval import full_prediction
20
- from modules.train import get_faster_rcnn_model, get_arrow_model
21
- from streamlit_image_comparison import image_comparison
22
  from streamlit_cropper import st_cropper
23
- from streamlit_drawable_canvas import st_canvas
24
  from streamlit_image_select import image_select
25
  from streamlit_js_eval import streamlit_js_eval
26
-
27
- def get_memory_usage():
28
- process = psutil.Process()
29
- mem_info = process.memory_info()
30
- return mem_info.rss / (1024 ** 2) # Return memory usage in MB
31
-
32
- def clear_memory():
33
- st.session_state.clear()
34
- gc.collect()
35
-
36
- # Function to read XML content from a file
37
- def read_xml_file(filepath):
38
- """ Read XML content from a file """
39
- with open(filepath, 'r', encoding='utf-8') as file:
40
- return file.read()
41
-
42
- def modif_box_pos(pred, size):
43
- modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction
44
-
45
- # Step 1: Calculate the center of each bounding box and group them by pool
46
- pool_groups = {}
47
- for pool_index, element_indices in pred['pool_dict'].items():
48
- pool_groups[pool_index] = []
49
- for i in element_indices:
50
- if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
51
- x1, y1, x2, y2 = modified_pred['boxes'][i]
52
- center = [(x1 + x2) / 2, (y1 + y2) / 2]
53
- pool_groups[pool_index].append((center, i))
54
-
55
- # Function to group centers within a specified range
56
- def group_centers(centers, axis, range_=50):
57
- groups = []
58
- while centers:
59
- center, idx = centers.pop(0)
60
- group = [(center, idx)]
61
- for other_center, other_idx in centers[:]:
62
- if abs(center[axis] - other_center[axis]) <= range_:
63
- group.append((other_center, other_idx))
64
- centers.remove((other_center, other_idx))
65
- groups.append(group)
66
- return groups
67
-
68
- # Step 2: Align the elements within each pool
69
- for pool_index, centers in pool_groups.items():
70
- # Group bounding boxes by checking if their centers are within ±50 pixels on the y-axis
71
- y_groups = group_centers(centers.copy(), axis=1)
72
-
73
- # Align the y-coordinates of the centers of grouped bounding boxes
74
- for group in y_groups:
75
- avg_y = sum([c[0][1] for c in group]) / len(group) # Calculate the average y-coordinate
76
- for (center, idx) in group:
77
- label = class_dict[modified_pred['labels'][idx]]
78
- if label in size:
79
- new_center = (center[0], avg_y) # Align the y-coordinate
80
- modified_pred['boxes'][idx] = [
81
- new_center[0] - size[label][0] / 2,
82
- new_center[1] - size[label][1] / 2,
83
- new_center[0] + size[label][0] / 2,
84
- new_center[1] + size[label][1] / 2
85
- ]
86
-
87
- # Recalculate centers after vertical alignment
88
- centers = []
89
- for group in y_groups:
90
- for center, idx in group:
91
- x1, y1, x2, y2 = modified_pred['boxes'][idx]
92
- center = [(x1 + x2) / 2, (y1 + y2) / 2]
93
- centers.append((center, idx))
94
-
95
- # Group bounding boxes by checking if their centers are within ±50 pixels on the x-axis
96
- x_groups = group_centers(centers.copy(), axis=0)
97
-
98
- # Align the x-coordinates of the centers of grouped bounding boxes
99
- for group in x_groups:
100
- avg_x = sum([c[0][0] for c in group]) / len(group) # Calculate the average x-coordinate
101
- for (center, idx) in group:
102
- label = class_dict[modified_pred['labels'][idx]]
103
- if label in size:
104
- new_center = (avg_x, center[1]) # Align the x-coordinate
105
- modified_pred['boxes'][idx] = [
106
- new_center[0] - size[label][0] / 2,
107
- modified_pred['boxes'][idx][1],
108
- new_center[0] + size[label][0] / 2,
109
- modified_pred['boxes'][idx][3]
110
- ]
111
-
112
- return modified_pred['boxes']
113
-
114
 
115
 
116
  # Function to create a BPMN XML file from prediction results
117
- def create_XML(full_pred, text_mapping, scale):
118
- namespaces = {
119
- 'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
120
- 'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
121
- 'di': 'http://www.omg.org/spec/DD/20100524/DI',
122
- 'dc': 'http://www.omg.org/spec/DD/20100524/DC',
123
- 'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
124
- }
125
 
 
126
 
127
- size_elements = {
128
- 'event': (st.session_state.size_scale*43.2, st.session_state.size_scale*43.2),
129
- 'task': (st.session_state.size_scale*120, st.session_state.size_scale*96),
130
- 'message': (st.session_state.size_scale*43.2, st.session_state.size_scale*43.2),
131
- 'messageEvent': (st.session_state.size_scale*43.2, st.session_state.size_scale*43.2),
132
- 'exclusiveGateway': (st.session_state.size_scale*60, st.session_state.size_scale*60),
133
- 'parallelGateway': (st.session_state.size_scale*60, st.session_state.size_scale*60),
134
- 'dataObject': ( st.session_state.size_scale*48, st.session_state.size_scale*72),
135
- 'dataStore': (st.session_state.size_scale*72, st.session_state.size_scale*72),
136
- 'subProcess': (st.session_state.size_scale*144, st.session_state.size_scale*108),
137
- 'eventBasedGateway': (st.session_state.size_scale*60, st.session_state.size_scale*60),
138
- 'timerEvent': (st.session_state.size_scale*48, st.session_state.size_scale*48),
139
- }
140
-
141
-
142
-
143
- definitions = ET.Element('bpmn:definitions', {
144
- 'xmlns:xsi': namespaces['xsi'],
145
- 'xmlns:bpmn': namespaces['bpmn'],
146
- 'xmlns:bpmndi': namespaces['bpmndi'],
147
- 'xmlns:di': namespaces['di'],
148
- 'xmlns:dc': namespaces['dc'],
149
- 'targetNamespace': "http://example.bpmn.com",
150
- 'id': "simpleExample"
151
- })
152
-
153
  #modify the boxes positions
154
  old_boxes = copy.deepcopy(full_pred)
155
 
@@ -165,8 +35,8 @@ def create_XML(full_pred, text_mapping, scale):
165
  bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
166
  bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
167
 
168
- full_pred['boxes'] = rescale(scale, old_boxes['boxes'])
169
- full_pred['boxes'] = modif_box_pos(full_pred, size_elements)
170
 
171
  # Add diagram elements for each pool
172
  for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
@@ -185,6 +55,7 @@ def create_XML(full_pred, text_mapping, scale):
185
 
186
  add_diagram_elements(bpmnplane, pool_id, min_x - 50, min_y - 50, pool_width, pool_height)
187
 
 
188
  # Create BPMN elements for each pool
189
  for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
190
  create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
@@ -206,163 +77,15 @@ def create_XML(full_pred, text_mapping, scale):
206
  reparsed = minidom.parseString(rough_string)
207
  pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
208
 
209
- full_pred['boxes'] = rescale(1/scale, full_pred['boxes'])
210
  full_pred['boxes'] = old_boxes
211
 
212
  return pretty_xml_as_string
213
 
214
 
215
- # Function to load the models only once and use session state to keep track of it
216
- def load_models():
217
- with st.spinner('Loading model...'):
218
- model_object = get_faster_rcnn_model(len(object_dict))
219
- model_arrow = get_arrow_model(len(arrow_dict),2)
220
-
221
- url_arrow = 'https://drive.google.com/uc?id=1vv1X_r_lZ8gnzMAIKxcVEb_T_Qb-NkyA'
222
- url_object = 'https://drive.google.com/uc?id=1b1bqogxqdPS-SnvaOfWJGV1I1qOrTKh5'
223
-
224
- # Define paths to save models
225
- output_arrow = 'model_arrow.pth'
226
- output_object = 'model_object.pth'
227
-
228
- # Download models using gdown
229
- if not Path(output_arrow).exists():
230
- # Download models using gdown
231
- gdown.download(url_arrow, output_arrow, quiet=False)
232
- else:
233
- print('Model arrow downloaded from local')
234
- if not Path(output_object).exists():
235
- gdown.download(url_object, output_object, quiet=False)
236
- else:
237
- print('Model object downloaded from local')
238
-
239
- # Load models
240
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
241
- model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
242
- model_object.load_state_dict(torch.load(output_object, map_location=device))
243
-
244
- st.session_state.model_loaded = True
245
- st.session_state.model_arrow = model_arrow
246
- st.session_state.model_object = model_object
247
-
248
- return model_object, model_arrow
249
-
250
- # Function to prepare the image for processing
251
- def prepare_image(image, pad=True, new_size=(1333, 1333)):
252
- original_size = image.size
253
- # Calculate scale to fit the new size while maintaining aspect ratio
254
- scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
255
- new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
256
- # Resize image to new scaled size
257
- image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
258
-
259
- if pad:
260
- enhancer = ImageEnhance.Brightness(image)
261
- image = enhancer.enhance(1.0) # Adjust the brightness if necessary
262
- # Pad the resized image to make it exactly the desired size
263
- padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
264
- image = F.pad(image, padding, fill=200, padding_mode='edge')
265
-
266
- return new_scaled_size, image
267
-
268
- # Function to display various options for image annotation
269
- def display_options(image, score_threshold, is_mobile, screen_width):
270
- col1, col2, col3, col4, col5 = st.columns(5)
271
- with col1:
272
- write_class = st.toggle("Write Class", value=True)
273
- draw_keypoints = st.toggle("Draw Keypoints", value=True)
274
- draw_boxes = st.toggle("Draw Boxes", value=True)
275
- with col2:
276
- draw_text = st.toggle("Draw Text", value=False)
277
- write_text = st.toggle("Write Text", value=False)
278
- draw_links = st.toggle("Draw Links", value=False)
279
- with col3:
280
- write_score = st.toggle("Write Score", value=True)
281
- write_idx = st.toggle("Write Index", value=False)
282
- with col4:
283
- # Define options for the dropdown menu
284
- dropdown_options = [list(class_dict.values())[i] for i in range(len(class_dict))]
285
- dropdown_options[0] = 'all'
286
- selected_option = st.selectbox("Show class", dropdown_options)
287
-
288
- # Draw the annotated image with selected options
289
- annotated_image = draw_stream(
290
- np.array(image), prediction=st.session_state.prediction, text_predictions=st.session_state.text_pred,
291
- draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
292
- write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_show=selected_option,
293
- score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
294
- )
295
-
296
- if is_mobile is True:
297
- width = screen_width
298
- else:
299
- width = screen_width//2
300
-
301
- # Display the original and annotated images side by side
302
- image_comparison(
303
- img1=annotated_image,
304
- img2=image,
305
- label1="Annotated Image",
306
- label2="Original Image",
307
- starting_position=99,
308
- width=width,
309
- )
310
-
311
- # Function to perform inference on the uploaded image using the loaded models
312
- def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width, iou_threshold=0.5, distance_treshold=30, percentage_text_dist_thresh=0.5):
313
- _, uploaded_image = prepare_image(image, pad=False)
314
-
315
- img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1])
316
-
317
- # Display original image
318
- if 'image_placeholder' not in st.session_state:
319
- image_placeholder = st.empty() # Create an empty placeholder
320
- if is_mobile is False:
321
- width = screen_width
322
- if is_mobile is False:
323
- width = screen_width//3
324
- image_placeholder.image(uploaded_image, caption='Original Image', width=width)
325
-
326
- # Prediction
327
- _, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
328
-
329
- # Perform OCR on the uploaded image
330
- ocr_results = text_prediction(uploaded_image)
331
-
332
- # Filter and map OCR results to prediction results
333
- st.session_state.text_pred = filter_text(ocr_results, threshold=0.6)
334
- st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
335
-
336
- # Remove the original image display
337
- image_placeholder.empty()
338
-
339
- # Force garbage collection
340
- gc.collect()
341
-
342
- @st.cache_data
343
- def get_image(uploaded_file):
344
- return Image.open(uploaded_file).convert('RGB')
345
-
346
-
347
-
348
  def main():
349
  st.set_page_config(layout="wide")
350
 
351
- # Apply CSS to change the sidebar width
352
- st.markdown(
353
- """
354
- <style>
355
- [data-testid="stSidebar"] {
356
- width: 350px;
357
- }
358
- [data-testid="stSidebar"][aria-expanded="true"] {
359
- width: 350px;
360
- }
361
- </style>
362
- """,
363
- unsafe_allow_html=True,
364
- )
365
-
366
  screen_width = streamlit_js_eval(js_expressions='screen.width', want_output = True, key = 'SCR')
367
  print("Screen width:", screen_width)
368
 
@@ -387,23 +110,7 @@ def main():
387
 
388
 
389
 
390
- # Sidebar content
391
- st.sidebar.header("This BPMN AI model recognition is proposed by: \n ELCA in collaboration with EPFL.")
392
- st.sidebar.subheader("Instructions:")
393
- st.sidebar.text("1. Upload you image")
394
- st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)")
395
- st.sidebar.text("3. Set the score threshold \n for prediction (default is 0.5)")
396
- st.sidebar.text("4. Click on 'Launch Prediction'")
397
- st.sidebar.text("5. You can now see the annotation \n and the BPMN XML result")
398
- st.sidebar.text("6. You can change the scale for \n the XML file (default is 1.0)")
399
- st.sidebar.text("7. You can modify and download \n the result in right format")
400
-
401
- st.sidebar.subheader("If there is an error, try to:")
402
- st.sidebar.text("1. Change the score threshold")
403
- st.sidebar.text("2. Re-crop the image by placing\n the BPMN diagram in the center\n of the image")
404
- st.sidebar.text("3. Re-Launch the prediction")
405
-
406
- st.sidebar.subheader("You can close this sidebar")
407
 
408
 
409
  # Display current memory usage
@@ -497,7 +204,7 @@ def main():
497
  st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
498
  else:
499
  st.session_state.size_scale = 1.0
500
- st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.scale)
501
  display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5*screen_width))
502
 
503
  gc.collect()
 
1
  import streamlit as st
 
 
2
  from torchvision.transforms import functional as F
3
  import gc
 
4
  import copy
5
  import xml.etree.ElementTree as ET
6
  import numpy as np
7
  from xml.dom import minidom
 
 
8
 
9
  from modules.htlm_webpage import display_bpmn_xml
10
+ from modules.utils import class_dict, rescale_boxes
11
+ from modules.toXML import calculate_pool_bounds, add_diagram_elements, create_bpmn_object, create_flow_element, get_size_elements, definitions
 
 
 
 
 
12
  from streamlit_cropper import st_cropper
 
13
  from streamlit_image_select import image_select
14
  from streamlit_js_eval import streamlit_js_eval
15
+ from modules.streamlit_utils import get_memory_usage, clear_memory, get_image, load_models, perform_inference, display_options, align_boxes, sidebar
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  # Function to create a BPMN XML file from prediction results
19
+ def create_XML(full_pred, text_mapping, size_scale, scale):
 
 
 
 
 
 
 
20
 
21
+ size_elements = get_size_elements(size_scale)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  #modify the boxes positions
24
  old_boxes = copy.deepcopy(full_pred)
25
 
 
35
  bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1')
36
  bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1')
37
 
38
+ full_pred['boxes'] = rescale_boxes(scale, old_boxes['boxes'])
39
+ full_pred['boxes'] = align_boxes(full_pred, size_elements)
40
 
41
  # Add diagram elements for each pool
42
  for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
 
55
 
56
  add_diagram_elements(bpmnplane, pool_id, min_x - 50, min_y - 50, pool_width, pool_height)
57
 
58
+
59
  # Create BPMN elements for each pool
60
  for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
61
  create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements)
 
77
  reparsed = minidom.parseString(rough_string)
78
  pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
79
 
80
+ full_pred['boxes'] = rescale_boxes(1/scale, full_pred['boxes'])
81
  full_pred['boxes'] = old_boxes
82
 
83
  return pretty_xml_as_string
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def main():
87
  st.set_page_config(layout="wide")
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  screen_width = streamlit_js_eval(js_expressions='screen.width', want_output = True, key = 'SCR')
90
  print("Screen width:", screen_width)
91
 
 
110
 
111
 
112
 
113
+ sidebar() # Display the sidebar
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  # Display current memory usage
 
204
  st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
205
  else:
206
  st.session_state.size_scale = 1.0
207
+ st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.size_scale, st.session_state.scale)
208
  display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5*screen_width))
209
 
210
  gc.collect()
modules/streamlit_utils.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image, ImageEnhance
3
+ import torch
4
+ from torchvision.transforms import functional as F
5
+ import gc
6
+ import psutil
7
+ import copy
8
+ import xml.etree.ElementTree as ET
9
+ import numpy as np
10
+ from pathlib import Path
11
+ import gdown
12
+
13
+
14
+ from modules.htlm_webpage import display_bpmn_xml
15
+ from modules.OCR import text_prediction, filter_text, mapping_text
16
+ from modules.utils import class_dict, arrow_dict, object_dict, rescale_boxes
17
+ from modules.display import draw_stream
18
+ from modules.eval import full_prediction
19
+ from modules.train import get_faster_rcnn_model, get_arrow_model
20
+ from streamlit_image_comparison import image_comparison
21
+
22
+
23
+
24
+ def get_memory_usage():
25
+ process = psutil.Process()
26
+ mem_info = process.memory_info()
27
+ return mem_info.rss / (1024 ** 2) # Return memory usage in MB
28
+
29
+ def clear_memory():
30
+ st.session_state.clear()
31
+ gc.collect()
32
+
33
+ def sidebar():# Sidebar content
34
+ st.sidebar.header("This BPMN AI model recognition is proposed by: \n ELCA in collaboration with EPFL.")
35
+ st.sidebar.subheader("Instructions:")
36
+ st.sidebar.text("1. Upload you image")
37
+ st.sidebar.text("2. Crop the image \n (try to put the BPMN diagram \n in the center of the image)")
38
+ st.sidebar.text("3. Set the score threshold \n for prediction (default is 0.5)")
39
+ st.sidebar.text("4. Click on 'Launch Prediction'")
40
+ st.sidebar.text("5. You can now see the annotation \n and the BPMN XML result")
41
+ st.sidebar.text("6. You can change the scale for \n the XML file (default is 1.0)")
42
+ st.sidebar.text("7. You can modify and download \n the result in right format")
43
+
44
+ st.sidebar.subheader("If there is an error, try to:")
45
+ st.sidebar.text("1. Change the score threshold")
46
+ st.sidebar.text("2. Re-crop the image by placing\n the BPMN diagram in the center\n of the image")
47
+ st.sidebar.text("3. Re-Launch the prediction")
48
+
49
+ st.sidebar.subheader("You can close this sidebar")
50
+
51
+
52
+ # Function to read XML content from a file
53
+ def read_xml_file(filepath):
54
+ """ Read XML content from a file """
55
+ with open(filepath, 'r', encoding='utf-8') as file:
56
+ return file.read()
57
+
58
+ def align_boxes(pred, size):
59
+ modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction
60
+
61
+ # Step 1: Calculate the center of each bounding box and group them by pool
62
+ pool_groups = {}
63
+ for pool_index, element_indices in pred['pool_dict'].items():
64
+ pool_groups[pool_index] = []
65
+ for i in element_indices:
66
+ if i > len(modified_pred['labels']):
67
+ continue
68
+ if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
69
+ x1, y1, x2, y2 = modified_pred['boxes'][i]
70
+ center = [(x1 + x2) / 2, (y1 + y2) / 2]
71
+ pool_groups[pool_index].append((center, i))
72
+
73
+ # Function to group centers within a specified range
74
+ def group_centers(centers, axis, range_=50):
75
+ groups = []
76
+ while centers:
77
+ center, idx = centers.pop(0)
78
+ group = [(center, idx)]
79
+ for other_center, other_idx in centers[:]:
80
+ if abs(center[axis] - other_center[axis]) <= range_:
81
+ group.append((other_center, other_idx))
82
+ centers.remove((other_center, other_idx))
83
+ groups.append(group)
84
+ return groups
85
+
86
+ # Step 2: Align the elements within each pool
87
+ for pool_index, centers in pool_groups.items():
88
+ # Group bounding boxes by checking if their centers are within ±50 pixels on the y-axis
89
+ y_groups = group_centers(centers.copy(), axis=1)
90
+
91
+ # Align the y-coordinates of the centers of grouped bounding boxes
92
+ for group in y_groups:
93
+ avg_y = sum([c[0][1] for c in group]) / len(group) # Calculate the average y-coordinate
94
+ for (center, idx) in group:
95
+ label = class_dict[modified_pred['labels'][idx]]
96
+ if label in size:
97
+ new_center = (center[0], avg_y) # Align the y-coordinate
98
+ modified_pred['boxes'][idx] = [
99
+ new_center[0] - size[label][0] / 2,
100
+ new_center[1] - size[label][1] / 2,
101
+ new_center[0] + size[label][0] / 2,
102
+ new_center[1] + size[label][1] / 2
103
+ ]
104
+
105
+ # Recalculate centers after vertical alignment
106
+ centers = []
107
+ for group in y_groups:
108
+ for center, idx in group:
109
+ x1, y1, x2, y2 = modified_pred['boxes'][idx]
110
+ center = [(x1 + x2) / 2, (y1 + y2) / 2]
111
+ centers.append((center, idx))
112
+
113
+ # Group bounding boxes by checking if their centers are within ±50 pixels on the x-axis
114
+ x_groups = group_centers(centers.copy(), axis=0)
115
+
116
+ # Align the x-coordinates of the centers of grouped bounding boxes
117
+ for group in x_groups:
118
+ avg_x = sum([c[0][0] for c in group]) / len(group) # Calculate the average x-coordinate
119
+ for (center, idx) in group:
120
+ label = class_dict[modified_pred['labels'][idx]]
121
+ if label in size:
122
+ new_center = (avg_x, center[1]) # Align the x-coordinate
123
+ modified_pred['boxes'][idx] = [
124
+ new_center[0] - size[label][0] / 2,
125
+ modified_pred['boxes'][idx][1],
126
+ new_center[0] + size[label][0] / 2,
127
+ modified_pred['boxes'][idx][3]
128
+ ]
129
+ return modified_pred['boxes']
130
+
131
+
132
+
133
+ # Function to load the models only once and use session state to keep track of it
134
+ def load_models():
135
+ with st.spinner('Loading model...'):
136
+ model_object = get_faster_rcnn_model(len(object_dict))
137
+ model_arrow = get_arrow_model(len(arrow_dict),2)
138
+
139
+ url_arrow = 'https://drive.google.com/uc?id=1vv1X_r_lZ8gnzMAIKxcVEb_T_Qb-NkyA'
140
+ url_object = 'https://drive.google.com/uc?id=1b1bqogxqdPS-SnvaOfWJGV1I1qOrTKh5'
141
+
142
+ # Define paths to save models
143
+ output_arrow = 'model_arrow.pth'
144
+ output_object = 'model_object.pth'
145
+
146
+ # Download models using gdown
147
+ if not Path(output_arrow).exists():
148
+ # Download models using gdown
149
+ gdown.download(url_arrow, output_arrow, quiet=False)
150
+ else:
151
+ print('Model arrow downloaded from local')
152
+ if not Path(output_object).exists():
153
+ gdown.download(url_object, output_object, quiet=False)
154
+ else:
155
+ print('Model object downloaded from local')
156
+
157
+ # Load models
158
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
159
+ model_arrow.load_state_dict(torch.load(output_arrow, map_location=device))
160
+ model_object.load_state_dict(torch.load(output_object, map_location=device))
161
+
162
+ st.session_state.model_loaded = True
163
+ st.session_state.model_arrow = model_arrow
164
+ st.session_state.model_object = model_object
165
+
166
+ return model_object, model_arrow
167
+
168
+ # Function to prepare the image for processing
169
+ def prepare_image(image, pad=True, new_size=(1333, 1333)):
170
+ original_size = image.size
171
+ # Calculate scale to fit the new size while maintaining aspect ratio
172
+ scale = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
173
+ new_scaled_size = (int(original_size[0] * scale), int(original_size[1] * scale))
174
+ # Resize image to new scaled size
175
+ image = F.resize(image, (new_scaled_size[1], new_scaled_size[0]))
176
+
177
+ if pad:
178
+ enhancer = ImageEnhance.Brightness(image)
179
+ image = enhancer.enhance(1.0) # Adjust the brightness if necessary
180
+ # Pad the resized image to make it exactly the desired size
181
+ padding = [0, 0, new_size[0] - new_scaled_size[0], new_size[1] - new_scaled_size[1]]
182
+ image = F.pad(image, padding, fill=200, padding_mode='edge')
183
+
184
+ return new_scaled_size, image
185
+
186
+ # Function to display various options for image annotation
187
+ def display_options(image, score_threshold, is_mobile, screen_width):
188
+ col1, col2, col3, col4, col5 = st.columns(5)
189
+ with col1:
190
+ write_class = st.toggle("Write Class", value=True)
191
+ draw_keypoints = st.toggle("Draw Keypoints", value=True)
192
+ draw_boxes = st.toggle("Draw Boxes", value=True)
193
+ with col2:
194
+ draw_text = st.toggle("Draw Text", value=False)
195
+ write_text = st.toggle("Write Text", value=False)
196
+ draw_links = st.toggle("Draw Links", value=False)
197
+ with col3:
198
+ write_score = st.toggle("Write Score", value=True)
199
+ write_idx = st.toggle("Write Index", value=False)
200
+ with col4:
201
+ # Define options for the dropdown menu
202
+ dropdown_options = [list(class_dict.values())[i] for i in range(len(class_dict))]
203
+ dropdown_options[0] = 'all'
204
+ selected_option = st.selectbox("Show class", dropdown_options)
205
+
206
+ # Draw the annotated image with selected options
207
+ annotated_image = draw_stream(
208
+ np.array(image), prediction=st.session_state.prediction, text_predictions=st.session_state.text_pred,
209
+ draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
210
+ write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_show=selected_option,
211
+ score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
212
+ )
213
+
214
+ if is_mobile is True:
215
+ width = screen_width
216
+ else:
217
+ width = screen_width//2
218
+
219
+ # Display the original and annotated images side by side
220
+ image_comparison(
221
+ img1=annotated_image,
222
+ img2=image,
223
+ label1="Annotated Image",
224
+ label2="Original Image",
225
+ starting_position=99,
226
+ width=width,
227
+ )
228
+
229
+ # Function to perform inference on the uploaded image using the loaded models
230
+ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width, iou_threshold=0.5, distance_treshold=30, percentage_text_dist_thresh=0.5):
231
+ _, uploaded_image = prepare_image(image, pad=False)
232
+
233
+ img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1])
234
+
235
+ # Display original image
236
+ if 'image_placeholder' not in st.session_state:
237
+ image_placeholder = st.empty() # Create an empty placeholder
238
+ if is_mobile is False:
239
+ width = screen_width
240
+ if is_mobile is False:
241
+ width = screen_width//3
242
+ image_placeholder.image(uploaded_image, caption='Original Image', width=width)
243
+
244
+ # Prediction
245
+ _, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
246
+
247
+ # Perform OCR on the uploaded image
248
+ ocr_results = text_prediction(uploaded_image)
249
+
250
+ # Filter and map OCR results to prediction results
251
+ st.session_state.text_pred = filter_text(ocr_results, threshold=0.6)
252
+ st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
253
+
254
+ # Remove the original image display
255
+ image_placeholder.empty()
256
+
257
+ # Force garbage collection
258
+ gc.collect()
259
+
260
+ @st.cache_data
261
+ def get_image(uploaded_file):
262
+ return Image.open(uploaded_file).convert('RGB')
modules/toXML.py CHANGED
@@ -1,5 +1,41 @@
1
  import xml.etree.ElementTree as ET
2
  from modules.utils import class_dict, error, warning
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def rescale(scale, boxes):
5
  for i in range(len(boxes)):
@@ -184,13 +220,10 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
184
  element = ET.SubElement(process, 'bpmn:endEvent', id=element_id, name=text_mapping[element_id])
185
 
186
  status, datasAssociation_idx = check_data_association(i, data['links'], data['labels'], keep_elements)
187
- print('status', status, datasAssociation_idx, element_id)
188
  if len(status) != 0:
189
- print('ici')
190
  for state, dataAssociation_idx in zip(status, datasAssociation_idx):
191
  # Handle Data Input Association
192
  if state == 'input':
193
- print('input')
194
  dataObject_idx = links[dataAssociation_idx][0]
195
  dataObject_name = elements[dataObject_idx]
196
  dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
@@ -200,7 +233,6 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
200
 
201
  # Handle Data Output Association
202
  elif state == 'output':
203
- print('output')
204
  dataObject_idx = links[dataAssociation_idx][1]
205
  dataObject_name = elements[dataObject_idx]
206
  dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
 
1
  import xml.etree.ElementTree as ET
2
  from modules.utils import class_dict, error, warning
3
+ import streamlit as st
4
+
5
+ namespaces = {
6
+ 'bpmn': 'http://www.omg.org/spec/BPMN/20100524/MODEL',
7
+ 'bpmndi': 'http://www.omg.org/spec/BPMN/20100524/DI',
8
+ 'di': 'http://www.omg.org/spec/DD/20100524/DI',
9
+ 'dc': 'http://www.omg.org/spec/DD/20100524/DC',
10
+ 'xsi': 'http://www.w3.org/2001/XMLSchema-instance'
11
+ }
12
+
13
+
14
+ definitions = ET.Element('bpmn:definitions', {
15
+ 'xmlns:xsi': namespaces['xsi'],
16
+ 'xmlns:bpmn': namespaces['bpmn'],
17
+ 'xmlns:bpmndi': namespaces['bpmndi'],
18
+ 'xmlns:di': namespaces['di'],
19
+ 'xmlns:dc': namespaces['dc'],
20
+ 'targetNamespace': "http://example.bpmn.com",
21
+ 'id': "simpleExample"
22
+ })
23
+
24
+ def get_size_elements(size_scale):
25
+ size_elements = {
26
+ 'event': (size_scale*43.2, size_scale*43.2),
27
+ 'task': (size_scale*120, size_scale*96),
28
+ 'message': (size_scale*43.2, size_scale*43.2),
29
+ 'messageEvent': (size_scale*43.2, size_scale*43.2),
30
+ 'exclusiveGateway': (size_scale*60, size_scale*60),
31
+ 'parallelGateway': (size_scale*60, size_scale*60),
32
+ 'dataObject': (size_scale*48, size_scale*72),
33
+ 'dataStore': (size_scale*72, size_scale*72),
34
+ 'subProcess': (size_scale*144, size_scale*108),
35
+ 'eventBasedGateway': (size_scale*60, size_scale*60),
36
+ 'timerEvent': (size_scale*48, size_scale*48),
37
+ }
38
+ return size_elements
39
 
40
  def rescale(scale, boxes):
41
  for i in range(len(boxes)):
 
220
  element = ET.SubElement(process, 'bpmn:endEvent', id=element_id, name=text_mapping[element_id])
221
 
222
  status, datasAssociation_idx = check_data_association(i, data['links'], data['labels'], keep_elements)
 
223
  if len(status) != 0:
 
224
  for state, dataAssociation_idx in zip(status, datasAssociation_idx):
225
  # Handle Data Input Association
226
  if state == 'input':
 
227
  dataObject_idx = links[dataAssociation_idx][0]
228
  dataObject_name = elements[dataObject_idx]
229
  dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'
 
233
 
234
  # Handle Data Output Association
235
  elif state == 'output':
 
236
  dataObject_idx = links[dataAssociation_idx][1]
237
  dataObject_name = elements[dataObject_idx]
238
  dataObject_ref = f'DataObjectReference_{dataObject_name.split("_")[1]}'