MikkoLipsanen commited on
Commit
f8a998a
·
verified ·
1 Parent(s): f731714

Create segment_image.py

Browse files
Files changed (1) hide show
  1. segment_image.py +340 -0
segment_image.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shapely.validation import make_valid
2
+ from shapely.geometry import Polygon
3
+ from ultralytics import YOLO
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ from reading_order import OrderPolygons
8
+
9
+ class SegmentImage:
10
+ """Class for segmenting document image regions and text lines."""
11
+ def __init__(self,
12
+ line_model_path,
13
+ device,
14
+ line_iou=0.5,
15
+ region_iou=0.5,
16
+ line_overlap=0.5,
17
+ line_nms_iou=0.7,
18
+ region_nms_iou=0.3,
19
+ line_conf_threshold=0.25,
20
+ region_conf_threshold=0.25,
21
+ region_model_path=None,
22
+ order_regions=True,
23
+ region_half_precision=False,
24
+ line_half_precision=False):
25
+
26
+ # Path to text line detection model
27
+ self.line_model_path = line_model_path
28
+ # Path to text region detection model
29
+ self.region_model_path = region_model_path
30
+ # Defines the IoU threshold used in the non-maximum suppression (NMS) process to
31
+ # determine which prediction boxes should be suppressed or discarded based on their overlap with other boxes
32
+ self.line_nms_iou = line_nms_iou
33
+ self.region_nms_iou = region_nms_iou
34
+ # Defines the IoU threshold for text lines
35
+ self.line_iou = line_iou
36
+ # Defines the IoU threshold for text regions
37
+ self.region_iou = region_iou
38
+ # Defines the extent of line polygon overlap used for merging the polygons
39
+ self.line_overlap = line_overlap
40
+ # Defines confidence threshold for line detection
41
+ self.line_conf_threshold = line_conf_threshold
42
+ # Defines confidence threshold for region detection
43
+ self.region_conf_threshold = region_conf_threshold
44
+ # Defines the device to be used ('cpu', gpu '0', gpu '1' etc.)
45
+ self.device = device
46
+ # Defines whether a reading order is also estimated for the region detections
47
+ self.order_regions = order_regions
48
+ # Defines whether half precision (FP16) is used by the region and line prediction models
49
+ self.region_half_precision = region_half_precision
50
+ self.line_half_precision = line_half_precision
51
+ self.order_poly = OrderPolygons()
52
+ # Initialize segmentation model(s)
53
+ self.line_model = self.init_line_model()
54
+ if self.region_model_path:
55
+ self.region_model = self.init_region_model()
56
+
57
+ def init_line_model(self):
58
+ """Function for initializing the line detection model."""
59
+ try:
60
+ # Load the trained line detection model
61
+ line_model = YOLO(self.line_model_path)
62
+ return line_model
63
+ except Exception as e:
64
+ print('Failed to load the line detection model: %s' % e)
65
+
66
+ def init_region_model(self):
67
+ """Function for initializing the region detection model."""
68
+ try:
69
+ # Load the trained line detection model
70
+ region_model = YOLO(self.region_model_path)
71
+ return region_model
72
+ except Exception as e:
73
+ print('Failed to load the region detection model: %s' % e)
74
+
75
+ def get_region_ids(self, coords, max_min, classes, names, box_confs, img_shape):
76
+ """Function for creating unique id for each detected region."""
77
+ n = min(len(classes), len(coords))
78
+ res = []
79
+ for i in range(n):
80
+ # Creates a simple index-based id for each region
81
+ region_id = str(i)
82
+ # Extracts region name corresponding to the index
83
+ region_type = names[classes[i]]
84
+ poly_dict = {'coords': coords[i],
85
+ 'max_min': max_min[i],
86
+ 'class': str(classes[i]),
87
+ 'name': region_type,
88
+ 'conf': box_confs[i],
89
+ 'id': region_id,
90
+ 'img_shape': img_shape}
91
+ res.append(poly_dict)
92
+ return res
93
+
94
+ def get_max_min(self, polygons):
95
+ """Creates an array with the minimum and maximum
96
+ x and y values of the input polygons."""
97
+ n_rows = len(polygons)
98
+ xy_array = np.zeros([n_rows, 4])
99
+ for i, poly in enumerate(polygons):
100
+ x = [point[0] for point in poly]
101
+ y = [point[1] for point in poly]
102
+ if x:
103
+ xy_array[i,0] = max(x)
104
+ xy_array[i,1] = min(x)
105
+ if y:
106
+ xy_array[i,2] = max(y)
107
+ xy_array[i,3] = min(y)
108
+ return xy_array
109
+
110
+ def validate_polygon(self, polygon):
111
+ """"Function for testing and correcting the validity of polygons."""
112
+ if len(polygon) > 2:
113
+ polygon = Polygon(polygon)
114
+ if not polygon.is_valid:
115
+ polygon = make_valid(polygon)
116
+ return polygon
117
+ else:
118
+ return None
119
+
120
+ def get_iou(self, poly1, poly2):
121
+ """Function for calculating Intersection over Union (IoU) values."""
122
+ # If the polygons don't intersect, IoU is 0
123
+ iou = 0
124
+ poly1 = self.validate_polygon(poly1)
125
+ poly2 = self.validate_polygon(poly2)
126
+
127
+ if poly1 and poly2:
128
+ if poly1.intersects(poly2):
129
+ # Calculates intersection of the 2 polygons
130
+ intersect = poly1.intersection(poly2).area
131
+ # Calculates union of the 2 polygons
132
+ uni = poly1.union(poly2)
133
+ # Calculates intersection over union
134
+ iou = intersect / uni.area
135
+ return iou
136
+
137
+ def merge_polygons(self, polygons, iou_threshold, overlap_threshold = None):
138
+ """Merges polygons that have an IoU value
139
+ above the given threshold."""
140
+ new_polygons = []
141
+ dropped = set()
142
+ # Loops over all input polygons and merges them if the
143
+ # IoU value is over the given threshold
144
+ for i in range(0, len(polygons)):
145
+ poly1 = self.validate_polygon(polygons[i])
146
+ merged = None
147
+ for j in range(i+1, len(polygons)):
148
+ poly2 = self.validate_polygon(polygons[j])
149
+ if poly1 and poly2:
150
+ if poly1.intersects(poly2):
151
+ overlap = False
152
+ intersect = poly1.intersection(poly2)
153
+ uni = poly1.union(poly2)
154
+ # Calculates intersection over union
155
+ iou = intersect.area / uni.area
156
+ if overlap_threshold:
157
+ overlap = intersect.area > (overlap_threshold * min(poly1.area, poly2.area))
158
+ if (iou > iou_threshold) or overlap:
159
+ if merged:
160
+ # If there are multiple overlapping polygons
161
+ # with IoU over the threshold, they are all merged together
162
+ merged = uni.union(merged)
163
+ dropped.add(j)
164
+ else:
165
+ merged = uni
166
+ # Polygons that are merged together are dropped from
167
+ # the list
168
+ dropped.add(i)
169
+ dropped.add(j)
170
+ if merged:
171
+ if merged.geom_type in ['GeometryCollection','MultiPolygon']:
172
+ for geom in merged.geoms:
173
+ if geom.geom_type == 'Polygon':
174
+ new_polygons.append(list(geom.exterior.coords))
175
+ elif merged.geom_type == 'Polygon':
176
+ new_polygons.append(list(merged.exterior.coords))
177
+ res = [i for j, i in enumerate(polygons) if j not in dropped]
178
+ res += new_polygons
179
+
180
+ return res
181
+
182
+ def get_region_preds(self, img):
183
+ """Function for predicting text region coordinates."""
184
+ results = self.region_model.predict(source=img,
185
+ device=self.device,
186
+ conf=self.region_conf_threshold,
187
+ half=bool(self.region_half_precision),
188
+ iou=self.region_nms_iou)
189
+ results = results[0].cpu()
190
+ if results.masks:
191
+ # Extracts detected region polygons
192
+ coords = results.masks.xy
193
+ # Merge overlapping polygons
194
+ coords = self.merge_polygons(coords, self.region_iou)
195
+ # Maximum and minimum x and y axis values for detected polygons used for ordering the polygons
196
+ max_min = self.get_max_min(coords).tolist()
197
+ # Gets a list of the predicted class labels for detected regions
198
+ classes = results.boxes.cls.tolist()
199
+ # A dictionary with class ids as keys and class names as values
200
+ names = results.names
201
+ # Confidence values for detections
202
+ box_confs = results.boxes.conf.tolist()
203
+ # A tuple containing the shape of the original image
204
+ img_shape = results.orig_shape
205
+ res = self.get_region_ids(list(coords), max_min, classes, names, box_confs, img_shape)
206
+ return res
207
+ else:
208
+ return None
209
+
210
+
211
+ def get_line_preds(self, img):
212
+ """Function for predicting text line coordinates."""
213
+ results = self.line_model.predict(source=img,
214
+ device=self.device,
215
+ conf=self.line_conf_threshold,
216
+ half=bool(self.line_half_precision),
217
+ iou=self.line_nms_iou)
218
+ results = results[0].cpu()
219
+ if results.masks:
220
+ # Detected text line polygons
221
+ coords = results.masks.xy
222
+ # Merge overlapping polygons
223
+ coords = self.merge_polygons(coords, self.line_iou, self.line_overlap)
224
+ # Maximum and minimum x and y axis values for detected polygons
225
+ max_min = self.get_max_min(coords).tolist()
226
+ # Confidence values for detections
227
+ box_confs = results.boxes.conf.tolist()
228
+ res_dict = {'coords': list(coords), 'max_min': max_min, 'confs': box_confs}
229
+ return res_dict
230
+ else:
231
+ return None
232
+
233
+ def get_dist(self, line_polygon, regions):
234
+ """Function for finding the closest region to the text line."""
235
+ dist, reg_id = 1000000, None
236
+ line_polygon = self.validate_polygon(line_polygon)
237
+
238
+ if line_polygon:
239
+ for region in regions:
240
+ # Calculates dictance between line and regions polygons
241
+ region_polygon = self.validate_polygon(region['coords'])
242
+ if region_polygon:
243
+ line_reg_dist = line_polygon.distance(region_polygon)
244
+ if line_reg_dist < dist:
245
+ dist = line_reg_dist
246
+ reg_id = region['id']
247
+ return reg_id
248
+
249
+ def get_line_regions(self, lines, regions):
250
+ """Function for connecting each text line to one region."""
251
+ lines_list = []
252
+ for i in range(len(lines['coords'])):
253
+ iou, reg_id, conf = 0, '', 0.0
254
+ max_min = [0.0, 0.0, 0.0, 0.0]
255
+ polygon = lines['coords'][i]
256
+ for region in regions:
257
+ line_reg_iou = self.get_iou(polygon, region['coords'])
258
+ if line_reg_iou > iou:
259
+ iou = line_reg_iou
260
+ reg_id = region['id']
261
+ # If line polygon does not intersect with any region, a distance metric is used for defining
262
+ # the region that the line belongs to
263
+ if iou == 0:
264
+ reg_id = self.get_dist(polygon, regions)
265
+
266
+ if (len(lines['max_min']) - 1) >= i:
267
+ max_min = lines['max_min'][i]
268
+
269
+ if (len(lines['confs']) - 1) >= i:
270
+ conf = lines['confs'][i]
271
+
272
+ new_line = {'polygon': polygon, 'reg_id': reg_id, 'max_min': max_min, 'conf': conf}
273
+ lines_list.append(new_line)
274
+ return lines_list
275
+
276
+ def order_regions_lines(self, lines, regions):
277
+ """Function for ordering line predictions inside each region."""
278
+ regions_with_rows = []
279
+ region_max_mins = []
280
+ for i, region in enumerate(regions):
281
+ line_max_mins = []
282
+ line_confs = []
283
+ line_polygons = []
284
+ for line in lines:
285
+ if line['reg_id'] == region['id']:
286
+ line_max_mins.append(line['max_min'])
287
+ line_confs.append(line['conf'])
288
+ line_polygons.append(line['polygon'])
289
+ if line_polygons:
290
+ # If one or more lines are connected to a region, line order inside the region is defined
291
+ # and the predicted text lines are joined in the same python dict
292
+ line_order = self.order_poly.order(line_max_mins)
293
+ line_polygons = [line_polygons[i] for i in line_order]
294
+ line_confs = [line_confs[i] for i in line_order]
295
+ new_region = {'region_coords': region['coords'],
296
+ 'region_name': region['name'],
297
+ 'lines': line_polygons,
298
+ 'line_confs': line_confs,
299
+ 'region_conf': region['conf'],
300
+ 'img_shape': region['img_shape']}
301
+ region_max_mins.append(region['max_min'])
302
+ regions_with_rows.append(new_region)
303
+ else:
304
+ continue
305
+ # Creates an ordering of the detected regions based on their polygon coordinates
306
+ if self.order_regions:
307
+ region_order = self.order_poly.order(region_max_mins)
308
+ regions_with_rows = [regions_with_rows[i] for i in region_order]
309
+
310
+ return regions_with_rows
311
+
312
+ def get_default_region(self, image):
313
+ """Function for creating a default region if no regions are detected."""
314
+ w, h = image.size
315
+ region = {'coords': [[0.0, 0.0], [w, 0.0], [w, h], [0.0, h]],
316
+ 'max_min': [w, 0.0, h, 0.0],
317
+ 'class': '0',
318
+ 'name': "paragraph",
319
+ 'conf': 0.0,
320
+ 'id': '0',
321
+ 'img_shape': (h, w)}
322
+ return [region]
323
+
324
+ def get_segmentation(self, image):
325
+ """Segment input image into ordered text lines or ordered text regions and text lines."""
326
+ line_preds = self.get_line_preds(image)
327
+ if line_preds:
328
+ # If region detection model is defined, text regions and text lines are detected
329
+ region_preds = self.get_region_preds(image)
330
+ if not region_preds:
331
+ region_preds = self.get_default_region(image)
332
+ print(f'No regions detected from image {image}')
333
+ lines_with_regions = self.get_line_regions(line_preds, region_preds)
334
+ ordered_regions = self.order_regions_lines(lines_with_regions, region_preds)
335
+ return ordered_regions
336
+ else:
337
+ print(f'No text lines detected from image {image}')
338
+ return None
339
+
340
+