BenjiELCA commited on
Commit
ca71e96
·
1 Parent(s): 5950696

correct bug with Vizifile

Browse files
Files changed (2) hide show
  1. modules/toWizard.py +10 -2
  2. modules/train.py +3 -4
modules/toWizard.py CHANGED
@@ -1,6 +1,7 @@
1
  import xml.etree.ElementTree as ET
2
  from modules.utils import class_dict
3
  from xml.dom import minidom
 
4
 
5
  def rescale(scale, boxes):
6
  for i in range(len(boxes)):
@@ -52,6 +53,9 @@ def check_end(val):
52
 
53
  def connect(data, text_mapping, i):
54
  target_idx = data['links'][i][1]
 
 
 
55
  current_id = data['BPMN_id'][i]
56
  next_idx = data['links'][target_idx][1]
57
  next_id = data['BPMN_id'][next_idx]
@@ -116,11 +120,11 @@ def create_wizard_file(data, text_mapping):
116
 
117
  for idx, activity_name in enumerate(data['BPMN_id']):
118
  if activity_name.startswith('task'):
119
- target_text = ' FINISH'
120
  activity = ET.SubElement(activities, 'activity', attrib={'name': text_mapping.get(activity_name, activity_name), 'performer': ''})
121
  endStates = ET.SubElement(activity, 'endStates')
122
  current_text, next_text = connect(data, text_mapping, idx)
123
- ET.SubElement(endStates, 'endState', attrib={'name': next_text, 'isRegular': 'True'})
 
124
  ET.SubElement(activity, 'subActivities')
125
  ET.SubElement(activity, 'subActivityFlows')
126
  ET.SubElement(activity, 'messageFlows')
@@ -130,10 +134,14 @@ def create_wizard_file(data, text_mapping):
130
  for i, link in enumerate(data['links']):
131
  if link[0] is None and link[1] is not None and (data['BPMN_id'][i].split('_')[0] == 'event' or data['BPMN_id'][i].split('_')[0] == 'message'):
132
  current_text, next_text = connect(data, text_mapping, i)
 
 
133
  ET.SubElement(activityFlows, 'activityFlow', attrib={'startEvent': current_text, 'endState': '---', 'target': next_text, 'isMerging': 'False', 'isPredefined': 'True'})
134
  i+=1
135
  if link[0] is not None and link[1] is not None and data['BPMN_id'][i].split('_')[0] == 'task':
136
  current_text, next_text = connect(data, text_mapping, i)
 
 
137
  ET.SubElement(activityFlows, 'activityFlow', attrib={'activity': current_text, 'endState': '---', 'target': next_text, 'isMerging': 'False', 'isPredefined': 'True'})
138
  i+=1
139
 
 
1
  import xml.etree.ElementTree as ET
2
  from modules.utils import class_dict
3
  from xml.dom import minidom
4
+ from modules.utils import error
5
 
6
  def rescale(scale, boxes):
7
  for i in range(len(boxes)):
 
53
 
54
  def connect(data, text_mapping, i):
55
  target_idx = data['links'][i][1]
56
+ if target_idx >= len(data['links']):
57
+ error('There is an error with the Vizi file, care when you download it.')
58
+ return None, None
59
  current_id = data['BPMN_id'][i]
60
  next_idx = data['links'][target_idx][1]
61
  next_id = data['BPMN_id'][next_idx]
 
120
 
121
  for idx, activity_name in enumerate(data['BPMN_id']):
122
  if activity_name.startswith('task'):
 
123
  activity = ET.SubElement(activities, 'activity', attrib={'name': text_mapping.get(activity_name, activity_name), 'performer': ''})
124
  endStates = ET.SubElement(activity, 'endStates')
125
  current_text, next_text = connect(data, text_mapping, idx)
126
+ if next_text is not None:
127
+ ET.SubElement(endStates, 'endState', attrib={'name': next_text, 'isRegular': 'True'})
128
  ET.SubElement(activity, 'subActivities')
129
  ET.SubElement(activity, 'subActivityFlows')
130
  ET.SubElement(activity, 'messageFlows')
 
134
  for i, link in enumerate(data['links']):
135
  if link[0] is None and link[1] is not None and (data['BPMN_id'][i].split('_')[0] == 'event' or data['BPMN_id'][i].split('_')[0] == 'message'):
136
  current_text, next_text = connect(data, text_mapping, i)
137
+ if current_text is None or next_text is None:
138
+ continue
139
  ET.SubElement(activityFlows, 'activityFlow', attrib={'startEvent': current_text, 'endState': '---', 'target': next_text, 'isMerging': 'False', 'isPredefined': 'True'})
140
  i+=1
141
  if link[0] is not None and link[1] is not None and data['BPMN_id'][i].split('_')[0] == 'task':
142
  current_text, next_text = connect(data, text_mapping, i)
143
+ if current_text is None or next_text is None:
144
+ continue
145
  ET.SubElement(activityFlows, 'activityFlow', attrib={'activity': current_text, 'endState': '---', 'target': next_text, 'isMerging': 'False', 'isPredefined': 'True'})
146
  i+=1
147
 
modules/train.py CHANGED
@@ -7,9 +7,10 @@ import matplotlib.pyplot as plt
7
 
8
  from modules.eval import main_evaluation
9
  from torch.optim import SGD, AdamW
10
- from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
11
  from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
12
  from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
 
13
  from tqdm import tqdm
14
  from modules.utils import write_results
15
 
@@ -36,7 +37,6 @@ def get_arrow_model(num_classes, num_keypoints=2):
36
  This is necessary to tailor the model to specific tasks that may have different keypoint structures.
37
  """
38
  # Load a model pre-trained on COCO, initialized without pre-trained weights
39
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
40
  model = keypointrcnn_resnet50_fpn(weights=None)
41
 
42
  # Get the number of input features for the classifier in the box predictor.
@@ -50,8 +50,7 @@ def get_arrow_model(num_classes, num_keypoints=2):
50
 
51
  return model
52
 
53
- from torchvision.models.detection import fasterrcnn_resnet50_fpn
54
- from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
55
  def get_faster_rcnn_model(num_classes):
56
  """
57
  Configures and returns a modified Faster R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes.
 
7
 
8
  from modules.eval import main_evaluation
9
  from torch.optim import SGD, AdamW
10
+ from torchvision.models.detection import keypointrcnn_resnet50_fpn
11
  from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
12
  from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
13
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn
14
  from tqdm import tqdm
15
  from modules.utils import write_results
16
 
 
37
  This is necessary to tailor the model to specific tasks that may have different keypoint structures.
38
  """
39
  # Load a model pre-trained on COCO, initialized without pre-trained weights
 
40
  model = keypointrcnn_resnet50_fpn(weights=None)
41
 
42
  # Get the number of input features for the classifier in the box predictor.
 
50
 
51
  return model
52
 
53
+
 
54
  def get_faster_rcnn_model(num_classes):
55
  """
56
  Configures and returns a modified Faster R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes.