Danieldu
commited on
Commit
·
a89d9fd
1
Parent(s):
2e90087
add code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- app.py +56 -0
- ppocr/__init__.py +13 -0
- ppocr/data/__init__.py +110 -0
- ppocr/data/collate_fn.py +118 -0
- ppocr/data/imaug/ColorJitter.py +26 -0
- ppocr/data/imaug/__init__.py +80 -0
- ppocr/data/imaug/abinet_aug.py +458 -0
- ppocr/data/imaug/copy_paste.py +174 -0
- ppocr/data/imaug/ct_process.py +355 -0
- ppocr/data/imaug/drrg_targets.py +696 -0
- ppocr/data/imaug/east_process.py +436 -0
- ppocr/data/imaug/fce_aug.py +564 -0
- ppocr/data/imaug/fce_targets.py +666 -0
- ppocr/data/imaug/iaa_augment.py +105 -0
- ppocr/data/imaug/label_ops.py +1505 -0
- ppocr/data/imaug/make_border_map.py +173 -0
- ppocr/data/imaug/make_pse_gt.py +106 -0
- ppocr/data/imaug/make_shrink_map.py +123 -0
- ppocr/data/imaug/operators.py +524 -0
- ppocr/data/imaug/pg_process.py +1034 -0
- ppocr/data/imaug/randaugment.py +143 -0
- ppocr/data/imaug/random_crop_data.py +234 -0
- ppocr/data/imaug/rec_img_aug.py +825 -0
- ppocr/data/imaug/sast_process.py +777 -0
- ppocr/data/imaug/ssl_img_aug.py +60 -0
- ppocr/data/imaug/table_ops.py +229 -0
- ppocr/data/imaug/text_image_aug/__init__.py +17 -0
- ppocr/data/imaug/text_image_aug/__pycache__/__init__.cpython-37.pyc +0 -0
- ppocr/data/imaug/text_image_aug/__pycache__/__init__.cpython-38.pyc +0 -0
- ppocr/data/imaug/text_image_aug/__pycache__/augment.cpython-37.pyc +0 -0
- ppocr/data/imaug/text_image_aug/__pycache__/augment.cpython-38.pyc +0 -0
- ppocr/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-37.pyc +0 -0
- ppocr/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-38.pyc +0 -0
- ppocr/data/imaug/text_image_aug/augment.py +120 -0
- ppocr/data/imaug/text_image_aug/warp_mls.py +168 -0
- ppocr/data/imaug/vqa/__init__.py +20 -0
- ppocr/data/imaug/vqa/augment.py +33 -0
- ppocr/data/imaug/vqa/token/__init__.py +18 -0
- ppocr/data/imaug/vqa/token/vqa_re_convert.py +51 -0
- ppocr/data/imaug/vqa/token/vqa_token_chunk.py +122 -0
- ppocr/data/imaug/vqa/token/vqa_token_pad.py +104 -0
- ppocr/data/imaug/vqa/token/vqa_token_relation.py +67 -0
- ppocr/data/lmdb_dataset.py +205 -0
- ppocr/data/pgnet_dataset.py +106 -0
- ppocr/data/pubtab_dataset.py +133 -0
- ppocr/data/simple_dataset.py +151 -0
- ppocr/ext_op/__init__.py +1 -0
- ppocr/ext_op/roi_align_rotated/roi_align_rotated.cc +528 -0
- ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu +380 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
simfang.ttf filter=lfs diff=lfs merge=lfs -text
|
37 |
+
ppocr/utils/dict/confuse.pkl filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, io
|
2 |
+
from paddleocr import PaddleOCR, draw_ocr
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
|
7 |
+
# 設定 Hugging Face Hub 的 Access Token
|
8 |
+
os.environ["HF_TOKEN"] = "TWOCR"
|
9 |
+
|
10 |
+
def inference(img_path):
|
11 |
+
|
12 |
+
ocr = PaddleOCR(
|
13 |
+
rec_char_dict_path='zhtw_common_dict.txt',
|
14 |
+
use_gpu=False,
|
15 |
+
rec_image_shape="3, 48, 320"
|
16 |
+
)
|
17 |
+
|
18 |
+
result = ocr.ocr(img_path)
|
19 |
+
|
20 |
+
for idx in range(len(result)):
|
21 |
+
res = result[idx]
|
22 |
+
for line in res:
|
23 |
+
print(line)
|
24 |
+
|
25 |
+
result = result[0]
|
26 |
+
image = Image.open(img_path).convert('RGB')
|
27 |
+
boxes = [line[0] for line in result]
|
28 |
+
txts = [line[1][0] if line[1] else '' for line in result] # 確保在無文字時 txts 還是個空字串
|
29 |
+
scores = [line[1][1] for line in result]
|
30 |
+
im_show_pil = draw_ocr(image, boxes, txts, scores, font_path="./simfang.ttf")
|
31 |
+
|
32 |
+
return im_show_pil, "\n".join(txts)
|
33 |
+
|
34 |
+
title = "<p style='text-align: center'><a href='https://www.twman.org/AI/CV' target='_blank'>繁體中文醫療診斷書和收據OCR:PaddleOCR</a></p>"
|
35 |
+
|
36 |
+
description = """
|
37 |
+
<p style='text-align: center'><a href="https://blog.twman.org/2023/07/wsl.html" target='_blank'>用PaddleOCR的PPOCRLabel來微調醫療診斷書和收據</a></p><br>
|
38 |
+
<p style='text-align: center'><a href="https://github.com/Deep-Learning-101" target='_blank'>https://github.com/Deep-Learning-101</a></p><br>
|
39 |
+
<p style='text-align: center'><a href="https://github.com/Deep-Learning-101/Computer-Vision-Paper" target='_blank'>https://github.com/Deep-Learning-101/Computer-Vision-Paper</a></p><br>
|
40 |
+
"""
|
41 |
+
|
42 |
+
|
43 |
+
css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
|
44 |
+
|
45 |
+
gr.Interface(
|
46 |
+
inference,
|
47 |
+
[gr.inputs.Image(type='filepath', label='圖片上傳')],
|
48 |
+
outputs=[
|
49 |
+
gr.outputs.Image(type="pil", label="識別結果"),
|
50 |
+
"text"
|
51 |
+
],
|
52 |
+
title=title,
|
53 |
+
description=description,
|
54 |
+
css=css,
|
55 |
+
enable_queue=True
|
56 |
+
).launch(debug=True)
|
ppocr/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
ppocr/data/__init__.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
from __future__ import unicode_literals
|
19 |
+
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
import numpy as np
|
23 |
+
import skimage
|
24 |
+
import paddle
|
25 |
+
import signal
|
26 |
+
import random
|
27 |
+
|
28 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
29 |
+
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
30 |
+
|
31 |
+
import copy
|
32 |
+
from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
|
33 |
+
import paddle.distributed as dist
|
34 |
+
|
35 |
+
from ppocr.data.imaug import transform, create_operators
|
36 |
+
from ppocr.data.simple_dataset import SimpleDataSet
|
37 |
+
from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR
|
38 |
+
from ppocr.data.pgnet_dataset import PGDataSet
|
39 |
+
from ppocr.data.pubtab_dataset import PubTabDataSet
|
40 |
+
|
41 |
+
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
42 |
+
|
43 |
+
|
44 |
+
def term_mp(sig_num, frame):
|
45 |
+
""" kill all child processes
|
46 |
+
"""
|
47 |
+
pid = os.getpid()
|
48 |
+
pgid = os.getpgid(os.getpid())
|
49 |
+
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
|
50 |
+
os.killpg(pgid, signal.SIGKILL)
|
51 |
+
|
52 |
+
|
53 |
+
def build_dataloader(config, mode, device, logger, seed=None):
|
54 |
+
config = copy.deepcopy(config)
|
55 |
+
|
56 |
+
support_dict = [
|
57 |
+
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet',
|
58 |
+
'LMDBDataSetSR'
|
59 |
+
]
|
60 |
+
module_name = config[mode]['dataset']['name']
|
61 |
+
assert module_name in support_dict, Exception(
|
62 |
+
'DataSet only support {}'.format(support_dict))
|
63 |
+
assert mode in ['Train', 'Eval', 'Test'
|
64 |
+
], "Mode should be Train, Eval or Test."
|
65 |
+
|
66 |
+
dataset = eval(module_name)(config, mode, logger, seed)
|
67 |
+
loader_config = config[mode]['loader']
|
68 |
+
batch_size = loader_config['batch_size_per_card']
|
69 |
+
drop_last = loader_config['drop_last']
|
70 |
+
shuffle = loader_config['shuffle']
|
71 |
+
num_workers = loader_config['num_workers']
|
72 |
+
if 'use_shared_memory' in loader_config.keys():
|
73 |
+
use_shared_memory = loader_config['use_shared_memory']
|
74 |
+
else:
|
75 |
+
use_shared_memory = True
|
76 |
+
|
77 |
+
if mode == "Train":
|
78 |
+
# Distribute data to multiple cards
|
79 |
+
batch_sampler = DistributedBatchSampler(
|
80 |
+
dataset=dataset,
|
81 |
+
batch_size=batch_size,
|
82 |
+
shuffle=shuffle,
|
83 |
+
drop_last=drop_last)
|
84 |
+
else:
|
85 |
+
# Distribute data to single card
|
86 |
+
batch_sampler = BatchSampler(
|
87 |
+
dataset=dataset,
|
88 |
+
batch_size=batch_size,
|
89 |
+
shuffle=shuffle,
|
90 |
+
drop_last=drop_last)
|
91 |
+
|
92 |
+
if 'collate_fn' in loader_config:
|
93 |
+
from . import collate_fn
|
94 |
+
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
|
95 |
+
else:
|
96 |
+
collate_fn = None
|
97 |
+
data_loader = DataLoader(
|
98 |
+
dataset=dataset,
|
99 |
+
batch_sampler=batch_sampler,
|
100 |
+
places=device,
|
101 |
+
num_workers=num_workers,
|
102 |
+
return_list=True,
|
103 |
+
use_shared_memory=use_shared_memory,
|
104 |
+
collate_fn=collate_fn)
|
105 |
+
|
106 |
+
# support exit using ctrl+c
|
107 |
+
signal.signal(signal.SIGINT, term_mp)
|
108 |
+
signal.signal(signal.SIGTERM, term_mp)
|
109 |
+
|
110 |
+
return data_loader
|
ppocr/data/collate_fn.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import paddle
|
16 |
+
import numbers
|
17 |
+
import numpy as np
|
18 |
+
from collections import defaultdict
|
19 |
+
|
20 |
+
|
21 |
+
class DictCollator(object):
|
22 |
+
"""
|
23 |
+
data batch
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __call__(self, batch):
|
27 |
+
# todo:support batch operators
|
28 |
+
data_dict = defaultdict(list)
|
29 |
+
to_tensor_keys = []
|
30 |
+
for sample in batch:
|
31 |
+
for k, v in sample.items():
|
32 |
+
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
33 |
+
if k not in to_tensor_keys:
|
34 |
+
to_tensor_keys.append(k)
|
35 |
+
data_dict[k].append(v)
|
36 |
+
for k in to_tensor_keys:
|
37 |
+
data_dict[k] = paddle.to_tensor(data_dict[k])
|
38 |
+
return data_dict
|
39 |
+
|
40 |
+
|
41 |
+
class ListCollator(object):
|
42 |
+
"""
|
43 |
+
data batch
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __call__(self, batch):
|
47 |
+
# todo:support batch operators
|
48 |
+
data_dict = defaultdict(list)
|
49 |
+
to_tensor_idxs = []
|
50 |
+
for sample in batch:
|
51 |
+
for idx, v in enumerate(sample):
|
52 |
+
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
|
53 |
+
if idx not in to_tensor_idxs:
|
54 |
+
to_tensor_idxs.append(idx)
|
55 |
+
data_dict[idx].append(v)
|
56 |
+
for idx in to_tensor_idxs:
|
57 |
+
data_dict[idx] = paddle.to_tensor(data_dict[idx])
|
58 |
+
return list(data_dict.values())
|
59 |
+
|
60 |
+
|
61 |
+
class SSLRotateCollate(object):
|
62 |
+
"""
|
63 |
+
bach: [
|
64 |
+
[(4*3xH*W), (4,)]
|
65 |
+
[(4*3xH*W), (4,)]
|
66 |
+
...
|
67 |
+
]
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __call__(self, batch):
|
71 |
+
output = [np.concatenate(d, axis=0) for d in zip(*batch)]
|
72 |
+
return output
|
73 |
+
|
74 |
+
|
75 |
+
class DyMaskCollator(object):
|
76 |
+
"""
|
77 |
+
batch: [
|
78 |
+
image [batch_size, channel, maxHinbatch, maxWinbatch]
|
79 |
+
image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
|
80 |
+
label [batch_size, maxLabelLen]
|
81 |
+
label_mask [batch_size, maxLabelLen]
|
82 |
+
...
|
83 |
+
]
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __call__(self, batch):
|
87 |
+
max_width, max_height, max_length = 0, 0, 0
|
88 |
+
bs, channel = len(batch), batch[0][0].shape[0]
|
89 |
+
proper_items = []
|
90 |
+
for item in batch:
|
91 |
+
if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[
|
92 |
+
2] * max_height > 1600 * 320:
|
93 |
+
continue
|
94 |
+
max_height = item[0].shape[1] if item[0].shape[
|
95 |
+
1] > max_height else max_height
|
96 |
+
max_width = item[0].shape[2] if item[0].shape[
|
97 |
+
2] > max_width else max_width
|
98 |
+
max_length = len(item[1]) if len(item[
|
99 |
+
1]) > max_length else max_length
|
100 |
+
proper_items.append(item)
|
101 |
+
|
102 |
+
images, image_masks = np.zeros(
|
103 |
+
(len(proper_items), channel, max_height, max_width),
|
104 |
+
dtype='float32'), np.zeros(
|
105 |
+
(len(proper_items), 1, max_height, max_width), dtype='float32')
|
106 |
+
labels, label_masks = np.zeros(
|
107 |
+
(len(proper_items), max_length), dtype='int64'), np.zeros(
|
108 |
+
(len(proper_items), max_length), dtype='int64')
|
109 |
+
|
110 |
+
for i in range(len(proper_items)):
|
111 |
+
_, h, w = proper_items[i][0].shape
|
112 |
+
images[i][:, :h, :w] = proper_items[i][0]
|
113 |
+
image_masks[i][:, :h, :w] = 1
|
114 |
+
l = len(proper_items[i][1])
|
115 |
+
labels[i][:l] = proper_items[i][1]
|
116 |
+
label_masks[i][:l] = 1
|
117 |
+
|
118 |
+
return images, image_masks, labels, label_masks
|
ppocr/data/imaug/ColorJitter.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from paddle.vision.transforms import ColorJitter as pp_ColorJitter
|
15 |
+
|
16 |
+
__all__ = ['ColorJitter']
|
17 |
+
|
18 |
+
class ColorJitter(object):
|
19 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs):
|
20 |
+
self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
|
21 |
+
|
22 |
+
def __call__(self, data):
|
23 |
+
image = data['image']
|
24 |
+
image = self.aug(image)
|
25 |
+
data['image'] = image
|
26 |
+
return data
|
ppocr/data/imaug/__init__.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from __future__ import absolute_import
|
15 |
+
from __future__ import division
|
16 |
+
from __future__ import print_function
|
17 |
+
from __future__ import unicode_literals
|
18 |
+
|
19 |
+
from .iaa_augment import IaaAugment
|
20 |
+
from .make_border_map import MakeBorderMap
|
21 |
+
from .make_shrink_map import MakeShrinkMap
|
22 |
+
from .random_crop_data import EastRandomCropData, RandomCropImgMask
|
23 |
+
from .make_pse_gt import MakePseGt
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
28 |
+
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
29 |
+
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
|
30 |
+
RFLRecResizeImg, SVTRRecAug
|
31 |
+
from .ssl_img_aug import SSLRotateResize
|
32 |
+
from .randaugment import RandAugment
|
33 |
+
from .copy_paste import CopyPaste
|
34 |
+
from .ColorJitter import ColorJitter
|
35 |
+
from .operators import *
|
36 |
+
from .label_ops import *
|
37 |
+
|
38 |
+
from .east_process import *
|
39 |
+
from .sast_process import *
|
40 |
+
from .pg_process import *
|
41 |
+
from .table_ops import *
|
42 |
+
|
43 |
+
from .vqa import *
|
44 |
+
|
45 |
+
from .fce_aug import *
|
46 |
+
from .fce_targets import FCENetTargets
|
47 |
+
from .ct_process import *
|
48 |
+
from .drrg_targets import DRRGTargets
|
49 |
+
|
50 |
+
|
51 |
+
def transform(data, ops=None):
|
52 |
+
""" transform """
|
53 |
+
if ops is None:
|
54 |
+
ops = []
|
55 |
+
for op in ops:
|
56 |
+
data = op(data)
|
57 |
+
if data is None:
|
58 |
+
return None
|
59 |
+
return data
|
60 |
+
|
61 |
+
|
62 |
+
def create_operators(op_param_list, global_config=None):
|
63 |
+
"""
|
64 |
+
create operators based on the config
|
65 |
+
|
66 |
+
Args:
|
67 |
+
params(list): a dict list, used to create some operators
|
68 |
+
"""
|
69 |
+
assert isinstance(op_param_list, list), ('operator config should be a list')
|
70 |
+
ops = []
|
71 |
+
for operator in op_param_list:
|
72 |
+
assert isinstance(operator,
|
73 |
+
dict) and len(operator) == 1, "yaml format error"
|
74 |
+
op_name = list(operator)[0]
|
75 |
+
param = {} if operator[op_name] is None else operator[op_name]
|
76 |
+
if global_config is not None:
|
77 |
+
param.update(global_config)
|
78 |
+
op = eval(op_name)(**param)
|
79 |
+
ops.append(op)
|
80 |
+
return ops
|
ppocr/data/imaug/abinet_aug.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is refer from:
|
16 |
+
https://github.com/FangShancheng/ABINet/blob/main/transforms.py
|
17 |
+
"""
|
18 |
+
import math
|
19 |
+
import numbers
|
20 |
+
import random
|
21 |
+
|
22 |
+
import cv2
|
23 |
+
import numpy as np
|
24 |
+
from paddle.vision.transforms import Compose, ColorJitter
|
25 |
+
|
26 |
+
|
27 |
+
def sample_asym(magnitude, size=None):
|
28 |
+
return np.random.beta(1, 4, size) * magnitude
|
29 |
+
|
30 |
+
|
31 |
+
def sample_sym(magnitude, size=None):
|
32 |
+
return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
|
33 |
+
|
34 |
+
|
35 |
+
def sample_uniform(low, high, size=None):
|
36 |
+
return np.random.uniform(low, high, size=size)
|
37 |
+
|
38 |
+
|
39 |
+
def get_interpolation(type='random'):
|
40 |
+
if type == 'random':
|
41 |
+
choice = [
|
42 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA
|
43 |
+
]
|
44 |
+
interpolation = choice[random.randint(0, len(choice) - 1)]
|
45 |
+
elif type == 'nearest':
|
46 |
+
interpolation = cv2.INTER_NEAREST
|
47 |
+
elif type == 'linear':
|
48 |
+
interpolation = cv2.INTER_LINEAR
|
49 |
+
elif type == 'cubic':
|
50 |
+
interpolation = cv2.INTER_CUBIC
|
51 |
+
elif type == 'area':
|
52 |
+
interpolation = cv2.INTER_AREA
|
53 |
+
else:
|
54 |
+
raise TypeError(
|
55 |
+
'Interpolation types only nearest, linear, cubic, area are supported!'
|
56 |
+
)
|
57 |
+
return interpolation
|
58 |
+
|
59 |
+
|
60 |
+
class CVRandomRotation(object):
|
61 |
+
def __init__(self, degrees=15):
|
62 |
+
assert isinstance(degrees,
|
63 |
+
numbers.Number), "degree should be a single number."
|
64 |
+
assert degrees >= 0, "degree must be positive."
|
65 |
+
self.degrees = degrees
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def get_params(degrees):
|
69 |
+
return sample_sym(degrees)
|
70 |
+
|
71 |
+
def __call__(self, img):
|
72 |
+
angle = self.get_params(self.degrees)
|
73 |
+
src_h, src_w = img.shape[:2]
|
74 |
+
M = cv2.getRotationMatrix2D(
|
75 |
+
center=(src_w / 2, src_h / 2), angle=angle, scale=1.0)
|
76 |
+
abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
|
77 |
+
dst_w = int(src_h * abs_sin + src_w * abs_cos)
|
78 |
+
dst_h = int(src_h * abs_cos + src_w * abs_sin)
|
79 |
+
M[0, 2] += (dst_w - src_w) / 2
|
80 |
+
M[1, 2] += (dst_h - src_h) / 2
|
81 |
+
|
82 |
+
flags = get_interpolation()
|
83 |
+
return cv2.warpAffine(
|
84 |
+
img,
|
85 |
+
M, (dst_w, dst_h),
|
86 |
+
flags=flags,
|
87 |
+
borderMode=cv2.BORDER_REPLICATE)
|
88 |
+
|
89 |
+
|
90 |
+
class CVRandomAffine(object):
|
91 |
+
def __init__(self, degrees, translate=None, scale=None, shear=None):
|
92 |
+
assert isinstance(degrees,
|
93 |
+
numbers.Number), "degree should be a single number."
|
94 |
+
assert degrees >= 0, "degree must be positive."
|
95 |
+
self.degrees = degrees
|
96 |
+
|
97 |
+
if translate is not None:
|
98 |
+
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
99 |
+
"translate should be a list or tuple and it must be of length 2."
|
100 |
+
for t in translate:
|
101 |
+
if not (0.0 <= t <= 1.0):
|
102 |
+
raise ValueError(
|
103 |
+
"translation values should be between 0 and 1")
|
104 |
+
self.translate = translate
|
105 |
+
|
106 |
+
if scale is not None:
|
107 |
+
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
|
108 |
+
"scale should be a list or tuple and it must be of length 2."
|
109 |
+
for s in scale:
|
110 |
+
if s <= 0:
|
111 |
+
raise ValueError("scale values should be positive")
|
112 |
+
self.scale = scale
|
113 |
+
|
114 |
+
if shear is not None:
|
115 |
+
if isinstance(shear, numbers.Number):
|
116 |
+
if shear < 0:
|
117 |
+
raise ValueError(
|
118 |
+
"If shear is a single number, it must be positive.")
|
119 |
+
self.shear = [shear]
|
120 |
+
else:
|
121 |
+
assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
|
122 |
+
"shear should be a list or tuple and it must be of length 2."
|
123 |
+
self.shear = shear
|
124 |
+
else:
|
125 |
+
self.shear = shear
|
126 |
+
|
127 |
+
def _get_inverse_affine_matrix(self, center, angle, translate, scale,
|
128 |
+
shear):
|
129 |
+
# https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717
|
130 |
+
from numpy import sin, cos, tan
|
131 |
+
|
132 |
+
if isinstance(shear, numbers.Number):
|
133 |
+
shear = [shear, 0]
|
134 |
+
|
135 |
+
if not isinstance(shear, (tuple, list)) and len(shear) == 2:
|
136 |
+
raise ValueError(
|
137 |
+
"Shear should be a single value or a tuple/list containing " +
|
138 |
+
"two values. Got {}".format(shear))
|
139 |
+
|
140 |
+
rot = math.radians(angle)
|
141 |
+
sx, sy = [math.radians(s) for s in shear]
|
142 |
+
|
143 |
+
cx, cy = center
|
144 |
+
tx, ty = translate
|
145 |
+
|
146 |
+
# RSS without scaling
|
147 |
+
a = cos(rot - sy) / cos(sy)
|
148 |
+
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
|
149 |
+
c = sin(rot - sy) / cos(sy)
|
150 |
+
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
|
151 |
+
|
152 |
+
# Inverted rotation matrix with scale and shear
|
153 |
+
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
|
154 |
+
M = [d, -b, 0, -c, a, 0]
|
155 |
+
M = [x / scale for x in M]
|
156 |
+
|
157 |
+
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
|
158 |
+
M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
|
159 |
+
M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
|
160 |
+
|
161 |
+
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
|
162 |
+
M[2] += cx
|
163 |
+
M[5] += cy
|
164 |
+
return M
|
165 |
+
|
166 |
+
@staticmethod
|
167 |
+
def get_params(degrees, translate, scale_ranges, shears, height):
|
168 |
+
angle = sample_sym(degrees)
|
169 |
+
if translate is not None:
|
170 |
+
max_dx = translate[0] * height
|
171 |
+
max_dy = translate[1] * height
|
172 |
+
translations = (np.round(sample_sym(max_dx)),
|
173 |
+
np.round(sample_sym(max_dy)))
|
174 |
+
else:
|
175 |
+
translations = (0, 0)
|
176 |
+
|
177 |
+
if scale_ranges is not None:
|
178 |
+
scale = sample_uniform(scale_ranges[0], scale_ranges[1])
|
179 |
+
else:
|
180 |
+
scale = 1.0
|
181 |
+
|
182 |
+
if shears is not None:
|
183 |
+
if len(shears) == 1:
|
184 |
+
shear = [sample_sym(shears[0]), 0.]
|
185 |
+
elif len(shears) == 2:
|
186 |
+
shear = [sample_sym(shears[0]), sample_sym(shears[1])]
|
187 |
+
else:
|
188 |
+
shear = 0.0
|
189 |
+
|
190 |
+
return angle, translations, scale, shear
|
191 |
+
|
192 |
+
def __call__(self, img):
|
193 |
+
src_h, src_w = img.shape[:2]
|
194 |
+
angle, translate, scale, shear = self.get_params(
|
195 |
+
self.degrees, self.translate, self.scale, self.shear, src_h)
|
196 |
+
|
197 |
+
M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle,
|
198 |
+
(0, 0), scale, shear)
|
199 |
+
M = np.array(M).reshape(2, 3)
|
200 |
+
|
201 |
+
startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1),
|
202 |
+
(0, src_h - 1)]
|
203 |
+
project = lambda x, y, a, b, c: int(a * x + b * y + c)
|
204 |
+
endpoints = [(project(x, y, *M[0]), project(x, y, *M[1]))
|
205 |
+
for x, y in startpoints]
|
206 |
+
|
207 |
+
rect = cv2.minAreaRect(np.array(endpoints))
|
208 |
+
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
|
209 |
+
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
|
210 |
+
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
|
211 |
+
|
212 |
+
dst_w = int(max_x - min_x)
|
213 |
+
dst_h = int(max_y - min_y)
|
214 |
+
M[0, 2] += (dst_w - src_w) / 2
|
215 |
+
M[1, 2] += (dst_h - src_h) / 2
|
216 |
+
|
217 |
+
# add translate
|
218 |
+
dst_w += int(abs(translate[0]))
|
219 |
+
dst_h += int(abs(translate[1]))
|
220 |
+
if translate[0] < 0: M[0, 2] += abs(translate[0])
|
221 |
+
if translate[1] < 0: M[1, 2] += abs(translate[1])
|
222 |
+
|
223 |
+
flags = get_interpolation()
|
224 |
+
return cv2.warpAffine(
|
225 |
+
img,
|
226 |
+
M, (dst_w, dst_h),
|
227 |
+
flags=flags,
|
228 |
+
borderMode=cv2.BORDER_REPLICATE)
|
229 |
+
|
230 |
+
|
231 |
+
class CVRandomPerspective(object):
|
232 |
+
def __init__(self, distortion=0.5):
|
233 |
+
self.distortion = distortion
|
234 |
+
|
235 |
+
def get_params(self, width, height, distortion):
|
236 |
+
offset_h = sample_asym(
|
237 |
+
distortion * height / 2, size=4).astype(dtype=np.int)
|
238 |
+
offset_w = sample_asym(
|
239 |
+
distortion * width / 2, size=4).astype(dtype=np.int)
|
240 |
+
topleft = (offset_w[0], offset_h[0])
|
241 |
+
topright = (width - 1 - offset_w[1], offset_h[1])
|
242 |
+
botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
|
243 |
+
botleft = (offset_w[3], height - 1 - offset_h[3])
|
244 |
+
|
245 |
+
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1),
|
246 |
+
(0, height - 1)]
|
247 |
+
endpoints = [topleft, topright, botright, botleft]
|
248 |
+
return np.array(
|
249 |
+
startpoints, dtype=np.float32), np.array(
|
250 |
+
endpoints, dtype=np.float32)
|
251 |
+
|
252 |
+
def __call__(self, img):
|
253 |
+
height, width = img.shape[:2]
|
254 |
+
startpoints, endpoints = self.get_params(width, height, self.distortion)
|
255 |
+
M = cv2.getPerspectiveTransform(startpoints, endpoints)
|
256 |
+
|
257 |
+
# TODO: more robust way to crop image
|
258 |
+
rect = cv2.minAreaRect(endpoints)
|
259 |
+
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
|
260 |
+
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
|
261 |
+
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
|
262 |
+
min_x, min_y = max(min_x, 0), max(min_y, 0)
|
263 |
+
|
264 |
+
flags = get_interpolation()
|
265 |
+
img = cv2.warpPerspective(
|
266 |
+
img,
|
267 |
+
M, (max_x, max_y),
|
268 |
+
flags=flags,
|
269 |
+
borderMode=cv2.BORDER_REPLICATE)
|
270 |
+
img = img[min_y:, min_x:]
|
271 |
+
return img
|
272 |
+
|
273 |
+
|
274 |
+
class CVRescale(object):
|
275 |
+
def __init__(self, factor=4, base_size=(128, 512)):
|
276 |
+
""" Define image scales using gaussian pyramid and rescale image to target scale.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
factor: the decayed factor from base size, factor=4 keeps target scale by default.
|
280 |
+
base_size: base size the build the bottom layer of pyramid
|
281 |
+
"""
|
282 |
+
if isinstance(factor, numbers.Number):
|
283 |
+
self.factor = round(sample_uniform(0, factor))
|
284 |
+
elif isinstance(factor, (tuple, list)) and len(factor) == 2:
|
285 |
+
self.factor = round(sample_uniform(factor[0], factor[1]))
|
286 |
+
else:
|
287 |
+
raise Exception('factor must be number or list with length 2')
|
288 |
+
# assert factor is valid
|
289 |
+
self.base_h, self.base_w = base_size[:2]
|
290 |
+
|
291 |
+
def __call__(self, img):
|
292 |
+
if self.factor == 0: return img
|
293 |
+
src_h, src_w = img.shape[:2]
|
294 |
+
cur_w, cur_h = self.base_w, self.base_h
|
295 |
+
scale_img = cv2.resize(
|
296 |
+
img, (cur_w, cur_h), interpolation=get_interpolation())
|
297 |
+
for _ in range(self.factor):
|
298 |
+
scale_img = cv2.pyrDown(scale_img)
|
299 |
+
scale_img = cv2.resize(
|
300 |
+
scale_img, (src_w, src_h), interpolation=get_interpolation())
|
301 |
+
return scale_img
|
302 |
+
|
303 |
+
|
304 |
+
class CVGaussianNoise(object):
|
305 |
+
def __init__(self, mean=0, var=20):
|
306 |
+
self.mean = mean
|
307 |
+
if isinstance(var, numbers.Number):
|
308 |
+
self.var = max(int(sample_asym(var)), 1)
|
309 |
+
elif isinstance(var, (tuple, list)) and len(var) == 2:
|
310 |
+
self.var = int(sample_uniform(var[0], var[1]))
|
311 |
+
else:
|
312 |
+
raise Exception('degree must be number or list with length 2')
|
313 |
+
|
314 |
+
def __call__(self, img):
|
315 |
+
noise = np.random.normal(self.mean, self.var**0.5, img.shape)
|
316 |
+
img = np.clip(img + noise, 0, 255).astype(np.uint8)
|
317 |
+
return img
|
318 |
+
|
319 |
+
|
320 |
+
class CVMotionBlur(object):
|
321 |
+
def __init__(self, degrees=12, angle=90):
|
322 |
+
if isinstance(degrees, numbers.Number):
|
323 |
+
self.degree = max(int(sample_asym(degrees)), 1)
|
324 |
+
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
|
325 |
+
self.degree = int(sample_uniform(degrees[0], degrees[1]))
|
326 |
+
else:
|
327 |
+
raise Exception('degree must be number or list with length 2')
|
328 |
+
self.angle = sample_uniform(-angle, angle)
|
329 |
+
|
330 |
+
def __call__(self, img):
|
331 |
+
M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2),
|
332 |
+
self.angle, 1)
|
333 |
+
motion_blur_kernel = np.zeros((self.degree, self.degree))
|
334 |
+
motion_blur_kernel[self.degree // 2, :] = 1
|
335 |
+
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
|
336 |
+
(self.degree, self.degree))
|
337 |
+
motion_blur_kernel = motion_blur_kernel / self.degree
|
338 |
+
img = cv2.filter2D(img, -1, motion_blur_kernel)
|
339 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
340 |
+
return img
|
341 |
+
|
342 |
+
|
343 |
+
class CVGeometry(object):
|
344 |
+
def __init__(self,
|
345 |
+
degrees=15,
|
346 |
+
translate=(0.3, 0.3),
|
347 |
+
scale=(0.5, 2.),
|
348 |
+
shear=(45, 15),
|
349 |
+
distortion=0.5,
|
350 |
+
p=0.5):
|
351 |
+
self.p = p
|
352 |
+
type_p = random.random()
|
353 |
+
if type_p < 0.33:
|
354 |
+
self.transforms = CVRandomRotation(degrees=degrees)
|
355 |
+
elif type_p < 0.66:
|
356 |
+
self.transforms = CVRandomAffine(
|
357 |
+
degrees=degrees, translate=translate, scale=scale, shear=shear)
|
358 |
+
else:
|
359 |
+
self.transforms = CVRandomPerspective(distortion=distortion)
|
360 |
+
|
361 |
+
def __call__(self, img):
|
362 |
+
if random.random() < self.p:
|
363 |
+
return self.transforms(img)
|
364 |
+
else:
|
365 |
+
return img
|
366 |
+
|
367 |
+
|
368 |
+
class CVDeterioration(object):
|
369 |
+
def __init__(self, var, degrees, factor, p=0.5):
|
370 |
+
self.p = p
|
371 |
+
transforms = []
|
372 |
+
if var is not None:
|
373 |
+
transforms.append(CVGaussianNoise(var=var))
|
374 |
+
if degrees is not None:
|
375 |
+
transforms.append(CVMotionBlur(degrees=degrees))
|
376 |
+
if factor is not None:
|
377 |
+
transforms.append(CVRescale(factor=factor))
|
378 |
+
|
379 |
+
random.shuffle(transforms)
|
380 |
+
transforms = Compose(transforms)
|
381 |
+
self.transforms = transforms
|
382 |
+
|
383 |
+
def __call__(self, img):
|
384 |
+
if random.random() < self.p:
|
385 |
+
|
386 |
+
return self.transforms(img)
|
387 |
+
else:
|
388 |
+
return img
|
389 |
+
|
390 |
+
|
391 |
+
class CVColorJitter(object):
|
392 |
+
def __init__(self,
|
393 |
+
brightness=0.5,
|
394 |
+
contrast=0.5,
|
395 |
+
saturation=0.5,
|
396 |
+
hue=0.1,
|
397 |
+
p=0.5):
|
398 |
+
self.p = p
|
399 |
+
self.transforms = ColorJitter(
|
400 |
+
brightness=brightness,
|
401 |
+
contrast=contrast,
|
402 |
+
saturation=saturation,
|
403 |
+
hue=hue)
|
404 |
+
|
405 |
+
def __call__(self, img):
|
406 |
+
if random.random() < self.p: return self.transforms(img)
|
407 |
+
else: return img
|
408 |
+
|
409 |
+
|
410 |
+
class SVTRDeterioration(object):
|
411 |
+
def __init__(self, var, degrees, factor, p=0.5):
|
412 |
+
self.p = p
|
413 |
+
transforms = []
|
414 |
+
if var is not None:
|
415 |
+
transforms.append(CVGaussianNoise(var=var))
|
416 |
+
if degrees is not None:
|
417 |
+
transforms.append(CVMotionBlur(degrees=degrees))
|
418 |
+
if factor is not None:
|
419 |
+
transforms.append(CVRescale(factor=factor))
|
420 |
+
self.transforms = transforms
|
421 |
+
|
422 |
+
def __call__(self, img):
|
423 |
+
if random.random() < self.p:
|
424 |
+
random.shuffle(self.transforms)
|
425 |
+
transforms = Compose(self.transforms)
|
426 |
+
return transforms(img)
|
427 |
+
else:
|
428 |
+
return img
|
429 |
+
|
430 |
+
|
431 |
+
class SVTRGeometry(object):
|
432 |
+
def __init__(self,
|
433 |
+
aug_type=0,
|
434 |
+
degrees=15,
|
435 |
+
translate=(0.3, 0.3),
|
436 |
+
scale=(0.5, 2.),
|
437 |
+
shear=(45, 15),
|
438 |
+
distortion=0.5,
|
439 |
+
p=0.5):
|
440 |
+
self.aug_type = aug_type
|
441 |
+
self.p = p
|
442 |
+
self.transforms = []
|
443 |
+
self.transforms.append(CVRandomRotation(degrees=degrees))
|
444 |
+
self.transforms.append(CVRandomAffine(
|
445 |
+
degrees=degrees, translate=translate, scale=scale, shear=shear))
|
446 |
+
self.transforms.append(CVRandomPerspective(distortion=distortion))
|
447 |
+
|
448 |
+
def __call__(self, img):
|
449 |
+
if random.random() < self.p:
|
450 |
+
if self.aug_type:
|
451 |
+
random.shuffle(self.transforms)
|
452 |
+
transforms = Compose(self.transforms[:random.randint(1, 3)])
|
453 |
+
img = transforms(img)
|
454 |
+
else:
|
455 |
+
img = self.transforms[random.randint(0, 2)](img)
|
456 |
+
return img
|
457 |
+
else:
|
458 |
+
return img
|
ppocr/data/imaug/copy_paste.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import copy
|
15 |
+
import cv2
|
16 |
+
import random
|
17 |
+
import numpy as np
|
18 |
+
from PIL import Image
|
19 |
+
from shapely.geometry import Polygon
|
20 |
+
|
21 |
+
from ppocr.data.imaug.iaa_augment import IaaAugment
|
22 |
+
from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
|
23 |
+
from tools.infer.utility import get_rotate_crop_image
|
24 |
+
|
25 |
+
|
26 |
+
class CopyPaste(object):
|
27 |
+
def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
|
28 |
+
self.ext_data_num = 1
|
29 |
+
self.objects_paste_ratio = objects_paste_ratio
|
30 |
+
self.limit_paste = limit_paste
|
31 |
+
augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
|
32 |
+
self.aug = IaaAugment(augmenter_args)
|
33 |
+
|
34 |
+
def __call__(self, data):
|
35 |
+
point_num = data['polys'].shape[1]
|
36 |
+
src_img = data['image']
|
37 |
+
src_polys = data['polys'].tolist()
|
38 |
+
src_texts = data['texts']
|
39 |
+
src_ignores = data['ignore_tags'].tolist()
|
40 |
+
ext_data = data['ext_data'][0]
|
41 |
+
ext_image = ext_data['image']
|
42 |
+
ext_polys = ext_data['polys']
|
43 |
+
ext_texts = ext_data['texts']
|
44 |
+
ext_ignores = ext_data['ignore_tags']
|
45 |
+
|
46 |
+
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
|
47 |
+
select_num = max(
|
48 |
+
1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
|
49 |
+
|
50 |
+
random.shuffle(indexs)
|
51 |
+
select_idxs = indexs[:select_num]
|
52 |
+
select_polys = ext_polys[select_idxs]
|
53 |
+
select_ignores = ext_ignores[select_idxs]
|
54 |
+
|
55 |
+
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
56 |
+
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
|
57 |
+
src_img = Image.fromarray(src_img).convert('RGBA')
|
58 |
+
for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
|
59 |
+
box_img = get_rotate_crop_image(ext_image, poly)
|
60 |
+
|
61 |
+
src_img, box = self.paste_img(src_img, box_img, src_polys)
|
62 |
+
if box is not None:
|
63 |
+
box = box.tolist()
|
64 |
+
for _ in range(len(box), point_num):
|
65 |
+
box.append(box[-1])
|
66 |
+
src_polys.append(box)
|
67 |
+
src_texts.append(ext_texts[idx])
|
68 |
+
src_ignores.append(tag)
|
69 |
+
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
|
70 |
+
h, w = src_img.shape[:2]
|
71 |
+
src_polys = np.array(src_polys)
|
72 |
+
src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
|
73 |
+
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
|
74 |
+
data['image'] = src_img
|
75 |
+
data['polys'] = src_polys
|
76 |
+
data['texts'] = src_texts
|
77 |
+
data['ignore_tags'] = np.array(src_ignores)
|
78 |
+
return data
|
79 |
+
|
80 |
+
def paste_img(self, src_img, box_img, src_polys):
|
81 |
+
box_img_pil = Image.fromarray(box_img).convert('RGBA')
|
82 |
+
src_w, src_h = src_img.size
|
83 |
+
box_w, box_h = box_img_pil.size
|
84 |
+
|
85 |
+
angle = np.random.randint(0, 360)
|
86 |
+
box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
|
87 |
+
box = rotate_bbox(box_img, box, angle)[0]
|
88 |
+
box_img_pil = box_img_pil.rotate(angle, expand=1)
|
89 |
+
box_w, box_h = box_img_pil.width, box_img_pil.height
|
90 |
+
if src_w - box_w < 0 or src_h - box_h < 0:
|
91 |
+
return src_img, None
|
92 |
+
|
93 |
+
paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
|
94 |
+
src_h - box_h)
|
95 |
+
if paste_x is None:
|
96 |
+
return src_img, None
|
97 |
+
box[:, 0] += paste_x
|
98 |
+
box[:, 1] += paste_y
|
99 |
+
r, g, b, A = box_img_pil.split()
|
100 |
+
src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
|
101 |
+
|
102 |
+
return src_img, box
|
103 |
+
|
104 |
+
def select_coord(self, src_polys, box, endx, endy):
|
105 |
+
if self.limit_paste:
|
106 |
+
xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
|
107 |
+
), box[:, 0].max(), box[:, 1].max()
|
108 |
+
for _ in range(50):
|
109 |
+
paste_x = random.randint(0, endx)
|
110 |
+
paste_y = random.randint(0, endy)
|
111 |
+
xmin1 = xmin + paste_x
|
112 |
+
xmax1 = xmax + paste_x
|
113 |
+
ymin1 = ymin + paste_y
|
114 |
+
ymax1 = ymax + paste_y
|
115 |
+
|
116 |
+
num_poly_in_rect = 0
|
117 |
+
for poly in src_polys:
|
118 |
+
if not is_poly_outside_rect(poly, xmin1, ymin1,
|
119 |
+
xmax1 - xmin1, ymax1 - ymin1):
|
120 |
+
num_poly_in_rect += 1
|
121 |
+
break
|
122 |
+
if num_poly_in_rect == 0:
|
123 |
+
return paste_x, paste_y
|
124 |
+
return None, None
|
125 |
+
else:
|
126 |
+
paste_x = random.randint(0, endx)
|
127 |
+
paste_y = random.randint(0, endy)
|
128 |
+
return paste_x, paste_y
|
129 |
+
|
130 |
+
|
131 |
+
def get_union(pD, pG):
|
132 |
+
return Polygon(pD).union(Polygon(pG)).area
|
133 |
+
|
134 |
+
|
135 |
+
def get_intersection_over_union(pD, pG):
|
136 |
+
return get_intersection(pD, pG) / get_union(pD, pG)
|
137 |
+
|
138 |
+
|
139 |
+
def get_intersection(pD, pG):
|
140 |
+
return Polygon(pD).intersection(Polygon(pG)).area
|
141 |
+
|
142 |
+
|
143 |
+
def rotate_bbox(img, text_polys, angle, scale=1):
|
144 |
+
"""
|
145 |
+
from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
|
146 |
+
Args:
|
147 |
+
img: np.ndarray
|
148 |
+
text_polys: np.ndarray N*4*2
|
149 |
+
angle: int
|
150 |
+
scale: int
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
|
154 |
+
"""
|
155 |
+
w = img.shape[1]
|
156 |
+
h = img.shape[0]
|
157 |
+
|
158 |
+
rangle = np.deg2rad(angle)
|
159 |
+
nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
|
160 |
+
nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
|
161 |
+
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
|
162 |
+
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
|
163 |
+
rot_mat[0, 2] += rot_move[0]
|
164 |
+
rot_mat[1, 2] += rot_move[1]
|
165 |
+
|
166 |
+
# ---------------------- rotate box ----------------------
|
167 |
+
rot_text_polys = list()
|
168 |
+
for bbox in text_polys:
|
169 |
+
point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
|
170 |
+
point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
|
171 |
+
point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
|
172 |
+
point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
|
173 |
+
rot_text_polys.append([point1, point2, point3, point4])
|
174 |
+
return np.array(rot_text_polys, dtype=np.float32)
|
ppocr/data/imaug/ct_process.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import cv2
|
17 |
+
import random
|
18 |
+
import pyclipper
|
19 |
+
import paddle
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import Polygon as plg
|
23 |
+
import scipy.io as scio
|
24 |
+
|
25 |
+
from PIL import Image
|
26 |
+
import paddle.vision.transforms as transforms
|
27 |
+
|
28 |
+
|
29 |
+
class RandomScale():
|
30 |
+
def __init__(self, short_size=640, **kwargs):
|
31 |
+
self.short_size = short_size
|
32 |
+
|
33 |
+
def scale_aligned(self, img, scale):
|
34 |
+
oh, ow = img.shape[0:2]
|
35 |
+
h = int(oh * scale + 0.5)
|
36 |
+
w = int(ow * scale + 0.5)
|
37 |
+
if h % 32 != 0:
|
38 |
+
h = h + (32 - h % 32)
|
39 |
+
if w % 32 != 0:
|
40 |
+
w = w + (32 - w % 32)
|
41 |
+
img = cv2.resize(img, dsize=(w, h))
|
42 |
+
factor_h = h / oh
|
43 |
+
factor_w = w / ow
|
44 |
+
return img, factor_h, factor_w
|
45 |
+
|
46 |
+
def __call__(self, data):
|
47 |
+
img = data['image']
|
48 |
+
|
49 |
+
h, w = img.shape[0:2]
|
50 |
+
random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3])
|
51 |
+
scale = (np.random.choice(random_scale) * self.short_size) / min(h, w)
|
52 |
+
img, factor_h, factor_w = self.scale_aligned(img, scale)
|
53 |
+
|
54 |
+
data['scale_factor'] = (factor_w, factor_h)
|
55 |
+
data['image'] = img
|
56 |
+
return data
|
57 |
+
|
58 |
+
|
59 |
+
class MakeShrink():
|
60 |
+
def __init__(self, kernel_scale=0.7, **kwargs):
|
61 |
+
self.kernel_scale = kernel_scale
|
62 |
+
|
63 |
+
def dist(self, a, b):
|
64 |
+
return np.linalg.norm((a - b), ord=2, axis=0)
|
65 |
+
|
66 |
+
def perimeter(self, bbox):
|
67 |
+
peri = 0.0
|
68 |
+
for i in range(bbox.shape[0]):
|
69 |
+
peri += self.dist(bbox[i], bbox[(i + 1) % bbox.shape[0]])
|
70 |
+
return peri
|
71 |
+
|
72 |
+
def shrink(self, bboxes, rate, max_shr=20):
|
73 |
+
rate = rate * rate
|
74 |
+
shrinked_bboxes = []
|
75 |
+
for bbox in bboxes:
|
76 |
+
area = plg.Polygon(bbox).area()
|
77 |
+
peri = self.perimeter(bbox)
|
78 |
+
|
79 |
+
try:
|
80 |
+
pco = pyclipper.PyclipperOffset()
|
81 |
+
pco.AddPath(bbox, pyclipper.JT_ROUND,
|
82 |
+
pyclipper.ET_CLOSEDPOLYGON)
|
83 |
+
offset = min(
|
84 |
+
int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)
|
85 |
+
|
86 |
+
shrinked_bbox = pco.Execute(-offset)
|
87 |
+
if len(shrinked_bbox) == 0:
|
88 |
+
shrinked_bboxes.append(bbox)
|
89 |
+
continue
|
90 |
+
|
91 |
+
shrinked_bbox = np.array(shrinked_bbox[0])
|
92 |
+
if shrinked_bbox.shape[0] <= 2:
|
93 |
+
shrinked_bboxes.append(bbox)
|
94 |
+
continue
|
95 |
+
|
96 |
+
shrinked_bboxes.append(shrinked_bbox)
|
97 |
+
except Exception as e:
|
98 |
+
shrinked_bboxes.append(bbox)
|
99 |
+
|
100 |
+
return shrinked_bboxes
|
101 |
+
|
102 |
+
def __call__(self, data):
|
103 |
+
img = data['image']
|
104 |
+
bboxes = data['polys']
|
105 |
+
words = data['texts']
|
106 |
+
scale_factor = data['scale_factor']
|
107 |
+
|
108 |
+
gt_instance = np.zeros(img.shape[0:2], dtype='uint8') # h,w
|
109 |
+
training_mask = np.ones(img.shape[0:2], dtype='uint8')
|
110 |
+
training_mask_distance = np.ones(img.shape[0:2], dtype='uint8')
|
111 |
+
|
112 |
+
for i in range(len(bboxes)):
|
113 |
+
bboxes[i] = np.reshape(bboxes[i] * (
|
114 |
+
[scale_factor[0], scale_factor[1]] * (bboxes[i].shape[0] // 2)),
|
115 |
+
(bboxes[i].shape[0] // 2, 2)).astype('int32')
|
116 |
+
|
117 |
+
for i in range(len(bboxes)):
|
118 |
+
#different value for different bbox
|
119 |
+
cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1)
|
120 |
+
|
121 |
+
# set training mask to 0
|
122 |
+
cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1)
|
123 |
+
|
124 |
+
# for not accurate annotation, use training_mask_distance
|
125 |
+
if words[i] == '###' or words[i] == '???':
|
126 |
+
cv2.drawContours(training_mask_distance, [bboxes[i]], -1, 0, -1)
|
127 |
+
|
128 |
+
# make shrink
|
129 |
+
gt_kernel_instance = np.zeros(img.shape[0:2], dtype='uint8')
|
130 |
+
kernel_bboxes = self.shrink(bboxes, self.kernel_scale)
|
131 |
+
for i in range(len(bboxes)):
|
132 |
+
cv2.drawContours(gt_kernel_instance, [kernel_bboxes[i]], -1, i + 1,
|
133 |
+
-1)
|
134 |
+
|
135 |
+
# for training mask, kernel and background= 1, box region=0
|
136 |
+
if words[i] != '###' and words[i] != '???':
|
137 |
+
cv2.drawContours(training_mask, [kernel_bboxes[i]], -1, 1, -1)
|
138 |
+
|
139 |
+
gt_kernel = gt_kernel_instance.copy()
|
140 |
+
# for gt_kernel, kernel = 1
|
141 |
+
gt_kernel[gt_kernel > 0] = 1
|
142 |
+
|
143 |
+
# shrink 2 times
|
144 |
+
tmp1 = gt_kernel_instance.copy()
|
145 |
+
erode_kernel = np.ones((3, 3), np.uint8)
|
146 |
+
tmp1 = cv2.erode(tmp1, erode_kernel, iterations=1)
|
147 |
+
tmp2 = tmp1.copy()
|
148 |
+
tmp2 = cv2.erode(tmp2, erode_kernel, iterations=1)
|
149 |
+
|
150 |
+
# compute text region
|
151 |
+
gt_kernel_inner = tmp1 - tmp2
|
152 |
+
|
153 |
+
# gt_instance: text instance, bg=0, diff word use diff value
|
154 |
+
# training_mask: text instance mask, word=0,kernel and bg=1
|
155 |
+
# gt_kernel_instance: text kernel instance, bg=0, diff word use diff value
|
156 |
+
# gt_kernel: text_kernel, bg=0,diff word use same value
|
157 |
+
# gt_kernel_inner: text kernel reference
|
158 |
+
# training_mask_distance: word without anno = 0, else 1
|
159 |
+
|
160 |
+
data['image'] = [
|
161 |
+
img, gt_instance, training_mask, gt_kernel_instance, gt_kernel,
|
162 |
+
gt_kernel_inner, training_mask_distance
|
163 |
+
]
|
164 |
+
return data
|
165 |
+
|
166 |
+
|
167 |
+
class GroupRandomHorizontalFlip():
|
168 |
+
def __init__(self, p=0.5, **kwargs):
|
169 |
+
self.p = p
|
170 |
+
|
171 |
+
def __call__(self, data):
|
172 |
+
imgs = data['image']
|
173 |
+
|
174 |
+
if random.random() < self.p:
|
175 |
+
for i in range(len(imgs)):
|
176 |
+
imgs[i] = np.flip(imgs[i], axis=1).copy()
|
177 |
+
data['image'] = imgs
|
178 |
+
return data
|
179 |
+
|
180 |
+
|
181 |
+
class GroupRandomRotate():
|
182 |
+
def __init__(self, **kwargs):
|
183 |
+
pass
|
184 |
+
|
185 |
+
def __call__(self, data):
|
186 |
+
imgs = data['image']
|
187 |
+
|
188 |
+
max_angle = 10
|
189 |
+
angle = random.random() * 2 * max_angle - max_angle
|
190 |
+
for i in range(len(imgs)):
|
191 |
+
img = imgs[i]
|
192 |
+
w, h = img.shape[:2]
|
193 |
+
rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
|
194 |
+
img_rotation = cv2.warpAffine(
|
195 |
+
img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST)
|
196 |
+
imgs[i] = img_rotation
|
197 |
+
|
198 |
+
data['image'] = imgs
|
199 |
+
return data
|
200 |
+
|
201 |
+
|
202 |
+
class GroupRandomCropPadding():
|
203 |
+
def __init__(self, target_size=(640, 640), **kwargs):
|
204 |
+
self.target_size = target_size
|
205 |
+
|
206 |
+
def __call__(self, data):
|
207 |
+
imgs = data['image']
|
208 |
+
|
209 |
+
h, w = imgs[0].shape[0:2]
|
210 |
+
t_w, t_h = self.target_size
|
211 |
+
p_w, p_h = self.target_size
|
212 |
+
if w == t_w and h == t_h:
|
213 |
+
return data
|
214 |
+
|
215 |
+
t_h = t_h if t_h < h else h
|
216 |
+
t_w = t_w if t_w < w else w
|
217 |
+
|
218 |
+
if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0:
|
219 |
+
# make sure to crop the text region
|
220 |
+
tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w)
|
221 |
+
tl[tl < 0] = 0
|
222 |
+
br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w)
|
223 |
+
br[br < 0] = 0
|
224 |
+
br[0] = min(br[0], h - t_h)
|
225 |
+
br[1] = min(br[1], w - t_w)
|
226 |
+
|
227 |
+
i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
|
228 |
+
j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
|
229 |
+
else:
|
230 |
+
i = random.randint(0, h - t_h) if h - t_h > 0 else 0
|
231 |
+
j = random.randint(0, w - t_w) if w - t_w > 0 else 0
|
232 |
+
|
233 |
+
n_imgs = []
|
234 |
+
for idx in range(len(imgs)):
|
235 |
+
if len(imgs[idx].shape) == 3:
|
236 |
+
s3_length = int(imgs[idx].shape[-1])
|
237 |
+
img = imgs[idx][i:i + t_h, j:j + t_w, :]
|
238 |
+
img_p = cv2.copyMakeBorder(
|
239 |
+
img,
|
240 |
+
0,
|
241 |
+
p_h - t_h,
|
242 |
+
0,
|
243 |
+
p_w - t_w,
|
244 |
+
borderType=cv2.BORDER_CONSTANT,
|
245 |
+
value=tuple(0 for i in range(s3_length)))
|
246 |
+
else:
|
247 |
+
img = imgs[idx][i:i + t_h, j:j + t_w]
|
248 |
+
img_p = cv2.copyMakeBorder(
|
249 |
+
img,
|
250 |
+
0,
|
251 |
+
p_h - t_h,
|
252 |
+
0,
|
253 |
+
p_w - t_w,
|
254 |
+
borderType=cv2.BORDER_CONSTANT,
|
255 |
+
value=(0, ))
|
256 |
+
n_imgs.append(img_p)
|
257 |
+
|
258 |
+
data['image'] = n_imgs
|
259 |
+
return data
|
260 |
+
|
261 |
+
|
262 |
+
class MakeCentripetalShift():
|
263 |
+
def __init__(self, **kwargs):
|
264 |
+
pass
|
265 |
+
|
266 |
+
def jaccard(self, As, Bs):
|
267 |
+
A = As.shape[0] # small
|
268 |
+
B = Bs.shape[0] # large
|
269 |
+
|
270 |
+
dis = np.sqrt(
|
271 |
+
np.sum((As[:, np.newaxis, :].repeat(
|
272 |
+
B, axis=1) - Bs[np.newaxis, :, :].repeat(
|
273 |
+
A, axis=0))**2,
|
274 |
+
axis=-1))
|
275 |
+
|
276 |
+
ind = np.argmin(dis, axis=-1)
|
277 |
+
|
278 |
+
return ind
|
279 |
+
|
280 |
+
def __call__(self, data):
|
281 |
+
imgs = data['image']
|
282 |
+
|
283 |
+
img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, gt_kernel_inner, training_mask_distance = \
|
284 |
+
imgs[0], imgs[1], imgs[2], imgs[3], imgs[4], imgs[5], imgs[6]
|
285 |
+
|
286 |
+
max_instance = np.max(gt_instance) # num bbox
|
287 |
+
|
288 |
+
# make centripetal shift
|
289 |
+
gt_distance = np.zeros((2, *img.shape[0:2]), dtype=np.float32)
|
290 |
+
for i in range(1, max_instance + 1):
|
291 |
+
# kernel_reference
|
292 |
+
ind = (gt_kernel_inner == i)
|
293 |
+
|
294 |
+
if np.sum(ind) == 0:
|
295 |
+
training_mask[gt_instance == i] = 0
|
296 |
+
training_mask_distance[gt_instance == i] = 0
|
297 |
+
continue
|
298 |
+
|
299 |
+
kpoints = np.array(np.where(ind)).transpose(
|
300 |
+
(1, 0))[:, ::-1].astype('float32')
|
301 |
+
|
302 |
+
ind = (gt_instance == i) * (gt_kernel_instance == 0)
|
303 |
+
if np.sum(ind) == 0:
|
304 |
+
continue
|
305 |
+
pixels = np.where(ind)
|
306 |
+
|
307 |
+
points = np.array(pixels).transpose(
|
308 |
+
(1, 0))[:, ::-1].astype('float32')
|
309 |
+
|
310 |
+
bbox_ind = self.jaccard(points, kpoints)
|
311 |
+
|
312 |
+
offset_gt = kpoints[bbox_ind] - points
|
313 |
+
|
314 |
+
gt_distance[:, pixels[0], pixels[1]] = offset_gt.T * 0.1
|
315 |
+
|
316 |
+
img = Image.fromarray(img)
|
317 |
+
img = img.convert('RGB')
|
318 |
+
|
319 |
+
data["image"] = img
|
320 |
+
data["gt_kernel"] = gt_kernel.astype("int64")
|
321 |
+
data["training_mask"] = training_mask.astype("int64")
|
322 |
+
data["gt_instance"] = gt_instance.astype("int64")
|
323 |
+
data["gt_kernel_instance"] = gt_kernel_instance.astype("int64")
|
324 |
+
data["training_mask_distance"] = training_mask_distance.astype("int64")
|
325 |
+
data["gt_distance"] = gt_distance.astype("float32")
|
326 |
+
|
327 |
+
return data
|
328 |
+
|
329 |
+
|
330 |
+
class ScaleAlignedShort():
|
331 |
+
def __init__(self, short_size=640, **kwargs):
|
332 |
+
self.short_size = short_size
|
333 |
+
|
334 |
+
def __call__(self, data):
|
335 |
+
img = data['image']
|
336 |
+
|
337 |
+
org_img_shape = img.shape
|
338 |
+
|
339 |
+
h, w = img.shape[0:2]
|
340 |
+
scale = self.short_size * 1.0 / min(h, w)
|
341 |
+
h = int(h * scale + 0.5)
|
342 |
+
w = int(w * scale + 0.5)
|
343 |
+
if h % 32 != 0:
|
344 |
+
h = h + (32 - h % 32)
|
345 |
+
if w % 32 != 0:
|
346 |
+
w = w + (32 - w % 32)
|
347 |
+
img = cv2.resize(img, dsize=(w, h))
|
348 |
+
|
349 |
+
new_img_shape = img.shape
|
350 |
+
img_shape = np.array(org_img_shape + new_img_shape)
|
351 |
+
|
352 |
+
data['shape'] = img_shape
|
353 |
+
data['image'] = img
|
354 |
+
|
355 |
+
return data
|
ppocr/data/imaug/drrg_targets.py
ADDED
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is refer from:
|
16 |
+
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py
|
17 |
+
"""
|
18 |
+
|
19 |
+
import cv2
|
20 |
+
import numpy as np
|
21 |
+
from lanms import merge_quadrangle_n9 as la_nms
|
22 |
+
from numpy.linalg import norm
|
23 |
+
|
24 |
+
|
25 |
+
class DRRGTargets(object):
|
26 |
+
def __init__(self,
|
27 |
+
orientation_thr=2.0,
|
28 |
+
resample_step=8.0,
|
29 |
+
num_min_comps=9,
|
30 |
+
num_max_comps=600,
|
31 |
+
min_width=8.0,
|
32 |
+
max_width=24.0,
|
33 |
+
center_region_shrink_ratio=0.3,
|
34 |
+
comp_shrink_ratio=1.0,
|
35 |
+
comp_w_h_ratio=0.3,
|
36 |
+
text_comp_nms_thr=0.25,
|
37 |
+
min_rand_half_height=8.0,
|
38 |
+
max_rand_half_height=24.0,
|
39 |
+
jitter_level=0.2,
|
40 |
+
**kwargs):
|
41 |
+
|
42 |
+
super().__init__()
|
43 |
+
self.orientation_thr = orientation_thr
|
44 |
+
self.resample_step = resample_step
|
45 |
+
self.num_max_comps = num_max_comps
|
46 |
+
self.num_min_comps = num_min_comps
|
47 |
+
self.min_width = min_width
|
48 |
+
self.max_width = max_width
|
49 |
+
self.center_region_shrink_ratio = center_region_shrink_ratio
|
50 |
+
self.comp_shrink_ratio = comp_shrink_ratio
|
51 |
+
self.comp_w_h_ratio = comp_w_h_ratio
|
52 |
+
self.text_comp_nms_thr = text_comp_nms_thr
|
53 |
+
self.min_rand_half_height = min_rand_half_height
|
54 |
+
self.max_rand_half_height = max_rand_half_height
|
55 |
+
self.jitter_level = jitter_level
|
56 |
+
self.eps = 1e-8
|
57 |
+
|
58 |
+
def vector_angle(self, vec1, vec2):
|
59 |
+
if vec1.ndim > 1:
|
60 |
+
unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps).reshape((-1, 1))
|
61 |
+
else:
|
62 |
+
unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps)
|
63 |
+
if vec2.ndim > 1:
|
64 |
+
unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps).reshape((-1, 1))
|
65 |
+
else:
|
66 |
+
unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps)
|
67 |
+
return np.arccos(
|
68 |
+
np.clip(
|
69 |
+
np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
|
70 |
+
|
71 |
+
def vector_slope(self, vec):
|
72 |
+
assert len(vec) == 2
|
73 |
+
return abs(vec[1] / (vec[0] + self.eps))
|
74 |
+
|
75 |
+
def vector_sin(self, vec):
|
76 |
+
assert len(vec) == 2
|
77 |
+
return vec[1] / (norm(vec) + self.eps)
|
78 |
+
|
79 |
+
def vector_cos(self, vec):
|
80 |
+
assert len(vec) == 2
|
81 |
+
return vec[0] / (norm(vec) + self.eps)
|
82 |
+
|
83 |
+
def find_head_tail(self, points, orientation_thr):
|
84 |
+
|
85 |
+
assert points.ndim == 2
|
86 |
+
assert points.shape[0] >= 4
|
87 |
+
assert points.shape[1] == 2
|
88 |
+
assert isinstance(orientation_thr, float)
|
89 |
+
|
90 |
+
if len(points) > 4:
|
91 |
+
pad_points = np.vstack([points, points[0]])
|
92 |
+
edge_vec = pad_points[1:] - pad_points[:-1]
|
93 |
+
|
94 |
+
theta_sum = []
|
95 |
+
adjacent_vec_theta = []
|
96 |
+
for i, edge_vec1 in enumerate(edge_vec):
|
97 |
+
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
|
98 |
+
adjacent_edge_vec = edge_vec[adjacent_ind]
|
99 |
+
temp_theta_sum = np.sum(
|
100 |
+
self.vector_angle(edge_vec1, adjacent_edge_vec))
|
101 |
+
temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
|
102 |
+
adjacent_edge_vec[1])
|
103 |
+
theta_sum.append(temp_theta_sum)
|
104 |
+
adjacent_vec_theta.append(temp_adjacent_theta)
|
105 |
+
theta_sum_score = np.array(theta_sum) / np.pi
|
106 |
+
adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
|
107 |
+
poly_center = np.mean(points, axis=0)
|
108 |
+
edge_dist = np.maximum(
|
109 |
+
norm(
|
110 |
+
pad_points[1:] - poly_center, axis=-1),
|
111 |
+
norm(
|
112 |
+
pad_points[:-1] - poly_center, axis=-1))
|
113 |
+
dist_score = edge_dist / (np.max(edge_dist) + self.eps)
|
114 |
+
position_score = np.zeros(len(edge_vec))
|
115 |
+
score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
|
116 |
+
score += 0.35 * dist_score
|
117 |
+
if len(points) % 2 == 0:
|
118 |
+
position_score[(len(score) // 2 - 1)] += 1
|
119 |
+
position_score[-1] += 1
|
120 |
+
score += 0.1 * position_score
|
121 |
+
pad_score = np.concatenate([score, score])
|
122 |
+
score_matrix = np.zeros((len(score), len(score) - 3))
|
123 |
+
x = np.arange(len(score) - 3) / float(len(score) - 4)
|
124 |
+
gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
|
125 |
+
(x - 0.5) / 0.5, 2.) / 2)
|
126 |
+
gaussian = gaussian / np.max(gaussian)
|
127 |
+
for i in range(len(score)):
|
128 |
+
score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
|
129 |
+
score) - 1)] * gaussian * 0.3
|
130 |
+
|
131 |
+
head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
|
132 |
+
score_matrix.shape)
|
133 |
+
tail_start = (head_start + tail_increment + 2) % len(points)
|
134 |
+
head_end = (head_start + 1) % len(points)
|
135 |
+
tail_end = (tail_start + 1) % len(points)
|
136 |
+
|
137 |
+
if head_end > tail_end:
|
138 |
+
head_start, tail_start = tail_start, head_start
|
139 |
+
head_end, tail_end = tail_end, head_end
|
140 |
+
head_inds = [head_start, head_end]
|
141 |
+
tail_inds = [tail_start, tail_end]
|
142 |
+
else:
|
143 |
+
if self.vector_slope(points[1] - points[0]) + self.vector_slope(
|
144 |
+
points[3] - points[2]) < self.vector_slope(points[
|
145 |
+
2] - points[1]) + self.vector_slope(points[0] - points[
|
146 |
+
3]):
|
147 |
+
horizontal_edge_inds = [[0, 1], [2, 3]]
|
148 |
+
vertical_edge_inds = [[3, 0], [1, 2]]
|
149 |
+
else:
|
150 |
+
horizontal_edge_inds = [[3, 0], [1, 2]]
|
151 |
+
vertical_edge_inds = [[0, 1], [2, 3]]
|
152 |
+
|
153 |
+
vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
|
154 |
+
vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
|
155 |
+
0]] - points[vertical_edge_inds[1][1]])
|
156 |
+
horizontal_len_sum = norm(points[horizontal_edge_inds[0][
|
157 |
+
0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
|
158 |
+
horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
|
159 |
+
[1]])
|
160 |
+
|
161 |
+
if vertical_len_sum > horizontal_len_sum * orientation_thr:
|
162 |
+
head_inds = horizontal_edge_inds[0]
|
163 |
+
tail_inds = horizontal_edge_inds[1]
|
164 |
+
else:
|
165 |
+
head_inds = vertical_edge_inds[0]
|
166 |
+
tail_inds = vertical_edge_inds[1]
|
167 |
+
|
168 |
+
return head_inds, tail_inds
|
169 |
+
|
170 |
+
def reorder_poly_edge(self, points):
|
171 |
+
|
172 |
+
assert points.ndim == 2
|
173 |
+
assert points.shape[0] >= 4
|
174 |
+
assert points.shape[1] == 2
|
175 |
+
|
176 |
+
head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
|
177 |
+
head_edge, tail_edge = points[head_inds], points[tail_inds]
|
178 |
+
|
179 |
+
pad_points = np.vstack([points, points])
|
180 |
+
if tail_inds[1] < 1:
|
181 |
+
tail_inds[1] = len(points)
|
182 |
+
sideline1 = pad_points[head_inds[1]:tail_inds[1]]
|
183 |
+
sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
|
184 |
+
sideline_mean_shift = np.mean(
|
185 |
+
sideline1, axis=0) - np.mean(
|
186 |
+
sideline2, axis=0)
|
187 |
+
|
188 |
+
if sideline_mean_shift[1] > 0:
|
189 |
+
top_sideline, bot_sideline = sideline2, sideline1
|
190 |
+
else:
|
191 |
+
top_sideline, bot_sideline = sideline1, sideline2
|
192 |
+
|
193 |
+
return head_edge, tail_edge, top_sideline, bot_sideline
|
194 |
+
|
195 |
+
def cal_curve_length(self, line):
|
196 |
+
|
197 |
+
assert line.ndim == 2
|
198 |
+
assert len(line) >= 2
|
199 |
+
|
200 |
+
edges_length = np.sqrt((line[1:, 0] - line[:-1, 0])**2 + (line[
|
201 |
+
1:, 1] - line[:-1, 1])**2)
|
202 |
+
total_length = np.sum(edges_length)
|
203 |
+
return edges_length, total_length
|
204 |
+
|
205 |
+
def resample_line(self, line, n):
|
206 |
+
|
207 |
+
assert line.ndim == 2
|
208 |
+
assert line.shape[0] >= 2
|
209 |
+
assert line.shape[1] == 2
|
210 |
+
assert isinstance(n, int)
|
211 |
+
assert n > 2
|
212 |
+
|
213 |
+
edges_length, total_length = self.cal_curve_length(line)
|
214 |
+
t_org = np.insert(np.cumsum(edges_length), 0, 0)
|
215 |
+
unit_t = total_length / (n - 1)
|
216 |
+
t_equidistant = np.arange(1, n - 1, dtype=np.float32) * unit_t
|
217 |
+
edge_ind = 0
|
218 |
+
points = [line[0]]
|
219 |
+
for t in t_equidistant:
|
220 |
+
while edge_ind < len(edges_length) - 1 and t > t_org[edge_ind + 1]:
|
221 |
+
edge_ind += 1
|
222 |
+
t_l, t_r = t_org[edge_ind], t_org[edge_ind + 1]
|
223 |
+
weight = np.array(
|
224 |
+
[t_r - t, t - t_l], dtype=np.float32) / (t_r - t_l + self.eps)
|
225 |
+
p_coords = np.dot(weight, line[[edge_ind, edge_ind + 1]])
|
226 |
+
points.append(p_coords)
|
227 |
+
points.append(line[-1])
|
228 |
+
resampled_line = np.vstack(points)
|
229 |
+
|
230 |
+
return resampled_line
|
231 |
+
|
232 |
+
def resample_sidelines(self, sideline1, sideline2, resample_step):
|
233 |
+
|
234 |
+
assert sideline1.ndim == sideline2.ndim == 2
|
235 |
+
assert sideline1.shape[1] == sideline2.shape[1] == 2
|
236 |
+
assert sideline1.shape[0] >= 2
|
237 |
+
assert sideline2.shape[0] >= 2
|
238 |
+
assert isinstance(resample_step, float)
|
239 |
+
|
240 |
+
_, length1 = self.cal_curve_length(sideline1)
|
241 |
+
_, length2 = self.cal_curve_length(sideline2)
|
242 |
+
|
243 |
+
avg_length = (length1 + length2) / 2
|
244 |
+
resample_point_num = max(int(float(avg_length) / resample_step) + 1, 3)
|
245 |
+
|
246 |
+
resampled_line1 = self.resample_line(sideline1, resample_point_num)
|
247 |
+
resampled_line2 = self.resample_line(sideline2, resample_point_num)
|
248 |
+
|
249 |
+
return resampled_line1, resampled_line2
|
250 |
+
|
251 |
+
def dist_point2line(self, point, line):
|
252 |
+
|
253 |
+
assert isinstance(line, tuple)
|
254 |
+
point1, point2 = line
|
255 |
+
d = abs(np.cross(point2 - point1, point - point1)) / (
|
256 |
+
norm(point2 - point1) + 1e-8)
|
257 |
+
return d
|
258 |
+
|
259 |
+
def draw_center_region_maps(self, top_line, bot_line, center_line,
|
260 |
+
center_region_mask, top_height_map,
|
261 |
+
bot_height_map, sin_map, cos_map,
|
262 |
+
region_shrink_ratio):
|
263 |
+
|
264 |
+
assert top_line.shape == bot_line.shape == center_line.shape
|
265 |
+
assert (center_region_mask.shape == top_height_map.shape ==
|
266 |
+
bot_height_map.shape == sin_map.shape == cos_map.shape)
|
267 |
+
assert isinstance(region_shrink_ratio, float)
|
268 |
+
|
269 |
+
h, w = center_region_mask.shape
|
270 |
+
for i in range(0, len(center_line) - 1):
|
271 |
+
|
272 |
+
top_mid_point = (top_line[i] + top_line[i + 1]) / 2
|
273 |
+
bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
|
274 |
+
|
275 |
+
sin_theta = self.vector_sin(top_mid_point - bot_mid_point)
|
276 |
+
cos_theta = self.vector_cos(top_mid_point - bot_mid_point)
|
277 |
+
|
278 |
+
tl = center_line[i] + (top_line[i] - center_line[i]
|
279 |
+
) * region_shrink_ratio
|
280 |
+
tr = center_line[i + 1] + (top_line[i + 1] - center_line[i + 1]
|
281 |
+
) * region_shrink_ratio
|
282 |
+
br = center_line[i + 1] + (bot_line[i + 1] - center_line[i + 1]
|
283 |
+
) * region_shrink_ratio
|
284 |
+
bl = center_line[i] + (bot_line[i] - center_line[i]
|
285 |
+
) * region_shrink_ratio
|
286 |
+
current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
|
287 |
+
|
288 |
+
cv2.fillPoly(center_region_mask, [current_center_box], color=1)
|
289 |
+
cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
|
290 |
+
cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
|
291 |
+
|
292 |
+
current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0,
|
293 |
+
w - 1)
|
294 |
+
current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0,
|
295 |
+
h - 1)
|
296 |
+
min_coord = np.min(current_center_box, axis=0).astype(np.int32)
|
297 |
+
max_coord = np.max(current_center_box, axis=0).astype(np.int32)
|
298 |
+
current_center_box = current_center_box - min_coord
|
299 |
+
box_sz = (max_coord - min_coord + 1)
|
300 |
+
|
301 |
+
center_box_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
|
302 |
+
cv2.fillPoly(center_box_mask, [current_center_box], color=1)
|
303 |
+
|
304 |
+
inds = np.argwhere(center_box_mask > 0)
|
305 |
+
inds = inds + (min_coord[1], min_coord[0])
|
306 |
+
inds_xy = np.fliplr(inds)
|
307 |
+
top_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
|
308 |
+
inds_xy, (top_line[i], top_line[i + 1]))
|
309 |
+
bot_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
|
310 |
+
inds_xy, (bot_line[i], bot_line[i + 1]))
|
311 |
+
|
312 |
+
def generate_center_mask_attrib_maps(self, img_size, text_polys):
|
313 |
+
|
314 |
+
assert isinstance(img_size, tuple)
|
315 |
+
|
316 |
+
h, w = img_size
|
317 |
+
|
318 |
+
center_lines = []
|
319 |
+
center_region_mask = np.zeros((h, w), np.uint8)
|
320 |
+
top_height_map = np.zeros((h, w), dtype=np.float32)
|
321 |
+
bot_height_map = np.zeros((h, w), dtype=np.float32)
|
322 |
+
sin_map = np.zeros((h, w), dtype=np.float32)
|
323 |
+
cos_map = np.zeros((h, w), dtype=np.float32)
|
324 |
+
|
325 |
+
for poly in text_polys:
|
326 |
+
polygon_points = poly
|
327 |
+
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
|
328 |
+
resampled_top_line, resampled_bot_line = self.resample_sidelines(
|
329 |
+
top_line, bot_line, self.resample_step)
|
330 |
+
resampled_bot_line = resampled_bot_line[::-1]
|
331 |
+
center_line = (resampled_top_line + resampled_bot_line) / 2
|
332 |
+
|
333 |
+
if self.vector_slope(center_line[-1] - center_line[0]) > 2:
|
334 |
+
if (center_line[-1] - center_line[0])[1] < 0:
|
335 |
+
center_line = center_line[::-1]
|
336 |
+
resampled_top_line = resampled_top_line[::-1]
|
337 |
+
resampled_bot_line = resampled_bot_line[::-1]
|
338 |
+
else:
|
339 |
+
if (center_line[-1] - center_line[0])[0] < 0:
|
340 |
+
center_line = center_line[::-1]
|
341 |
+
resampled_top_line = resampled_top_line[::-1]
|
342 |
+
resampled_bot_line = resampled_bot_line[::-1]
|
343 |
+
|
344 |
+
line_head_shrink_len = np.clip(
|
345 |
+
(norm(top_line[0] - bot_line[0]) * self.comp_w_h_ratio),
|
346 |
+
self.min_width, self.max_width) / 2
|
347 |
+
line_tail_shrink_len = np.clip(
|
348 |
+
(norm(top_line[-1] - bot_line[-1]) * self.comp_w_h_ratio),
|
349 |
+
self.min_width, self.max_width) / 2
|
350 |
+
num_head_shrink = int(line_head_shrink_len // self.resample_step)
|
351 |
+
num_tail_shrink = int(line_tail_shrink_len // self.resample_step)
|
352 |
+
if len(center_line) > num_head_shrink + num_tail_shrink + 2:
|
353 |
+
center_line = center_line[num_head_shrink:len(center_line) -
|
354 |
+
num_tail_shrink]
|
355 |
+
resampled_top_line = resampled_top_line[num_head_shrink:len(
|
356 |
+
resampled_top_line) - num_tail_shrink]
|
357 |
+
resampled_bot_line = resampled_bot_line[num_head_shrink:len(
|
358 |
+
resampled_bot_line) - num_tail_shrink]
|
359 |
+
center_lines.append(center_line.astype(np.int32))
|
360 |
+
|
361 |
+
self.draw_center_region_maps(
|
362 |
+
resampled_top_line, resampled_bot_line, center_line,
|
363 |
+
center_region_mask, top_height_map, bot_height_map, sin_map,
|
364 |
+
cos_map, self.center_region_shrink_ratio)
|
365 |
+
|
366 |
+
return (center_lines, center_region_mask, top_height_map,
|
367 |
+
bot_height_map, sin_map, cos_map)
|
368 |
+
|
369 |
+
def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
|
370 |
+
|
371 |
+
assert isinstance(num_rand_comps, int)
|
372 |
+
assert num_rand_comps > 0
|
373 |
+
assert center_sample_mask.ndim == 2
|
374 |
+
|
375 |
+
h, w = center_sample_mask.shape
|
376 |
+
|
377 |
+
max_rand_half_height = self.max_rand_half_height
|
378 |
+
min_rand_half_height = self.min_rand_half_height
|
379 |
+
max_rand_height = max_rand_half_height * 2
|
380 |
+
max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
|
381 |
+
self.min_width, self.max_width)
|
382 |
+
margin = int(
|
383 |
+
np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
|
384 |
+
|
385 |
+
if 2 * margin + 1 > min(h, w):
|
386 |
+
|
387 |
+
assert min(h, w) > (np.sqrt(2) * (self.min_width + 1))
|
388 |
+
max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1)
|
389 |
+
min_rand_half_height = max(max_rand_half_height / 4,
|
390 |
+
self.min_width / 2)
|
391 |
+
|
392 |
+
max_rand_height = max_rand_half_height * 2
|
393 |
+
max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
|
394 |
+
self.min_width, self.max_width)
|
395 |
+
margin = int(
|
396 |
+
np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
|
397 |
+
|
398 |
+
inner_center_sample_mask = np.zeros_like(center_sample_mask)
|
399 |
+
inner_center_sample_mask[margin:h - margin, margin:w - margin] = \
|
400 |
+
center_sample_mask[margin:h - margin, margin:w - margin]
|
401 |
+
kernel_size = int(np.clip(max_rand_half_height, 7, 21))
|
402 |
+
inner_center_sample_mask = cv2.erode(
|
403 |
+
inner_center_sample_mask,
|
404 |
+
np.ones((kernel_size, kernel_size), np.uint8))
|
405 |
+
|
406 |
+
center_candidates = np.argwhere(inner_center_sample_mask > 0)
|
407 |
+
num_center_candidates = len(center_candidates)
|
408 |
+
sample_inds = np.random.choice(num_center_candidates, num_rand_comps)
|
409 |
+
rand_centers = center_candidates[sample_inds]
|
410 |
+
|
411 |
+
rand_top_height = np.random.randint(
|
412 |
+
min_rand_half_height,
|
413 |
+
max_rand_half_height,
|
414 |
+
size=(len(rand_centers), 1))
|
415 |
+
rand_bot_height = np.random.randint(
|
416 |
+
min_rand_half_height,
|
417 |
+
max_rand_half_height,
|
418 |
+
size=(len(rand_centers), 1))
|
419 |
+
|
420 |
+
rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
|
421 |
+
rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
|
422 |
+
scale = np.sqrt(1.0 / (rand_cos**2 + rand_sin**2 + 1e-8))
|
423 |
+
rand_cos = rand_cos * scale
|
424 |
+
rand_sin = rand_sin * scale
|
425 |
+
|
426 |
+
height = (rand_top_height + rand_bot_height)
|
427 |
+
width = np.clip(height * self.comp_w_h_ratio, self.min_width,
|
428 |
+
self.max_width)
|
429 |
+
|
430 |
+
rand_comp_attribs = np.hstack([
|
431 |
+
rand_centers[:, ::-1], height, width, rand_cos, rand_sin,
|
432 |
+
np.zeros_like(rand_sin)
|
433 |
+
]).astype(np.float32)
|
434 |
+
|
435 |
+
return rand_comp_attribs
|
436 |
+
|
437 |
+
def jitter_comp_attribs(self, comp_attribs, jitter_level):
|
438 |
+
"""Jitter text components attributes.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
comp_attribs (ndarray): The text component attributes.
|
442 |
+
jitter_level (float): The jitter level of text components
|
443 |
+
attributes.
|
444 |
+
|
445 |
+
Returns:
|
446 |
+
jittered_comp_attribs (ndarray): The jittered text component
|
447 |
+
attributes (x, y, h, w, cos, sin, comp_label).
|
448 |
+
"""
|
449 |
+
|
450 |
+
assert comp_attribs.shape[1] == 7
|
451 |
+
assert comp_attribs.shape[0] > 0
|
452 |
+
assert isinstance(jitter_level, float)
|
453 |
+
|
454 |
+
x = comp_attribs[:, 0].reshape((-1, 1))
|
455 |
+
y = comp_attribs[:, 1].reshape((-1, 1))
|
456 |
+
h = comp_attribs[:, 2].reshape((-1, 1))
|
457 |
+
w = comp_attribs[:, 3].reshape((-1, 1))
|
458 |
+
cos = comp_attribs[:, 4].reshape((-1, 1))
|
459 |
+
sin = comp_attribs[:, 5].reshape((-1, 1))
|
460 |
+
comp_labels = comp_attribs[:, 6].reshape((-1, 1))
|
461 |
+
|
462 |
+
x += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
|
463 |
+
h * np.abs(cos) + w * np.abs(sin)) * jitter_level
|
464 |
+
y += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
|
465 |
+
h * np.abs(sin) + w * np.abs(cos)) * jitter_level
|
466 |
+
|
467 |
+
h += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
|
468 |
+
) * h * jitter_level
|
469 |
+
w += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
|
470 |
+
) * w * jitter_level
|
471 |
+
|
472 |
+
cos += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
|
473 |
+
) * 2 * jitter_level
|
474 |
+
sin += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
|
475 |
+
) * 2 * jitter_level
|
476 |
+
|
477 |
+
scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8))
|
478 |
+
cos = cos * scale
|
479 |
+
sin = sin * scale
|
480 |
+
|
481 |
+
jittered_comp_attribs = np.hstack([x, y, h, w, cos, sin, comp_labels])
|
482 |
+
|
483 |
+
return jittered_comp_attribs
|
484 |
+
|
485 |
+
def generate_comp_attribs(self, center_lines, text_mask, center_region_mask,
|
486 |
+
top_height_map, bot_height_map, sin_map, cos_map):
|
487 |
+
"""Generate text component attributes.
|
488 |
+
|
489 |
+
Args:
|
490 |
+
center_lines (list[ndarray]): The list of text center lines .
|
491 |
+
text_mask (ndarray): The text region mask.
|
492 |
+
center_region_mask (ndarray): The text center region mask.
|
493 |
+
top_height_map (ndarray): The map on which the distance from points
|
494 |
+
to top side lines will be drawn for each pixel in text center
|
495 |
+
regions.
|
496 |
+
bot_height_map (ndarray): The map on which the distance from points
|
497 |
+
to bottom side lines will be drawn for each pixel in text
|
498 |
+
center regions.
|
499 |
+
sin_map (ndarray): The sin(theta) map where theta is the angle
|
500 |
+
between vector (top point - bottom point) and vector (1, 0).
|
501 |
+
cos_map (ndarray): The cos(theta) map where theta is the angle
|
502 |
+
between vector (top point - bottom point) and vector (1, 0).
|
503 |
+
|
504 |
+
Returns:
|
505 |
+
pad_comp_attribs (ndarray): The padded text component attributes
|
506 |
+
of a fixed size.
|
507 |
+
"""
|
508 |
+
|
509 |
+
assert isinstance(center_lines, list)
|
510 |
+
assert (
|
511 |
+
text_mask.shape == center_region_mask.shape == top_height_map.shape
|
512 |
+
== bot_height_map.shape == sin_map.shape == cos_map.shape)
|
513 |
+
|
514 |
+
center_lines_mask = np.zeros_like(center_region_mask)
|
515 |
+
cv2.polylines(center_lines_mask, center_lines, 0, 1, 1)
|
516 |
+
center_lines_mask = center_lines_mask * center_region_mask
|
517 |
+
comp_centers = np.argwhere(center_lines_mask > 0)
|
518 |
+
|
519 |
+
y = comp_centers[:, 0]
|
520 |
+
x = comp_centers[:, 1]
|
521 |
+
|
522 |
+
top_height = top_height_map[y, x].reshape(
|
523 |
+
(-1, 1)) * self.comp_shrink_ratio
|
524 |
+
bot_height = bot_height_map[y, x].reshape(
|
525 |
+
(-1, 1)) * self.comp_shrink_ratio
|
526 |
+
sin = sin_map[y, x].reshape((-1, 1))
|
527 |
+
cos = cos_map[y, x].reshape((-1, 1))
|
528 |
+
|
529 |
+
top_mid_points = comp_centers + np.hstack(
|
530 |
+
[top_height * sin, top_height * cos])
|
531 |
+
bot_mid_points = comp_centers - np.hstack(
|
532 |
+
[bot_height * sin, bot_height * cos])
|
533 |
+
|
534 |
+
width = (top_height + bot_height) * self.comp_w_h_ratio
|
535 |
+
width = np.clip(width, self.min_width, self.max_width)
|
536 |
+
r = width / 2
|
537 |
+
|
538 |
+
tl = top_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
|
539 |
+
tr = top_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
|
540 |
+
br = bot_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
|
541 |
+
bl = bot_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
|
542 |
+
text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
|
543 |
+
|
544 |
+
score = np.ones((text_comps.shape[0], 1), dtype=np.float32)
|
545 |
+
text_comps = np.hstack([text_comps, score])
|
546 |
+
text_comps = la_nms(text_comps, self.text_comp_nms_thr)
|
547 |
+
|
548 |
+
if text_comps.shape[0] >= 1:
|
549 |
+
img_h, img_w = center_region_mask.shape
|
550 |
+
text_comps[:, 0:8:2] = np.clip(text_comps[:, 0:8:2], 0, img_w - 1)
|
551 |
+
text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1)
|
552 |
+
|
553 |
+
comp_centers = np.mean(
|
554 |
+
text_comps[:, 0:8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
|
555 |
+
x = comp_centers[:, 0]
|
556 |
+
y = comp_centers[:, 1]
|
557 |
+
|
558 |
+
height = (top_height_map[y, x] + bot_height_map[y, x]).reshape(
|
559 |
+
(-1, 1))
|
560 |
+
width = np.clip(height * self.comp_w_h_ratio, self.min_width,
|
561 |
+
self.max_width)
|
562 |
+
|
563 |
+
cos = cos_map[y, x].reshape((-1, 1))
|
564 |
+
sin = sin_map[y, x].reshape((-1, 1))
|
565 |
+
|
566 |
+
_, comp_label_mask = cv2.connectedComponents(
|
567 |
+
center_region_mask, connectivity=8)
|
568 |
+
comp_labels = comp_label_mask[y, x].reshape(
|
569 |
+
(-1, 1)).astype(np.float32)
|
570 |
+
|
571 |
+
x = x.reshape((-1, 1)).astype(np.float32)
|
572 |
+
y = y.reshape((-1, 1)).astype(np.float32)
|
573 |
+
comp_attribs = np.hstack(
|
574 |
+
[x, y, height, width, cos, sin, comp_labels])
|
575 |
+
comp_attribs = self.jitter_comp_attribs(comp_attribs,
|
576 |
+
self.jitter_level)
|
577 |
+
|
578 |
+
if comp_attribs.shape[0] < self.num_min_comps:
|
579 |
+
num_rand_comps = self.num_min_comps - comp_attribs.shape[0]
|
580 |
+
rand_comp_attribs = self.generate_rand_comp_attribs(
|
581 |
+
num_rand_comps, 1 - text_mask)
|
582 |
+
comp_attribs = np.vstack([comp_attribs, rand_comp_attribs])
|
583 |
+
else:
|
584 |
+
comp_attribs = self.generate_rand_comp_attribs(self.num_min_comps,
|
585 |
+
1 - text_mask)
|
586 |
+
|
587 |
+
num_comps = (np.ones(
|
588 |
+
(comp_attribs.shape[0], 1),
|
589 |
+
dtype=np.float32) * comp_attribs.shape[0])
|
590 |
+
comp_attribs = np.hstack([num_comps, comp_attribs])
|
591 |
+
|
592 |
+
if comp_attribs.shape[0] > self.num_max_comps:
|
593 |
+
comp_attribs = comp_attribs[:self.num_max_comps, :]
|
594 |
+
comp_attribs[:, 0] = self.num_max_comps
|
595 |
+
|
596 |
+
pad_comp_attribs = np.zeros(
|
597 |
+
(self.num_max_comps, comp_attribs.shape[1]), dtype=np.float32)
|
598 |
+
pad_comp_attribs[:comp_attribs.shape[0], :] = comp_attribs
|
599 |
+
|
600 |
+
return pad_comp_attribs
|
601 |
+
|
602 |
+
def generate_text_region_mask(self, img_size, text_polys):
|
603 |
+
"""Generate text center region mask and geometry attribute maps.
|
604 |
+
|
605 |
+
Args:
|
606 |
+
img_size (tuple): The image size (height, width).
|
607 |
+
text_polys (list[list[ndarray]]): The list of text polygons.
|
608 |
+
|
609 |
+
Returns:
|
610 |
+
text_region_mask (ndarray): The text region mask.
|
611 |
+
"""
|
612 |
+
|
613 |
+
assert isinstance(img_size, tuple)
|
614 |
+
|
615 |
+
h, w = img_size
|
616 |
+
text_region_mask = np.zeros((h, w), dtype=np.uint8)
|
617 |
+
|
618 |
+
for poly in text_polys:
|
619 |
+
polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
|
620 |
+
cv2.fillPoly(text_region_mask, polygon, 1)
|
621 |
+
|
622 |
+
return text_region_mask
|
623 |
+
|
624 |
+
def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
|
625 |
+
"""Generate effective mask by setting the ineffective regions to 0 and
|
626 |
+
effective regions to 1.
|
627 |
+
|
628 |
+
Args:
|
629 |
+
mask_size (tuple): The mask size.
|
630 |
+
polygons_ignore (list[[ndarray]]: The list of ignored text
|
631 |
+
polygons.
|
632 |
+
|
633 |
+
Returns:
|
634 |
+
mask (ndarray): The effective mask of (height, width).
|
635 |
+
"""
|
636 |
+
mask = np.ones(mask_size, dtype=np.uint8)
|
637 |
+
|
638 |
+
for poly in polygons_ignore:
|
639 |
+
instance = poly.astype(np.int32).reshape(1, -1, 2)
|
640 |
+
cv2.fillPoly(mask, instance, 0)
|
641 |
+
|
642 |
+
return mask
|
643 |
+
|
644 |
+
def generate_targets(self, data):
|
645 |
+
"""Generate the gt targets for DRRG.
|
646 |
+
|
647 |
+
Args:
|
648 |
+
data (dict): The input result dictionary.
|
649 |
+
|
650 |
+
Returns:
|
651 |
+
data (dict): The output result dictionary.
|
652 |
+
"""
|
653 |
+
|
654 |
+
assert isinstance(data, dict)
|
655 |
+
|
656 |
+
image = data['image']
|
657 |
+
polygons = data['polys']
|
658 |
+
ignore_tags = data['ignore_tags']
|
659 |
+
h, w, _ = image.shape
|
660 |
+
|
661 |
+
polygon_masks = []
|
662 |
+
polygon_masks_ignore = []
|
663 |
+
for tag, polygon in zip(ignore_tags, polygons):
|
664 |
+
if tag is True:
|
665 |
+
polygon_masks_ignore.append(polygon)
|
666 |
+
else:
|
667 |
+
polygon_masks.append(polygon)
|
668 |
+
|
669 |
+
gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
|
670 |
+
gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
|
671 |
+
(center_lines, gt_center_region_mask, gt_top_height_map,
|
672 |
+
gt_bot_height_map, gt_sin_map,
|
673 |
+
gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
|
674 |
+
polygon_masks)
|
675 |
+
|
676 |
+
gt_comp_attribs = self.generate_comp_attribs(
|
677 |
+
center_lines, gt_text_mask, gt_center_region_mask,
|
678 |
+
gt_top_height_map, gt_bot_height_map, gt_sin_map, gt_cos_map)
|
679 |
+
|
680 |
+
mapping = {
|
681 |
+
'gt_text_mask': gt_text_mask,
|
682 |
+
'gt_center_region_mask': gt_center_region_mask,
|
683 |
+
'gt_mask': gt_mask,
|
684 |
+
'gt_top_height_map': gt_top_height_map,
|
685 |
+
'gt_bot_height_map': gt_bot_height_map,
|
686 |
+
'gt_sin_map': gt_sin_map,
|
687 |
+
'gt_cos_map': gt_cos_map
|
688 |
+
}
|
689 |
+
|
690 |
+
data.update(mapping)
|
691 |
+
data['gt_comp_attribs'] = gt_comp_attribs
|
692 |
+
return data
|
693 |
+
|
694 |
+
def __call__(self, data):
|
695 |
+
data = self.generate_targets(data)
|
696 |
+
return data
|
ppocr/data/imaug/east_process.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
#Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
#you may not use this file except in compliance with the License.
|
5 |
+
#You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
#Unless required by applicable law or agreed to in writing, software
|
10 |
+
#distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
#See the License for the specific language governing permissions and
|
13 |
+
#limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is refered from:
|
16 |
+
https://github.com/songdejia/EAST/blob/master/data_utils.py
|
17 |
+
"""
|
18 |
+
import math
|
19 |
+
import cv2
|
20 |
+
import numpy as np
|
21 |
+
import json
|
22 |
+
import sys
|
23 |
+
import os
|
24 |
+
|
25 |
+
__all__ = ['EASTProcessTrain']
|
26 |
+
|
27 |
+
|
28 |
+
class EASTProcessTrain(object):
|
29 |
+
def __init__(self,
|
30 |
+
image_shape=[512, 512],
|
31 |
+
background_ratio=0.125,
|
32 |
+
min_crop_side_ratio=0.1,
|
33 |
+
min_text_size=10,
|
34 |
+
**kwargs):
|
35 |
+
self.input_size = image_shape[1]
|
36 |
+
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
|
37 |
+
self.background_ratio = background_ratio
|
38 |
+
self.min_crop_side_ratio = min_crop_side_ratio
|
39 |
+
self.min_text_size = min_text_size
|
40 |
+
|
41 |
+
def preprocess(self, im):
|
42 |
+
input_size = self.input_size
|
43 |
+
im_shape = im.shape
|
44 |
+
im_size_min = np.min(im_shape[0:2])
|
45 |
+
im_size_max = np.max(im_shape[0:2])
|
46 |
+
im_scale = float(input_size) / float(im_size_max)
|
47 |
+
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
|
48 |
+
img_mean = [0.485, 0.456, 0.406]
|
49 |
+
img_std = [0.229, 0.224, 0.225]
|
50 |
+
# im = im[:, :, ::-1].astype(np.float32)
|
51 |
+
im = im / 255
|
52 |
+
im -= img_mean
|
53 |
+
im /= img_std
|
54 |
+
new_h, new_w, _ = im.shape
|
55 |
+
im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
|
56 |
+
im_padded[:new_h, :new_w, :] = im
|
57 |
+
im_padded = im_padded.transpose((2, 0, 1))
|
58 |
+
im_padded = im_padded[np.newaxis, :]
|
59 |
+
return im_padded, im_scale
|
60 |
+
|
61 |
+
def rotate_im_poly(self, im, text_polys):
|
62 |
+
"""
|
63 |
+
rotate image with 90 / 180 / 270 degre
|
64 |
+
"""
|
65 |
+
im_w, im_h = im.shape[1], im.shape[0]
|
66 |
+
dst_im = im.copy()
|
67 |
+
dst_polys = []
|
68 |
+
rand_degree_ratio = np.random.rand()
|
69 |
+
rand_degree_cnt = 1
|
70 |
+
if 0.333 < rand_degree_ratio < 0.666:
|
71 |
+
rand_degree_cnt = 2
|
72 |
+
elif rand_degree_ratio > 0.666:
|
73 |
+
rand_degree_cnt = 3
|
74 |
+
for i in range(rand_degree_cnt):
|
75 |
+
dst_im = np.rot90(dst_im)
|
76 |
+
rot_degree = -90 * rand_degree_cnt
|
77 |
+
rot_angle = rot_degree * math.pi / 180.0
|
78 |
+
n_poly = text_polys.shape[0]
|
79 |
+
cx, cy = 0.5 * im_w, 0.5 * im_h
|
80 |
+
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
|
81 |
+
for i in range(n_poly):
|
82 |
+
wordBB = text_polys[i]
|
83 |
+
poly = []
|
84 |
+
for j in range(4):
|
85 |
+
sx, sy = wordBB[j][0], wordBB[j][1]
|
86 |
+
dx = math.cos(rot_angle) * (sx - cx)\
|
87 |
+
- math.sin(rot_angle) * (sy - cy) + ncx
|
88 |
+
dy = math.sin(rot_angle) * (sx - cx)\
|
89 |
+
+ math.cos(rot_angle) * (sy - cy) + ncy
|
90 |
+
poly.append([dx, dy])
|
91 |
+
dst_polys.append(poly)
|
92 |
+
dst_polys = np.array(dst_polys, dtype=np.float32)
|
93 |
+
return dst_im, dst_polys
|
94 |
+
|
95 |
+
def polygon_area(self, poly):
|
96 |
+
"""
|
97 |
+
compute area of a polygon
|
98 |
+
:param poly:
|
99 |
+
:return:
|
100 |
+
"""
|
101 |
+
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
102 |
+
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
103 |
+
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
104 |
+
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
|
105 |
+
return np.sum(edge) / 2.
|
106 |
+
|
107 |
+
def check_and_validate_polys(self, polys, tags, img_height, img_width):
|
108 |
+
"""
|
109 |
+
check so that the text poly is in the same direction,
|
110 |
+
and also filter some invalid polygons
|
111 |
+
:param polys:
|
112 |
+
:param tags:
|
113 |
+
:return:
|
114 |
+
"""
|
115 |
+
h, w = img_height, img_width
|
116 |
+
if polys.shape[0] == 0:
|
117 |
+
return polys
|
118 |
+
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
119 |
+
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
120 |
+
|
121 |
+
validated_polys = []
|
122 |
+
validated_tags = []
|
123 |
+
for poly, tag in zip(polys, tags):
|
124 |
+
p_area = self.polygon_area(poly)
|
125 |
+
#invalid poly
|
126 |
+
if abs(p_area) < 1:
|
127 |
+
continue
|
128 |
+
if p_area > 0:
|
129 |
+
#'poly in wrong direction'
|
130 |
+
if not tag:
|
131 |
+
tag = True #reversed cases should be ignore
|
132 |
+
poly = poly[(0, 3, 2, 1), :]
|
133 |
+
validated_polys.append(poly)
|
134 |
+
validated_tags.append(tag)
|
135 |
+
return np.array(validated_polys), np.array(validated_tags)
|
136 |
+
|
137 |
+
def draw_img_polys(self, img, polys):
|
138 |
+
if len(img.shape) == 4:
|
139 |
+
img = np.squeeze(img, axis=0)
|
140 |
+
if img.shape[0] == 3:
|
141 |
+
img = img.transpose((1, 2, 0))
|
142 |
+
img[:, :, 2] += 123.68
|
143 |
+
img[:, :, 1] += 116.78
|
144 |
+
img[:, :, 0] += 103.94
|
145 |
+
cv2.imwrite("tmp.jpg", img)
|
146 |
+
img = cv2.imread("tmp.jpg")
|
147 |
+
for box in polys:
|
148 |
+
box = box.astype(np.int32).reshape((-1, 1, 2))
|
149 |
+
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
|
150 |
+
import random
|
151 |
+
ino = random.randint(0, 100)
|
152 |
+
cv2.imwrite("tmp_%d.jpg" % ino, img)
|
153 |
+
return
|
154 |
+
|
155 |
+
def shrink_poly(self, poly, r):
|
156 |
+
"""
|
157 |
+
fit a poly inside the origin poly, maybe bugs here...
|
158 |
+
used for generate the score map
|
159 |
+
:param poly: the text poly
|
160 |
+
:param r: r in the paper
|
161 |
+
:return: the shrinked poly
|
162 |
+
"""
|
163 |
+
# shrink ratio
|
164 |
+
R = 0.3
|
165 |
+
# find the longer pair
|
166 |
+
dist0 = np.linalg.norm(poly[0] - poly[1])
|
167 |
+
dist1 = np.linalg.norm(poly[2] - poly[3])
|
168 |
+
dist2 = np.linalg.norm(poly[0] - poly[3])
|
169 |
+
dist3 = np.linalg.norm(poly[1] - poly[2])
|
170 |
+
if dist0 + dist1 > dist2 + dist3:
|
171 |
+
# first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
|
172 |
+
## p0, p1
|
173 |
+
theta = np.arctan2((poly[1][1] - poly[0][1]),
|
174 |
+
(poly[1][0] - poly[0][0]))
|
175 |
+
poly[0][0] += R * r[0] * np.cos(theta)
|
176 |
+
poly[0][1] += R * r[0] * np.sin(theta)
|
177 |
+
poly[1][0] -= R * r[1] * np.cos(theta)
|
178 |
+
poly[1][1] -= R * r[1] * np.sin(theta)
|
179 |
+
## p2, p3
|
180 |
+
theta = np.arctan2((poly[2][1] - poly[3][1]),
|
181 |
+
(poly[2][0] - poly[3][0]))
|
182 |
+
poly[3][0] += R * r[3] * np.cos(theta)
|
183 |
+
poly[3][1] += R * r[3] * np.sin(theta)
|
184 |
+
poly[2][0] -= R * r[2] * np.cos(theta)
|
185 |
+
poly[2][1] -= R * r[2] * np.sin(theta)
|
186 |
+
## p0, p3
|
187 |
+
theta = np.arctan2((poly[3][0] - poly[0][0]),
|
188 |
+
(poly[3][1] - poly[0][1]))
|
189 |
+
poly[0][0] += R * r[0] * np.sin(theta)
|
190 |
+
poly[0][1] += R * r[0] * np.cos(theta)
|
191 |
+
poly[3][0] -= R * r[3] * np.sin(theta)
|
192 |
+
poly[3][1] -= R * r[3] * np.cos(theta)
|
193 |
+
## p1, p2
|
194 |
+
theta = np.arctan2((poly[2][0] - poly[1][0]),
|
195 |
+
(poly[2][1] - poly[1][1]))
|
196 |
+
poly[1][0] += R * r[1] * np.sin(theta)
|
197 |
+
poly[1][1] += R * r[1] * np.cos(theta)
|
198 |
+
poly[2][0] -= R * r[2] * np.sin(theta)
|
199 |
+
poly[2][1] -= R * r[2] * np.cos(theta)
|
200 |
+
else:
|
201 |
+
## p0, p3
|
202 |
+
# print poly
|
203 |
+
theta = np.arctan2((poly[3][0] - poly[0][0]),
|
204 |
+
(poly[3][1] - poly[0][1]))
|
205 |
+
poly[0][0] += R * r[0] * np.sin(theta)
|
206 |
+
poly[0][1] += R * r[0] * np.cos(theta)
|
207 |
+
poly[3][0] -= R * r[3] * np.sin(theta)
|
208 |
+
poly[3][1] -= R * r[3] * np.cos(theta)
|
209 |
+
## p1, p2
|
210 |
+
theta = np.arctan2((poly[2][0] - poly[1][0]),
|
211 |
+
(poly[2][1] - poly[1][1]))
|
212 |
+
poly[1][0] += R * r[1] * np.sin(theta)
|
213 |
+
poly[1][1] += R * r[1] * np.cos(theta)
|
214 |
+
poly[2][0] -= R * r[2] * np.sin(theta)
|
215 |
+
poly[2][1] -= R * r[2] * np.cos(theta)
|
216 |
+
## p0, p1
|
217 |
+
theta = np.arctan2((poly[1][1] - poly[0][1]),
|
218 |
+
(poly[1][0] - poly[0][0]))
|
219 |
+
poly[0][0] += R * r[0] * np.cos(theta)
|
220 |
+
poly[0][1] += R * r[0] * np.sin(theta)
|
221 |
+
poly[1][0] -= R * r[1] * np.cos(theta)
|
222 |
+
poly[1][1] -= R * r[1] * np.sin(theta)
|
223 |
+
## p2, p3
|
224 |
+
theta = np.arctan2((poly[2][1] - poly[3][1]),
|
225 |
+
(poly[2][0] - poly[3][0]))
|
226 |
+
poly[3][0] += R * r[3] * np.cos(theta)
|
227 |
+
poly[3][1] += R * r[3] * np.sin(theta)
|
228 |
+
poly[2][0] -= R * r[2] * np.cos(theta)
|
229 |
+
poly[2][1] -= R * r[2] * np.sin(theta)
|
230 |
+
return poly
|
231 |
+
|
232 |
+
def generate_quad(self, im_size, polys, tags):
|
233 |
+
"""
|
234 |
+
Generate quadrangle.
|
235 |
+
"""
|
236 |
+
h, w = im_size
|
237 |
+
poly_mask = np.zeros((h, w), dtype=np.uint8)
|
238 |
+
score_map = np.zeros((h, w), dtype=np.uint8)
|
239 |
+
# (x1, y1, ..., x4, y4, short_edge_norm)
|
240 |
+
geo_map = np.zeros((h, w, 9), dtype=np.float32)
|
241 |
+
# mask used during traning, to ignore some hard areas
|
242 |
+
training_mask = np.ones((h, w), dtype=np.uint8)
|
243 |
+
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
244 |
+
poly = poly_tag[0]
|
245 |
+
tag = poly_tag[1]
|
246 |
+
|
247 |
+
r = [None, None, None, None]
|
248 |
+
for i in range(4):
|
249 |
+
dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
|
250 |
+
dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
|
251 |
+
r[i] = min(dist1, dist2)
|
252 |
+
# score map
|
253 |
+
shrinked_poly = self.shrink_poly(
|
254 |
+
poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
|
255 |
+
cv2.fillPoly(score_map, shrinked_poly, 1)
|
256 |
+
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
|
257 |
+
# if the poly is too small, then ignore it during training
|
258 |
+
poly_h = min(
|
259 |
+
np.linalg.norm(poly[0] - poly[3]),
|
260 |
+
np.linalg.norm(poly[1] - poly[2]))
|
261 |
+
poly_w = min(
|
262 |
+
np.linalg.norm(poly[0] - poly[1]),
|
263 |
+
np.linalg.norm(poly[2] - poly[3]))
|
264 |
+
if min(poly_h, poly_w) < self.min_text_size:
|
265 |
+
cv2.fillPoly(training_mask,
|
266 |
+
poly.astype(np.int32)[np.newaxis, :, :], 0)
|
267 |
+
|
268 |
+
if tag:
|
269 |
+
cv2.fillPoly(training_mask,
|
270 |
+
poly.astype(np.int32)[np.newaxis, :, :], 0)
|
271 |
+
|
272 |
+
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
|
273 |
+
# geo map.
|
274 |
+
y_in_poly = xy_in_poly[:, 0]
|
275 |
+
x_in_poly = xy_in_poly[:, 1]
|
276 |
+
poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
|
277 |
+
poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
|
278 |
+
for pno in range(4):
|
279 |
+
geo_channel_beg = pno * 2
|
280 |
+
geo_map[y_in_poly, x_in_poly, geo_channel_beg] =\
|
281 |
+
x_in_poly - poly[pno, 0]
|
282 |
+
geo_map[y_in_poly, x_in_poly, geo_channel_beg+1] =\
|
283 |
+
y_in_poly - poly[pno, 1]
|
284 |
+
geo_map[y_in_poly, x_in_poly, 8] = \
|
285 |
+
1.0 / max(min(poly_h, poly_w), 1.0)
|
286 |
+
return score_map, geo_map, training_mask
|
287 |
+
|
288 |
+
def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
|
289 |
+
"""
|
290 |
+
make random crop from the input image
|
291 |
+
:param im:
|
292 |
+
:param polys:
|
293 |
+
:param tags:
|
294 |
+
:param crop_background:
|
295 |
+
:param max_tries:
|
296 |
+
:return:
|
297 |
+
"""
|
298 |
+
h, w, _ = im.shape
|
299 |
+
pad_h = h // 10
|
300 |
+
pad_w = w // 10
|
301 |
+
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
302 |
+
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
303 |
+
for poly in polys:
|
304 |
+
poly = np.round(poly, decimals=0).astype(np.int32)
|
305 |
+
minx = np.min(poly[:, 0])
|
306 |
+
maxx = np.max(poly[:, 0])
|
307 |
+
w_array[minx + pad_w:maxx + pad_w] = 1
|
308 |
+
miny = np.min(poly[:, 1])
|
309 |
+
maxy = np.max(poly[:, 1])
|
310 |
+
h_array[miny + pad_h:maxy + pad_h] = 1
|
311 |
+
# ensure the cropped area not across a text
|
312 |
+
h_axis = np.where(h_array == 0)[0]
|
313 |
+
w_axis = np.where(w_array == 0)[0]
|
314 |
+
if len(h_axis) == 0 or len(w_axis) == 0:
|
315 |
+
return im, polys, tags
|
316 |
+
|
317 |
+
for i in range(max_tries):
|
318 |
+
xx = np.random.choice(w_axis, size=2)
|
319 |
+
xmin = np.min(xx) - pad_w
|
320 |
+
xmax = np.max(xx) - pad_w
|
321 |
+
xmin = np.clip(xmin, 0, w - 1)
|
322 |
+
xmax = np.clip(xmax, 0, w - 1)
|
323 |
+
yy = np.random.choice(h_axis, size=2)
|
324 |
+
ymin = np.min(yy) - pad_h
|
325 |
+
ymax = np.max(yy) - pad_h
|
326 |
+
ymin = np.clip(ymin, 0, h - 1)
|
327 |
+
ymax = np.clip(ymax, 0, h - 1)
|
328 |
+
if xmax - xmin < self.min_crop_side_ratio * w or \
|
329 |
+
ymax - ymin < self.min_crop_side_ratio * h:
|
330 |
+
# area too small
|
331 |
+
continue
|
332 |
+
if polys.shape[0] != 0:
|
333 |
+
poly_axis_in_area = (polys[:, :, 0] >= xmin)\
|
334 |
+
& (polys[:, :, 0] <= xmax)\
|
335 |
+
& (polys[:, :, 1] >= ymin)\
|
336 |
+
& (polys[:, :, 1] <= ymax)
|
337 |
+
selected_polys = np.where(
|
338 |
+
np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
339 |
+
else:
|
340 |
+
selected_polys = []
|
341 |
+
|
342 |
+
if len(selected_polys) == 0:
|
343 |
+
# no text in this area
|
344 |
+
if crop_background:
|
345 |
+
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
346 |
+
polys = []
|
347 |
+
tags = []
|
348 |
+
return im, polys, tags
|
349 |
+
else:
|
350 |
+
continue
|
351 |
+
|
352 |
+
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
353 |
+
polys = polys[selected_polys]
|
354 |
+
tags = tags[selected_polys]
|
355 |
+
polys[:, :, 0] -= xmin
|
356 |
+
polys[:, :, 1] -= ymin
|
357 |
+
return im, polys, tags
|
358 |
+
return im, polys, tags
|
359 |
+
|
360 |
+
def crop_background_infor(self, im, text_polys, text_tags):
|
361 |
+
im, text_polys, text_tags = self.crop_area(
|
362 |
+
im, text_polys, text_tags, crop_background=True)
|
363 |
+
|
364 |
+
if len(text_polys) > 0:
|
365 |
+
return None
|
366 |
+
# pad and resize image
|
367 |
+
input_size = self.input_size
|
368 |
+
im, ratio = self.preprocess(im)
|
369 |
+
score_map = np.zeros((input_size, input_size), dtype=np.float32)
|
370 |
+
geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
|
371 |
+
training_mask = np.ones((input_size, input_size), dtype=np.float32)
|
372 |
+
return im, score_map, geo_map, training_mask
|
373 |
+
|
374 |
+
def crop_foreground_infor(self, im, text_polys, text_tags):
|
375 |
+
im, text_polys, text_tags = self.crop_area(
|
376 |
+
im, text_polys, text_tags, crop_background=False)
|
377 |
+
|
378 |
+
if text_polys.shape[0] == 0:
|
379 |
+
return None
|
380 |
+
#continue for all ignore case
|
381 |
+
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
382 |
+
return None
|
383 |
+
# pad and resize image
|
384 |
+
input_size = self.input_size
|
385 |
+
im, ratio = self.preprocess(im)
|
386 |
+
text_polys[:, :, 0] *= ratio
|
387 |
+
text_polys[:, :, 1] *= ratio
|
388 |
+
_, _, new_h, new_w = im.shape
|
389 |
+
# print(im.shape)
|
390 |
+
# self.draw_img_polys(im, text_polys)
|
391 |
+
score_map, geo_map, training_mask = self.generate_quad(
|
392 |
+
(new_h, new_w), text_polys, text_tags)
|
393 |
+
return im, score_map, geo_map, training_mask
|
394 |
+
|
395 |
+
def __call__(self, data):
|
396 |
+
im = data['image']
|
397 |
+
text_polys = data['polys']
|
398 |
+
text_tags = data['ignore_tags']
|
399 |
+
if im is None:
|
400 |
+
return None
|
401 |
+
if text_polys.shape[0] == 0:
|
402 |
+
return None
|
403 |
+
|
404 |
+
#add rotate cases
|
405 |
+
if np.random.rand() < 0.5:
|
406 |
+
im, text_polys = self.rotate_im_poly(im, text_polys)
|
407 |
+
h, w, _ = im.shape
|
408 |
+
text_polys, text_tags = self.check_and_validate_polys(text_polys,
|
409 |
+
text_tags, h, w)
|
410 |
+
if text_polys.shape[0] == 0:
|
411 |
+
return None
|
412 |
+
|
413 |
+
# random scale this image
|
414 |
+
rd_scale = np.random.choice(self.random_scale)
|
415 |
+
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
416 |
+
text_polys *= rd_scale
|
417 |
+
if np.random.rand() < self.background_ratio:
|
418 |
+
outs = self.crop_background_infor(im, text_polys, text_tags)
|
419 |
+
else:
|
420 |
+
outs = self.crop_foreground_infor(im, text_polys, text_tags)
|
421 |
+
|
422 |
+
if outs is None:
|
423 |
+
return None
|
424 |
+
im, score_map, geo_map, training_mask = outs
|
425 |
+
score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
|
426 |
+
geo_map = np.swapaxes(geo_map, 1, 2)
|
427 |
+
geo_map = np.swapaxes(geo_map, 1, 0)
|
428 |
+
geo_map = geo_map[:, ::4, ::4].astype(np.float32)
|
429 |
+
training_mask = training_mask[np.newaxis, ::4, ::4]
|
430 |
+
training_mask = training_mask.astype(np.float32)
|
431 |
+
|
432 |
+
data['image'] = im[0]
|
433 |
+
data['score_map'] = score_map
|
434 |
+
data['geo_map'] = geo_map
|
435 |
+
data['training_mask'] = training_mask
|
436 |
+
return data
|
ppocr/data/imaug/fce_aug.py
ADDED
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is refer from:
|
16 |
+
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/transforms.py
|
17 |
+
"""
|
18 |
+
import numpy as np
|
19 |
+
from PIL import Image, ImageDraw
|
20 |
+
import cv2
|
21 |
+
from shapely.geometry import Polygon
|
22 |
+
import math
|
23 |
+
from ppocr.utils.poly_nms import poly_intersection
|
24 |
+
|
25 |
+
|
26 |
+
class RandomScaling:
|
27 |
+
def __init__(self, size=800, scale=(3. / 4, 5. / 2), **kwargs):
|
28 |
+
"""Random scale the image while keeping aspect.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
size (int) : Base size before scaling.
|
32 |
+
scale (tuple(float)) : The range of scaling.
|
33 |
+
"""
|
34 |
+
assert isinstance(size, int)
|
35 |
+
assert isinstance(scale, float) or isinstance(scale, tuple)
|
36 |
+
self.size = size
|
37 |
+
self.scale = scale if isinstance(scale, tuple) \
|
38 |
+
else (1 - scale, 1 + scale)
|
39 |
+
|
40 |
+
def __call__(self, data):
|
41 |
+
image = data['image']
|
42 |
+
text_polys = data['polys']
|
43 |
+
h, w, _ = image.shape
|
44 |
+
|
45 |
+
aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
|
46 |
+
scales = self.size * 1.0 / max(h, w) * aspect_ratio
|
47 |
+
scales = np.array([scales, scales])
|
48 |
+
out_size = (int(h * scales[1]), int(w * scales[0]))
|
49 |
+
image = cv2.resize(image, out_size[::-1])
|
50 |
+
|
51 |
+
data['image'] = image
|
52 |
+
text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1]
|
53 |
+
text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0]
|
54 |
+
data['polys'] = text_polys
|
55 |
+
|
56 |
+
return data
|
57 |
+
|
58 |
+
|
59 |
+
class RandomCropFlip:
|
60 |
+
def __init__(self,
|
61 |
+
pad_ratio=0.1,
|
62 |
+
crop_ratio=0.5,
|
63 |
+
iter_num=1,
|
64 |
+
min_area_ratio=0.2,
|
65 |
+
**kwargs):
|
66 |
+
"""Random crop and flip a patch of the image.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
crop_ratio (float): The ratio of cropping.
|
70 |
+
iter_num (int): Number of operations.
|
71 |
+
min_area_ratio (float): Minimal area ratio between cropped patch
|
72 |
+
and original image.
|
73 |
+
"""
|
74 |
+
assert isinstance(crop_ratio, float)
|
75 |
+
assert isinstance(iter_num, int)
|
76 |
+
assert isinstance(min_area_ratio, float)
|
77 |
+
|
78 |
+
self.pad_ratio = pad_ratio
|
79 |
+
self.epsilon = 1e-2
|
80 |
+
self.crop_ratio = crop_ratio
|
81 |
+
self.iter_num = iter_num
|
82 |
+
self.min_area_ratio = min_area_ratio
|
83 |
+
|
84 |
+
def __call__(self, results):
|
85 |
+
for i in range(self.iter_num):
|
86 |
+
results = self.random_crop_flip(results)
|
87 |
+
|
88 |
+
return results
|
89 |
+
|
90 |
+
def random_crop_flip(self, results):
|
91 |
+
image = results['image']
|
92 |
+
polygons = results['polys']
|
93 |
+
ignore_tags = results['ignore_tags']
|
94 |
+
if len(polygons) == 0:
|
95 |
+
return results
|
96 |
+
|
97 |
+
if np.random.random() >= self.crop_ratio:
|
98 |
+
return results
|
99 |
+
|
100 |
+
h, w, _ = image.shape
|
101 |
+
area = h * w
|
102 |
+
pad_h = int(h * self.pad_ratio)
|
103 |
+
pad_w = int(w * self.pad_ratio)
|
104 |
+
h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h,
|
105 |
+
pad_w)
|
106 |
+
if len(h_axis) == 0 or len(w_axis) == 0:
|
107 |
+
return results
|
108 |
+
|
109 |
+
attempt = 0
|
110 |
+
while attempt < 50:
|
111 |
+
attempt += 1
|
112 |
+
polys_keep = []
|
113 |
+
polys_new = []
|
114 |
+
ignore_tags_keep = []
|
115 |
+
ignore_tags_new = []
|
116 |
+
xx = np.random.choice(w_axis, size=2)
|
117 |
+
xmin = np.min(xx) - pad_w
|
118 |
+
xmax = np.max(xx) - pad_w
|
119 |
+
xmin = np.clip(xmin, 0, w - 1)
|
120 |
+
xmax = np.clip(xmax, 0, w - 1)
|
121 |
+
yy = np.random.choice(h_axis, size=2)
|
122 |
+
ymin = np.min(yy) - pad_h
|
123 |
+
ymax = np.max(yy) - pad_h
|
124 |
+
ymin = np.clip(ymin, 0, h - 1)
|
125 |
+
ymax = np.clip(ymax, 0, h - 1)
|
126 |
+
if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
|
127 |
+
# area too small
|
128 |
+
continue
|
129 |
+
|
130 |
+
pts = np.stack([[xmin, xmax, xmax, xmin],
|
131 |
+
[ymin, ymin, ymax, ymax]]).T.astype(np.int32)
|
132 |
+
pp = Polygon(pts)
|
133 |
+
fail_flag = False
|
134 |
+
for polygon, ignore_tag in zip(polygons, ignore_tags):
|
135 |
+
ppi = Polygon(polygon.reshape(-1, 2))
|
136 |
+
ppiou, _ = poly_intersection(ppi, pp, buffer=0)
|
137 |
+
if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
|
138 |
+
np.abs(ppiou) > self.epsilon:
|
139 |
+
fail_flag = True
|
140 |
+
break
|
141 |
+
elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
|
142 |
+
polys_new.append(polygon)
|
143 |
+
ignore_tags_new.append(ignore_tag)
|
144 |
+
else:
|
145 |
+
polys_keep.append(polygon)
|
146 |
+
ignore_tags_keep.append(ignore_tag)
|
147 |
+
|
148 |
+
if fail_flag:
|
149 |
+
continue
|
150 |
+
else:
|
151 |
+
break
|
152 |
+
|
153 |
+
cropped = image[ymin:ymax, xmin:xmax, :]
|
154 |
+
select_type = np.random.randint(3)
|
155 |
+
if select_type == 0:
|
156 |
+
img = np.ascontiguousarray(cropped[:, ::-1])
|
157 |
+
elif select_type == 1:
|
158 |
+
img = np.ascontiguousarray(cropped[::-1, :])
|
159 |
+
else:
|
160 |
+
img = np.ascontiguousarray(cropped[::-1, ::-1])
|
161 |
+
image[ymin:ymax, xmin:xmax, :] = img
|
162 |
+
results['img'] = image
|
163 |
+
|
164 |
+
if len(polys_new) != 0:
|
165 |
+
height, width, _ = cropped.shape
|
166 |
+
if select_type == 0:
|
167 |
+
for idx, polygon in enumerate(polys_new):
|
168 |
+
poly = polygon.reshape(-1, 2)
|
169 |
+
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
170 |
+
polys_new[idx] = poly
|
171 |
+
elif select_type == 1:
|
172 |
+
for idx, polygon in enumerate(polys_new):
|
173 |
+
poly = polygon.reshape(-1, 2)
|
174 |
+
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
175 |
+
polys_new[idx] = poly
|
176 |
+
else:
|
177 |
+
for idx, polygon in enumerate(polys_new):
|
178 |
+
poly = polygon.reshape(-1, 2)
|
179 |
+
poly[:, 0] = width - poly[:, 0] + 2 * xmin
|
180 |
+
poly[:, 1] = height - poly[:, 1] + 2 * ymin
|
181 |
+
polys_new[idx] = poly
|
182 |
+
polygons = polys_keep + polys_new
|
183 |
+
ignore_tags = ignore_tags_keep + ignore_tags_new
|
184 |
+
results['polys'] = np.array(polygons)
|
185 |
+
results['ignore_tags'] = ignore_tags
|
186 |
+
|
187 |
+
return results
|
188 |
+
|
189 |
+
def generate_crop_target(self, image, all_polys, pad_h, pad_w):
|
190 |
+
"""Generate crop target and make sure not to crop the polygon
|
191 |
+
instances.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
image (ndarray): The image waited to be crop.
|
195 |
+
all_polys (list[list[ndarray]]): All polygons including ground
|
196 |
+
truth polygons and ground truth ignored polygons.
|
197 |
+
pad_h (int): Padding length of height.
|
198 |
+
pad_w (int): Padding length of width.
|
199 |
+
Returns:
|
200 |
+
h_axis (ndarray): Vertical cropping range.
|
201 |
+
w_axis (ndarray): Horizontal cropping range.
|
202 |
+
"""
|
203 |
+
h, w, _ = image.shape
|
204 |
+
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
205 |
+
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
206 |
+
|
207 |
+
text_polys = []
|
208 |
+
for polygon in all_polys:
|
209 |
+
rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2))
|
210 |
+
box = cv2.boxPoints(rect)
|
211 |
+
box = np.int0(box)
|
212 |
+
text_polys.append([box[0], box[1], box[2], box[3]])
|
213 |
+
|
214 |
+
polys = np.array(text_polys, dtype=np.int32)
|
215 |
+
for poly in polys:
|
216 |
+
poly = np.round(poly, decimals=0).astype(np.int32)
|
217 |
+
minx = np.min(poly[:, 0])
|
218 |
+
maxx = np.max(poly[:, 0])
|
219 |
+
w_array[minx + pad_w:maxx + pad_w] = 1
|
220 |
+
miny = np.min(poly[:, 1])
|
221 |
+
maxy = np.max(poly[:, 1])
|
222 |
+
h_array[miny + pad_h:maxy + pad_h] = 1
|
223 |
+
|
224 |
+
h_axis = np.where(h_array == 0)[0]
|
225 |
+
w_axis = np.where(w_array == 0)[0]
|
226 |
+
return h_axis, w_axis
|
227 |
+
|
228 |
+
|
229 |
+
class RandomCropPolyInstances:
|
230 |
+
"""Randomly crop images and make sure to contain at least one intact
|
231 |
+
instance."""
|
232 |
+
|
233 |
+
def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs):
|
234 |
+
super().__init__()
|
235 |
+
self.crop_ratio = crop_ratio
|
236 |
+
self.min_side_ratio = min_side_ratio
|
237 |
+
|
238 |
+
def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
|
239 |
+
|
240 |
+
assert isinstance(min_len, int)
|
241 |
+
assert len(valid_array) > min_len
|
242 |
+
|
243 |
+
start_array = valid_array.copy()
|
244 |
+
max_start = min(len(start_array) - min_len, max_start)
|
245 |
+
start_array[max_start:] = 0
|
246 |
+
start_array[0] = 1
|
247 |
+
diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
|
248 |
+
region_starts = np.where(diff_array < 0)[0]
|
249 |
+
region_ends = np.where(diff_array > 0)[0]
|
250 |
+
region_ind = np.random.randint(0, len(region_starts))
|
251 |
+
start = np.random.randint(region_starts[region_ind],
|
252 |
+
region_ends[region_ind])
|
253 |
+
|
254 |
+
end_array = valid_array.copy()
|
255 |
+
min_end = max(start + min_len, min_end)
|
256 |
+
end_array[:min_end] = 0
|
257 |
+
end_array[-1] = 1
|
258 |
+
diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
|
259 |
+
region_starts = np.where(diff_array < 0)[0]
|
260 |
+
region_ends = np.where(diff_array > 0)[0]
|
261 |
+
region_ind = np.random.randint(0, len(region_starts))
|
262 |
+
end = np.random.randint(region_starts[region_ind],
|
263 |
+
region_ends[region_ind])
|
264 |
+
return start, end
|
265 |
+
|
266 |
+
def sample_crop_box(self, img_size, results):
|
267 |
+
"""Generate crop box and make sure not to crop the polygon instances.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
img_size (tuple(int)): The image size (h, w).
|
271 |
+
results (dict): The results dict.
|
272 |
+
"""
|
273 |
+
|
274 |
+
assert isinstance(img_size, tuple)
|
275 |
+
h, w = img_size[:2]
|
276 |
+
|
277 |
+
key_masks = results['polys']
|
278 |
+
|
279 |
+
x_valid_array = np.ones(w, dtype=np.int32)
|
280 |
+
y_valid_array = np.ones(h, dtype=np.int32)
|
281 |
+
|
282 |
+
selected_mask = key_masks[np.random.randint(0, len(key_masks))]
|
283 |
+
selected_mask = selected_mask.reshape((-1, 2)).astype(np.int32)
|
284 |
+
max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
|
285 |
+
min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
|
286 |
+
max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
|
287 |
+
min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
|
288 |
+
|
289 |
+
for mask in key_masks:
|
290 |
+
mask = mask.reshape((-1, 2)).astype(np.int32)
|
291 |
+
clip_x = np.clip(mask[:, 0], 0, w - 1)
|
292 |
+
clip_y = np.clip(mask[:, 1], 0, h - 1)
|
293 |
+
min_x, max_x = np.min(clip_x), np.max(clip_x)
|
294 |
+
min_y, max_y = np.min(clip_y), np.max(clip_y)
|
295 |
+
|
296 |
+
x_valid_array[min_x - 2:max_x + 3] = 0
|
297 |
+
y_valid_array[min_y - 2:max_y + 3] = 0
|
298 |
+
|
299 |
+
min_w = int(w * self.min_side_ratio)
|
300 |
+
min_h = int(h * self.min_side_ratio)
|
301 |
+
|
302 |
+
x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start,
|
303 |
+
min_x_end)
|
304 |
+
y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start,
|
305 |
+
min_y_end)
|
306 |
+
|
307 |
+
return np.array([x1, y1, x2, y2])
|
308 |
+
|
309 |
+
def crop_img(self, img, bbox):
|
310 |
+
assert img.ndim == 3
|
311 |
+
h, w, _ = img.shape
|
312 |
+
assert 0 <= bbox[1] < bbox[3] <= h
|
313 |
+
assert 0 <= bbox[0] < bbox[2] <= w
|
314 |
+
return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
|
315 |
+
|
316 |
+
def __call__(self, results):
|
317 |
+
image = results['image']
|
318 |
+
polygons = results['polys']
|
319 |
+
ignore_tags = results['ignore_tags']
|
320 |
+
if len(polygons) < 1:
|
321 |
+
return results
|
322 |
+
|
323 |
+
if np.random.random_sample() < self.crop_ratio:
|
324 |
+
|
325 |
+
crop_box = self.sample_crop_box(image.shape, results)
|
326 |
+
img = self.crop_img(image, crop_box)
|
327 |
+
results['image'] = img
|
328 |
+
# crop and filter masks
|
329 |
+
x1, y1, x2, y2 = crop_box
|
330 |
+
w = max(x2 - x1, 1)
|
331 |
+
h = max(y2 - y1, 1)
|
332 |
+
polygons[:, :, 0::2] = polygons[:, :, 0::2] - x1
|
333 |
+
polygons[:, :, 1::2] = polygons[:, :, 1::2] - y1
|
334 |
+
|
335 |
+
valid_masks_list = []
|
336 |
+
valid_tags_list = []
|
337 |
+
for ind, polygon in enumerate(polygons):
|
338 |
+
if (polygon[:, ::2] > -4).all() and (
|
339 |
+
polygon[:, ::2] < w + 4).all() and (
|
340 |
+
polygon[:, 1::2] > -4).all() and (
|
341 |
+
polygon[:, 1::2] < h + 4).all():
|
342 |
+
polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w)
|
343 |
+
polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h)
|
344 |
+
valid_masks_list.append(polygon)
|
345 |
+
valid_tags_list.append(ignore_tags[ind])
|
346 |
+
|
347 |
+
results['polys'] = np.array(valid_masks_list)
|
348 |
+
results['ignore_tags'] = valid_tags_list
|
349 |
+
|
350 |
+
return results
|
351 |
+
|
352 |
+
def __repr__(self):
|
353 |
+
repr_str = self.__class__.__name__
|
354 |
+
return repr_str
|
355 |
+
|
356 |
+
|
357 |
+
class RandomRotatePolyInstances:
|
358 |
+
def __init__(self,
|
359 |
+
rotate_ratio=0.5,
|
360 |
+
max_angle=10,
|
361 |
+
pad_with_fixed_color=False,
|
362 |
+
pad_value=(0, 0, 0),
|
363 |
+
**kwargs):
|
364 |
+
"""Randomly rotate images and polygon masks.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
rotate_ratio (float): The ratio of samples to operate rotation.
|
368 |
+
max_angle (int): The maximum rotation angle.
|
369 |
+
pad_with_fixed_color (bool): The flag for whether to pad rotated
|
370 |
+
image with fixed value. If set to False, the rotated image will
|
371 |
+
be padded onto cropped image.
|
372 |
+
pad_value (tuple(int)): The color value for padding rotated image.
|
373 |
+
"""
|
374 |
+
self.rotate_ratio = rotate_ratio
|
375 |
+
self.max_angle = max_angle
|
376 |
+
self.pad_with_fixed_color = pad_with_fixed_color
|
377 |
+
self.pad_value = pad_value
|
378 |
+
|
379 |
+
def rotate(self, center, points, theta, center_shift=(0, 0)):
|
380 |
+
# rotate points.
|
381 |
+
(center_x, center_y) = center
|
382 |
+
center_y = -center_y
|
383 |
+
x, y = points[:, ::2], points[:, 1::2]
|
384 |
+
y = -y
|
385 |
+
|
386 |
+
theta = theta / 180 * math.pi
|
387 |
+
cos = math.cos(theta)
|
388 |
+
sin = math.sin(theta)
|
389 |
+
|
390 |
+
x = (x - center_x)
|
391 |
+
y = (y - center_y)
|
392 |
+
|
393 |
+
_x = center_x + x * cos - y * sin + center_shift[0]
|
394 |
+
_y = -(center_y + x * sin + y * cos) + center_shift[1]
|
395 |
+
|
396 |
+
points[:, ::2], points[:, 1::2] = _x, _y
|
397 |
+
return points
|
398 |
+
|
399 |
+
def cal_canvas_size(self, ori_size, degree):
|
400 |
+
assert isinstance(ori_size, tuple)
|
401 |
+
angle = degree * math.pi / 180.0
|
402 |
+
h, w = ori_size[:2]
|
403 |
+
|
404 |
+
cos = math.cos(angle)
|
405 |
+
sin = math.sin(angle)
|
406 |
+
canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
|
407 |
+
canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
|
408 |
+
|
409 |
+
canvas_size = (canvas_h, canvas_w)
|
410 |
+
return canvas_size
|
411 |
+
|
412 |
+
def sample_angle(self, max_angle):
|
413 |
+
angle = np.random.random_sample() * 2 * max_angle - max_angle
|
414 |
+
return angle
|
415 |
+
|
416 |
+
def rotate_img(self, img, angle, canvas_size):
|
417 |
+
h, w = img.shape[:2]
|
418 |
+
rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
|
419 |
+
rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
|
420 |
+
rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
|
421 |
+
|
422 |
+
if self.pad_with_fixed_color:
|
423 |
+
target_img = cv2.warpAffine(
|
424 |
+
img,
|
425 |
+
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
426 |
+
flags=cv2.INTER_NEAREST,
|
427 |
+
borderValue=self.pad_value)
|
428 |
+
else:
|
429 |
+
mask = np.zeros_like(img)
|
430 |
+
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
|
431 |
+
np.random.randint(0, w * 7 // 8))
|
432 |
+
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
|
433 |
+
img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
|
434 |
+
|
435 |
+
mask = cv2.warpAffine(
|
436 |
+
mask,
|
437 |
+
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
438 |
+
borderValue=[1, 1, 1])
|
439 |
+
target_img = cv2.warpAffine(
|
440 |
+
img,
|
441 |
+
rotation_matrix, (canvas_size[1], canvas_size[0]),
|
442 |
+
borderValue=[0, 0, 0])
|
443 |
+
target_img = target_img + img_cut * mask
|
444 |
+
|
445 |
+
return target_img
|
446 |
+
|
447 |
+
def __call__(self, results):
|
448 |
+
if np.random.random_sample() < self.rotate_ratio:
|
449 |
+
image = results['image']
|
450 |
+
polygons = results['polys']
|
451 |
+
h, w = image.shape[:2]
|
452 |
+
|
453 |
+
angle = self.sample_angle(self.max_angle)
|
454 |
+
canvas_size = self.cal_canvas_size((h, w), angle)
|
455 |
+
center_shift = (int((canvas_size[1] - w) / 2), int(
|
456 |
+
(canvas_size[0] - h) / 2))
|
457 |
+
image = self.rotate_img(image, angle, canvas_size)
|
458 |
+
results['image'] = image
|
459 |
+
# rotate polygons
|
460 |
+
rotated_masks = []
|
461 |
+
for mask in polygons:
|
462 |
+
rotated_mask = self.rotate((w / 2, h / 2), mask, angle,
|
463 |
+
center_shift)
|
464 |
+
rotated_masks.append(rotated_mask)
|
465 |
+
results['polys'] = np.array(rotated_masks)
|
466 |
+
|
467 |
+
return results
|
468 |
+
|
469 |
+
def __repr__(self):
|
470 |
+
repr_str = self.__class__.__name__
|
471 |
+
return repr_str
|
472 |
+
|
473 |
+
|
474 |
+
class SquareResizePad:
|
475 |
+
def __init__(self,
|
476 |
+
target_size,
|
477 |
+
pad_ratio=0.6,
|
478 |
+
pad_with_fixed_color=False,
|
479 |
+
pad_value=(0, 0, 0),
|
480 |
+
**kwargs):
|
481 |
+
"""Resize or pad images to be square shape.
|
482 |
+
|
483 |
+
Args:
|
484 |
+
target_size (int): The target size of square shaped image.
|
485 |
+
pad_with_fixed_color (bool): The flag for whether to pad rotated
|
486 |
+
image with fixed value. If set to False, the rescales image will
|
487 |
+
be padded onto cropped image.
|
488 |
+
pad_value (tuple(int)): The color value for padding rotated image.
|
489 |
+
"""
|
490 |
+
assert isinstance(target_size, int)
|
491 |
+
assert isinstance(pad_ratio, float)
|
492 |
+
assert isinstance(pad_with_fixed_color, bool)
|
493 |
+
assert isinstance(pad_value, tuple)
|
494 |
+
|
495 |
+
self.target_size = target_size
|
496 |
+
self.pad_ratio = pad_ratio
|
497 |
+
self.pad_with_fixed_color = pad_with_fixed_color
|
498 |
+
self.pad_value = pad_value
|
499 |
+
|
500 |
+
def resize_img(self, img, keep_ratio=True):
|
501 |
+
h, w, _ = img.shape
|
502 |
+
if keep_ratio:
|
503 |
+
t_h = self.target_size if h >= w else int(h * self.target_size / w)
|
504 |
+
t_w = self.target_size if h <= w else int(w * self.target_size / h)
|
505 |
+
else:
|
506 |
+
t_h = t_w = self.target_size
|
507 |
+
img = cv2.resize(img, (t_w, t_h))
|
508 |
+
return img, (t_h, t_w)
|
509 |
+
|
510 |
+
def square_pad(self, img):
|
511 |
+
h, w = img.shape[:2]
|
512 |
+
if h == w:
|
513 |
+
return img, (0, 0)
|
514 |
+
pad_size = max(h, w)
|
515 |
+
if self.pad_with_fixed_color:
|
516 |
+
expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
|
517 |
+
expand_img[:] = self.pad_value
|
518 |
+
else:
|
519 |
+
(h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
|
520 |
+
np.random.randint(0, w * 7 // 8))
|
521 |
+
img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
|
522 |
+
expand_img = cv2.resize(img_cut, (pad_size, pad_size))
|
523 |
+
if h > w:
|
524 |
+
y0, x0 = 0, (h - w) // 2
|
525 |
+
else:
|
526 |
+
y0, x0 = (w - h) // 2, 0
|
527 |
+
expand_img[y0:y0 + h, x0:x0 + w] = img
|
528 |
+
offset = (x0, y0)
|
529 |
+
|
530 |
+
return expand_img, offset
|
531 |
+
|
532 |
+
def square_pad_mask(self, points, offset):
|
533 |
+
x0, y0 = offset
|
534 |
+
pad_points = points.copy()
|
535 |
+
pad_points[::2] = pad_points[::2] + x0
|
536 |
+
pad_points[1::2] = pad_points[1::2] + y0
|
537 |
+
return pad_points
|
538 |
+
|
539 |
+
def __call__(self, results):
|
540 |
+
image = results['image']
|
541 |
+
polygons = results['polys']
|
542 |
+
h, w = image.shape[:2]
|
543 |
+
|
544 |
+
if np.random.random_sample() < self.pad_ratio:
|
545 |
+
image, out_size = self.resize_img(image, keep_ratio=True)
|
546 |
+
image, offset = self.square_pad(image)
|
547 |
+
else:
|
548 |
+
image, out_size = self.resize_img(image, keep_ratio=False)
|
549 |
+
offset = (0, 0)
|
550 |
+
results['image'] = image
|
551 |
+
try:
|
552 |
+
polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[
|
553 |
+
1] / w + offset[0]
|
554 |
+
polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[
|
555 |
+
0] / h + offset[1]
|
556 |
+
except:
|
557 |
+
pass
|
558 |
+
results['polys'] = polygons
|
559 |
+
|
560 |
+
return results
|
561 |
+
|
562 |
+
def __repr__(self):
|
563 |
+
repr_str = self.__class__.__name__
|
564 |
+
return repr_str
|
ppocr/data/imaug/fce_targets.py
ADDED
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is refer from:
|
16 |
+
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py
|
17 |
+
"""
|
18 |
+
|
19 |
+
import cv2
|
20 |
+
import numpy as np
|
21 |
+
from numpy.fft import fft
|
22 |
+
from numpy.linalg import norm
|
23 |
+
import sys
|
24 |
+
|
25 |
+
def vector_slope(vec):
|
26 |
+
assert len(vec) == 2
|
27 |
+
return abs(vec[1] / (vec[0] + 1e-8))
|
28 |
+
|
29 |
+
class FCENetTargets:
|
30 |
+
"""Generate the ground truth targets of FCENet: Fourier Contour Embedding
|
31 |
+
for Arbitrary-Shaped Text Detection.
|
32 |
+
|
33 |
+
[https://arxiv.org/abs/2104.10442]
|
34 |
+
|
35 |
+
Args:
|
36 |
+
fourier_degree (int): The maximum Fourier transform degree k.
|
37 |
+
resample_step (float): The step size for resampling the text center
|
38 |
+
line (TCL). It's better not to exceed half of the minimum width.
|
39 |
+
center_region_shrink_ratio (float): The shrink ratio of text center
|
40 |
+
region.
|
41 |
+
level_size_divisors (tuple(int)): The downsample ratio on each level.
|
42 |
+
level_proportion_range (tuple(tuple(int))): The range of text sizes
|
43 |
+
assigned to each level.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self,
|
47 |
+
fourier_degree=5,
|
48 |
+
resample_step=4.0,
|
49 |
+
center_region_shrink_ratio=0.3,
|
50 |
+
level_size_divisors=(8, 16, 32),
|
51 |
+
level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)),
|
52 |
+
orientation_thr=2.0,
|
53 |
+
**kwargs):
|
54 |
+
|
55 |
+
super().__init__()
|
56 |
+
assert isinstance(level_size_divisors, tuple)
|
57 |
+
assert isinstance(level_proportion_range, tuple)
|
58 |
+
assert len(level_size_divisors) == len(level_proportion_range)
|
59 |
+
self.fourier_degree = fourier_degree
|
60 |
+
self.resample_step = resample_step
|
61 |
+
self.center_region_shrink_ratio = center_region_shrink_ratio
|
62 |
+
self.level_size_divisors = level_size_divisors
|
63 |
+
self.level_proportion_range = level_proportion_range
|
64 |
+
|
65 |
+
self.orientation_thr = orientation_thr
|
66 |
+
|
67 |
+
def vector_angle(self, vec1, vec2):
|
68 |
+
if vec1.ndim > 1:
|
69 |
+
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
|
70 |
+
else:
|
71 |
+
unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
|
72 |
+
if vec2.ndim > 1:
|
73 |
+
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
|
74 |
+
else:
|
75 |
+
unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
|
76 |
+
return np.arccos(
|
77 |
+
np.clip(
|
78 |
+
np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
|
79 |
+
|
80 |
+
def resample_line(self, line, n):
|
81 |
+
"""Resample n points on a line.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
line (ndarray): The points composing a line.
|
85 |
+
n (int): The resampled points number.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
resampled_line (ndarray): The points composing the resampled line.
|
89 |
+
"""
|
90 |
+
|
91 |
+
assert line.ndim == 2
|
92 |
+
assert line.shape[0] >= 2
|
93 |
+
assert line.shape[1] == 2
|
94 |
+
assert isinstance(n, int)
|
95 |
+
assert n > 0
|
96 |
+
|
97 |
+
length_list = [
|
98 |
+
norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
|
99 |
+
]
|
100 |
+
total_length = sum(length_list)
|
101 |
+
length_cumsum = np.cumsum([0.0] + length_list)
|
102 |
+
delta_length = total_length / (float(n) + 1e-8)
|
103 |
+
|
104 |
+
current_edge_ind = 0
|
105 |
+
resampled_line = [line[0]]
|
106 |
+
|
107 |
+
for i in range(1, n):
|
108 |
+
current_line_len = i * delta_length
|
109 |
+
|
110 |
+
while current_edge_ind + 1 < len(length_cumsum) and current_line_len >= length_cumsum[current_edge_ind + 1]:
|
111 |
+
current_edge_ind += 1
|
112 |
+
|
113 |
+
current_edge_end_shift = current_line_len - length_cumsum[
|
114 |
+
current_edge_ind]
|
115 |
+
|
116 |
+
if current_edge_ind >= len(length_list):
|
117 |
+
break
|
118 |
+
end_shift_ratio = current_edge_end_shift / length_list[
|
119 |
+
current_edge_ind]
|
120 |
+
current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
|
121 |
+
- line[current_edge_ind]
|
122 |
+
) * end_shift_ratio
|
123 |
+
resampled_line.append(current_point)
|
124 |
+
resampled_line.append(line[-1])
|
125 |
+
resampled_line = np.array(resampled_line)
|
126 |
+
|
127 |
+
return resampled_line
|
128 |
+
|
129 |
+
def reorder_poly_edge(self, points):
|
130 |
+
"""Get the respective points composing head edge, tail edge, top
|
131 |
+
sideline and bottom sideline.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
points (ndarray): The points composing a text polygon.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
head_edge (ndarray): The two points composing the head edge of text
|
138 |
+
polygon.
|
139 |
+
tail_edge (ndarray): The two points composing the tail edge of text
|
140 |
+
polygon.
|
141 |
+
top_sideline (ndarray): The points composing top curved sideline of
|
142 |
+
text polygon.
|
143 |
+
bot_sideline (ndarray): The points composing bottom curved sideline
|
144 |
+
of text polygon.
|
145 |
+
"""
|
146 |
+
|
147 |
+
assert points.ndim == 2
|
148 |
+
assert points.shape[0] >= 4
|
149 |
+
assert points.shape[1] == 2
|
150 |
+
|
151 |
+
head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
|
152 |
+
head_edge, tail_edge = points[head_inds], points[tail_inds]
|
153 |
+
|
154 |
+
pad_points = np.vstack([points, points])
|
155 |
+
if tail_inds[1] < 1:
|
156 |
+
tail_inds[1] = len(points)
|
157 |
+
sideline1 = pad_points[head_inds[1]:tail_inds[1]]
|
158 |
+
sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
|
159 |
+
sideline_mean_shift = np.mean(
|
160 |
+
sideline1, axis=0) - np.mean(
|
161 |
+
sideline2, axis=0)
|
162 |
+
|
163 |
+
if sideline_mean_shift[1] > 0:
|
164 |
+
top_sideline, bot_sideline = sideline2, sideline1
|
165 |
+
else:
|
166 |
+
top_sideline, bot_sideline = sideline1, sideline2
|
167 |
+
|
168 |
+
return head_edge, tail_edge, top_sideline, bot_sideline
|
169 |
+
|
170 |
+
def find_head_tail(self, points, orientation_thr):
|
171 |
+
"""Find the head edge and tail edge of a text polygon.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
points (ndarray): The points composing a text polygon.
|
175 |
+
orientation_thr (float): The threshold for distinguishing between
|
176 |
+
head edge and tail edge among the horizontal and vertical edges
|
177 |
+
of a quadrangle.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
head_inds (list): The indexes of two points composing head edge.
|
181 |
+
tail_inds (list): The indexes of two points composing tail edge.
|
182 |
+
"""
|
183 |
+
|
184 |
+
assert points.ndim == 2
|
185 |
+
assert points.shape[0] >= 4
|
186 |
+
assert points.shape[1] == 2
|
187 |
+
assert isinstance(orientation_thr, float)
|
188 |
+
|
189 |
+
if len(points) > 4:
|
190 |
+
pad_points = np.vstack([points, points[0]])
|
191 |
+
edge_vec = pad_points[1:] - pad_points[:-1]
|
192 |
+
|
193 |
+
theta_sum = []
|
194 |
+
adjacent_vec_theta = []
|
195 |
+
for i, edge_vec1 in enumerate(edge_vec):
|
196 |
+
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
|
197 |
+
adjacent_edge_vec = edge_vec[adjacent_ind]
|
198 |
+
temp_theta_sum = np.sum(
|
199 |
+
self.vector_angle(edge_vec1, adjacent_edge_vec))
|
200 |
+
temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
|
201 |
+
adjacent_edge_vec[1])
|
202 |
+
theta_sum.append(temp_theta_sum)
|
203 |
+
adjacent_vec_theta.append(temp_adjacent_theta)
|
204 |
+
theta_sum_score = np.array(theta_sum) / np.pi
|
205 |
+
adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
|
206 |
+
poly_center = np.mean(points, axis=0)
|
207 |
+
edge_dist = np.maximum(
|
208 |
+
norm(
|
209 |
+
pad_points[1:] - poly_center, axis=-1),
|
210 |
+
norm(
|
211 |
+
pad_points[:-1] - poly_center, axis=-1))
|
212 |
+
dist_score = edge_dist / np.max(edge_dist)
|
213 |
+
position_score = np.zeros(len(edge_vec))
|
214 |
+
score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
|
215 |
+
score += 0.35 * dist_score
|
216 |
+
if len(points) % 2 == 0:
|
217 |
+
position_score[(len(score) // 2 - 1)] += 1
|
218 |
+
position_score[-1] += 1
|
219 |
+
score += 0.1 * position_score
|
220 |
+
pad_score = np.concatenate([score, score])
|
221 |
+
score_matrix = np.zeros((len(score), len(score) - 3))
|
222 |
+
x = np.arange(len(score) - 3) / float(len(score) - 4)
|
223 |
+
gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
|
224 |
+
(x - 0.5) / 0.5, 2.) / 2)
|
225 |
+
gaussian = gaussian / np.max(gaussian)
|
226 |
+
for i in range(len(score)):
|
227 |
+
score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
|
228 |
+
score) - 1)] * gaussian * 0.3
|
229 |
+
|
230 |
+
head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
|
231 |
+
score_matrix.shape)
|
232 |
+
tail_start = (head_start + tail_increment + 2) % len(points)
|
233 |
+
head_end = (head_start + 1) % len(points)
|
234 |
+
tail_end = (tail_start + 1) % len(points)
|
235 |
+
|
236 |
+
if head_end > tail_end:
|
237 |
+
head_start, tail_start = tail_start, head_start
|
238 |
+
head_end, tail_end = tail_end, head_end
|
239 |
+
head_inds = [head_start, head_end]
|
240 |
+
tail_inds = [tail_start, tail_end]
|
241 |
+
else:
|
242 |
+
if vector_slope(points[1] - points[0]) + vector_slope(
|
243 |
+
points[3] - points[2]) < vector_slope(points[
|
244 |
+
2] - points[1]) + vector_slope(points[0] - points[
|
245 |
+
3]):
|
246 |
+
horizontal_edge_inds = [[0, 1], [2, 3]]
|
247 |
+
vertical_edge_inds = [[3, 0], [1, 2]]
|
248 |
+
else:
|
249 |
+
horizontal_edge_inds = [[3, 0], [1, 2]]
|
250 |
+
vertical_edge_inds = [[0, 1], [2, 3]]
|
251 |
+
|
252 |
+
vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
|
253 |
+
vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
|
254 |
+
0]] - points[vertical_edge_inds[1][1]])
|
255 |
+
horizontal_len_sum = norm(points[horizontal_edge_inds[0][
|
256 |
+
0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
|
257 |
+
horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
|
258 |
+
[1]])
|
259 |
+
|
260 |
+
if vertical_len_sum > horizontal_len_sum * orientation_thr:
|
261 |
+
head_inds = horizontal_edge_inds[0]
|
262 |
+
tail_inds = horizontal_edge_inds[1]
|
263 |
+
else:
|
264 |
+
head_inds = vertical_edge_inds[0]
|
265 |
+
tail_inds = vertical_edge_inds[1]
|
266 |
+
|
267 |
+
return head_inds, tail_inds
|
268 |
+
|
269 |
+
def resample_sidelines(self, sideline1, sideline2, resample_step):
|
270 |
+
"""Resample two sidelines to be of the same points number according to
|
271 |
+
step size.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
sideline1 (ndarray): The points composing a sideline of a text
|
275 |
+
polygon.
|
276 |
+
sideline2 (ndarray): The points composing another sideline of a
|
277 |
+
text polygon.
|
278 |
+
resample_step (float): The resampled step size.
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
resampled_line1 (ndarray): The resampled line 1.
|
282 |
+
resampled_line2 (ndarray): The resampled line 2.
|
283 |
+
"""
|
284 |
+
|
285 |
+
assert sideline1.ndim == sideline2.ndim == 2
|
286 |
+
assert sideline1.shape[1] == sideline2.shape[1] == 2
|
287 |
+
assert sideline1.shape[0] >= 2
|
288 |
+
assert sideline2.shape[0] >= 2
|
289 |
+
assert isinstance(resample_step, float)
|
290 |
+
|
291 |
+
length1 = sum([
|
292 |
+
norm(sideline1[i + 1] - sideline1[i])
|
293 |
+
for i in range(len(sideline1) - 1)
|
294 |
+
])
|
295 |
+
length2 = sum([
|
296 |
+
norm(sideline2[i + 1] - sideline2[i])
|
297 |
+
for i in range(len(sideline2) - 1)
|
298 |
+
])
|
299 |
+
|
300 |
+
total_length = (length1 + length2) / 2
|
301 |
+
resample_point_num = max(int(float(total_length) / resample_step), 1)
|
302 |
+
|
303 |
+
resampled_line1 = self.resample_line(sideline1, resample_point_num)
|
304 |
+
resampled_line2 = self.resample_line(sideline2, resample_point_num)
|
305 |
+
|
306 |
+
return resampled_line1, resampled_line2
|
307 |
+
|
308 |
+
def generate_center_region_mask(self, img_size, text_polys):
|
309 |
+
"""Generate text center region mask.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
img_size (tuple): The image size of (height, width).
|
313 |
+
text_polys (list[list[ndarray]]): The list of text polygons.
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
center_region_mask (ndarray): The text center region mask.
|
317 |
+
"""
|
318 |
+
|
319 |
+
assert isinstance(img_size, tuple)
|
320 |
+
# assert check_argument.is_2dlist(text_polys)
|
321 |
+
|
322 |
+
h, w = img_size
|
323 |
+
|
324 |
+
center_region_mask = np.zeros((h, w), np.uint8)
|
325 |
+
|
326 |
+
center_region_boxes = []
|
327 |
+
for poly in text_polys:
|
328 |
+
# assert len(poly) == 1
|
329 |
+
polygon_points = poly.reshape(-1, 2)
|
330 |
+
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
|
331 |
+
resampled_top_line, resampled_bot_line = self.resample_sidelines(
|
332 |
+
top_line, bot_line, self.resample_step)
|
333 |
+
resampled_bot_line = resampled_bot_line[::-1]
|
334 |
+
if len(resampled_top_line) != len(resampled_bot_line):
|
335 |
+
continue
|
336 |
+
center_line = (resampled_top_line + resampled_bot_line) / 2
|
337 |
+
|
338 |
+
line_head_shrink_len = norm(resampled_top_line[0] -
|
339 |
+
resampled_bot_line[0]) / 4.0
|
340 |
+
line_tail_shrink_len = norm(resampled_top_line[-1] -
|
341 |
+
resampled_bot_line[-1]) / 4.0
|
342 |
+
head_shrink_num = int(line_head_shrink_len // self.resample_step)
|
343 |
+
tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
|
344 |
+
if len(center_line) > head_shrink_num + tail_shrink_num + 2:
|
345 |
+
center_line = center_line[head_shrink_num:len(center_line) -
|
346 |
+
tail_shrink_num]
|
347 |
+
resampled_top_line = resampled_top_line[head_shrink_num:len(
|
348 |
+
resampled_top_line) - tail_shrink_num]
|
349 |
+
resampled_bot_line = resampled_bot_line[head_shrink_num:len(
|
350 |
+
resampled_bot_line) - tail_shrink_num]
|
351 |
+
|
352 |
+
for i in range(0, len(center_line) - 1):
|
353 |
+
tl = center_line[i] + (resampled_top_line[i] - center_line[i]
|
354 |
+
) * self.center_region_shrink_ratio
|
355 |
+
tr = center_line[i + 1] + (resampled_top_line[i + 1] -
|
356 |
+
center_line[i + 1]
|
357 |
+
) * self.center_region_shrink_ratio
|
358 |
+
br = center_line[i + 1] + (resampled_bot_line[i + 1] -
|
359 |
+
center_line[i + 1]
|
360 |
+
) * self.center_region_shrink_ratio
|
361 |
+
bl = center_line[i] + (resampled_bot_line[i] - center_line[i]
|
362 |
+
) * self.center_region_shrink_ratio
|
363 |
+
current_center_box = np.vstack([tl, tr, br,
|
364 |
+
bl]).astype(np.int32)
|
365 |
+
center_region_boxes.append(current_center_box)
|
366 |
+
|
367 |
+
cv2.fillPoly(center_region_mask, center_region_boxes, 1)
|
368 |
+
return center_region_mask
|
369 |
+
|
370 |
+
def resample_polygon(self, polygon, n=400):
|
371 |
+
"""Resample one polygon with n points on its boundary.
|
372 |
+
|
373 |
+
Args:
|
374 |
+
polygon (list[float]): The input polygon.
|
375 |
+
n (int): The number of resampled points.
|
376 |
+
Returns:
|
377 |
+
resampled_polygon (list[float]): The resampled polygon.
|
378 |
+
"""
|
379 |
+
length = []
|
380 |
+
|
381 |
+
for i in range(len(polygon)):
|
382 |
+
p1 = polygon[i]
|
383 |
+
if i == len(polygon) - 1:
|
384 |
+
p2 = polygon[0]
|
385 |
+
else:
|
386 |
+
p2 = polygon[i + 1]
|
387 |
+
length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5)
|
388 |
+
|
389 |
+
total_length = sum(length)
|
390 |
+
n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
|
391 |
+
n_on_each_line = n_on_each_line.astype(np.int32)
|
392 |
+
new_polygon = []
|
393 |
+
|
394 |
+
for i in range(len(polygon)):
|
395 |
+
num = n_on_each_line[i]
|
396 |
+
p1 = polygon[i]
|
397 |
+
if i == len(polygon) - 1:
|
398 |
+
p2 = polygon[0]
|
399 |
+
else:
|
400 |
+
p2 = polygon[i + 1]
|
401 |
+
|
402 |
+
if num == 0:
|
403 |
+
continue
|
404 |
+
|
405 |
+
dxdy = (p2 - p1) / num
|
406 |
+
for j in range(num):
|
407 |
+
point = p1 + dxdy * j
|
408 |
+
new_polygon.append(point)
|
409 |
+
|
410 |
+
return np.array(new_polygon)
|
411 |
+
|
412 |
+
def normalize_polygon(self, polygon):
|
413 |
+
"""Normalize one polygon so that its start point is at right most.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
polygon (list[float]): The origin polygon.
|
417 |
+
Returns:
|
418 |
+
new_polygon (lost[float]): The polygon with start point at right.
|
419 |
+
"""
|
420 |
+
temp_polygon = polygon - polygon.mean(axis=0)
|
421 |
+
x = np.abs(temp_polygon[:, 0])
|
422 |
+
y = temp_polygon[:, 1]
|
423 |
+
index_x = np.argsort(x)
|
424 |
+
index_y = np.argmin(y[index_x[:8]])
|
425 |
+
index = index_x[index_y]
|
426 |
+
new_polygon = np.concatenate([polygon[index:], polygon[:index]])
|
427 |
+
return new_polygon
|
428 |
+
|
429 |
+
def poly2fourier(self, polygon, fourier_degree):
|
430 |
+
"""Perform Fourier transformation to generate Fourier coefficients ck
|
431 |
+
from polygon.
|
432 |
+
|
433 |
+
Args:
|
434 |
+
polygon (ndarray): An input polygon.
|
435 |
+
fourier_degree (int): The maximum Fourier degree K.
|
436 |
+
Returns:
|
437 |
+
c (ndarray(complex)): Fourier coefficients.
|
438 |
+
"""
|
439 |
+
points = polygon[:, 0] + polygon[:, 1] * 1j
|
440 |
+
c_fft = fft(points) / len(points)
|
441 |
+
c = np.hstack((c_fft[-fourier_degree:], c_fft[:fourier_degree + 1]))
|
442 |
+
return c
|
443 |
+
|
444 |
+
def clockwise(self, c, fourier_degree):
|
445 |
+
"""Make sure the polygon reconstructed from Fourier coefficients c in
|
446 |
+
the clockwise direction.
|
447 |
+
|
448 |
+
Args:
|
449 |
+
polygon (list[float]): The origin polygon.
|
450 |
+
Returns:
|
451 |
+
new_polygon (lost[float]): The polygon in clockwise point order.
|
452 |
+
"""
|
453 |
+
if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
|
454 |
+
return c
|
455 |
+
elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
|
456 |
+
return c[::-1]
|
457 |
+
else:
|
458 |
+
if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
|
459 |
+
return c
|
460 |
+
else:
|
461 |
+
return c[::-1]
|
462 |
+
|
463 |
+
def cal_fourier_signature(self, polygon, fourier_degree):
|
464 |
+
"""Calculate Fourier signature from input polygon.
|
465 |
+
|
466 |
+
Args:
|
467 |
+
polygon (ndarray): The input polygon.
|
468 |
+
fourier_degree (int): The maximum Fourier degree K.
|
469 |
+
Returns:
|
470 |
+
fourier_signature (ndarray): An array shaped (2k+1, 2) containing
|
471 |
+
real part and image part of 2k+1 Fourier coefficients.
|
472 |
+
"""
|
473 |
+
resampled_polygon = self.resample_polygon(polygon)
|
474 |
+
resampled_polygon = self.normalize_polygon(resampled_polygon)
|
475 |
+
|
476 |
+
fourier_coeff = self.poly2fourier(resampled_polygon, fourier_degree)
|
477 |
+
fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
|
478 |
+
|
479 |
+
real_part = np.real(fourier_coeff).reshape((-1, 1))
|
480 |
+
image_part = np.imag(fourier_coeff).reshape((-1, 1))
|
481 |
+
fourier_signature = np.hstack([real_part, image_part])
|
482 |
+
|
483 |
+
return fourier_signature
|
484 |
+
|
485 |
+
def generate_fourier_maps(self, img_size, text_polys):
|
486 |
+
"""Generate Fourier coefficient maps.
|
487 |
+
|
488 |
+
Args:
|
489 |
+
img_size (tuple): The image size of (height, width).
|
490 |
+
text_polys (list[list[ndarray]]): The list of text polygons.
|
491 |
+
|
492 |
+
Returns:
|
493 |
+
fourier_real_map (ndarray): The Fourier coefficient real part maps.
|
494 |
+
fourier_image_map (ndarray): The Fourier coefficient image part
|
495 |
+
maps.
|
496 |
+
"""
|
497 |
+
|
498 |
+
assert isinstance(img_size, tuple)
|
499 |
+
|
500 |
+
h, w = img_size
|
501 |
+
k = self.fourier_degree
|
502 |
+
real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
|
503 |
+
imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
|
504 |
+
|
505 |
+
for poly in text_polys:
|
506 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
507 |
+
polygon = np.array(poly).reshape((1, -1, 2))
|
508 |
+
cv2.fillPoly(mask, polygon.astype(np.int32), 1)
|
509 |
+
fourier_coeff = self.cal_fourier_signature(polygon[0], k)
|
510 |
+
for i in range(-k, k + 1):
|
511 |
+
if i != 0:
|
512 |
+
real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + (
|
513 |
+
1 - mask) * real_map[i + k, :, :]
|
514 |
+
imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + (
|
515 |
+
1 - mask) * imag_map[i + k, :, :]
|
516 |
+
else:
|
517 |
+
yx = np.argwhere(mask > 0.5)
|
518 |
+
k_ind = np.ones((len(yx)), dtype=np.int64) * k
|
519 |
+
y, x = yx[:, 0], yx[:, 1]
|
520 |
+
real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
|
521 |
+
imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
|
522 |
+
|
523 |
+
return real_map, imag_map
|
524 |
+
|
525 |
+
def generate_text_region_mask(self, img_size, text_polys):
|
526 |
+
"""Generate text center region mask and geometry attribute maps.
|
527 |
+
|
528 |
+
Args:
|
529 |
+
img_size (tuple): The image size (height, width).
|
530 |
+
text_polys (list[list[ndarray]]): The list of text polygons.
|
531 |
+
|
532 |
+
Returns:
|
533 |
+
text_region_mask (ndarray): The text region mask.
|
534 |
+
"""
|
535 |
+
|
536 |
+
assert isinstance(img_size, tuple)
|
537 |
+
|
538 |
+
h, w = img_size
|
539 |
+
text_region_mask = np.zeros((h, w), dtype=np.uint8)
|
540 |
+
|
541 |
+
for poly in text_polys:
|
542 |
+
polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
|
543 |
+
cv2.fillPoly(text_region_mask, polygon, 1)
|
544 |
+
|
545 |
+
return text_region_mask
|
546 |
+
|
547 |
+
def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
|
548 |
+
"""Generate effective mask by setting the ineffective regions to 0 and
|
549 |
+
effective regions to 1.
|
550 |
+
|
551 |
+
Args:
|
552 |
+
mask_size (tuple): The mask size.
|
553 |
+
polygons_ignore (list[[ndarray]]: The list of ignored text
|
554 |
+
polygons.
|
555 |
+
|
556 |
+
Returns:
|
557 |
+
mask (ndarray): The effective mask of (height, width).
|
558 |
+
"""
|
559 |
+
|
560 |
+
mask = np.ones(mask_size, dtype=np.uint8)
|
561 |
+
|
562 |
+
for poly in polygons_ignore:
|
563 |
+
instance = poly.reshape(-1, 2).astype(np.int32).reshape(1, -1, 2)
|
564 |
+
cv2.fillPoly(mask, instance, 0)
|
565 |
+
|
566 |
+
return mask
|
567 |
+
|
568 |
+
def generate_level_targets(self, img_size, text_polys, ignore_polys):
|
569 |
+
"""Generate ground truth target on each level.
|
570 |
+
|
571 |
+
Args:
|
572 |
+
img_size (list[int]): Shape of input image.
|
573 |
+
text_polys (list[list[ndarray]]): A list of ground truth polygons.
|
574 |
+
ignore_polys (list[list[ndarray]]): A list of ignored polygons.
|
575 |
+
Returns:
|
576 |
+
level_maps (list(ndarray)): A list of ground target on each level.
|
577 |
+
"""
|
578 |
+
h, w = img_size
|
579 |
+
lv_size_divs = self.level_size_divisors
|
580 |
+
lv_proportion_range = self.level_proportion_range
|
581 |
+
lv_text_polys = [[] for i in range(len(lv_size_divs))]
|
582 |
+
lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
|
583 |
+
level_maps = []
|
584 |
+
for poly in text_polys:
|
585 |
+
polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2))
|
586 |
+
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
587 |
+
proportion = max(box_h, box_w) / (h + 1e-8)
|
588 |
+
|
589 |
+
for ind, proportion_range in enumerate(lv_proportion_range):
|
590 |
+
if proportion_range[0] < proportion < proportion_range[1]:
|
591 |
+
lv_text_polys[ind].append(poly / lv_size_divs[ind])
|
592 |
+
|
593 |
+
for ignore_poly in ignore_polys:
|
594 |
+
polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2))
|
595 |
+
_, _, box_w, box_h = cv2.boundingRect(polygon)
|
596 |
+
proportion = max(box_h, box_w) / (h + 1e-8)
|
597 |
+
|
598 |
+
for ind, proportion_range in enumerate(lv_proportion_range):
|
599 |
+
if proportion_range[0] < proportion < proportion_range[1]:
|
600 |
+
lv_ignore_polys[ind].append(ignore_poly / lv_size_divs[ind])
|
601 |
+
|
602 |
+
for ind, size_divisor in enumerate(lv_size_divs):
|
603 |
+
current_level_maps = []
|
604 |
+
level_img_size = (h // size_divisor, w // size_divisor)
|
605 |
+
|
606 |
+
text_region = self.generate_text_region_mask(
|
607 |
+
level_img_size, lv_text_polys[ind])[None]
|
608 |
+
current_level_maps.append(text_region)
|
609 |
+
|
610 |
+
center_region = self.generate_center_region_mask(
|
611 |
+
level_img_size, lv_text_polys[ind])[None]
|
612 |
+
current_level_maps.append(center_region)
|
613 |
+
|
614 |
+
effective_mask = self.generate_effective_mask(
|
615 |
+
level_img_size, lv_ignore_polys[ind])[None]
|
616 |
+
current_level_maps.append(effective_mask)
|
617 |
+
|
618 |
+
fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
|
619 |
+
level_img_size, lv_text_polys[ind])
|
620 |
+
current_level_maps.append(fourier_real_map)
|
621 |
+
current_level_maps.append(fourier_image_maps)
|
622 |
+
|
623 |
+
level_maps.append(np.concatenate(current_level_maps))
|
624 |
+
|
625 |
+
return level_maps
|
626 |
+
|
627 |
+
def generate_targets(self, results):
|
628 |
+
"""Generate the ground truth targets for FCENet.
|
629 |
+
|
630 |
+
Args:
|
631 |
+
results (dict): The input result dictionary.
|
632 |
+
|
633 |
+
Returns:
|
634 |
+
results (dict): The output result dictionary.
|
635 |
+
"""
|
636 |
+
|
637 |
+
assert isinstance(results, dict)
|
638 |
+
image = results['image']
|
639 |
+
polygons = results['polys']
|
640 |
+
ignore_tags = results['ignore_tags']
|
641 |
+
h, w, _ = image.shape
|
642 |
+
|
643 |
+
polygon_masks = []
|
644 |
+
polygon_masks_ignore = []
|
645 |
+
for tag, polygon in zip(ignore_tags, polygons):
|
646 |
+
if tag is True:
|
647 |
+
polygon_masks_ignore.append(polygon)
|
648 |
+
else:
|
649 |
+
polygon_masks.append(polygon)
|
650 |
+
|
651 |
+
level_maps = self.generate_level_targets((h, w), polygon_masks,
|
652 |
+
polygon_masks_ignore)
|
653 |
+
|
654 |
+
mapping = {
|
655 |
+
'p3_maps': level_maps[0],
|
656 |
+
'p4_maps': level_maps[1],
|
657 |
+
'p5_maps': level_maps[2]
|
658 |
+
}
|
659 |
+
for key, value in mapping.items():
|
660 |
+
results[key] = value
|
661 |
+
|
662 |
+
return results
|
663 |
+
|
664 |
+
def __call__(self, results):
|
665 |
+
results = self.generate_targets(results)
|
666 |
+
return results
|
ppocr/data/imaug/iaa_augment.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is refer from:
|
16 |
+
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py
|
17 |
+
"""
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
from __future__ import unicode_literals
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import imgaug
|
25 |
+
import imgaug.augmenters as iaa
|
26 |
+
|
27 |
+
|
28 |
+
class AugmenterBuilder(object):
|
29 |
+
def __init__(self):
|
30 |
+
pass
|
31 |
+
|
32 |
+
def build(self, args, root=True):
|
33 |
+
if args is None or len(args) == 0:
|
34 |
+
return None
|
35 |
+
elif isinstance(args, list):
|
36 |
+
if root:
|
37 |
+
sequence = [self.build(value, root=False) for value in args]
|
38 |
+
return iaa.Sequential(sequence)
|
39 |
+
else:
|
40 |
+
return getattr(iaa, args[0])(
|
41 |
+
*[self.to_tuple_if_list(a) for a in args[1:]])
|
42 |
+
elif isinstance(args, dict):
|
43 |
+
cls = getattr(iaa, args['type'])
|
44 |
+
return cls(**{
|
45 |
+
k: self.to_tuple_if_list(v)
|
46 |
+
for k, v in args['args'].items()
|
47 |
+
})
|
48 |
+
else:
|
49 |
+
raise RuntimeError('unknown augmenter arg: ' + str(args))
|
50 |
+
|
51 |
+
def to_tuple_if_list(self, obj):
|
52 |
+
if isinstance(obj, list):
|
53 |
+
return tuple(obj)
|
54 |
+
return obj
|
55 |
+
|
56 |
+
|
57 |
+
class IaaAugment():
|
58 |
+
def __init__(self, augmenter_args=None, **kwargs):
|
59 |
+
if augmenter_args is None:
|
60 |
+
augmenter_args = [{
|
61 |
+
'type': 'Fliplr',
|
62 |
+
'args': {
|
63 |
+
'p': 0.5
|
64 |
+
}
|
65 |
+
}, {
|
66 |
+
'type': 'Affine',
|
67 |
+
'args': {
|
68 |
+
'rotate': [-10, 10]
|
69 |
+
}
|
70 |
+
}, {
|
71 |
+
'type': 'Resize',
|
72 |
+
'args': {
|
73 |
+
'size': [0.5, 3]
|
74 |
+
}
|
75 |
+
}]
|
76 |
+
self.augmenter = AugmenterBuilder().build(augmenter_args)
|
77 |
+
|
78 |
+
def __call__(self, data):
|
79 |
+
image = data['image']
|
80 |
+
shape = image.shape
|
81 |
+
|
82 |
+
if self.augmenter:
|
83 |
+
aug = self.augmenter.to_deterministic()
|
84 |
+
data['image'] = aug.augment_image(image)
|
85 |
+
data = self.may_augment_annotation(aug, data, shape)
|
86 |
+
return data
|
87 |
+
|
88 |
+
def may_augment_annotation(self, aug, data, shape):
|
89 |
+
if aug is None:
|
90 |
+
return data
|
91 |
+
|
92 |
+
line_polys = []
|
93 |
+
for poly in data['polys']:
|
94 |
+
new_poly = self.may_augment_poly(aug, shape, poly)
|
95 |
+
line_polys.append(new_poly)
|
96 |
+
data['polys'] = np.array(line_polys)
|
97 |
+
return data
|
98 |
+
|
99 |
+
def may_augment_poly(self, aug, img_shape, poly):
|
100 |
+
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
|
101 |
+
keypoints = aug.augment_keypoints(
|
102 |
+
[imgaug.KeypointsOnImage(
|
103 |
+
keypoints, shape=img_shape)])[0].keypoints
|
104 |
+
poly = [(p.x, p.y) for p in keypoints]
|
105 |
+
return poly
|
ppocr/data/imaug/label_ops.py
ADDED
@@ -0,0 +1,1505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
from __future__ import unicode_literals
|
19 |
+
|
20 |
+
import copy
|
21 |
+
import numpy as np
|
22 |
+
import string
|
23 |
+
from shapely.geometry import LineString, Point, Polygon
|
24 |
+
import json
|
25 |
+
import copy
|
26 |
+
from random import sample
|
27 |
+
|
28 |
+
from ppocr.utils.logging import get_logger
|
29 |
+
from ppocr.data.imaug.vqa.augment import order_by_tbyx
|
30 |
+
|
31 |
+
|
32 |
+
class ClsLabelEncode(object):
|
33 |
+
def __init__(self, label_list, **kwargs):
|
34 |
+
self.label_list = label_list
|
35 |
+
|
36 |
+
def __call__(self, data):
|
37 |
+
label = data['label']
|
38 |
+
if label not in self.label_list:
|
39 |
+
return None
|
40 |
+
label = self.label_list.index(label)
|
41 |
+
data['label'] = label
|
42 |
+
return data
|
43 |
+
|
44 |
+
|
45 |
+
class DetLabelEncode(object):
|
46 |
+
def __init__(self, **kwargs):
|
47 |
+
pass
|
48 |
+
|
49 |
+
def __call__(self, data):
|
50 |
+
label = data['label']
|
51 |
+
label = json.loads(label)
|
52 |
+
nBox = len(label)
|
53 |
+
boxes, txts, txt_tags = [], [], []
|
54 |
+
for bno in range(0, nBox):
|
55 |
+
box = label[bno]['points']
|
56 |
+
txt = label[bno]['transcription']
|
57 |
+
boxes.append(box)
|
58 |
+
txts.append(txt)
|
59 |
+
if txt in ['*', '###']:
|
60 |
+
txt_tags.append(True)
|
61 |
+
else:
|
62 |
+
txt_tags.append(False)
|
63 |
+
if len(boxes) == 0:
|
64 |
+
return None
|
65 |
+
boxes = self.expand_points_num(boxes)
|
66 |
+
boxes = np.array(boxes, dtype=np.float32)
|
67 |
+
txt_tags = np.array(txt_tags, dtype=bool)
|
68 |
+
|
69 |
+
data['polys'] = boxes
|
70 |
+
data['texts'] = txts
|
71 |
+
data['ignore_tags'] = txt_tags
|
72 |
+
return data
|
73 |
+
|
74 |
+
def order_points_clockwise(self, pts):
|
75 |
+
rect = np.zeros((4, 2), dtype="float32")
|
76 |
+
s = pts.sum(axis=1)
|
77 |
+
rect[0] = pts[np.argmin(s)]
|
78 |
+
rect[2] = pts[np.argmax(s)]
|
79 |
+
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
|
80 |
+
diff = np.diff(np.array(tmp), axis=1)
|
81 |
+
rect[1] = tmp[np.argmin(diff)]
|
82 |
+
rect[3] = tmp[np.argmax(diff)]
|
83 |
+
return rect
|
84 |
+
|
85 |
+
def expand_points_num(self, boxes):
|
86 |
+
max_points_num = 0
|
87 |
+
for box in boxes:
|
88 |
+
if len(box) > max_points_num:
|
89 |
+
max_points_num = len(box)
|
90 |
+
ex_boxes = []
|
91 |
+
for box in boxes:
|
92 |
+
ex_box = box + [box[-1]] * (max_points_num - len(box))
|
93 |
+
ex_boxes.append(ex_box)
|
94 |
+
return ex_boxes
|
95 |
+
|
96 |
+
|
97 |
+
class BaseRecLabelEncode(object):
|
98 |
+
""" Convert between text-label and text-index """
|
99 |
+
|
100 |
+
def __init__(self,
|
101 |
+
max_text_length,
|
102 |
+
character_dict_path=None,
|
103 |
+
use_space_char=False,
|
104 |
+
lower=False):
|
105 |
+
|
106 |
+
self.max_text_len = max_text_length
|
107 |
+
self.beg_str = "sos"
|
108 |
+
self.end_str = "eos"
|
109 |
+
self.lower = lower
|
110 |
+
|
111 |
+
if character_dict_path is None:
|
112 |
+
logger = get_logger()
|
113 |
+
logger.warning(
|
114 |
+
"The character_dict_path is None, model can only recognize number and lower letters"
|
115 |
+
)
|
116 |
+
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
117 |
+
dict_character = list(self.character_str)
|
118 |
+
self.lower = True
|
119 |
+
else:
|
120 |
+
self.character_str = []
|
121 |
+
with open(character_dict_path, "rb") as fin:
|
122 |
+
lines = fin.readlines()
|
123 |
+
for line in lines:
|
124 |
+
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
125 |
+
self.character_str.append(line)
|
126 |
+
if use_space_char:
|
127 |
+
self.character_str.append(" ")
|
128 |
+
dict_character = list(self.character_str)
|
129 |
+
dict_character = self.add_special_char(dict_character)
|
130 |
+
self.dict = {}
|
131 |
+
for i, char in enumerate(dict_character):
|
132 |
+
self.dict[char] = i
|
133 |
+
self.character = dict_character
|
134 |
+
|
135 |
+
def add_special_char(self, dict_character):
|
136 |
+
return dict_character
|
137 |
+
|
138 |
+
def encode(self, text):
|
139 |
+
"""convert text-label into text-index.
|
140 |
+
input:
|
141 |
+
text: text labels of each image. [batch_size]
|
142 |
+
|
143 |
+
output:
|
144 |
+
text: concatenated text index for CTCLoss.
|
145 |
+
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
146 |
+
length: length of each text. [batch_size]
|
147 |
+
"""
|
148 |
+
if len(text) == 0 or len(text) > self.max_text_len:
|
149 |
+
return None
|
150 |
+
if self.lower:
|
151 |
+
text = text.lower()
|
152 |
+
text_list = []
|
153 |
+
for char in text:
|
154 |
+
if char not in self.dict:
|
155 |
+
# logger = get_logger()
|
156 |
+
# logger.warning('{} is not in dict'.format(char))
|
157 |
+
continue
|
158 |
+
text_list.append(self.dict[char])
|
159 |
+
if len(text_list) == 0:
|
160 |
+
return None
|
161 |
+
return text_list
|
162 |
+
|
163 |
+
|
164 |
+
class CTCLabelEncode(BaseRecLabelEncode):
|
165 |
+
""" Convert between text-label and text-index """
|
166 |
+
|
167 |
+
def __init__(self,
|
168 |
+
max_text_length,
|
169 |
+
character_dict_path=None,
|
170 |
+
use_space_char=False,
|
171 |
+
**kwargs):
|
172 |
+
super(CTCLabelEncode, self).__init__(
|
173 |
+
max_text_length, character_dict_path, use_space_char)
|
174 |
+
|
175 |
+
def __call__(self, data):
|
176 |
+
text = data['label']
|
177 |
+
text = self.encode(text)
|
178 |
+
if text is None:
|
179 |
+
return None
|
180 |
+
data['length'] = np.array(len(text))
|
181 |
+
text = text + [0] * (self.max_text_len - len(text))
|
182 |
+
data['label'] = np.array(text)
|
183 |
+
|
184 |
+
label = [0] * len(self.character)
|
185 |
+
for x in text:
|
186 |
+
label[x] += 1
|
187 |
+
data['label_ace'] = np.array(label)
|
188 |
+
return data
|
189 |
+
|
190 |
+
def add_special_char(self, dict_character):
|
191 |
+
dict_character = ['blank'] + dict_character
|
192 |
+
return dict_character
|
193 |
+
|
194 |
+
|
195 |
+
class E2ELabelEncodeTest(BaseRecLabelEncode):
|
196 |
+
def __init__(self,
|
197 |
+
max_text_length,
|
198 |
+
character_dict_path=None,
|
199 |
+
use_space_char=False,
|
200 |
+
**kwargs):
|
201 |
+
super(E2ELabelEncodeTest, self).__init__(
|
202 |
+
max_text_length, character_dict_path, use_space_char)
|
203 |
+
|
204 |
+
def __call__(self, data):
|
205 |
+
import json
|
206 |
+
padnum = len(self.dict)
|
207 |
+
label = data['label']
|
208 |
+
label = json.loads(label)
|
209 |
+
nBox = len(label)
|
210 |
+
boxes, txts, txt_tags = [], [], []
|
211 |
+
for bno in range(0, nBox):
|
212 |
+
box = label[bno]['points']
|
213 |
+
txt = label[bno]['transcription']
|
214 |
+
boxes.append(box)
|
215 |
+
txts.append(txt)
|
216 |
+
if txt in ['*', '###']:
|
217 |
+
txt_tags.append(True)
|
218 |
+
else:
|
219 |
+
txt_tags.append(False)
|
220 |
+
boxes = np.array(boxes, dtype=np.float32)
|
221 |
+
txt_tags = np.array(txt_tags, dtype=bool)
|
222 |
+
data['polys'] = boxes
|
223 |
+
data['ignore_tags'] = txt_tags
|
224 |
+
temp_texts = []
|
225 |
+
for text in txts:
|
226 |
+
text = text.lower()
|
227 |
+
text = self.encode(text)
|
228 |
+
if text is None:
|
229 |
+
return None
|
230 |
+
text = text + [padnum] * (self.max_text_len - len(text)
|
231 |
+
) # use 36 to pad
|
232 |
+
temp_texts.append(text)
|
233 |
+
data['texts'] = np.array(temp_texts)
|
234 |
+
return data
|
235 |
+
|
236 |
+
|
237 |
+
class E2ELabelEncodeTrain(object):
|
238 |
+
def __init__(self, **kwargs):
|
239 |
+
pass
|
240 |
+
|
241 |
+
def __call__(self, data):
|
242 |
+
import json
|
243 |
+
label = data['label']
|
244 |
+
label = json.loads(label)
|
245 |
+
nBox = len(label)
|
246 |
+
boxes, txts, txt_tags = [], [], []
|
247 |
+
for bno in range(0, nBox):
|
248 |
+
box = label[bno]['points']
|
249 |
+
txt = label[bno]['transcription']
|
250 |
+
boxes.append(box)
|
251 |
+
txts.append(txt)
|
252 |
+
if txt in ['*', '###']:
|
253 |
+
txt_tags.append(True)
|
254 |
+
else:
|
255 |
+
txt_tags.append(False)
|
256 |
+
boxes = np.array(boxes, dtype=np.float32)
|
257 |
+
txt_tags = np.array(txt_tags, dtype=bool)
|
258 |
+
|
259 |
+
data['polys'] = boxes
|
260 |
+
data['texts'] = txts
|
261 |
+
data['ignore_tags'] = txt_tags
|
262 |
+
return data
|
263 |
+
|
264 |
+
|
265 |
+
class KieLabelEncode(object):
|
266 |
+
def __init__(self,
|
267 |
+
character_dict_path,
|
268 |
+
class_path,
|
269 |
+
norm=10,
|
270 |
+
directed=False,
|
271 |
+
**kwargs):
|
272 |
+
super(KieLabelEncode, self).__init__()
|
273 |
+
self.dict = dict({'': 0})
|
274 |
+
self.label2classid_map = dict()
|
275 |
+
with open(character_dict_path, 'r', encoding='utf-8') as fr:
|
276 |
+
idx = 1
|
277 |
+
for line in fr:
|
278 |
+
char = line.strip()
|
279 |
+
self.dict[char] = idx
|
280 |
+
idx += 1
|
281 |
+
with open(class_path, "r") as fin:
|
282 |
+
lines = fin.readlines()
|
283 |
+
for idx, line in enumerate(lines):
|
284 |
+
line = line.strip("\n")
|
285 |
+
self.label2classid_map[line] = idx
|
286 |
+
self.norm = norm
|
287 |
+
self.directed = directed
|
288 |
+
|
289 |
+
def compute_relation(self, boxes):
|
290 |
+
"""Compute relation between every two boxes."""
|
291 |
+
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
|
292 |
+
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
|
293 |
+
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
|
294 |
+
dxs = (x1s[:, 0][None] - x1s) / self.norm
|
295 |
+
dys = (y1s[:, 0][None] - y1s) / self.norm
|
296 |
+
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
|
297 |
+
whs = ws / hs + np.zeros_like(xhhs)
|
298 |
+
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
|
299 |
+
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
|
300 |
+
return relations, bboxes
|
301 |
+
|
302 |
+
def pad_text_indices(self, text_inds):
|
303 |
+
"""Pad text index to same length."""
|
304 |
+
max_len = 300
|
305 |
+
recoder_len = max([len(text_ind) for text_ind in text_inds])
|
306 |
+
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
|
307 |
+
for idx, text_ind in enumerate(text_inds):
|
308 |
+
padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
|
309 |
+
return padded_text_inds, recoder_len
|
310 |
+
|
311 |
+
def list_to_numpy(self, ann_infos):
|
312 |
+
"""Convert bboxes, relations, texts and labels to ndarray."""
|
313 |
+
boxes, text_inds = ann_infos['points'], ann_infos['text_inds']
|
314 |
+
boxes = np.array(boxes, np.int32)
|
315 |
+
relations, bboxes = self.compute_relation(boxes)
|
316 |
+
|
317 |
+
labels = ann_infos.get('labels', None)
|
318 |
+
if labels is not None:
|
319 |
+
labels = np.array(labels, np.int32)
|
320 |
+
edges = ann_infos.get('edges', None)
|
321 |
+
if edges is not None:
|
322 |
+
labels = labels[:, None]
|
323 |
+
edges = np.array(edges)
|
324 |
+
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
|
325 |
+
if self.directed:
|
326 |
+
edges = (edges & labels == 1).astype(np.int32)
|
327 |
+
np.fill_diagonal(edges, -1)
|
328 |
+
labels = np.concatenate([labels, edges], -1)
|
329 |
+
padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
|
330 |
+
max_num = 300
|
331 |
+
temp_bboxes = np.zeros([max_num, 4])
|
332 |
+
h, _ = bboxes.shape
|
333 |
+
temp_bboxes[:h, :] = bboxes
|
334 |
+
|
335 |
+
temp_relations = np.zeros([max_num, max_num, 5])
|
336 |
+
temp_relations[:h, :h, :] = relations
|
337 |
+
|
338 |
+
temp_padded_text_inds = np.zeros([max_num, max_num])
|
339 |
+
temp_padded_text_inds[:h, :] = padded_text_inds
|
340 |
+
|
341 |
+
temp_labels = np.zeros([max_num, max_num])
|
342 |
+
temp_labels[:h, :h + 1] = labels
|
343 |
+
|
344 |
+
tag = np.array([h, recoder_len])
|
345 |
+
return dict(
|
346 |
+
image=ann_infos['image'],
|
347 |
+
points=temp_bboxes,
|
348 |
+
relations=temp_relations,
|
349 |
+
texts=temp_padded_text_inds,
|
350 |
+
labels=temp_labels,
|
351 |
+
tag=tag)
|
352 |
+
|
353 |
+
def convert_canonical(self, points_x, points_y):
|
354 |
+
|
355 |
+
assert len(points_x) == 4
|
356 |
+
assert len(points_y) == 4
|
357 |
+
|
358 |
+
points = [Point(points_x[i], points_y[i]) for i in range(4)]
|
359 |
+
|
360 |
+
polygon = Polygon([(p.x, p.y) for p in points])
|
361 |
+
min_x, min_y, _, _ = polygon.bounds
|
362 |
+
points_to_lefttop = [
|
363 |
+
LineString([points[i], Point(min_x, min_y)]) for i in range(4)
|
364 |
+
]
|
365 |
+
distances = np.array([line.length for line in points_to_lefttop])
|
366 |
+
sort_dist_idx = np.argsort(distances)
|
367 |
+
lefttop_idx = sort_dist_idx[0]
|
368 |
+
|
369 |
+
if lefttop_idx == 0:
|
370 |
+
point_orders = [0, 1, 2, 3]
|
371 |
+
elif lefttop_idx == 1:
|
372 |
+
point_orders = [1, 2, 3, 0]
|
373 |
+
elif lefttop_idx == 2:
|
374 |
+
point_orders = [2, 3, 0, 1]
|
375 |
+
else:
|
376 |
+
point_orders = [3, 0, 1, 2]
|
377 |
+
|
378 |
+
sorted_points_x = [points_x[i] for i in point_orders]
|
379 |
+
sorted_points_y = [points_y[j] for j in point_orders]
|
380 |
+
|
381 |
+
return sorted_points_x, sorted_points_y
|
382 |
+
|
383 |
+
def sort_vertex(self, points_x, points_y):
|
384 |
+
|
385 |
+
assert len(points_x) == 4
|
386 |
+
assert len(points_y) == 4
|
387 |
+
|
388 |
+
x = np.array(points_x)
|
389 |
+
y = np.array(points_y)
|
390 |
+
center_x = np.sum(x) * 0.25
|
391 |
+
center_y = np.sum(y) * 0.25
|
392 |
+
|
393 |
+
x_arr = np.array(x - center_x)
|
394 |
+
y_arr = np.array(y - center_y)
|
395 |
+
|
396 |
+
angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
|
397 |
+
sort_idx = np.argsort(angle)
|
398 |
+
|
399 |
+
sorted_points_x, sorted_points_y = [], []
|
400 |
+
for i in range(4):
|
401 |
+
sorted_points_x.append(points_x[sort_idx[i]])
|
402 |
+
sorted_points_y.append(points_y[sort_idx[i]])
|
403 |
+
|
404 |
+
return self.convert_canonical(sorted_points_x, sorted_points_y)
|
405 |
+
|
406 |
+
def __call__(self, data):
|
407 |
+
import json
|
408 |
+
label = data['label']
|
409 |
+
annotations = json.loads(label)
|
410 |
+
boxes, texts, text_inds, labels, edges = [], [], [], [], []
|
411 |
+
for ann in annotations:
|
412 |
+
box = ann['points']
|
413 |
+
x_list = [box[i][0] for i in range(4)]
|
414 |
+
y_list = [box[i][1] for i in range(4)]
|
415 |
+
sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
|
416 |
+
sorted_box = []
|
417 |
+
for x, y in zip(sorted_x_list, sorted_y_list):
|
418 |
+
sorted_box.append(x)
|
419 |
+
sorted_box.append(y)
|
420 |
+
boxes.append(sorted_box)
|
421 |
+
text = ann['transcription']
|
422 |
+
texts.append(ann['transcription'])
|
423 |
+
text_ind = [self.dict[c] for c in text if c in self.dict]
|
424 |
+
text_inds.append(text_ind)
|
425 |
+
if 'label' in ann.keys():
|
426 |
+
labels.append(self.label2classid_map[ann['label']])
|
427 |
+
elif 'key_cls' in ann.keys():
|
428 |
+
labels.append(ann['key_cls'])
|
429 |
+
else:
|
430 |
+
raise ValueError(
|
431 |
+
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
|
432 |
+
)
|
433 |
+
edges.append(ann.get('edge', 0))
|
434 |
+
ann_infos = dict(
|
435 |
+
image=data['image'],
|
436 |
+
points=boxes,
|
437 |
+
texts=texts,
|
438 |
+
text_inds=text_inds,
|
439 |
+
edges=edges,
|
440 |
+
labels=labels)
|
441 |
+
|
442 |
+
return self.list_to_numpy(ann_infos)
|
443 |
+
|
444 |
+
|
445 |
+
class AttnLabelEncode(BaseRecLabelEncode):
|
446 |
+
""" Convert between text-label and text-index """
|
447 |
+
|
448 |
+
def __init__(self,
|
449 |
+
max_text_length,
|
450 |
+
character_dict_path=None,
|
451 |
+
use_space_char=False,
|
452 |
+
**kwargs):
|
453 |
+
super(AttnLabelEncode, self).__init__(
|
454 |
+
max_text_length, character_dict_path, use_space_char)
|
455 |
+
|
456 |
+
def add_special_char(self, dict_character):
|
457 |
+
self.beg_str = "sos"
|
458 |
+
self.end_str = "eos"
|
459 |
+
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
460 |
+
return dict_character
|
461 |
+
|
462 |
+
def __call__(self, data):
|
463 |
+
text = data['label']
|
464 |
+
text = self.encode(text)
|
465 |
+
if text is None:
|
466 |
+
return None
|
467 |
+
if len(text) >= self.max_text_len:
|
468 |
+
return None
|
469 |
+
data['length'] = np.array(len(text))
|
470 |
+
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
471 |
+
- len(text) - 2)
|
472 |
+
data['label'] = np.array(text)
|
473 |
+
return data
|
474 |
+
|
475 |
+
def get_ignored_tokens(self):
|
476 |
+
beg_idx = self.get_beg_end_flag_idx("beg")
|
477 |
+
end_idx = self.get_beg_end_flag_idx("end")
|
478 |
+
return [beg_idx, end_idx]
|
479 |
+
|
480 |
+
def get_beg_end_flag_idx(self, beg_or_end):
|
481 |
+
if beg_or_end == "beg":
|
482 |
+
idx = np.array(self.dict[self.beg_str])
|
483 |
+
elif beg_or_end == "end":
|
484 |
+
idx = np.array(self.dict[self.end_str])
|
485 |
+
else:
|
486 |
+
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
487 |
+
% beg_or_end
|
488 |
+
return idx
|
489 |
+
|
490 |
+
|
491 |
+
class RFLLabelEncode(BaseRecLabelEncode):
|
492 |
+
""" Convert between text-label and text-index """
|
493 |
+
|
494 |
+
def __init__(self,
|
495 |
+
max_text_length,
|
496 |
+
character_dict_path=None,
|
497 |
+
use_space_char=False,
|
498 |
+
**kwargs):
|
499 |
+
super(RFLLabelEncode, self).__init__(
|
500 |
+
max_text_length, character_dict_path, use_space_char)
|
501 |
+
|
502 |
+
def add_special_char(self, dict_character):
|
503 |
+
self.beg_str = "sos"
|
504 |
+
self.end_str = "eos"
|
505 |
+
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
506 |
+
return dict_character
|
507 |
+
|
508 |
+
def encode_cnt(self, text):
|
509 |
+
cnt_label = [0.0] * len(self.character)
|
510 |
+
for char_ in text:
|
511 |
+
cnt_label[char_] += 1
|
512 |
+
return np.array(cnt_label)
|
513 |
+
|
514 |
+
def __call__(self, data):
|
515 |
+
text = data['label']
|
516 |
+
text = self.encode(text)
|
517 |
+
if text is None:
|
518 |
+
return None
|
519 |
+
if len(text) >= self.max_text_len:
|
520 |
+
return None
|
521 |
+
cnt_label = self.encode_cnt(text)
|
522 |
+
data['length'] = np.array(len(text))
|
523 |
+
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
524 |
+
- len(text) - 2)
|
525 |
+
if len(text) != self.max_text_len:
|
526 |
+
return None
|
527 |
+
data['label'] = np.array(text)
|
528 |
+
data['cnt_label'] = cnt_label
|
529 |
+
return data
|
530 |
+
|
531 |
+
def get_ignored_tokens(self):
|
532 |
+
beg_idx = self.get_beg_end_flag_idx("beg")
|
533 |
+
end_idx = self.get_beg_end_flag_idx("end")
|
534 |
+
return [beg_idx, end_idx]
|
535 |
+
|
536 |
+
def get_beg_end_flag_idx(self, beg_or_end):
|
537 |
+
if beg_or_end == "beg":
|
538 |
+
idx = np.array(self.dict[self.beg_str])
|
539 |
+
elif beg_or_end == "end":
|
540 |
+
idx = np.array(self.dict[self.end_str])
|
541 |
+
else:
|
542 |
+
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
543 |
+
% beg_or_end
|
544 |
+
return idx
|
545 |
+
|
546 |
+
|
547 |
+
class SEEDLabelEncode(BaseRecLabelEncode):
|
548 |
+
""" Convert between text-label and text-index """
|
549 |
+
|
550 |
+
def __init__(self,
|
551 |
+
max_text_length,
|
552 |
+
character_dict_path=None,
|
553 |
+
use_space_char=False,
|
554 |
+
**kwargs):
|
555 |
+
super(SEEDLabelEncode, self).__init__(
|
556 |
+
max_text_length, character_dict_path, use_space_char)
|
557 |
+
|
558 |
+
def add_special_char(self, dict_character):
|
559 |
+
self.padding = "padding"
|
560 |
+
self.end_str = "eos"
|
561 |
+
self.unknown = "unknown"
|
562 |
+
dict_character = dict_character + [
|
563 |
+
self.end_str, self.padding, self.unknown
|
564 |
+
]
|
565 |
+
return dict_character
|
566 |
+
|
567 |
+
def __call__(self, data):
|
568 |
+
text = data['label']
|
569 |
+
text = self.encode(text)
|
570 |
+
if text is None:
|
571 |
+
return None
|
572 |
+
if len(text) >= self.max_text_len:
|
573 |
+
return None
|
574 |
+
data['length'] = np.array(len(text)) + 1 # conclude eos
|
575 |
+
text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
|
576 |
+
self.max_text_len - len(text) - 1)
|
577 |
+
data['label'] = np.array(text)
|
578 |
+
return data
|
579 |
+
|
580 |
+
|
581 |
+
class SRNLabelEncode(BaseRecLabelEncode):
|
582 |
+
""" Convert between text-label and text-index """
|
583 |
+
|
584 |
+
def __init__(self,
|
585 |
+
max_text_length=25,
|
586 |
+
character_dict_path=None,
|
587 |
+
use_space_char=False,
|
588 |
+
**kwargs):
|
589 |
+
super(SRNLabelEncode, self).__init__(
|
590 |
+
max_text_length, character_dict_path, use_space_char)
|
591 |
+
|
592 |
+
def add_special_char(self, dict_character):
|
593 |
+
dict_character = dict_character + [self.beg_str, self.end_str]
|
594 |
+
return dict_character
|
595 |
+
|
596 |
+
def __call__(self, data):
|
597 |
+
text = data['label']
|
598 |
+
text = self.encode(text)
|
599 |
+
char_num = len(self.character)
|
600 |
+
if text is None:
|
601 |
+
return None
|
602 |
+
if len(text) > self.max_text_len:
|
603 |
+
return None
|
604 |
+
data['length'] = np.array(len(text))
|
605 |
+
text = text + [char_num - 1] * (self.max_text_len - len(text))
|
606 |
+
data['label'] = np.array(text)
|
607 |
+
return data
|
608 |
+
|
609 |
+
def get_ignored_tokens(self):
|
610 |
+
beg_idx = self.get_beg_end_flag_idx("beg")
|
611 |
+
end_idx = self.get_beg_end_flag_idx("end")
|
612 |
+
return [beg_idx, end_idx]
|
613 |
+
|
614 |
+
def get_beg_end_flag_idx(self, beg_or_end):
|
615 |
+
if beg_or_end == "beg":
|
616 |
+
idx = np.array(self.dict[self.beg_str])
|
617 |
+
elif beg_or_end == "end":
|
618 |
+
idx = np.array(self.dict[self.end_str])
|
619 |
+
else:
|
620 |
+
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
621 |
+
% beg_or_end
|
622 |
+
return idx
|
623 |
+
|
624 |
+
|
625 |
+
class TableLabelEncode(AttnLabelEncode):
|
626 |
+
""" Convert between text-label and text-index """
|
627 |
+
|
628 |
+
def __init__(self,
|
629 |
+
max_text_length,
|
630 |
+
character_dict_path,
|
631 |
+
replace_empty_cell_token=False,
|
632 |
+
merge_no_span_structure=False,
|
633 |
+
learn_empty_box=False,
|
634 |
+
loc_reg_num=4,
|
635 |
+
**kwargs):
|
636 |
+
self.max_text_len = max_text_length
|
637 |
+
self.lower = False
|
638 |
+
self.learn_empty_box = learn_empty_box
|
639 |
+
self.merge_no_span_structure = merge_no_span_structure
|
640 |
+
self.replace_empty_cell_token = replace_empty_cell_token
|
641 |
+
|
642 |
+
dict_character = []
|
643 |
+
with open(character_dict_path, "rb") as fin:
|
644 |
+
lines = fin.readlines()
|
645 |
+
for line in lines:
|
646 |
+
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
647 |
+
dict_character.append(line)
|
648 |
+
|
649 |
+
if self.merge_no_span_structure:
|
650 |
+
if "<td></td>" not in dict_character:
|
651 |
+
dict_character.append("<td></td>")
|
652 |
+
if "<td>" in dict_character:
|
653 |
+
dict_character.remove("<td>")
|
654 |
+
|
655 |
+
dict_character = self.add_special_char(dict_character)
|
656 |
+
self.dict = {}
|
657 |
+
for i, char in enumerate(dict_character):
|
658 |
+
self.dict[char] = i
|
659 |
+
self.idx2char = {v: k for k, v in self.dict.items()}
|
660 |
+
|
661 |
+
self.character = dict_character
|
662 |
+
self.loc_reg_num = loc_reg_num
|
663 |
+
self.pad_idx = self.dict[self.beg_str]
|
664 |
+
self.start_idx = self.dict[self.beg_str]
|
665 |
+
self.end_idx = self.dict[self.end_str]
|
666 |
+
|
667 |
+
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
|
668 |
+
self.empty_bbox_token_dict = {
|
669 |
+
"[]": '<eb></eb>',
|
670 |
+
"[' ']": '<eb1></eb1>',
|
671 |
+
"['<b>', ' ', '</b>']": '<eb2></eb2>',
|
672 |
+
"['\\u2028', '\\u2028']": '<eb3></eb3>',
|
673 |
+
"['<sup>', ' ', '</sup>']": '<eb4></eb4>',
|
674 |
+
"['<b>', '</b>']": '<eb5></eb5>',
|
675 |
+
"['<i>', ' ', '</i>']": '<eb6></eb6>',
|
676 |
+
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
|
677 |
+
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
|
678 |
+
"['<i>', '</i>']": '<eb9></eb9>',
|
679 |
+
"['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']":
|
680 |
+
'<eb10></eb10>',
|
681 |
+
}
|
682 |
+
|
683 |
+
@property
|
684 |
+
def _max_text_len(self):
|
685 |
+
return self.max_text_len + 2
|
686 |
+
|
687 |
+
def __call__(self, data):
|
688 |
+
cells = data['cells']
|
689 |
+
structure = data['structure']
|
690 |
+
if self.merge_no_span_structure:
|
691 |
+
structure = self._merge_no_span_structure(structure)
|
692 |
+
if self.replace_empty_cell_token:
|
693 |
+
structure = self._replace_empty_cell_token(structure, cells)
|
694 |
+
# remove empty token and add " " to span token
|
695 |
+
new_structure = []
|
696 |
+
for token in structure:
|
697 |
+
if token != '':
|
698 |
+
if 'span' in token and token[0] != ' ':
|
699 |
+
token = ' ' + token
|
700 |
+
new_structure.append(token)
|
701 |
+
# encode structure
|
702 |
+
structure = self.encode(new_structure)
|
703 |
+
if structure is None:
|
704 |
+
return None
|
705 |
+
|
706 |
+
structure = [self.start_idx] + structure + [self.end_idx
|
707 |
+
] # add sos abd eos
|
708 |
+
structure = structure + [self.pad_idx] * (self._max_text_len -
|
709 |
+
len(structure)) # pad
|
710 |
+
structure = np.array(structure)
|
711 |
+
data['structure'] = structure
|
712 |
+
|
713 |
+
if len(structure) > self._max_text_len:
|
714 |
+
return None
|
715 |
+
|
716 |
+
# encode box
|
717 |
+
bboxes = np.zeros(
|
718 |
+
(self._max_text_len, self.loc_reg_num), dtype=np.float32)
|
719 |
+
bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
|
720 |
+
|
721 |
+
bbox_idx = 0
|
722 |
+
|
723 |
+
for i, token in enumerate(structure):
|
724 |
+
if self.idx2char[token] in self.td_token:
|
725 |
+
if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][
|
726 |
+
'tokens']) > 0:
|
727 |
+
bbox = cells[bbox_idx]['bbox'].copy()
|
728 |
+
bbox = np.array(bbox, dtype=np.float32).reshape(-1)
|
729 |
+
bboxes[i] = bbox
|
730 |
+
bbox_masks[i] = 1.0
|
731 |
+
if self.learn_empty_box:
|
732 |
+
bbox_masks[i] = 1.0
|
733 |
+
bbox_idx += 1
|
734 |
+
data['bboxes'] = bboxes
|
735 |
+
data['bbox_masks'] = bbox_masks
|
736 |
+
return data
|
737 |
+
|
738 |
+
def _merge_no_span_structure(self, structure):
|
739 |
+
"""
|
740 |
+
This code is refer from:
|
741 |
+
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
|
742 |
+
"""
|
743 |
+
new_structure = []
|
744 |
+
i = 0
|
745 |
+
while i < len(structure):
|
746 |
+
token = structure[i]
|
747 |
+
if token == '<td>':
|
748 |
+
token = '<td></td>'
|
749 |
+
i += 1
|
750 |
+
new_structure.append(token)
|
751 |
+
i += 1
|
752 |
+
return new_structure
|
753 |
+
|
754 |
+
def _replace_empty_cell_token(self, token_list, cells):
|
755 |
+
"""
|
756 |
+
This fun code is refer from:
|
757 |
+
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
|
758 |
+
"""
|
759 |
+
|
760 |
+
bbox_idx = 0
|
761 |
+
add_empty_bbox_token_list = []
|
762 |
+
for token in token_list:
|
763 |
+
if token in ['<td></td>', '<td', '<td>']:
|
764 |
+
if 'bbox' not in cells[bbox_idx].keys():
|
765 |
+
content = str(cells[bbox_idx]['tokens'])
|
766 |
+
token = self.empty_bbox_token_dict[content]
|
767 |
+
add_empty_bbox_token_list.append(token)
|
768 |
+
bbox_idx += 1
|
769 |
+
else:
|
770 |
+
add_empty_bbox_token_list.append(token)
|
771 |
+
return add_empty_bbox_token_list
|
772 |
+
|
773 |
+
|
774 |
+
class TableMasterLabelEncode(TableLabelEncode):
|
775 |
+
""" Convert between text-label and text-index """
|
776 |
+
|
777 |
+
def __init__(self,
|
778 |
+
max_text_length,
|
779 |
+
character_dict_path,
|
780 |
+
replace_empty_cell_token=False,
|
781 |
+
merge_no_span_structure=False,
|
782 |
+
learn_empty_box=False,
|
783 |
+
loc_reg_num=4,
|
784 |
+
**kwargs):
|
785 |
+
super(TableMasterLabelEncode, self).__init__(
|
786 |
+
max_text_length, character_dict_path, replace_empty_cell_token,
|
787 |
+
merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs)
|
788 |
+
self.pad_idx = self.dict[self.pad_str]
|
789 |
+
self.unknown_idx = self.dict[self.unknown_str]
|
790 |
+
|
791 |
+
@property
|
792 |
+
def _max_text_len(self):
|
793 |
+
return self.max_text_len
|
794 |
+
|
795 |
+
def add_special_char(self, dict_character):
|
796 |
+
self.beg_str = '<SOS>'
|
797 |
+
self.end_str = '<EOS>'
|
798 |
+
self.unknown_str = '<UKN>'
|
799 |
+
self.pad_str = '<PAD>'
|
800 |
+
dict_character = dict_character
|
801 |
+
dict_character = dict_character + [
|
802 |
+
self.unknown_str, self.beg_str, self.end_str, self.pad_str
|
803 |
+
]
|
804 |
+
return dict_character
|
805 |
+
|
806 |
+
|
807 |
+
class TableBoxEncode(object):
|
808 |
+
def __init__(self, in_box_format='xyxy', out_box_format='xyxy', **kwargs):
|
809 |
+
assert out_box_format in ['xywh', 'xyxy', 'xyxyxyxy']
|
810 |
+
self.in_box_format = in_box_format
|
811 |
+
self.out_box_format = out_box_format
|
812 |
+
|
813 |
+
def __call__(self, data):
|
814 |
+
img_height, img_width = data['image'].shape[:2]
|
815 |
+
bboxes = data['bboxes']
|
816 |
+
if self.in_box_format != self.out_box_format:
|
817 |
+
if self.out_box_format == 'xywh':
|
818 |
+
if self.in_box_format == 'xyxyxyxy':
|
819 |
+
bboxes = self.xyxyxyxy2xywh(bboxes)
|
820 |
+
elif self.in_box_format == 'xyxy':
|
821 |
+
bboxes = self.xyxy2xywh(bboxes)
|
822 |
+
|
823 |
+
bboxes[:, 0::2] /= img_width
|
824 |
+
bboxes[:, 1::2] /= img_height
|
825 |
+
data['bboxes'] = bboxes
|
826 |
+
return data
|
827 |
+
|
828 |
+
def xyxyxyxy2xywh(self, boxes):
|
829 |
+
new_bboxes = np.zeros([len(bboxes), 4])
|
830 |
+
new_bboxes[:, 0] = bboxes[:, 0::2].min() # x1
|
831 |
+
new_bboxes[:, 1] = bboxes[:, 1::2].min() # y1
|
832 |
+
new_bboxes[:, 2] = bboxes[:, 0::2].max() - new_bboxes[:, 0] # w
|
833 |
+
new_bboxes[:, 3] = bboxes[:, 1::2].max() - new_bboxes[:, 1] # h
|
834 |
+
return new_bboxes
|
835 |
+
|
836 |
+
def xyxy2xywh(self, bboxes):
|
837 |
+
new_bboxes = np.empty_like(bboxes)
|
838 |
+
new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
|
839 |
+
new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
|
840 |
+
new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width
|
841 |
+
new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height
|
842 |
+
return new_bboxes
|
843 |
+
|
844 |
+
|
845 |
+
class SARLabelEncode(BaseRecLabelEncode):
|
846 |
+
""" Convert between text-label and text-index """
|
847 |
+
|
848 |
+
def __init__(self,
|
849 |
+
max_text_length,
|
850 |
+
character_dict_path=None,
|
851 |
+
use_space_char=False,
|
852 |
+
**kwargs):
|
853 |
+
super(SARLabelEncode, self).__init__(
|
854 |
+
max_text_length, character_dict_path, use_space_char)
|
855 |
+
|
856 |
+
def add_special_char(self, dict_character):
|
857 |
+
beg_end_str = "<BOS/EOS>"
|
858 |
+
unknown_str = "<UKN>"
|
859 |
+
padding_str = "<PAD>"
|
860 |
+
dict_character = dict_character + [unknown_str]
|
861 |
+
self.unknown_idx = len(dict_character) - 1
|
862 |
+
dict_character = dict_character + [beg_end_str]
|
863 |
+
self.start_idx = len(dict_character) - 1
|
864 |
+
self.end_idx = len(dict_character) - 1
|
865 |
+
dict_character = dict_character + [padding_str]
|
866 |
+
self.padding_idx = len(dict_character) - 1
|
867 |
+
|
868 |
+
return dict_character
|
869 |
+
|
870 |
+
def __call__(self, data):
|
871 |
+
text = data['label']
|
872 |
+
text = self.encode(text)
|
873 |
+
if text is None:
|
874 |
+
return None
|
875 |
+
if len(text) >= self.max_text_len - 1:
|
876 |
+
return None
|
877 |
+
data['length'] = np.array(len(text))
|
878 |
+
target = [self.start_idx] + text + [self.end_idx]
|
879 |
+
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
|
880 |
+
|
881 |
+
padded_text[:len(target)] = target
|
882 |
+
data['label'] = np.array(padded_text)
|
883 |
+
return data
|
884 |
+
|
885 |
+
def get_ignored_tokens(self):
|
886 |
+
return [self.padding_idx]
|
887 |
+
|
888 |
+
|
889 |
+
class PRENLabelEncode(BaseRecLabelEncode):
|
890 |
+
def __init__(self,
|
891 |
+
max_text_length,
|
892 |
+
character_dict_path,
|
893 |
+
use_space_char=False,
|
894 |
+
**kwargs):
|
895 |
+
super(PRENLabelEncode, self).__init__(
|
896 |
+
max_text_length, character_dict_path, use_space_char)
|
897 |
+
|
898 |
+
def add_special_char(self, dict_character):
|
899 |
+
padding_str = '<PAD>' # 0
|
900 |
+
end_str = '<EOS>' # 1
|
901 |
+
unknown_str = '<UNK>' # 2
|
902 |
+
|
903 |
+
dict_character = [padding_str, end_str, unknown_str] + dict_character
|
904 |
+
self.padding_idx = 0
|
905 |
+
self.end_idx = 1
|
906 |
+
self.unknown_idx = 2
|
907 |
+
|
908 |
+
return dict_character
|
909 |
+
|
910 |
+
def encode(self, text):
|
911 |
+
if len(text) == 0 or len(text) >= self.max_text_len:
|
912 |
+
return None
|
913 |
+
if self.lower:
|
914 |
+
text = text.lower()
|
915 |
+
text_list = []
|
916 |
+
for char in text:
|
917 |
+
if char not in self.dict:
|
918 |
+
text_list.append(self.unknown_idx)
|
919 |
+
else:
|
920 |
+
text_list.append(self.dict[char])
|
921 |
+
text_list.append(self.end_idx)
|
922 |
+
if len(text_list) < self.max_text_len:
|
923 |
+
text_list += [self.padding_idx] * (
|
924 |
+
self.max_text_len - len(text_list))
|
925 |
+
return text_list
|
926 |
+
|
927 |
+
def __call__(self, data):
|
928 |
+
text = data['label']
|
929 |
+
encoded_text = self.encode(text)
|
930 |
+
if encoded_text is None:
|
931 |
+
return None
|
932 |
+
data['label'] = np.array(encoded_text)
|
933 |
+
return data
|
934 |
+
|
935 |
+
|
936 |
+
class VQATokenLabelEncode(object):
|
937 |
+
"""
|
938 |
+
Label encode for NLP VQA methods
|
939 |
+
"""
|
940 |
+
|
941 |
+
def __init__(self,
|
942 |
+
class_path,
|
943 |
+
contains_re=False,
|
944 |
+
add_special_ids=False,
|
945 |
+
algorithm='LayoutXLM',
|
946 |
+
use_textline_bbox_info=True,
|
947 |
+
order_method=None,
|
948 |
+
infer_mode=False,
|
949 |
+
ocr_engine=None,
|
950 |
+
**kwargs):
|
951 |
+
super(VQATokenLabelEncode, self).__init__()
|
952 |
+
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer
|
953 |
+
from ppocr.utils.utility import load_vqa_bio_label_maps
|
954 |
+
tokenizer_dict = {
|
955 |
+
'LayoutXLM': {
|
956 |
+
'class': LayoutXLMTokenizer,
|
957 |
+
'pretrained_model': 'layoutxlm-base-uncased'
|
958 |
+
},
|
959 |
+
'LayoutLM': {
|
960 |
+
'class': LayoutLMTokenizer,
|
961 |
+
'pretrained_model': 'layoutlm-base-uncased'
|
962 |
+
},
|
963 |
+
'LayoutLMv2': {
|
964 |
+
'class': LayoutLMv2Tokenizer,
|
965 |
+
'pretrained_model': 'layoutlmv2-base-uncased'
|
966 |
+
}
|
967 |
+
}
|
968 |
+
self.contains_re = contains_re
|
969 |
+
tokenizer_config = tokenizer_dict[algorithm]
|
970 |
+
self.tokenizer = tokenizer_config['class'].from_pretrained(
|
971 |
+
tokenizer_config['pretrained_model'])
|
972 |
+
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
|
973 |
+
self.add_special_ids = add_special_ids
|
974 |
+
self.infer_mode = infer_mode
|
975 |
+
self.ocr_engine = ocr_engine
|
976 |
+
self.use_textline_bbox_info = use_textline_bbox_info
|
977 |
+
self.order_method = order_method
|
978 |
+
assert self.order_method in [None, "tb-yx"]
|
979 |
+
|
980 |
+
def split_bbox(self, bbox, text, tokenizer):
|
981 |
+
words = text.split()
|
982 |
+
token_bboxes = []
|
983 |
+
curr_word_idx = 0
|
984 |
+
x1, y1, x2, y2 = bbox
|
985 |
+
unit_w = (x2 - x1) / len(text)
|
986 |
+
for idx, word in enumerate(words):
|
987 |
+
curr_w = len(word) * unit_w
|
988 |
+
word_bbox = [x1, y1, x1 + curr_w, y2]
|
989 |
+
token_bboxes.extend([word_bbox] * len(tokenizer.tokenize(word)))
|
990 |
+
x1 += (len(word) + 1) * unit_w
|
991 |
+
return token_bboxes
|
992 |
+
|
993 |
+
def filter_empty_contents(self, ocr_info):
|
994 |
+
"""
|
995 |
+
find out the empty texts and remove the links
|
996 |
+
"""
|
997 |
+
new_ocr_info = []
|
998 |
+
empty_index = []
|
999 |
+
for idx, info in enumerate(ocr_info):
|
1000 |
+
if len(info["transcription"]) > 0:
|
1001 |
+
new_ocr_info.append(copy.deepcopy(info))
|
1002 |
+
else:
|
1003 |
+
empty_index.append(info["id"])
|
1004 |
+
|
1005 |
+
for idx, info in enumerate(new_ocr_info):
|
1006 |
+
new_link = []
|
1007 |
+
for link in info["linking"]:
|
1008 |
+
if link[0] in empty_index or link[1] in empty_index:
|
1009 |
+
continue
|
1010 |
+
new_link.append(link)
|
1011 |
+
new_ocr_info[idx]["linking"] = new_link
|
1012 |
+
return new_ocr_info
|
1013 |
+
|
1014 |
+
def __call__(self, data):
|
1015 |
+
# load bbox and label info
|
1016 |
+
ocr_info = self._load_ocr_info(data)
|
1017 |
+
|
1018 |
+
for idx in range(len(ocr_info)):
|
1019 |
+
if "bbox" not in ocr_info[idx]:
|
1020 |
+
ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][
|
1021 |
+
"points"])
|
1022 |
+
|
1023 |
+
if self.order_method == "tb-yx":
|
1024 |
+
ocr_info = order_by_tbyx(ocr_info)
|
1025 |
+
|
1026 |
+
# for re
|
1027 |
+
train_re = self.contains_re and not self.infer_mode
|
1028 |
+
if train_re:
|
1029 |
+
ocr_info = self.filter_empty_contents(ocr_info)
|
1030 |
+
|
1031 |
+
height, width, _ = data['image'].shape
|
1032 |
+
|
1033 |
+
words_list = []
|
1034 |
+
bbox_list = []
|
1035 |
+
input_ids_list = []
|
1036 |
+
token_type_ids_list = []
|
1037 |
+
segment_offset_id = []
|
1038 |
+
gt_label_list = []
|
1039 |
+
|
1040 |
+
entities = []
|
1041 |
+
|
1042 |
+
if train_re:
|
1043 |
+
relations = []
|
1044 |
+
id2label = {}
|
1045 |
+
entity_id_to_index_map = {}
|
1046 |
+
empty_entity = set()
|
1047 |
+
|
1048 |
+
data['ocr_info'] = copy.deepcopy(ocr_info)
|
1049 |
+
|
1050 |
+
for info in ocr_info:
|
1051 |
+
text = info["transcription"]
|
1052 |
+
if len(text) <= 0:
|
1053 |
+
continue
|
1054 |
+
if train_re:
|
1055 |
+
# for re
|
1056 |
+
if len(text) == 0:
|
1057 |
+
empty_entity.add(info["id"])
|
1058 |
+
continue
|
1059 |
+
id2label[info["id"]] = info["label"]
|
1060 |
+
relations.extend([tuple(sorted(l)) for l in info["linking"]])
|
1061 |
+
# smooth_box
|
1062 |
+
info["bbox"] = self.trans_poly_to_bbox(info["points"])
|
1063 |
+
|
1064 |
+
encode_res = self.tokenizer.encode(
|
1065 |
+
text,
|
1066 |
+
pad_to_max_seq_len=False,
|
1067 |
+
return_attention_mask=True,
|
1068 |
+
return_token_type_ids=True)
|
1069 |
+
|
1070 |
+
if not self.add_special_ids:
|
1071 |
+
# TODO: use tok.all_special_ids to remove
|
1072 |
+
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
|
1073 |
+
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
|
1074 |
+
-1]
|
1075 |
+
encode_res["attention_mask"] = encode_res["attention_mask"][1:
|
1076 |
+
-1]
|
1077 |
+
|
1078 |
+
if self.use_textline_bbox_info:
|
1079 |
+
bbox = [info["bbox"]] * len(encode_res["input_ids"])
|
1080 |
+
else:
|
1081 |
+
bbox = self.split_bbox(info["bbox"], info["transcription"],
|
1082 |
+
self.tokenizer)
|
1083 |
+
if len(bbox) <= 0:
|
1084 |
+
continue
|
1085 |
+
bbox = self._smooth_box(bbox, height, width)
|
1086 |
+
if self.add_special_ids:
|
1087 |
+
bbox.insert(0, [0, 0, 0, 0])
|
1088 |
+
bbox.append([0, 0, 0, 0])
|
1089 |
+
|
1090 |
+
# parse label
|
1091 |
+
if not self.infer_mode:
|
1092 |
+
label = info['label']
|
1093 |
+
gt_label = self._parse_label(label, encode_res)
|
1094 |
+
|
1095 |
+
# construct entities for re
|
1096 |
+
if train_re:
|
1097 |
+
if gt_label[0] != self.label2id_map["O"]:
|
1098 |
+
entity_id_to_index_map[info["id"]] = len(entities)
|
1099 |
+
label = label.upper()
|
1100 |
+
entities.append({
|
1101 |
+
"start": len(input_ids_list),
|
1102 |
+
"end":
|
1103 |
+
len(input_ids_list) + len(encode_res["input_ids"]),
|
1104 |
+
"label": label.upper(),
|
1105 |
+
})
|
1106 |
+
else:
|
1107 |
+
entities.append({
|
1108 |
+
"start": len(input_ids_list),
|
1109 |
+
"end": len(input_ids_list) + len(encode_res["input_ids"]),
|
1110 |
+
"label": 'O',
|
1111 |
+
})
|
1112 |
+
input_ids_list.extend(encode_res["input_ids"])
|
1113 |
+
token_type_ids_list.extend(encode_res["token_type_ids"])
|
1114 |
+
bbox_list.extend(bbox)
|
1115 |
+
words_list.append(text)
|
1116 |
+
segment_offset_id.append(len(input_ids_list))
|
1117 |
+
if not self.infer_mode:
|
1118 |
+
gt_label_list.extend(gt_label)
|
1119 |
+
|
1120 |
+
data['input_ids'] = input_ids_list
|
1121 |
+
data['token_type_ids'] = token_type_ids_list
|
1122 |
+
data['bbox'] = bbox_list
|
1123 |
+
data['attention_mask'] = [1] * len(input_ids_list)
|
1124 |
+
data['labels'] = gt_label_list
|
1125 |
+
data['segment_offset_id'] = segment_offset_id
|
1126 |
+
data['tokenizer_params'] = dict(
|
1127 |
+
padding_side=self.tokenizer.padding_side,
|
1128 |
+
pad_token_type_id=self.tokenizer.pad_token_type_id,
|
1129 |
+
pad_token_id=self.tokenizer.pad_token_id)
|
1130 |
+
data['entities'] = entities
|
1131 |
+
|
1132 |
+
if train_re:
|
1133 |
+
data['relations'] = relations
|
1134 |
+
data['id2label'] = id2label
|
1135 |
+
data['empty_entity'] = empty_entity
|
1136 |
+
data['entity_id_to_index_map'] = entity_id_to_index_map
|
1137 |
+
return data
|
1138 |
+
|
1139 |
+
def trans_poly_to_bbox(self, poly):
|
1140 |
+
x1 = int(np.min([p[0] for p in poly]))
|
1141 |
+
x2 = int(np.max([p[0] for p in poly]))
|
1142 |
+
y1 = int(np.min([p[1] for p in poly]))
|
1143 |
+
y2 = int(np.max([p[1] for p in poly]))
|
1144 |
+
return [x1, y1, x2, y2]
|
1145 |
+
|
1146 |
+
def _load_ocr_info(self, data):
|
1147 |
+
if self.infer_mode:
|
1148 |
+
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)[0]
|
1149 |
+
ocr_info = []
|
1150 |
+
for res in ocr_result:
|
1151 |
+
ocr_info.append({
|
1152 |
+
"transcription": res[1][0],
|
1153 |
+
"bbox": self.trans_poly_to_bbox(res[0]),
|
1154 |
+
"points": res[0],
|
1155 |
+
})
|
1156 |
+
return ocr_info
|
1157 |
+
else:
|
1158 |
+
info = data['label']
|
1159 |
+
# read text info
|
1160 |
+
info_dict = json.loads(info)
|
1161 |
+
return info_dict
|
1162 |
+
|
1163 |
+
def _smooth_box(self, bboxes, height, width):
|
1164 |
+
bboxes = np.array(bboxes)
|
1165 |
+
bboxes[:, 0] = bboxes[:, 0] * 1000 / width
|
1166 |
+
bboxes[:, 2] = bboxes[:, 2] * 1000 / width
|
1167 |
+
bboxes[:, 1] = bboxes[:, 1] * 1000 / height
|
1168 |
+
bboxes[:, 3] = bboxes[:, 3] * 1000 / height
|
1169 |
+
bboxes = bboxes.astype("int64").tolist()
|
1170 |
+
return bboxes
|
1171 |
+
|
1172 |
+
def _parse_label(self, label, encode_res):
|
1173 |
+
gt_label = []
|
1174 |
+
if label.lower() in ["other", "others", "ignore"]:
|
1175 |
+
gt_label.extend([0] * len(encode_res["input_ids"]))
|
1176 |
+
else:
|
1177 |
+
gt_label.append(self.label2id_map[("b-" + label).upper()])
|
1178 |
+
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
|
1179 |
+
(len(encode_res["input_ids"]) - 1))
|
1180 |
+
return gt_label
|
1181 |
+
|
1182 |
+
|
1183 |
+
class MultiLabelEncode(BaseRecLabelEncode):
|
1184 |
+
def __init__(self,
|
1185 |
+
max_text_length,
|
1186 |
+
character_dict_path=None,
|
1187 |
+
use_space_char=False,
|
1188 |
+
**kwargs):
|
1189 |
+
super(MultiLabelEncode, self).__init__(
|
1190 |
+
max_text_length, character_dict_path, use_space_char)
|
1191 |
+
|
1192 |
+
self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path,
|
1193 |
+
use_space_char, **kwargs)
|
1194 |
+
self.sar_encode = SARLabelEncode(max_text_length, character_dict_path,
|
1195 |
+
use_space_char, **kwargs)
|
1196 |
+
|
1197 |
+
def __call__(self, data):
|
1198 |
+
data_ctc = copy.deepcopy(data)
|
1199 |
+
data_sar = copy.deepcopy(data)
|
1200 |
+
data_out = dict()
|
1201 |
+
data_out['img_path'] = data.get('img_path', None)
|
1202 |
+
data_out['image'] = data['image']
|
1203 |
+
ctc = self.ctc_encode.__call__(data_ctc)
|
1204 |
+
sar = self.sar_encode.__call__(data_sar)
|
1205 |
+
if ctc is None or sar is None:
|
1206 |
+
return None
|
1207 |
+
data_out['label_ctc'] = ctc['label']
|
1208 |
+
data_out['label_sar'] = sar['label']
|
1209 |
+
data_out['length'] = ctc['length']
|
1210 |
+
return data_out
|
1211 |
+
|
1212 |
+
|
1213 |
+
class NRTRLabelEncode(BaseRecLabelEncode):
|
1214 |
+
""" Convert between text-label and text-index """
|
1215 |
+
|
1216 |
+
def __init__(self,
|
1217 |
+
max_text_length,
|
1218 |
+
character_dict_path=None,
|
1219 |
+
use_space_char=False,
|
1220 |
+
**kwargs):
|
1221 |
+
|
1222 |
+
super(NRTRLabelEncode, self).__init__(
|
1223 |
+
max_text_length, character_dict_path, use_space_char)
|
1224 |
+
|
1225 |
+
def __call__(self, data):
|
1226 |
+
text = data['label']
|
1227 |
+
text = self.encode(text)
|
1228 |
+
if text is None:
|
1229 |
+
return None
|
1230 |
+
if len(text) >= self.max_text_len - 1:
|
1231 |
+
return None
|
1232 |
+
data['length'] = np.array(len(text))
|
1233 |
+
text.insert(0, 2)
|
1234 |
+
text.append(3)
|
1235 |
+
text = text + [0] * (self.max_text_len - len(text))
|
1236 |
+
data['label'] = np.array(text)
|
1237 |
+
return data
|
1238 |
+
|
1239 |
+
def add_special_char(self, dict_character):
|
1240 |
+
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
|
1241 |
+
return dict_character
|
1242 |
+
|
1243 |
+
|
1244 |
+
class ViTSTRLabelEncode(BaseRecLabelEncode):
|
1245 |
+
""" Convert between text-label and text-index """
|
1246 |
+
|
1247 |
+
def __init__(self,
|
1248 |
+
max_text_length,
|
1249 |
+
character_dict_path=None,
|
1250 |
+
use_space_char=False,
|
1251 |
+
ignore_index=0,
|
1252 |
+
**kwargs):
|
1253 |
+
|
1254 |
+
super(ViTSTRLabelEncode, self).__init__(
|
1255 |
+
max_text_length, character_dict_path, use_space_char)
|
1256 |
+
self.ignore_index = ignore_index
|
1257 |
+
|
1258 |
+
def __call__(self, data):
|
1259 |
+
text = data['label']
|
1260 |
+
text = self.encode(text)
|
1261 |
+
if text is None:
|
1262 |
+
return None
|
1263 |
+
if len(text) >= self.max_text_len:
|
1264 |
+
return None
|
1265 |
+
data['length'] = np.array(len(text))
|
1266 |
+
text.insert(0, self.ignore_index)
|
1267 |
+
text.append(1)
|
1268 |
+
text = text + [self.ignore_index] * (self.max_text_len + 2 - len(text))
|
1269 |
+
data['label'] = np.array(text)
|
1270 |
+
return data
|
1271 |
+
|
1272 |
+
def add_special_char(self, dict_character):
|
1273 |
+
dict_character = ['<s>', '</s>'] + dict_character
|
1274 |
+
return dict_character
|
1275 |
+
|
1276 |
+
|
1277 |
+
class ABINetLabelEncode(BaseRecLabelEncode):
|
1278 |
+
""" Convert between text-label and text-index """
|
1279 |
+
|
1280 |
+
def __init__(self,
|
1281 |
+
max_text_length,
|
1282 |
+
character_dict_path=None,
|
1283 |
+
use_space_char=False,
|
1284 |
+
ignore_index=100,
|
1285 |
+
**kwargs):
|
1286 |
+
|
1287 |
+
super(ABINetLabelEncode, self).__init__(
|
1288 |
+
max_text_length, character_dict_path, use_space_char)
|
1289 |
+
self.ignore_index = ignore_index
|
1290 |
+
|
1291 |
+
def __call__(self, data):
|
1292 |
+
text = data['label']
|
1293 |
+
text = self.encode(text)
|
1294 |
+
if text is None:
|
1295 |
+
return None
|
1296 |
+
if len(text) >= self.max_text_len:
|
1297 |
+
return None
|
1298 |
+
data['length'] = np.array(len(text))
|
1299 |
+
text.append(0)
|
1300 |
+
text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
|
1301 |
+
data['label'] = np.array(text)
|
1302 |
+
return data
|
1303 |
+
|
1304 |
+
def add_special_char(self, dict_character):
|
1305 |
+
dict_character = ['</s>'] + dict_character
|
1306 |
+
return dict_character
|
1307 |
+
|
1308 |
+
|
1309 |
+
class SRLabelEncode(BaseRecLabelEncode):
|
1310 |
+
def __init__(self,
|
1311 |
+
max_text_length,
|
1312 |
+
character_dict_path=None,
|
1313 |
+
use_space_char=False,
|
1314 |
+
**kwargs):
|
1315 |
+
super(SRLabelEncode, self).__init__(max_text_length,
|
1316 |
+
character_dict_path, use_space_char)
|
1317 |
+
self.dic = {}
|
1318 |
+
with open(character_dict_path, 'r') as fin:
|
1319 |
+
for line in fin.readlines():
|
1320 |
+
line = line.strip()
|
1321 |
+
character, sequence = line.split()
|
1322 |
+
self.dic[character] = sequence
|
1323 |
+
english_stroke_alphabet = '0123456789'
|
1324 |
+
self.english_stroke_dict = {}
|
1325 |
+
for index in range(len(english_stroke_alphabet)):
|
1326 |
+
self.english_stroke_dict[english_stroke_alphabet[index]] = index
|
1327 |
+
|
1328 |
+
def encode(self, label):
|
1329 |
+
stroke_sequence = ''
|
1330 |
+
for character in label:
|
1331 |
+
if character not in self.dic:
|
1332 |
+
continue
|
1333 |
+
else:
|
1334 |
+
stroke_sequence += self.dic[character]
|
1335 |
+
stroke_sequence += '0'
|
1336 |
+
label = stroke_sequence
|
1337 |
+
|
1338 |
+
length = len(label)
|
1339 |
+
|
1340 |
+
input_tensor = np.zeros(self.max_text_len).astype("int64")
|
1341 |
+
for j in range(length - 1):
|
1342 |
+
input_tensor[j + 1] = self.english_stroke_dict[label[j]]
|
1343 |
+
|
1344 |
+
return length, input_tensor
|
1345 |
+
|
1346 |
+
def __call__(self, data):
|
1347 |
+
text = data['label']
|
1348 |
+
length, input_tensor = self.encode(text)
|
1349 |
+
|
1350 |
+
data["length"] = length
|
1351 |
+
data["input_tensor"] = input_tensor
|
1352 |
+
if text is None:
|
1353 |
+
return None
|
1354 |
+
return data
|
1355 |
+
|
1356 |
+
|
1357 |
+
class SPINLabelEncode(AttnLabelEncode):
|
1358 |
+
""" Convert between text-label and text-index """
|
1359 |
+
|
1360 |
+
def __init__(self,
|
1361 |
+
max_text_length,
|
1362 |
+
character_dict_path=None,
|
1363 |
+
use_space_char=False,
|
1364 |
+
lower=True,
|
1365 |
+
**kwargs):
|
1366 |
+
super(SPINLabelEncode, self).__init__(
|
1367 |
+
max_text_length, character_dict_path, use_space_char)
|
1368 |
+
self.lower = lower
|
1369 |
+
|
1370 |
+
def add_special_char(self, dict_character):
|
1371 |
+
self.beg_str = "sos"
|
1372 |
+
self.end_str = "eos"
|
1373 |
+
dict_character = [self.beg_str] + [self.end_str] + dict_character
|
1374 |
+
return dict_character
|
1375 |
+
|
1376 |
+
def __call__(self, data):
|
1377 |
+
text = data['label']
|
1378 |
+
text = self.encode(text)
|
1379 |
+
if text is None:
|
1380 |
+
return None
|
1381 |
+
if len(text) > self.max_text_len:
|
1382 |
+
return None
|
1383 |
+
data['length'] = np.array(len(text))
|
1384 |
+
target = [0] + text + [1]
|
1385 |
+
padded_text = [0 for _ in range(self.max_text_len + 2)]
|
1386 |
+
|
1387 |
+
padded_text[:len(target)] = target
|
1388 |
+
data['label'] = np.array(padded_text)
|
1389 |
+
return data
|
1390 |
+
|
1391 |
+
|
1392 |
+
class VLLabelEncode(BaseRecLabelEncode):
|
1393 |
+
""" Convert between text-label and text-index """
|
1394 |
+
|
1395 |
+
def __init__(self,
|
1396 |
+
max_text_length,
|
1397 |
+
character_dict_path=None,
|
1398 |
+
use_space_char=False,
|
1399 |
+
**kwargs):
|
1400 |
+
super(VLLabelEncode, self).__init__(max_text_length,
|
1401 |
+
character_dict_path, use_space_char)
|
1402 |
+
self.dict = {}
|
1403 |
+
for i, char in enumerate(self.character):
|
1404 |
+
self.dict[char] = i
|
1405 |
+
|
1406 |
+
def __call__(self, data):
|
1407 |
+
text = data['label'] # original string
|
1408 |
+
# generate occluded text
|
1409 |
+
len_str = len(text)
|
1410 |
+
if len_str <= 0:
|
1411 |
+
return None
|
1412 |
+
change_num = 1
|
1413 |
+
order = list(range(len_str))
|
1414 |
+
change_id = sample(order, change_num)[0]
|
1415 |
+
label_sub = text[change_id]
|
1416 |
+
if change_id == (len_str - 1):
|
1417 |
+
label_res = text[:change_id]
|
1418 |
+
elif change_id == 0:
|
1419 |
+
label_res = text[1:]
|
1420 |
+
else:
|
1421 |
+
label_res = text[:change_id] + text[change_id + 1:]
|
1422 |
+
|
1423 |
+
data['label_res'] = label_res # remaining string
|
1424 |
+
data['label_sub'] = label_sub # occluded character
|
1425 |
+
data['label_id'] = change_id # character index
|
1426 |
+
# encode label
|
1427 |
+
text = self.encode(text)
|
1428 |
+
if text is None:
|
1429 |
+
return None
|
1430 |
+
text = [i + 1 for i in text]
|
1431 |
+
data['length'] = np.array(len(text))
|
1432 |
+
text = text + [0] * (self.max_text_len - len(text))
|
1433 |
+
data['label'] = np.array(text)
|
1434 |
+
label_res = self.encode(label_res)
|
1435 |
+
label_sub = self.encode(label_sub)
|
1436 |
+
if label_res is None:
|
1437 |
+
label_res = []
|
1438 |
+
else:
|
1439 |
+
label_res = [i + 1 for i in label_res]
|
1440 |
+
if label_sub is None:
|
1441 |
+
label_sub = []
|
1442 |
+
else:
|
1443 |
+
label_sub = [i + 1 for i in label_sub]
|
1444 |
+
data['length_res'] = np.array(len(label_res))
|
1445 |
+
data['length_sub'] = np.array(len(label_sub))
|
1446 |
+
label_res = label_res + [0] * (self.max_text_len - len(label_res))
|
1447 |
+
label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
|
1448 |
+
data['label_res'] = np.array(label_res)
|
1449 |
+
data['label_sub'] = np.array(label_sub)
|
1450 |
+
return data
|
1451 |
+
|
1452 |
+
|
1453 |
+
class CTLabelEncode(object):
|
1454 |
+
def __init__(self, **kwargs):
|
1455 |
+
pass
|
1456 |
+
|
1457 |
+
def __call__(self, data):
|
1458 |
+
label = data['label']
|
1459 |
+
|
1460 |
+
label = json.loads(label)
|
1461 |
+
nBox = len(label)
|
1462 |
+
boxes, txts = [], []
|
1463 |
+
for bno in range(0, nBox):
|
1464 |
+
box = label[bno]['points']
|
1465 |
+
box = np.array(box)
|
1466 |
+
|
1467 |
+
boxes.append(box)
|
1468 |
+
txt = label[bno]['transcription']
|
1469 |
+
txts.append(txt)
|
1470 |
+
|
1471 |
+
if len(boxes) == 0:
|
1472 |
+
return None
|
1473 |
+
|
1474 |
+
data['polys'] = boxes
|
1475 |
+
data['texts'] = txts
|
1476 |
+
return data
|
1477 |
+
|
1478 |
+
|
1479 |
+
class CANLabelEncode(BaseRecLabelEncode):
|
1480 |
+
def __init__(self,
|
1481 |
+
character_dict_path,
|
1482 |
+
max_text_length=100,
|
1483 |
+
use_space_char=False,
|
1484 |
+
lower=True,
|
1485 |
+
**kwargs):
|
1486 |
+
super(CANLabelEncode, self).__init__(
|
1487 |
+
max_text_length, character_dict_path, use_space_char, lower)
|
1488 |
+
|
1489 |
+
def encode(self, text_seq):
|
1490 |
+
text_seq_encoded = []
|
1491 |
+
for text in text_seq:
|
1492 |
+
if text not in self.character:
|
1493 |
+
continue
|
1494 |
+
text_seq_encoded.append(self.dict.get(text))
|
1495 |
+
if len(text_seq_encoded) == 0:
|
1496 |
+
return None
|
1497 |
+
return text_seq_encoded
|
1498 |
+
|
1499 |
+
def __call__(self, data):
|
1500 |
+
label = data['label']
|
1501 |
+
if isinstance(label, str):
|
1502 |
+
label = label.strip().split()
|
1503 |
+
label.append(self.end_str)
|
1504 |
+
data['label'] = self.encode(label)
|
1505 |
+
return data
|
ppocr/data/imaug/make_border_map.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is refer from:
|
16 |
+
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py
|
17 |
+
"""
|
18 |
+
|
19 |
+
from __future__ import absolute_import
|
20 |
+
from __future__ import division
|
21 |
+
from __future__ import print_function
|
22 |
+
from __future__ import unicode_literals
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import cv2
|
26 |
+
|
27 |
+
np.seterr(divide='ignore', invalid='ignore')
|
28 |
+
import pyclipper
|
29 |
+
from shapely.geometry import Polygon
|
30 |
+
import sys
|
31 |
+
import warnings
|
32 |
+
|
33 |
+
warnings.simplefilter("ignore")
|
34 |
+
|
35 |
+
__all__ = ['MakeBorderMap']
|
36 |
+
|
37 |
+
|
38 |
+
class MakeBorderMap(object):
|
39 |
+
def __init__(self,
|
40 |
+
shrink_ratio=0.4,
|
41 |
+
thresh_min=0.3,
|
42 |
+
thresh_max=0.7,
|
43 |
+
**kwargs):
|
44 |
+
self.shrink_ratio = shrink_ratio
|
45 |
+
self.thresh_min = thresh_min
|
46 |
+
self.thresh_max = thresh_max
|
47 |
+
|
48 |
+
def __call__(self, data):
|
49 |
+
|
50 |
+
img = data['image']
|
51 |
+
text_polys = data['polys']
|
52 |
+
ignore_tags = data['ignore_tags']
|
53 |
+
|
54 |
+
canvas = np.zeros(img.shape[:2], dtype=np.float32)
|
55 |
+
mask = np.zeros(img.shape[:2], dtype=np.float32)
|
56 |
+
|
57 |
+
for i in range(len(text_polys)):
|
58 |
+
if ignore_tags[i]:
|
59 |
+
continue
|
60 |
+
self.draw_border_map(text_polys[i], canvas, mask=mask)
|
61 |
+
canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
|
62 |
+
|
63 |
+
data['threshold_map'] = canvas
|
64 |
+
data['threshold_mask'] = mask
|
65 |
+
return data
|
66 |
+
|
67 |
+
def draw_border_map(self, polygon, canvas, mask):
|
68 |
+
polygon = np.array(polygon)
|
69 |
+
assert polygon.ndim == 2
|
70 |
+
assert polygon.shape[1] == 2
|
71 |
+
|
72 |
+
polygon_shape = Polygon(polygon)
|
73 |
+
if polygon_shape.area <= 0:
|
74 |
+
return
|
75 |
+
distance = polygon_shape.area * (
|
76 |
+
1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
|
77 |
+
subject = [tuple(l) for l in polygon]
|
78 |
+
padding = pyclipper.PyclipperOffset()
|
79 |
+
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
80 |
+
|
81 |
+
padded_polygon = np.array(padding.Execute(distance)[0])
|
82 |
+
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
|
83 |
+
|
84 |
+
xmin = padded_polygon[:, 0].min()
|
85 |
+
xmax = padded_polygon[:, 0].max()
|
86 |
+
ymin = padded_polygon[:, 1].min()
|
87 |
+
ymax = padded_polygon[:, 1].max()
|
88 |
+
width = xmax - xmin + 1
|
89 |
+
height = ymax - ymin + 1
|
90 |
+
|
91 |
+
polygon[:, 0] = polygon[:, 0] - xmin
|
92 |
+
polygon[:, 1] = polygon[:, 1] - ymin
|
93 |
+
|
94 |
+
xs = np.broadcast_to(
|
95 |
+
np.linspace(
|
96 |
+
0, width - 1, num=width).reshape(1, width), (height, width))
|
97 |
+
ys = np.broadcast_to(
|
98 |
+
np.linspace(
|
99 |
+
0, height - 1, num=height).reshape(height, 1), (height, width))
|
100 |
+
|
101 |
+
distance_map = np.zeros(
|
102 |
+
(polygon.shape[0], height, width), dtype=np.float32)
|
103 |
+
for i in range(polygon.shape[0]):
|
104 |
+
j = (i + 1) % polygon.shape[0]
|
105 |
+
absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
|
106 |
+
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
|
107 |
+
distance_map = distance_map.min(axis=0)
|
108 |
+
|
109 |
+
xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
|
110 |
+
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
|
111 |
+
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
|
112 |
+
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
|
113 |
+
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
|
114 |
+
1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
|
115 |
+
xmin_valid - xmin:xmax_valid - xmax + width],
|
116 |
+
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
|
117 |
+
|
118 |
+
def _distance(self, xs, ys, point_1, point_2):
|
119 |
+
'''
|
120 |
+
compute the distance from point to a line
|
121 |
+
ys: coordinates in the first axis
|
122 |
+
xs: coordinates in the second axis
|
123 |
+
point_1, point_2: (x, y), the end of the line
|
124 |
+
'''
|
125 |
+
height, width = xs.shape[:2]
|
126 |
+
square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[
|
127 |
+
1])
|
128 |
+
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[
|
129 |
+
1])
|
130 |
+
square_distance = np.square(point_1[0] - point_2[0]) + np.square(
|
131 |
+
point_1[1] - point_2[1])
|
132 |
+
|
133 |
+
cosin = (square_distance - square_distance_1 - square_distance_2) / (
|
134 |
+
2 * np.sqrt(square_distance_1 * square_distance_2))
|
135 |
+
square_sin = 1 - np.square(cosin)
|
136 |
+
square_sin = np.nan_to_num(square_sin)
|
137 |
+
result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
|
138 |
+
square_distance)
|
139 |
+
|
140 |
+
result[cosin <
|
141 |
+
0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin
|
142 |
+
< 0]
|
143 |
+
# self.extend_line(point_1, point_2, result)
|
144 |
+
return result
|
145 |
+
|
146 |
+
def extend_line(self, point_1, point_2, result, shrink_ratio):
|
147 |
+
ex_point_1 = (int(
|
148 |
+
round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
|
149 |
+
int(
|
150 |
+
round(point_1[1] + (point_1[1] - point_2[1]) * (
|
151 |
+
1 + shrink_ratio))))
|
152 |
+
cv2.line(
|
153 |
+
result,
|
154 |
+
tuple(ex_point_1),
|
155 |
+
tuple(point_1),
|
156 |
+
4096.0,
|
157 |
+
1,
|
158 |
+
lineType=cv2.LINE_AA,
|
159 |
+
shift=0)
|
160 |
+
ex_point_2 = (int(
|
161 |
+
round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
|
162 |
+
int(
|
163 |
+
round(point_2[1] + (point_2[1] - point_1[1]) * (
|
164 |
+
1 + shrink_ratio))))
|
165 |
+
cv2.line(
|
166 |
+
result,
|
167 |
+
tuple(ex_point_2),
|
168 |
+
tuple(point_2),
|
169 |
+
4096.0,
|
170 |
+
1,
|
171 |
+
lineType=cv2.LINE_AA,
|
172 |
+
shift=0)
|
173 |
+
return ex_point_1, ex_point_2
|
ppocr/data/imaug/make_pse_gt.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
from __future__ import unicode_literals
|
19 |
+
|
20 |
+
import cv2
|
21 |
+
import numpy as np
|
22 |
+
import pyclipper
|
23 |
+
from shapely.geometry import Polygon
|
24 |
+
|
25 |
+
__all__ = ['MakePseGt']
|
26 |
+
|
27 |
+
|
28 |
+
class MakePseGt(object):
|
29 |
+
def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
|
30 |
+
self.kernel_num = kernel_num
|
31 |
+
self.min_shrink_ratio = min_shrink_ratio
|
32 |
+
self.size = size
|
33 |
+
|
34 |
+
def __call__(self, data):
|
35 |
+
|
36 |
+
image = data['image']
|
37 |
+
text_polys = data['polys']
|
38 |
+
ignore_tags = data['ignore_tags']
|
39 |
+
|
40 |
+
h, w, _ = image.shape
|
41 |
+
short_edge = min(h, w)
|
42 |
+
if short_edge < self.size:
|
43 |
+
# keep short_size >= self.size
|
44 |
+
scale = self.size / short_edge
|
45 |
+
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
|
46 |
+
text_polys *= scale
|
47 |
+
|
48 |
+
gt_kernels = []
|
49 |
+
for i in range(1, self.kernel_num + 1):
|
50 |
+
# s1->sn, from big to small
|
51 |
+
rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1
|
52 |
+
) * i
|
53 |
+
text_kernel, ignore_tags = self.generate_kernel(
|
54 |
+
image.shape[0:2], rate, text_polys, ignore_tags)
|
55 |
+
gt_kernels.append(text_kernel)
|
56 |
+
|
57 |
+
training_mask = np.ones(image.shape[0:2], dtype='uint8')
|
58 |
+
for i in range(text_polys.shape[0]):
|
59 |
+
if ignore_tags[i]:
|
60 |
+
cv2.fillPoly(training_mask,
|
61 |
+
text_polys[i].astype(np.int32)[np.newaxis, :, :],
|
62 |
+
0)
|
63 |
+
|
64 |
+
gt_kernels = np.array(gt_kernels)
|
65 |
+
gt_kernels[gt_kernels > 0] = 1
|
66 |
+
|
67 |
+
data['image'] = image
|
68 |
+
data['polys'] = text_polys
|
69 |
+
data['gt_kernels'] = gt_kernels[0:]
|
70 |
+
data['gt_text'] = gt_kernels[0]
|
71 |
+
data['mask'] = training_mask.astype('float32')
|
72 |
+
return data
|
73 |
+
|
74 |
+
def generate_kernel(self,
|
75 |
+
img_size,
|
76 |
+
shrink_ratio,
|
77 |
+
text_polys,
|
78 |
+
ignore_tags=None):
|
79 |
+
"""
|
80 |
+
Refer to part of the code:
|
81 |
+
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
|
82 |
+
"""
|
83 |
+
|
84 |
+
h, w = img_size
|
85 |
+
text_kernel = np.zeros((h, w), dtype=np.float32)
|
86 |
+
for i, poly in enumerate(text_polys):
|
87 |
+
polygon = Polygon(poly)
|
88 |
+
distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (
|
89 |
+
polygon.length + 1e-6)
|
90 |
+
subject = [tuple(l) for l in poly]
|
91 |
+
pco = pyclipper.PyclipperOffset()
|
92 |
+
pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
93 |
+
shrinked = np.array(pco.Execute(-distance))
|
94 |
+
|
95 |
+
if len(shrinked) == 0 or shrinked.size == 0:
|
96 |
+
if ignore_tags is not None:
|
97 |
+
ignore_tags[i] = True
|
98 |
+
continue
|
99 |
+
try:
|
100 |
+
shrinked = np.array(shrinked[0]).reshape(-1, 2)
|
101 |
+
except:
|
102 |
+
if ignore_tags is not None:
|
103 |
+
ignore_tags[i] = True
|
104 |
+
continue
|
105 |
+
cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
|
106 |
+
return text_kernel, ignore_tags
|
ppocr/data/imaug/make_shrink_map.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is refer from:
|
16 |
+
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py
|
17 |
+
"""
|
18 |
+
|
19 |
+
from __future__ import absolute_import
|
20 |
+
from __future__ import division
|
21 |
+
from __future__ import print_function
|
22 |
+
from __future__ import unicode_literals
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import cv2
|
26 |
+
from shapely.geometry import Polygon
|
27 |
+
import pyclipper
|
28 |
+
|
29 |
+
__all__ = ['MakeShrinkMap']
|
30 |
+
|
31 |
+
|
32 |
+
class MakeShrinkMap(object):
|
33 |
+
r'''
|
34 |
+
Making binary mask from detection data with ICDAR format.
|
35 |
+
Typically following the process of class `MakeICDARData`.
|
36 |
+
'''
|
37 |
+
|
38 |
+
def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
|
39 |
+
self.min_text_size = min_text_size
|
40 |
+
self.shrink_ratio = shrink_ratio
|
41 |
+
|
42 |
+
def __call__(self, data):
|
43 |
+
image = data['image']
|
44 |
+
text_polys = data['polys']
|
45 |
+
ignore_tags = data['ignore_tags']
|
46 |
+
|
47 |
+
h, w = image.shape[:2]
|
48 |
+
text_polys, ignore_tags = self.validate_polygons(text_polys,
|
49 |
+
ignore_tags, h, w)
|
50 |
+
gt = np.zeros((h, w), dtype=np.float32)
|
51 |
+
mask = np.ones((h, w), dtype=np.float32)
|
52 |
+
for i in range(len(text_polys)):
|
53 |
+
polygon = text_polys[i]
|
54 |
+
height = max(polygon[:, 1]) - min(polygon[:, 1])
|
55 |
+
width = max(polygon[:, 0]) - min(polygon[:, 0])
|
56 |
+
if ignore_tags[i] or min(height, width) < self.min_text_size:
|
57 |
+
cv2.fillPoly(mask,
|
58 |
+
polygon.astype(np.int32)[np.newaxis, :, :], 0)
|
59 |
+
ignore_tags[i] = True
|
60 |
+
else:
|
61 |
+
polygon_shape = Polygon(polygon)
|
62 |
+
subject = [tuple(l) for l in polygon]
|
63 |
+
padding = pyclipper.PyclipperOffset()
|
64 |
+
padding.AddPath(subject, pyclipper.JT_ROUND,
|
65 |
+
pyclipper.ET_CLOSEDPOLYGON)
|
66 |
+
shrinked = []
|
67 |
+
|
68 |
+
# Increase the shrink ratio every time we get multiple polygon returned back
|
69 |
+
possible_ratios = np.arange(self.shrink_ratio, 1,
|
70 |
+
self.shrink_ratio)
|
71 |
+
np.append(possible_ratios, 1)
|
72 |
+
# print(possible_ratios)
|
73 |
+
for ratio in possible_ratios:
|
74 |
+
# print(f"Change shrink ratio to {ratio}")
|
75 |
+
distance = polygon_shape.area * (
|
76 |
+
1 - np.power(ratio, 2)) / polygon_shape.length
|
77 |
+
shrinked = padding.Execute(-distance)
|
78 |
+
if len(shrinked) == 1:
|
79 |
+
break
|
80 |
+
|
81 |
+
if shrinked == []:
|
82 |
+
cv2.fillPoly(mask,
|
83 |
+
polygon.astype(np.int32)[np.newaxis, :, :], 0)
|
84 |
+
ignore_tags[i] = True
|
85 |
+
continue
|
86 |
+
|
87 |
+
for each_shirnk in shrinked:
|
88 |
+
shirnk = np.array(each_shirnk).reshape(-1, 2)
|
89 |
+
cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
|
90 |
+
|
91 |
+
data['shrink_map'] = gt
|
92 |
+
data['shrink_mask'] = mask
|
93 |
+
return data
|
94 |
+
|
95 |
+
def validate_polygons(self, polygons, ignore_tags, h, w):
|
96 |
+
'''
|
97 |
+
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
|
98 |
+
'''
|
99 |
+
if len(polygons) == 0:
|
100 |
+
return polygons, ignore_tags
|
101 |
+
assert len(polygons) == len(ignore_tags)
|
102 |
+
for polygon in polygons:
|
103 |
+
polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
|
104 |
+
polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
|
105 |
+
|
106 |
+
for i in range(len(polygons)):
|
107 |
+
area = self.polygon_area(polygons[i])
|
108 |
+
if abs(area) < 1:
|
109 |
+
ignore_tags[i] = True
|
110 |
+
if area > 0:
|
111 |
+
polygons[i] = polygons[i][::-1, :]
|
112 |
+
return polygons, ignore_tags
|
113 |
+
|
114 |
+
def polygon_area(self, polygon):
|
115 |
+
"""
|
116 |
+
compute polygon area
|
117 |
+
"""
|
118 |
+
area = 0
|
119 |
+
q = polygon[-1]
|
120 |
+
for p in polygon:
|
121 |
+
area += p[0] * q[1] - p[1] * q[0]
|
122 |
+
q = p
|
123 |
+
return area / 2.0
|
ppocr/data/imaug/operators.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
from __future__ import unicode_literals
|
21 |
+
|
22 |
+
import sys
|
23 |
+
import six
|
24 |
+
import cv2
|
25 |
+
import numpy as np
|
26 |
+
import math
|
27 |
+
from PIL import Image
|
28 |
+
|
29 |
+
|
30 |
+
class DecodeImage(object):
|
31 |
+
""" decode image """
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
img_mode='RGB',
|
35 |
+
channel_first=False,
|
36 |
+
ignore_orientation=False,
|
37 |
+
**kwargs):
|
38 |
+
self.img_mode = img_mode
|
39 |
+
self.channel_first = channel_first
|
40 |
+
self.ignore_orientation = ignore_orientation
|
41 |
+
|
42 |
+
def __call__(self, data):
|
43 |
+
img = data['image']
|
44 |
+
if six.PY2:
|
45 |
+
assert type(img) is str and len(
|
46 |
+
img) > 0, "invalid input 'img' in DecodeImage"
|
47 |
+
else:
|
48 |
+
assert type(img) is bytes and len(
|
49 |
+
img) > 0, "invalid input 'img' in DecodeImage"
|
50 |
+
img = np.frombuffer(img, dtype='uint8')
|
51 |
+
if self.ignore_orientation:
|
52 |
+
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
|
53 |
+
cv2.IMREAD_COLOR)
|
54 |
+
else:
|
55 |
+
img = cv2.imdecode(img, 1)
|
56 |
+
if img is None:
|
57 |
+
return None
|
58 |
+
if self.img_mode == 'GRAY':
|
59 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
60 |
+
elif self.img_mode == 'RGB':
|
61 |
+
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
62 |
+
img = img[:, :, ::-1]
|
63 |
+
|
64 |
+
if self.channel_first:
|
65 |
+
img = img.transpose((2, 0, 1))
|
66 |
+
|
67 |
+
data['image'] = img
|
68 |
+
return data
|
69 |
+
|
70 |
+
|
71 |
+
class NormalizeImage(object):
|
72 |
+
""" normalize image such as substract mean, divide std
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
76 |
+
if isinstance(scale, str):
|
77 |
+
scale = eval(scale)
|
78 |
+
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
79 |
+
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
80 |
+
std = std if std is not None else [0.229, 0.224, 0.225]
|
81 |
+
|
82 |
+
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
83 |
+
self.mean = np.array(mean).reshape(shape).astype('float32')
|
84 |
+
self.std = np.array(std).reshape(shape).astype('float32')
|
85 |
+
|
86 |
+
def __call__(self, data):
|
87 |
+
img = data['image']
|
88 |
+
from PIL import Image
|
89 |
+
if isinstance(img, Image.Image):
|
90 |
+
img = np.array(img)
|
91 |
+
assert isinstance(img,
|
92 |
+
np.ndarray), "invalid input 'img' in NormalizeImage"
|
93 |
+
data['image'] = (
|
94 |
+
img.astype('float32') * self.scale - self.mean) / self.std
|
95 |
+
return data
|
96 |
+
|
97 |
+
|
98 |
+
class ToCHWImage(object):
|
99 |
+
""" convert hwc image to chw image
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, **kwargs):
|
103 |
+
pass
|
104 |
+
|
105 |
+
def __call__(self, data):
|
106 |
+
img = data['image']
|
107 |
+
from PIL import Image
|
108 |
+
if isinstance(img, Image.Image):
|
109 |
+
img = np.array(img)
|
110 |
+
data['image'] = img.transpose((2, 0, 1))
|
111 |
+
return data
|
112 |
+
|
113 |
+
|
114 |
+
class Fasttext(object):
|
115 |
+
def __init__(self, path="None", **kwargs):
|
116 |
+
import fasttext
|
117 |
+
self.fast_model = fasttext.load_model(path)
|
118 |
+
|
119 |
+
def __call__(self, data):
|
120 |
+
label = data['label']
|
121 |
+
fast_label = self.fast_model[label]
|
122 |
+
data['fast_label'] = fast_label
|
123 |
+
return data
|
124 |
+
|
125 |
+
|
126 |
+
class KeepKeys(object):
|
127 |
+
def __init__(self, keep_keys, **kwargs):
|
128 |
+
self.keep_keys = keep_keys
|
129 |
+
|
130 |
+
def __call__(self, data):
|
131 |
+
data_list = []
|
132 |
+
for key in self.keep_keys:
|
133 |
+
data_list.append(data[key])
|
134 |
+
return data_list
|
135 |
+
|
136 |
+
|
137 |
+
class Pad(object):
|
138 |
+
def __init__(self, size=None, size_div=32, **kwargs):
|
139 |
+
if size is not None and not isinstance(size, (int, list, tuple)):
|
140 |
+
raise TypeError("Type of target_size is invalid. Now is {}".format(
|
141 |
+
type(size)))
|
142 |
+
if isinstance(size, int):
|
143 |
+
size = [size, size]
|
144 |
+
self.size = size
|
145 |
+
self.size_div = size_div
|
146 |
+
|
147 |
+
def __call__(self, data):
|
148 |
+
|
149 |
+
img = data['image']
|
150 |
+
img_h, img_w = img.shape[0], img.shape[1]
|
151 |
+
if self.size:
|
152 |
+
resize_h2, resize_w2 = self.size
|
153 |
+
assert (
|
154 |
+
img_h < resize_h2 and img_w < resize_w2
|
155 |
+
), '(h, w) of target size should be greater than (img_h, img_w)'
|
156 |
+
else:
|
157 |
+
resize_h2 = max(
|
158 |
+
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
|
159 |
+
self.size_div)
|
160 |
+
resize_w2 = max(
|
161 |
+
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
|
162 |
+
self.size_div)
|
163 |
+
img = cv2.copyMakeBorder(
|
164 |
+
img,
|
165 |
+
0,
|
166 |
+
resize_h2 - img_h,
|
167 |
+
0,
|
168 |
+
resize_w2 - img_w,
|
169 |
+
cv2.BORDER_CONSTANT,
|
170 |
+
value=0)
|
171 |
+
data['image'] = img
|
172 |
+
return data
|
173 |
+
|
174 |
+
|
175 |
+
class Resize(object):
|
176 |
+
def __init__(self, size=(640, 640), **kwargs):
|
177 |
+
self.size = size
|
178 |
+
|
179 |
+
def resize_image(self, img):
|
180 |
+
resize_h, resize_w = self.size
|
181 |
+
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
182 |
+
ratio_h = float(resize_h) / ori_h
|
183 |
+
ratio_w = float(resize_w) / ori_w
|
184 |
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
185 |
+
return img, [ratio_h, ratio_w]
|
186 |
+
|
187 |
+
def __call__(self, data):
|
188 |
+
img = data['image']
|
189 |
+
if 'polys' in data:
|
190 |
+
text_polys = data['polys']
|
191 |
+
|
192 |
+
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
|
193 |
+
if 'polys' in data:
|
194 |
+
new_boxes = []
|
195 |
+
for box in text_polys:
|
196 |
+
new_box = []
|
197 |
+
for cord in box:
|
198 |
+
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
|
199 |
+
new_boxes.append(new_box)
|
200 |
+
data['polys'] = np.array(new_boxes, dtype=np.float32)
|
201 |
+
data['image'] = img_resize
|
202 |
+
return data
|
203 |
+
|
204 |
+
|
205 |
+
class DetResizeForTest(object):
|
206 |
+
def __init__(self, **kwargs):
|
207 |
+
super(DetResizeForTest, self).__init__()
|
208 |
+
self.resize_type = 0
|
209 |
+
self.keep_ratio = False
|
210 |
+
if 'image_shape' in kwargs:
|
211 |
+
self.image_shape = kwargs['image_shape']
|
212 |
+
self.resize_type = 1
|
213 |
+
if 'keep_ratio' in kwargs:
|
214 |
+
self.keep_ratio = kwargs['keep_ratio']
|
215 |
+
elif 'limit_side_len' in kwargs:
|
216 |
+
self.limit_side_len = kwargs['limit_side_len']
|
217 |
+
self.limit_type = kwargs.get('limit_type', 'min')
|
218 |
+
elif 'resize_long' in kwargs:
|
219 |
+
self.resize_type = 2
|
220 |
+
self.resize_long = kwargs.get('resize_long', 960)
|
221 |
+
else:
|
222 |
+
self.limit_side_len = 736
|
223 |
+
self.limit_type = 'min'
|
224 |
+
|
225 |
+
def __call__(self, data):
|
226 |
+
img = data['image']
|
227 |
+
src_h, src_w, _ = img.shape
|
228 |
+
if sum([src_h, src_w]) < 64:
|
229 |
+
img = self.image_padding(img)
|
230 |
+
|
231 |
+
if self.resize_type == 0:
|
232 |
+
# img, shape = self.resize_image_type0(img)
|
233 |
+
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
234 |
+
elif self.resize_type == 2:
|
235 |
+
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
236 |
+
else:
|
237 |
+
# img, shape = self.resize_image_type1(img)
|
238 |
+
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
239 |
+
data['image'] = img
|
240 |
+
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
241 |
+
return data
|
242 |
+
|
243 |
+
def image_padding(self, im, value=0):
|
244 |
+
h, w, c = im.shape
|
245 |
+
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
|
246 |
+
im_pad[:h, :w, :] = im
|
247 |
+
return im_pad
|
248 |
+
|
249 |
+
def resize_image_type1(self, img):
|
250 |
+
resize_h, resize_w = self.image_shape
|
251 |
+
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
252 |
+
if self.keep_ratio is True:
|
253 |
+
resize_w = ori_w * resize_h / ori_h
|
254 |
+
N = math.ceil(resize_w / 32)
|
255 |
+
resize_w = N * 32
|
256 |
+
ratio_h = float(resize_h) / ori_h
|
257 |
+
ratio_w = float(resize_w) / ori_w
|
258 |
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
259 |
+
# return img, np.array([ori_h, ori_w])
|
260 |
+
return img, [ratio_h, ratio_w]
|
261 |
+
|
262 |
+
def resize_image_type0(self, img):
|
263 |
+
"""
|
264 |
+
resize image to a size multiple of 32 which is required by the network
|
265 |
+
args:
|
266 |
+
img(array): array with shape [h, w, c]
|
267 |
+
return(tuple):
|
268 |
+
img, (ratio_h, ratio_w)
|
269 |
+
"""
|
270 |
+
limit_side_len = self.limit_side_len
|
271 |
+
h, w, c = img.shape
|
272 |
+
|
273 |
+
# limit the max side
|
274 |
+
if self.limit_type == 'max':
|
275 |
+
if max(h, w) > limit_side_len:
|
276 |
+
if h > w:
|
277 |
+
ratio = float(limit_side_len) / h
|
278 |
+
else:
|
279 |
+
ratio = float(limit_side_len) / w
|
280 |
+
else:
|
281 |
+
ratio = 1.
|
282 |
+
elif self.limit_type == 'min':
|
283 |
+
if min(h, w) < limit_side_len:
|
284 |
+
if h < w:
|
285 |
+
ratio = float(limit_side_len) / h
|
286 |
+
else:
|
287 |
+
ratio = float(limit_side_len) / w
|
288 |
+
else:
|
289 |
+
ratio = 1.
|
290 |
+
elif self.limit_type == 'resize_long':
|
291 |
+
ratio = float(limit_side_len) / max(h, w)
|
292 |
+
else:
|
293 |
+
raise Exception('not support limit type, image ')
|
294 |
+
resize_h = int(h * ratio)
|
295 |
+
resize_w = int(w * ratio)
|
296 |
+
|
297 |
+
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
298 |
+
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
299 |
+
|
300 |
+
try:
|
301 |
+
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
302 |
+
return None, (None, None)
|
303 |
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
304 |
+
except:
|
305 |
+
print(img.shape, resize_w, resize_h)
|
306 |
+
sys.exit(0)
|
307 |
+
ratio_h = resize_h / float(h)
|
308 |
+
ratio_w = resize_w / float(w)
|
309 |
+
return img, [ratio_h, ratio_w]
|
310 |
+
|
311 |
+
def resize_image_type2(self, img):
|
312 |
+
h, w, _ = img.shape
|
313 |
+
|
314 |
+
resize_w = w
|
315 |
+
resize_h = h
|
316 |
+
|
317 |
+
if resize_h > resize_w:
|
318 |
+
ratio = float(self.resize_long) / resize_h
|
319 |
+
else:
|
320 |
+
ratio = float(self.resize_long) / resize_w
|
321 |
+
|
322 |
+
resize_h = int(resize_h * ratio)
|
323 |
+
resize_w = int(resize_w * ratio)
|
324 |
+
|
325 |
+
max_stride = 128
|
326 |
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
327 |
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
328 |
+
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
329 |
+
ratio_h = resize_h / float(h)
|
330 |
+
ratio_w = resize_w / float(w)
|
331 |
+
|
332 |
+
return img, [ratio_h, ratio_w]
|
333 |
+
|
334 |
+
|
335 |
+
class E2EResizeForTest(object):
|
336 |
+
def __init__(self, **kwargs):
|
337 |
+
super(E2EResizeForTest, self).__init__()
|
338 |
+
self.max_side_len = kwargs['max_side_len']
|
339 |
+
self.valid_set = kwargs['valid_set']
|
340 |
+
|
341 |
+
def __call__(self, data):
|
342 |
+
img = data['image']
|
343 |
+
src_h, src_w, _ = img.shape
|
344 |
+
if self.valid_set == 'totaltext':
|
345 |
+
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
|
346 |
+
img, max_side_len=self.max_side_len)
|
347 |
+
else:
|
348 |
+
im_resized, (ratio_h, ratio_w) = self.resize_image(
|
349 |
+
img, max_side_len=self.max_side_len)
|
350 |
+
data['image'] = im_resized
|
351 |
+
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
352 |
+
return data
|
353 |
+
|
354 |
+
def resize_image_for_totaltext(self, im, max_side_len=512):
|
355 |
+
|
356 |
+
h, w, _ = im.shape
|
357 |
+
resize_w = w
|
358 |
+
resize_h = h
|
359 |
+
ratio = 1.25
|
360 |
+
if h * ratio > max_side_len:
|
361 |
+
ratio = float(max_side_len) / resize_h
|
362 |
+
resize_h = int(resize_h * ratio)
|
363 |
+
resize_w = int(resize_w * ratio)
|
364 |
+
|
365 |
+
max_stride = 128
|
366 |
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
367 |
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
368 |
+
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
369 |
+
ratio_h = resize_h / float(h)
|
370 |
+
ratio_w = resize_w / float(w)
|
371 |
+
return im, (ratio_h, ratio_w)
|
372 |
+
|
373 |
+
def resize_image(self, im, max_side_len=512):
|
374 |
+
"""
|
375 |
+
resize image to a size multiple of max_stride which is required by the network
|
376 |
+
:param im: the resized image
|
377 |
+
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
378 |
+
:return: the resized image and the resize ratio
|
379 |
+
"""
|
380 |
+
h, w, _ = im.shape
|
381 |
+
|
382 |
+
resize_w = w
|
383 |
+
resize_h = h
|
384 |
+
|
385 |
+
# Fix the longer side
|
386 |
+
if resize_h > resize_w:
|
387 |
+
ratio = float(max_side_len) / resize_h
|
388 |
+
else:
|
389 |
+
ratio = float(max_side_len) / resize_w
|
390 |
+
|
391 |
+
resize_h = int(resize_h * ratio)
|
392 |
+
resize_w = int(resize_w * ratio)
|
393 |
+
|
394 |
+
max_stride = 128
|
395 |
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
396 |
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
397 |
+
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
398 |
+
ratio_h = resize_h / float(h)
|
399 |
+
ratio_w = resize_w / float(w)
|
400 |
+
|
401 |
+
return im, (ratio_h, ratio_w)
|
402 |
+
|
403 |
+
|
404 |
+
class KieResize(object):
|
405 |
+
def __init__(self, **kwargs):
|
406 |
+
super(KieResize, self).__init__()
|
407 |
+
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
|
408 |
+
'img_scale'][1]
|
409 |
+
|
410 |
+
def __call__(self, data):
|
411 |
+
img = data['image']
|
412 |
+
points = data['points']
|
413 |
+
src_h, src_w, _ = img.shape
|
414 |
+
im_resized, scale_factor, [ratio_h, ratio_w
|
415 |
+
], [new_h, new_w] = self.resize_image(img)
|
416 |
+
resize_points = self.resize_boxes(img, points, scale_factor)
|
417 |
+
data['ori_image'] = img
|
418 |
+
data['ori_boxes'] = points
|
419 |
+
data['points'] = resize_points
|
420 |
+
data['image'] = im_resized
|
421 |
+
data['shape'] = np.array([new_h, new_w])
|
422 |
+
return data
|
423 |
+
|
424 |
+
def resize_image(self, img):
|
425 |
+
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
|
426 |
+
scale = [512, 1024]
|
427 |
+
h, w = img.shape[:2]
|
428 |
+
max_long_edge = max(scale)
|
429 |
+
max_short_edge = min(scale)
|
430 |
+
scale_factor = min(max_long_edge / max(h, w),
|
431 |
+
max_short_edge / min(h, w))
|
432 |
+
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
|
433 |
+
scale_factor) + 0.5)
|
434 |
+
max_stride = 32
|
435 |
+
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
436 |
+
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
437 |
+
im = cv2.resize(img, (resize_w, resize_h))
|
438 |
+
new_h, new_w = im.shape[:2]
|
439 |
+
w_scale = new_w / w
|
440 |
+
h_scale = new_h / h
|
441 |
+
scale_factor = np.array(
|
442 |
+
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
443 |
+
norm_img[:new_h, :new_w, :] = im
|
444 |
+
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
|
445 |
+
|
446 |
+
def resize_boxes(self, im, points, scale_factor):
|
447 |
+
points = points * scale_factor
|
448 |
+
img_shape = im.shape[:2]
|
449 |
+
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
|
450 |
+
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
|
451 |
+
return points
|
452 |
+
|
453 |
+
|
454 |
+
class SRResize(object):
|
455 |
+
def __init__(self,
|
456 |
+
imgH=32,
|
457 |
+
imgW=128,
|
458 |
+
down_sample_scale=4,
|
459 |
+
keep_ratio=False,
|
460 |
+
min_ratio=1,
|
461 |
+
mask=False,
|
462 |
+
infer_mode=False,
|
463 |
+
**kwargs):
|
464 |
+
self.imgH = imgH
|
465 |
+
self.imgW = imgW
|
466 |
+
self.keep_ratio = keep_ratio
|
467 |
+
self.min_ratio = min_ratio
|
468 |
+
self.down_sample_scale = down_sample_scale
|
469 |
+
self.mask = mask
|
470 |
+
self.infer_mode = infer_mode
|
471 |
+
|
472 |
+
def __call__(self, data):
|
473 |
+
imgH = self.imgH
|
474 |
+
imgW = self.imgW
|
475 |
+
images_lr = data["image_lr"]
|
476 |
+
transform2 = ResizeNormalize(
|
477 |
+
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
|
478 |
+
images_lr = transform2(images_lr)
|
479 |
+
data["img_lr"] = images_lr
|
480 |
+
if self.infer_mode:
|
481 |
+
return data
|
482 |
+
|
483 |
+
images_HR = data["image_hr"]
|
484 |
+
label_strs = data["label"]
|
485 |
+
transform = ResizeNormalize((imgW, imgH))
|
486 |
+
images_HR = transform(images_HR)
|
487 |
+
data["img_hr"] = images_HR
|
488 |
+
return data
|
489 |
+
|
490 |
+
|
491 |
+
class ResizeNormalize(object):
|
492 |
+
def __init__(self, size, interpolation=Image.BICUBIC):
|
493 |
+
self.size = size
|
494 |
+
self.interpolation = interpolation
|
495 |
+
|
496 |
+
def __call__(self, img):
|
497 |
+
img = img.resize(self.size, self.interpolation)
|
498 |
+
img_numpy = np.array(img).astype("float32")
|
499 |
+
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
|
500 |
+
return img_numpy
|
501 |
+
|
502 |
+
|
503 |
+
class GrayImageChannelFormat(object):
|
504 |
+
"""
|
505 |
+
format gray scale image's channel: (3,h,w) -> (1,h,w)
|
506 |
+
Args:
|
507 |
+
inverse: inverse gray image
|
508 |
+
"""
|
509 |
+
|
510 |
+
def __init__(self, inverse=False, **kwargs):
|
511 |
+
self.inverse = inverse
|
512 |
+
|
513 |
+
def __call__(self, data):
|
514 |
+
img = data['image']
|
515 |
+
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
516 |
+
img_expanded = np.expand_dims(img_single_channel, 0)
|
517 |
+
|
518 |
+
if self.inverse:
|
519 |
+
data['image'] = np.abs(img_expanded - 1)
|
520 |
+
else:
|
521 |
+
data['image'] = img_expanded
|
522 |
+
|
523 |
+
data['src_image'] = img
|
524 |
+
return data
|
ppocr/data/imaug/pg_process.py
ADDED
@@ -0,0 +1,1034 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import cv2
|
17 |
+
import numpy as np
|
18 |
+
from skimage.morphology._skeletonize import thin
|
19 |
+
from ppocr.utils.e2e_utils.extract_textpoint_fast import sort_and_expand_with_direction_v2
|
20 |
+
|
21 |
+
__all__ = ['PGProcessTrain']
|
22 |
+
|
23 |
+
|
24 |
+
class PGProcessTrain(object):
|
25 |
+
def __init__(self,
|
26 |
+
character_dict_path,
|
27 |
+
max_text_length,
|
28 |
+
max_text_nums,
|
29 |
+
tcl_len,
|
30 |
+
batch_size=14,
|
31 |
+
use_resize=True,
|
32 |
+
use_random_crop=False,
|
33 |
+
min_crop_size=24,
|
34 |
+
min_text_size=4,
|
35 |
+
max_text_size=512,
|
36 |
+
point_gather_mode=None,
|
37 |
+
**kwargs):
|
38 |
+
self.tcl_len = tcl_len
|
39 |
+
self.max_text_length = max_text_length
|
40 |
+
self.max_text_nums = max_text_nums
|
41 |
+
self.batch_size = batch_size
|
42 |
+
if use_random_crop is True:
|
43 |
+
self.min_crop_size = min_crop_size
|
44 |
+
self.use_random_crop = use_random_crop
|
45 |
+
self.min_text_size = min_text_size
|
46 |
+
self.max_text_size = max_text_size
|
47 |
+
self.use_resize = use_resize
|
48 |
+
self.point_gather_mode = point_gather_mode
|
49 |
+
self.Lexicon_Table = self.get_dict(character_dict_path)
|
50 |
+
self.pad_num = len(self.Lexicon_Table)
|
51 |
+
self.img_id = 0
|
52 |
+
|
53 |
+
def get_dict(self, character_dict_path):
|
54 |
+
character_str = ""
|
55 |
+
with open(character_dict_path, "rb") as fin:
|
56 |
+
lines = fin.readlines()
|
57 |
+
for line in lines:
|
58 |
+
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
59 |
+
character_str += line
|
60 |
+
dict_character = list(character_str)
|
61 |
+
return dict_character
|
62 |
+
|
63 |
+
def quad_area(self, poly):
|
64 |
+
"""
|
65 |
+
compute area of a polygon
|
66 |
+
:param poly:
|
67 |
+
:return:
|
68 |
+
"""
|
69 |
+
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
70 |
+
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
71 |
+
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
72 |
+
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
|
73 |
+
return np.sum(edge) / 2.
|
74 |
+
|
75 |
+
def gen_quad_from_poly(self, poly):
|
76 |
+
"""
|
77 |
+
Generate min area quad from poly.
|
78 |
+
"""
|
79 |
+
point_num = poly.shape[0]
|
80 |
+
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
81 |
+
rect = cv2.minAreaRect(poly.astype(
|
82 |
+
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
83 |
+
box = np.array(cv2.boxPoints(rect))
|
84 |
+
|
85 |
+
first_point_idx = 0
|
86 |
+
min_dist = 1e4
|
87 |
+
for i in range(4):
|
88 |
+
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
89 |
+
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
90 |
+
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
91 |
+
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
92 |
+
if dist < min_dist:
|
93 |
+
min_dist = dist
|
94 |
+
first_point_idx = i
|
95 |
+
for i in range(4):
|
96 |
+
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
97 |
+
|
98 |
+
return min_area_quad
|
99 |
+
|
100 |
+
def check_and_validate_polys(self, polys, tags, im_size):
|
101 |
+
"""
|
102 |
+
check so that the text poly is in the same direction,
|
103 |
+
and also filter some invalid polygons
|
104 |
+
:param polys:
|
105 |
+
:param tags:
|
106 |
+
:return:
|
107 |
+
"""
|
108 |
+
(h, w) = im_size
|
109 |
+
if polys.shape[0] == 0:
|
110 |
+
return polys, np.array([]), np.array([])
|
111 |
+
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
112 |
+
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
113 |
+
|
114 |
+
validated_polys = []
|
115 |
+
validated_tags = []
|
116 |
+
hv_tags = []
|
117 |
+
for poly, tag in zip(polys, tags):
|
118 |
+
quad = self.gen_quad_from_poly(poly)
|
119 |
+
p_area = self.quad_area(quad)
|
120 |
+
if abs(p_area) < 1:
|
121 |
+
print('invalid poly')
|
122 |
+
continue
|
123 |
+
if p_area > 0:
|
124 |
+
if tag == False:
|
125 |
+
print('poly in wrong direction')
|
126 |
+
tag = True # reversed cases should be ignore
|
127 |
+
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
|
128 |
+
1), :]
|
129 |
+
quad = quad[(0, 3, 2, 1), :]
|
130 |
+
|
131 |
+
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
|
132 |
+
quad[2])
|
133 |
+
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
|
134 |
+
quad[2])
|
135 |
+
hv_tag = 1
|
136 |
+
|
137 |
+
if len_w * 2.0 < len_h:
|
138 |
+
hv_tag = 0
|
139 |
+
|
140 |
+
validated_polys.append(poly)
|
141 |
+
validated_tags.append(tag)
|
142 |
+
hv_tags.append(hv_tag)
|
143 |
+
return np.array(validated_polys), np.array(validated_tags), np.array(
|
144 |
+
hv_tags)
|
145 |
+
|
146 |
+
def crop_area(self,
|
147 |
+
im,
|
148 |
+
polys,
|
149 |
+
tags,
|
150 |
+
hv_tags,
|
151 |
+
txts,
|
152 |
+
crop_background=False,
|
153 |
+
max_tries=25):
|
154 |
+
"""
|
155 |
+
make random crop from the input image
|
156 |
+
:param im:
|
157 |
+
:param polys: [b,4,2]
|
158 |
+
:param tags:
|
159 |
+
:param crop_background:
|
160 |
+
:param max_tries: 50 -> 25
|
161 |
+
:return:
|
162 |
+
"""
|
163 |
+
h, w, _ = im.shape
|
164 |
+
pad_h = h // 10
|
165 |
+
pad_w = w // 10
|
166 |
+
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
167 |
+
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
168 |
+
for poly in polys:
|
169 |
+
poly = np.round(poly, decimals=0).astype(np.int32)
|
170 |
+
minx = np.min(poly[:, 0])
|
171 |
+
maxx = np.max(poly[:, 0])
|
172 |
+
w_array[minx + pad_w:maxx + pad_w] = 1
|
173 |
+
miny = np.min(poly[:, 1])
|
174 |
+
maxy = np.max(poly[:, 1])
|
175 |
+
h_array[miny + pad_h:maxy + pad_h] = 1
|
176 |
+
# ensure the cropped area not across a text
|
177 |
+
h_axis = np.where(h_array == 0)[0]
|
178 |
+
w_axis = np.where(w_array == 0)[0]
|
179 |
+
if len(h_axis) == 0 or len(w_axis) == 0:
|
180 |
+
return im, polys, tags, hv_tags, txts
|
181 |
+
for i in range(max_tries):
|
182 |
+
xx = np.random.choice(w_axis, size=2)
|
183 |
+
xmin = np.min(xx) - pad_w
|
184 |
+
xmax = np.max(xx) - pad_w
|
185 |
+
xmin = np.clip(xmin, 0, w - 1)
|
186 |
+
xmax = np.clip(xmax, 0, w - 1)
|
187 |
+
yy = np.random.choice(h_axis, size=2)
|
188 |
+
ymin = np.min(yy) - pad_h
|
189 |
+
ymax = np.max(yy) - pad_h
|
190 |
+
ymin = np.clip(ymin, 0, h - 1)
|
191 |
+
ymax = np.clip(ymax, 0, h - 1)
|
192 |
+
if xmax - xmin < self.min_crop_size or \
|
193 |
+
ymax - ymin < self.min_crop_size:
|
194 |
+
continue
|
195 |
+
if polys.shape[0] != 0:
|
196 |
+
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
197 |
+
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
198 |
+
selected_polys = np.where(
|
199 |
+
np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
200 |
+
else:
|
201 |
+
selected_polys = []
|
202 |
+
if len(selected_polys) == 0:
|
203 |
+
# no text in this area
|
204 |
+
if crop_background:
|
205 |
+
txts_tmp = []
|
206 |
+
for selected_poly in selected_polys:
|
207 |
+
txts_tmp.append(txts[selected_poly])
|
208 |
+
txts = txts_tmp
|
209 |
+
return im[ymin: ymax + 1, xmin: xmax + 1, :], \
|
210 |
+
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
|
211 |
+
else:
|
212 |
+
continue
|
213 |
+
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
214 |
+
polys = polys[selected_polys]
|
215 |
+
tags = tags[selected_polys]
|
216 |
+
hv_tags = hv_tags[selected_polys]
|
217 |
+
txts_tmp = []
|
218 |
+
for selected_poly in selected_polys:
|
219 |
+
txts_tmp.append(txts[selected_poly])
|
220 |
+
txts = txts_tmp
|
221 |
+
polys[:, :, 0] -= xmin
|
222 |
+
polys[:, :, 1] -= ymin
|
223 |
+
return im, polys, tags, hv_tags, txts
|
224 |
+
|
225 |
+
return im, polys, tags, hv_tags, txts
|
226 |
+
|
227 |
+
def fit_and_gather_tcl_points_v2(self,
|
228 |
+
min_area_quad,
|
229 |
+
poly,
|
230 |
+
max_h,
|
231 |
+
max_w,
|
232 |
+
fixed_point_num=64,
|
233 |
+
img_id=0,
|
234 |
+
reference_height=3):
|
235 |
+
"""
|
236 |
+
Find the center point of poly as key_points, then fit and gather.
|
237 |
+
"""
|
238 |
+
key_point_xys = []
|
239 |
+
point_num = poly.shape[0]
|
240 |
+
for idx in range(point_num // 2):
|
241 |
+
center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
|
242 |
+
key_point_xys.append(center_point)
|
243 |
+
|
244 |
+
tmp_image = np.zeros(
|
245 |
+
shape=(
|
246 |
+
max_h,
|
247 |
+
max_w, ), dtype='float32')
|
248 |
+
cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
|
249 |
+
False, 1.0)
|
250 |
+
ys, xs = np.where(tmp_image > 0)
|
251 |
+
xy_text = np.array(list(zip(xs, ys)), dtype='float32')
|
252 |
+
|
253 |
+
left_center_pt = (
|
254 |
+
(min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
|
255 |
+
right_center_pt = (
|
256 |
+
(min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
|
257 |
+
proj_unit_vec = (right_center_pt - left_center_pt) / (
|
258 |
+
np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
|
259 |
+
proj_unit_vec_tile = np.tile(proj_unit_vec,
|
260 |
+
(xy_text.shape[0], 1)) # (n, 2)
|
261 |
+
left_center_pt_tile = np.tile(left_center_pt,
|
262 |
+
(xy_text.shape[0], 1)) # (n, 2)
|
263 |
+
xy_text_to_left_center = xy_text - left_center_pt_tile
|
264 |
+
proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
|
265 |
+
xy_text = xy_text[np.argsort(proj_value)]
|
266 |
+
|
267 |
+
# convert to np and keep the num of point not greater then fixed_point_num
|
268 |
+
pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
|
269 |
+
point_num = len(pos_info)
|
270 |
+
if point_num > fixed_point_num:
|
271 |
+
keep_ids = [
|
272 |
+
int((point_num * 1.0 / fixed_point_num) * x)
|
273 |
+
for x in range(fixed_point_num)
|
274 |
+
]
|
275 |
+
pos_info = pos_info[keep_ids, :]
|
276 |
+
|
277 |
+
keep = int(min(len(pos_info), fixed_point_num))
|
278 |
+
if np.random.rand() < 0.2 and reference_height >= 3:
|
279 |
+
dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
|
280 |
+
random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
|
281 |
+
[keep, 1])
|
282 |
+
pos_info += random_float
|
283 |
+
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
|
284 |
+
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
|
285 |
+
|
286 |
+
# padding to fixed length
|
287 |
+
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
|
288 |
+
pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
|
289 |
+
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
|
290 |
+
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
|
291 |
+
pos_m[:keep] = 1.0
|
292 |
+
return pos_l, pos_m
|
293 |
+
|
294 |
+
def fit_and_gather_tcl_points_v3(self,
|
295 |
+
min_area_quad,
|
296 |
+
poly,
|
297 |
+
max_h,
|
298 |
+
max_w,
|
299 |
+
fixed_point_num=64,
|
300 |
+
img_id=0,
|
301 |
+
reference_height=3):
|
302 |
+
"""
|
303 |
+
Find the center point of poly as key_points, then fit and gather.
|
304 |
+
"""
|
305 |
+
det_mask = np.zeros((int(max_h / self.ds_ratio),
|
306 |
+
int(max_w / self.ds_ratio))).astype(np.float32)
|
307 |
+
|
308 |
+
# score_big_map
|
309 |
+
cv2.fillPoly(det_mask,
|
310 |
+
np.round(poly / self.ds_ratio).astype(np.int32), 1.0)
|
311 |
+
det_mask = cv2.resize(
|
312 |
+
det_mask, dsize=None, fx=self.ds_ratio, fy=self.ds_ratio)
|
313 |
+
det_mask = np.array(det_mask > 1e-3, dtype='float32')
|
314 |
+
|
315 |
+
f_direction = self.f_direction
|
316 |
+
skeleton_map = thin(det_mask.astype(np.uint8))
|
317 |
+
instance_count, instance_label_map = cv2.connectedComponents(
|
318 |
+
skeleton_map.astype(np.uint8), connectivity=8)
|
319 |
+
|
320 |
+
ys, xs = np.where(instance_label_map == 1)
|
321 |
+
pos_list = list(zip(ys, xs))
|
322 |
+
if len(pos_list) < 3:
|
323 |
+
return None
|
324 |
+
pos_list_sorted = sort_and_expand_with_direction_v2(
|
325 |
+
pos_list, f_direction, det_mask)
|
326 |
+
|
327 |
+
pos_list_sorted = np.array(pos_list_sorted)
|
328 |
+
length = len(pos_list_sorted) - 1
|
329 |
+
insert_num = 0
|
330 |
+
for index in range(length):
|
331 |
+
stride_y = np.abs(pos_list_sorted[index + insert_num][0] -
|
332 |
+
pos_list_sorted[index + 1 + insert_num][0])
|
333 |
+
stride_x = np.abs(pos_list_sorted[index + insert_num][1] -
|
334 |
+
pos_list_sorted[index + 1 + insert_num][1])
|
335 |
+
max_points = int(max(stride_x, stride_y))
|
336 |
+
|
337 |
+
stride = (pos_list_sorted[index + insert_num] -
|
338 |
+
pos_list_sorted[index + 1 + insert_num]) / (max_points)
|
339 |
+
insert_num_temp = max_points - 1
|
340 |
+
|
341 |
+
for i in range(int(insert_num_temp)):
|
342 |
+
insert_value = pos_list_sorted[index + insert_num] - (i + 1
|
343 |
+
) * stride
|
344 |
+
insert_index = index + i + 1 + insert_num
|
345 |
+
pos_list_sorted = np.insert(
|
346 |
+
pos_list_sorted, insert_index, insert_value, axis=0)
|
347 |
+
insert_num += insert_num_temp
|
348 |
+
|
349 |
+
pos_info = np.array(pos_list_sorted).reshape(-1, 2).astype(
|
350 |
+
np.float32) # xy-> yx
|
351 |
+
|
352 |
+
point_num = len(pos_info)
|
353 |
+
if point_num > fixed_point_num:
|
354 |
+
keep_ids = [
|
355 |
+
int((point_num * 1.0 / fixed_point_num) * x)
|
356 |
+
for x in range(fixed_point_num)
|
357 |
+
]
|
358 |
+
pos_info = pos_info[keep_ids, :]
|
359 |
+
|
360 |
+
keep = int(min(len(pos_info), fixed_point_num))
|
361 |
+
reference_width = (np.abs(poly[0, 0, 0] - poly[-1, 1, 0]) +
|
362 |
+
np.abs(poly[0, 3, 0] - poly[-1, 2, 0])) // 2
|
363 |
+
if np.random.rand() < 1:
|
364 |
+
dh = (np.random.rand(keep) - 0.5) * reference_height
|
365 |
+
offset = np.random.rand() - 0.5
|
366 |
+
dw = np.array([[0, offset * reference_width * 0.2]])
|
367 |
+
random_float_h = np.array([1, 0]).reshape([1, 2]) * dh.reshape(
|
368 |
+
[keep, 1])
|
369 |
+
random_float_w = dw.repeat(keep, axis=0)
|
370 |
+
pos_info += random_float_h
|
371 |
+
pos_info += random_float_w
|
372 |
+
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
|
373 |
+
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
|
374 |
+
|
375 |
+
# padding to fixed length
|
376 |
+
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
|
377 |
+
pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
|
378 |
+
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
|
379 |
+
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
|
380 |
+
pos_m[:keep] = 1.0
|
381 |
+
return pos_l, pos_m
|
382 |
+
|
383 |
+
def generate_direction_map(self, poly_quads, n_char, direction_map):
|
384 |
+
"""
|
385 |
+
"""
|
386 |
+
width_list = []
|
387 |
+
height_list = []
|
388 |
+
for quad in poly_quads:
|
389 |
+
quad_w = (np.linalg.norm(quad[0] - quad[1]) +
|
390 |
+
np.linalg.norm(quad[2] - quad[3])) / 2.0
|
391 |
+
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
392 |
+
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
393 |
+
width_list.append(quad_w)
|
394 |
+
height_list.append(quad_h)
|
395 |
+
norm_width = max(sum(width_list) / n_char, 1.0)
|
396 |
+
average_height = max(sum(height_list) / len(height_list), 1.0)
|
397 |
+
k = 1
|
398 |
+
for quad in poly_quads:
|
399 |
+
direct_vector_full = (
|
400 |
+
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
401 |
+
direct_vector = direct_vector_full / (
|
402 |
+
np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
403 |
+
direction_label = tuple(
|
404 |
+
map(float,
|
405 |
+
[direct_vector[0], direct_vector[1], 1.0 / average_height]))
|
406 |
+
cv2.fillPoly(direction_map,
|
407 |
+
quad.round().astype(np.int32)[np.newaxis, :, :],
|
408 |
+
direction_label)
|
409 |
+
k += 1
|
410 |
+
return direction_map
|
411 |
+
|
412 |
+
def calculate_average_height(self, poly_quads):
|
413 |
+
"""
|
414 |
+
"""
|
415 |
+
height_list = []
|
416 |
+
for quad in poly_quads:
|
417 |
+
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
418 |
+
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
419 |
+
height_list.append(quad_h)
|
420 |
+
average_height = max(sum(height_list) / len(height_list), 1.0)
|
421 |
+
return average_height
|
422 |
+
|
423 |
+
def generate_tcl_ctc_label(self,
|
424 |
+
h,
|
425 |
+
w,
|
426 |
+
polys,
|
427 |
+
tags,
|
428 |
+
text_strs,
|
429 |
+
ds_ratio,
|
430 |
+
tcl_ratio=0.3,
|
431 |
+
shrink_ratio_of_width=0.15):
|
432 |
+
"""
|
433 |
+
Generate polygon.
|
434 |
+
"""
|
435 |
+
self.ds_ratio = ds_ratio
|
436 |
+
score_map_big = np.zeros(
|
437 |
+
(
|
438 |
+
h,
|
439 |
+
w, ), dtype=np.float32)
|
440 |
+
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
441 |
+
polys = polys * ds_ratio
|
442 |
+
|
443 |
+
score_map = np.zeros(
|
444 |
+
(
|
445 |
+
h,
|
446 |
+
w, ), dtype=np.float32)
|
447 |
+
score_label_map = np.zeros(
|
448 |
+
(
|
449 |
+
h,
|
450 |
+
w, ), dtype=np.float32)
|
451 |
+
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
452 |
+
training_mask = np.ones(
|
453 |
+
(
|
454 |
+
h,
|
455 |
+
w, ), dtype=np.float32)
|
456 |
+
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
|
457 |
+
[1, 1, 3]).astype(np.float32)
|
458 |
+
|
459 |
+
label_idx = 0
|
460 |
+
score_label_map_text_label_list = []
|
461 |
+
pos_list, pos_mask, label_list = [], [], []
|
462 |
+
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
463 |
+
poly = poly_tag[0]
|
464 |
+
tag = poly_tag[1]
|
465 |
+
|
466 |
+
# generate min_area_quad
|
467 |
+
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
468 |
+
min_area_quad_h = 0.5 * (
|
469 |
+
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
470 |
+
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
471 |
+
min_area_quad_w = 0.5 * (
|
472 |
+
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
473 |
+
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
474 |
+
|
475 |
+
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
|
476 |
+
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
|
477 |
+
continue
|
478 |
+
|
479 |
+
if tag:
|
480 |
+
cv2.fillPoly(training_mask,
|
481 |
+
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
482 |
+
else:
|
483 |
+
text_label = text_strs[poly_idx]
|
484 |
+
text_label = self.prepare_text_label(text_label,
|
485 |
+
self.Lexicon_Table)
|
486 |
+
text_label_index_list = [[self.Lexicon_Table.index(c_)]
|
487 |
+
for c_ in text_label
|
488 |
+
if c_ in self.Lexicon_Table]
|
489 |
+
if len(text_label_index_list) < 1:
|
490 |
+
continue
|
491 |
+
|
492 |
+
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
493 |
+
tcl_quads = self.poly2quads(tcl_poly)
|
494 |
+
poly_quads = self.poly2quads(poly)
|
495 |
+
|
496 |
+
stcl_quads, quad_index = self.shrink_poly_along_width(
|
497 |
+
tcl_quads,
|
498 |
+
shrink_ratio_of_width=shrink_ratio_of_width,
|
499 |
+
expand_height_ratio=1.0 / tcl_ratio)
|
500 |
+
|
501 |
+
cv2.fillPoly(score_map,
|
502 |
+
np.round(stcl_quads).astype(np.int32), 1.0)
|
503 |
+
cv2.fillPoly(score_map_big,
|
504 |
+
np.round(stcl_quads / ds_ratio).astype(np.int32),
|
505 |
+
1.0)
|
506 |
+
|
507 |
+
for idx, quad in enumerate(stcl_quads):
|
508 |
+
quad_mask = np.zeros((h, w), dtype=np.float32)
|
509 |
+
quad_mask = cv2.fillPoly(
|
510 |
+
quad_mask,
|
511 |
+
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
512 |
+
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
|
513 |
+
quad_mask, tbo_map)
|
514 |
+
|
515 |
+
# score label map and score_label_map_text_label_list for refine
|
516 |
+
if label_idx == 0:
|
517 |
+
text_pos_list_ = [[len(self.Lexicon_Table)], ]
|
518 |
+
score_label_map_text_label_list.append(text_pos_list_)
|
519 |
+
|
520 |
+
label_idx += 1
|
521 |
+
cv2.fillPoly(score_label_map,
|
522 |
+
np.round(poly_quads).astype(np.int32), label_idx)
|
523 |
+
score_label_map_text_label_list.append(text_label_index_list)
|
524 |
+
|
525 |
+
# direction info, fix-me
|
526 |
+
n_char = len(text_label_index_list)
|
527 |
+
direction_map = self.generate_direction_map(poly_quads, n_char,
|
528 |
+
direction_map)
|
529 |
+
|
530 |
+
# pos info
|
531 |
+
average_shrink_height = self.calculate_average_height(
|
532 |
+
stcl_quads)
|
533 |
+
|
534 |
+
if self.point_gather_mode == 'align':
|
535 |
+
self.f_direction = direction_map[:, :, :-1].copy()
|
536 |
+
pos_res = self.fit_and_gather_tcl_points_v3(
|
537 |
+
min_area_quad,
|
538 |
+
stcl_quads,
|
539 |
+
max_h=h,
|
540 |
+
max_w=w,
|
541 |
+
fixed_point_num=64,
|
542 |
+
img_id=self.img_id,
|
543 |
+
reference_height=average_shrink_height)
|
544 |
+
if pos_res is None:
|
545 |
+
continue
|
546 |
+
pos_l, pos_m = pos_res[0], pos_res[1]
|
547 |
+
|
548 |
+
else:
|
549 |
+
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
|
550 |
+
min_area_quad,
|
551 |
+
poly,
|
552 |
+
max_h=h,
|
553 |
+
max_w=w,
|
554 |
+
fixed_point_num=64,
|
555 |
+
img_id=self.img_id,
|
556 |
+
reference_height=average_shrink_height)
|
557 |
+
|
558 |
+
label_l = text_label_index_list
|
559 |
+
if len(text_label_index_list) < 2:
|
560 |
+
continue
|
561 |
+
|
562 |
+
pos_list.append(pos_l)
|
563 |
+
pos_mask.append(pos_m)
|
564 |
+
label_list.append(label_l)
|
565 |
+
|
566 |
+
# use big score_map for smooth tcl lines
|
567 |
+
score_map_big_resized = cv2.resize(
|
568 |
+
score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
|
569 |
+
score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
|
570 |
+
|
571 |
+
return score_map, score_label_map, tbo_map, direction_map, training_mask, \
|
572 |
+
pos_list, pos_mask, label_list, score_label_map_text_label_list
|
573 |
+
|
574 |
+
def adjust_point(self, poly):
|
575 |
+
"""
|
576 |
+
adjust point order.
|
577 |
+
"""
|
578 |
+
point_num = poly.shape[0]
|
579 |
+
if point_num == 4:
|
580 |
+
len_1 = np.linalg.norm(poly[0] - poly[1])
|
581 |
+
len_2 = np.linalg.norm(poly[1] - poly[2])
|
582 |
+
len_3 = np.linalg.norm(poly[2] - poly[3])
|
583 |
+
len_4 = np.linalg.norm(poly[3] - poly[0])
|
584 |
+
|
585 |
+
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
|
586 |
+
poly = poly[[1, 2, 3, 0], :]
|
587 |
+
|
588 |
+
elif point_num > 4:
|
589 |
+
vector_1 = poly[0] - poly[1]
|
590 |
+
vector_2 = poly[1] - poly[2]
|
591 |
+
cos_theta = np.dot(vector_1, vector_2) / (
|
592 |
+
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
593 |
+
theta = np.arccos(np.round(cos_theta, decimals=4))
|
594 |
+
|
595 |
+
if abs(theta) > (70 / 180 * math.pi):
|
596 |
+
index = list(range(1, point_num)) + [0]
|
597 |
+
poly = poly[np.array(index), :]
|
598 |
+
return poly
|
599 |
+
|
600 |
+
def gen_min_area_quad_from_poly(self, poly):
|
601 |
+
"""
|
602 |
+
Generate min area quad from poly.
|
603 |
+
"""
|
604 |
+
point_num = poly.shape[0]
|
605 |
+
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
606 |
+
if point_num == 4:
|
607 |
+
min_area_quad = poly
|
608 |
+
center_point = np.sum(poly, axis=0) / 4
|
609 |
+
else:
|
610 |
+
rect = cv2.minAreaRect(poly.astype(
|
611 |
+
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
612 |
+
center_point = rect[0]
|
613 |
+
box = np.array(cv2.boxPoints(rect))
|
614 |
+
|
615 |
+
first_point_idx = 0
|
616 |
+
min_dist = 1e4
|
617 |
+
for i in range(4):
|
618 |
+
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
619 |
+
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
620 |
+
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
621 |
+
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
622 |
+
if dist < min_dist:
|
623 |
+
min_dist = dist
|
624 |
+
first_point_idx = i
|
625 |
+
|
626 |
+
for i in range(4):
|
627 |
+
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
628 |
+
|
629 |
+
return min_area_quad, center_point
|
630 |
+
|
631 |
+
def shrink_quad_along_width(self,
|
632 |
+
quad,
|
633 |
+
begin_width_ratio=0.,
|
634 |
+
end_width_ratio=1.):
|
635 |
+
"""
|
636 |
+
Generate shrink_quad_along_width.
|
637 |
+
"""
|
638 |
+
ratio_pair = np.array(
|
639 |
+
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
640 |
+
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
641 |
+
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
642 |
+
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
643 |
+
|
644 |
+
def shrink_poly_along_width(self,
|
645 |
+
quads,
|
646 |
+
shrink_ratio_of_width,
|
647 |
+
expand_height_ratio=1.0):
|
648 |
+
"""
|
649 |
+
shrink poly with given length.
|
650 |
+
"""
|
651 |
+
upper_edge_list = []
|
652 |
+
|
653 |
+
def get_cut_info(edge_len_list, cut_len):
|
654 |
+
for idx, edge_len in enumerate(edge_len_list):
|
655 |
+
cut_len -= edge_len
|
656 |
+
if cut_len <= 0.000001:
|
657 |
+
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
|
658 |
+
return idx, ratio
|
659 |
+
|
660 |
+
for quad in quads:
|
661 |
+
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
|
662 |
+
upper_edge_list.append(upper_edge_len)
|
663 |
+
|
664 |
+
# length of left edge and right edge.
|
665 |
+
left_length = np.linalg.norm(quads[0][0] - quads[0][
|
666 |
+
3]) * expand_height_ratio
|
667 |
+
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
|
668 |
+
2]) * expand_height_ratio
|
669 |
+
|
670 |
+
shrink_length = min(left_length, right_length,
|
671 |
+
sum(upper_edge_list)) * shrink_ratio_of_width
|
672 |
+
# shrinking length
|
673 |
+
upper_len_left = shrink_length
|
674 |
+
upper_len_right = sum(upper_edge_list) - shrink_length
|
675 |
+
|
676 |
+
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
677 |
+
left_quad = self.shrink_quad_along_width(
|
678 |
+
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
679 |
+
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
680 |
+
right_quad = self.shrink_quad_along_width(
|
681 |
+
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
682 |
+
|
683 |
+
out_quad_list = []
|
684 |
+
if left_idx == right_idx:
|
685 |
+
out_quad_list.append(
|
686 |
+
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
687 |
+
else:
|
688 |
+
out_quad_list.append(left_quad)
|
689 |
+
for idx in range(left_idx + 1, right_idx):
|
690 |
+
out_quad_list.append(quads[idx])
|
691 |
+
out_quad_list.append(right_quad)
|
692 |
+
|
693 |
+
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
|
694 |
+
|
695 |
+
def prepare_text_label(self, label_str, Lexicon_Table):
|
696 |
+
"""
|
697 |
+
Prepare text lablel by given Lexicon_Table.
|
698 |
+
"""
|
699 |
+
if len(Lexicon_Table) == 36:
|
700 |
+
return label_str.lower()
|
701 |
+
else:
|
702 |
+
return label_str
|
703 |
+
|
704 |
+
def vector_angle(self, A, B):
|
705 |
+
"""
|
706 |
+
Calculate the angle between vector AB and x-axis positive direction.
|
707 |
+
"""
|
708 |
+
AB = np.array([B[1] - A[1], B[0] - A[0]])
|
709 |
+
return np.arctan2(*AB)
|
710 |
+
|
711 |
+
def theta_line_cross_point(self, theta, point):
|
712 |
+
"""
|
713 |
+
Calculate the line through given point and angle in ax + by + c =0 form.
|
714 |
+
"""
|
715 |
+
x, y = point
|
716 |
+
cos = np.cos(theta)
|
717 |
+
sin = np.sin(theta)
|
718 |
+
return [sin, -cos, cos * y - sin * x]
|
719 |
+
|
720 |
+
def line_cross_two_point(self, A, B):
|
721 |
+
"""
|
722 |
+
Calculate the line through given point A and B in ax + by + c =0 form.
|
723 |
+
"""
|
724 |
+
angle = self.vector_angle(A, B)
|
725 |
+
return self.theta_line_cross_point(angle, A)
|
726 |
+
|
727 |
+
def average_angle(self, poly):
|
728 |
+
"""
|
729 |
+
Calculate the average angle between left and right edge in given poly.
|
730 |
+
"""
|
731 |
+
p0, p1, p2, p3 = poly
|
732 |
+
angle30 = self.vector_angle(p3, p0)
|
733 |
+
angle21 = self.vector_angle(p2, p1)
|
734 |
+
return (angle30 + angle21) / 2
|
735 |
+
|
736 |
+
def line_cross_point(self, line1, line2):
|
737 |
+
"""
|
738 |
+
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
|
739 |
+
"""
|
740 |
+
a1, b1, c1 = line1
|
741 |
+
a2, b2, c2 = line2
|
742 |
+
d = a1 * b2 - a2 * b1
|
743 |
+
|
744 |
+
if d == 0:
|
745 |
+
print('Cross point does not exist')
|
746 |
+
return np.array([0, 0], dtype=np.float32)
|
747 |
+
else:
|
748 |
+
x = (b1 * c2 - b2 * c1) / d
|
749 |
+
y = (a2 * c1 - a1 * c2) / d
|
750 |
+
|
751 |
+
return np.array([x, y], dtype=np.float32)
|
752 |
+
|
753 |
+
def quad2tcl(self, poly, ratio):
|
754 |
+
"""
|
755 |
+
Generate center line by poly clock-wise point. (4, 2)
|
756 |
+
"""
|
757 |
+
ratio_pair = np.array(
|
758 |
+
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
759 |
+
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
760 |
+
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
761 |
+
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
762 |
+
|
763 |
+
def poly2tcl(self, poly, ratio):
|
764 |
+
"""
|
765 |
+
Generate center line by poly clock-wise point.
|
766 |
+
"""
|
767 |
+
ratio_pair = np.array(
|
768 |
+
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
769 |
+
tcl_poly = np.zeros_like(poly)
|
770 |
+
point_num = poly.shape[0]
|
771 |
+
|
772 |
+
for idx in range(point_num // 2):
|
773 |
+
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
|
774 |
+
) * ratio_pair
|
775 |
+
tcl_poly[idx] = point_pair[0]
|
776 |
+
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
777 |
+
return tcl_poly
|
778 |
+
|
779 |
+
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
|
780 |
+
"""
|
781 |
+
Generate tbo_map for give quad.
|
782 |
+
"""
|
783 |
+
# upper and lower line function: ax + by + c = 0;
|
784 |
+
up_line = self.line_cross_two_point(quad[0], quad[1])
|
785 |
+
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
786 |
+
|
787 |
+
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
|
788 |
+
np.linalg.norm(quad[1] - quad[2]))
|
789 |
+
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
|
790 |
+
np.linalg.norm(quad[2] - quad[3]))
|
791 |
+
|
792 |
+
# average angle of left and right line.
|
793 |
+
angle = self.average_angle(quad)
|
794 |
+
|
795 |
+
xy_in_poly = np.argwhere(tcl_mask == 1)
|
796 |
+
for y, x in xy_in_poly:
|
797 |
+
point = (x, y)
|
798 |
+
line = self.theta_line_cross_point(angle, point)
|
799 |
+
cross_point_upper = self.line_cross_point(up_line, line)
|
800 |
+
cross_point_lower = self.line_cross_point(lower_line, line)
|
801 |
+
##FIX, offset reverse
|
802 |
+
upper_offset_x, upper_offset_y = cross_point_upper - point
|
803 |
+
lower_offset_x, lower_offset_y = cross_point_lower - point
|
804 |
+
tbo_map[y, x, 0] = upper_offset_y
|
805 |
+
tbo_map[y, x, 1] = upper_offset_x
|
806 |
+
tbo_map[y, x, 2] = lower_offset_y
|
807 |
+
tbo_map[y, x, 3] = lower_offset_x
|
808 |
+
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
|
809 |
+
return tbo_map
|
810 |
+
|
811 |
+
def poly2quads(self, poly):
|
812 |
+
"""
|
813 |
+
Split poly into quads.
|
814 |
+
"""
|
815 |
+
quad_list = []
|
816 |
+
point_num = poly.shape[0]
|
817 |
+
|
818 |
+
# point pair
|
819 |
+
point_pair_list = []
|
820 |
+
for idx in range(point_num // 2):
|
821 |
+
point_pair = [poly[idx], poly[point_num - 1 - idx]]
|
822 |
+
point_pair_list.append(point_pair)
|
823 |
+
|
824 |
+
quad_num = point_num // 2 - 1
|
825 |
+
for idx in range(quad_num):
|
826 |
+
# reshape and adjust to clock-wise
|
827 |
+
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
|
828 |
+
).reshape(4, 2)[[0, 2, 3, 1]])
|
829 |
+
|
830 |
+
return np.array(quad_list)
|
831 |
+
|
832 |
+
def rotate_im_poly(self, im, text_polys):
|
833 |
+
"""
|
834 |
+
rotate image with 90 / 180 / 270 degre
|
835 |
+
"""
|
836 |
+
im_w, im_h = im.shape[1], im.shape[0]
|
837 |
+
dst_im = im.copy()
|
838 |
+
dst_polys = []
|
839 |
+
rand_degree_ratio = np.random.rand()
|
840 |
+
rand_degree_cnt = 1
|
841 |
+
if rand_degree_ratio > 0.5:
|
842 |
+
rand_degree_cnt = 3
|
843 |
+
for i in range(rand_degree_cnt):
|
844 |
+
dst_im = np.rot90(dst_im)
|
845 |
+
rot_degree = -90 * rand_degree_cnt
|
846 |
+
rot_angle = rot_degree * math.pi / 180.0
|
847 |
+
n_poly = text_polys.shape[0]
|
848 |
+
cx, cy = 0.5 * im_w, 0.5 * im_h
|
849 |
+
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
|
850 |
+
for i in range(n_poly):
|
851 |
+
wordBB = text_polys[i]
|
852 |
+
poly = []
|
853 |
+
for j in range(4): # 16->4
|
854 |
+
sx, sy = wordBB[j][0], wordBB[j][1]
|
855 |
+
dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
|
856 |
+
sy - cy) + ncx
|
857 |
+
dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
|
858 |
+
sy - cy) + ncy
|
859 |
+
poly.append([dx, dy])
|
860 |
+
dst_polys.append(poly)
|
861 |
+
return dst_im, np.array(dst_polys, dtype=np.float32)
|
862 |
+
|
863 |
+
def __call__(self, data):
|
864 |
+
input_size = 512
|
865 |
+
im = data['image']
|
866 |
+
text_polys = data['polys']
|
867 |
+
text_tags = data['ignore_tags']
|
868 |
+
text_strs = data['texts']
|
869 |
+
h, w, _ = im.shape
|
870 |
+
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
|
871 |
+
text_polys, text_tags, (h, w))
|
872 |
+
if text_polys.shape[0] <= 0:
|
873 |
+
return None
|
874 |
+
# set aspect ratio and keep area fix
|
875 |
+
asp_scales = np.arange(1.0, 1.55, 0.1)
|
876 |
+
asp_scale = np.random.choice(asp_scales)
|
877 |
+
if np.random.rand() < 0.5:
|
878 |
+
asp_scale = 1.0 / asp_scale
|
879 |
+
asp_scale = math.sqrt(asp_scale)
|
880 |
+
|
881 |
+
asp_wx = asp_scale
|
882 |
+
asp_hy = 1.0 / asp_scale
|
883 |
+
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
884 |
+
text_polys[:, :, 0] *= asp_wx
|
885 |
+
text_polys[:, :, 1] *= asp_hy
|
886 |
+
|
887 |
+
if self.use_resize is True:
|
888 |
+
ori_h, ori_w, _ = im.shape
|
889 |
+
if max(ori_h, ori_w) < 200:
|
890 |
+
ratio = 200 / max(ori_h, ori_w)
|
891 |
+
im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
|
892 |
+
text_polys[:, :, 0] *= ratio
|
893 |
+
text_polys[:, :, 1] *= ratio
|
894 |
+
|
895 |
+
if max(ori_h, ori_w) > 512:
|
896 |
+
ratio = 512 / max(ori_h, ori_w)
|
897 |
+
im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
|
898 |
+
text_polys[:, :, 0] *= ratio
|
899 |
+
text_polys[:, :, 1] *= ratio
|
900 |
+
elif self.use_random_crop is True:
|
901 |
+
h, w, _ = im.shape
|
902 |
+
if max(h, w) > 2048:
|
903 |
+
rd_scale = 2048.0 / max(h, w)
|
904 |
+
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
905 |
+
text_polys *= rd_scale
|
906 |
+
h, w, _ = im.shape
|
907 |
+
if min(h, w) < 16:
|
908 |
+
return None
|
909 |
+
|
910 |
+
# no background
|
911 |
+
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
|
912 |
+
im,
|
913 |
+
text_polys,
|
914 |
+
text_tags,
|
915 |
+
hv_tags,
|
916 |
+
text_strs,
|
917 |
+
crop_background=False)
|
918 |
+
|
919 |
+
if text_polys.shape[0] == 0:
|
920 |
+
return None
|
921 |
+
# continue for all ignore case
|
922 |
+
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
923 |
+
return None
|
924 |
+
new_h, new_w, _ = im.shape
|
925 |
+
if (new_h is None) or (new_w is None):
|
926 |
+
return None
|
927 |
+
# resize image
|
928 |
+
std_ratio = float(input_size) / max(new_w, new_h)
|
929 |
+
rand_scales = np.array(
|
930 |
+
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
931 |
+
rz_scale = std_ratio * np.random.choice(rand_scales)
|
932 |
+
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
933 |
+
text_polys[:, :, 0] *= rz_scale
|
934 |
+
text_polys[:, :, 1] *= rz_scale
|
935 |
+
|
936 |
+
# add gaussian blur
|
937 |
+
if np.random.rand() < 0.1 * 0.5:
|
938 |
+
ks = np.random.permutation(5)[0] + 1
|
939 |
+
ks = int(ks / 2) * 2 + 1
|
940 |
+
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
941 |
+
# add brighter
|
942 |
+
if np.random.rand() < 0.1 * 0.5:
|
943 |
+
im = im * (1.0 + np.random.rand() * 0.5)
|
944 |
+
im = np.clip(im, 0.0, 255.0)
|
945 |
+
# add darker
|
946 |
+
if np.random.rand() < 0.1 * 0.5:
|
947 |
+
im = im * (1.0 - np.random.rand() * 0.5)
|
948 |
+
im = np.clip(im, 0.0, 255.0)
|
949 |
+
|
950 |
+
# Padding the im to [input_size, input_size]
|
951 |
+
new_h, new_w, _ = im.shape
|
952 |
+
if min(new_w, new_h) < input_size * 0.5:
|
953 |
+
return None
|
954 |
+
im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
|
955 |
+
im_padded[:, :, 2] = 0.485 * 255
|
956 |
+
im_padded[:, :, 1] = 0.456 * 255
|
957 |
+
im_padded[:, :, 0] = 0.406 * 255
|
958 |
+
|
959 |
+
# Random the start position
|
960 |
+
del_h = input_size - new_h
|
961 |
+
del_w = input_size - new_w
|
962 |
+
sh, sw = 0, 0
|
963 |
+
if del_h > 1:
|
964 |
+
sh = int(np.random.rand() * del_h)
|
965 |
+
if del_w > 1:
|
966 |
+
sw = int(np.random.rand() * del_w)
|
967 |
+
|
968 |
+
# Padding
|
969 |
+
im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
|
970 |
+
text_polys[:, :, 0] += sw
|
971 |
+
text_polys[:, :, 1] += sh
|
972 |
+
|
973 |
+
score_map, score_label_map, border_map, direction_map, training_mask, \
|
974 |
+
pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
|
975 |
+
input_size,
|
976 |
+
text_polys,
|
977 |
+
text_tags,
|
978 |
+
text_strs, 0.25)
|
979 |
+
if len(label_list) <= 0: # eliminate negative samples
|
980 |
+
return None
|
981 |
+
pos_list_temp = np.zeros([64, 3])
|
982 |
+
pos_mask_temp = np.zeros([64, 1])
|
983 |
+
label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
|
984 |
+
|
985 |
+
for i, label in enumerate(label_list):
|
986 |
+
n = len(label)
|
987 |
+
if n > self.max_text_length:
|
988 |
+
label_list[i] = label[:self.max_text_length]
|
989 |
+
continue
|
990 |
+
while n < self.max_text_length:
|
991 |
+
label.append([self.pad_num])
|
992 |
+
n += 1
|
993 |
+
|
994 |
+
for i in range(len(label_list)):
|
995 |
+
label_list[i] = np.array(label_list[i])
|
996 |
+
|
997 |
+
if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
|
998 |
+
return None
|
999 |
+
for __ in range(self.max_text_nums - len(pos_list), 0, -1):
|
1000 |
+
pos_list.append(pos_list_temp)
|
1001 |
+
pos_mask.append(pos_mask_temp)
|
1002 |
+
label_list.append(label_list_temp)
|
1003 |
+
|
1004 |
+
if self.img_id == self.batch_size - 1:
|
1005 |
+
self.img_id = 0
|
1006 |
+
else:
|
1007 |
+
self.img_id += 1
|
1008 |
+
|
1009 |
+
im_padded[:, :, 2] -= 0.485 * 255
|
1010 |
+
im_padded[:, :, 1] -= 0.456 * 255
|
1011 |
+
im_padded[:, :, 0] -= 0.406 * 255
|
1012 |
+
im_padded[:, :, 2] /= (255.0 * 0.229)
|
1013 |
+
im_padded[:, :, 1] /= (255.0 * 0.224)
|
1014 |
+
im_padded[:, :, 0] /= (255.0 * 0.225)
|
1015 |
+
im_padded = im_padded.transpose((2, 0, 1))
|
1016 |
+
images = im_padded[::-1, :, :]
|
1017 |
+
tcl_maps = score_map[np.newaxis, :, :]
|
1018 |
+
tcl_label_maps = score_label_map[np.newaxis, :, :]
|
1019 |
+
border_maps = border_map.transpose((2, 0, 1))
|
1020 |
+
direction_maps = direction_map.transpose((2, 0, 1))
|
1021 |
+
training_masks = training_mask[np.newaxis, :, :]
|
1022 |
+
pos_list = np.array(pos_list)
|
1023 |
+
pos_mask = np.array(pos_mask)
|
1024 |
+
label_list = np.array(label_list)
|
1025 |
+
data['images'] = images
|
1026 |
+
data['tcl_maps'] = tcl_maps
|
1027 |
+
data['tcl_label_maps'] = tcl_label_maps
|
1028 |
+
data['border_maps'] = border_maps
|
1029 |
+
data['direction_maps'] = direction_maps
|
1030 |
+
data['training_masks'] = training_masks
|
1031 |
+
data['label_list'] = label_list
|
1032 |
+
data['pos_list'] = pos_list
|
1033 |
+
data['pos_mask'] = pos_mask
|
1034 |
+
return data
|
ppocr/data/imaug/randaugment.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
from __future__ import unicode_literals
|
19 |
+
|
20 |
+
from PIL import Image, ImageEnhance, ImageOps
|
21 |
+
import numpy as np
|
22 |
+
import random
|
23 |
+
import six
|
24 |
+
|
25 |
+
|
26 |
+
class RawRandAugment(object):
|
27 |
+
def __init__(self,
|
28 |
+
num_layers=2,
|
29 |
+
magnitude=5,
|
30 |
+
fillcolor=(128, 128, 128),
|
31 |
+
**kwargs):
|
32 |
+
self.num_layers = num_layers
|
33 |
+
self.magnitude = magnitude
|
34 |
+
self.max_level = 10
|
35 |
+
|
36 |
+
abso_level = self.magnitude / self.max_level
|
37 |
+
self.level_map = {
|
38 |
+
"shearX": 0.3 * abso_level,
|
39 |
+
"shearY": 0.3 * abso_level,
|
40 |
+
"translateX": 150.0 / 331 * abso_level,
|
41 |
+
"translateY": 150.0 / 331 * abso_level,
|
42 |
+
"rotate": 30 * abso_level,
|
43 |
+
"color": 0.9 * abso_level,
|
44 |
+
"posterize": int(4.0 * abso_level),
|
45 |
+
"solarize": 256.0 * abso_level,
|
46 |
+
"contrast": 0.9 * abso_level,
|
47 |
+
"sharpness": 0.9 * abso_level,
|
48 |
+
"brightness": 0.9 * abso_level,
|
49 |
+
"autocontrast": 0,
|
50 |
+
"equalize": 0,
|
51 |
+
"invert": 0
|
52 |
+
}
|
53 |
+
|
54 |
+
# from https://stackoverflow.com/questions/5252170/
|
55 |
+
# specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
56 |
+
def rotate_with_fill(img, magnitude):
|
57 |
+
rot = img.convert("RGBA").rotate(magnitude)
|
58 |
+
return Image.composite(rot,
|
59 |
+
Image.new("RGBA", rot.size, (128, ) * 4),
|
60 |
+
rot).convert(img.mode)
|
61 |
+
|
62 |
+
rnd_ch_op = random.choice
|
63 |
+
|
64 |
+
self.func = {
|
65 |
+
"shearX": lambda img, magnitude: img.transform(
|
66 |
+
img.size,
|
67 |
+
Image.AFFINE,
|
68 |
+
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
|
69 |
+
Image.BICUBIC,
|
70 |
+
fillcolor=fillcolor),
|
71 |
+
"shearY": lambda img, magnitude: img.transform(
|
72 |
+
img.size,
|
73 |
+
Image.AFFINE,
|
74 |
+
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
|
75 |
+
Image.BICUBIC,
|
76 |
+
fillcolor=fillcolor),
|
77 |
+
"translateX": lambda img, magnitude: img.transform(
|
78 |
+
img.size,
|
79 |
+
Image.AFFINE,
|
80 |
+
(1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
|
81 |
+
fillcolor=fillcolor),
|
82 |
+
"translateY": lambda img, magnitude: img.transform(
|
83 |
+
img.size,
|
84 |
+
Image.AFFINE,
|
85 |
+
(1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
|
86 |
+
fillcolor=fillcolor),
|
87 |
+
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
88 |
+
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
|
89 |
+
1 + magnitude * rnd_ch_op([-1, 1])),
|
90 |
+
"posterize": lambda img, magnitude:
|
91 |
+
ImageOps.posterize(img, magnitude),
|
92 |
+
"solarize": lambda img, magnitude:
|
93 |
+
ImageOps.solarize(img, magnitude),
|
94 |
+
"contrast": lambda img, magnitude:
|
95 |
+
ImageEnhance.Contrast(img).enhance(
|
96 |
+
1 + magnitude * rnd_ch_op([-1, 1])),
|
97 |
+
"sharpness": lambda img, magnitude:
|
98 |
+
ImageEnhance.Sharpness(img).enhance(
|
99 |
+
1 + magnitude * rnd_ch_op([-1, 1])),
|
100 |
+
"brightness": lambda img, magnitude:
|
101 |
+
ImageEnhance.Brightness(img).enhance(
|
102 |
+
1 + magnitude * rnd_ch_op([-1, 1])),
|
103 |
+
"autocontrast": lambda img, magnitude:
|
104 |
+
ImageOps.autocontrast(img),
|
105 |
+
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
106 |
+
"invert": lambda img, magnitude: ImageOps.invert(img)
|
107 |
+
}
|
108 |
+
|
109 |
+
def __call__(self, img):
|
110 |
+
avaiable_op_names = list(self.level_map.keys())
|
111 |
+
for layer_num in range(self.num_layers):
|
112 |
+
op_name = np.random.choice(avaiable_op_names)
|
113 |
+
img = self.func[op_name](img, self.level_map[op_name])
|
114 |
+
return img
|
115 |
+
|
116 |
+
|
117 |
+
class RandAugment(RawRandAugment):
|
118 |
+
""" RandAugment wrapper to auto fit different img types """
|
119 |
+
|
120 |
+
def __init__(self, prob=0.5, *args, **kwargs):
|
121 |
+
self.prob = prob
|
122 |
+
if six.PY2:
|
123 |
+
super(RandAugment, self).__init__(*args, **kwargs)
|
124 |
+
else:
|
125 |
+
super().__init__(*args, **kwargs)
|
126 |
+
|
127 |
+
def __call__(self, data):
|
128 |
+
if np.random.rand() > self.prob:
|
129 |
+
return data
|
130 |
+
img = data['image']
|
131 |
+
if not isinstance(img, Image.Image):
|
132 |
+
img = np.ascontiguousarray(img)
|
133 |
+
img = Image.fromarray(img)
|
134 |
+
|
135 |
+
if six.PY2:
|
136 |
+
img = super(RandAugment, self).__call__(img)
|
137 |
+
else:
|
138 |
+
img = super().__call__(img)
|
139 |
+
|
140 |
+
if isinstance(img, Image.Image):
|
141 |
+
img = np.asarray(img)
|
142 |
+
data['image'] = img
|
143 |
+
return data
|
ppocr/data/imaug/random_crop_data.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is refer from:
|
16 |
+
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py
|
17 |
+
"""
|
18 |
+
|
19 |
+
from __future__ import absolute_import
|
20 |
+
from __future__ import division
|
21 |
+
from __future__ import print_function
|
22 |
+
from __future__ import unicode_literals
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import cv2
|
26 |
+
import random
|
27 |
+
|
28 |
+
|
29 |
+
def is_poly_in_rect(poly, x, y, w, h):
|
30 |
+
poly = np.array(poly)
|
31 |
+
if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
|
32 |
+
return False
|
33 |
+
if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
|
34 |
+
return False
|
35 |
+
return True
|
36 |
+
|
37 |
+
|
38 |
+
def is_poly_outside_rect(poly, x, y, w, h):
|
39 |
+
poly = np.array(poly)
|
40 |
+
if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
|
41 |
+
return True
|
42 |
+
if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
|
43 |
+
return True
|
44 |
+
return False
|
45 |
+
|
46 |
+
|
47 |
+
def split_regions(axis):
|
48 |
+
regions = []
|
49 |
+
min_axis = 0
|
50 |
+
for i in range(1, axis.shape[0]):
|
51 |
+
if axis[i] != axis[i - 1] + 1:
|
52 |
+
region = axis[min_axis:i]
|
53 |
+
min_axis = i
|
54 |
+
regions.append(region)
|
55 |
+
return regions
|
56 |
+
|
57 |
+
|
58 |
+
def random_select(axis, max_size):
|
59 |
+
xx = np.random.choice(axis, size=2)
|
60 |
+
xmin = np.min(xx)
|
61 |
+
xmax = np.max(xx)
|
62 |
+
xmin = np.clip(xmin, 0, max_size - 1)
|
63 |
+
xmax = np.clip(xmax, 0, max_size - 1)
|
64 |
+
return xmin, xmax
|
65 |
+
|
66 |
+
|
67 |
+
def region_wise_random_select(regions, max_size):
|
68 |
+
selected_index = list(np.random.choice(len(regions), 2))
|
69 |
+
selected_values = []
|
70 |
+
for index in selected_index:
|
71 |
+
axis = regions[index]
|
72 |
+
xx = int(np.random.choice(axis, size=1))
|
73 |
+
selected_values.append(xx)
|
74 |
+
xmin = min(selected_values)
|
75 |
+
xmax = max(selected_values)
|
76 |
+
return xmin, xmax
|
77 |
+
|
78 |
+
|
79 |
+
def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
|
80 |
+
h, w, _ = im.shape
|
81 |
+
h_array = np.zeros(h, dtype=np.int32)
|
82 |
+
w_array = np.zeros(w, dtype=np.int32)
|
83 |
+
for points in text_polys:
|
84 |
+
points = np.round(points, decimals=0).astype(np.int32)
|
85 |
+
minx = np.min(points[:, 0])
|
86 |
+
maxx = np.max(points[:, 0])
|
87 |
+
w_array[minx:maxx] = 1
|
88 |
+
miny = np.min(points[:, 1])
|
89 |
+
maxy = np.max(points[:, 1])
|
90 |
+
h_array[miny:maxy] = 1
|
91 |
+
# ensure the cropped area not across a text
|
92 |
+
h_axis = np.where(h_array == 0)[0]
|
93 |
+
w_axis = np.where(w_array == 0)[0]
|
94 |
+
|
95 |
+
if len(h_axis) == 0 or len(w_axis) == 0:
|
96 |
+
return 0, 0, w, h
|
97 |
+
|
98 |
+
h_regions = split_regions(h_axis)
|
99 |
+
w_regions = split_regions(w_axis)
|
100 |
+
|
101 |
+
for i in range(max_tries):
|
102 |
+
if len(w_regions) > 1:
|
103 |
+
xmin, xmax = region_wise_random_select(w_regions, w)
|
104 |
+
else:
|
105 |
+
xmin, xmax = random_select(w_axis, w)
|
106 |
+
if len(h_regions) > 1:
|
107 |
+
ymin, ymax = region_wise_random_select(h_regions, h)
|
108 |
+
else:
|
109 |
+
ymin, ymax = random_select(h_axis, h)
|
110 |
+
|
111 |
+
if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h:
|
112 |
+
# area too small
|
113 |
+
continue
|
114 |
+
num_poly_in_rect = 0
|
115 |
+
for poly in text_polys:
|
116 |
+
if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
|
117 |
+
ymax - ymin):
|
118 |
+
num_poly_in_rect += 1
|
119 |
+
break
|
120 |
+
|
121 |
+
if num_poly_in_rect > 0:
|
122 |
+
return xmin, ymin, xmax - xmin, ymax - ymin
|
123 |
+
|
124 |
+
return 0, 0, w, h
|
125 |
+
|
126 |
+
|
127 |
+
class EastRandomCropData(object):
|
128 |
+
def __init__(self,
|
129 |
+
size=(640, 640),
|
130 |
+
max_tries=10,
|
131 |
+
min_crop_side_ratio=0.1,
|
132 |
+
keep_ratio=True,
|
133 |
+
**kwargs):
|
134 |
+
self.size = size
|
135 |
+
self.max_tries = max_tries
|
136 |
+
self.min_crop_side_ratio = min_crop_side_ratio
|
137 |
+
self.keep_ratio = keep_ratio
|
138 |
+
|
139 |
+
def __call__(self, data):
|
140 |
+
img = data['image']
|
141 |
+
text_polys = data['polys']
|
142 |
+
ignore_tags = data['ignore_tags']
|
143 |
+
texts = data['texts']
|
144 |
+
all_care_polys = [
|
145 |
+
text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
|
146 |
+
]
|
147 |
+
# 计算crop区域
|
148 |
+
crop_x, crop_y, crop_w, crop_h = crop_area(
|
149 |
+
img, all_care_polys, self.min_crop_side_ratio, self.max_tries)
|
150 |
+
# crop 图片 保持比例填充
|
151 |
+
scale_w = self.size[0] / crop_w
|
152 |
+
scale_h = self.size[1] / crop_h
|
153 |
+
scale = min(scale_w, scale_h)
|
154 |
+
h = int(crop_h * scale)
|
155 |
+
w = int(crop_w * scale)
|
156 |
+
if self.keep_ratio:
|
157 |
+
padimg = np.zeros((self.size[1], self.size[0], img.shape[2]),
|
158 |
+
img.dtype)
|
159 |
+
padimg[:h, :w] = cv2.resize(
|
160 |
+
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
|
161 |
+
img = padimg
|
162 |
+
else:
|
163 |
+
img = cv2.resize(
|
164 |
+
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
|
165 |
+
tuple(self.size))
|
166 |
+
# crop 文本框
|
167 |
+
text_polys_crop = []
|
168 |
+
ignore_tags_crop = []
|
169 |
+
texts_crop = []
|
170 |
+
for poly, text, tag in zip(text_polys, texts, ignore_tags):
|
171 |
+
poly = ((poly - (crop_x, crop_y)) * scale).tolist()
|
172 |
+
if not is_poly_outside_rect(poly, 0, 0, w, h):
|
173 |
+
text_polys_crop.append(poly)
|
174 |
+
ignore_tags_crop.append(tag)
|
175 |
+
texts_crop.append(text)
|
176 |
+
data['image'] = img
|
177 |
+
data['polys'] = np.array(text_polys_crop)
|
178 |
+
data['ignore_tags'] = ignore_tags_crop
|
179 |
+
data['texts'] = texts_crop
|
180 |
+
return data
|
181 |
+
|
182 |
+
|
183 |
+
class RandomCropImgMask(object):
|
184 |
+
def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs):
|
185 |
+
self.size = size
|
186 |
+
self.main_key = main_key
|
187 |
+
self.crop_keys = crop_keys
|
188 |
+
self.p = p
|
189 |
+
|
190 |
+
def __call__(self, data):
|
191 |
+
image = data['image']
|
192 |
+
|
193 |
+
h, w = image.shape[0:2]
|
194 |
+
th, tw = self.size
|
195 |
+
if w == tw and h == th:
|
196 |
+
return data
|
197 |
+
|
198 |
+
mask = data[self.main_key]
|
199 |
+
if np.max(mask) > 0 and random.random() > self.p:
|
200 |
+
# make sure to crop the text region
|
201 |
+
tl = np.min(np.where(mask > 0), axis=1) - (th, tw)
|
202 |
+
tl[tl < 0] = 0
|
203 |
+
br = np.max(np.where(mask > 0), axis=1) - (th, tw)
|
204 |
+
br[br < 0] = 0
|
205 |
+
|
206 |
+
br[0] = min(br[0], h - th)
|
207 |
+
br[1] = min(br[1], w - tw)
|
208 |
+
|
209 |
+
i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
|
210 |
+
j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
|
211 |
+
else:
|
212 |
+
i = random.randint(0, h - th) if h - th > 0 else 0
|
213 |
+
j = random.randint(0, w - tw) if w - tw > 0 else 0
|
214 |
+
|
215 |
+
# return i, j, th, tw
|
216 |
+
for k in data:
|
217 |
+
if k in self.crop_keys:
|
218 |
+
if len(data[k].shape) == 3:
|
219 |
+
if np.argmin(data[k].shape) == 0:
|
220 |
+
img = data[k][:, i:i + th, j:j + tw]
|
221 |
+
if img.shape[1] != img.shape[2]:
|
222 |
+
a = 1
|
223 |
+
elif np.argmin(data[k].shape) == 2:
|
224 |
+
img = data[k][i:i + th, j:j + tw, :]
|
225 |
+
if img.shape[1] != img.shape[0]:
|
226 |
+
a = 1
|
227 |
+
else:
|
228 |
+
img = data[k]
|
229 |
+
else:
|
230 |
+
img = data[k][i:i + th, j:j + tw]
|
231 |
+
if img.shape[0] != img.shape[1]:
|
232 |
+
a = 1
|
233 |
+
data[k] = img
|
234 |
+
return data
|
ppocr/data/imaug/rec_img_aug.py
ADDED
@@ -0,0 +1,825 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import cv2
|
17 |
+
import numpy as np
|
18 |
+
import random
|
19 |
+
import copy
|
20 |
+
from PIL import Image
|
21 |
+
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
22 |
+
from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration
|
23 |
+
from paddle.vision.transforms import Compose
|
24 |
+
|
25 |
+
|
26 |
+
class RecAug(object):
|
27 |
+
def __init__(self,
|
28 |
+
tia_prob=0.4,
|
29 |
+
crop_prob=0.4,
|
30 |
+
reverse_prob=0.4,
|
31 |
+
noise_prob=0.4,
|
32 |
+
jitter_prob=0.4,
|
33 |
+
blur_prob=0.4,
|
34 |
+
hsv_aug_prob=0.4,
|
35 |
+
**kwargs):
|
36 |
+
self.tia_prob = tia_prob
|
37 |
+
self.bda = BaseDataAugmentation(crop_prob, reverse_prob, noise_prob,
|
38 |
+
jitter_prob, blur_prob, hsv_aug_prob)
|
39 |
+
|
40 |
+
def __call__(self, data):
|
41 |
+
img = data['image']
|
42 |
+
h, w, _ = img.shape
|
43 |
+
|
44 |
+
# tia
|
45 |
+
if random.random() <= self.tia_prob:
|
46 |
+
if h >= 20 and w >= 20:
|
47 |
+
img = tia_distort(img, random.randint(3, 6))
|
48 |
+
img = tia_stretch(img, random.randint(3, 6))
|
49 |
+
img = tia_perspective(img)
|
50 |
+
|
51 |
+
# bda
|
52 |
+
data['image'] = img
|
53 |
+
data = self.bda(data)
|
54 |
+
return data
|
55 |
+
|
56 |
+
|
57 |
+
class BaseDataAugmentation(object):
|
58 |
+
def __init__(self,
|
59 |
+
crop_prob=0.4,
|
60 |
+
reverse_prob=0.4,
|
61 |
+
noise_prob=0.4,
|
62 |
+
jitter_prob=0.4,
|
63 |
+
blur_prob=0.4,
|
64 |
+
hsv_aug_prob=0.4,
|
65 |
+
**kwargs):
|
66 |
+
self.crop_prob = crop_prob
|
67 |
+
self.reverse_prob = reverse_prob
|
68 |
+
self.noise_prob = noise_prob
|
69 |
+
self.jitter_prob = jitter_prob
|
70 |
+
self.blur_prob = blur_prob
|
71 |
+
self.hsv_aug_prob = hsv_aug_prob
|
72 |
+
|
73 |
+
def __call__(self, data):
|
74 |
+
img = data['image']
|
75 |
+
h, w, _ = img.shape
|
76 |
+
|
77 |
+
if random.random() <= self.crop_prob and h >= 20 and w >= 20:
|
78 |
+
img = get_crop(img)
|
79 |
+
|
80 |
+
if random.random() <= self.blur_prob:
|
81 |
+
img = blur(img)
|
82 |
+
|
83 |
+
if random.random() <= self.hsv_aug_prob:
|
84 |
+
img = hsv_aug(img)
|
85 |
+
|
86 |
+
if random.random() <= self.jitter_prob:
|
87 |
+
img = jitter(img)
|
88 |
+
|
89 |
+
if random.random() <= self.noise_prob:
|
90 |
+
img = add_gasuss_noise(img)
|
91 |
+
|
92 |
+
if random.random() <= self.reverse_prob:
|
93 |
+
img = 255 - img
|
94 |
+
|
95 |
+
data['image'] = img
|
96 |
+
return data
|
97 |
+
|
98 |
+
|
99 |
+
class ABINetRecAug(object):
|
100 |
+
def __init__(self,
|
101 |
+
geometry_p=0.5,
|
102 |
+
deterioration_p=0.25,
|
103 |
+
colorjitter_p=0.25,
|
104 |
+
**kwargs):
|
105 |
+
self.transforms = Compose([
|
106 |
+
CVGeometry(
|
107 |
+
degrees=45,
|
108 |
+
translate=(0.0, 0.0),
|
109 |
+
scale=(0.5, 2.),
|
110 |
+
shear=(45, 15),
|
111 |
+
distortion=0.5,
|
112 |
+
p=geometry_p), CVDeterioration(
|
113 |
+
var=20, degrees=6, factor=4, p=deterioration_p),
|
114 |
+
CVColorJitter(
|
115 |
+
brightness=0.5,
|
116 |
+
contrast=0.5,
|
117 |
+
saturation=0.5,
|
118 |
+
hue=0.1,
|
119 |
+
p=colorjitter_p)
|
120 |
+
])
|
121 |
+
|
122 |
+
def __call__(self, data):
|
123 |
+
img = data['image']
|
124 |
+
img = self.transforms(img)
|
125 |
+
data['image'] = img
|
126 |
+
return data
|
127 |
+
|
128 |
+
|
129 |
+
class RecConAug(object):
|
130 |
+
def __init__(self,
|
131 |
+
prob=0.5,
|
132 |
+
image_shape=(32, 320, 3),
|
133 |
+
max_text_length=25,
|
134 |
+
ext_data_num=1,
|
135 |
+
**kwargs):
|
136 |
+
self.ext_data_num = ext_data_num
|
137 |
+
self.prob = prob
|
138 |
+
self.max_text_length = max_text_length
|
139 |
+
self.image_shape = image_shape
|
140 |
+
self.max_wh_ratio = self.image_shape[1] / self.image_shape[0]
|
141 |
+
|
142 |
+
def merge_ext_data(self, data, ext_data):
|
143 |
+
ori_w = round(data['image'].shape[1] / data['image'].shape[0] *
|
144 |
+
self.image_shape[0])
|
145 |
+
ext_w = round(ext_data['image'].shape[1] / ext_data['image'].shape[0] *
|
146 |
+
self.image_shape[0])
|
147 |
+
data['image'] = cv2.resize(data['image'], (ori_w, self.image_shape[0]))
|
148 |
+
ext_data['image'] = cv2.resize(ext_data['image'],
|
149 |
+
(ext_w, self.image_shape[0]))
|
150 |
+
data['image'] = np.concatenate(
|
151 |
+
[data['image'], ext_data['image']], axis=1)
|
152 |
+
data["label"] += ext_data["label"]
|
153 |
+
return data
|
154 |
+
|
155 |
+
def __call__(self, data):
|
156 |
+
rnd_num = random.random()
|
157 |
+
if rnd_num > self.prob:
|
158 |
+
return data
|
159 |
+
for idx, ext_data in enumerate(data["ext_data"]):
|
160 |
+
if len(data["label"]) + len(ext_data[
|
161 |
+
"label"]) > self.max_text_length:
|
162 |
+
break
|
163 |
+
concat_ratio = data['image'].shape[1] / data['image'].shape[
|
164 |
+
0] + ext_data['image'].shape[1] / ext_data['image'].shape[0]
|
165 |
+
if concat_ratio > self.max_wh_ratio:
|
166 |
+
break
|
167 |
+
data = self.merge_ext_data(data, ext_data)
|
168 |
+
data.pop("ext_data")
|
169 |
+
return data
|
170 |
+
|
171 |
+
|
172 |
+
class SVTRRecAug(object):
|
173 |
+
def __init__(self,
|
174 |
+
aug_type=0,
|
175 |
+
geometry_p=0.5,
|
176 |
+
deterioration_p=0.25,
|
177 |
+
colorjitter_p=0.25,
|
178 |
+
**kwargs):
|
179 |
+
self.transforms = Compose([
|
180 |
+
SVTRGeometry(
|
181 |
+
aug_type=aug_type,
|
182 |
+
degrees=45,
|
183 |
+
translate=(0.0, 0.0),
|
184 |
+
scale=(0.5, 2.),
|
185 |
+
shear=(45, 15),
|
186 |
+
distortion=0.5,
|
187 |
+
p=geometry_p), SVTRDeterioration(
|
188 |
+
var=20, degrees=6, factor=4, p=deterioration_p),
|
189 |
+
CVColorJitter(
|
190 |
+
brightness=0.5,
|
191 |
+
contrast=0.5,
|
192 |
+
saturation=0.5,
|
193 |
+
hue=0.1,
|
194 |
+
p=colorjitter_p)
|
195 |
+
])
|
196 |
+
|
197 |
+
def __call__(self, data):
|
198 |
+
img = data['image']
|
199 |
+
img = self.transforms(img)
|
200 |
+
data['image'] = img
|
201 |
+
return data
|
202 |
+
|
203 |
+
|
204 |
+
class ClsResizeImg(object):
|
205 |
+
def __init__(self, image_shape, **kwargs):
|
206 |
+
self.image_shape = image_shape
|
207 |
+
|
208 |
+
def __call__(self, data):
|
209 |
+
img = data['image']
|
210 |
+
norm_img, _ = resize_norm_img(img, self.image_shape)
|
211 |
+
data['image'] = norm_img
|
212 |
+
return data
|
213 |
+
|
214 |
+
|
215 |
+
class RecResizeImg(object):
|
216 |
+
def __init__(self,
|
217 |
+
image_shape,
|
218 |
+
infer_mode=False,
|
219 |
+
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
|
220 |
+
padding=True,
|
221 |
+
**kwargs):
|
222 |
+
self.image_shape = image_shape
|
223 |
+
self.infer_mode = infer_mode
|
224 |
+
self.character_dict_path = character_dict_path
|
225 |
+
self.padding = padding
|
226 |
+
|
227 |
+
def __call__(self, data):
|
228 |
+
img = data['image']
|
229 |
+
if self.infer_mode and self.character_dict_path is not None:
|
230 |
+
norm_img, valid_ratio = resize_norm_img_chinese(img,
|
231 |
+
self.image_shape)
|
232 |
+
else:
|
233 |
+
norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
|
234 |
+
self.padding)
|
235 |
+
data['image'] = norm_img
|
236 |
+
data['valid_ratio'] = valid_ratio
|
237 |
+
return data
|
238 |
+
|
239 |
+
|
240 |
+
class VLRecResizeImg(object):
|
241 |
+
def __init__(self,
|
242 |
+
image_shape,
|
243 |
+
infer_mode=False,
|
244 |
+
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
|
245 |
+
padding=True,
|
246 |
+
**kwargs):
|
247 |
+
self.image_shape = image_shape
|
248 |
+
self.infer_mode = infer_mode
|
249 |
+
self.character_dict_path = character_dict_path
|
250 |
+
self.padding = padding
|
251 |
+
|
252 |
+
def __call__(self, data):
|
253 |
+
img = data['image']
|
254 |
+
|
255 |
+
imgC, imgH, imgW = self.image_shape
|
256 |
+
resized_image = cv2.resize(
|
257 |
+
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
258 |
+
resized_w = imgW
|
259 |
+
resized_image = resized_image.astype('float32')
|
260 |
+
if self.image_shape[0] == 1:
|
261 |
+
resized_image = resized_image / 255
|
262 |
+
norm_img = resized_image[np.newaxis, :]
|
263 |
+
else:
|
264 |
+
norm_img = resized_image.transpose((2, 0, 1)) / 255
|
265 |
+
valid_ratio = min(1.0, float(resized_w / imgW))
|
266 |
+
|
267 |
+
data['image'] = norm_img
|
268 |
+
data['valid_ratio'] = valid_ratio
|
269 |
+
return data
|
270 |
+
|
271 |
+
|
272 |
+
class RFLRecResizeImg(object):
|
273 |
+
def __init__(self, image_shape, padding=True, interpolation=1, **kwargs):
|
274 |
+
self.image_shape = image_shape
|
275 |
+
self.padding = padding
|
276 |
+
|
277 |
+
self.interpolation = interpolation
|
278 |
+
if self.interpolation == 0:
|
279 |
+
self.interpolation = cv2.INTER_NEAREST
|
280 |
+
elif self.interpolation == 1:
|
281 |
+
self.interpolation = cv2.INTER_LINEAR
|
282 |
+
elif self.interpolation == 2:
|
283 |
+
self.interpolation = cv2.INTER_CUBIC
|
284 |
+
elif self.interpolation == 3:
|
285 |
+
self.interpolation = cv2.INTER_AREA
|
286 |
+
else:
|
287 |
+
raise Exception("Unsupported interpolation type !!!")
|
288 |
+
|
289 |
+
def __call__(self, data):
|
290 |
+
img = data['image']
|
291 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
292 |
+
norm_img, valid_ratio = resize_norm_img(
|
293 |
+
img, self.image_shape, self.padding, self.interpolation)
|
294 |
+
data['image'] = norm_img
|
295 |
+
data['valid_ratio'] = valid_ratio
|
296 |
+
return data
|
297 |
+
|
298 |
+
|
299 |
+
class SRNRecResizeImg(object):
|
300 |
+
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
301 |
+
self.image_shape = image_shape
|
302 |
+
self.num_heads = num_heads
|
303 |
+
self.max_text_length = max_text_length
|
304 |
+
|
305 |
+
def __call__(self, data):
|
306 |
+
img = data['image']
|
307 |
+
norm_img = resize_norm_img_srn(img, self.image_shape)
|
308 |
+
data['image'] = norm_img
|
309 |
+
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
310 |
+
srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
|
311 |
+
|
312 |
+
data['encoder_word_pos'] = encoder_word_pos
|
313 |
+
data['gsrm_word_pos'] = gsrm_word_pos
|
314 |
+
data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
|
315 |
+
data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
|
316 |
+
return data
|
317 |
+
|
318 |
+
|
319 |
+
class SARRecResizeImg(object):
|
320 |
+
def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
|
321 |
+
self.image_shape = image_shape
|
322 |
+
self.width_downsample_ratio = width_downsample_ratio
|
323 |
+
|
324 |
+
def __call__(self, data):
|
325 |
+
img = data['image']
|
326 |
+
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
|
327 |
+
img, self.image_shape, self.width_downsample_ratio)
|
328 |
+
data['image'] = norm_img
|
329 |
+
data['resized_shape'] = resize_shape
|
330 |
+
data['pad_shape'] = pad_shape
|
331 |
+
data['valid_ratio'] = valid_ratio
|
332 |
+
return data
|
333 |
+
|
334 |
+
|
335 |
+
class PRENResizeImg(object):
|
336 |
+
def __init__(self, image_shape, **kwargs):
|
337 |
+
"""
|
338 |
+
Accroding to original paper's realization, it's a hard resize method here.
|
339 |
+
So maybe you should optimize it to fit for your task better.
|
340 |
+
"""
|
341 |
+
self.dst_h, self.dst_w = image_shape
|
342 |
+
|
343 |
+
def __call__(self, data):
|
344 |
+
img = data['image']
|
345 |
+
resized_img = cv2.resize(
|
346 |
+
img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR)
|
347 |
+
resized_img = resized_img.transpose((2, 0, 1)) / 255
|
348 |
+
resized_img -= 0.5
|
349 |
+
resized_img /= 0.5
|
350 |
+
data['image'] = resized_img.astype(np.float32)
|
351 |
+
return data
|
352 |
+
|
353 |
+
|
354 |
+
class SPINRecResizeImg(object):
|
355 |
+
def __init__(self,
|
356 |
+
image_shape,
|
357 |
+
interpolation=2,
|
358 |
+
mean=(127.5, 127.5, 127.5),
|
359 |
+
std=(127.5, 127.5, 127.5),
|
360 |
+
**kwargs):
|
361 |
+
self.image_shape = image_shape
|
362 |
+
|
363 |
+
self.mean = np.array(mean, dtype=np.float32)
|
364 |
+
self.std = np.array(std, dtype=np.float32)
|
365 |
+
self.interpolation = interpolation
|
366 |
+
|
367 |
+
def __call__(self, data):
|
368 |
+
img = data['image']
|
369 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
370 |
+
# different interpolation type corresponding the OpenCV
|
371 |
+
if self.interpolation == 0:
|
372 |
+
interpolation = cv2.INTER_NEAREST
|
373 |
+
elif self.interpolation == 1:
|
374 |
+
interpolation = cv2.INTER_LINEAR
|
375 |
+
elif self.interpolation == 2:
|
376 |
+
interpolation = cv2.INTER_CUBIC
|
377 |
+
elif self.interpolation == 3:
|
378 |
+
interpolation = cv2.INTER_AREA
|
379 |
+
else:
|
380 |
+
raise Exception("Unsupported interpolation type !!!")
|
381 |
+
# Deal with the image error during image loading
|
382 |
+
if img is None:
|
383 |
+
return None
|
384 |
+
|
385 |
+
img = cv2.resize(img, tuple(self.image_shape), interpolation)
|
386 |
+
img = np.array(img, np.float32)
|
387 |
+
img = np.expand_dims(img, -1)
|
388 |
+
img = img.transpose((2, 0, 1))
|
389 |
+
# normalize the image
|
390 |
+
img = img.copy().astype(np.float32)
|
391 |
+
mean = np.float64(self.mean.reshape(1, -1))
|
392 |
+
stdinv = 1 / np.float64(self.std.reshape(1, -1))
|
393 |
+
img -= mean
|
394 |
+
img *= stdinv
|
395 |
+
data['image'] = img
|
396 |
+
return data
|
397 |
+
|
398 |
+
|
399 |
+
class GrayRecResizeImg(object):
|
400 |
+
def __init__(self,
|
401 |
+
image_shape,
|
402 |
+
resize_type,
|
403 |
+
inter_type='Image.ANTIALIAS',
|
404 |
+
scale=True,
|
405 |
+
padding=False,
|
406 |
+
**kwargs):
|
407 |
+
self.image_shape = image_shape
|
408 |
+
self.resize_type = resize_type
|
409 |
+
self.padding = padding
|
410 |
+
self.inter_type = eval(inter_type)
|
411 |
+
self.scale = scale
|
412 |
+
|
413 |
+
def __call__(self, data):
|
414 |
+
img = data['image']
|
415 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
416 |
+
image_shape = self.image_shape
|
417 |
+
if self.padding:
|
418 |
+
imgC, imgH, imgW = image_shape
|
419 |
+
# todo: change to 0 and modified image shape
|
420 |
+
h = img.shape[0]
|
421 |
+
w = img.shape[1]
|
422 |
+
ratio = w / float(h)
|
423 |
+
if math.ceil(imgH * ratio) > imgW:
|
424 |
+
resized_w = imgW
|
425 |
+
else:
|
426 |
+
resized_w = int(math.ceil(imgH * ratio))
|
427 |
+
resized_image = cv2.resize(img, (resized_w, imgH))
|
428 |
+
norm_img = np.expand_dims(resized_image, -1)
|
429 |
+
norm_img = norm_img.transpose((2, 0, 1))
|
430 |
+
resized_image = norm_img.astype(np.float32) / 128. - 1.
|
431 |
+
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
432 |
+
padding_im[:, :, 0:resized_w] = resized_image
|
433 |
+
data['image'] = padding_im
|
434 |
+
return data
|
435 |
+
if self.resize_type == 'PIL':
|
436 |
+
image_pil = Image.fromarray(np.uint8(img))
|
437 |
+
img = image_pil.resize(self.image_shape, self.inter_type)
|
438 |
+
img = np.array(img)
|
439 |
+
if self.resize_type == 'OpenCV':
|
440 |
+
img = cv2.resize(img, self.image_shape)
|
441 |
+
norm_img = np.expand_dims(img, -1)
|
442 |
+
norm_img = norm_img.transpose((2, 0, 1))
|
443 |
+
if self.scale:
|
444 |
+
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
445 |
+
else:
|
446 |
+
data['image'] = norm_img.astype(np.float32) / 255.
|
447 |
+
return data
|
448 |
+
|
449 |
+
|
450 |
+
class ABINetRecResizeImg(object):
|
451 |
+
def __init__(self, image_shape, **kwargs):
|
452 |
+
self.image_shape = image_shape
|
453 |
+
|
454 |
+
def __call__(self, data):
|
455 |
+
img = data['image']
|
456 |
+
norm_img, valid_ratio = resize_norm_img_abinet(img, self.image_shape)
|
457 |
+
data['image'] = norm_img
|
458 |
+
data['valid_ratio'] = valid_ratio
|
459 |
+
return data
|
460 |
+
|
461 |
+
|
462 |
+
class SVTRRecResizeImg(object):
|
463 |
+
def __init__(self, image_shape, padding=True, **kwargs):
|
464 |
+
self.image_shape = image_shape
|
465 |
+
self.padding = padding
|
466 |
+
|
467 |
+
def __call__(self, data):
|
468 |
+
img = data['image']
|
469 |
+
|
470 |
+
norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
|
471 |
+
self.padding)
|
472 |
+
data['image'] = norm_img
|
473 |
+
data['valid_ratio'] = valid_ratio
|
474 |
+
return data
|
475 |
+
|
476 |
+
|
477 |
+
class RobustScannerRecResizeImg(object):
|
478 |
+
def __init__(self,
|
479 |
+
image_shape,
|
480 |
+
max_text_length,
|
481 |
+
width_downsample_ratio=0.25,
|
482 |
+
**kwargs):
|
483 |
+
self.image_shape = image_shape
|
484 |
+
self.width_downsample_ratio = width_downsample_ratio
|
485 |
+
self.max_text_length = max_text_length
|
486 |
+
|
487 |
+
def __call__(self, data):
|
488 |
+
img = data['image']
|
489 |
+
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
|
490 |
+
img, self.image_shape, self.width_downsample_ratio)
|
491 |
+
word_positons = np.array(range(0, self.max_text_length)).astype('int64')
|
492 |
+
data['image'] = norm_img
|
493 |
+
data['resized_shape'] = resize_shape
|
494 |
+
data['pad_shape'] = pad_shape
|
495 |
+
data['valid_ratio'] = valid_ratio
|
496 |
+
data['word_positons'] = word_positons
|
497 |
+
return data
|
498 |
+
|
499 |
+
|
500 |
+
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
501 |
+
imgC, imgH, imgW_min, imgW_max = image_shape
|
502 |
+
h = img.shape[0]
|
503 |
+
w = img.shape[1]
|
504 |
+
valid_ratio = 1.0
|
505 |
+
# make sure new_width is an integral multiple of width_divisor.
|
506 |
+
width_divisor = int(1 / width_downsample_ratio)
|
507 |
+
# resize
|
508 |
+
ratio = w / float(h)
|
509 |
+
resize_w = math.ceil(imgH * ratio)
|
510 |
+
if resize_w % width_divisor != 0:
|
511 |
+
resize_w = round(resize_w / width_divisor) * width_divisor
|
512 |
+
if imgW_min is not None:
|
513 |
+
resize_w = max(imgW_min, resize_w)
|
514 |
+
if imgW_max is not None:
|
515 |
+
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
516 |
+
resize_w = min(imgW_max, resize_w)
|
517 |
+
resized_image = cv2.resize(img, (resize_w, imgH))
|
518 |
+
resized_image = resized_image.astype('float32')
|
519 |
+
# norm
|
520 |
+
if image_shape[0] == 1:
|
521 |
+
resized_image = resized_image / 255
|
522 |
+
resized_image = resized_image[np.newaxis, :]
|
523 |
+
else:
|
524 |
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
525 |
+
resized_image -= 0.5
|
526 |
+
resized_image /= 0.5
|
527 |
+
resize_shape = resized_image.shape
|
528 |
+
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
529 |
+
padding_im[:, :, 0:resize_w] = resized_image
|
530 |
+
pad_shape = padding_im.shape
|
531 |
+
|
532 |
+
return padding_im, resize_shape, pad_shape, valid_ratio
|
533 |
+
|
534 |
+
|
535 |
+
def resize_norm_img(img,
|
536 |
+
image_shape,
|
537 |
+
padding=True,
|
538 |
+
interpolation=cv2.INTER_LINEAR):
|
539 |
+
imgC, imgH, imgW = image_shape
|
540 |
+
h = img.shape[0]
|
541 |
+
w = img.shape[1]
|
542 |
+
if not padding:
|
543 |
+
resized_image = cv2.resize(
|
544 |
+
img, (imgW, imgH), interpolation=interpolation)
|
545 |
+
resized_w = imgW
|
546 |
+
else:
|
547 |
+
ratio = w / float(h)
|
548 |
+
if math.ceil(imgH * ratio) > imgW:
|
549 |
+
resized_w = imgW
|
550 |
+
else:
|
551 |
+
resized_w = int(math.ceil(imgH * ratio))
|
552 |
+
resized_image = cv2.resize(img, (resized_w, imgH))
|
553 |
+
resized_image = resized_image.astype('float32')
|
554 |
+
if image_shape[0] == 1:
|
555 |
+
resized_image = resized_image / 255
|
556 |
+
resized_image = resized_image[np.newaxis, :]
|
557 |
+
else:
|
558 |
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
559 |
+
resized_image -= 0.5
|
560 |
+
resized_image /= 0.5
|
561 |
+
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
562 |
+
padding_im[:, :, 0:resized_w] = resized_image
|
563 |
+
valid_ratio = min(1.0, float(resized_w / imgW))
|
564 |
+
return padding_im, valid_ratio
|
565 |
+
|
566 |
+
|
567 |
+
def resize_norm_img_chinese(img, image_shape):
|
568 |
+
imgC, imgH, imgW = image_shape
|
569 |
+
# todo: change to 0 and modified image shape
|
570 |
+
max_wh_ratio = imgW * 1.0 / imgH
|
571 |
+
h, w = img.shape[0], img.shape[1]
|
572 |
+
ratio = w * 1.0 / h
|
573 |
+
max_wh_ratio = max(max_wh_ratio, ratio)
|
574 |
+
imgW = int(imgH * max_wh_ratio)
|
575 |
+
if math.ceil(imgH * ratio) > imgW:
|
576 |
+
resized_w = imgW
|
577 |
+
else:
|
578 |
+
resized_w = int(math.ceil(imgH * ratio))
|
579 |
+
resized_image = cv2.resize(img, (resized_w, imgH))
|
580 |
+
resized_image = resized_image.astype('float32')
|
581 |
+
if image_shape[0] == 1:
|
582 |
+
resized_image = resized_image / 255
|
583 |
+
resized_image = resized_image[np.newaxis, :]
|
584 |
+
else:
|
585 |
+
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
586 |
+
resized_image -= 0.5
|
587 |
+
resized_image /= 0.5
|
588 |
+
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
589 |
+
padding_im[:, :, 0:resized_w] = resized_image
|
590 |
+
valid_ratio = min(1.0, float(resized_w / imgW))
|
591 |
+
return padding_im, valid_ratio
|
592 |
+
|
593 |
+
|
594 |
+
def resize_norm_img_srn(img, image_shape):
|
595 |
+
imgC, imgH, imgW = image_shape
|
596 |
+
|
597 |
+
img_black = np.zeros((imgH, imgW))
|
598 |
+
im_hei = img.shape[0]
|
599 |
+
im_wid = img.shape[1]
|
600 |
+
|
601 |
+
if im_wid <= im_hei * 1:
|
602 |
+
img_new = cv2.resize(img, (imgH * 1, imgH))
|
603 |
+
elif im_wid <= im_hei * 2:
|
604 |
+
img_new = cv2.resize(img, (imgH * 2, imgH))
|
605 |
+
elif im_wid <= im_hei * 3:
|
606 |
+
img_new = cv2.resize(img, (imgH * 3, imgH))
|
607 |
+
else:
|
608 |
+
img_new = cv2.resize(img, (imgW, imgH))
|
609 |
+
|
610 |
+
img_np = np.asarray(img_new)
|
611 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
612 |
+
img_black[:, 0:img_np.shape[1]] = img_np
|
613 |
+
img_black = img_black[:, :, np.newaxis]
|
614 |
+
|
615 |
+
row, col, c = img_black.shape
|
616 |
+
c = 1
|
617 |
+
|
618 |
+
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
619 |
+
|
620 |
+
|
621 |
+
def resize_norm_img_abinet(img, image_shape):
|
622 |
+
imgC, imgH, imgW = image_shape
|
623 |
+
|
624 |
+
resized_image = cv2.resize(
|
625 |
+
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
626 |
+
resized_w = imgW
|
627 |
+
resized_image = resized_image.astype('float32')
|
628 |
+
resized_image = resized_image / 255.
|
629 |
+
|
630 |
+
mean = np.array([0.485, 0.456, 0.406])
|
631 |
+
std = np.array([0.229, 0.224, 0.225])
|
632 |
+
resized_image = (
|
633 |
+
resized_image - mean[None, None, ...]) / std[None, None, ...]
|
634 |
+
resized_image = resized_image.transpose((2, 0, 1))
|
635 |
+
resized_image = resized_image.astype('float32')
|
636 |
+
|
637 |
+
valid_ratio = min(1.0, float(resized_w / imgW))
|
638 |
+
return resized_image, valid_ratio
|
639 |
+
|
640 |
+
|
641 |
+
def srn_other_inputs(image_shape, num_heads, max_text_length):
|
642 |
+
|
643 |
+
imgC, imgH, imgW = image_shape
|
644 |
+
feature_dim = int((imgH / 8) * (imgW / 8))
|
645 |
+
|
646 |
+
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
647 |
+
(feature_dim, 1)).astype('int64')
|
648 |
+
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
649 |
+
(max_text_length, 1)).astype('int64')
|
650 |
+
|
651 |
+
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
652 |
+
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
653 |
+
[1, max_text_length, max_text_length])
|
654 |
+
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
|
655 |
+
[num_heads, 1, 1]) * [-1e9]
|
656 |
+
|
657 |
+
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
658 |
+
[1, max_text_length, max_text_length])
|
659 |
+
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
|
660 |
+
[num_heads, 1, 1]) * [-1e9]
|
661 |
+
|
662 |
+
return [
|
663 |
+
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
664 |
+
gsrm_slf_attn_bias2
|
665 |
+
]
|
666 |
+
|
667 |
+
|
668 |
+
def flag():
|
669 |
+
"""
|
670 |
+
flag
|
671 |
+
"""
|
672 |
+
return 1 if random.random() > 0.5000001 else -1
|
673 |
+
|
674 |
+
|
675 |
+
def hsv_aug(img):
|
676 |
+
"""
|
677 |
+
cvtColor
|
678 |
+
"""
|
679 |
+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
680 |
+
delta = 0.001 * random.random() * flag()
|
681 |
+
hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta)
|
682 |
+
new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
|
683 |
+
return new_img
|
684 |
+
|
685 |
+
|
686 |
+
def blur(img):
|
687 |
+
"""
|
688 |
+
blur
|
689 |
+
"""
|
690 |
+
h, w, _ = img.shape
|
691 |
+
if h > 10 and w > 10:
|
692 |
+
return cv2.GaussianBlur(img, (5, 5), 1)
|
693 |
+
else:
|
694 |
+
return img
|
695 |
+
|
696 |
+
|
697 |
+
def jitter(img):
|
698 |
+
"""
|
699 |
+
jitter
|
700 |
+
"""
|
701 |
+
w, h, _ = img.shape
|
702 |
+
if h > 10 and w > 10:
|
703 |
+
thres = min(w, h)
|
704 |
+
s = int(random.random() * thres * 0.01)
|
705 |
+
src_img = img.copy()
|
706 |
+
for i in range(s):
|
707 |
+
img[i:, i:, :] = src_img[:w - i, :h - i, :]
|
708 |
+
return img
|
709 |
+
else:
|
710 |
+
return img
|
711 |
+
|
712 |
+
|
713 |
+
def add_gasuss_noise(image, mean=0, var=0.1):
|
714 |
+
"""
|
715 |
+
Gasuss noise
|
716 |
+
"""
|
717 |
+
|
718 |
+
noise = np.random.normal(mean, var**0.5, image.shape)
|
719 |
+
out = image + 0.5 * noise
|
720 |
+
out = np.clip(out, 0, 255)
|
721 |
+
out = np.uint8(out)
|
722 |
+
return out
|
723 |
+
|
724 |
+
|
725 |
+
def get_crop(image):
|
726 |
+
"""
|
727 |
+
random crop
|
728 |
+
"""
|
729 |
+
h, w, _ = image.shape
|
730 |
+
top_min = 1
|
731 |
+
top_max = 8
|
732 |
+
top_crop = int(random.randint(top_min, top_max))
|
733 |
+
top_crop = min(top_crop, h - 1)
|
734 |
+
crop_img = image.copy()
|
735 |
+
ratio = random.randint(0, 1)
|
736 |
+
if ratio:
|
737 |
+
crop_img = crop_img[top_crop:h, :, :]
|
738 |
+
else:
|
739 |
+
crop_img = crop_img[0:h - top_crop, :, :]
|
740 |
+
return crop_img
|
741 |
+
|
742 |
+
|
743 |
+
def rad(x):
|
744 |
+
"""
|
745 |
+
rad
|
746 |
+
"""
|
747 |
+
return x * np.pi / 180
|
748 |
+
|
749 |
+
|
750 |
+
def get_warpR(config):
|
751 |
+
"""
|
752 |
+
get_warpR
|
753 |
+
"""
|
754 |
+
anglex, angley, anglez, fov, w, h, r = \
|
755 |
+
config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r
|
756 |
+
if w > 69 and w < 112:
|
757 |
+
anglex = anglex * 1.5
|
758 |
+
|
759 |
+
z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2))
|
760 |
+
# Homogeneous coordinate transformation matrix
|
761 |
+
rx = np.array([[1, 0, 0, 0],
|
762 |
+
[0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [
|
763 |
+
0,
|
764 |
+
-np.sin(rad(anglex)),
|
765 |
+
np.cos(rad(anglex)),
|
766 |
+
0,
|
767 |
+
], [0, 0, 0, 1]], np.float32)
|
768 |
+
ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
|
769 |
+
[0, 1, 0, 0], [
|
770 |
+
-np.sin(rad(angley)),
|
771 |
+
0,
|
772 |
+
np.cos(rad(angley)),
|
773 |
+
0,
|
774 |
+
], [0, 0, 0, 1]], np.float32)
|
775 |
+
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
|
776 |
+
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
|
777 |
+
[0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
|
778 |
+
r = rx.dot(ry).dot(rz)
|
779 |
+
# generate 4 points
|
780 |
+
pcenter = np.array([h / 2, w / 2, 0, 0], np.float32)
|
781 |
+
p1 = np.array([0, 0, 0, 0], np.float32) - pcenter
|
782 |
+
p2 = np.array([w, 0, 0, 0], np.float32) - pcenter
|
783 |
+
p3 = np.array([0, h, 0, 0], np.float32) - pcenter
|
784 |
+
p4 = np.array([w, h, 0, 0], np.float32) - pcenter
|
785 |
+
dst1 = r.dot(p1)
|
786 |
+
dst2 = r.dot(p2)
|
787 |
+
dst3 = r.dot(p3)
|
788 |
+
dst4 = r.dot(p4)
|
789 |
+
list_dst = np.array([dst1, dst2, dst3, dst4])
|
790 |
+
org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
|
791 |
+
dst = np.zeros((4, 2), np.float32)
|
792 |
+
# Project onto the image plane
|
793 |
+
dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0]
|
794 |
+
dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1]
|
795 |
+
|
796 |
+
warpR = cv2.getPerspectiveTransform(org, dst)
|
797 |
+
|
798 |
+
dst1, dst2, dst3, dst4 = dst
|
799 |
+
r1 = int(min(dst1[1], dst2[1]))
|
800 |
+
r2 = int(max(dst3[1], dst4[1]))
|
801 |
+
c1 = int(min(dst1[0], dst3[0]))
|
802 |
+
c2 = int(max(dst2[0], dst4[0]))
|
803 |
+
|
804 |
+
try:
|
805 |
+
ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1))
|
806 |
+
|
807 |
+
dx = -c1
|
808 |
+
dy = -r1
|
809 |
+
T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]])
|
810 |
+
ret = T1.dot(warpR)
|
811 |
+
except:
|
812 |
+
ratio = 1.0
|
813 |
+
T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
|
814 |
+
ret = T1
|
815 |
+
return ret, (-r1, -c1), ratio, dst
|
816 |
+
|
817 |
+
|
818 |
+
def get_warpAffine(config):
|
819 |
+
"""
|
820 |
+
get_warpAffine
|
821 |
+
"""
|
822 |
+
anglez = config.anglez
|
823 |
+
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
|
824 |
+
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
|
825 |
+
return rz
|
ppocr/data/imaug/sast_process.py
ADDED
@@ -0,0 +1,777 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
#Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
#you may not use this file except in compliance with the License.
|
5 |
+
#You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
#Unless required by applicable law or agreed to in writing, software
|
10 |
+
#distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
#See the License for the specific language governing permissions and
|
13 |
+
#limitations under the License.
|
14 |
+
"""
|
15 |
+
This part code is refered from:
|
16 |
+
https://github.com/songdejia/EAST/blob/master/data_utils.py
|
17 |
+
"""
|
18 |
+
import math
|
19 |
+
import cv2
|
20 |
+
import numpy as np
|
21 |
+
import json
|
22 |
+
import sys
|
23 |
+
import os
|
24 |
+
|
25 |
+
__all__ = ['SASTProcessTrain']
|
26 |
+
|
27 |
+
|
28 |
+
class SASTProcessTrain(object):
|
29 |
+
def __init__(self,
|
30 |
+
image_shape=[512, 512],
|
31 |
+
min_crop_size=24,
|
32 |
+
min_crop_side_ratio=0.3,
|
33 |
+
min_text_size=10,
|
34 |
+
max_text_size=512,
|
35 |
+
**kwargs):
|
36 |
+
self.input_size = image_shape[1]
|
37 |
+
self.min_crop_size = min_crop_size
|
38 |
+
self.min_crop_side_ratio = min_crop_side_ratio
|
39 |
+
self.min_text_size = min_text_size
|
40 |
+
self.max_text_size = max_text_size
|
41 |
+
|
42 |
+
def quad_area(self, poly):
|
43 |
+
"""
|
44 |
+
compute area of a polygon
|
45 |
+
:param poly:
|
46 |
+
:return:
|
47 |
+
"""
|
48 |
+
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
49 |
+
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
50 |
+
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
51 |
+
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
|
52 |
+
return np.sum(edge) / 2.
|
53 |
+
|
54 |
+
def gen_quad_from_poly(self, poly):
|
55 |
+
"""
|
56 |
+
Generate min area quad from poly.
|
57 |
+
"""
|
58 |
+
point_num = poly.shape[0]
|
59 |
+
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
60 |
+
if True:
|
61 |
+
rect = cv2.minAreaRect(poly.astype(
|
62 |
+
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
63 |
+
center_point = rect[0]
|
64 |
+
box = np.array(cv2.boxPoints(rect))
|
65 |
+
|
66 |
+
first_point_idx = 0
|
67 |
+
min_dist = 1e4
|
68 |
+
for i in range(4):
|
69 |
+
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
70 |
+
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
71 |
+
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
72 |
+
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
73 |
+
if dist < min_dist:
|
74 |
+
min_dist = dist
|
75 |
+
first_point_idx = i
|
76 |
+
for i in range(4):
|
77 |
+
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
78 |
+
|
79 |
+
return min_area_quad
|
80 |
+
|
81 |
+
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
|
82 |
+
"""
|
83 |
+
check so that the text poly is in the same direction,
|
84 |
+
and also filter some invalid polygons
|
85 |
+
:param polys:
|
86 |
+
:param tags:
|
87 |
+
:return:
|
88 |
+
"""
|
89 |
+
(h, w) = xxx_todo_changeme
|
90 |
+
if polys.shape[0] == 0:
|
91 |
+
return polys, np.array([]), np.array([])
|
92 |
+
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
93 |
+
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
94 |
+
|
95 |
+
validated_polys = []
|
96 |
+
validated_tags = []
|
97 |
+
hv_tags = []
|
98 |
+
for poly, tag in zip(polys, tags):
|
99 |
+
quad = self.gen_quad_from_poly(poly)
|
100 |
+
p_area = self.quad_area(quad)
|
101 |
+
if abs(p_area) < 1:
|
102 |
+
print('invalid poly')
|
103 |
+
continue
|
104 |
+
if p_area > 0:
|
105 |
+
if tag == False:
|
106 |
+
print('poly in wrong direction')
|
107 |
+
tag = True # reversed cases should be ignore
|
108 |
+
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
|
109 |
+
1), :]
|
110 |
+
quad = quad[(0, 3, 2, 1), :]
|
111 |
+
|
112 |
+
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
|
113 |
+
quad[2])
|
114 |
+
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
|
115 |
+
quad[2])
|
116 |
+
hv_tag = 1
|
117 |
+
|
118 |
+
if len_w * 2.0 < len_h:
|
119 |
+
hv_tag = 0
|
120 |
+
|
121 |
+
validated_polys.append(poly)
|
122 |
+
validated_tags.append(tag)
|
123 |
+
hv_tags.append(hv_tag)
|
124 |
+
return np.array(validated_polys), np.array(validated_tags), np.array(
|
125 |
+
hv_tags)
|
126 |
+
|
127 |
+
def crop_area(self,
|
128 |
+
im,
|
129 |
+
polys,
|
130 |
+
tags,
|
131 |
+
hv_tags,
|
132 |
+
crop_background=False,
|
133 |
+
max_tries=25):
|
134 |
+
"""
|
135 |
+
make random crop from the input image
|
136 |
+
:param im:
|
137 |
+
:param polys:
|
138 |
+
:param tags:
|
139 |
+
:param crop_background:
|
140 |
+
:param max_tries: 50 -> 25
|
141 |
+
:return:
|
142 |
+
"""
|
143 |
+
h, w, _ = im.shape
|
144 |
+
pad_h = h // 10
|
145 |
+
pad_w = w // 10
|
146 |
+
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
147 |
+
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
148 |
+
for poly in polys:
|
149 |
+
poly = np.round(poly, decimals=0).astype(np.int32)
|
150 |
+
minx = np.min(poly[:, 0])
|
151 |
+
maxx = np.max(poly[:, 0])
|
152 |
+
w_array[minx + pad_w:maxx + pad_w] = 1
|
153 |
+
miny = np.min(poly[:, 1])
|
154 |
+
maxy = np.max(poly[:, 1])
|
155 |
+
h_array[miny + pad_h:maxy + pad_h] = 1
|
156 |
+
# ensure the cropped area not across a text
|
157 |
+
h_axis = np.where(h_array == 0)[0]
|
158 |
+
w_axis = np.where(w_array == 0)[0]
|
159 |
+
if len(h_axis) == 0 or len(w_axis) == 0:
|
160 |
+
return im, polys, tags, hv_tags
|
161 |
+
for i in range(max_tries):
|
162 |
+
xx = np.random.choice(w_axis, size=2)
|
163 |
+
xmin = np.min(xx) - pad_w
|
164 |
+
xmax = np.max(xx) - pad_w
|
165 |
+
xmin = np.clip(xmin, 0, w - 1)
|
166 |
+
xmax = np.clip(xmax, 0, w - 1)
|
167 |
+
yy = np.random.choice(h_axis, size=2)
|
168 |
+
ymin = np.min(yy) - pad_h
|
169 |
+
ymax = np.max(yy) - pad_h
|
170 |
+
ymin = np.clip(ymin, 0, h - 1)
|
171 |
+
ymax = np.clip(ymax, 0, h - 1)
|
172 |
+
# if xmax - xmin < ARGS.min_crop_side_ratio * w or \
|
173 |
+
# ymax - ymin < ARGS.min_crop_side_ratio * h:
|
174 |
+
if xmax - xmin < self.min_crop_size or \
|
175 |
+
ymax - ymin < self.min_crop_size:
|
176 |
+
# area too small
|
177 |
+
continue
|
178 |
+
if polys.shape[0] != 0:
|
179 |
+
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
180 |
+
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
181 |
+
selected_polys = np.where(
|
182 |
+
np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
183 |
+
else:
|
184 |
+
selected_polys = []
|
185 |
+
if len(selected_polys) == 0:
|
186 |
+
# no text in this area
|
187 |
+
if crop_background:
|
188 |
+
return im[ymin : ymax + 1, xmin : xmax + 1, :], \
|
189 |
+
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys]
|
190 |
+
else:
|
191 |
+
continue
|
192 |
+
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
193 |
+
polys = polys[selected_polys]
|
194 |
+
tags = tags[selected_polys]
|
195 |
+
hv_tags = hv_tags[selected_polys]
|
196 |
+
polys[:, :, 0] -= xmin
|
197 |
+
polys[:, :, 1] -= ymin
|
198 |
+
return im, polys, tags, hv_tags
|
199 |
+
|
200 |
+
return im, polys, tags, hv_tags
|
201 |
+
|
202 |
+
def generate_direction_map(self, poly_quads, direction_map):
|
203 |
+
"""
|
204 |
+
"""
|
205 |
+
width_list = []
|
206 |
+
height_list = []
|
207 |
+
for quad in poly_quads:
|
208 |
+
quad_w = (np.linalg.norm(quad[0] - quad[1]) +
|
209 |
+
np.linalg.norm(quad[2] - quad[3])) / 2.0
|
210 |
+
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
211 |
+
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
212 |
+
width_list.append(quad_w)
|
213 |
+
height_list.append(quad_h)
|
214 |
+
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
|
215 |
+
average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
|
216 |
+
|
217 |
+
for quad in poly_quads:
|
218 |
+
direct_vector_full = (
|
219 |
+
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
220 |
+
direct_vector = direct_vector_full / (
|
221 |
+
np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
222 |
+
direction_label = tuple(
|
223 |
+
map(float, [
|
224 |
+
direct_vector[0], direct_vector[1], 1.0 / (average_height +
|
225 |
+
1e-6)
|
226 |
+
]))
|
227 |
+
cv2.fillPoly(direction_map,
|
228 |
+
quad.round().astype(np.int32)[np.newaxis, :, :],
|
229 |
+
direction_label)
|
230 |
+
return direction_map
|
231 |
+
|
232 |
+
def calculate_average_height(self, poly_quads):
|
233 |
+
"""
|
234 |
+
"""
|
235 |
+
height_list = []
|
236 |
+
for quad in poly_quads:
|
237 |
+
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
238 |
+
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
239 |
+
height_list.append(quad_h)
|
240 |
+
average_height = max(sum(height_list) / len(height_list), 1.0)
|
241 |
+
return average_height
|
242 |
+
|
243 |
+
def generate_tcl_label(self,
|
244 |
+
hw,
|
245 |
+
polys,
|
246 |
+
tags,
|
247 |
+
ds_ratio,
|
248 |
+
tcl_ratio=0.3,
|
249 |
+
shrink_ratio_of_width=0.15):
|
250 |
+
"""
|
251 |
+
Generate polygon.
|
252 |
+
"""
|
253 |
+
h, w = hw
|
254 |
+
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
255 |
+
polys = polys * ds_ratio
|
256 |
+
|
257 |
+
score_map = np.zeros(
|
258 |
+
(
|
259 |
+
h,
|
260 |
+
w, ), dtype=np.float32)
|
261 |
+
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
262 |
+
training_mask = np.ones(
|
263 |
+
(
|
264 |
+
h,
|
265 |
+
w, ), dtype=np.float32)
|
266 |
+
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
|
267 |
+
[1, 1, 3]).astype(np.float32)
|
268 |
+
|
269 |
+
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
270 |
+
poly = poly_tag[0]
|
271 |
+
tag = poly_tag[1]
|
272 |
+
|
273 |
+
# generate min_area_quad
|
274 |
+
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
275 |
+
min_area_quad_h = 0.5 * (
|
276 |
+
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
277 |
+
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
278 |
+
min_area_quad_w = 0.5 * (
|
279 |
+
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
280 |
+
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
281 |
+
|
282 |
+
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
|
283 |
+
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
|
284 |
+
continue
|
285 |
+
|
286 |
+
if tag:
|
287 |
+
# continue
|
288 |
+
cv2.fillPoly(training_mask,
|
289 |
+
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
290 |
+
else:
|
291 |
+
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
292 |
+
tcl_quads = self.poly2quads(tcl_poly)
|
293 |
+
poly_quads = self.poly2quads(poly)
|
294 |
+
# stcl map
|
295 |
+
stcl_quads, quad_index = self.shrink_poly_along_width(
|
296 |
+
tcl_quads,
|
297 |
+
shrink_ratio_of_width=shrink_ratio_of_width,
|
298 |
+
expand_height_ratio=1.0 / tcl_ratio)
|
299 |
+
# generate tcl map
|
300 |
+
cv2.fillPoly(score_map,
|
301 |
+
np.round(stcl_quads).astype(np.int32), 1.0)
|
302 |
+
|
303 |
+
# generate tbo map
|
304 |
+
for idx, quad in enumerate(stcl_quads):
|
305 |
+
quad_mask = np.zeros((h, w), dtype=np.float32)
|
306 |
+
quad_mask = cv2.fillPoly(
|
307 |
+
quad_mask,
|
308 |
+
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
309 |
+
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
|
310 |
+
quad_mask, tbo_map)
|
311 |
+
return score_map, tbo_map, training_mask
|
312 |
+
|
313 |
+
def generate_tvo_and_tco(self,
|
314 |
+
hw,
|
315 |
+
polys,
|
316 |
+
tags,
|
317 |
+
tcl_ratio=0.3,
|
318 |
+
ds_ratio=0.25):
|
319 |
+
"""
|
320 |
+
Generate tcl map, tvo map and tbo map.
|
321 |
+
"""
|
322 |
+
h, w = hw
|
323 |
+
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
324 |
+
polys = polys * ds_ratio
|
325 |
+
poly_mask = np.zeros((h, w), dtype=np.float32)
|
326 |
+
|
327 |
+
tvo_map = np.ones((9, h, w), dtype=np.float32)
|
328 |
+
tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1))
|
329 |
+
tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T
|
330 |
+
poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32)
|
331 |
+
|
332 |
+
# tco map
|
333 |
+
tco_map = np.ones((3, h, w), dtype=np.float32)
|
334 |
+
tco_map[0] = np.tile(np.arange(0, w), (h, 1))
|
335 |
+
tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T
|
336 |
+
poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32)
|
337 |
+
|
338 |
+
poly_short_edge_map = np.ones((h, w), dtype=np.float32)
|
339 |
+
|
340 |
+
for poly, poly_tag in zip(polys, tags):
|
341 |
+
|
342 |
+
if poly_tag == True:
|
343 |
+
continue
|
344 |
+
|
345 |
+
# adjust point order for vertical poly
|
346 |
+
poly = self.adjust_point(poly)
|
347 |
+
|
348 |
+
# generate min_area_quad
|
349 |
+
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
350 |
+
min_area_quad_h = 0.5 * (
|
351 |
+
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
352 |
+
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
353 |
+
min_area_quad_w = 0.5 * (
|
354 |
+
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
355 |
+
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
356 |
+
|
357 |
+
# generate tcl map and text, 128 * 128
|
358 |
+
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
359 |
+
|
360 |
+
# generate poly_tv_xy_map
|
361 |
+
for idx in range(4):
|
362 |
+
cv2.fillPoly(
|
363 |
+
poly_tv_xy_map[2 * idx],
|
364 |
+
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
365 |
+
float(min(max(min_area_quad[idx, 0], 0), w)))
|
366 |
+
cv2.fillPoly(
|
367 |
+
poly_tv_xy_map[2 * idx + 1],
|
368 |
+
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
369 |
+
float(min(max(min_area_quad[idx, 1], 0), h)))
|
370 |
+
|
371 |
+
# generate poly_tc_xy_map
|
372 |
+
for idx in range(2):
|
373 |
+
cv2.fillPoly(
|
374 |
+
poly_tc_xy_map[idx],
|
375 |
+
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
376 |
+
float(center_point[idx]))
|
377 |
+
|
378 |
+
# generate poly_short_edge_map
|
379 |
+
cv2.fillPoly(
|
380 |
+
poly_short_edge_map,
|
381 |
+
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
382 |
+
float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
|
383 |
+
|
384 |
+
# generate poly_mask and training_mask
|
385 |
+
cv2.fillPoly(poly_mask,
|
386 |
+
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
387 |
+
1)
|
388 |
+
|
389 |
+
tvo_map *= poly_mask
|
390 |
+
tvo_map[:8] -= poly_tv_xy_map
|
391 |
+
tvo_map[-1] /= poly_short_edge_map
|
392 |
+
tvo_map = tvo_map.transpose((1, 2, 0))
|
393 |
+
|
394 |
+
tco_map *= poly_mask
|
395 |
+
tco_map[:2] -= poly_tc_xy_map
|
396 |
+
tco_map[-1] /= poly_short_edge_map
|
397 |
+
tco_map = tco_map.transpose((1, 2, 0))
|
398 |
+
|
399 |
+
return tvo_map, tco_map
|
400 |
+
|
401 |
+
def adjust_point(self, poly):
|
402 |
+
"""
|
403 |
+
adjust point order.
|
404 |
+
"""
|
405 |
+
point_num = poly.shape[0]
|
406 |
+
if point_num == 4:
|
407 |
+
len_1 = np.linalg.norm(poly[0] - poly[1])
|
408 |
+
len_2 = np.linalg.norm(poly[1] - poly[2])
|
409 |
+
len_3 = np.linalg.norm(poly[2] - poly[3])
|
410 |
+
len_4 = np.linalg.norm(poly[3] - poly[0])
|
411 |
+
|
412 |
+
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
|
413 |
+
poly = poly[[1, 2, 3, 0], :]
|
414 |
+
|
415 |
+
elif point_num > 4:
|
416 |
+
vector_1 = poly[0] - poly[1]
|
417 |
+
vector_2 = poly[1] - poly[2]
|
418 |
+
cos_theta = np.dot(vector_1, vector_2) / (
|
419 |
+
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
420 |
+
theta = np.arccos(np.round(cos_theta, decimals=4))
|
421 |
+
|
422 |
+
if abs(theta) > (70 / 180 * math.pi):
|
423 |
+
index = list(range(1, point_num)) + [0]
|
424 |
+
poly = poly[np.array(index), :]
|
425 |
+
return poly
|
426 |
+
|
427 |
+
def gen_min_area_quad_from_poly(self, poly):
|
428 |
+
"""
|
429 |
+
Generate min area quad from poly.
|
430 |
+
"""
|
431 |
+
point_num = poly.shape[0]
|
432 |
+
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
433 |
+
if point_num == 4:
|
434 |
+
min_area_quad = poly
|
435 |
+
center_point = np.sum(poly, axis=0) / 4
|
436 |
+
else:
|
437 |
+
rect = cv2.minAreaRect(poly.astype(
|
438 |
+
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
439 |
+
center_point = rect[0]
|
440 |
+
box = np.array(cv2.boxPoints(rect))
|
441 |
+
|
442 |
+
first_point_idx = 0
|
443 |
+
min_dist = 1e4
|
444 |
+
for i in range(4):
|
445 |
+
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
446 |
+
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
447 |
+
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
448 |
+
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
449 |
+
if dist < min_dist:
|
450 |
+
min_dist = dist
|
451 |
+
first_point_idx = i
|
452 |
+
|
453 |
+
for i in range(4):
|
454 |
+
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
455 |
+
|
456 |
+
return min_area_quad, center_point
|
457 |
+
|
458 |
+
def shrink_quad_along_width(self,
|
459 |
+
quad,
|
460 |
+
begin_width_ratio=0.,
|
461 |
+
end_width_ratio=1.):
|
462 |
+
"""
|
463 |
+
Generate shrink_quad_along_width.
|
464 |
+
"""
|
465 |
+
ratio_pair = np.array(
|
466 |
+
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
467 |
+
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
468 |
+
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
469 |
+
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
470 |
+
|
471 |
+
def shrink_poly_along_width(self,
|
472 |
+
quads,
|
473 |
+
shrink_ratio_of_width,
|
474 |
+
expand_height_ratio=1.0):
|
475 |
+
"""
|
476 |
+
shrink poly with given length.
|
477 |
+
"""
|
478 |
+
upper_edge_list = []
|
479 |
+
|
480 |
+
def get_cut_info(edge_len_list, cut_len):
|
481 |
+
for idx, edge_len in enumerate(edge_len_list):
|
482 |
+
cut_len -= edge_len
|
483 |
+
if cut_len <= 0.000001:
|
484 |
+
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
|
485 |
+
return idx, ratio
|
486 |
+
|
487 |
+
for quad in quads:
|
488 |
+
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
|
489 |
+
upper_edge_list.append(upper_edge_len)
|
490 |
+
|
491 |
+
# length of left edge and right edge.
|
492 |
+
left_length = np.linalg.norm(quads[0][0] - quads[0][
|
493 |
+
3]) * expand_height_ratio
|
494 |
+
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
|
495 |
+
2]) * expand_height_ratio
|
496 |
+
|
497 |
+
shrink_length = min(left_length, right_length,
|
498 |
+
sum(upper_edge_list)) * shrink_ratio_of_width
|
499 |
+
# shrinking length
|
500 |
+
upper_len_left = shrink_length
|
501 |
+
upper_len_right = sum(upper_edge_list) - shrink_length
|
502 |
+
|
503 |
+
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
504 |
+
left_quad = self.shrink_quad_along_width(
|
505 |
+
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
506 |
+
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
507 |
+
right_quad = self.shrink_quad_along_width(
|
508 |
+
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
509 |
+
|
510 |
+
out_quad_list = []
|
511 |
+
if left_idx == right_idx:
|
512 |
+
out_quad_list.append(
|
513 |
+
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
514 |
+
else:
|
515 |
+
out_quad_list.append(left_quad)
|
516 |
+
for idx in range(left_idx + 1, right_idx):
|
517 |
+
out_quad_list.append(quads[idx])
|
518 |
+
out_quad_list.append(right_quad)
|
519 |
+
|
520 |
+
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
|
521 |
+
|
522 |
+
def vector_angle(self, A, B):
|
523 |
+
"""
|
524 |
+
Calculate the angle between vector AB and x-axis positive direction.
|
525 |
+
"""
|
526 |
+
AB = np.array([B[1] - A[1], B[0] - A[0]])
|
527 |
+
return np.arctan2(*AB)
|
528 |
+
|
529 |
+
def theta_line_cross_point(self, theta, point):
|
530 |
+
"""
|
531 |
+
Calculate the line through given point and angle in ax + by + c =0 form.
|
532 |
+
"""
|
533 |
+
x, y = point
|
534 |
+
cos = np.cos(theta)
|
535 |
+
sin = np.sin(theta)
|
536 |
+
return [sin, -cos, cos * y - sin * x]
|
537 |
+
|
538 |
+
def line_cross_two_point(self, A, B):
|
539 |
+
"""
|
540 |
+
Calculate the line through given point A and B in ax + by + c =0 form.
|
541 |
+
"""
|
542 |
+
angle = self.vector_angle(A, B)
|
543 |
+
return self.theta_line_cross_point(angle, A)
|
544 |
+
|
545 |
+
def average_angle(self, poly):
|
546 |
+
"""
|
547 |
+
Calculate the average angle between left and right edge in given poly.
|
548 |
+
"""
|
549 |
+
p0, p1, p2, p3 = poly
|
550 |
+
angle30 = self.vector_angle(p3, p0)
|
551 |
+
angle21 = self.vector_angle(p2, p1)
|
552 |
+
return (angle30 + angle21) / 2
|
553 |
+
|
554 |
+
def line_cross_point(self, line1, line2):
|
555 |
+
"""
|
556 |
+
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
|
557 |
+
"""
|
558 |
+
a1, b1, c1 = line1
|
559 |
+
a2, b2, c2 = line2
|
560 |
+
d = a1 * b2 - a2 * b1
|
561 |
+
|
562 |
+
if d == 0:
|
563 |
+
#print("line1", line1)
|
564 |
+
#print("line2", line2)
|
565 |
+
print('Cross point does not exist')
|
566 |
+
return np.array([0, 0], dtype=np.float32)
|
567 |
+
else:
|
568 |
+
x = (b1 * c2 - b2 * c1) / d
|
569 |
+
y = (a2 * c1 - a1 * c2) / d
|
570 |
+
|
571 |
+
return np.array([x, y], dtype=np.float32)
|
572 |
+
|
573 |
+
def quad2tcl(self, poly, ratio):
|
574 |
+
"""
|
575 |
+
Generate center line by poly clock-wise point. (4, 2)
|
576 |
+
"""
|
577 |
+
ratio_pair = np.array(
|
578 |
+
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
579 |
+
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
580 |
+
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
581 |
+
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
582 |
+
|
583 |
+
def poly2tcl(self, poly, ratio):
|
584 |
+
"""
|
585 |
+
Generate center line by poly clock-wise point.
|
586 |
+
"""
|
587 |
+
ratio_pair = np.array(
|
588 |
+
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
589 |
+
tcl_poly = np.zeros_like(poly)
|
590 |
+
point_num = poly.shape[0]
|
591 |
+
|
592 |
+
for idx in range(point_num // 2):
|
593 |
+
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
|
594 |
+
) * ratio_pair
|
595 |
+
tcl_poly[idx] = point_pair[0]
|
596 |
+
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
597 |
+
return tcl_poly
|
598 |
+
|
599 |
+
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
|
600 |
+
"""
|
601 |
+
Generate tbo_map for give quad.
|
602 |
+
"""
|
603 |
+
# upper and lower line function: ax + by + c = 0;
|
604 |
+
up_line = self.line_cross_two_point(quad[0], quad[1])
|
605 |
+
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
606 |
+
|
607 |
+
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
|
608 |
+
np.linalg.norm(quad[1] - quad[2]))
|
609 |
+
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
|
610 |
+
np.linalg.norm(quad[2] - quad[3]))
|
611 |
+
|
612 |
+
# average angle of left and right line.
|
613 |
+
angle = self.average_angle(quad)
|
614 |
+
|
615 |
+
xy_in_poly = np.argwhere(tcl_mask == 1)
|
616 |
+
for y, x in xy_in_poly:
|
617 |
+
point = (x, y)
|
618 |
+
line = self.theta_line_cross_point(angle, point)
|
619 |
+
cross_point_upper = self.line_cross_point(up_line, line)
|
620 |
+
cross_point_lower = self.line_cross_point(lower_line, line)
|
621 |
+
##FIX, offset reverse
|
622 |
+
upper_offset_x, upper_offset_y = cross_point_upper - point
|
623 |
+
lower_offset_x, lower_offset_y = cross_point_lower - point
|
624 |
+
tbo_map[y, x, 0] = upper_offset_y
|
625 |
+
tbo_map[y, x, 1] = upper_offset_x
|
626 |
+
tbo_map[y, x, 2] = lower_offset_y
|
627 |
+
tbo_map[y, x, 3] = lower_offset_x
|
628 |
+
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
|
629 |
+
return tbo_map
|
630 |
+
|
631 |
+
def poly2quads(self, poly):
|
632 |
+
"""
|
633 |
+
Split poly into quads.
|
634 |
+
"""
|
635 |
+
quad_list = []
|
636 |
+
point_num = poly.shape[0]
|
637 |
+
|
638 |
+
# point pair
|
639 |
+
point_pair_list = []
|
640 |
+
for idx in range(point_num // 2):
|
641 |
+
point_pair = [poly[idx], poly[point_num - 1 - idx]]
|
642 |
+
point_pair_list.append(point_pair)
|
643 |
+
|
644 |
+
quad_num = point_num // 2 - 1
|
645 |
+
for idx in range(quad_num):
|
646 |
+
# reshape and adjust to clock-wise
|
647 |
+
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
|
648 |
+
).reshape(4, 2)[[0, 2, 3, 1]])
|
649 |
+
|
650 |
+
return np.array(quad_list)
|
651 |
+
|
652 |
+
def __call__(self, data):
|
653 |
+
im = data['image']
|
654 |
+
text_polys = data['polys']
|
655 |
+
text_tags = data['ignore_tags']
|
656 |
+
if im is None:
|
657 |
+
return None
|
658 |
+
if text_polys.shape[0] == 0:
|
659 |
+
return None
|
660 |
+
|
661 |
+
h, w, _ = im.shape
|
662 |
+
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
|
663 |
+
text_polys, text_tags, (h, w))
|
664 |
+
|
665 |
+
if text_polys.shape[0] == 0:
|
666 |
+
return None
|
667 |
+
|
668 |
+
#set aspect ratio and keep area fix
|
669 |
+
asp_scales = np.arange(1.0, 1.55, 0.1)
|
670 |
+
asp_scale = np.random.choice(asp_scales)
|
671 |
+
|
672 |
+
if np.random.rand() < 0.5:
|
673 |
+
asp_scale = 1.0 / asp_scale
|
674 |
+
asp_scale = math.sqrt(asp_scale)
|
675 |
+
|
676 |
+
asp_wx = asp_scale
|
677 |
+
asp_hy = 1.0 / asp_scale
|
678 |
+
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
679 |
+
text_polys[:, :, 0] *= asp_wx
|
680 |
+
text_polys[:, :, 1] *= asp_hy
|
681 |
+
|
682 |
+
h, w, _ = im.shape
|
683 |
+
if max(h, w) > 2048:
|
684 |
+
rd_scale = 2048.0 / max(h, w)
|
685 |
+
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
686 |
+
text_polys *= rd_scale
|
687 |
+
h, w, _ = im.shape
|
688 |
+
if min(h, w) < 16:
|
689 |
+
return None
|
690 |
+
|
691 |
+
#no background
|
692 |
+
im, text_polys, text_tags, hv_tags = self.crop_area(im, \
|
693 |
+
text_polys, text_tags, hv_tags, crop_background=False)
|
694 |
+
|
695 |
+
if text_polys.shape[0] == 0:
|
696 |
+
return None
|
697 |
+
#continue for all ignore case
|
698 |
+
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
699 |
+
return None
|
700 |
+
new_h, new_w, _ = im.shape
|
701 |
+
if (new_h is None) or (new_w is None):
|
702 |
+
return None
|
703 |
+
#resize image
|
704 |
+
std_ratio = float(self.input_size) / max(new_w, new_h)
|
705 |
+
rand_scales = np.array(
|
706 |
+
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
707 |
+
rz_scale = std_ratio * np.random.choice(rand_scales)
|
708 |
+
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
709 |
+
text_polys[:, :, 0] *= rz_scale
|
710 |
+
text_polys[:, :, 1] *= rz_scale
|
711 |
+
|
712 |
+
#add gaussian blur
|
713 |
+
if np.random.rand() < 0.1 * 0.5:
|
714 |
+
ks = np.random.permutation(5)[0] + 1
|
715 |
+
ks = int(ks / 2) * 2 + 1
|
716 |
+
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
717 |
+
#add brighter
|
718 |
+
if np.random.rand() < 0.1 * 0.5:
|
719 |
+
im = im * (1.0 + np.random.rand() * 0.5)
|
720 |
+
im = np.clip(im, 0.0, 255.0)
|
721 |
+
#add darker
|
722 |
+
if np.random.rand() < 0.1 * 0.5:
|
723 |
+
im = im * (1.0 - np.random.rand() * 0.5)
|
724 |
+
im = np.clip(im, 0.0, 255.0)
|
725 |
+
|
726 |
+
# Padding the im to [input_size, input_size]
|
727 |
+
new_h, new_w, _ = im.shape
|
728 |
+
if min(new_w, new_h) < self.input_size * 0.5:
|
729 |
+
return None
|
730 |
+
|
731 |
+
im_padded = np.ones(
|
732 |
+
(self.input_size, self.input_size, 3), dtype=np.float32)
|
733 |
+
im_padded[:, :, 2] = 0.485 * 255
|
734 |
+
im_padded[:, :, 1] = 0.456 * 255
|
735 |
+
im_padded[:, :, 0] = 0.406 * 255
|
736 |
+
|
737 |
+
# Random the start position
|
738 |
+
del_h = self.input_size - new_h
|
739 |
+
del_w = self.input_size - new_w
|
740 |
+
sh, sw = 0, 0
|
741 |
+
if del_h > 1:
|
742 |
+
sh = int(np.random.rand() * del_h)
|
743 |
+
if del_w > 1:
|
744 |
+
sw = int(np.random.rand() * del_w)
|
745 |
+
|
746 |
+
# Padding
|
747 |
+
im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
|
748 |
+
text_polys[:, :, 0] += sw
|
749 |
+
text_polys[:, :, 1] += sh
|
750 |
+
|
751 |
+
score_map, border_map, training_mask = self.generate_tcl_label(
|
752 |
+
(self.input_size, self.input_size), text_polys, text_tags, 0.25)
|
753 |
+
|
754 |
+
# SAST head
|
755 |
+
tvo_map, tco_map = self.generate_tvo_and_tco(
|
756 |
+
(self.input_size, self.input_size),
|
757 |
+
text_polys,
|
758 |
+
text_tags,
|
759 |
+
tcl_ratio=0.3,
|
760 |
+
ds_ratio=0.25)
|
761 |
+
# print("test--------tvo_map shape:", tvo_map.shape)
|
762 |
+
|
763 |
+
im_padded[:, :, 2] -= 0.485 * 255
|
764 |
+
im_padded[:, :, 1] -= 0.456 * 255
|
765 |
+
im_padded[:, :, 0] -= 0.406 * 255
|
766 |
+
im_padded[:, :, 2] /= (255.0 * 0.229)
|
767 |
+
im_padded[:, :, 1] /= (255.0 * 0.224)
|
768 |
+
im_padded[:, :, 0] /= (255.0 * 0.225)
|
769 |
+
im_padded = im_padded.transpose((2, 0, 1))
|
770 |
+
|
771 |
+
data['image'] = im_padded[::-1, :, :]
|
772 |
+
data['score_map'] = score_map[np.newaxis, :, :]
|
773 |
+
data['border_map'] = border_map.transpose((2, 0, 1))
|
774 |
+
data['training_mask'] = training_mask[np.newaxis, :, :]
|
775 |
+
data['tvo_map'] = tvo_map.transpose((2, 0, 1))
|
776 |
+
data['tco_map'] = tco_map.transpose((2, 0, 1))
|
777 |
+
return data
|
ppocr/data/imaug/ssl_img_aug.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import cv2
|
17 |
+
import numpy as np
|
18 |
+
import random
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
from .rec_img_aug import resize_norm_img
|
22 |
+
|
23 |
+
|
24 |
+
class SSLRotateResize(object):
|
25 |
+
def __init__(self,
|
26 |
+
image_shape,
|
27 |
+
padding=False,
|
28 |
+
select_all=True,
|
29 |
+
mode="train",
|
30 |
+
**kwargs):
|
31 |
+
self.image_shape = image_shape
|
32 |
+
self.padding = padding
|
33 |
+
self.select_all = select_all
|
34 |
+
self.mode = mode
|
35 |
+
|
36 |
+
def __call__(self, data):
|
37 |
+
img = data["image"]
|
38 |
+
|
39 |
+
data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
40 |
+
data["image_r180"] = cv2.rotate(data["image_r90"],
|
41 |
+
cv2.ROTATE_90_CLOCKWISE)
|
42 |
+
data["image_r270"] = cv2.rotate(data["image_r180"],
|
43 |
+
cv2.ROTATE_90_CLOCKWISE)
|
44 |
+
|
45 |
+
images = []
|
46 |
+
for key in ["image", "image_r90", "image_r180", "image_r270"]:
|
47 |
+
images.append(
|
48 |
+
resize_norm_img(
|
49 |
+
data.pop(key),
|
50 |
+
image_shape=self.image_shape,
|
51 |
+
padding=self.padding)[0])
|
52 |
+
data["image"] = np.stack(images, axis=0)
|
53 |
+
data["label"] = np.array(list(range(4)))
|
54 |
+
if not self.select_all:
|
55 |
+
data["image"] = data["image"][0::2] # just choose 0 and 180
|
56 |
+
data["label"] = data["label"][0:2] # label needs to be continuous
|
57 |
+
if self.mode == "test":
|
58 |
+
data["image"] = data["image"][0]
|
59 |
+
data["label"] = data["label"][0]
|
60 |
+
return data
|
ppocr/data/imaug/table_ops.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
from __future__ import unicode_literals
|
21 |
+
|
22 |
+
import sys
|
23 |
+
import six
|
24 |
+
import cv2
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
|
28 |
+
class GenTableMask(object):
|
29 |
+
""" gen table mask """
|
30 |
+
|
31 |
+
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
|
32 |
+
self.shrink_h_max = 5
|
33 |
+
self.shrink_w_max = 5
|
34 |
+
self.mask_type = mask_type
|
35 |
+
|
36 |
+
def projection(self, erosion, h, w, spilt_threshold=0):
|
37 |
+
# 水平投影
|
38 |
+
projection_map = np.ones_like(erosion)
|
39 |
+
project_val_array = [0 for _ in range(0, h)]
|
40 |
+
|
41 |
+
for j in range(0, h):
|
42 |
+
for i in range(0, w):
|
43 |
+
if erosion[j, i] == 255:
|
44 |
+
project_val_array[j] += 1
|
45 |
+
# 根据数组,获取切割点
|
46 |
+
start_idx = 0 # 记录进入字符区的索引
|
47 |
+
end_idx = 0 # 记录进入空白区域的索引
|
48 |
+
in_text = False # 是否遍历到了字符区内
|
49 |
+
box_list = []
|
50 |
+
for i in range(len(project_val_array)):
|
51 |
+
if in_text == False and project_val_array[
|
52 |
+
i] > spilt_threshold: # 进入字符区了
|
53 |
+
in_text = True
|
54 |
+
start_idx = i
|
55 |
+
elif project_val_array[
|
56 |
+
i] <= spilt_threshold and in_text == True: # 进入空白区了
|
57 |
+
end_idx = i
|
58 |
+
in_text = False
|
59 |
+
if end_idx - start_idx <= 2:
|
60 |
+
continue
|
61 |
+
box_list.append((start_idx, end_idx + 1))
|
62 |
+
|
63 |
+
if in_text:
|
64 |
+
box_list.append((start_idx, h - 1))
|
65 |
+
# 绘制投影直方图
|
66 |
+
for j in range(0, h):
|
67 |
+
for i in range(0, project_val_array[j]):
|
68 |
+
projection_map[j, i] = 0
|
69 |
+
return box_list, projection_map
|
70 |
+
|
71 |
+
def projection_cx(self, box_img):
|
72 |
+
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
|
73 |
+
h, w = box_gray_img.shape
|
74 |
+
# 灰度图片进行二值化处理
|
75 |
+
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
|
76 |
+
cv2.THRESH_BINARY_INV)
|
77 |
+
# 纵向腐蚀
|
78 |
+
if h < w:
|
79 |
+
kernel = np.ones((2, 1), np.uint8)
|
80 |
+
erode = cv2.erode(thresh1, kernel, iterations=1)
|
81 |
+
else:
|
82 |
+
erode = thresh1
|
83 |
+
# 水平膨胀
|
84 |
+
kernel = np.ones((1, 5), np.uint8)
|
85 |
+
erosion = cv2.dilate(erode, kernel, iterations=1)
|
86 |
+
# 水平投影
|
87 |
+
projection_map = np.ones_like(erosion)
|
88 |
+
project_val_array = [0 for _ in range(0, h)]
|
89 |
+
|
90 |
+
for j in range(0, h):
|
91 |
+
for i in range(0, w):
|
92 |
+
if erosion[j, i] == 255:
|
93 |
+
project_val_array[j] += 1
|
94 |
+
# 根据数组,获取切割点
|
95 |
+
start_idx = 0 # 记录进入字符区的索引
|
96 |
+
end_idx = 0 # 记录进入空白区域的索引
|
97 |
+
in_text = False # 是否遍历到了字符区内
|
98 |
+
box_list = []
|
99 |
+
spilt_threshold = 0
|
100 |
+
for i in range(len(project_val_array)):
|
101 |
+
if in_text == False and project_val_array[
|
102 |
+
i] > spilt_threshold: # 进入字符区了
|
103 |
+
in_text = True
|
104 |
+
start_idx = i
|
105 |
+
elif project_val_array[
|
106 |
+
i] <= spilt_threshold and in_text == True: # 进入空白区了
|
107 |
+
end_idx = i
|
108 |
+
in_text = False
|
109 |
+
if end_idx - start_idx <= 2:
|
110 |
+
continue
|
111 |
+
box_list.append((start_idx, end_idx + 1))
|
112 |
+
|
113 |
+
if in_text:
|
114 |
+
box_list.append((start_idx, h - 1))
|
115 |
+
# 绘制投影直方图
|
116 |
+
for j in range(0, h):
|
117 |
+
for i in range(0, project_val_array[j]):
|
118 |
+
projection_map[j, i] = 0
|
119 |
+
split_bbox_list = []
|
120 |
+
if len(box_list) > 1:
|
121 |
+
for i, (h_start, h_end) in enumerate(box_list):
|
122 |
+
if i == 0:
|
123 |
+
h_start = 0
|
124 |
+
if i == len(box_list):
|
125 |
+
h_end = h
|
126 |
+
word_img = erosion[h_start:h_end + 1, :]
|
127 |
+
word_h, word_w = word_img.shape
|
128 |
+
w_split_list, w_projection_map = self.projection(word_img.T,
|
129 |
+
word_w, word_h)
|
130 |
+
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
|
131 |
+
if h_start > 0:
|
132 |
+
h_start -= 1
|
133 |
+
h_end += 1
|
134 |
+
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
|
135 |
+
split_bbox_list.append([w_start, h_start, w_end, h_end])
|
136 |
+
else:
|
137 |
+
split_bbox_list.append([0, 0, w, h])
|
138 |
+
return split_bbox_list
|
139 |
+
|
140 |
+
def shrink_bbox(self, bbox):
|
141 |
+
left, top, right, bottom = bbox
|
142 |
+
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
|
143 |
+
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
|
144 |
+
left_new = left + sh_w
|
145 |
+
right_new = right - sh_w
|
146 |
+
top_new = top + sh_h
|
147 |
+
bottom_new = bottom - sh_h
|
148 |
+
if left_new >= right_new:
|
149 |
+
left_new = left
|
150 |
+
right_new = right
|
151 |
+
if top_new >= bottom_new:
|
152 |
+
top_new = top
|
153 |
+
bottom_new = bottom
|
154 |
+
return [left_new, top_new, right_new, bottom_new]
|
155 |
+
|
156 |
+
def __call__(self, data):
|
157 |
+
img = data['image']
|
158 |
+
cells = data['cells']
|
159 |
+
height, width = img.shape[0:2]
|
160 |
+
if self.mask_type == 1:
|
161 |
+
mask_img = np.zeros((height, width), dtype=np.float32)
|
162 |
+
else:
|
163 |
+
mask_img = np.zeros((height, width, 3), dtype=np.float32)
|
164 |
+
cell_num = len(cells)
|
165 |
+
for cno in range(cell_num):
|
166 |
+
if "bbox" in cells[cno]:
|
167 |
+
bbox = cells[cno]['bbox']
|
168 |
+
left, top, right, bottom = bbox
|
169 |
+
box_img = img[top:bottom, left:right, :].copy()
|
170 |
+
split_bbox_list = self.projection_cx(box_img)
|
171 |
+
for sno in range(len(split_bbox_list)):
|
172 |
+
split_bbox_list[sno][0] += left
|
173 |
+
split_bbox_list[sno][1] += top
|
174 |
+
split_bbox_list[sno][2] += left
|
175 |
+
split_bbox_list[sno][3] += top
|
176 |
+
|
177 |
+
for sno in range(len(split_bbox_list)):
|
178 |
+
left, top, right, bottom = split_bbox_list[sno]
|
179 |
+
left, top, right, bottom = self.shrink_bbox(
|
180 |
+
[left, top, right, bottom])
|
181 |
+
if self.mask_type == 1:
|
182 |
+
mask_img[top:bottom, left:right] = 1.0
|
183 |
+
data['mask_img'] = mask_img
|
184 |
+
else:
|
185 |
+
mask_img[top:bottom, left:right, :] = (255, 255, 255)
|
186 |
+
data['image'] = mask_img
|
187 |
+
return data
|
188 |
+
|
189 |
+
|
190 |
+
class ResizeTableImage(object):
|
191 |
+
def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
|
192 |
+
**kwargs):
|
193 |
+
super(ResizeTableImage, self).__init__()
|
194 |
+
self.max_len = max_len
|
195 |
+
self.resize_bboxes = resize_bboxes
|
196 |
+
self.infer_mode = infer_mode
|
197 |
+
|
198 |
+
def __call__(self, data):
|
199 |
+
img = data['image']
|
200 |
+
height, width = img.shape[0:2]
|
201 |
+
ratio = self.max_len / (max(height, width) * 1.0)
|
202 |
+
resize_h = int(height * ratio)
|
203 |
+
resize_w = int(width * ratio)
|
204 |
+
resize_img = cv2.resize(img, (resize_w, resize_h))
|
205 |
+
if self.resize_bboxes and not self.infer_mode:
|
206 |
+
data['bboxes'] = data['bboxes'] * ratio
|
207 |
+
data['image'] = resize_img
|
208 |
+
data['src_img'] = img
|
209 |
+
data['shape'] = np.array([height, width, ratio, ratio])
|
210 |
+
data['max_len'] = self.max_len
|
211 |
+
return data
|
212 |
+
|
213 |
+
|
214 |
+
class PaddingTableImage(object):
|
215 |
+
def __init__(self, size, **kwargs):
|
216 |
+
super(PaddingTableImage, self).__init__()
|
217 |
+
self.size = size
|
218 |
+
|
219 |
+
def __call__(self, data):
|
220 |
+
img = data['image']
|
221 |
+
pad_h, pad_w = self.size
|
222 |
+
padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
|
223 |
+
height, width = img.shape[0:2]
|
224 |
+
padding_img[0:height, 0:width, :] = img.copy()
|
225 |
+
data['image'] = padding_img
|
226 |
+
shape = data['shape'].tolist()
|
227 |
+
shape.extend([pad_h, pad_w])
|
228 |
+
data['shape'] = np.array(shape)
|
229 |
+
return data
|
ppocr/data/imaug/text_image_aug/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .augment import tia_perspective, tia_distort, tia_stretch
|
16 |
+
|
17 |
+
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
|
ppocr/data/imaug/text_image_aug/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (260 Bytes). View file
|
|
ppocr/data/imaug/text_image_aug/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (297 Bytes). View file
|
|
ppocr/data/imaug/text_image_aug/__pycache__/augment.cpython-37.pyc
ADDED
Binary file (2.15 kB). View file
|
|
ppocr/data/imaug/text_image_aug/__pycache__/augment.cpython-38.pyc
ADDED
Binary file (2.19 kB). View file
|
|
ppocr/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-37.pyc
ADDED
Binary file (3.89 kB). View file
|
|
ppocr/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-38.pyc
ADDED
Binary file (3.96 kB). View file
|
|
ppocr/data/imaug/text_image_aug/augment.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is refer from:
|
16 |
+
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py
|
17 |
+
"""
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
from .warp_mls import WarpMLS
|
21 |
+
|
22 |
+
|
23 |
+
def tia_distort(src, segment=4):
|
24 |
+
img_h, img_w = src.shape[:2]
|
25 |
+
|
26 |
+
cut = img_w // segment
|
27 |
+
thresh = cut // 3
|
28 |
+
|
29 |
+
src_pts = list()
|
30 |
+
dst_pts = list()
|
31 |
+
|
32 |
+
src_pts.append([0, 0])
|
33 |
+
src_pts.append([img_w, 0])
|
34 |
+
src_pts.append([img_w, img_h])
|
35 |
+
src_pts.append([0, img_h])
|
36 |
+
|
37 |
+
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
|
38 |
+
dst_pts.append(
|
39 |
+
[img_w - np.random.randint(thresh), np.random.randint(thresh)])
|
40 |
+
dst_pts.append(
|
41 |
+
[img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
|
42 |
+
dst_pts.append(
|
43 |
+
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
|
44 |
+
|
45 |
+
half_thresh = thresh * 0.5
|
46 |
+
|
47 |
+
for cut_idx in np.arange(1, segment, 1):
|
48 |
+
src_pts.append([cut * cut_idx, 0])
|
49 |
+
src_pts.append([cut * cut_idx, img_h])
|
50 |
+
dst_pts.append([
|
51 |
+
cut * cut_idx + np.random.randint(thresh) - half_thresh,
|
52 |
+
np.random.randint(thresh) - half_thresh
|
53 |
+
])
|
54 |
+
dst_pts.append([
|
55 |
+
cut * cut_idx + np.random.randint(thresh) - half_thresh,
|
56 |
+
img_h + np.random.randint(thresh) - half_thresh
|
57 |
+
])
|
58 |
+
|
59 |
+
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
|
60 |
+
dst = trans.generate()
|
61 |
+
|
62 |
+
return dst
|
63 |
+
|
64 |
+
|
65 |
+
def tia_stretch(src, segment=4):
|
66 |
+
img_h, img_w = src.shape[:2]
|
67 |
+
|
68 |
+
cut = img_w // segment
|
69 |
+
thresh = cut * 4 // 5
|
70 |
+
|
71 |
+
src_pts = list()
|
72 |
+
dst_pts = list()
|
73 |
+
|
74 |
+
src_pts.append([0, 0])
|
75 |
+
src_pts.append([img_w, 0])
|
76 |
+
src_pts.append([img_w, img_h])
|
77 |
+
src_pts.append([0, img_h])
|
78 |
+
|
79 |
+
dst_pts.append([0, 0])
|
80 |
+
dst_pts.append([img_w, 0])
|
81 |
+
dst_pts.append([img_w, img_h])
|
82 |
+
dst_pts.append([0, img_h])
|
83 |
+
|
84 |
+
half_thresh = thresh * 0.5
|
85 |
+
|
86 |
+
for cut_idx in np.arange(1, segment, 1):
|
87 |
+
move = np.random.randint(thresh) - half_thresh
|
88 |
+
src_pts.append([cut * cut_idx, 0])
|
89 |
+
src_pts.append([cut * cut_idx, img_h])
|
90 |
+
dst_pts.append([cut * cut_idx + move, 0])
|
91 |
+
dst_pts.append([cut * cut_idx + move, img_h])
|
92 |
+
|
93 |
+
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
|
94 |
+
dst = trans.generate()
|
95 |
+
|
96 |
+
return dst
|
97 |
+
|
98 |
+
|
99 |
+
def tia_perspective(src):
|
100 |
+
img_h, img_w = src.shape[:2]
|
101 |
+
|
102 |
+
thresh = img_h // 2
|
103 |
+
|
104 |
+
src_pts = list()
|
105 |
+
dst_pts = list()
|
106 |
+
|
107 |
+
src_pts.append([0, 0])
|
108 |
+
src_pts.append([img_w, 0])
|
109 |
+
src_pts.append([img_w, img_h])
|
110 |
+
src_pts.append([0, img_h])
|
111 |
+
|
112 |
+
dst_pts.append([0, np.random.randint(thresh)])
|
113 |
+
dst_pts.append([img_w, np.random.randint(thresh)])
|
114 |
+
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
|
115 |
+
dst_pts.append([0, img_h - np.random.randint(thresh)])
|
116 |
+
|
117 |
+
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
|
118 |
+
dst = trans.generate()
|
119 |
+
|
120 |
+
return dst
|
ppocr/data/imaug/text_image_aug/warp_mls.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is refer from:
|
16 |
+
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py
|
17 |
+
"""
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
|
22 |
+
class WarpMLS:
|
23 |
+
def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
|
24 |
+
self.src = src
|
25 |
+
self.src_pts = src_pts
|
26 |
+
self.dst_pts = dst_pts
|
27 |
+
self.pt_count = len(self.dst_pts)
|
28 |
+
self.dst_w = dst_w
|
29 |
+
self.dst_h = dst_h
|
30 |
+
self.trans_ratio = trans_ratio
|
31 |
+
self.grid_size = 100
|
32 |
+
self.rdx = np.zeros((self.dst_h, self.dst_w))
|
33 |
+
self.rdy = np.zeros((self.dst_h, self.dst_w))
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def __bilinear_interp(x, y, v11, v12, v21, v22):
|
37 |
+
return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
|
38 |
+
(1 - y) + v22 * y) * x
|
39 |
+
|
40 |
+
def generate(self):
|
41 |
+
self.calc_delta()
|
42 |
+
return self.gen_img()
|
43 |
+
|
44 |
+
def calc_delta(self):
|
45 |
+
w = np.zeros(self.pt_count, dtype=np.float32)
|
46 |
+
|
47 |
+
if self.pt_count < 2:
|
48 |
+
return
|
49 |
+
|
50 |
+
i = 0
|
51 |
+
while 1:
|
52 |
+
if self.dst_w <= i < self.dst_w + self.grid_size - 1:
|
53 |
+
i = self.dst_w - 1
|
54 |
+
elif i >= self.dst_w:
|
55 |
+
break
|
56 |
+
|
57 |
+
j = 0
|
58 |
+
while 1:
|
59 |
+
if self.dst_h <= j < self.dst_h + self.grid_size - 1:
|
60 |
+
j = self.dst_h - 1
|
61 |
+
elif j >= self.dst_h:
|
62 |
+
break
|
63 |
+
|
64 |
+
sw = 0
|
65 |
+
swp = np.zeros(2, dtype=np.float32)
|
66 |
+
swq = np.zeros(2, dtype=np.float32)
|
67 |
+
new_pt = np.zeros(2, dtype=np.float32)
|
68 |
+
cur_pt = np.array([i, j], dtype=np.float32)
|
69 |
+
|
70 |
+
k = 0
|
71 |
+
for k in range(self.pt_count):
|
72 |
+
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
|
73 |
+
break
|
74 |
+
|
75 |
+
w[k] = 1. / (
|
76 |
+
(i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
|
77 |
+
(j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
|
78 |
+
|
79 |
+
sw += w[k]
|
80 |
+
swp = swp + w[k] * np.array(self.dst_pts[k])
|
81 |
+
swq = swq + w[k] * np.array(self.src_pts[k])
|
82 |
+
|
83 |
+
if k == self.pt_count - 1:
|
84 |
+
pstar = 1 / sw * swp
|
85 |
+
qstar = 1 / sw * swq
|
86 |
+
|
87 |
+
miu_s = 0
|
88 |
+
for k in range(self.pt_count):
|
89 |
+
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
|
90 |
+
continue
|
91 |
+
pt_i = self.dst_pts[k] - pstar
|
92 |
+
miu_s += w[k] * np.sum(pt_i * pt_i)
|
93 |
+
|
94 |
+
cur_pt -= pstar
|
95 |
+
cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
|
96 |
+
|
97 |
+
for k in range(self.pt_count):
|
98 |
+
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
|
99 |
+
continue
|
100 |
+
|
101 |
+
pt_i = self.dst_pts[k] - pstar
|
102 |
+
pt_j = np.array([-pt_i[1], pt_i[0]])
|
103 |
+
|
104 |
+
tmp_pt = np.zeros(2, dtype=np.float32)
|
105 |
+
tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
|
106 |
+
np.sum(pt_j * cur_pt) * self.src_pts[k][1]
|
107 |
+
tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
|
108 |
+
np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
|
109 |
+
tmp_pt *= (w[k] / miu_s)
|
110 |
+
new_pt += tmp_pt
|
111 |
+
|
112 |
+
new_pt += qstar
|
113 |
+
else:
|
114 |
+
new_pt = self.src_pts[k]
|
115 |
+
|
116 |
+
self.rdx[j, i] = new_pt[0] - i
|
117 |
+
self.rdy[j, i] = new_pt[1] - j
|
118 |
+
|
119 |
+
j += self.grid_size
|
120 |
+
i += self.grid_size
|
121 |
+
|
122 |
+
def gen_img(self):
|
123 |
+
src_h, src_w = self.src.shape[:2]
|
124 |
+
dst = np.zeros_like(self.src, dtype=np.float32)
|
125 |
+
|
126 |
+
for i in np.arange(0, self.dst_h, self.grid_size):
|
127 |
+
for j in np.arange(0, self.dst_w, self.grid_size):
|
128 |
+
ni = i + self.grid_size
|
129 |
+
nj = j + self.grid_size
|
130 |
+
w = h = self.grid_size
|
131 |
+
if ni >= self.dst_h:
|
132 |
+
ni = self.dst_h - 1
|
133 |
+
h = ni - i + 1
|
134 |
+
if nj >= self.dst_w:
|
135 |
+
nj = self.dst_w - 1
|
136 |
+
w = nj - j + 1
|
137 |
+
|
138 |
+
di = np.reshape(np.arange(h), (-1, 1))
|
139 |
+
dj = np.reshape(np.arange(w), (1, -1))
|
140 |
+
delta_x = self.__bilinear_interp(
|
141 |
+
di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
|
142 |
+
self.rdx[ni, j], self.rdx[ni, nj])
|
143 |
+
delta_y = self.__bilinear_interp(
|
144 |
+
di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
|
145 |
+
self.rdy[ni, j], self.rdy[ni, nj])
|
146 |
+
nx = j + dj + delta_x * self.trans_ratio
|
147 |
+
ny = i + di + delta_y * self.trans_ratio
|
148 |
+
nx = np.clip(nx, 0, src_w - 1)
|
149 |
+
ny = np.clip(ny, 0, src_h - 1)
|
150 |
+
nxi = np.array(np.floor(nx), dtype=np.int32)
|
151 |
+
nyi = np.array(np.floor(ny), dtype=np.int32)
|
152 |
+
nxi1 = np.array(np.ceil(nx), dtype=np.int32)
|
153 |
+
nyi1 = np.array(np.ceil(ny), dtype=np.int32)
|
154 |
+
|
155 |
+
if len(self.src.shape) == 3:
|
156 |
+
x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
|
157 |
+
y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
|
158 |
+
else:
|
159 |
+
x = ny - nyi
|
160 |
+
y = nx - nxi
|
161 |
+
dst[i:i + h, j:j + w] = self.__bilinear_interp(
|
162 |
+
x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
|
163 |
+
self.src[nyi1, nxi], self.src[nyi1, nxi1])
|
164 |
+
|
165 |
+
dst = np.clip(dst, 0, 255)
|
166 |
+
dst = np.array(dst, dtype=np.uint8)
|
167 |
+
|
168 |
+
return dst
|
ppocr/data/imaug/vqa/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation, TensorizeEntitiesRelations
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation',
|
19 |
+
'TensorizeEntitiesRelations'
|
20 |
+
]
|
ppocr/data/imaug/vqa/augment.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
import numpy as np
|
18 |
+
import random
|
19 |
+
from copy import deepcopy
|
20 |
+
|
21 |
+
|
22 |
+
def order_by_tbyx(ocr_info):
|
23 |
+
res = sorted(ocr_info, key=lambda r: (r["bbox"][1], r["bbox"][0]))
|
24 |
+
for i in range(len(res) - 1):
|
25 |
+
for j in range(i, 0, -1):
|
26 |
+
if abs(res[j + 1]["bbox"][1] - res[j]["bbox"][1]) < 20 and \
|
27 |
+
(res[j + 1]["bbox"][0] < res[j]["bbox"][0]):
|
28 |
+
tmp = deepcopy(res[j])
|
29 |
+
res[j] = deepcopy(res[j + 1])
|
30 |
+
res[j + 1] = deepcopy(tmp)
|
31 |
+
else:
|
32 |
+
break
|
33 |
+
return res
|
ppocr/data/imaug/vqa/token/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
|
16 |
+
from .vqa_token_pad import VQATokenPad
|
17 |
+
from .vqa_token_relation import VQAReTokenRelation
|
18 |
+
from .vqa_re_convert import TensorizeEntitiesRelations
|
ppocr/data/imaug/vqa/token/vqa_re_convert.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
class TensorizeEntitiesRelations(object):
|
19 |
+
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
|
20 |
+
self.max_seq_len = max_seq_len
|
21 |
+
self.infer_mode = infer_mode
|
22 |
+
|
23 |
+
def __call__(self, data):
|
24 |
+
entities = data['entities']
|
25 |
+
relations = data['relations']
|
26 |
+
|
27 |
+
entities_new = np.full(
|
28 |
+
shape=[self.max_seq_len + 1, 3], fill_value=-1, dtype='int64')
|
29 |
+
entities_new[0, 0] = len(entities['start'])
|
30 |
+
entities_new[0, 1] = len(entities['end'])
|
31 |
+
entities_new[0, 2] = len(entities['label'])
|
32 |
+
entities_new[1:len(entities['start']) + 1, 0] = np.array(entities[
|
33 |
+
'start'])
|
34 |
+
entities_new[1:len(entities['end']) + 1, 1] = np.array(entities['end'])
|
35 |
+
entities_new[1:len(entities['label']) + 1, 2] = np.array(entities[
|
36 |
+
'label'])
|
37 |
+
|
38 |
+
relations_new = np.full(
|
39 |
+
shape=[self.max_seq_len * self.max_seq_len + 1, 2],
|
40 |
+
fill_value=-1,
|
41 |
+
dtype='int64')
|
42 |
+
relations_new[0, 0] = len(relations['head'])
|
43 |
+
relations_new[0, 1] = len(relations['tail'])
|
44 |
+
relations_new[1:len(relations['head']) + 1, 0] = np.array(relations[
|
45 |
+
'head'])
|
46 |
+
relations_new[1:len(relations['tail']) + 1, 1] = np.array(relations[
|
47 |
+
'tail'])
|
48 |
+
|
49 |
+
data['entities'] = entities_new
|
50 |
+
data['relations'] = relations_new
|
51 |
+
return data
|
ppocr/data/imaug/vqa/token/vqa_token_chunk.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from collections import defaultdict
|
16 |
+
|
17 |
+
|
18 |
+
class VQASerTokenChunk(object):
|
19 |
+
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
|
20 |
+
self.max_seq_len = max_seq_len
|
21 |
+
self.infer_mode = infer_mode
|
22 |
+
|
23 |
+
def __call__(self, data):
|
24 |
+
encoded_inputs_all = []
|
25 |
+
seq_len = len(data['input_ids'])
|
26 |
+
for index in range(0, seq_len, self.max_seq_len):
|
27 |
+
chunk_beg = index
|
28 |
+
chunk_end = min(index + self.max_seq_len, seq_len)
|
29 |
+
encoded_inputs_example = {}
|
30 |
+
for key in data:
|
31 |
+
if key in [
|
32 |
+
'label', 'input_ids', 'labels', 'token_type_ids',
|
33 |
+
'bbox', 'attention_mask'
|
34 |
+
]:
|
35 |
+
if self.infer_mode and key == 'labels':
|
36 |
+
encoded_inputs_example[key] = data[key]
|
37 |
+
else:
|
38 |
+
encoded_inputs_example[key] = data[key][chunk_beg:
|
39 |
+
chunk_end]
|
40 |
+
else:
|
41 |
+
encoded_inputs_example[key] = data[key]
|
42 |
+
|
43 |
+
encoded_inputs_all.append(encoded_inputs_example)
|
44 |
+
if len(encoded_inputs_all) == 0:
|
45 |
+
return None
|
46 |
+
return encoded_inputs_all[0]
|
47 |
+
|
48 |
+
|
49 |
+
class VQAReTokenChunk(object):
|
50 |
+
def __init__(self,
|
51 |
+
max_seq_len=512,
|
52 |
+
entities_labels=None,
|
53 |
+
infer_mode=False,
|
54 |
+
**kwargs):
|
55 |
+
self.max_seq_len = max_seq_len
|
56 |
+
self.entities_labels = {
|
57 |
+
'HEADER': 0,
|
58 |
+
'QUESTION': 1,
|
59 |
+
'ANSWER': 2
|
60 |
+
} if entities_labels is None else entities_labels
|
61 |
+
self.infer_mode = infer_mode
|
62 |
+
|
63 |
+
def __call__(self, data):
|
64 |
+
# prepare data
|
65 |
+
entities = data.pop('entities')
|
66 |
+
relations = data.pop('relations')
|
67 |
+
encoded_inputs_all = []
|
68 |
+
for index in range(0, len(data["input_ids"]), self.max_seq_len):
|
69 |
+
item = {}
|
70 |
+
for key in data:
|
71 |
+
if key in [
|
72 |
+
'label', 'input_ids', 'labels', 'token_type_ids',
|
73 |
+
'bbox', 'attention_mask'
|
74 |
+
]:
|
75 |
+
if self.infer_mode and key == 'labels':
|
76 |
+
item[key] = data[key]
|
77 |
+
else:
|
78 |
+
item[key] = data[key][index:index + self.max_seq_len]
|
79 |
+
else:
|
80 |
+
item[key] = data[key]
|
81 |
+
# select entity in current chunk
|
82 |
+
entities_in_this_span = []
|
83 |
+
global_to_local_map = {} #
|
84 |
+
for entity_id, entity in enumerate(entities):
|
85 |
+
if (index <= entity["start"] < index + self.max_seq_len and
|
86 |
+
index <= entity["end"] < index + self.max_seq_len):
|
87 |
+
entity["start"] = entity["start"] - index
|
88 |
+
entity["end"] = entity["end"] - index
|
89 |
+
global_to_local_map[entity_id] = len(entities_in_this_span)
|
90 |
+
entities_in_this_span.append(entity)
|
91 |
+
|
92 |
+
# select relations in current chunk
|
93 |
+
relations_in_this_span = []
|
94 |
+
for relation in relations:
|
95 |
+
if (index <= relation["start_index"] < index + self.max_seq_len
|
96 |
+
and index <= relation["end_index"] <
|
97 |
+
index + self.max_seq_len):
|
98 |
+
relations_in_this_span.append({
|
99 |
+
"head": global_to_local_map[relation["head"]],
|
100 |
+
"tail": global_to_local_map[relation["tail"]],
|
101 |
+
"start_index": relation["start_index"] - index,
|
102 |
+
"end_index": relation["end_index"] - index,
|
103 |
+
})
|
104 |
+
item.update({
|
105 |
+
"entities": self.reformat(entities_in_this_span),
|
106 |
+
"relations": self.reformat(relations_in_this_span),
|
107 |
+
})
|
108 |
+
if len(item['entities']) > 0:
|
109 |
+
item['entities']['label'] = [
|
110 |
+
self.entities_labels[x] for x in item['entities']['label']
|
111 |
+
]
|
112 |
+
encoded_inputs_all.append(item)
|
113 |
+
if len(encoded_inputs_all) == 0:
|
114 |
+
return None
|
115 |
+
return encoded_inputs_all[0]
|
116 |
+
|
117 |
+
def reformat(self, data):
|
118 |
+
new_data = defaultdict(list)
|
119 |
+
for item in data:
|
120 |
+
for k, v in item.items():
|
121 |
+
new_data[k].append(v)
|
122 |
+
return new_data
|
ppocr/data/imaug/vqa/token/vqa_token_pad.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import paddle
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
class VQATokenPad(object):
|
19 |
+
def __init__(self,
|
20 |
+
max_seq_len=512,
|
21 |
+
pad_to_max_seq_len=True,
|
22 |
+
return_attention_mask=True,
|
23 |
+
return_token_type_ids=True,
|
24 |
+
truncation_strategy="longest_first",
|
25 |
+
return_overflowing_tokens=False,
|
26 |
+
return_special_tokens_mask=False,
|
27 |
+
infer_mode=False,
|
28 |
+
**kwargs):
|
29 |
+
self.max_seq_len = max_seq_len
|
30 |
+
self.pad_to_max_seq_len = max_seq_len
|
31 |
+
self.return_attention_mask = return_attention_mask
|
32 |
+
self.return_token_type_ids = return_token_type_ids
|
33 |
+
self.truncation_strategy = truncation_strategy
|
34 |
+
self.return_overflowing_tokens = return_overflowing_tokens
|
35 |
+
self.return_special_tokens_mask = return_special_tokens_mask
|
36 |
+
self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
|
37 |
+
self.infer_mode = infer_mode
|
38 |
+
|
39 |
+
def __call__(self, data):
|
40 |
+
needs_to_be_padded = self.pad_to_max_seq_len and len(data[
|
41 |
+
"input_ids"]) < self.max_seq_len
|
42 |
+
|
43 |
+
if needs_to_be_padded:
|
44 |
+
if 'tokenizer_params' in data:
|
45 |
+
tokenizer_params = data.pop('tokenizer_params')
|
46 |
+
else:
|
47 |
+
tokenizer_params = dict(
|
48 |
+
padding_side='right', pad_token_type_id=0, pad_token_id=1)
|
49 |
+
|
50 |
+
difference = self.max_seq_len - len(data["input_ids"])
|
51 |
+
if tokenizer_params['padding_side'] == 'right':
|
52 |
+
if self.return_attention_mask:
|
53 |
+
data["attention_mask"] = [1] * len(data[
|
54 |
+
"input_ids"]) + [0] * difference
|
55 |
+
if self.return_token_type_ids:
|
56 |
+
data["token_type_ids"] = (
|
57 |
+
data["token_type_ids"] +
|
58 |
+
[tokenizer_params['pad_token_type_id']] * difference)
|
59 |
+
if self.return_special_tokens_mask:
|
60 |
+
data["special_tokens_mask"] = data[
|
61 |
+
"special_tokens_mask"] + [1] * difference
|
62 |
+
data["input_ids"] = data["input_ids"] + [
|
63 |
+
tokenizer_params['pad_token_id']
|
64 |
+
] * difference
|
65 |
+
if not self.infer_mode:
|
66 |
+
data["labels"] = data[
|
67 |
+
"labels"] + [self.pad_token_label_id] * difference
|
68 |
+
data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
|
69 |
+
elif tokenizer_params['padding_side'] == 'left':
|
70 |
+
if self.return_attention_mask:
|
71 |
+
data["attention_mask"] = [0] * difference + [
|
72 |
+
1
|
73 |
+
] * len(data["input_ids"])
|
74 |
+
if self.return_token_type_ids:
|
75 |
+
data["token_type_ids"] = (
|
76 |
+
[tokenizer_params['pad_token_type_id']] * difference +
|
77 |
+
data["token_type_ids"])
|
78 |
+
if self.return_special_tokens_mask:
|
79 |
+
data["special_tokens_mask"] = [
|
80 |
+
1
|
81 |
+
] * difference + data["special_tokens_mask"]
|
82 |
+
data["input_ids"] = [tokenizer_params['pad_token_id']
|
83 |
+
] * difference + data["input_ids"]
|
84 |
+
if not self.infer_mode:
|
85 |
+
data["labels"] = [self.pad_token_label_id
|
86 |
+
] * difference + data["labels"]
|
87 |
+
data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
|
88 |
+
else:
|
89 |
+
if self.return_attention_mask:
|
90 |
+
data["attention_mask"] = [1] * len(data["input_ids"])
|
91 |
+
|
92 |
+
for key in data:
|
93 |
+
if key in [
|
94 |
+
'input_ids', 'labels', 'token_type_ids', 'bbox',
|
95 |
+
'attention_mask'
|
96 |
+
]:
|
97 |
+
if self.infer_mode:
|
98 |
+
if key != 'labels':
|
99 |
+
length = min(len(data[key]), self.max_seq_len)
|
100 |
+
data[key] = data[key][:length]
|
101 |
+
else:
|
102 |
+
continue
|
103 |
+
data[key] = np.array(data[key], dtype='int64')
|
104 |
+
return data
|
ppocr/data/imaug/vqa/token/vqa_token_relation.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
class VQAReTokenRelation(object):
|
17 |
+
def __init__(self, **kwargs):
|
18 |
+
pass
|
19 |
+
|
20 |
+
def __call__(self, data):
|
21 |
+
"""
|
22 |
+
build relations
|
23 |
+
"""
|
24 |
+
entities = data['entities']
|
25 |
+
relations = data['relations']
|
26 |
+
id2label = data.pop('id2label')
|
27 |
+
empty_entity = data.pop('empty_entity')
|
28 |
+
entity_id_to_index_map = data.pop('entity_id_to_index_map')
|
29 |
+
|
30 |
+
relations = list(set(relations))
|
31 |
+
relations = [
|
32 |
+
rel for rel in relations
|
33 |
+
if rel[0] not in empty_entity and rel[1] not in empty_entity
|
34 |
+
]
|
35 |
+
kv_relations = []
|
36 |
+
for rel in relations:
|
37 |
+
pair = [id2label[rel[0]], id2label[rel[1]]]
|
38 |
+
if pair == ["question", "answer"]:
|
39 |
+
kv_relations.append({
|
40 |
+
"head": entity_id_to_index_map[rel[0]],
|
41 |
+
"tail": entity_id_to_index_map[rel[1]]
|
42 |
+
})
|
43 |
+
elif pair == ["answer", "question"]:
|
44 |
+
kv_relations.append({
|
45 |
+
"head": entity_id_to_index_map[rel[1]],
|
46 |
+
"tail": entity_id_to_index_map[rel[0]]
|
47 |
+
})
|
48 |
+
else:
|
49 |
+
continue
|
50 |
+
relations = sorted(
|
51 |
+
[{
|
52 |
+
"head": rel["head"],
|
53 |
+
"tail": rel["tail"],
|
54 |
+
"start_index": self.get_relation_span(rel, entities)[0],
|
55 |
+
"end_index": self.get_relation_span(rel, entities)[1],
|
56 |
+
} for rel in kv_relations],
|
57 |
+
key=lambda x: x["head"], )
|
58 |
+
|
59 |
+
data['relations'] = relations
|
60 |
+
return data
|
61 |
+
|
62 |
+
def get_relation_span(self, rel, entities):
|
63 |
+
bound = []
|
64 |
+
for entity_index in [rel["head"], rel["tail"]]:
|
65 |
+
bound.append(entities[entity_index]["start"])
|
66 |
+
bound.append(entities[entity_index]["end"])
|
67 |
+
return min(bound), max(bound)
|
ppocr/data/lmdb_dataset.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
import os
|
16 |
+
from paddle.io import Dataset
|
17 |
+
import lmdb
|
18 |
+
import cv2
|
19 |
+
import string
|
20 |
+
import six
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
from .imaug import transform, create_operators
|
24 |
+
|
25 |
+
|
26 |
+
class LMDBDataSet(Dataset):
|
27 |
+
def __init__(self, config, mode, logger, seed=None):
|
28 |
+
super(LMDBDataSet, self).__init__()
|
29 |
+
|
30 |
+
global_config = config['Global']
|
31 |
+
dataset_config = config[mode]['dataset']
|
32 |
+
loader_config = config[mode]['loader']
|
33 |
+
batch_size = loader_config['batch_size_per_card']
|
34 |
+
data_dir = dataset_config['data_dir']
|
35 |
+
self.do_shuffle = loader_config['shuffle']
|
36 |
+
|
37 |
+
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
|
38 |
+
logger.info("Initialize indexs of datasets:%s" % data_dir)
|
39 |
+
self.data_idx_order_list = self.dataset_traversal()
|
40 |
+
if self.do_shuffle:
|
41 |
+
np.random.shuffle(self.data_idx_order_list)
|
42 |
+
self.ops = create_operators(dataset_config['transforms'], global_config)
|
43 |
+
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
|
44 |
+
1)
|
45 |
+
|
46 |
+
ratio_list = dataset_config.get("ratio_list", [1.0])
|
47 |
+
self.need_reset = True in [x < 1 for x in ratio_list]
|
48 |
+
|
49 |
+
def load_hierarchical_lmdb_dataset(self, data_dir):
|
50 |
+
lmdb_sets = {}
|
51 |
+
dataset_idx = 0
|
52 |
+
for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
|
53 |
+
if not dirnames:
|
54 |
+
env = lmdb.open(
|
55 |
+
dirpath,
|
56 |
+
max_readers=32,
|
57 |
+
readonly=True,
|
58 |
+
lock=False,
|
59 |
+
readahead=False,
|
60 |
+
meminit=False)
|
61 |
+
txn = env.begin(write=False)
|
62 |
+
num_samples = int(txn.get('num-samples'.encode()))
|
63 |
+
lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
|
64 |
+
"txn":txn, "num_samples":num_samples}
|
65 |
+
dataset_idx += 1
|
66 |
+
return lmdb_sets
|
67 |
+
|
68 |
+
def dataset_traversal(self):
|
69 |
+
lmdb_num = len(self.lmdb_sets)
|
70 |
+
total_sample_num = 0
|
71 |
+
for lno in range(lmdb_num):
|
72 |
+
total_sample_num += self.lmdb_sets[lno]['num_samples']
|
73 |
+
data_idx_order_list = np.zeros((total_sample_num, 2))
|
74 |
+
beg_idx = 0
|
75 |
+
for lno in range(lmdb_num):
|
76 |
+
tmp_sample_num = self.lmdb_sets[lno]['num_samples']
|
77 |
+
end_idx = beg_idx + tmp_sample_num
|
78 |
+
data_idx_order_list[beg_idx:end_idx, 0] = lno
|
79 |
+
data_idx_order_list[beg_idx:end_idx, 1] \
|
80 |
+
= list(range(tmp_sample_num))
|
81 |
+
data_idx_order_list[beg_idx:end_idx, 1] += 1
|
82 |
+
beg_idx = beg_idx + tmp_sample_num
|
83 |
+
return data_idx_order_list
|
84 |
+
|
85 |
+
def get_img_data(self, value):
|
86 |
+
"""get_img_data"""
|
87 |
+
if not value:
|
88 |
+
return None
|
89 |
+
imgdata = np.frombuffer(value, dtype='uint8')
|
90 |
+
if imgdata is None:
|
91 |
+
return None
|
92 |
+
imgori = cv2.imdecode(imgdata, 1)
|
93 |
+
if imgori is None:
|
94 |
+
return None
|
95 |
+
return imgori
|
96 |
+
|
97 |
+
def get_ext_data(self):
|
98 |
+
ext_data_num = 0
|
99 |
+
for op in self.ops:
|
100 |
+
if hasattr(op, 'ext_data_num'):
|
101 |
+
ext_data_num = getattr(op, 'ext_data_num')
|
102 |
+
break
|
103 |
+
load_data_ops = self.ops[:self.ext_op_transform_idx]
|
104 |
+
ext_data = []
|
105 |
+
|
106 |
+
while len(ext_data) < ext_data_num:
|
107 |
+
lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(
|
108 |
+
len(self))]
|
109 |
+
lmdb_idx = int(lmdb_idx)
|
110 |
+
file_idx = int(file_idx)
|
111 |
+
sample_info = self.get_lmdb_sample_info(
|
112 |
+
self.lmdb_sets[lmdb_idx]['txn'], file_idx)
|
113 |
+
if sample_info is None:
|
114 |
+
continue
|
115 |
+
img, label = sample_info
|
116 |
+
data = {'image': img, 'label': label}
|
117 |
+
data = transform(data, load_data_ops)
|
118 |
+
if data is None:
|
119 |
+
continue
|
120 |
+
ext_data.append(data)
|
121 |
+
return ext_data
|
122 |
+
|
123 |
+
def get_lmdb_sample_info(self, txn, index):
|
124 |
+
label_key = 'label-%09d'.encode() % index
|
125 |
+
label = txn.get(label_key)
|
126 |
+
if label is None:
|
127 |
+
return None
|
128 |
+
label = label.decode('utf-8')
|
129 |
+
img_key = 'image-%09d'.encode() % index
|
130 |
+
imgbuf = txn.get(img_key)
|
131 |
+
return imgbuf, label
|
132 |
+
|
133 |
+
def __getitem__(self, idx):
|
134 |
+
lmdb_idx, file_idx = self.data_idx_order_list[idx]
|
135 |
+
lmdb_idx = int(lmdb_idx)
|
136 |
+
file_idx = int(file_idx)
|
137 |
+
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
|
138 |
+
file_idx)
|
139 |
+
if sample_info is None:
|
140 |
+
return self.__getitem__(np.random.randint(self.__len__()))
|
141 |
+
img, label = sample_info
|
142 |
+
data = {'image': img, 'label': label}
|
143 |
+
data['ext_data'] = self.get_ext_data()
|
144 |
+
outs = transform(data, self.ops)
|
145 |
+
if outs is None:
|
146 |
+
return self.__getitem__(np.random.randint(self.__len__()))
|
147 |
+
return outs
|
148 |
+
|
149 |
+
def __len__(self):
|
150 |
+
return self.data_idx_order_list.shape[0]
|
151 |
+
|
152 |
+
|
153 |
+
class LMDBDataSetSR(LMDBDataSet):
|
154 |
+
def buf2PIL(self, txn, key, type='RGB'):
|
155 |
+
imgbuf = txn.get(key)
|
156 |
+
buf = six.BytesIO()
|
157 |
+
buf.write(imgbuf)
|
158 |
+
buf.seek(0)
|
159 |
+
im = Image.open(buf).convert(type)
|
160 |
+
return im
|
161 |
+
|
162 |
+
def str_filt(self, str_, voc_type):
|
163 |
+
alpha_dict = {
|
164 |
+
'digit': string.digits,
|
165 |
+
'lower': string.digits + string.ascii_lowercase,
|
166 |
+
'upper': string.digits + string.ascii_letters,
|
167 |
+
'all': string.digits + string.ascii_letters + string.punctuation
|
168 |
+
}
|
169 |
+
if voc_type == 'lower':
|
170 |
+
str_ = str_.lower()
|
171 |
+
for char in str_:
|
172 |
+
if char not in alpha_dict[voc_type]:
|
173 |
+
str_ = str_.replace(char, '')
|
174 |
+
return str_
|
175 |
+
|
176 |
+
def get_lmdb_sample_info(self, txn, index):
|
177 |
+
self.voc_type = 'upper'
|
178 |
+
self.max_len = 100
|
179 |
+
self.test = False
|
180 |
+
label_key = b'label-%09d' % index
|
181 |
+
word = str(txn.get(label_key).decode())
|
182 |
+
img_HR_key = b'image_hr-%09d' % index # 128*32
|
183 |
+
img_lr_key = b'image_lr-%09d' % index # 64*16
|
184 |
+
try:
|
185 |
+
img_HR = self.buf2PIL(txn, img_HR_key, 'RGB')
|
186 |
+
img_lr = self.buf2PIL(txn, img_lr_key, 'RGB')
|
187 |
+
except IOError or len(word) > self.max_len:
|
188 |
+
return self[index + 1]
|
189 |
+
label_str = self.str_filt(word, self.voc_type)
|
190 |
+
return img_HR, img_lr, label_str
|
191 |
+
|
192 |
+
def __getitem__(self, idx):
|
193 |
+
lmdb_idx, file_idx = self.data_idx_order_list[idx]
|
194 |
+
lmdb_idx = int(lmdb_idx)
|
195 |
+
file_idx = int(file_idx)
|
196 |
+
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
|
197 |
+
file_idx)
|
198 |
+
if sample_info is None:
|
199 |
+
return self.__getitem__(np.random.randint(self.__len__()))
|
200 |
+
img_HR, img_lr, label_str = sample_info
|
201 |
+
data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str}
|
202 |
+
outs = transform(data, self.ops)
|
203 |
+
if outs is None:
|
204 |
+
return self.__getitem__(np.random.randint(self.__len__()))
|
205 |
+
return outs
|
ppocr/data/pgnet_dataset.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
import os
|
16 |
+
from paddle.io import Dataset
|
17 |
+
from .imaug import transform, create_operators
|
18 |
+
import random
|
19 |
+
|
20 |
+
|
21 |
+
class PGDataSet(Dataset):
|
22 |
+
def __init__(self, config, mode, logger, seed=None):
|
23 |
+
super(PGDataSet, self).__init__()
|
24 |
+
|
25 |
+
self.logger = logger
|
26 |
+
self.seed = seed
|
27 |
+
self.mode = mode
|
28 |
+
global_config = config['Global']
|
29 |
+
dataset_config = config[mode]['dataset']
|
30 |
+
loader_config = config[mode]['loader']
|
31 |
+
|
32 |
+
self.delimiter = dataset_config.get('delimiter', '\t')
|
33 |
+
label_file_list = dataset_config.pop('label_file_list')
|
34 |
+
data_source_num = len(label_file_list)
|
35 |
+
ratio_list = dataset_config.get("ratio_list", [1.0])
|
36 |
+
if isinstance(ratio_list, (float, int)):
|
37 |
+
ratio_list = [float(ratio_list)] * int(data_source_num)
|
38 |
+
assert len(
|
39 |
+
ratio_list
|
40 |
+
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
41 |
+
self.data_dir = dataset_config['data_dir']
|
42 |
+
self.do_shuffle = loader_config['shuffle']
|
43 |
+
|
44 |
+
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
45 |
+
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
46 |
+
self.data_idx_order_list = list(range(len(self.data_lines)))
|
47 |
+
if mode.lower() == "train":
|
48 |
+
self.shuffle_data_random()
|
49 |
+
|
50 |
+
self.ops = create_operators(dataset_config['transforms'], global_config)
|
51 |
+
|
52 |
+
self.need_reset = True in [x < 1 for x in ratio_list]
|
53 |
+
|
54 |
+
def shuffle_data_random(self):
|
55 |
+
if self.do_shuffle:
|
56 |
+
random.seed(self.seed)
|
57 |
+
random.shuffle(self.data_lines)
|
58 |
+
return
|
59 |
+
|
60 |
+
def get_image_info_list(self, file_list, ratio_list):
|
61 |
+
if isinstance(file_list, str):
|
62 |
+
file_list = [file_list]
|
63 |
+
data_lines = []
|
64 |
+
for idx, file in enumerate(file_list):
|
65 |
+
with open(file, "rb") as f:
|
66 |
+
lines = f.readlines()
|
67 |
+
if self.mode == "train" or ratio_list[idx] < 1.0:
|
68 |
+
random.seed(self.seed)
|
69 |
+
lines = random.sample(lines,
|
70 |
+
round(len(lines) * ratio_list[idx]))
|
71 |
+
data_lines.extend(lines)
|
72 |
+
return data_lines
|
73 |
+
|
74 |
+
def __getitem__(self, idx):
|
75 |
+
file_idx = self.data_idx_order_list[idx]
|
76 |
+
data_line = self.data_lines[file_idx]
|
77 |
+
img_id = 0
|
78 |
+
try:
|
79 |
+
data_line = data_line.decode('utf-8')
|
80 |
+
substr = data_line.strip("\n").split(self.delimiter)
|
81 |
+
file_name = substr[0]
|
82 |
+
label = substr[1]
|
83 |
+
img_path = os.path.join(self.data_dir, file_name)
|
84 |
+
if self.mode.lower() == 'eval':
|
85 |
+
try:
|
86 |
+
img_id = int(data_line.split(".")[0][7:])
|
87 |
+
except:
|
88 |
+
img_id = 0
|
89 |
+
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
|
90 |
+
if not os.path.exists(img_path):
|
91 |
+
raise Exception("{} does not exist!".format(img_path))
|
92 |
+
with open(data['img_path'], 'rb') as f:
|
93 |
+
img = f.read()
|
94 |
+
data['image'] = img
|
95 |
+
outs = transform(data, self.ops)
|
96 |
+
except Exception as e:
|
97 |
+
self.logger.error(
|
98 |
+
"When parsing line {}, error happened with msg: {}".format(
|
99 |
+
self.data_idx_order_list[idx], e))
|
100 |
+
outs = None
|
101 |
+
if outs is None:
|
102 |
+
return self.__getitem__(np.random.randint(self.__len__()))
|
103 |
+
return outs
|
104 |
+
|
105 |
+
def __len__(self):
|
106 |
+
return len(self.data_idx_order_list)
|
ppocr/data/pubtab_dataset.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
import os
|
16 |
+
import random
|
17 |
+
from paddle.io import Dataset
|
18 |
+
import json
|
19 |
+
from copy import deepcopy
|
20 |
+
|
21 |
+
from .imaug import transform, create_operators
|
22 |
+
|
23 |
+
|
24 |
+
class PubTabDataSet(Dataset):
|
25 |
+
def __init__(self, config, mode, logger, seed=None):
|
26 |
+
super(PubTabDataSet, self).__init__()
|
27 |
+
self.logger = logger
|
28 |
+
|
29 |
+
global_config = config['Global']
|
30 |
+
dataset_config = config[mode]['dataset']
|
31 |
+
loader_config = config[mode]['loader']
|
32 |
+
|
33 |
+
label_file_list = dataset_config.pop('label_file_list')
|
34 |
+
data_source_num = len(label_file_list)
|
35 |
+
ratio_list = dataset_config.get("ratio_list", [1.0])
|
36 |
+
if isinstance(ratio_list, (float, int)):
|
37 |
+
ratio_list = [float(ratio_list)] * int(data_source_num)
|
38 |
+
|
39 |
+
assert len(
|
40 |
+
ratio_list
|
41 |
+
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
42 |
+
|
43 |
+
self.data_dir = dataset_config['data_dir']
|
44 |
+
self.do_shuffle = loader_config['shuffle']
|
45 |
+
|
46 |
+
self.seed = seed
|
47 |
+
self.mode = mode.lower()
|
48 |
+
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
49 |
+
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
50 |
+
# self.check(config['Global']['max_text_length'])
|
51 |
+
|
52 |
+
if mode.lower() == "train" and self.do_shuffle:
|
53 |
+
self.shuffle_data_random()
|
54 |
+
self.ops = create_operators(dataset_config['transforms'], global_config)
|
55 |
+
self.need_reset = True in [x < 1 for x in ratio_list]
|
56 |
+
|
57 |
+
def get_image_info_list(self, file_list, ratio_list):
|
58 |
+
if isinstance(file_list, str):
|
59 |
+
file_list = [file_list]
|
60 |
+
data_lines = []
|
61 |
+
for idx, file in enumerate(file_list):
|
62 |
+
with open(file, "rb") as f:
|
63 |
+
lines = f.readlines()
|
64 |
+
if self.mode == "train" or ratio_list[idx] < 1.0:
|
65 |
+
random.seed(self.seed)
|
66 |
+
lines = random.sample(lines,
|
67 |
+
round(len(lines) * ratio_list[idx]))
|
68 |
+
data_lines.extend(lines)
|
69 |
+
return data_lines
|
70 |
+
|
71 |
+
def check(self, max_text_length):
|
72 |
+
data_lines = []
|
73 |
+
for line in self.data_lines:
|
74 |
+
data_line = line.decode('utf-8').strip("\n")
|
75 |
+
info = json.loads(data_line)
|
76 |
+
file_name = info['filename']
|
77 |
+
cells = info['html']['cells'].copy()
|
78 |
+
structure = info['html']['structure']['tokens'].copy()
|
79 |
+
|
80 |
+
img_path = os.path.join(self.data_dir, file_name)
|
81 |
+
if not os.path.exists(img_path):
|
82 |
+
self.logger.warning("{} does not exist!".format(img_path))
|
83 |
+
continue
|
84 |
+
if len(structure) == 0 or len(structure) > max_text_length:
|
85 |
+
continue
|
86 |
+
# data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name}
|
87 |
+
data_lines.append(line)
|
88 |
+
self.data_lines = data_lines
|
89 |
+
|
90 |
+
def shuffle_data_random(self):
|
91 |
+
if self.do_shuffle:
|
92 |
+
random.seed(self.seed)
|
93 |
+
random.shuffle(self.data_lines)
|
94 |
+
return
|
95 |
+
|
96 |
+
def __getitem__(self, idx):
|
97 |
+
try:
|
98 |
+
data_line = self.data_lines[idx]
|
99 |
+
data_line = data_line.decode('utf-8').strip("\n")
|
100 |
+
info = json.loads(data_line)
|
101 |
+
file_name = info['filename']
|
102 |
+
cells = info['html']['cells'].copy()
|
103 |
+
structure = info['html']['structure']['tokens'].copy()
|
104 |
+
|
105 |
+
img_path = os.path.join(self.data_dir, file_name)
|
106 |
+
if not os.path.exists(img_path):
|
107 |
+
raise Exception("{} does not exist!".format(img_path))
|
108 |
+
data = {
|
109 |
+
'img_path': img_path,
|
110 |
+
'cells': cells,
|
111 |
+
'structure': structure,
|
112 |
+
'file_name': file_name
|
113 |
+
}
|
114 |
+
|
115 |
+
with open(data['img_path'], 'rb') as f:
|
116 |
+
img = f.read()
|
117 |
+
data['image'] = img
|
118 |
+
outs = transform(data, self.ops)
|
119 |
+
except:
|
120 |
+
import traceback
|
121 |
+
err = traceback.format_exc()
|
122 |
+
self.logger.error(
|
123 |
+
"When parsing line {}, error happened with msg: {}".format(
|
124 |
+
data_line, err))
|
125 |
+
outs = None
|
126 |
+
if outs is None:
|
127 |
+
rnd_idx = np.random.randint(self.__len__(
|
128 |
+
)) if self.mode == "train" else (idx + 1) % self.__len__()
|
129 |
+
return self.__getitem__(rnd_idx)
|
130 |
+
return outs
|
131 |
+
|
132 |
+
def __len__(self):
|
133 |
+
return len(self.data_lines)
|
ppocr/data/simple_dataset.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
import os
|
16 |
+
import json
|
17 |
+
import random
|
18 |
+
import traceback
|
19 |
+
from paddle.io import Dataset
|
20 |
+
from .imaug import transform, create_operators
|
21 |
+
|
22 |
+
|
23 |
+
class SimpleDataSet(Dataset):
|
24 |
+
def __init__(self, config, mode, logger, seed=None):
|
25 |
+
super(SimpleDataSet, self).__init__()
|
26 |
+
self.logger = logger
|
27 |
+
self.mode = mode.lower()
|
28 |
+
|
29 |
+
global_config = config['Global']
|
30 |
+
dataset_config = config[mode]['dataset']
|
31 |
+
loader_config = config[mode]['loader']
|
32 |
+
|
33 |
+
self.delimiter = dataset_config.get('delimiter', '\t')
|
34 |
+
label_file_list = dataset_config.pop('label_file_list')
|
35 |
+
data_source_num = len(label_file_list)
|
36 |
+
ratio_list = dataset_config.get("ratio_list", 1.0)
|
37 |
+
if isinstance(ratio_list, (float, int)):
|
38 |
+
ratio_list = [float(ratio_list)] * int(data_source_num)
|
39 |
+
|
40 |
+
assert len(
|
41 |
+
ratio_list
|
42 |
+
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
43 |
+
self.data_dir = dataset_config['data_dir']
|
44 |
+
self.do_shuffle = loader_config['shuffle']
|
45 |
+
self.seed = seed
|
46 |
+
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
47 |
+
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
48 |
+
self.data_idx_order_list = list(range(len(self.data_lines)))
|
49 |
+
if self.mode == "train" and self.do_shuffle:
|
50 |
+
self.shuffle_data_random()
|
51 |
+
self.ops = create_operators(dataset_config['transforms'], global_config)
|
52 |
+
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
|
53 |
+
2)
|
54 |
+
self.need_reset = True in [x < 1 for x in ratio_list]
|
55 |
+
|
56 |
+
def get_image_info_list(self, file_list, ratio_list):
|
57 |
+
if isinstance(file_list, str):
|
58 |
+
file_list = [file_list]
|
59 |
+
data_lines = []
|
60 |
+
for idx, file in enumerate(file_list):
|
61 |
+
with open(file, "rb") as f:
|
62 |
+
lines = f.readlines()
|
63 |
+
if self.mode == "train" or ratio_list[idx] < 1.0:
|
64 |
+
random.seed(self.seed)
|
65 |
+
lines = random.sample(lines,
|
66 |
+
round(len(lines) * ratio_list[idx]))
|
67 |
+
data_lines.extend(lines)
|
68 |
+
return data_lines
|
69 |
+
|
70 |
+
def shuffle_data_random(self):
|
71 |
+
random.seed(self.seed)
|
72 |
+
random.shuffle(self.data_lines)
|
73 |
+
return
|
74 |
+
|
75 |
+
def _try_parse_filename_list(self, file_name):
|
76 |
+
# multiple images -> one gt label
|
77 |
+
if len(file_name) > 0 and file_name[0] == "[":
|
78 |
+
try:
|
79 |
+
info = json.loads(file_name)
|
80 |
+
file_name = random.choice(info)
|
81 |
+
except:
|
82 |
+
pass
|
83 |
+
return file_name
|
84 |
+
|
85 |
+
def get_ext_data(self):
|
86 |
+
ext_data_num = 0
|
87 |
+
for op in self.ops:
|
88 |
+
if hasattr(op, 'ext_data_num'):
|
89 |
+
ext_data_num = getattr(op, 'ext_data_num')
|
90 |
+
break
|
91 |
+
load_data_ops = self.ops[:self.ext_op_transform_idx]
|
92 |
+
ext_data = []
|
93 |
+
|
94 |
+
while len(ext_data) < ext_data_num:
|
95 |
+
file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
|
96 |
+
))]
|
97 |
+
data_line = self.data_lines[file_idx]
|
98 |
+
data_line = data_line.decode('utf-8')
|
99 |
+
substr = data_line.strip("\n").split(self.delimiter)
|
100 |
+
file_name = substr[0]
|
101 |
+
file_name = self._try_parse_filename_list(file_name)
|
102 |
+
label = substr[1]
|
103 |
+
img_path = os.path.join(self.data_dir, file_name)
|
104 |
+
data = {'img_path': img_path, 'label': label}
|
105 |
+
if not os.path.exists(img_path):
|
106 |
+
continue
|
107 |
+
with open(data['img_path'], 'rb') as f:
|
108 |
+
img = f.read()
|
109 |
+
data['image'] = img
|
110 |
+
data = transform(data, load_data_ops)
|
111 |
+
|
112 |
+
if data is None:
|
113 |
+
continue
|
114 |
+
if 'polys' in data.keys():
|
115 |
+
if data['polys'].shape[1] != 4:
|
116 |
+
continue
|
117 |
+
ext_data.append(data)
|
118 |
+
return ext_data
|
119 |
+
|
120 |
+
def __getitem__(self, idx):
|
121 |
+
file_idx = self.data_idx_order_list[idx]
|
122 |
+
data_line = self.data_lines[file_idx]
|
123 |
+
try:
|
124 |
+
data_line = data_line.decode('utf-8')
|
125 |
+
substr = data_line.strip("\n").split(self.delimiter)
|
126 |
+
file_name = substr[0]
|
127 |
+
file_name = self._try_parse_filename_list(file_name)
|
128 |
+
label = substr[1]
|
129 |
+
img_path = os.path.join(self.data_dir, file_name)
|
130 |
+
data = {'img_path': img_path, 'label': label}
|
131 |
+
if not os.path.exists(img_path):
|
132 |
+
raise Exception("{} does not exist!".format(img_path))
|
133 |
+
with open(data['img_path'], 'rb') as f:
|
134 |
+
img = f.read()
|
135 |
+
data['image'] = img
|
136 |
+
data['ext_data'] = self.get_ext_data()
|
137 |
+
outs = transform(data, self.ops)
|
138 |
+
except:
|
139 |
+
self.logger.error(
|
140 |
+
"When parsing line {}, error happened with msg: {}".format(
|
141 |
+
data_line, traceback.format_exc()))
|
142 |
+
outs = None
|
143 |
+
if outs is None:
|
144 |
+
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
145 |
+
rnd_idx = np.random.randint(self.__len__(
|
146 |
+
)) if self.mode == "train" else (idx + 1) % self.__len__()
|
147 |
+
return self.__getitem__(rnd_idx)
|
148 |
+
return outs
|
149 |
+
|
150 |
+
def __len__(self):
|
151 |
+
return len(self.data_idx_order_list)
|
ppocr/ext_op/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .roi_align_rotated.roi_align_rotated import RoIAlignRotated
|
ppocr/ext_op/roi_align_rotated/roi_align_rotated.cc
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
// This code is refer from:
|
3 |
+
// https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cpu/roi_align_rotated.cpp
|
4 |
+
|
5 |
+
#include <cassert>
|
6 |
+
#include <cmath>
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
#include "paddle/extension.h"
|
10 |
+
|
11 |
+
#define PADDLE_WITH_CUDA
|
12 |
+
#define CHECK_INPUT_SAME(x1, x2) \
|
13 |
+
PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.")
|
14 |
+
#define CHECK_INPUT_CPU(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
|
15 |
+
|
16 |
+
template <typename T> struct PreCalc {
|
17 |
+
int pos1;
|
18 |
+
int pos2;
|
19 |
+
int pos3;
|
20 |
+
int pos4;
|
21 |
+
T w1;
|
22 |
+
T w2;
|
23 |
+
T w3;
|
24 |
+
T w4;
|
25 |
+
};
|
26 |
+
|
27 |
+
template <typename T>
|
28 |
+
void pre_calc_for_bilinear_interpolate(
|
29 |
+
const int height, const int width, const int pooled_height,
|
30 |
+
const int pooled_width, const int iy_upper, const int ix_upper,
|
31 |
+
T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w,
|
32 |
+
int roi_bin_grid_h, int roi_bin_grid_w, T roi_center_h, T roi_center_w,
|
33 |
+
T cos_theta, T sin_theta, std::vector<PreCalc<T>> &pre_calc) {
|
34 |
+
int pre_calc_index = 0;
|
35 |
+
for (int ph = 0; ph < pooled_height; ph++) {
|
36 |
+
for (int pw = 0; pw < pooled_width; pw++) {
|
37 |
+
for (int iy = 0; iy < iy_upper; iy++) {
|
38 |
+
const T yy = roi_start_h + ph * bin_size_h +
|
39 |
+
static_cast<T>(iy + .5f) * bin_size_h /
|
40 |
+
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
41 |
+
for (int ix = 0; ix < ix_upper; ix++) {
|
42 |
+
const T xx = roi_start_w + pw * bin_size_w +
|
43 |
+
static_cast<T>(ix + .5f) * bin_size_w /
|
44 |
+
static_cast<T>(roi_bin_grid_w);
|
45 |
+
|
46 |
+
// Rotate by theta around the center and translate
|
47 |
+
// In image space, (y, x) is the order for Right Handed System,
|
48 |
+
// and this is essentially multiplying the point by a rotation matrix
|
49 |
+
// to rotate it counterclockwise through angle theta.
|
50 |
+
T y = yy * cos_theta - xx * sin_theta + roi_center_h;
|
51 |
+
T x = yy * sin_theta + xx * cos_theta + roi_center_w;
|
52 |
+
// deal with: inverse elements are out of feature map boundary
|
53 |
+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
54 |
+
// empty
|
55 |
+
PreCalc<T> pc;
|
56 |
+
pc.pos1 = 0;
|
57 |
+
pc.pos2 = 0;
|
58 |
+
pc.pos3 = 0;
|
59 |
+
pc.pos4 = 0;
|
60 |
+
pc.w1 = 0;
|
61 |
+
pc.w2 = 0;
|
62 |
+
pc.w3 = 0;
|
63 |
+
pc.w4 = 0;
|
64 |
+
pre_calc[pre_calc_index] = pc;
|
65 |
+
pre_calc_index += 1;
|
66 |
+
continue;
|
67 |
+
}
|
68 |
+
|
69 |
+
if (y < 0) {
|
70 |
+
y = 0;
|
71 |
+
}
|
72 |
+
if (x < 0) {
|
73 |
+
x = 0;
|
74 |
+
}
|
75 |
+
|
76 |
+
int y_low = (int)y;
|
77 |
+
int x_low = (int)x;
|
78 |
+
int y_high;
|
79 |
+
int x_high;
|
80 |
+
|
81 |
+
if (y_low >= height - 1) {
|
82 |
+
y_high = y_low = height - 1;
|
83 |
+
y = (T)y_low;
|
84 |
+
} else {
|
85 |
+
y_high = y_low + 1;
|
86 |
+
}
|
87 |
+
|
88 |
+
if (x_low >= width - 1) {
|
89 |
+
x_high = x_low = width - 1;
|
90 |
+
x = (T)x_low;
|
91 |
+
} else {
|
92 |
+
x_high = x_low + 1;
|
93 |
+
}
|
94 |
+
|
95 |
+
T ly = y - y_low;
|
96 |
+
T lx = x - x_low;
|
97 |
+
T hy = 1. - ly, hx = 1. - lx;
|
98 |
+
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
99 |
+
|
100 |
+
// save weights and indices
|
101 |
+
PreCalc<T> pc;
|
102 |
+
pc.pos1 = y_low * width + x_low;
|
103 |
+
pc.pos2 = y_low * width + x_high;
|
104 |
+
pc.pos3 = y_high * width + x_low;
|
105 |
+
pc.pos4 = y_high * width + x_high;
|
106 |
+
pc.w1 = w1;
|
107 |
+
pc.w2 = w2;
|
108 |
+
pc.w3 = w3;
|
109 |
+
pc.w4 = w4;
|
110 |
+
pre_calc[pre_calc_index] = pc;
|
111 |
+
|
112 |
+
pre_calc_index += 1;
|
113 |
+
}
|
114 |
+
}
|
115 |
+
}
|
116 |
+
}
|
117 |
+
}
|
118 |
+
|
119 |
+
template <typename T>
|
120 |
+
void roi_align_rotated_cpu_forward(const int nthreads, const T *input,
|
121 |
+
const T &spatial_scale, const bool aligned,
|
122 |
+
const bool clockwise, const int channels,
|
123 |
+
const int height, const int width,
|
124 |
+
const int pooled_height,
|
125 |
+
const int pooled_width,
|
126 |
+
const int sampling_ratio, const T *rois,
|
127 |
+
T *output) {
|
128 |
+
int n_rois = nthreads / channels / pooled_width / pooled_height;
|
129 |
+
// (n, c, ph, pw) is an element in the pooled output
|
130 |
+
// can be parallelized using omp
|
131 |
+
// #pragma omp parallel for num_threads(32)
|
132 |
+
for (int n = 0; n < n_rois; n++) {
|
133 |
+
int index_n = n * channels * pooled_width * pooled_height;
|
134 |
+
|
135 |
+
const T *current_roi = rois + n * 6;
|
136 |
+
int roi_batch_ind = current_roi[0];
|
137 |
+
|
138 |
+
// Do not use rounding; this implementation detail is critical
|
139 |
+
T offset = aligned ? (T)0.5 : (T)0.0;
|
140 |
+
T roi_center_w = current_roi[1] * spatial_scale - offset;
|
141 |
+
T roi_center_h = current_roi[2] * spatial_scale - offset;
|
142 |
+
T roi_width = current_roi[3] * spatial_scale;
|
143 |
+
T roi_height = current_roi[4] * spatial_scale;
|
144 |
+
T theta = current_roi[5];
|
145 |
+
if (clockwise) {
|
146 |
+
theta = -theta; // If clockwise, the angle needs to be reversed.
|
147 |
+
}
|
148 |
+
T cos_theta = cos(theta);
|
149 |
+
T sin_theta = sin(theta);
|
150 |
+
|
151 |
+
if (aligned) {
|
152 |
+
assert(roi_width >= 0 && roi_height >= 0);
|
153 |
+
} else { // for backward-compatibility only
|
154 |
+
roi_width = std::max(roi_width, (T)1.);
|
155 |
+
roi_height = std::max(roi_height, (T)1.);
|
156 |
+
}
|
157 |
+
|
158 |
+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
159 |
+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
160 |
+
|
161 |
+
// We use roi_bin_grid to sample the grid and mimic integral
|
162 |
+
int roi_bin_grid_h = (sampling_ratio > 0)
|
163 |
+
? sampling_ratio
|
164 |
+
: ceilf(roi_height / pooled_height); // e.g., = 2
|
165 |
+
int roi_bin_grid_w =
|
166 |
+
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
|
167 |
+
|
168 |
+
// We do average (integral) pooling inside a bin
|
169 |
+
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
|
170 |
+
|
171 |
+
// we want to precalculate indices and weights shared by all channels,
|
172 |
+
// this is the key point of optimization
|
173 |
+
std::vector<PreCalc<T>> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
|
174 |
+
pooled_width * pooled_height);
|
175 |
+
|
176 |
+
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
177 |
+
// Appropriate translation needs to be applied after.
|
178 |
+
T roi_start_h = -roi_height / 2.0;
|
179 |
+
T roi_start_w = -roi_width / 2.0;
|
180 |
+
|
181 |
+
pre_calc_for_bilinear_interpolate(
|
182 |
+
height, width, pooled_height, pooled_width, roi_bin_grid_h,
|
183 |
+
roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
|
184 |
+
roi_bin_grid_h, roi_bin_grid_w, roi_center_h, roi_center_w, cos_theta,
|
185 |
+
sin_theta, pre_calc);
|
186 |
+
|
187 |
+
for (int c = 0; c < channels; c++) {
|
188 |
+
int index_n_c = index_n + c * pooled_width * pooled_height;
|
189 |
+
const T *offset_input =
|
190 |
+
input + (roi_batch_ind * channels + c) * height * width;
|
191 |
+
int pre_calc_index = 0;
|
192 |
+
|
193 |
+
for (int ph = 0; ph < pooled_height; ph++) {
|
194 |
+
for (int pw = 0; pw < pooled_width; pw++) {
|
195 |
+
int index = index_n_c + ph * pooled_width + pw;
|
196 |
+
|
197 |
+
T output_val = 0.;
|
198 |
+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
|
199 |
+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
200 |
+
PreCalc<T> pc = pre_calc[pre_calc_index];
|
201 |
+
output_val += pc.w1 * offset_input[pc.pos1] +
|
202 |
+
pc.w2 * offset_input[pc.pos2] +
|
203 |
+
pc.w3 * offset_input[pc.pos3] +
|
204 |
+
pc.w4 * offset_input[pc.pos4];
|
205 |
+
|
206 |
+
pre_calc_index += 1;
|
207 |
+
}
|
208 |
+
}
|
209 |
+
output_val /= count;
|
210 |
+
|
211 |
+
output[index] = output_val;
|
212 |
+
} // for pw
|
213 |
+
} // for ph
|
214 |
+
} // for c
|
215 |
+
} // for n
|
216 |
+
}
|
217 |
+
|
218 |
+
template <typename T>
|
219 |
+
void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
|
220 |
+
T &w1, T &w2, T &w3, T &w4, int &x_low,
|
221 |
+
int &x_high, int &y_low, int &y_high) {
|
222 |
+
// deal with cases that inverse elements are out of feature map boundary
|
223 |
+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
224 |
+
// empty
|
225 |
+
w1 = w2 = w3 = w4 = 0.;
|
226 |
+
x_low = x_high = y_low = y_high = -1;
|
227 |
+
return;
|
228 |
+
}
|
229 |
+
|
230 |
+
if (y < 0) {
|
231 |
+
y = 0;
|
232 |
+
}
|
233 |
+
|
234 |
+
if (x < 0) {
|
235 |
+
x = 0;
|
236 |
+
}
|
237 |
+
|
238 |
+
y_low = (int)y;
|
239 |
+
x_low = (int)x;
|
240 |
+
|
241 |
+
if (y_low >= height - 1) {
|
242 |
+
y_high = y_low = height - 1;
|
243 |
+
y = (T)y_low;
|
244 |
+
} else {
|
245 |
+
y_high = y_low + 1;
|
246 |
+
}
|
247 |
+
|
248 |
+
if (x_low >= width - 1) {
|
249 |
+
x_high = x_low = width - 1;
|
250 |
+
x = (T)x_low;
|
251 |
+
} else {
|
252 |
+
x_high = x_low + 1;
|
253 |
+
}
|
254 |
+
|
255 |
+
T ly = y - y_low;
|
256 |
+
T lx = x - x_low;
|
257 |
+
T hy = 1. - ly, hx = 1. - lx;
|
258 |
+
|
259 |
+
// reference in forward
|
260 |
+
// T v1 = input[y_low * width + x_low];
|
261 |
+
// T v2 = input[y_low * width + x_high];
|
262 |
+
// T v3 = input[y_high * width + x_low];
|
263 |
+
// T v4 = input[y_high * width + x_high];
|
264 |
+
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
265 |
+
|
266 |
+
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
267 |
+
|
268 |
+
return;
|
269 |
+
}
|
270 |
+
|
271 |
+
template <class T> inline void add(T *address, const T &val) {
|
272 |
+
*address += val;
|
273 |
+
}
|
274 |
+
|
275 |
+
template <typename T>
|
276 |
+
void roi_align_rotated_cpu_backward(
|
277 |
+
const int nthreads,
|
278 |
+
// may not be contiguous. should index using n_stride, etc
|
279 |
+
const T *grad_output, const T &spatial_scale, const bool aligned,
|
280 |
+
const bool clockwise, const int channels, const int height, const int width,
|
281 |
+
const int pooled_height, const int pooled_width, const int sampling_ratio,
|
282 |
+
T *grad_input, const T *rois, const int n_stride, const int c_stride,
|
283 |
+
const int h_stride, const int w_stride) {
|
284 |
+
for (int index = 0; index < nthreads; index++) {
|
285 |
+
// (n, c, ph, pw) is an element in the pooled output
|
286 |
+
int pw = index % pooled_width;
|
287 |
+
int ph = (index / pooled_width) % pooled_height;
|
288 |
+
int c = (index / pooled_width / pooled_height) % channels;
|
289 |
+
int n = index / pooled_width / pooled_height / channels;
|
290 |
+
|
291 |
+
const T *current_roi = rois + n * 6;
|
292 |
+
int roi_batch_ind = current_roi[0];
|
293 |
+
|
294 |
+
// Do not use rounding; this implementation detail is critical
|
295 |
+
T offset = aligned ? (T)0.5 : (T)0.0;
|
296 |
+
T roi_center_w = current_roi[1] * spatial_scale - offset;
|
297 |
+
T roi_center_h = current_roi[2] * spatial_scale - offset;
|
298 |
+
T roi_width = current_roi[3] * spatial_scale;
|
299 |
+
T roi_height = current_roi[4] * spatial_scale;
|
300 |
+
T theta = current_roi[5];
|
301 |
+
if (clockwise) {
|
302 |
+
theta = -theta; // If clockwise, the angle needs to be reversed.
|
303 |
+
}
|
304 |
+
T cos_theta = cos(theta);
|
305 |
+
T sin_theta = sin(theta);
|
306 |
+
|
307 |
+
if (aligned) {
|
308 |
+
assert(roi_width >= 0 && roi_height >= 0);
|
309 |
+
} else { // for backward-compatibility only
|
310 |
+
roi_width = std::max(roi_width, (T)1.);
|
311 |
+
roi_height = std::max(roi_height, (T)1.);
|
312 |
+
}
|
313 |
+
|
314 |
+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
315 |
+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
316 |
+
|
317 |
+
T *offset_grad_input =
|
318 |
+
grad_input + ((roi_batch_ind * channels + c) * height * width);
|
319 |
+
|
320 |
+
int output_offset = n * n_stride + c * c_stride;
|
321 |
+
const T *offset_grad_output = grad_output + output_offset;
|
322 |
+
const T grad_output_this_bin =
|
323 |
+
offset_grad_output[ph * h_stride + pw * w_stride];
|
324 |
+
|
325 |
+
// We use roi_bin_grid to sample the grid and mimic integral
|
326 |
+
int roi_bin_grid_h = (sampling_ratio > 0)
|
327 |
+
? sampling_ratio
|
328 |
+
: ceilf(roi_height / pooled_height); // e.g., = 2
|
329 |
+
int roi_bin_grid_w =
|
330 |
+
(sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
|
331 |
+
|
332 |
+
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
333 |
+
// Appropriate translation needs to be applied after.
|
334 |
+
T roi_start_h = -roi_height / 2.0;
|
335 |
+
T roi_start_w = -roi_width / 2.0;
|
336 |
+
|
337 |
+
// We do average (integral) pooling inside a bin
|
338 |
+
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
339 |
+
|
340 |
+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
|
341 |
+
const T yy = roi_start_h + ph * bin_size_h +
|
342 |
+
static_cast<T>(iy + .5f) * bin_size_h /
|
343 |
+
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
344 |
+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
345 |
+
const T xx = roi_start_w + pw * bin_size_w +
|
346 |
+
static_cast<T>(ix + .5f) * bin_size_w /
|
347 |
+
static_cast<T>(roi_bin_grid_w);
|
348 |
+
|
349 |
+
// Rotate by theta around the center and translate
|
350 |
+
T y = yy * cos_theta - xx * sin_theta + roi_center_h;
|
351 |
+
T x = yy * sin_theta + xx * cos_theta + roi_center_w;
|
352 |
+
|
353 |
+
T w1, w2, w3, w4;
|
354 |
+
int x_low, x_high, y_low, y_high;
|
355 |
+
|
356 |
+
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
|
357 |
+
x_low, x_high, y_low, y_high);
|
358 |
+
|
359 |
+
T g1 = grad_output_this_bin * w1 / count;
|
360 |
+
T g2 = grad_output_this_bin * w2 / count;
|
361 |
+
T g3 = grad_output_this_bin * w3 / count;
|
362 |
+
T g4 = grad_output_this_bin * w4 / count;
|
363 |
+
|
364 |
+
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
365 |
+
// atomic add is not needed for now since it is single threaded
|
366 |
+
add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
|
367 |
+
add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
|
368 |
+
add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
|
369 |
+
add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
|
370 |
+
} // if
|
371 |
+
} // ix
|
372 |
+
} // iy
|
373 |
+
} // for
|
374 |
+
} // ROIAlignRotatedBackward
|
375 |
+
|
376 |
+
std::vector<paddle::Tensor>
|
377 |
+
RoIAlignRotatedCPUForward(const paddle::Tensor &input,
|
378 |
+
const paddle::Tensor &rois, int aligned_height,
|
379 |
+
int aligned_width, float spatial_scale,
|
380 |
+
int sampling_ratio, bool aligned, bool clockwise) {
|
381 |
+
CHECK_INPUT_CPU(input);
|
382 |
+
CHECK_INPUT_CPU(rois);
|
383 |
+
|
384 |
+
auto num_rois = rois.shape()[0];
|
385 |
+
|
386 |
+
auto channels = input.shape()[1];
|
387 |
+
auto height = input.shape()[2];
|
388 |
+
auto width = input.shape()[3];
|
389 |
+
|
390 |
+
auto output =
|
391 |
+
paddle::empty({num_rois, channels, aligned_height, aligned_width},
|
392 |
+
input.type(), paddle::CPUPlace());
|
393 |
+
auto output_size = output.numel();
|
394 |
+
|
395 |
+
PD_DISPATCH_FLOATING_TYPES(
|
396 |
+
input.type(), "roi_align_rotated_cpu_forward", ([&] {
|
397 |
+
roi_align_rotated_cpu_forward<data_t>(
|
398 |
+
output_size, input.data<data_t>(),
|
399 |
+
static_cast<data_t>(spatial_scale), aligned, clockwise, channels,
|
400 |
+
height, width, aligned_height, aligned_width, sampling_ratio,
|
401 |
+
rois.data<data_t>(), output.data<data_t>());
|
402 |
+
}));
|
403 |
+
|
404 |
+
return {output};
|
405 |
+
}
|
406 |
+
|
407 |
+
std::vector<paddle::Tensor> RoIAlignRotatedCPUBackward(
|
408 |
+
const paddle::Tensor &input, const paddle::Tensor &rois,
|
409 |
+
const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
|
410 |
+
float spatial_scale, int sampling_ratio, bool aligned, bool clockwise) {
|
411 |
+
|
412 |
+
auto batch_size = input.shape()[0];
|
413 |
+
auto channels = input.shape()[1];
|
414 |
+
auto height = input.shape()[2];
|
415 |
+
auto width = input.shape()[3];
|
416 |
+
|
417 |
+
auto grad_input = paddle::full({batch_size, channels, height, width}, 0.0,
|
418 |
+
input.type(), paddle::CPUPlace());
|
419 |
+
|
420 |
+
// get stride values to ensure indexing into gradients is correct.
|
421 |
+
int n_stride = grad_output.shape()[0];
|
422 |
+
int c_stride = grad_output.shape()[1];
|
423 |
+
int h_stride = grad_output.shape()[2];
|
424 |
+
int w_stride = grad_output.shape()[3];
|
425 |
+
|
426 |
+
PD_DISPATCH_FLOATING_TYPES(
|
427 |
+
grad_output.type(), "roi_align_rotated_cpu_backward", [&] {
|
428 |
+
roi_align_rotated_cpu_backward<data_t>(
|
429 |
+
grad_output.numel(), grad_output.data<data_t>(),
|
430 |
+
static_cast<data_t>(spatial_scale), aligned, clockwise, channels,
|
431 |
+
height, width, aligned_height, aligned_width, sampling_ratio,
|
432 |
+
grad_input.data<data_t>(), rois.data<data_t>(), n_stride, c_stride,
|
433 |
+
h_stride, w_stride);
|
434 |
+
});
|
435 |
+
return {grad_input};
|
436 |
+
}
|
437 |
+
|
438 |
+
#ifdef PADDLE_WITH_CUDA
|
439 |
+
std::vector<paddle::Tensor>
|
440 |
+
RoIAlignRotatedCUDAForward(const paddle::Tensor &input,
|
441 |
+
const paddle::Tensor &rois, int aligned_height,
|
442 |
+
int aligned_width, float spatial_scale,
|
443 |
+
int sampling_ratio, bool aligned, bool clockwise);
|
444 |
+
#endif
|
445 |
+
|
446 |
+
#ifdef PADDLE_WITH_CUDA
|
447 |
+
std::vector<paddle::Tensor> RoIAlignRotatedCUDABackward(
|
448 |
+
const paddle::Tensor &input, const paddle::Tensor &rois,
|
449 |
+
const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
|
450 |
+
float spatial_scale, int sampling_ratio, bool aligned, bool clockwise);
|
451 |
+
#endif
|
452 |
+
|
453 |
+
std::vector<paddle::Tensor>
|
454 |
+
RoIAlignRotatedForward(const paddle::Tensor &input, const paddle::Tensor &rois,
|
455 |
+
int aligned_height, int aligned_width,
|
456 |
+
float spatial_scale, int sampling_ratio, bool aligned,
|
457 |
+
bool clockwise) {
|
458 |
+
CHECK_INPUT_SAME(input, rois);
|
459 |
+
if (input.is_cpu()) {
|
460 |
+
return RoIAlignRotatedCPUForward(input, rois, aligned_height, aligned_width,
|
461 |
+
spatial_scale, sampling_ratio, aligned,
|
462 |
+
clockwise);
|
463 |
+
#ifdef PADDLE_WITH_CUDA
|
464 |
+
} else if (input.is_gpu()) {
|
465 |
+
return RoIAlignRotatedCUDAForward(input, rois, aligned_height,
|
466 |
+
aligned_width, spatial_scale,
|
467 |
+
sampling_ratio, aligned, clockwise);
|
468 |
+
#endif
|
469 |
+
} else {
|
470 |
+
PD_THROW("Unsupported device type for forward function of roi align "
|
471 |
+
"rotated operator.");
|
472 |
+
}
|
473 |
+
}
|
474 |
+
|
475 |
+
std::vector<paddle::Tensor>
|
476 |
+
RoIAlignRotatedBackward(const paddle::Tensor &input, const paddle::Tensor &rois,
|
477 |
+
const paddle::Tensor &grad_output, int aligned_height,
|
478 |
+
int aligned_width, float spatial_scale,
|
479 |
+
int sampling_ratio, bool aligned, bool clockwise) {
|
480 |
+
CHECK_INPUT_SAME(input, rois);
|
481 |
+
if (input.is_cpu()) {
|
482 |
+
return RoIAlignRotatedCPUBackward(input, rois, grad_output, aligned_height,
|
483 |
+
aligned_width, spatial_scale,
|
484 |
+
sampling_ratio, aligned, clockwise);
|
485 |
+
#ifdef PADDLE_WITH_CUDA
|
486 |
+
} else if (input.is_gpu()) {
|
487 |
+
return RoIAlignRotatedCUDABackward(input, rois, grad_output, aligned_height,
|
488 |
+
aligned_width, spatial_scale,
|
489 |
+
sampling_ratio, aligned, clockwise);
|
490 |
+
#endif
|
491 |
+
} else {
|
492 |
+
PD_THROW("Unsupported device type for forward function of roi align "
|
493 |
+
"rotated operator.");
|
494 |
+
}
|
495 |
+
}
|
496 |
+
|
497 |
+
std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> input_shape,
|
498 |
+
std::vector<int64_t> rois_shape) {
|
499 |
+
return {{rois_shape[0], input_shape[1], input_shape[2], input_shape[3]}};
|
500 |
+
}
|
501 |
+
|
502 |
+
std::vector<std::vector<int64_t>>
|
503 |
+
InferBackShape(std::vector<int64_t> input_shape,
|
504 |
+
std::vector<int64_t> rois_shape) {
|
505 |
+
return {input_shape};
|
506 |
+
}
|
507 |
+
|
508 |
+
std::vector<paddle::DataType> InferDtype(paddle::DataType input_dtype,
|
509 |
+
paddle::DataType rois_dtype) {
|
510 |
+
return {input_dtype};
|
511 |
+
}
|
512 |
+
|
513 |
+
PD_BUILD_OP(roi_align_rotated)
|
514 |
+
.Inputs({"Input", "Rois"})
|
515 |
+
.Outputs({"Output"})
|
516 |
+
.Attrs({"aligned_height: int", "aligned_width: int", "spatial_scale: float",
|
517 |
+
"sampling_ratio: int", "aligned: bool", "clockwise: bool"})
|
518 |
+
.SetKernelFn(PD_KERNEL(RoIAlignRotatedForward))
|
519 |
+
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
|
520 |
+
.SetInferDtypeFn(PD_INFER_DTYPE(InferDtype));
|
521 |
+
|
522 |
+
PD_BUILD_GRAD_OP(roi_align_rotated)
|
523 |
+
.Inputs({"Input", "Rois", paddle::Grad("Output")})
|
524 |
+
.Attrs({"aligned_height: int", "aligned_width: int", "spatial_scale: float",
|
525 |
+
"sampling_ratio: int", "aligned: bool", "clockwise: bool"})
|
526 |
+
.Outputs({paddle::Grad("Input")})
|
527 |
+
.SetKernelFn(PD_KERNEL(RoIAlignRotatedBackward))
|
528 |
+
.SetInferShapeFn(PD_INFER_SHAPE(InferBackShape));
|
ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
// This code is refer from:
|
3 |
+
// https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/common/cuda/roi_align_rotated_cuda_kernel.cuh
|
4 |
+
|
5 |
+
#include <cassert>
|
6 |
+
#include <cmath>
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
#include "paddle/extension.h"
|
10 |
+
#include <cuda.h>
|
11 |
+
|
12 |
+
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
13 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
14 |
+
i += blockDim.x * gridDim.x)
|
15 |
+
|
16 |
+
#define THREADS_PER_BLOCK 512
|
17 |
+
|
18 |
+
inline int GET_BLOCKS(const int N) {
|
19 |
+
int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
|
20 |
+
int max_block_num = 4096;
|
21 |
+
return min(optimal_block_num, max_block_num);
|
22 |
+
}
|
23 |
+
|
24 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
|
25 |
+
|
26 |
+
static __inline__ __device__ double atomicAdd(double *address, double val) {
|
27 |
+
unsigned long long int *address_as_ull = (unsigned long long int *)address;
|
28 |
+
unsigned long long int old = *address_as_ull, assumed;
|
29 |
+
if (val == 0.0)
|
30 |
+
return __longlong_as_double(old);
|
31 |
+
do {
|
32 |
+
assumed = old;
|
33 |
+
old = atomicCAS(address_as_ull, assumed,
|
34 |
+
__double_as_longlong(val + __longlong_as_double(assumed)));
|
35 |
+
} while (assumed != old);
|
36 |
+
return __longlong_as_double(old);
|
37 |
+
}
|
38 |
+
|
39 |
+
#endif
|
40 |
+
|
41 |
+
template <typename T>
|
42 |
+
__device__ T bilinear_interpolate(const T *input, const int height,
|
43 |
+
const int width, T y, T x,
|
44 |
+
const int index /* index for debug only*/) {
|
45 |
+
// deal with cases that inverse elements are out of feature map boundary
|
46 |
+
if (y < -1.0 || y > height || x < -1.0 || x > width)
|
47 |
+
return 0;
|
48 |
+
|
49 |
+
if (y <= 0)
|
50 |
+
y = 0;
|
51 |
+
if (x <= 0)
|
52 |
+
x = 0;
|
53 |
+
|
54 |
+
int y_low = (int)y;
|
55 |
+
int x_low = (int)x;
|
56 |
+
int y_high;
|
57 |
+
int x_high;
|
58 |
+
|
59 |
+
if (y_low >= height - 1) {
|
60 |
+
y_high = y_low = height - 1;
|
61 |
+
y = (T)y_low;
|
62 |
+
} else {
|
63 |
+
y_high = y_low + 1;
|
64 |
+
}
|
65 |
+
|
66 |
+
if (x_low >= width - 1) {
|
67 |
+
x_high = x_low = width - 1;
|
68 |
+
x = (T)x_low;
|
69 |
+
} else {
|
70 |
+
x_high = x_low + 1;
|
71 |
+
}
|
72 |
+
|
73 |
+
T ly = y - y_low;
|
74 |
+
T lx = x - x_low;
|
75 |
+
T hy = 1. - ly, hx = 1. - lx;
|
76 |
+
// do bilinear interpolation
|
77 |
+
T v1 = input[y_low * width + x_low];
|
78 |
+
T v2 = input[y_low * width + x_high];
|
79 |
+
T v3 = input[y_high * width + x_low];
|
80 |
+
T v4 = input[y_high * width + x_high];
|
81 |
+
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
82 |
+
|
83 |
+
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
84 |
+
|
85 |
+
return val;
|
86 |
+
}
|
87 |
+
|
88 |
+
template <typename T>
|
89 |
+
__device__ void
|
90 |
+
bilinear_interpolate_gradient(const int height, const int width, T y, T x,
|
91 |
+
T &w1, T &w2, T &w3, T &w4, int &x_low,
|
92 |
+
int &x_high, int &y_low, int &y_high,
|
93 |
+
const int index /* index for debug only*/) {
|
94 |
+
// deal with cases that inverse elements are out of feature map boundary
|
95 |
+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
96 |
+
// empty
|
97 |
+
w1 = w2 = w3 = w4 = 0.;
|
98 |
+
x_low = x_high = y_low = y_high = -1;
|
99 |
+
return;
|
100 |
+
}
|
101 |
+
|
102 |
+
if (y <= 0)
|
103 |
+
y = 0;
|
104 |
+
if (x <= 0)
|
105 |
+
x = 0;
|
106 |
+
|
107 |
+
y_low = (int)y;
|
108 |
+
x_low = (int)x;
|
109 |
+
|
110 |
+
if (y_low >= height - 1) {
|
111 |
+
y_high = y_low = height - 1;
|
112 |
+
y = (T)y_low;
|
113 |
+
} else {
|
114 |
+
y_high = y_low + 1;
|
115 |
+
}
|
116 |
+
|
117 |
+
if (x_low >= width - 1) {
|
118 |
+
x_high = x_low = width - 1;
|
119 |
+
x = (T)x_low;
|
120 |
+
} else {
|
121 |
+
x_high = x_low + 1;
|
122 |
+
}
|
123 |
+
|
124 |
+
T ly = y - y_low;
|
125 |
+
T lx = x - x_low;
|
126 |
+
T hy = 1. - ly, hx = 1. - lx;
|
127 |
+
|
128 |
+
// reference in forward
|
129 |
+
// T v1 = input[y_low * width + x_low];
|
130 |
+
// T v2 = input[y_low * width + x_high];
|
131 |
+
// T v3 = input[y_high * width + x_low];
|
132 |
+
// T v4 = input[y_high * width + x_high];
|
133 |
+
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
134 |
+
|
135 |
+
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
136 |
+
|
137 |
+
return;
|
138 |
+
}
|
139 |
+
|
140 |
+
/*** Forward ***/
|
141 |
+
template <typename scalar_t>
|
142 |
+
__global__ void roi_align_rotated_cuda_forward_kernel(
|
143 |
+
const int nthreads, const scalar_t *bottom_data,
|
144 |
+
const scalar_t *bottom_rois, const scalar_t spatial_scale,
|
145 |
+
const int sample_num, const bool aligned, const bool clockwise,
|
146 |
+
const int channels, const int height, const int width,
|
147 |
+
const int pooled_height, const int pooled_width, scalar_t *top_data) {
|
148 |
+
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
149 |
+
// (n, c, ph, pw) is an element in the pooled output
|
150 |
+
int pw = index % pooled_width;
|
151 |
+
int ph = (index / pooled_width) % pooled_height;
|
152 |
+
int c = (index / pooled_width / pooled_height) % channels;
|
153 |
+
int n = index / pooled_width / pooled_height / channels;
|
154 |
+
|
155 |
+
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
|
156 |
+
int roi_batch_ind = offset_bottom_rois[0];
|
157 |
+
|
158 |
+
// Do not using rounding; this implementation detail is critical
|
159 |
+
scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
|
160 |
+
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
|
161 |
+
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
|
162 |
+
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
|
163 |
+
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
|
164 |
+
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
|
165 |
+
scalar_t theta = offset_bottom_rois[5];
|
166 |
+
if (clockwise) {
|
167 |
+
theta = -theta; // If clockwise, the angle needs to be reversed.
|
168 |
+
}
|
169 |
+
if (!aligned) { // for backward-compatibility only
|
170 |
+
// Force malformed ROIs to be 1x1
|
171 |
+
roi_width = max(roi_width, (scalar_t)1.);
|
172 |
+
roi_height = max(roi_height, (scalar_t)1.);
|
173 |
+
}
|
174 |
+
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
|
175 |
+
static_cast<scalar_t>(pooled_height);
|
176 |
+
scalar_t bin_size_w =
|
177 |
+
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
|
178 |
+
|
179 |
+
const scalar_t *offset_bottom_data =
|
180 |
+
bottom_data + (roi_batch_ind * channels + c) * height * width;
|
181 |
+
|
182 |
+
// We use roi_bin_grid to sample the grid and mimic integral
|
183 |
+
int roi_bin_grid_h = (sample_num > 0)
|
184 |
+
? sample_num
|
185 |
+
: ceilf(roi_height / pooled_height); // e.g., = 2
|
186 |
+
int roi_bin_grid_w =
|
187 |
+
(sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
|
188 |
+
|
189 |
+
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
190 |
+
// Appropriate translation needs to be applied after.
|
191 |
+
scalar_t roi_start_h = -roi_height / 2.0;
|
192 |
+
scalar_t roi_start_w = -roi_width / 2.0;
|
193 |
+
scalar_t cosscalar_theta = cos(theta);
|
194 |
+
scalar_t sinscalar_theta = sin(theta);
|
195 |
+
|
196 |
+
// We do average (integral) pooling inside a bin
|
197 |
+
const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
|
198 |
+
|
199 |
+
scalar_t output_val = 0.;
|
200 |
+
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
|
201 |
+
const scalar_t yy =
|
202 |
+
roi_start_h + ph * bin_size_h +
|
203 |
+
static_cast<scalar_t>(iy + .5f) * bin_size_h /
|
204 |
+
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
205 |
+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
206 |
+
const scalar_t xx = roi_start_w + pw * bin_size_w +
|
207 |
+
static_cast<scalar_t>(ix + .5f) * bin_size_w /
|
208 |
+
static_cast<scalar_t>(roi_bin_grid_w);
|
209 |
+
|
210 |
+
// Rotate by theta (counterclockwise) around the center and translate
|
211 |
+
scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
|
212 |
+
scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
|
213 |
+
|
214 |
+
scalar_t val = bilinear_interpolate<scalar_t>(
|
215 |
+
offset_bottom_data, height, width, y, x, index);
|
216 |
+
output_val += val;
|
217 |
+
}
|
218 |
+
}
|
219 |
+
output_val /= count;
|
220 |
+
|
221 |
+
top_data[index] = output_val;
|
222 |
+
}
|
223 |
+
}
|
224 |
+
|
225 |
+
/*** Backward ***/
|
226 |
+
template <typename scalar_t>
|
227 |
+
__global__ void roi_align_rotated_backward_cuda_kernel(
|
228 |
+
const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
|
229 |
+
const scalar_t spatial_scale, const int sample_num, const bool aligned,
|
230 |
+
const bool clockwise, const int channels, const int height, const int width,
|
231 |
+
const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
|
232 |
+
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
233 |
+
// (n, c, ph, pw) is an element in the pooled output
|
234 |
+
int pw = index % pooled_width;
|
235 |
+
int ph = (index / pooled_width) % pooled_height;
|
236 |
+
int c = (index / pooled_width / pooled_height) % channels;
|
237 |
+
int n = index / pooled_width / pooled_height / channels;
|
238 |
+
|
239 |
+
const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
|
240 |
+
int roi_batch_ind = offset_bottom_rois[0];
|
241 |
+
|
242 |
+
// Do not round
|
243 |
+
scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
|
244 |
+
scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
|
245 |
+
scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
|
246 |
+
scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
|
247 |
+
scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
|
248 |
+
// scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
|
249 |
+
scalar_t theta = offset_bottom_rois[5];
|
250 |
+
if (clockwise) {
|
251 |
+
theta = -theta; // If clockwise, the angle needs to be reversed.
|
252 |
+
}
|
253 |
+
if (!aligned) { // for backward-compatibility only
|
254 |
+
// Force malformed ROIs to be 1x1
|
255 |
+
roi_width = max(roi_width, (scalar_t)1.);
|
256 |
+
roi_height = max(roi_height, (scalar_t)1.);
|
257 |
+
}
|
258 |
+
scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
|
259 |
+
static_cast<scalar_t>(pooled_height);
|
260 |
+
scalar_t bin_size_w =
|
261 |
+
static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
|
262 |
+
|
263 |
+
scalar_t *offset_bottom_diff =
|
264 |
+
bottom_diff + (roi_batch_ind * channels + c) * height * width;
|
265 |
+
|
266 |
+
int top_offset = (n * channels + c) * pooled_height * pooled_width;
|
267 |
+
const scalar_t *offset_top_diff = top_diff + top_offset;
|
268 |
+
const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
|
269 |
+
|
270 |
+
// We use roi_bin_grid to sample the grid and mimic integral
|
271 |
+
int roi_bin_grid_h = (sample_num > 0)
|
272 |
+
? sample_num
|
273 |
+
: ceilf(roi_height / pooled_height); // e.g., = 2
|
274 |
+
int roi_bin_grid_w =
|
275 |
+
(sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
|
276 |
+
|
277 |
+
// roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
|
278 |
+
// Appropriate translation needs to be applied after.
|
279 |
+
scalar_t roi_start_h = -roi_height / 2.0;
|
280 |
+
scalar_t roi_start_w = -roi_width / 2.0;
|
281 |
+
scalar_t cosTheta = cos(theta);
|
282 |
+
scalar_t sinTheta = sin(theta);
|
283 |
+
|
284 |
+
// We do average (integral) pooling inside a bin
|
285 |
+
const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
286 |
+
|
287 |
+
for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
|
288 |
+
const scalar_t yy =
|
289 |
+
roi_start_h + ph * bin_size_h +
|
290 |
+
static_cast<scalar_t>(iy + .5f) * bin_size_h /
|
291 |
+
static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
292 |
+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
293 |
+
const scalar_t xx = roi_start_w + pw * bin_size_w +
|
294 |
+
static_cast<scalar_t>(ix + .5f) * bin_size_w /
|
295 |
+
static_cast<scalar_t>(roi_bin_grid_w);
|
296 |
+
|
297 |
+
// Rotate by theta around the center and translate
|
298 |
+
scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
|
299 |
+
scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
|
300 |
+
|
301 |
+
scalar_t w1, w2, w3, w4;
|
302 |
+
int x_low, x_high, y_low, y_high;
|
303 |
+
|
304 |
+
bilinear_interpolate_gradient<scalar_t>(height, width, y, x, w1, w2, w3,
|
305 |
+
w4, x_low, x_high, y_low,
|
306 |
+
y_high, index);
|
307 |
+
|
308 |
+
scalar_t g1 = top_diff_this_bin * w1 / count;
|
309 |
+
scalar_t g2 = top_diff_this_bin * w2 / count;
|
310 |
+
scalar_t g3 = top_diff_this_bin * w3 / count;
|
311 |
+
scalar_t g4 = top_diff_this_bin * w4 / count;
|
312 |
+
|
313 |
+
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
|
314 |
+
atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
|
315 |
+
atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
|
316 |
+
atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
|
317 |
+
atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
|
318 |
+
} // if
|
319 |
+
} // ix
|
320 |
+
} // iy
|
321 |
+
} // CUDA_1D_KERNEL_LOOP
|
322 |
+
} // RoIAlignBackward
|
323 |
+
|
324 |
+
std::vector<paddle::Tensor>
|
325 |
+
RoIAlignRotatedCUDAForward(const paddle::Tensor &input,
|
326 |
+
const paddle::Tensor &rois, int aligned_height,
|
327 |
+
int aligned_width, float spatial_scale,
|
328 |
+
int sampling_ratio, bool aligned, bool clockwise) {
|
329 |
+
|
330 |
+
auto num_rois = rois.shape()[0];
|
331 |
+
|
332 |
+
auto channels = input.shape()[1];
|
333 |
+
auto height = input.shape()[2];
|
334 |
+
auto width = input.shape()[3];
|
335 |
+
|
336 |
+
auto output =
|
337 |
+
paddle::empty({num_rois, channels, aligned_height, aligned_width},
|
338 |
+
input.type(), paddle::GPUPlace());
|
339 |
+
auto output_size = output.numel();
|
340 |
+
|
341 |
+
PD_DISPATCH_FLOATING_TYPES(
|
342 |
+
input.type(), "roi_align_rotated_cuda_forward_kernel", ([&] {
|
343 |
+
roi_align_rotated_cuda_forward_kernel<
|
344 |
+
data_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
|
345 |
+
output_size, input.data<data_t>(), rois.data<data_t>(),
|
346 |
+
static_cast<data_t>(spatial_scale), sampling_ratio, aligned,
|
347 |
+
clockwise, channels, height, width, aligned_height, aligned_width,
|
348 |
+
output.data<data_t>());
|
349 |
+
}));
|
350 |
+
|
351 |
+
return {output};
|
352 |
+
}
|
353 |
+
|
354 |
+
std::vector<paddle::Tensor> RoIAlignRotatedCUDABackward(
|
355 |
+
const paddle::Tensor &input, const paddle::Tensor &rois,
|
356 |
+
const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
|
357 |
+
float spatial_scale, int sampling_ratio, bool aligned, bool clockwise) {
|
358 |
+
|
359 |
+
auto num_rois = rois.shape()[0];
|
360 |
+
|
361 |
+
auto batch_size = input.shape()[0];
|
362 |
+
auto channels = input.shape()[1];
|
363 |
+
auto height = input.shape()[2];
|
364 |
+
auto width = input.shape()[3];
|
365 |
+
|
366 |
+
auto grad_input = paddle::full({batch_size, channels, height, width}, 0.0,
|
367 |
+
input.type(), paddle::GPUPlace());
|
368 |
+
|
369 |
+
const int output_size = num_rois * aligned_height * aligned_width * channels;
|
370 |
+
|
371 |
+
PD_DISPATCH_FLOATING_TYPES(
|
372 |
+
grad_output.type(), "roi_align_rotated_backward_cuda_kernel", ([&] {
|
373 |
+
roi_align_rotated_backward_cuda_kernel<
|
374 |
+
data_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
|
375 |
+
output_size, grad_output.data<data_t>(), rois.data<data_t>(),
|
376 |
+
spatial_scale, sampling_ratio, aligned, clockwise, channels, height,
|
377 |
+
width, aligned_height, aligned_width, grad_input.data<data_t>());
|
378 |
+
}));
|
379 |
+
return {grad_input};
|
380 |
+
}
|