BenjiELCA commited on
Commit
813fdb6
1 Parent(s): b0e8a9d

correct bug

Browse files
Files changed (1) hide show
  1. app.py +1 -235
app.py CHANGED
@@ -2,246 +2,12 @@ import streamlit as st
2
  from torchvision.transforms import functional as F
3
  import gc
4
  import numpy as np
5
- from modules.htlm_webpage import display_bpmn_xml
6
- from streamlit_cropper import st_cropper
7
- from streamlit_image_select import image_select
8
- from streamlit_js_eval import streamlit_js_eval
9
- from streamlit_drawable_canvas import st_canvas
10
- from modules.streamlit_utils import *
11
- from glob import glob
12
-
13
- from streamlit_image_annotation import detection
14
- from modules.toXML import create_XML
15
- from modules.eval import develop_prediction, generate_data
16
- from modules.utils import class_dict, object_dict
17
-
18
- def configure_page():
19
- st.set_page_config(layout="wide")
20
- screen_width = streamlit_js_eval(js_expressions='screen.width', want_output=True, key='SCR')
21
- is_mobile = screen_width is not None and screen_width < 800
22
- return is_mobile, screen_width
23
-
24
- def display_banner(is_mobile):
25
- if is_mobile:
26
- st.image("./images/banner_mobile.png", use_column_width=True)
27
- else:
28
- st.image("./images/banner_desktop.png", use_column_width=True)
29
-
30
- def display_title(is_mobile):
31
- title = "Welcome on the BPMN AI model recognition app"
32
- if is_mobile:
33
- title = "Welcome on the mobile version of BPMN AI model recognition app"
34
- st.title(title)
35
-
36
- def display_sidebar():
37
- sidebar()
38
-
39
- def initialize_session_state():
40
- if 'pool_bboxes' not in st.session_state:
41
- st.session_state.pool_bboxes = []
42
- if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state:
43
- clear_memory()
44
- load_models()
45
-
46
- def load_example_image():
47
- with st.expander("Use example images"):
48
- img_selected = image_select(
49
- "If you have no image and just want to test the demo, click on one of these images",
50
- ["./images/none.jpg", "./images/example1.jpg", "./images/example2.jpg", "./images/example3.jpg", "./images/example4.jpg"],
51
- captions=["None", "Example 1", "Example 2", "Example 3", "Example 4"],
52
- index=0,
53
- use_container_width=False,
54
- return_value="original"
55
- )
56
- return img_selected
57
-
58
- def load_user_image(img_selected, is_mobile):
59
- if img_selected == './images/none.jpg':
60
- img_selected = None
61
-
62
- if img_selected is not None:
63
- uploaded_file = img_selected
64
- else:
65
- if is_mobile:
66
- uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"], accept_multiple_files=False)
67
- else:
68
- col1, col2 = st.columns(2)
69
- with col1:
70
- uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"])
71
-
72
- return uploaded_file
73
-
74
- def display_image(uploaded_file, screen_width, is_mobile):
75
-
76
- with st.spinner('Waiting for image display...'):
77
- original_image = get_image(uploaded_file)
78
- resized_image = original_image.resize((screen_width // 2, int(original_image.height * (screen_width // 2) / original_image.width)))
79
-
80
- if not is_mobile:
81
- cropped_image = crop_image(resized_image, original_image)
82
- else:
83
- st.image(resized_image, caption="Image", use_column_width=False, width=int(4/5 * screen_width))
84
- cropped_image = original_image
85
-
86
- return cropped_image
87
-
88
- def crop_image(resized_image, original_image):
89
- marge = 10
90
- cropped_box = st_cropper(
91
- resized_image,
92
- realtime_update=True,
93
- box_color='#0000FF',
94
- return_type='box',
95
- should_resize_image=False,
96
- default_coords=(marge, resized_image.width - marge, marge, resized_image.height - marge)
97
- )
98
- scale_x = original_image.width / resized_image.width
99
- scale_y = original_image.height / resized_image.height
100
- x0, y0, x1, y1 = int(cropped_box['left'] * scale_x), int(cropped_box['top'] * scale_y), int((cropped_box['left'] + cropped_box['width']) * scale_x), int((cropped_box['top'] + cropped_box['height']) * scale_y)
101
- cropped_image = original_image.crop((x0, y0, x1, y1))
102
- return cropped_image
103
-
104
- def get_score_threshold(is_mobile):
105
- col1, col2 = st.columns(2)
106
- with col1:
107
- st.session_state.score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5 if not is_mobile else 0.6, step=0.05)
108
-
109
- def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
110
- st.session_state.crop_image = cropped_image
111
- with st.spinner('Processing...'):
112
- perform_inference(
113
- st.session_state.model_object, st.session_state.model_arrow, st.session_state.crop_image,
114
- score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
115
- )
116
- st.balloons()
117
-
118
- def mix_new_pred(objects_pred, arrow_pred):
119
- # Initialize the list of lists for keypoints
120
- object_keypoints = []
121
-
122
- # Number of boxes
123
- num_boxes = len(objects_pred['boxes'])
124
-
125
- # Iterate over the number of boxes
126
- for _ in range(num_boxes):
127
- # Each box has 2 keypoints, both initialized to [0, 0, 0]
128
- keypoints = [[0, 0, 0], [0, 0, 0]]
129
- object_keypoints.append(keypoints)
130
 
131
- #concatenate the two predictions
132
- boxes = np.concatenate((objects_pred['boxes'], arrow_pred['boxes']))
133
- labels = np.concatenate((objects_pred['labels'], arrow_pred['labels']))
134
-
135
- return boxes, labels, keypoints
136
-
137
-
138
- def modify_results(percentage_text_dist_thresh=0.5):
139
- with st.expander("Method and Style modification (beta version)"):
140
- label_list = list(object_dict.values())
141
- bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
142
- for i in range(len(bboxes)):
143
- bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
144
- bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
145
- labels = [int(label) for label in st.session_state.prediction['labels']]
146
-
147
-
148
- # Filter boxes and labels where label is less than 12
149
- ignore_labels = [6, 7]
150
- object_bboxes = []
151
- object_labels = []
152
- arrow_bboxes = []
153
- arrow_labels = []
154
- for i in range(len(bboxes)):
155
- if labels[i] <= 12:
156
- object_bboxes.append(bboxes[i])
157
- object_labels.append(labels[i])
158
- else:
159
- arrow_bboxes.append(bboxes[i])
160
- arrow_labels.append(labels[i])
161
-
162
- print('Object bboxes:', object_bboxes)
163
- print('Object labels:', object_labels)
164
- print('Arrow bboxes:', arrow_bboxes)
165
- print('Arrow labels:', arrow_labels)
166
-
167
- original_obj_len = len(object_bboxes)
168
-
169
-
170
- uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
171
- scale = 2000 / uploaded_image.size[0]
172
- new_labels = detection(
173
- image=uploaded_image, bboxes=object_bboxes, labels=object_labels,
174
- label_list=label_list, line_width=3, width=2000, use_space=False
175
- )
176
-
177
- if new_labels is not None:
178
- new_lab = np.array([label['label_id'] for label in new_labels])
179
- # Convert back to original format
180
- bboxes = np.array([label['bbox'] for label in new_labels])
181
- for i in range(len(bboxes)):
182
- bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
183
- bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
184
- for i in range(len(arrow_bboxes)):
185
- arrow_bboxes[i][2] = arrow_bboxes[i][2] + arrow_bboxes[i][0]
186
- arrow_bboxes[i][3] = arrow_bboxes[i][3] + arrow_bboxes[i][1]
187
-
188
-
189
- new_bbox = np.concatenate((bboxes, arrow_bboxes))
190
- new_lab = np.concatenate((new_lab, arrow_labels))
191
-
192
- print('New labels:', new_lab)
193
-
194
- scores = st.session_state.prediction['scores']
195
- keypoints = st.session_state.prediction['keypoints']
196
-
197
- #delete element in keypoints to make it match the new number of boxes
198
- len_keypoints = len(keypoints)
199
- keypoints = keypoints.tolist()
200
- scores = scores.tolist()
201
-
202
-
203
- diff = original_obj_len-len(bboxes)
204
- if diff > 0:
205
- for i in range(diff):
206
- keypoints.pop(0)
207
- scores.pop(0)
208
- elif diff < 0:
209
- for i in range(-diff):
210
- keypoints.insert(0, [[0, 0, 0], [0, 0, 0]])
211
- scores.insert(0, 0.0)
212
-
213
- print('lenghts: ',len(bboxes), len(new_lab), len(scores), len(keypoints))
214
- keypoints = np.array(keypoints)
215
- scores = np.array(scores)
216
-
217
- #print('Old prediction:', st.session_state.prediction['keypoints'])
218
- boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, scores, keypoints, class_dict, correction=False)
219
-
220
- st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
221
- st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
222
 
223
- #print('New prediction:', st.session_state.prediction['keypoints'])
224
- st.rerun()
225
 
226
 
227
- def display_bpmn_modeler(is_mobile, screen_width):
228
- with st.spinner('Waiting for BPMN modeler...'):
229
- st.session_state.bpmn_xml = create_XML(
230
- st.session_state.prediction.copy(), st.session_state.text_mapping,
231
- st.session_state.size_scale, st.session_state.scale
232
- )
233
- display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5 * screen_width))
234
 
235
- def modeler_options(is_mobile):
236
- if not is_mobile:
237
- with st.expander("Options for BPMN modeler"):
238
- col1, col2 = st.columns(2)
239
- with col1:
240
- st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
241
- st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
242
- else:
243
- st.session_state.scale = 1.0
244
- st.session_state.size_scale = 1.0
245
 
246
  def main():
247
  is_mobile, screen_width = configure_page()
 
2
  from torchvision.transforms import functional as F
3
  import gc
4
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ from modules.streamlit_utils import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
8
 
9
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def main():
13
  is_mobile, screen_width = configure_page()