|
""" |
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
from __future__ import unicode_literals |
|
|
|
import sys |
|
import six |
|
import cv2 |
|
import numpy as np |
|
|
|
|
|
class GenTableMask(object): |
|
""" gen table mask """ |
|
|
|
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs): |
|
self.shrink_h_max = 5 |
|
self.shrink_w_max = 5 |
|
self.mask_type = mask_type |
|
|
|
def projection(self, erosion, h, w, spilt_threshold=0): |
|
|
|
projection_map = np.ones_like(erosion) |
|
project_val_array = [0 for _ in range(0, h)] |
|
|
|
for j in range(0, h): |
|
for i in range(0, w): |
|
if erosion[j, i] == 255: |
|
project_val_array[j] += 1 |
|
|
|
start_idx = 0 |
|
end_idx = 0 |
|
in_text = False |
|
box_list = [] |
|
for i in range(len(project_val_array)): |
|
if in_text == False and project_val_array[ |
|
i] > spilt_threshold: |
|
in_text = True |
|
start_idx = i |
|
elif project_val_array[ |
|
i] <= spilt_threshold and in_text == True: |
|
end_idx = i |
|
in_text = False |
|
if end_idx - start_idx <= 2: |
|
continue |
|
box_list.append((start_idx, end_idx + 1)) |
|
|
|
if in_text: |
|
box_list.append((start_idx, h - 1)) |
|
|
|
for j in range(0, h): |
|
for i in range(0, project_val_array[j]): |
|
projection_map[j, i] = 0 |
|
return box_list, projection_map |
|
|
|
def projection_cx(self, box_img): |
|
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) |
|
h, w = box_gray_img.shape |
|
|
|
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, |
|
cv2.THRESH_BINARY_INV) |
|
|
|
if h < w: |
|
kernel = np.ones((2, 1), np.uint8) |
|
erode = cv2.erode(thresh1, kernel, iterations=1) |
|
else: |
|
erode = thresh1 |
|
|
|
kernel = np.ones((1, 5), np.uint8) |
|
erosion = cv2.dilate(erode, kernel, iterations=1) |
|
|
|
projection_map = np.ones_like(erosion) |
|
project_val_array = [0 for _ in range(0, h)] |
|
|
|
for j in range(0, h): |
|
for i in range(0, w): |
|
if erosion[j, i] == 255: |
|
project_val_array[j] += 1 |
|
|
|
start_idx = 0 |
|
end_idx = 0 |
|
in_text = False |
|
box_list = [] |
|
spilt_threshold = 0 |
|
for i in range(len(project_val_array)): |
|
if in_text == False and project_val_array[ |
|
i] > spilt_threshold: |
|
in_text = True |
|
start_idx = i |
|
elif project_val_array[ |
|
i] <= spilt_threshold and in_text == True: |
|
end_idx = i |
|
in_text = False |
|
if end_idx - start_idx <= 2: |
|
continue |
|
box_list.append((start_idx, end_idx + 1)) |
|
|
|
if in_text: |
|
box_list.append((start_idx, h - 1)) |
|
|
|
for j in range(0, h): |
|
for i in range(0, project_val_array[j]): |
|
projection_map[j, i] = 0 |
|
split_bbox_list = [] |
|
if len(box_list) > 1: |
|
for i, (h_start, h_end) in enumerate(box_list): |
|
if i == 0: |
|
h_start = 0 |
|
if i == len(box_list): |
|
h_end = h |
|
word_img = erosion[h_start:h_end + 1, :] |
|
word_h, word_w = word_img.shape |
|
w_split_list, w_projection_map = self.projection(word_img.T, |
|
word_w, word_h) |
|
w_start, w_end = w_split_list[0][0], w_split_list[-1][1] |
|
if h_start > 0: |
|
h_start -= 1 |
|
h_end += 1 |
|
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :] |
|
split_bbox_list.append([w_start, h_start, w_end, h_end]) |
|
else: |
|
split_bbox_list.append([0, 0, w, h]) |
|
return split_bbox_list |
|
|
|
def shrink_bbox(self, bbox): |
|
left, top, right, bottom = bbox |
|
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max) |
|
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max) |
|
left_new = left + sh_w |
|
right_new = right - sh_w |
|
top_new = top + sh_h |
|
bottom_new = bottom - sh_h |
|
if left_new >= right_new: |
|
left_new = left |
|
right_new = right |
|
if top_new >= bottom_new: |
|
top_new = top |
|
bottom_new = bottom |
|
return [left_new, top_new, right_new, bottom_new] |
|
|
|
def __call__(self, data): |
|
img = data['image'] |
|
cells = data['cells'] |
|
height, width = img.shape[0:2] |
|
if self.mask_type == 1: |
|
mask_img = np.zeros((height, width), dtype=np.float32) |
|
else: |
|
mask_img = np.zeros((height, width, 3), dtype=np.float32) |
|
cell_num = len(cells) |
|
for cno in range(cell_num): |
|
if "bbox" in cells[cno]: |
|
bbox = cells[cno]['bbox'] |
|
left, top, right, bottom = bbox |
|
box_img = img[top:bottom, left:right, :].copy() |
|
split_bbox_list = self.projection_cx(box_img) |
|
for sno in range(len(split_bbox_list)): |
|
split_bbox_list[sno][0] += left |
|
split_bbox_list[sno][1] += top |
|
split_bbox_list[sno][2] += left |
|
split_bbox_list[sno][3] += top |
|
|
|
for sno in range(len(split_bbox_list)): |
|
left, top, right, bottom = split_bbox_list[sno] |
|
left, top, right, bottom = self.shrink_bbox( |
|
[left, top, right, bottom]) |
|
if self.mask_type == 1: |
|
mask_img[top:bottom, left:right] = 1.0 |
|
data['mask_img'] = mask_img |
|
else: |
|
mask_img[top:bottom, left:right, :] = (255, 255, 255) |
|
data['image'] = mask_img |
|
return data |
|
|
|
|
|
class ResizeTableImage(object): |
|
def __init__(self, max_len, resize_bboxes=False, infer_mode=False, |
|
**kwargs): |
|
super(ResizeTableImage, self).__init__() |
|
self.max_len = max_len |
|
self.resize_bboxes = resize_bboxes |
|
self.infer_mode = infer_mode |
|
|
|
def __call__(self, data): |
|
img = data['image'] |
|
height, width = img.shape[0:2] |
|
ratio = self.max_len / (max(height, width) * 1.0) |
|
resize_h = int(height * ratio) |
|
resize_w = int(width * ratio) |
|
resize_img = cv2.resize(img, (resize_w, resize_h)) |
|
if self.resize_bboxes and not self.infer_mode: |
|
data['bboxes'] = data['bboxes'] * ratio |
|
data['image'] = resize_img |
|
data['src_img'] = img |
|
data['shape'] = np.array([height, width, ratio, ratio]) |
|
data['max_len'] = self.max_len |
|
return data |
|
|
|
|
|
class PaddingTableImage(object): |
|
def __init__(self, size, **kwargs): |
|
super(PaddingTableImage, self).__init__() |
|
self.size = size |
|
|
|
def __call__(self, data): |
|
img = data['image'] |
|
pad_h, pad_w = self.size |
|
padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32) |
|
height, width = img.shape[0:2] |
|
padding_img[0:height, 0:width, :] = img.copy() |
|
data['image'] = padding_img |
|
shape = data['shape'].tolist() |
|
shape.extend([pad_h, pad_w]) |
|
data['shape'] = np.array(shape) |
|
return data |
|
|