Spaces:
Running
Running
add verification of length and ready for demo
Browse files- app.py +84 -10
- modules/display.py +2 -0
- modules/streamlit_utils.py +1 -1
- modules/toXML.py +10 -2
app.py
CHANGED
@@ -13,7 +13,7 @@ from glob import glob
|
|
13 |
from streamlit_image_annotation import detection
|
14 |
from modules.toXML import create_XML
|
15 |
from modules.eval import develop_prediction, generate_data
|
16 |
-
from modules.utils import class_dict
|
17 |
|
18 |
def configure_page():
|
19 |
st.set_page_config(layout="wide")
|
@@ -114,41 +114,114 @@ def launch_prediction(cropped_image, score_threshold, is_mobile, screen_width):
|
|
114 |
score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
|
115 |
)
|
116 |
st.balloons()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
|
119 |
def modify_results(percentage_text_dist_thresh=0.5):
|
120 |
with st.expander("Method and Style modification (beta version)"):
|
121 |
-
label_list = list(
|
122 |
bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
|
123 |
for i in range(len(bboxes)):
|
124 |
bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
|
125 |
bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
|
126 |
labels = [int(label) for label in st.session_state.prediction['labels']]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
|
128 |
scale = 2000 / uploaded_image.size[0]
|
129 |
new_labels = detection(
|
130 |
-
image=uploaded_image, bboxes=
|
131 |
label_list=label_list, line_width=3, width=2000, use_space=False
|
132 |
)
|
133 |
|
134 |
if new_labels is not None:
|
135 |
-
new_lab = np.array([label['label_id'] for label in new_labels])
|
136 |
-
|
137 |
# Convert back to original format
|
138 |
bboxes = np.array([label['bbox'] for label in new_labels])
|
139 |
for i in range(len(bboxes)):
|
140 |
bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
|
141 |
bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
scores = st.session_state.prediction['scores']
|
144 |
keypoints = st.session_state.prediction['keypoints']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
#print('Old prediction:', st.session_state.prediction['keypoints'])
|
146 |
-
boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(
|
147 |
|
148 |
st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
|
149 |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
150 |
|
151 |
#print('New prediction:', st.session_state.prediction['keypoints'])
|
|
|
152 |
|
153 |
|
154 |
def display_bpmn_modeler(is_mobile, screen_width):
|
@@ -186,15 +259,16 @@ def main():
|
|
186 |
|
187 |
if cropped_image is not None:
|
188 |
get_score_threshold(is_mobile)
|
189 |
-
if st.button("Launch Prediction"):
|
190 |
launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
|
|
|
191 |
st.rerun()
|
192 |
|
193 |
if 'prediction' in st.session_state and uploaded_file:
|
194 |
-
if st.button("π Refresh image"):
|
195 |
-
st.rerun()
|
196 |
|
197 |
-
with st.expander("Show result"):
|
198 |
with st.spinner('Waiting for result display...'):
|
199 |
display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))
|
200 |
|
|
|
13 |
from streamlit_image_annotation import detection
|
14 |
from modules.toXML import create_XML
|
15 |
from modules.eval import develop_prediction, generate_data
|
16 |
+
from modules.utils import class_dict, object_dict
|
17 |
|
18 |
def configure_page():
|
19 |
st.set_page_config(layout="wide")
|
|
|
114 |
score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5
|
115 |
)
|
116 |
st.balloons()
|
117 |
+
|
118 |
+
def mix_new_pred(objects_pred, arrow_pred):
|
119 |
+
# Initialize the list of lists for keypoints
|
120 |
+
object_keypoints = []
|
121 |
+
|
122 |
+
# Number of boxes
|
123 |
+
num_boxes = len(objects_pred['boxes'])
|
124 |
+
|
125 |
+
# Iterate over the number of boxes
|
126 |
+
for _ in range(num_boxes):
|
127 |
+
# Each box has 2 keypoints, both initialized to [0, 0, 0]
|
128 |
+
keypoints = [[0, 0, 0], [0, 0, 0]]
|
129 |
+
object_keypoints.append(keypoints)
|
130 |
+
|
131 |
+
#concatenate the two predictions
|
132 |
+
boxes = np.concatenate((objects_pred['boxes'], arrow_pred['boxes']))
|
133 |
+
labels = np.concatenate((objects_pred['labels'], arrow_pred['labels']))
|
134 |
+
|
135 |
+
return boxes, labels, keypoints
|
136 |
|
137 |
|
138 |
def modify_results(percentage_text_dist_thresh=0.5):
|
139 |
with st.expander("Method and Style modification (beta version)"):
|
140 |
+
label_list = list(object_dict.values())
|
141 |
bboxes = [[int(coord) for coord in box] for box in st.session_state.prediction['boxes']]
|
142 |
for i in range(len(bboxes)):
|
143 |
bboxes[i][2] = bboxes[i][2] - bboxes[i][0]
|
144 |
bboxes[i][3] = bboxes[i][3] - bboxes[i][1]
|
145 |
labels = [int(label) for label in st.session_state.prediction['labels']]
|
146 |
+
|
147 |
+
|
148 |
+
# Filter boxes and labels where label is less than 12
|
149 |
+
ignore_labels = [6, 7]
|
150 |
+
object_bboxes = []
|
151 |
+
object_labels = []
|
152 |
+
arrow_bboxes = []
|
153 |
+
arrow_labels = []
|
154 |
+
for i in range(len(bboxes)):
|
155 |
+
if labels[i] <= 12:
|
156 |
+
object_bboxes.append(bboxes[i])
|
157 |
+
object_labels.append(labels[i])
|
158 |
+
else:
|
159 |
+
arrow_bboxes.append(bboxes[i])
|
160 |
+
arrow_labels.append(labels[i])
|
161 |
+
|
162 |
+
print('Object bboxes:', object_bboxes)
|
163 |
+
print('Object labels:', object_labels)
|
164 |
+
print('Arrow bboxes:', arrow_bboxes)
|
165 |
+
print('Arrow labels:', arrow_labels)
|
166 |
+
|
167 |
+
original_obj_len = len(object_bboxes)
|
168 |
+
|
169 |
+
|
170 |
uploaded_image = prepare_image(st.session_state.crop_image, new_size=(1333, 1333), pad=False)
|
171 |
scale = 2000 / uploaded_image.size[0]
|
172 |
new_labels = detection(
|
173 |
+
image=uploaded_image, bboxes=object_bboxes, labels=object_labels,
|
174 |
label_list=label_list, line_width=3, width=2000, use_space=False
|
175 |
)
|
176 |
|
177 |
if new_labels is not None:
|
178 |
+
new_lab = np.array([label['label_id'] for label in new_labels])
|
|
|
179 |
# Convert back to original format
|
180 |
bboxes = np.array([label['bbox'] for label in new_labels])
|
181 |
for i in range(len(bboxes)):
|
182 |
bboxes[i][2] = bboxes[i][2] + bboxes[i][0]
|
183 |
bboxes[i][3] = bboxes[i][3] + bboxes[i][1]
|
184 |
+
for i in range(len(arrow_bboxes)):
|
185 |
+
arrow_bboxes[i][2] = arrow_bboxes[i][2] + arrow_bboxes[i][0]
|
186 |
+
arrow_bboxes[i][3] = arrow_bboxes[i][3] + arrow_bboxes[i][1]
|
187 |
+
|
188 |
+
|
189 |
+
new_bbox = np.concatenate((bboxes, arrow_bboxes))
|
190 |
+
new_lab = np.concatenate((new_lab, arrow_labels))
|
191 |
+
|
192 |
+
print('New labels:', new_lab)
|
193 |
|
194 |
scores = st.session_state.prediction['scores']
|
195 |
keypoints = st.session_state.prediction['keypoints']
|
196 |
+
|
197 |
+
#delete element in keypoints to make it match the new number of boxes
|
198 |
+
len_keypoints = len(keypoints)
|
199 |
+
keypoints = keypoints.tolist()
|
200 |
+
scores = scores.tolist()
|
201 |
+
|
202 |
+
|
203 |
+
diff = original_obj_len-len(bboxes)
|
204 |
+
if diff > 0:
|
205 |
+
for i in range(diff):
|
206 |
+
keypoints.pop(0)
|
207 |
+
scores.pop(0)
|
208 |
+
elif diff < 0:
|
209 |
+
for i in range(-diff):
|
210 |
+
keypoints.insert(0, [[0, 0, 0], [0, 0, 0]])
|
211 |
+
scores.insert(0, 0.0)
|
212 |
+
|
213 |
+
print('lenghts: ',len(bboxes), len(new_lab), len(scores), len(keypoints))
|
214 |
+
keypoints = np.array(keypoints)
|
215 |
+
scores = np.array(scores)
|
216 |
+
|
217 |
#print('Old prediction:', st.session_state.prediction['keypoints'])
|
218 |
+
boxes, labels, scores, keypoints, flow_links, best_points, pool_dict = develop_prediction(new_bbox, new_lab, scores, keypoints, class_dict, correction=False)
|
219 |
|
220 |
st.session_state.prediction = generate_data(st.session_state.prediction['image'], boxes, labels, scores, keypoints, flow_links, best_points, pool_dict, class_dict)
|
221 |
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
222 |
|
223 |
#print('New prediction:', st.session_state.prediction['keypoints'])
|
224 |
+
st.rerun()
|
225 |
|
226 |
|
227 |
def display_bpmn_modeler(is_mobile, screen_width):
|
|
|
259 |
|
260 |
if cropped_image is not None:
|
261 |
get_score_threshold(is_mobile)
|
262 |
+
if st.button("π Launch Prediction"):
|
263 |
launch_prediction(cropped_image, st.session_state.score_threshold, is_mobile, screen_width)
|
264 |
+
st.session_state.original_prediction = st.session_state.prediction.copy()
|
265 |
st.rerun()
|
266 |
|
267 |
if 'prediction' in st.session_state and uploaded_file:
|
268 |
+
#if st.button("π Refresh image"):
|
269 |
+
#st.rerun()
|
270 |
|
271 |
+
with st.expander("Show result of prediction"):
|
272 |
with st.spinner('Waiting for result display...'):
|
273 |
display_options(st.session_state.crop_image, st.session_state.score_threshold, is_mobile, int(5/6 * screen_width))
|
274 |
|
modules/display.py
CHANGED
@@ -94,6 +94,8 @@ def draw_stream(image,
|
|
94 |
# Draw keypoints if available
|
95 |
if draw_keypoints and 'keypoints' in prediction:
|
96 |
for i in range(len(prediction['keypoints'])):
|
|
|
|
|
97 |
kp = prediction['keypoints'][i]
|
98 |
for j in range(kp.shape[0]):
|
99 |
if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
|
|
|
94 |
# Draw keypoints if available
|
95 |
if draw_keypoints and 'keypoints' in prediction:
|
96 |
for i in range(len(prediction['keypoints'])):
|
97 |
+
if i >= len(prediction['keypoints']):
|
98 |
+
continue
|
99 |
kp = prediction['keypoints'][i]
|
100 |
for j in range(kp.shape[0]):
|
101 |
if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
|
modules/streamlit_utils.py
CHANGED
@@ -130,7 +130,7 @@ def display_options(image, score_threshold, is_mobile, screen_width):
|
|
130 |
|
131 |
# Draw the annotated image with selected options
|
132 |
annotated_image = draw_stream(
|
133 |
-
np.array(image), prediction=st.session_state.
|
134 |
draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
|
135 |
write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_show=selected_option,
|
136 |
score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
|
|
|
130 |
|
131 |
# Draw the annotated image with selected options
|
132 |
annotated_image = draw_stream(
|
133 |
+
np.array(image), prediction=st.session_state.original_prediction, text_predictions=st.session_state.text_pred,
|
134 |
draw_keypoints=draw_keypoints, draw_boxes=draw_boxes, draw_links=draw_links, draw_twins=False, draw_grouped_text=draw_text,
|
135 |
write_class=write_class, write_text=write_text, keypoints_correction=True, write_idx=write_idx, only_show=selected_option,
|
136 |
score_threshold=score_threshold, write_score=write_score, resize=True, return_image=True, axis=True
|
modules/toXML.py
CHANGED
@@ -13,7 +13,7 @@ def align_boxes(pred, size):
|
|
13 |
for pool_index, element_indices in pred['pool_dict'].items():
|
14 |
pool_groups[pool_index] = []
|
15 |
for i in element_indices:
|
16 |
-
if i
|
17 |
continue
|
18 |
if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
|
19 |
x1, y1, x2, y2 = modified_pred['boxes'][i]
|
@@ -138,7 +138,7 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
138 |
pool_width = max_x - min_x + 100 # Adding padding
|
139 |
pool_height = max_y - min_y + 100 # Adding padding
|
140 |
#check area
|
141 |
-
if pool_width <
|
142 |
print("The pool is too small, please add more elements or increase the scale")
|
143 |
continue
|
144 |
|
@@ -157,6 +157,9 @@ def create_XML(full_pred, text_mapping, size_scale, scale):
|
|
157 |
# Create sequence flow elements
|
158 |
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
159 |
for i in keep_elements:
|
|
|
|
|
|
|
160 |
if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
|
161 |
create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
|
162 |
|
@@ -259,6 +262,8 @@ def add_diagram_edge(parent, element_id, waypoints):
|
|
259 |
'id': element_id + '_di'
|
260 |
})
|
261 |
for x, y in waypoints:
|
|
|
|
|
262 |
ET.SubElement(edge, 'di:waypoint', attrib={
|
263 |
'x': str(x),
|
264 |
'y': str(y)
|
@@ -312,6 +317,9 @@ def create_bpmn_object(process, bpmnplane, text_mapping, definitions, size, data
|
|
312 |
links = data['links']
|
313 |
|
314 |
for i in keep_elements:
|
|
|
|
|
|
|
315 |
element_id = elements[i]
|
316 |
|
317 |
if element_id is None:
|
|
|
13 |
for pool_index, element_indices in pred['pool_dict'].items():
|
14 |
pool_groups[pool_index] = []
|
15 |
for i in element_indices:
|
16 |
+
if i >= len(modified_pred['labels']):
|
17 |
continue
|
18 |
if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
|
19 |
x1, y1, x2, y2 = modified_pred['boxes'][i]
|
|
|
138 |
pool_width = max_x - min_x + 100 # Adding padding
|
139 |
pool_height = max_y - min_y + 100 # Adding padding
|
140 |
#check area
|
141 |
+
if pool_width < 300 or pool_height < 30:
|
142 |
print("The pool is too small, please add more elements or increase the scale")
|
143 |
continue
|
144 |
|
|
|
157 |
# Create sequence flow elements
|
158 |
for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()):
|
159 |
for i in keep_elements:
|
160 |
+
if i >= len(full_pred['labels']):
|
161 |
+
print("Problem with the index")
|
162 |
+
continue
|
163 |
if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'):
|
164 |
create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False)
|
165 |
|
|
|
262 |
'id': element_id + '_di'
|
263 |
})
|
264 |
for x, y in waypoints:
|
265 |
+
if x is None or y is None:
|
266 |
+
return
|
267 |
ET.SubElement(edge, 'di:waypoint', attrib={
|
268 |
'x': str(x),
|
269 |
'y': str(y)
|
|
|
317 |
links = data['links']
|
318 |
|
319 |
for i in keep_elements:
|
320 |
+
if i >= len(elements):
|
321 |
+
print("Problem with the index")
|
322 |
+
continue
|
323 |
element_id = elements[i]
|
324 |
|
325 |
if element_id is None:
|