BenjiELCA commited on
Commit
bbe2b18
·
1 Parent(s): 27b8abe

alignment of the element

Browse files
Files changed (1) hide show
  1. app.py +72 -10
app.py CHANGED
@@ -39,14 +39,76 @@ def read_xml_file(filepath):
39
  with open(filepath, 'r', encoding='utf-8') as file:
40
  return file.read()
41
 
42
- # Function to modify bounding box positions based on the given sizes
43
  def modif_box_pos(pred, size):
44
  modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction
45
- for i, (x1, y1, x2, y2) in enumerate(modified_pred['boxes']):
46
- center = [(x1 + x2) / 2, (y1 + y2) / 2]
47
- label = class_dict[modified_pred['labels'][i]]
48
- if label in size:
49
- modified_pred['boxes'][i] = [center[0] - size[label][0] / 2, center[1] - size[label][1] / 2, center[0] + size[label][0] / 2, center[1] + size[label][1] / 2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  return modified_pred['boxes']
51
 
52
 
@@ -245,7 +307,7 @@ def display_options(image, score_threshold, is_mobile, screen_width):
245
  )
246
 
247
  # Function to perform inference on the uploaded image using the loaded models
248
- def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width):
249
  _, uploaded_image = prepare_image(image, pad=False)
250
 
251
  img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1])
@@ -260,14 +322,14 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
260
  image_placeholder.image(uploaded_image, caption='Original Image', width=width)
261
 
262
  # Prediction
263
- _, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=0.5, distance_treshold=30)
264
 
265
  # Perform OCR on the uploaded image
266
  ocr_results = text_prediction(uploaded_image)
267
 
268
  # Filter and map OCR results to prediction results
269
  st.session_state.text_pred = filter_text(ocr_results, threshold=0.6)
270
- st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=0.5)
271
 
272
  # Remove the original image display
273
  image_placeholder.empty()
@@ -419,7 +481,7 @@ def main():
419
  if st.button("Launch Prediction"):
420
  st.session_state.crop_image = cropped_image
421
  with st.spinner('Processing...'):
422
- perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold, is_mobile, screen_width)
423
  st.balloons()
424
 
425
  if 'prediction' in st.session_state and uploaded_file is not None:
 
39
  with open(filepath, 'r', encoding='utf-8') as file:
40
  return file.read()
41
 
 
42
  def modif_box_pos(pred, size):
43
  modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction
44
+
45
+ # Step 1: Calculate the center of each bounding box and group them by pool
46
+ pool_groups = {}
47
+ for pool_index, element_indices in pred['pool_dict'].items():
48
+ pool_groups[pool_index] = []
49
+ for i in element_indices:
50
+ if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
51
+ x1, y1, x2, y2 = modified_pred['boxes'][i]
52
+ center = [(x1 + x2) / 2, (y1 + y2) / 2]
53
+ pool_groups[pool_index].append((center, i))
54
+
55
+ # Function to group centers within a specified range
56
+ def group_centers(centers, axis, range_=50):
57
+ groups = []
58
+ while centers:
59
+ center, idx = centers.pop(0)
60
+ group = [(center, idx)]
61
+ for other_center, other_idx in centers[:]:
62
+ if abs(center[axis] - other_center[axis]) <= range_:
63
+ group.append((other_center, other_idx))
64
+ centers.remove((other_center, other_idx))
65
+ groups.append(group)
66
+ return groups
67
+
68
+ # Step 2: Align the elements within each pool
69
+ for pool_index, centers in pool_groups.items():
70
+ # Group bounding boxes by checking if their centers are within ±50 pixels on the y-axis
71
+ y_groups = group_centers(centers.copy(), axis=1)
72
+
73
+ # Align the y-coordinates of the centers of grouped bounding boxes
74
+ for group in y_groups:
75
+ avg_y = sum([c[0][1] for c in group]) / len(group) # Calculate the average y-coordinate
76
+ for (center, idx) in group:
77
+ label = class_dict[modified_pred['labels'][idx]]
78
+ if label in size:
79
+ new_center = (center[0], avg_y) # Align the y-coordinate
80
+ modified_pred['boxes'][idx] = [
81
+ new_center[0] - size[label][0] / 2,
82
+ new_center[1] - size[label][1] / 2,
83
+ new_center[0] + size[label][0] / 2,
84
+ new_center[1] + size[label][1] / 2
85
+ ]
86
+
87
+ # Recalculate centers after vertical alignment
88
+ centers = []
89
+ for group in y_groups:
90
+ for center, idx in group:
91
+ x1, y1, x2, y2 = modified_pred['boxes'][idx]
92
+ center = [(x1 + x2) / 2, (y1 + y2) / 2]
93
+ centers.append((center, idx))
94
+
95
+ # Group bounding boxes by checking if their centers are within ±50 pixels on the x-axis
96
+ x_groups = group_centers(centers.copy(), axis=0)
97
+
98
+ # Align the x-coordinates of the centers of grouped bounding boxes
99
+ for group in x_groups:
100
+ avg_x = sum([c[0][0] for c in group]) / len(group) # Calculate the average x-coordinate
101
+ for (center, idx) in group:
102
+ label = class_dict[modified_pred['labels'][idx]]
103
+ if label in size:
104
+ new_center = (avg_x, center[1]) # Align the x-coordinate
105
+ modified_pred['boxes'][idx] = [
106
+ new_center[0] - size[label][0] / 2,
107
+ modified_pred['boxes'][idx][1],
108
+ new_center[0] + size[label][0] / 2,
109
+ modified_pred['boxes'][idx][3]
110
+ ]
111
+
112
  return modified_pred['boxes']
113
 
114
 
 
307
  )
308
 
309
  # Function to perform inference on the uploaded image using the loaded models
310
+ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width, iou_threshold=0.5, distance_treshold=30, percentage_text_dist_thresh=0.5):
311
  _, uploaded_image = prepare_image(image, pad=False)
312
 
313
  img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1])
 
322
  image_placeholder.image(uploaded_image, caption='Original Image', width=width)
323
 
324
  # Prediction
325
+ _, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
326
 
327
  # Perform OCR on the uploaded image
328
  ocr_results = text_prediction(uploaded_image)
329
 
330
  # Filter and map OCR results to prediction results
331
  st.session_state.text_pred = filter_text(ocr_results, threshold=0.6)
332
+ st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
333
 
334
  # Remove the original image display
335
  image_placeholder.empty()
 
481
  if st.button("Launch Prediction"):
482
  st.session_state.crop_image = cropped_image
483
  with st.spinner('Processing...'):
484
+ perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5)
485
  st.balloons()
486
 
487
  if 'prediction' in st.session_state and uploaded_file is not None: