Spaces:
Running
Running
alignment of the element
Browse files
app.py
CHANGED
@@ -39,14 +39,76 @@ def read_xml_file(filepath):
|
|
39 |
with open(filepath, 'r', encoding='utf-8') as file:
|
40 |
return file.read()
|
41 |
|
42 |
-
# Function to modify bounding box positions based on the given sizes
|
43 |
def modif_box_pos(pred, size):
|
44 |
modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
return modified_pred['boxes']
|
51 |
|
52 |
|
@@ -245,7 +307,7 @@ def display_options(image, score_threshold, is_mobile, screen_width):
|
|
245 |
)
|
246 |
|
247 |
# Function to perform inference on the uploaded image using the loaded models
|
248 |
-
def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width):
|
249 |
_, uploaded_image = prepare_image(image, pad=False)
|
250 |
|
251 |
img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1])
|
@@ -260,14 +322,14 @@ def perform_inference(model_object, model_arrow, image, score_threshold, is_mobi
|
|
260 |
image_placeholder.image(uploaded_image, caption='Original Image', width=width)
|
261 |
|
262 |
# Prediction
|
263 |
-
_, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=
|
264 |
|
265 |
# Perform OCR on the uploaded image
|
266 |
ocr_results = text_prediction(uploaded_image)
|
267 |
|
268 |
# Filter and map OCR results to prediction results
|
269 |
st.session_state.text_pred = filter_text(ocr_results, threshold=0.6)
|
270 |
-
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=
|
271 |
|
272 |
# Remove the original image display
|
273 |
image_placeholder.empty()
|
@@ -419,7 +481,7 @@ def main():
|
|
419 |
if st.button("Launch Prediction"):
|
420 |
st.session_state.crop_image = cropped_image
|
421 |
with st.spinner('Processing...'):
|
422 |
-
perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold, is_mobile, screen_width)
|
423 |
st.balloons()
|
424 |
|
425 |
if 'prediction' in st.session_state and uploaded_file is not None:
|
|
|
39 |
with open(filepath, 'r', encoding='utf-8') as file:
|
40 |
return file.read()
|
41 |
|
|
|
42 |
def modif_box_pos(pred, size):
|
43 |
modified_pred = copy.deepcopy(pred) # Make a deep copy of the prediction
|
44 |
+
|
45 |
+
# Step 1: Calculate the center of each bounding box and group them by pool
|
46 |
+
pool_groups = {}
|
47 |
+
for pool_index, element_indices in pred['pool_dict'].items():
|
48 |
+
pool_groups[pool_index] = []
|
49 |
+
for i in element_indices:
|
50 |
+
if class_dict[modified_pred['labels'][i]] != 'dataObject' or class_dict[modified_pred['labels'][i]] != 'dataStore':
|
51 |
+
x1, y1, x2, y2 = modified_pred['boxes'][i]
|
52 |
+
center = [(x1 + x2) / 2, (y1 + y2) / 2]
|
53 |
+
pool_groups[pool_index].append((center, i))
|
54 |
+
|
55 |
+
# Function to group centers within a specified range
|
56 |
+
def group_centers(centers, axis, range_=50):
|
57 |
+
groups = []
|
58 |
+
while centers:
|
59 |
+
center, idx = centers.pop(0)
|
60 |
+
group = [(center, idx)]
|
61 |
+
for other_center, other_idx in centers[:]:
|
62 |
+
if abs(center[axis] - other_center[axis]) <= range_:
|
63 |
+
group.append((other_center, other_idx))
|
64 |
+
centers.remove((other_center, other_idx))
|
65 |
+
groups.append(group)
|
66 |
+
return groups
|
67 |
+
|
68 |
+
# Step 2: Align the elements within each pool
|
69 |
+
for pool_index, centers in pool_groups.items():
|
70 |
+
# Group bounding boxes by checking if their centers are within ±50 pixels on the y-axis
|
71 |
+
y_groups = group_centers(centers.copy(), axis=1)
|
72 |
+
|
73 |
+
# Align the y-coordinates of the centers of grouped bounding boxes
|
74 |
+
for group in y_groups:
|
75 |
+
avg_y = sum([c[0][1] for c in group]) / len(group) # Calculate the average y-coordinate
|
76 |
+
for (center, idx) in group:
|
77 |
+
label = class_dict[modified_pred['labels'][idx]]
|
78 |
+
if label in size:
|
79 |
+
new_center = (center[0], avg_y) # Align the y-coordinate
|
80 |
+
modified_pred['boxes'][idx] = [
|
81 |
+
new_center[0] - size[label][0] / 2,
|
82 |
+
new_center[1] - size[label][1] / 2,
|
83 |
+
new_center[0] + size[label][0] / 2,
|
84 |
+
new_center[1] + size[label][1] / 2
|
85 |
+
]
|
86 |
+
|
87 |
+
# Recalculate centers after vertical alignment
|
88 |
+
centers = []
|
89 |
+
for group in y_groups:
|
90 |
+
for center, idx in group:
|
91 |
+
x1, y1, x2, y2 = modified_pred['boxes'][idx]
|
92 |
+
center = [(x1 + x2) / 2, (y1 + y2) / 2]
|
93 |
+
centers.append((center, idx))
|
94 |
+
|
95 |
+
# Group bounding boxes by checking if their centers are within ±50 pixels on the x-axis
|
96 |
+
x_groups = group_centers(centers.copy(), axis=0)
|
97 |
+
|
98 |
+
# Align the x-coordinates of the centers of grouped bounding boxes
|
99 |
+
for group in x_groups:
|
100 |
+
avg_x = sum([c[0][0] for c in group]) / len(group) # Calculate the average x-coordinate
|
101 |
+
for (center, idx) in group:
|
102 |
+
label = class_dict[modified_pred['labels'][idx]]
|
103 |
+
if label in size:
|
104 |
+
new_center = (avg_x, center[1]) # Align the x-coordinate
|
105 |
+
modified_pred['boxes'][idx] = [
|
106 |
+
new_center[0] - size[label][0] / 2,
|
107 |
+
modified_pred['boxes'][idx][1],
|
108 |
+
new_center[0] + size[label][0] / 2,
|
109 |
+
modified_pred['boxes'][idx][3]
|
110 |
+
]
|
111 |
+
|
112 |
return modified_pred['boxes']
|
113 |
|
114 |
|
|
|
307 |
)
|
308 |
|
309 |
# Function to perform inference on the uploaded image using the loaded models
|
310 |
+
def perform_inference(model_object, model_arrow, image, score_threshold, is_mobile, screen_width, iou_threshold=0.5, distance_treshold=30, percentage_text_dist_thresh=0.5):
|
311 |
_, uploaded_image = prepare_image(image, pad=False)
|
312 |
|
313 |
img_tensor = F.to_tensor(prepare_image(image.convert('RGB'))[1])
|
|
|
322 |
image_placeholder.image(uploaded_image, caption='Original Image', width=width)
|
323 |
|
324 |
# Prediction
|
325 |
+
_, st.session_state.prediction = full_prediction(model_object, model_arrow, img_tensor, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_treshold=distance_treshold)
|
326 |
|
327 |
# Perform OCR on the uploaded image
|
328 |
ocr_results = text_prediction(uploaded_image)
|
329 |
|
330 |
# Filter and map OCR results to prediction results
|
331 |
st.session_state.text_pred = filter_text(ocr_results, threshold=0.6)
|
332 |
+
st.session_state.text_mapping = mapping_text(st.session_state.prediction, st.session_state.text_pred, print_sentences=False, percentage_thresh=percentage_text_dist_thresh)
|
333 |
|
334 |
# Remove the original image display
|
335 |
image_placeholder.empty()
|
|
|
481 |
if st.button("Launch Prediction"):
|
482 |
st.session_state.crop_image = cropped_image
|
483 |
with st.spinner('Processing...'):
|
484 |
+
perform_inference(model_object, model_arrow, st.session_state.crop_image, score_threshold, is_mobile, screen_width, iou_threshold=0.3, distance_treshold=30, percentage_text_dist_thresh=0.5)
|
485 |
st.balloons()
|
486 |
|
487 |
if 'prediction' in st.session_state and uploaded_file is not None:
|