import streamlit as st from torchvision.transforms import functional as F import gc import copy import xml.etree.ElementTree as ET import numpy as np from xml.dom import minidom from modules.htlm_webpage import display_bpmn_xml from modules.utils import class_dict, rescale_boxes from modules.toXML import calculate_pool_bounds, add_diagram_elements, create_bpmn_object, create_flow_element, get_size_elements, definitions from streamlit_cropper import st_cropper from streamlit_image_select import image_select from streamlit_js_eval import streamlit_js_eval from modules.streamlit_utils import get_memory_usage, clear_memory, get_image, load_models, perform_inference, display_options, align_boxes, sidebar # Function to create a BPMN XML file from prediction results def create_XML(full_pred, text_mapping, size_scale, scale): size_elements = get_size_elements(size_scale) #modify the boxes positions old_boxes = copy.deepcopy(full_pred) # Create BPMN collaboration element collaboration = ET.SubElement(definitions, 'bpmn:collaboration', id='collaboration_1') # Create BPMN process elements process = [] for idx in range(len(full_pred['pool_dict'].items())): process_id = f'process_{idx+1}' process.append(ET.SubElement(definitions, 'bpmn:process', id=process_id, isExecutable='false', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]])) bpmndi = ET.SubElement(definitions, 'bpmndi:BPMNDiagram', id='BPMNDiagram_1') bpmnplane = ET.SubElement(bpmndi, 'bpmndi:BPMNPlane', id='BPMNPlane_1', bpmnElement='collaboration_1') full_pred['boxes'] = rescale_boxes(scale, old_boxes['boxes']) full_pred['boxes'] = align_boxes(full_pred, size_elements) # Add diagram elements for each pool for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()): pool_id = f'participant_{idx+1}' pool = ET.SubElement(collaboration, 'bpmn:participant', id=pool_id, processRef=f'process_{idx+1}', name=text_mapping[full_pred['BPMN_id'][list(full_pred['pool_dict'].keys())[idx]]]) # Calculate the bounding box for the pool if len(keep_elements) == 0: min_x, min_y, max_x, max_y = full_pred['boxes'][pool_index] pool_width = max_x - min_x pool_height = max_y - min_y else: min_x, min_y, max_x, max_y = calculate_pool_bounds(full_pred, keep_elements, size_elements) pool_width = max_x - min_x + 100 # Adding padding pool_height = max_y - min_y + 100 # Adding padding add_diagram_elements(bpmnplane, pool_id, min_x - 50, min_y - 50, pool_width, pool_height) # Create BPMN elements for each pool for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()): create_bpmn_object(process[idx], bpmnplane, text_mapping, definitions, size_elements, full_pred, keep_elements) # Create message flow elements message_flows = [i for i, label in enumerate(full_pred['labels']) if class_dict[label] == 'messageFlow'] for idx in message_flows: create_flow_element(bpmnplane, text_mapping, idx, size_elements, full_pred, collaboration, message=True) # Create sequence flow elements for idx, (pool_index, keep_elements) in enumerate(full_pred['pool_dict'].items()): for i in keep_elements: if full_pred['labels'][i] == list(class_dict.values()).index('sequenceFlow'): create_flow_element(bpmnplane, text_mapping, i, size_elements, full_pred, process[idx], message=False) # Generate pretty XML string tree = ET.ElementTree(definitions) rough_string = ET.tostring(definitions, 'utf-8') reparsed = minidom.parseString(rough_string) pretty_xml_as_string = reparsed.toprettyxml(indent=" ") full_pred['boxes'] = rescale_boxes(1/scale, full_pred['boxes']) full_pred['boxes'] = old_boxes return pretty_xml_as_string def main(): st.set_page_config(layout="wide") screen_width = streamlit_js_eval(js_expressions='screen.width', want_output = True, key = 'SCR') print("Screen width:", screen_width) if screen_width is not None and screen_width < 800: is_mobile = True print('Mobile version') else: is_mobile = False print('Desktop version') # Add your company logo banner if is_mobile: st.image("./images/banner_mobile.png", use_column_width=True) else: st.image("./images/banner_desktop.png", use_column_width=True) # Use is_mobile flag in your logic if is_mobile: st.title(f"Welcome on the mobile version of BPMN AI model recognition app") else: st.title(f"Welcome on BPMN AI model recognition app") sidebar() # Display the sidebar # Display current memory usage memory_usage = get_memory_usage() print(f"Current memory usage: {memory_usage:.2f} MB") # Initialize the session state for storing pool bounding boxes if 'pool_bboxes' not in st.session_state: st.session_state.pool_bboxes = [] # Load the models using the defined function if 'model_object' not in st.session_state or 'model_arrow' not in st.session_state: clear_memory() _, _ = load_models() model_arrow = st.session_state.model_arrow model_object = st.session_state.model_object with st.expander("Use example images"): img_selected = image_select("If you have no image and just want to test the demo, click on one of these images", ["./images/none.jpg", "./images/example1.jpg", "./images/example2.jpg", "./images/example3.jpg", "./images/example4.jpg"], captions=["None", "Example 1", "Example 2", "Example 3", "Example 4"], index=0, use_container_width=False, return_value="original") if img_selected== './images/none.jpg': print('No example image selected') img_selected = None if is_mobile==False: #Create the layout for the app col1, col2 = st.columns(2) with col1: if img_selected is not None: uploaded_file = img_selected else: uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"]) else: if img_selected is not None: uploaded_file = img_selected else: uploaded_file = st.file_uploader("Choose an image from my computer...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: with st.spinner('Waiting for image display...'): original_image = get_image(uploaded_file) resized_image = original_image.resize((screen_width // 3, int(original_image.height * (screen_width // 3) / original_image.width))) if not is_mobile: col1, col2 = st.columns(2) with col1: marge=10 cropped_box = st_cropper( resized_image, realtime_update=True, box_color='#0000FF', return_type='box', should_resize_image=False, default_coords=(marge, resized_image.width-marge, marge, resized_image.height-marge) ) scale_x = original_image.width / resized_image.width scale_y = original_image.height / resized_image.height x0, y0, x1, y1 = int(cropped_box['left'] * scale_x), int(cropped_box['top'] * scale_y), int((cropped_box['left'] + cropped_box['width']) * scale_x), int((cropped_box['top'] + cropped_box['height']) * scale_y) cropped_image = original_image.crop((x0, y0, x1, y1)) with col2: st.image(cropped_image, caption="Cropped Image", use_column_width=False, width=int(screen_width//4)) else: st.image(resized_image, caption="Image", use_column_width=False, width=int(4/5*screen_width)) cropped_image = original_image if cropped_image is not None: if is_mobile is False: col1, col2 = st.columns(2) with col1: score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.5, step=0.05) else: score_threshold = st.slider("Set score threshold for prediction", min_value=0.0, max_value=1.0, value=0.6, step=0.05) if st.button("Launch Prediction"): st.session_state.crop_image = cropped_image with st.spinner('Processing...'): perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5) st.balloons() if 'prediction' in st.session_state and uploaded_file is not None: with st.spinner('Waiting for result display...'): display_options(st.session_state.crop_image, score_threshold, is_mobile, int(5/6*screen_width)) with st.spinner('Waiting for BPMN modeler...'): col1, col2 = st.columns(2) with col1: st.session_state.scale = st.slider("Set distance scale for XML file", min_value=0.1, max_value=2.0, value=1.0, step=0.1) if is_mobile is False: st.session_state.size_scale = st.slider("Set size object scale for XML file", min_value=0.5, max_value=2.0, value=1.0, step=0.1) else: st.session_state.size_scale = 1.0 st.session_state.bpmn_xml = create_XML(st.session_state.prediction.copy(), st.session_state.text_mapping, st.session_state.size_scale, st.session_state.scale) display_bpmn_xml(st.session_state.bpmn_xml, is_mobile=is_mobile, screen_width=int(4/5*screen_width)) gc.collect() if __name__ == "__main__": print('Starting the app...') main()