Danieldu commited on
Commit
a89d9fd
1 Parent(s): 2e90087
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. app.py +56 -0
  3. ppocr/__init__.py +13 -0
  4. ppocr/data/__init__.py +110 -0
  5. ppocr/data/collate_fn.py +118 -0
  6. ppocr/data/imaug/ColorJitter.py +26 -0
  7. ppocr/data/imaug/__init__.py +80 -0
  8. ppocr/data/imaug/abinet_aug.py +458 -0
  9. ppocr/data/imaug/copy_paste.py +174 -0
  10. ppocr/data/imaug/ct_process.py +355 -0
  11. ppocr/data/imaug/drrg_targets.py +696 -0
  12. ppocr/data/imaug/east_process.py +436 -0
  13. ppocr/data/imaug/fce_aug.py +564 -0
  14. ppocr/data/imaug/fce_targets.py +666 -0
  15. ppocr/data/imaug/iaa_augment.py +105 -0
  16. ppocr/data/imaug/label_ops.py +1505 -0
  17. ppocr/data/imaug/make_border_map.py +173 -0
  18. ppocr/data/imaug/make_pse_gt.py +106 -0
  19. ppocr/data/imaug/make_shrink_map.py +123 -0
  20. ppocr/data/imaug/operators.py +524 -0
  21. ppocr/data/imaug/pg_process.py +1034 -0
  22. ppocr/data/imaug/randaugment.py +143 -0
  23. ppocr/data/imaug/random_crop_data.py +234 -0
  24. ppocr/data/imaug/rec_img_aug.py +825 -0
  25. ppocr/data/imaug/sast_process.py +777 -0
  26. ppocr/data/imaug/ssl_img_aug.py +60 -0
  27. ppocr/data/imaug/table_ops.py +229 -0
  28. ppocr/data/imaug/text_image_aug/__init__.py +17 -0
  29. ppocr/data/imaug/text_image_aug/__pycache__/__init__.cpython-37.pyc +0 -0
  30. ppocr/data/imaug/text_image_aug/__pycache__/__init__.cpython-38.pyc +0 -0
  31. ppocr/data/imaug/text_image_aug/__pycache__/augment.cpython-37.pyc +0 -0
  32. ppocr/data/imaug/text_image_aug/__pycache__/augment.cpython-38.pyc +0 -0
  33. ppocr/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-37.pyc +0 -0
  34. ppocr/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-38.pyc +0 -0
  35. ppocr/data/imaug/text_image_aug/augment.py +120 -0
  36. ppocr/data/imaug/text_image_aug/warp_mls.py +168 -0
  37. ppocr/data/imaug/vqa/__init__.py +20 -0
  38. ppocr/data/imaug/vqa/augment.py +33 -0
  39. ppocr/data/imaug/vqa/token/__init__.py +18 -0
  40. ppocr/data/imaug/vqa/token/vqa_re_convert.py +51 -0
  41. ppocr/data/imaug/vqa/token/vqa_token_chunk.py +122 -0
  42. ppocr/data/imaug/vqa/token/vqa_token_pad.py +104 -0
  43. ppocr/data/imaug/vqa/token/vqa_token_relation.py +67 -0
  44. ppocr/data/lmdb_dataset.py +205 -0
  45. ppocr/data/pgnet_dataset.py +106 -0
  46. ppocr/data/pubtab_dataset.py +133 -0
  47. ppocr/data/simple_dataset.py +151 -0
  48. ppocr/ext_op/__init__.py +1 -0
  49. ppocr/ext_op/roi_align_rotated/roi_align_rotated.cc +528 -0
  50. ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu +380 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ simfang.ttf filter=lfs diff=lfs merge=lfs -text
37
+ ppocr/utils/dict/confuse.pkl filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io
2
+ from paddleocr import PaddleOCR, draw_ocr
3
+ from PIL import Image, ImageDraw
4
+ import gradio as gr
5
+
6
+
7
+ # 設定 Hugging Face Hub 的 Access Token
8
+ os.environ["HF_TOKEN"] = "TWOCR"
9
+
10
+ def inference(img_path):
11
+
12
+ ocr = PaddleOCR(
13
+ rec_char_dict_path='zhtw_common_dict.txt',
14
+ use_gpu=False,
15
+ rec_image_shape="3, 48, 320"
16
+ )
17
+
18
+ result = ocr.ocr(img_path)
19
+
20
+ for idx in range(len(result)):
21
+ res = result[idx]
22
+ for line in res:
23
+ print(line)
24
+
25
+ result = result[0]
26
+ image = Image.open(img_path).convert('RGB')
27
+ boxes = [line[0] for line in result]
28
+ txts = [line[1][0] if line[1] else '' for line in result] # 確保在無文字時 txts 還是個空字串
29
+ scores = [line[1][1] for line in result]
30
+ im_show_pil = draw_ocr(image, boxes, txts, scores, font_path="./simfang.ttf")
31
+
32
+ return im_show_pil, "\n".join(txts)
33
+
34
+ title = "<p style='text-align: center'><a href='https://www.twman.org/AI/CV' target='_blank'>繁體中文醫療診斷書和收據OCR:PaddleOCR</a></p>"
35
+
36
+ description = """
37
+ <p style='text-align: center'><a href="https://blog.twman.org/2023/07/wsl.html" target='_blank'>用PaddleOCR的PPOCRLabel來微調醫療診斷書和收據</a></p><br>
38
+ <p style='text-align: center'><a href="https://github.com/Deep-Learning-101" target='_blank'>https://github.com/Deep-Learning-101</a></p><br>
39
+ <p style='text-align: center'><a href="https://github.com/Deep-Learning-101/Computer-Vision-Paper" target='_blank'>https://github.com/Deep-Learning-101/Computer-Vision-Paper</a></p><br>
40
+ """
41
+
42
+
43
+ css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
44
+
45
+ gr.Interface(
46
+ inference,
47
+ [gr.inputs.Image(type='filepath', label='圖片上傳')],
48
+ outputs=[
49
+ gr.outputs.Image(type="pil", label="識別結果"),
50
+ "text"
51
+ ],
52
+ title=title,
53
+ description=description,
54
+ css=css,
55
+ enable_queue=True
56
+ ).launch(debug=True)
ppocr/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
ppocr/data/__init__.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+ from __future__ import unicode_literals
19
+
20
+ import os
21
+ import sys
22
+ import numpy as np
23
+ import skimage
24
+ import paddle
25
+ import signal
26
+ import random
27
+
28
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
29
+ sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
30
+
31
+ import copy
32
+ from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
33
+ import paddle.distributed as dist
34
+
35
+ from ppocr.data.imaug import transform, create_operators
36
+ from ppocr.data.simple_dataset import SimpleDataSet
37
+ from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR
38
+ from ppocr.data.pgnet_dataset import PGDataSet
39
+ from ppocr.data.pubtab_dataset import PubTabDataSet
40
+
41
+ __all__ = ['build_dataloader', 'transform', 'create_operators']
42
+
43
+
44
+ def term_mp(sig_num, frame):
45
+ """ kill all child processes
46
+ """
47
+ pid = os.getpid()
48
+ pgid = os.getpgid(os.getpid())
49
+ print("main proc {} exit, kill process group " "{}".format(pid, pgid))
50
+ os.killpg(pgid, signal.SIGKILL)
51
+
52
+
53
+ def build_dataloader(config, mode, device, logger, seed=None):
54
+ config = copy.deepcopy(config)
55
+
56
+ support_dict = [
57
+ 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet',
58
+ 'LMDBDataSetSR'
59
+ ]
60
+ module_name = config[mode]['dataset']['name']
61
+ assert module_name in support_dict, Exception(
62
+ 'DataSet only support {}'.format(support_dict))
63
+ assert mode in ['Train', 'Eval', 'Test'
64
+ ], "Mode should be Train, Eval or Test."
65
+
66
+ dataset = eval(module_name)(config, mode, logger, seed)
67
+ loader_config = config[mode]['loader']
68
+ batch_size = loader_config['batch_size_per_card']
69
+ drop_last = loader_config['drop_last']
70
+ shuffle = loader_config['shuffle']
71
+ num_workers = loader_config['num_workers']
72
+ if 'use_shared_memory' in loader_config.keys():
73
+ use_shared_memory = loader_config['use_shared_memory']
74
+ else:
75
+ use_shared_memory = True
76
+
77
+ if mode == "Train":
78
+ # Distribute data to multiple cards
79
+ batch_sampler = DistributedBatchSampler(
80
+ dataset=dataset,
81
+ batch_size=batch_size,
82
+ shuffle=shuffle,
83
+ drop_last=drop_last)
84
+ else:
85
+ # Distribute data to single card
86
+ batch_sampler = BatchSampler(
87
+ dataset=dataset,
88
+ batch_size=batch_size,
89
+ shuffle=shuffle,
90
+ drop_last=drop_last)
91
+
92
+ if 'collate_fn' in loader_config:
93
+ from . import collate_fn
94
+ collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
95
+ else:
96
+ collate_fn = None
97
+ data_loader = DataLoader(
98
+ dataset=dataset,
99
+ batch_sampler=batch_sampler,
100
+ places=device,
101
+ num_workers=num_workers,
102
+ return_list=True,
103
+ use_shared_memory=use_shared_memory,
104
+ collate_fn=collate_fn)
105
+
106
+ # support exit using ctrl+c
107
+ signal.signal(signal.SIGINT, term_mp)
108
+ signal.signal(signal.SIGTERM, term_mp)
109
+
110
+ return data_loader
ppocr/data/collate_fn.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ import numbers
17
+ import numpy as np
18
+ from collections import defaultdict
19
+
20
+
21
+ class DictCollator(object):
22
+ """
23
+ data batch
24
+ """
25
+
26
+ def __call__(self, batch):
27
+ # todo:support batch operators
28
+ data_dict = defaultdict(list)
29
+ to_tensor_keys = []
30
+ for sample in batch:
31
+ for k, v in sample.items():
32
+ if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
33
+ if k not in to_tensor_keys:
34
+ to_tensor_keys.append(k)
35
+ data_dict[k].append(v)
36
+ for k in to_tensor_keys:
37
+ data_dict[k] = paddle.to_tensor(data_dict[k])
38
+ return data_dict
39
+
40
+
41
+ class ListCollator(object):
42
+ """
43
+ data batch
44
+ """
45
+
46
+ def __call__(self, batch):
47
+ # todo:support batch operators
48
+ data_dict = defaultdict(list)
49
+ to_tensor_idxs = []
50
+ for sample in batch:
51
+ for idx, v in enumerate(sample):
52
+ if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
53
+ if idx not in to_tensor_idxs:
54
+ to_tensor_idxs.append(idx)
55
+ data_dict[idx].append(v)
56
+ for idx in to_tensor_idxs:
57
+ data_dict[idx] = paddle.to_tensor(data_dict[idx])
58
+ return list(data_dict.values())
59
+
60
+
61
+ class SSLRotateCollate(object):
62
+ """
63
+ bach: [
64
+ [(4*3xH*W), (4,)]
65
+ [(4*3xH*W), (4,)]
66
+ ...
67
+ ]
68
+ """
69
+
70
+ def __call__(self, batch):
71
+ output = [np.concatenate(d, axis=0) for d in zip(*batch)]
72
+ return output
73
+
74
+
75
+ class DyMaskCollator(object):
76
+ """
77
+ batch: [
78
+ image [batch_size, channel, maxHinbatch, maxWinbatch]
79
+ image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
80
+ label [batch_size, maxLabelLen]
81
+ label_mask [batch_size, maxLabelLen]
82
+ ...
83
+ ]
84
+ """
85
+
86
+ def __call__(self, batch):
87
+ max_width, max_height, max_length = 0, 0, 0
88
+ bs, channel = len(batch), batch[0][0].shape[0]
89
+ proper_items = []
90
+ for item in batch:
91
+ if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[
92
+ 2] * max_height > 1600 * 320:
93
+ continue
94
+ max_height = item[0].shape[1] if item[0].shape[
95
+ 1] > max_height else max_height
96
+ max_width = item[0].shape[2] if item[0].shape[
97
+ 2] > max_width else max_width
98
+ max_length = len(item[1]) if len(item[
99
+ 1]) > max_length else max_length
100
+ proper_items.append(item)
101
+
102
+ images, image_masks = np.zeros(
103
+ (len(proper_items), channel, max_height, max_width),
104
+ dtype='float32'), np.zeros(
105
+ (len(proper_items), 1, max_height, max_width), dtype='float32')
106
+ labels, label_masks = np.zeros(
107
+ (len(proper_items), max_length), dtype='int64'), np.zeros(
108
+ (len(proper_items), max_length), dtype='int64')
109
+
110
+ for i in range(len(proper_items)):
111
+ _, h, w = proper_items[i][0].shape
112
+ images[i][:, :h, :w] = proper_items[i][0]
113
+ image_masks[i][:, :h, :w] = 1
114
+ l = len(proper_items[i][1])
115
+ labels[i][:l] = proper_items[i][1]
116
+ label_masks[i][:l] = 1
117
+
118
+ return images, image_masks, labels, label_masks
ppocr/data/imaug/ColorJitter.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from paddle.vision.transforms import ColorJitter as pp_ColorJitter
15
+
16
+ __all__ = ['ColorJitter']
17
+
18
+ class ColorJitter(object):
19
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs):
20
+ self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
21
+
22
+ def __call__(self, data):
23
+ image = data['image']
24
+ image = self.aug(image)
25
+ data['image'] = image
26
+ return data
ppocr/data/imaug/__init__.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from __future__ import absolute_import
15
+ from __future__ import division
16
+ from __future__ import print_function
17
+ from __future__ import unicode_literals
18
+
19
+ from .iaa_augment import IaaAugment
20
+ from .make_border_map import MakeBorderMap
21
+ from .make_shrink_map import MakeShrinkMap
22
+ from .random_crop_data import EastRandomCropData, RandomCropImgMask
23
+ from .make_pse_gt import MakePseGt
24
+
25
+
26
+
27
+ from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
28
+ SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
29
+ ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
30
+ RFLRecResizeImg, SVTRRecAug
31
+ from .ssl_img_aug import SSLRotateResize
32
+ from .randaugment import RandAugment
33
+ from .copy_paste import CopyPaste
34
+ from .ColorJitter import ColorJitter
35
+ from .operators import *
36
+ from .label_ops import *
37
+
38
+ from .east_process import *
39
+ from .sast_process import *
40
+ from .pg_process import *
41
+ from .table_ops import *
42
+
43
+ from .vqa import *
44
+
45
+ from .fce_aug import *
46
+ from .fce_targets import FCENetTargets
47
+ from .ct_process import *
48
+ from .drrg_targets import DRRGTargets
49
+
50
+
51
+ def transform(data, ops=None):
52
+ """ transform """
53
+ if ops is None:
54
+ ops = []
55
+ for op in ops:
56
+ data = op(data)
57
+ if data is None:
58
+ return None
59
+ return data
60
+
61
+
62
+ def create_operators(op_param_list, global_config=None):
63
+ """
64
+ create operators based on the config
65
+
66
+ Args:
67
+ params(list): a dict list, used to create some operators
68
+ """
69
+ assert isinstance(op_param_list, list), ('operator config should be a list')
70
+ ops = []
71
+ for operator in op_param_list:
72
+ assert isinstance(operator,
73
+ dict) and len(operator) == 1, "yaml format error"
74
+ op_name = list(operator)[0]
75
+ param = {} if operator[op_name] is None else operator[op_name]
76
+ if global_config is not None:
77
+ param.update(global_config)
78
+ op = eval(op_name)(**param)
79
+ ops.append(op)
80
+ return ops
ppocr/data/imaug/abinet_aug.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This code is refer from:
16
+ https://github.com/FangShancheng/ABINet/blob/main/transforms.py
17
+ """
18
+ import math
19
+ import numbers
20
+ import random
21
+
22
+ import cv2
23
+ import numpy as np
24
+ from paddle.vision.transforms import Compose, ColorJitter
25
+
26
+
27
+ def sample_asym(magnitude, size=None):
28
+ return np.random.beta(1, 4, size) * magnitude
29
+
30
+
31
+ def sample_sym(magnitude, size=None):
32
+ return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
33
+
34
+
35
+ def sample_uniform(low, high, size=None):
36
+ return np.random.uniform(low, high, size=size)
37
+
38
+
39
+ def get_interpolation(type='random'):
40
+ if type == 'random':
41
+ choice = [
42
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA
43
+ ]
44
+ interpolation = choice[random.randint(0, len(choice) - 1)]
45
+ elif type == 'nearest':
46
+ interpolation = cv2.INTER_NEAREST
47
+ elif type == 'linear':
48
+ interpolation = cv2.INTER_LINEAR
49
+ elif type == 'cubic':
50
+ interpolation = cv2.INTER_CUBIC
51
+ elif type == 'area':
52
+ interpolation = cv2.INTER_AREA
53
+ else:
54
+ raise TypeError(
55
+ 'Interpolation types only nearest, linear, cubic, area are supported!'
56
+ )
57
+ return interpolation
58
+
59
+
60
+ class CVRandomRotation(object):
61
+ def __init__(self, degrees=15):
62
+ assert isinstance(degrees,
63
+ numbers.Number), "degree should be a single number."
64
+ assert degrees >= 0, "degree must be positive."
65
+ self.degrees = degrees
66
+
67
+ @staticmethod
68
+ def get_params(degrees):
69
+ return sample_sym(degrees)
70
+
71
+ def __call__(self, img):
72
+ angle = self.get_params(self.degrees)
73
+ src_h, src_w = img.shape[:2]
74
+ M = cv2.getRotationMatrix2D(
75
+ center=(src_w / 2, src_h / 2), angle=angle, scale=1.0)
76
+ abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
77
+ dst_w = int(src_h * abs_sin + src_w * abs_cos)
78
+ dst_h = int(src_h * abs_cos + src_w * abs_sin)
79
+ M[0, 2] += (dst_w - src_w) / 2
80
+ M[1, 2] += (dst_h - src_h) / 2
81
+
82
+ flags = get_interpolation()
83
+ return cv2.warpAffine(
84
+ img,
85
+ M, (dst_w, dst_h),
86
+ flags=flags,
87
+ borderMode=cv2.BORDER_REPLICATE)
88
+
89
+
90
+ class CVRandomAffine(object):
91
+ def __init__(self, degrees, translate=None, scale=None, shear=None):
92
+ assert isinstance(degrees,
93
+ numbers.Number), "degree should be a single number."
94
+ assert degrees >= 0, "degree must be positive."
95
+ self.degrees = degrees
96
+
97
+ if translate is not None:
98
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
99
+ "translate should be a list or tuple and it must be of length 2."
100
+ for t in translate:
101
+ if not (0.0 <= t <= 1.0):
102
+ raise ValueError(
103
+ "translation values should be between 0 and 1")
104
+ self.translate = translate
105
+
106
+ if scale is not None:
107
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
108
+ "scale should be a list or tuple and it must be of length 2."
109
+ for s in scale:
110
+ if s <= 0:
111
+ raise ValueError("scale values should be positive")
112
+ self.scale = scale
113
+
114
+ if shear is not None:
115
+ if isinstance(shear, numbers.Number):
116
+ if shear < 0:
117
+ raise ValueError(
118
+ "If shear is a single number, it must be positive.")
119
+ self.shear = [shear]
120
+ else:
121
+ assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
122
+ "shear should be a list or tuple and it must be of length 2."
123
+ self.shear = shear
124
+ else:
125
+ self.shear = shear
126
+
127
+ def _get_inverse_affine_matrix(self, center, angle, translate, scale,
128
+ shear):
129
+ # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717
130
+ from numpy import sin, cos, tan
131
+
132
+ if isinstance(shear, numbers.Number):
133
+ shear = [shear, 0]
134
+
135
+ if not isinstance(shear, (tuple, list)) and len(shear) == 2:
136
+ raise ValueError(
137
+ "Shear should be a single value or a tuple/list containing " +
138
+ "two values. Got {}".format(shear))
139
+
140
+ rot = math.radians(angle)
141
+ sx, sy = [math.radians(s) for s in shear]
142
+
143
+ cx, cy = center
144
+ tx, ty = translate
145
+
146
+ # RSS without scaling
147
+ a = cos(rot - sy) / cos(sy)
148
+ b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
149
+ c = sin(rot - sy) / cos(sy)
150
+ d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
151
+
152
+ # Inverted rotation matrix with scale and shear
153
+ # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
154
+ M = [d, -b, 0, -c, a, 0]
155
+ M = [x / scale for x in M]
156
+
157
+ # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
158
+ M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
159
+ M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
160
+
161
+ # Apply center translation: C * RSS^-1 * C^-1 * T^-1
162
+ M[2] += cx
163
+ M[5] += cy
164
+ return M
165
+
166
+ @staticmethod
167
+ def get_params(degrees, translate, scale_ranges, shears, height):
168
+ angle = sample_sym(degrees)
169
+ if translate is not None:
170
+ max_dx = translate[0] * height
171
+ max_dy = translate[1] * height
172
+ translations = (np.round(sample_sym(max_dx)),
173
+ np.round(sample_sym(max_dy)))
174
+ else:
175
+ translations = (0, 0)
176
+
177
+ if scale_ranges is not None:
178
+ scale = sample_uniform(scale_ranges[0], scale_ranges[1])
179
+ else:
180
+ scale = 1.0
181
+
182
+ if shears is not None:
183
+ if len(shears) == 1:
184
+ shear = [sample_sym(shears[0]), 0.]
185
+ elif len(shears) == 2:
186
+ shear = [sample_sym(shears[0]), sample_sym(shears[1])]
187
+ else:
188
+ shear = 0.0
189
+
190
+ return angle, translations, scale, shear
191
+
192
+ def __call__(self, img):
193
+ src_h, src_w = img.shape[:2]
194
+ angle, translate, scale, shear = self.get_params(
195
+ self.degrees, self.translate, self.scale, self.shear, src_h)
196
+
197
+ M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle,
198
+ (0, 0), scale, shear)
199
+ M = np.array(M).reshape(2, 3)
200
+
201
+ startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1),
202
+ (0, src_h - 1)]
203
+ project = lambda x, y, a, b, c: int(a * x + b * y + c)
204
+ endpoints = [(project(x, y, *M[0]), project(x, y, *M[1]))
205
+ for x, y in startpoints]
206
+
207
+ rect = cv2.minAreaRect(np.array(endpoints))
208
+ bbox = cv2.boxPoints(rect).astype(dtype=np.int)
209
+ max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
210
+ min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
211
+
212
+ dst_w = int(max_x - min_x)
213
+ dst_h = int(max_y - min_y)
214
+ M[0, 2] += (dst_w - src_w) / 2
215
+ M[1, 2] += (dst_h - src_h) / 2
216
+
217
+ # add translate
218
+ dst_w += int(abs(translate[0]))
219
+ dst_h += int(abs(translate[1]))
220
+ if translate[0] < 0: M[0, 2] += abs(translate[0])
221
+ if translate[1] < 0: M[1, 2] += abs(translate[1])
222
+
223
+ flags = get_interpolation()
224
+ return cv2.warpAffine(
225
+ img,
226
+ M, (dst_w, dst_h),
227
+ flags=flags,
228
+ borderMode=cv2.BORDER_REPLICATE)
229
+
230
+
231
+ class CVRandomPerspective(object):
232
+ def __init__(self, distortion=0.5):
233
+ self.distortion = distortion
234
+
235
+ def get_params(self, width, height, distortion):
236
+ offset_h = sample_asym(
237
+ distortion * height / 2, size=4).astype(dtype=np.int)
238
+ offset_w = sample_asym(
239
+ distortion * width / 2, size=4).astype(dtype=np.int)
240
+ topleft = (offset_w[0], offset_h[0])
241
+ topright = (width - 1 - offset_w[1], offset_h[1])
242
+ botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
243
+ botleft = (offset_w[3], height - 1 - offset_h[3])
244
+
245
+ startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1),
246
+ (0, height - 1)]
247
+ endpoints = [topleft, topright, botright, botleft]
248
+ return np.array(
249
+ startpoints, dtype=np.float32), np.array(
250
+ endpoints, dtype=np.float32)
251
+
252
+ def __call__(self, img):
253
+ height, width = img.shape[:2]
254
+ startpoints, endpoints = self.get_params(width, height, self.distortion)
255
+ M = cv2.getPerspectiveTransform(startpoints, endpoints)
256
+
257
+ # TODO: more robust way to crop image
258
+ rect = cv2.minAreaRect(endpoints)
259
+ bbox = cv2.boxPoints(rect).astype(dtype=np.int)
260
+ max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
261
+ min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
262
+ min_x, min_y = max(min_x, 0), max(min_y, 0)
263
+
264
+ flags = get_interpolation()
265
+ img = cv2.warpPerspective(
266
+ img,
267
+ M, (max_x, max_y),
268
+ flags=flags,
269
+ borderMode=cv2.BORDER_REPLICATE)
270
+ img = img[min_y:, min_x:]
271
+ return img
272
+
273
+
274
+ class CVRescale(object):
275
+ def __init__(self, factor=4, base_size=(128, 512)):
276
+ """ Define image scales using gaussian pyramid and rescale image to target scale.
277
+
278
+ Args:
279
+ factor: the decayed factor from base size, factor=4 keeps target scale by default.
280
+ base_size: base size the build the bottom layer of pyramid
281
+ """
282
+ if isinstance(factor, numbers.Number):
283
+ self.factor = round(sample_uniform(0, factor))
284
+ elif isinstance(factor, (tuple, list)) and len(factor) == 2:
285
+ self.factor = round(sample_uniform(factor[0], factor[1]))
286
+ else:
287
+ raise Exception('factor must be number or list with length 2')
288
+ # assert factor is valid
289
+ self.base_h, self.base_w = base_size[:2]
290
+
291
+ def __call__(self, img):
292
+ if self.factor == 0: return img
293
+ src_h, src_w = img.shape[:2]
294
+ cur_w, cur_h = self.base_w, self.base_h
295
+ scale_img = cv2.resize(
296
+ img, (cur_w, cur_h), interpolation=get_interpolation())
297
+ for _ in range(self.factor):
298
+ scale_img = cv2.pyrDown(scale_img)
299
+ scale_img = cv2.resize(
300
+ scale_img, (src_w, src_h), interpolation=get_interpolation())
301
+ return scale_img
302
+
303
+
304
+ class CVGaussianNoise(object):
305
+ def __init__(self, mean=0, var=20):
306
+ self.mean = mean
307
+ if isinstance(var, numbers.Number):
308
+ self.var = max(int(sample_asym(var)), 1)
309
+ elif isinstance(var, (tuple, list)) and len(var) == 2:
310
+ self.var = int(sample_uniform(var[0], var[1]))
311
+ else:
312
+ raise Exception('degree must be number or list with length 2')
313
+
314
+ def __call__(self, img):
315
+ noise = np.random.normal(self.mean, self.var**0.5, img.shape)
316
+ img = np.clip(img + noise, 0, 255).astype(np.uint8)
317
+ return img
318
+
319
+
320
+ class CVMotionBlur(object):
321
+ def __init__(self, degrees=12, angle=90):
322
+ if isinstance(degrees, numbers.Number):
323
+ self.degree = max(int(sample_asym(degrees)), 1)
324
+ elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
325
+ self.degree = int(sample_uniform(degrees[0], degrees[1]))
326
+ else:
327
+ raise Exception('degree must be number or list with length 2')
328
+ self.angle = sample_uniform(-angle, angle)
329
+
330
+ def __call__(self, img):
331
+ M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2),
332
+ self.angle, 1)
333
+ motion_blur_kernel = np.zeros((self.degree, self.degree))
334
+ motion_blur_kernel[self.degree // 2, :] = 1
335
+ motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
336
+ (self.degree, self.degree))
337
+ motion_blur_kernel = motion_blur_kernel / self.degree
338
+ img = cv2.filter2D(img, -1, motion_blur_kernel)
339
+ img = np.clip(img, 0, 255).astype(np.uint8)
340
+ return img
341
+
342
+
343
+ class CVGeometry(object):
344
+ def __init__(self,
345
+ degrees=15,
346
+ translate=(0.3, 0.3),
347
+ scale=(0.5, 2.),
348
+ shear=(45, 15),
349
+ distortion=0.5,
350
+ p=0.5):
351
+ self.p = p
352
+ type_p = random.random()
353
+ if type_p < 0.33:
354
+ self.transforms = CVRandomRotation(degrees=degrees)
355
+ elif type_p < 0.66:
356
+ self.transforms = CVRandomAffine(
357
+ degrees=degrees, translate=translate, scale=scale, shear=shear)
358
+ else:
359
+ self.transforms = CVRandomPerspective(distortion=distortion)
360
+
361
+ def __call__(self, img):
362
+ if random.random() < self.p:
363
+ return self.transforms(img)
364
+ else:
365
+ return img
366
+
367
+
368
+ class CVDeterioration(object):
369
+ def __init__(self, var, degrees, factor, p=0.5):
370
+ self.p = p
371
+ transforms = []
372
+ if var is not None:
373
+ transforms.append(CVGaussianNoise(var=var))
374
+ if degrees is not None:
375
+ transforms.append(CVMotionBlur(degrees=degrees))
376
+ if factor is not None:
377
+ transforms.append(CVRescale(factor=factor))
378
+
379
+ random.shuffle(transforms)
380
+ transforms = Compose(transforms)
381
+ self.transforms = transforms
382
+
383
+ def __call__(self, img):
384
+ if random.random() < self.p:
385
+
386
+ return self.transforms(img)
387
+ else:
388
+ return img
389
+
390
+
391
+ class CVColorJitter(object):
392
+ def __init__(self,
393
+ brightness=0.5,
394
+ contrast=0.5,
395
+ saturation=0.5,
396
+ hue=0.1,
397
+ p=0.5):
398
+ self.p = p
399
+ self.transforms = ColorJitter(
400
+ brightness=brightness,
401
+ contrast=contrast,
402
+ saturation=saturation,
403
+ hue=hue)
404
+
405
+ def __call__(self, img):
406
+ if random.random() < self.p: return self.transforms(img)
407
+ else: return img
408
+
409
+
410
+ class SVTRDeterioration(object):
411
+ def __init__(self, var, degrees, factor, p=0.5):
412
+ self.p = p
413
+ transforms = []
414
+ if var is not None:
415
+ transforms.append(CVGaussianNoise(var=var))
416
+ if degrees is not None:
417
+ transforms.append(CVMotionBlur(degrees=degrees))
418
+ if factor is not None:
419
+ transforms.append(CVRescale(factor=factor))
420
+ self.transforms = transforms
421
+
422
+ def __call__(self, img):
423
+ if random.random() < self.p:
424
+ random.shuffle(self.transforms)
425
+ transforms = Compose(self.transforms)
426
+ return transforms(img)
427
+ else:
428
+ return img
429
+
430
+
431
+ class SVTRGeometry(object):
432
+ def __init__(self,
433
+ aug_type=0,
434
+ degrees=15,
435
+ translate=(0.3, 0.3),
436
+ scale=(0.5, 2.),
437
+ shear=(45, 15),
438
+ distortion=0.5,
439
+ p=0.5):
440
+ self.aug_type = aug_type
441
+ self.p = p
442
+ self.transforms = []
443
+ self.transforms.append(CVRandomRotation(degrees=degrees))
444
+ self.transforms.append(CVRandomAffine(
445
+ degrees=degrees, translate=translate, scale=scale, shear=shear))
446
+ self.transforms.append(CVRandomPerspective(distortion=distortion))
447
+
448
+ def __call__(self, img):
449
+ if random.random() < self.p:
450
+ if self.aug_type:
451
+ random.shuffle(self.transforms)
452
+ transforms = Compose(self.transforms[:random.randint(1, 3)])
453
+ img = transforms(img)
454
+ else:
455
+ img = self.transforms[random.randint(0, 2)](img)
456
+ return img
457
+ else:
458
+ return img
ppocr/data/imaug/copy_paste.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import copy
15
+ import cv2
16
+ import random
17
+ import numpy as np
18
+ from PIL import Image
19
+ from shapely.geometry import Polygon
20
+
21
+ from ppocr.data.imaug.iaa_augment import IaaAugment
22
+ from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
23
+ from tools.infer.utility import get_rotate_crop_image
24
+
25
+
26
+ class CopyPaste(object):
27
+ def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
28
+ self.ext_data_num = 1
29
+ self.objects_paste_ratio = objects_paste_ratio
30
+ self.limit_paste = limit_paste
31
+ augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
32
+ self.aug = IaaAugment(augmenter_args)
33
+
34
+ def __call__(self, data):
35
+ point_num = data['polys'].shape[1]
36
+ src_img = data['image']
37
+ src_polys = data['polys'].tolist()
38
+ src_texts = data['texts']
39
+ src_ignores = data['ignore_tags'].tolist()
40
+ ext_data = data['ext_data'][0]
41
+ ext_image = ext_data['image']
42
+ ext_polys = ext_data['polys']
43
+ ext_texts = ext_data['texts']
44
+ ext_ignores = ext_data['ignore_tags']
45
+
46
+ indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
47
+ select_num = max(
48
+ 1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
49
+
50
+ random.shuffle(indexs)
51
+ select_idxs = indexs[:select_num]
52
+ select_polys = ext_polys[select_idxs]
53
+ select_ignores = ext_ignores[select_idxs]
54
+
55
+ src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
56
+ ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
57
+ src_img = Image.fromarray(src_img).convert('RGBA')
58
+ for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
59
+ box_img = get_rotate_crop_image(ext_image, poly)
60
+
61
+ src_img, box = self.paste_img(src_img, box_img, src_polys)
62
+ if box is not None:
63
+ box = box.tolist()
64
+ for _ in range(len(box), point_num):
65
+ box.append(box[-1])
66
+ src_polys.append(box)
67
+ src_texts.append(ext_texts[idx])
68
+ src_ignores.append(tag)
69
+ src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
70
+ h, w = src_img.shape[:2]
71
+ src_polys = np.array(src_polys)
72
+ src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
73
+ src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
74
+ data['image'] = src_img
75
+ data['polys'] = src_polys
76
+ data['texts'] = src_texts
77
+ data['ignore_tags'] = np.array(src_ignores)
78
+ return data
79
+
80
+ def paste_img(self, src_img, box_img, src_polys):
81
+ box_img_pil = Image.fromarray(box_img).convert('RGBA')
82
+ src_w, src_h = src_img.size
83
+ box_w, box_h = box_img_pil.size
84
+
85
+ angle = np.random.randint(0, 360)
86
+ box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
87
+ box = rotate_bbox(box_img, box, angle)[0]
88
+ box_img_pil = box_img_pil.rotate(angle, expand=1)
89
+ box_w, box_h = box_img_pil.width, box_img_pil.height
90
+ if src_w - box_w < 0 or src_h - box_h < 0:
91
+ return src_img, None
92
+
93
+ paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
94
+ src_h - box_h)
95
+ if paste_x is None:
96
+ return src_img, None
97
+ box[:, 0] += paste_x
98
+ box[:, 1] += paste_y
99
+ r, g, b, A = box_img_pil.split()
100
+ src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
101
+
102
+ return src_img, box
103
+
104
+ def select_coord(self, src_polys, box, endx, endy):
105
+ if self.limit_paste:
106
+ xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
107
+ ), box[:, 0].max(), box[:, 1].max()
108
+ for _ in range(50):
109
+ paste_x = random.randint(0, endx)
110
+ paste_y = random.randint(0, endy)
111
+ xmin1 = xmin + paste_x
112
+ xmax1 = xmax + paste_x
113
+ ymin1 = ymin + paste_y
114
+ ymax1 = ymax + paste_y
115
+
116
+ num_poly_in_rect = 0
117
+ for poly in src_polys:
118
+ if not is_poly_outside_rect(poly, xmin1, ymin1,
119
+ xmax1 - xmin1, ymax1 - ymin1):
120
+ num_poly_in_rect += 1
121
+ break
122
+ if num_poly_in_rect == 0:
123
+ return paste_x, paste_y
124
+ return None, None
125
+ else:
126
+ paste_x = random.randint(0, endx)
127
+ paste_y = random.randint(0, endy)
128
+ return paste_x, paste_y
129
+
130
+
131
+ def get_union(pD, pG):
132
+ return Polygon(pD).union(Polygon(pG)).area
133
+
134
+
135
+ def get_intersection_over_union(pD, pG):
136
+ return get_intersection(pD, pG) / get_union(pD, pG)
137
+
138
+
139
+ def get_intersection(pD, pG):
140
+ return Polygon(pD).intersection(Polygon(pG)).area
141
+
142
+
143
+ def rotate_bbox(img, text_polys, angle, scale=1):
144
+ """
145
+ from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
146
+ Args:
147
+ img: np.ndarray
148
+ text_polys: np.ndarray N*4*2
149
+ angle: int
150
+ scale: int
151
+
152
+ Returns:
153
+
154
+ """
155
+ w = img.shape[1]
156
+ h = img.shape[0]
157
+
158
+ rangle = np.deg2rad(angle)
159
+ nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
160
+ nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
161
+ rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
162
+ rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
163
+ rot_mat[0, 2] += rot_move[0]
164
+ rot_mat[1, 2] += rot_move[1]
165
+
166
+ # ---------------------- rotate box ----------------------
167
+ rot_text_polys = list()
168
+ for bbox in text_polys:
169
+ point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
170
+ point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
171
+ point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
172
+ point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
173
+ rot_text_polys.append([point1, point2, point3, point4])
174
+ return np.array(rot_text_polys, dtype=np.float32)
ppocr/data/imaug/ct_process.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import cv2
17
+ import random
18
+ import pyclipper
19
+ import paddle
20
+
21
+ import numpy as np
22
+ import Polygon as plg
23
+ import scipy.io as scio
24
+
25
+ from PIL import Image
26
+ import paddle.vision.transforms as transforms
27
+
28
+
29
+ class RandomScale():
30
+ def __init__(self, short_size=640, **kwargs):
31
+ self.short_size = short_size
32
+
33
+ def scale_aligned(self, img, scale):
34
+ oh, ow = img.shape[0:2]
35
+ h = int(oh * scale + 0.5)
36
+ w = int(ow * scale + 0.5)
37
+ if h % 32 != 0:
38
+ h = h + (32 - h % 32)
39
+ if w % 32 != 0:
40
+ w = w + (32 - w % 32)
41
+ img = cv2.resize(img, dsize=(w, h))
42
+ factor_h = h / oh
43
+ factor_w = w / ow
44
+ return img, factor_h, factor_w
45
+
46
+ def __call__(self, data):
47
+ img = data['image']
48
+
49
+ h, w = img.shape[0:2]
50
+ random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3])
51
+ scale = (np.random.choice(random_scale) * self.short_size) / min(h, w)
52
+ img, factor_h, factor_w = self.scale_aligned(img, scale)
53
+
54
+ data['scale_factor'] = (factor_w, factor_h)
55
+ data['image'] = img
56
+ return data
57
+
58
+
59
+ class MakeShrink():
60
+ def __init__(self, kernel_scale=0.7, **kwargs):
61
+ self.kernel_scale = kernel_scale
62
+
63
+ def dist(self, a, b):
64
+ return np.linalg.norm((a - b), ord=2, axis=0)
65
+
66
+ def perimeter(self, bbox):
67
+ peri = 0.0
68
+ for i in range(bbox.shape[0]):
69
+ peri += self.dist(bbox[i], bbox[(i + 1) % bbox.shape[0]])
70
+ return peri
71
+
72
+ def shrink(self, bboxes, rate, max_shr=20):
73
+ rate = rate * rate
74
+ shrinked_bboxes = []
75
+ for bbox in bboxes:
76
+ area = plg.Polygon(bbox).area()
77
+ peri = self.perimeter(bbox)
78
+
79
+ try:
80
+ pco = pyclipper.PyclipperOffset()
81
+ pco.AddPath(bbox, pyclipper.JT_ROUND,
82
+ pyclipper.ET_CLOSEDPOLYGON)
83
+ offset = min(
84
+ int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)
85
+
86
+ shrinked_bbox = pco.Execute(-offset)
87
+ if len(shrinked_bbox) == 0:
88
+ shrinked_bboxes.append(bbox)
89
+ continue
90
+
91
+ shrinked_bbox = np.array(shrinked_bbox[0])
92
+ if shrinked_bbox.shape[0] <= 2:
93
+ shrinked_bboxes.append(bbox)
94
+ continue
95
+
96
+ shrinked_bboxes.append(shrinked_bbox)
97
+ except Exception as e:
98
+ shrinked_bboxes.append(bbox)
99
+
100
+ return shrinked_bboxes
101
+
102
+ def __call__(self, data):
103
+ img = data['image']
104
+ bboxes = data['polys']
105
+ words = data['texts']
106
+ scale_factor = data['scale_factor']
107
+
108
+ gt_instance = np.zeros(img.shape[0:2], dtype='uint8') # h,w
109
+ training_mask = np.ones(img.shape[0:2], dtype='uint8')
110
+ training_mask_distance = np.ones(img.shape[0:2], dtype='uint8')
111
+
112
+ for i in range(len(bboxes)):
113
+ bboxes[i] = np.reshape(bboxes[i] * (
114
+ [scale_factor[0], scale_factor[1]] * (bboxes[i].shape[0] // 2)),
115
+ (bboxes[i].shape[0] // 2, 2)).astype('int32')
116
+
117
+ for i in range(len(bboxes)):
118
+ #different value for different bbox
119
+ cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1)
120
+
121
+ # set training mask to 0
122
+ cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1)
123
+
124
+ # for not accurate annotation, use training_mask_distance
125
+ if words[i] == '###' or words[i] == '???':
126
+ cv2.drawContours(training_mask_distance, [bboxes[i]], -1, 0, -1)
127
+
128
+ # make shrink
129
+ gt_kernel_instance = np.zeros(img.shape[0:2], dtype='uint8')
130
+ kernel_bboxes = self.shrink(bboxes, self.kernel_scale)
131
+ for i in range(len(bboxes)):
132
+ cv2.drawContours(gt_kernel_instance, [kernel_bboxes[i]], -1, i + 1,
133
+ -1)
134
+
135
+ # for training mask, kernel and background= 1, box region=0
136
+ if words[i] != '###' and words[i] != '???':
137
+ cv2.drawContours(training_mask, [kernel_bboxes[i]], -1, 1, -1)
138
+
139
+ gt_kernel = gt_kernel_instance.copy()
140
+ # for gt_kernel, kernel = 1
141
+ gt_kernel[gt_kernel > 0] = 1
142
+
143
+ # shrink 2 times
144
+ tmp1 = gt_kernel_instance.copy()
145
+ erode_kernel = np.ones((3, 3), np.uint8)
146
+ tmp1 = cv2.erode(tmp1, erode_kernel, iterations=1)
147
+ tmp2 = tmp1.copy()
148
+ tmp2 = cv2.erode(tmp2, erode_kernel, iterations=1)
149
+
150
+ # compute text region
151
+ gt_kernel_inner = tmp1 - tmp2
152
+
153
+ # gt_instance: text instance, bg=0, diff word use diff value
154
+ # training_mask: text instance mask, word=0,kernel and bg=1
155
+ # gt_kernel_instance: text kernel instance, bg=0, diff word use diff value
156
+ # gt_kernel: text_kernel, bg=0,diff word use same value
157
+ # gt_kernel_inner: text kernel reference
158
+ # training_mask_distance: word without anno = 0, else 1
159
+
160
+ data['image'] = [
161
+ img, gt_instance, training_mask, gt_kernel_instance, gt_kernel,
162
+ gt_kernel_inner, training_mask_distance
163
+ ]
164
+ return data
165
+
166
+
167
+ class GroupRandomHorizontalFlip():
168
+ def __init__(self, p=0.5, **kwargs):
169
+ self.p = p
170
+
171
+ def __call__(self, data):
172
+ imgs = data['image']
173
+
174
+ if random.random() < self.p:
175
+ for i in range(len(imgs)):
176
+ imgs[i] = np.flip(imgs[i], axis=1).copy()
177
+ data['image'] = imgs
178
+ return data
179
+
180
+
181
+ class GroupRandomRotate():
182
+ def __init__(self, **kwargs):
183
+ pass
184
+
185
+ def __call__(self, data):
186
+ imgs = data['image']
187
+
188
+ max_angle = 10
189
+ angle = random.random() * 2 * max_angle - max_angle
190
+ for i in range(len(imgs)):
191
+ img = imgs[i]
192
+ w, h = img.shape[:2]
193
+ rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
194
+ img_rotation = cv2.warpAffine(
195
+ img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST)
196
+ imgs[i] = img_rotation
197
+
198
+ data['image'] = imgs
199
+ return data
200
+
201
+
202
+ class GroupRandomCropPadding():
203
+ def __init__(self, target_size=(640, 640), **kwargs):
204
+ self.target_size = target_size
205
+
206
+ def __call__(self, data):
207
+ imgs = data['image']
208
+
209
+ h, w = imgs[0].shape[0:2]
210
+ t_w, t_h = self.target_size
211
+ p_w, p_h = self.target_size
212
+ if w == t_w and h == t_h:
213
+ return data
214
+
215
+ t_h = t_h if t_h < h else h
216
+ t_w = t_w if t_w < w else w
217
+
218
+ if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0:
219
+ # make sure to crop the text region
220
+ tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w)
221
+ tl[tl < 0] = 0
222
+ br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w)
223
+ br[br < 0] = 0
224
+ br[0] = min(br[0], h - t_h)
225
+ br[1] = min(br[1], w - t_w)
226
+
227
+ i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
228
+ j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
229
+ else:
230
+ i = random.randint(0, h - t_h) if h - t_h > 0 else 0
231
+ j = random.randint(0, w - t_w) if w - t_w > 0 else 0
232
+
233
+ n_imgs = []
234
+ for idx in range(len(imgs)):
235
+ if len(imgs[idx].shape) == 3:
236
+ s3_length = int(imgs[idx].shape[-1])
237
+ img = imgs[idx][i:i + t_h, j:j + t_w, :]
238
+ img_p = cv2.copyMakeBorder(
239
+ img,
240
+ 0,
241
+ p_h - t_h,
242
+ 0,
243
+ p_w - t_w,
244
+ borderType=cv2.BORDER_CONSTANT,
245
+ value=tuple(0 for i in range(s3_length)))
246
+ else:
247
+ img = imgs[idx][i:i + t_h, j:j + t_w]
248
+ img_p = cv2.copyMakeBorder(
249
+ img,
250
+ 0,
251
+ p_h - t_h,
252
+ 0,
253
+ p_w - t_w,
254
+ borderType=cv2.BORDER_CONSTANT,
255
+ value=(0, ))
256
+ n_imgs.append(img_p)
257
+
258
+ data['image'] = n_imgs
259
+ return data
260
+
261
+
262
+ class MakeCentripetalShift():
263
+ def __init__(self, **kwargs):
264
+ pass
265
+
266
+ def jaccard(self, As, Bs):
267
+ A = As.shape[0] # small
268
+ B = Bs.shape[0] # large
269
+
270
+ dis = np.sqrt(
271
+ np.sum((As[:, np.newaxis, :].repeat(
272
+ B, axis=1) - Bs[np.newaxis, :, :].repeat(
273
+ A, axis=0))**2,
274
+ axis=-1))
275
+
276
+ ind = np.argmin(dis, axis=-1)
277
+
278
+ return ind
279
+
280
+ def __call__(self, data):
281
+ imgs = data['image']
282
+
283
+ img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, gt_kernel_inner, training_mask_distance = \
284
+ imgs[0], imgs[1], imgs[2], imgs[3], imgs[4], imgs[5], imgs[6]
285
+
286
+ max_instance = np.max(gt_instance) # num bbox
287
+
288
+ # make centripetal shift
289
+ gt_distance = np.zeros((2, *img.shape[0:2]), dtype=np.float32)
290
+ for i in range(1, max_instance + 1):
291
+ # kernel_reference
292
+ ind = (gt_kernel_inner == i)
293
+
294
+ if np.sum(ind) == 0:
295
+ training_mask[gt_instance == i] = 0
296
+ training_mask_distance[gt_instance == i] = 0
297
+ continue
298
+
299
+ kpoints = np.array(np.where(ind)).transpose(
300
+ (1, 0))[:, ::-1].astype('float32')
301
+
302
+ ind = (gt_instance == i) * (gt_kernel_instance == 0)
303
+ if np.sum(ind) == 0:
304
+ continue
305
+ pixels = np.where(ind)
306
+
307
+ points = np.array(pixels).transpose(
308
+ (1, 0))[:, ::-1].astype('float32')
309
+
310
+ bbox_ind = self.jaccard(points, kpoints)
311
+
312
+ offset_gt = kpoints[bbox_ind] - points
313
+
314
+ gt_distance[:, pixels[0], pixels[1]] = offset_gt.T * 0.1
315
+
316
+ img = Image.fromarray(img)
317
+ img = img.convert('RGB')
318
+
319
+ data["image"] = img
320
+ data["gt_kernel"] = gt_kernel.astype("int64")
321
+ data["training_mask"] = training_mask.astype("int64")
322
+ data["gt_instance"] = gt_instance.astype("int64")
323
+ data["gt_kernel_instance"] = gt_kernel_instance.astype("int64")
324
+ data["training_mask_distance"] = training_mask_distance.astype("int64")
325
+ data["gt_distance"] = gt_distance.astype("float32")
326
+
327
+ return data
328
+
329
+
330
+ class ScaleAlignedShort():
331
+ def __init__(self, short_size=640, **kwargs):
332
+ self.short_size = short_size
333
+
334
+ def __call__(self, data):
335
+ img = data['image']
336
+
337
+ org_img_shape = img.shape
338
+
339
+ h, w = img.shape[0:2]
340
+ scale = self.short_size * 1.0 / min(h, w)
341
+ h = int(h * scale + 0.5)
342
+ w = int(w * scale + 0.5)
343
+ if h % 32 != 0:
344
+ h = h + (32 - h % 32)
345
+ if w % 32 != 0:
346
+ w = w + (32 - w % 32)
347
+ img = cv2.resize(img, dsize=(w, h))
348
+
349
+ new_img_shape = img.shape
350
+ img_shape = np.array(org_img_shape + new_img_shape)
351
+
352
+ data['shape'] = img_shape
353
+ data['image'] = img
354
+
355
+ return data
ppocr/data/imaug/drrg_targets.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This code is refer from:
16
+ https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py
17
+ """
18
+
19
+ import cv2
20
+ import numpy as np
21
+ from lanms import merge_quadrangle_n9 as la_nms
22
+ from numpy.linalg import norm
23
+
24
+
25
+ class DRRGTargets(object):
26
+ def __init__(self,
27
+ orientation_thr=2.0,
28
+ resample_step=8.0,
29
+ num_min_comps=9,
30
+ num_max_comps=600,
31
+ min_width=8.0,
32
+ max_width=24.0,
33
+ center_region_shrink_ratio=0.3,
34
+ comp_shrink_ratio=1.0,
35
+ comp_w_h_ratio=0.3,
36
+ text_comp_nms_thr=0.25,
37
+ min_rand_half_height=8.0,
38
+ max_rand_half_height=24.0,
39
+ jitter_level=0.2,
40
+ **kwargs):
41
+
42
+ super().__init__()
43
+ self.orientation_thr = orientation_thr
44
+ self.resample_step = resample_step
45
+ self.num_max_comps = num_max_comps
46
+ self.num_min_comps = num_min_comps
47
+ self.min_width = min_width
48
+ self.max_width = max_width
49
+ self.center_region_shrink_ratio = center_region_shrink_ratio
50
+ self.comp_shrink_ratio = comp_shrink_ratio
51
+ self.comp_w_h_ratio = comp_w_h_ratio
52
+ self.text_comp_nms_thr = text_comp_nms_thr
53
+ self.min_rand_half_height = min_rand_half_height
54
+ self.max_rand_half_height = max_rand_half_height
55
+ self.jitter_level = jitter_level
56
+ self.eps = 1e-8
57
+
58
+ def vector_angle(self, vec1, vec2):
59
+ if vec1.ndim > 1:
60
+ unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps).reshape((-1, 1))
61
+ else:
62
+ unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps)
63
+ if vec2.ndim > 1:
64
+ unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps).reshape((-1, 1))
65
+ else:
66
+ unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps)
67
+ return np.arccos(
68
+ np.clip(
69
+ np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
70
+
71
+ def vector_slope(self, vec):
72
+ assert len(vec) == 2
73
+ return abs(vec[1] / (vec[0] + self.eps))
74
+
75
+ def vector_sin(self, vec):
76
+ assert len(vec) == 2
77
+ return vec[1] / (norm(vec) + self.eps)
78
+
79
+ def vector_cos(self, vec):
80
+ assert len(vec) == 2
81
+ return vec[0] / (norm(vec) + self.eps)
82
+
83
+ def find_head_tail(self, points, orientation_thr):
84
+
85
+ assert points.ndim == 2
86
+ assert points.shape[0] >= 4
87
+ assert points.shape[1] == 2
88
+ assert isinstance(orientation_thr, float)
89
+
90
+ if len(points) > 4:
91
+ pad_points = np.vstack([points, points[0]])
92
+ edge_vec = pad_points[1:] - pad_points[:-1]
93
+
94
+ theta_sum = []
95
+ adjacent_vec_theta = []
96
+ for i, edge_vec1 in enumerate(edge_vec):
97
+ adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
98
+ adjacent_edge_vec = edge_vec[adjacent_ind]
99
+ temp_theta_sum = np.sum(
100
+ self.vector_angle(edge_vec1, adjacent_edge_vec))
101
+ temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
102
+ adjacent_edge_vec[1])
103
+ theta_sum.append(temp_theta_sum)
104
+ adjacent_vec_theta.append(temp_adjacent_theta)
105
+ theta_sum_score = np.array(theta_sum) / np.pi
106
+ adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
107
+ poly_center = np.mean(points, axis=0)
108
+ edge_dist = np.maximum(
109
+ norm(
110
+ pad_points[1:] - poly_center, axis=-1),
111
+ norm(
112
+ pad_points[:-1] - poly_center, axis=-1))
113
+ dist_score = edge_dist / (np.max(edge_dist) + self.eps)
114
+ position_score = np.zeros(len(edge_vec))
115
+ score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
116
+ score += 0.35 * dist_score
117
+ if len(points) % 2 == 0:
118
+ position_score[(len(score) // 2 - 1)] += 1
119
+ position_score[-1] += 1
120
+ score += 0.1 * position_score
121
+ pad_score = np.concatenate([score, score])
122
+ score_matrix = np.zeros((len(score), len(score) - 3))
123
+ x = np.arange(len(score) - 3) / float(len(score) - 4)
124
+ gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
125
+ (x - 0.5) / 0.5, 2.) / 2)
126
+ gaussian = gaussian / np.max(gaussian)
127
+ for i in range(len(score)):
128
+ score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
129
+ score) - 1)] * gaussian * 0.3
130
+
131
+ head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
132
+ score_matrix.shape)
133
+ tail_start = (head_start + tail_increment + 2) % len(points)
134
+ head_end = (head_start + 1) % len(points)
135
+ tail_end = (tail_start + 1) % len(points)
136
+
137
+ if head_end > tail_end:
138
+ head_start, tail_start = tail_start, head_start
139
+ head_end, tail_end = tail_end, head_end
140
+ head_inds = [head_start, head_end]
141
+ tail_inds = [tail_start, tail_end]
142
+ else:
143
+ if self.vector_slope(points[1] - points[0]) + self.vector_slope(
144
+ points[3] - points[2]) < self.vector_slope(points[
145
+ 2] - points[1]) + self.vector_slope(points[0] - points[
146
+ 3]):
147
+ horizontal_edge_inds = [[0, 1], [2, 3]]
148
+ vertical_edge_inds = [[3, 0], [1, 2]]
149
+ else:
150
+ horizontal_edge_inds = [[3, 0], [1, 2]]
151
+ vertical_edge_inds = [[0, 1], [2, 3]]
152
+
153
+ vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
154
+ vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
155
+ 0]] - points[vertical_edge_inds[1][1]])
156
+ horizontal_len_sum = norm(points[horizontal_edge_inds[0][
157
+ 0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
158
+ horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
159
+ [1]])
160
+
161
+ if vertical_len_sum > horizontal_len_sum * orientation_thr:
162
+ head_inds = horizontal_edge_inds[0]
163
+ tail_inds = horizontal_edge_inds[1]
164
+ else:
165
+ head_inds = vertical_edge_inds[0]
166
+ tail_inds = vertical_edge_inds[1]
167
+
168
+ return head_inds, tail_inds
169
+
170
+ def reorder_poly_edge(self, points):
171
+
172
+ assert points.ndim == 2
173
+ assert points.shape[0] >= 4
174
+ assert points.shape[1] == 2
175
+
176
+ head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
177
+ head_edge, tail_edge = points[head_inds], points[tail_inds]
178
+
179
+ pad_points = np.vstack([points, points])
180
+ if tail_inds[1] < 1:
181
+ tail_inds[1] = len(points)
182
+ sideline1 = pad_points[head_inds[1]:tail_inds[1]]
183
+ sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
184
+ sideline_mean_shift = np.mean(
185
+ sideline1, axis=0) - np.mean(
186
+ sideline2, axis=0)
187
+
188
+ if sideline_mean_shift[1] > 0:
189
+ top_sideline, bot_sideline = sideline2, sideline1
190
+ else:
191
+ top_sideline, bot_sideline = sideline1, sideline2
192
+
193
+ return head_edge, tail_edge, top_sideline, bot_sideline
194
+
195
+ def cal_curve_length(self, line):
196
+
197
+ assert line.ndim == 2
198
+ assert len(line) >= 2
199
+
200
+ edges_length = np.sqrt((line[1:, 0] - line[:-1, 0])**2 + (line[
201
+ 1:, 1] - line[:-1, 1])**2)
202
+ total_length = np.sum(edges_length)
203
+ return edges_length, total_length
204
+
205
+ def resample_line(self, line, n):
206
+
207
+ assert line.ndim == 2
208
+ assert line.shape[0] >= 2
209
+ assert line.shape[1] == 2
210
+ assert isinstance(n, int)
211
+ assert n > 2
212
+
213
+ edges_length, total_length = self.cal_curve_length(line)
214
+ t_org = np.insert(np.cumsum(edges_length), 0, 0)
215
+ unit_t = total_length / (n - 1)
216
+ t_equidistant = np.arange(1, n - 1, dtype=np.float32) * unit_t
217
+ edge_ind = 0
218
+ points = [line[0]]
219
+ for t in t_equidistant:
220
+ while edge_ind < len(edges_length) - 1 and t > t_org[edge_ind + 1]:
221
+ edge_ind += 1
222
+ t_l, t_r = t_org[edge_ind], t_org[edge_ind + 1]
223
+ weight = np.array(
224
+ [t_r - t, t - t_l], dtype=np.float32) / (t_r - t_l + self.eps)
225
+ p_coords = np.dot(weight, line[[edge_ind, edge_ind + 1]])
226
+ points.append(p_coords)
227
+ points.append(line[-1])
228
+ resampled_line = np.vstack(points)
229
+
230
+ return resampled_line
231
+
232
+ def resample_sidelines(self, sideline1, sideline2, resample_step):
233
+
234
+ assert sideline1.ndim == sideline2.ndim == 2
235
+ assert sideline1.shape[1] == sideline2.shape[1] == 2
236
+ assert sideline1.shape[0] >= 2
237
+ assert sideline2.shape[0] >= 2
238
+ assert isinstance(resample_step, float)
239
+
240
+ _, length1 = self.cal_curve_length(sideline1)
241
+ _, length2 = self.cal_curve_length(sideline2)
242
+
243
+ avg_length = (length1 + length2) / 2
244
+ resample_point_num = max(int(float(avg_length) / resample_step) + 1, 3)
245
+
246
+ resampled_line1 = self.resample_line(sideline1, resample_point_num)
247
+ resampled_line2 = self.resample_line(sideline2, resample_point_num)
248
+
249
+ return resampled_line1, resampled_line2
250
+
251
+ def dist_point2line(self, point, line):
252
+
253
+ assert isinstance(line, tuple)
254
+ point1, point2 = line
255
+ d = abs(np.cross(point2 - point1, point - point1)) / (
256
+ norm(point2 - point1) + 1e-8)
257
+ return d
258
+
259
+ def draw_center_region_maps(self, top_line, bot_line, center_line,
260
+ center_region_mask, top_height_map,
261
+ bot_height_map, sin_map, cos_map,
262
+ region_shrink_ratio):
263
+
264
+ assert top_line.shape == bot_line.shape == center_line.shape
265
+ assert (center_region_mask.shape == top_height_map.shape ==
266
+ bot_height_map.shape == sin_map.shape == cos_map.shape)
267
+ assert isinstance(region_shrink_ratio, float)
268
+
269
+ h, w = center_region_mask.shape
270
+ for i in range(0, len(center_line) - 1):
271
+
272
+ top_mid_point = (top_line[i] + top_line[i + 1]) / 2
273
+ bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
274
+
275
+ sin_theta = self.vector_sin(top_mid_point - bot_mid_point)
276
+ cos_theta = self.vector_cos(top_mid_point - bot_mid_point)
277
+
278
+ tl = center_line[i] + (top_line[i] - center_line[i]
279
+ ) * region_shrink_ratio
280
+ tr = center_line[i + 1] + (top_line[i + 1] - center_line[i + 1]
281
+ ) * region_shrink_ratio
282
+ br = center_line[i + 1] + (bot_line[i + 1] - center_line[i + 1]
283
+ ) * region_shrink_ratio
284
+ bl = center_line[i] + (bot_line[i] - center_line[i]
285
+ ) * region_shrink_ratio
286
+ current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
287
+
288
+ cv2.fillPoly(center_region_mask, [current_center_box], color=1)
289
+ cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
290
+ cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
291
+
292
+ current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0,
293
+ w - 1)
294
+ current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0,
295
+ h - 1)
296
+ min_coord = np.min(current_center_box, axis=0).astype(np.int32)
297
+ max_coord = np.max(current_center_box, axis=0).astype(np.int32)
298
+ current_center_box = current_center_box - min_coord
299
+ box_sz = (max_coord - min_coord + 1)
300
+
301
+ center_box_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
302
+ cv2.fillPoly(center_box_mask, [current_center_box], color=1)
303
+
304
+ inds = np.argwhere(center_box_mask > 0)
305
+ inds = inds + (min_coord[1], min_coord[0])
306
+ inds_xy = np.fliplr(inds)
307
+ top_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
308
+ inds_xy, (top_line[i], top_line[i + 1]))
309
+ bot_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
310
+ inds_xy, (bot_line[i], bot_line[i + 1]))
311
+
312
+ def generate_center_mask_attrib_maps(self, img_size, text_polys):
313
+
314
+ assert isinstance(img_size, tuple)
315
+
316
+ h, w = img_size
317
+
318
+ center_lines = []
319
+ center_region_mask = np.zeros((h, w), np.uint8)
320
+ top_height_map = np.zeros((h, w), dtype=np.float32)
321
+ bot_height_map = np.zeros((h, w), dtype=np.float32)
322
+ sin_map = np.zeros((h, w), dtype=np.float32)
323
+ cos_map = np.zeros((h, w), dtype=np.float32)
324
+
325
+ for poly in text_polys:
326
+ polygon_points = poly
327
+ _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
328
+ resampled_top_line, resampled_bot_line = self.resample_sidelines(
329
+ top_line, bot_line, self.resample_step)
330
+ resampled_bot_line = resampled_bot_line[::-1]
331
+ center_line = (resampled_top_line + resampled_bot_line) / 2
332
+
333
+ if self.vector_slope(center_line[-1] - center_line[0]) > 2:
334
+ if (center_line[-1] - center_line[0])[1] < 0:
335
+ center_line = center_line[::-1]
336
+ resampled_top_line = resampled_top_line[::-1]
337
+ resampled_bot_line = resampled_bot_line[::-1]
338
+ else:
339
+ if (center_line[-1] - center_line[0])[0] < 0:
340
+ center_line = center_line[::-1]
341
+ resampled_top_line = resampled_top_line[::-1]
342
+ resampled_bot_line = resampled_bot_line[::-1]
343
+
344
+ line_head_shrink_len = np.clip(
345
+ (norm(top_line[0] - bot_line[0]) * self.comp_w_h_ratio),
346
+ self.min_width, self.max_width) / 2
347
+ line_tail_shrink_len = np.clip(
348
+ (norm(top_line[-1] - bot_line[-1]) * self.comp_w_h_ratio),
349
+ self.min_width, self.max_width) / 2
350
+ num_head_shrink = int(line_head_shrink_len // self.resample_step)
351
+ num_tail_shrink = int(line_tail_shrink_len // self.resample_step)
352
+ if len(center_line) > num_head_shrink + num_tail_shrink + 2:
353
+ center_line = center_line[num_head_shrink:len(center_line) -
354
+ num_tail_shrink]
355
+ resampled_top_line = resampled_top_line[num_head_shrink:len(
356
+ resampled_top_line) - num_tail_shrink]
357
+ resampled_bot_line = resampled_bot_line[num_head_shrink:len(
358
+ resampled_bot_line) - num_tail_shrink]
359
+ center_lines.append(center_line.astype(np.int32))
360
+
361
+ self.draw_center_region_maps(
362
+ resampled_top_line, resampled_bot_line, center_line,
363
+ center_region_mask, top_height_map, bot_height_map, sin_map,
364
+ cos_map, self.center_region_shrink_ratio)
365
+
366
+ return (center_lines, center_region_mask, top_height_map,
367
+ bot_height_map, sin_map, cos_map)
368
+
369
+ def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
370
+
371
+ assert isinstance(num_rand_comps, int)
372
+ assert num_rand_comps > 0
373
+ assert center_sample_mask.ndim == 2
374
+
375
+ h, w = center_sample_mask.shape
376
+
377
+ max_rand_half_height = self.max_rand_half_height
378
+ min_rand_half_height = self.min_rand_half_height
379
+ max_rand_height = max_rand_half_height * 2
380
+ max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
381
+ self.min_width, self.max_width)
382
+ margin = int(
383
+ np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
384
+
385
+ if 2 * margin + 1 > min(h, w):
386
+
387
+ assert min(h, w) > (np.sqrt(2) * (self.min_width + 1))
388
+ max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1)
389
+ min_rand_half_height = max(max_rand_half_height / 4,
390
+ self.min_width / 2)
391
+
392
+ max_rand_height = max_rand_half_height * 2
393
+ max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
394
+ self.min_width, self.max_width)
395
+ margin = int(
396
+ np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
397
+
398
+ inner_center_sample_mask = np.zeros_like(center_sample_mask)
399
+ inner_center_sample_mask[margin:h - margin, margin:w - margin] = \
400
+ center_sample_mask[margin:h - margin, margin:w - margin]
401
+ kernel_size = int(np.clip(max_rand_half_height, 7, 21))
402
+ inner_center_sample_mask = cv2.erode(
403
+ inner_center_sample_mask,
404
+ np.ones((kernel_size, kernel_size), np.uint8))
405
+
406
+ center_candidates = np.argwhere(inner_center_sample_mask > 0)
407
+ num_center_candidates = len(center_candidates)
408
+ sample_inds = np.random.choice(num_center_candidates, num_rand_comps)
409
+ rand_centers = center_candidates[sample_inds]
410
+
411
+ rand_top_height = np.random.randint(
412
+ min_rand_half_height,
413
+ max_rand_half_height,
414
+ size=(len(rand_centers), 1))
415
+ rand_bot_height = np.random.randint(
416
+ min_rand_half_height,
417
+ max_rand_half_height,
418
+ size=(len(rand_centers), 1))
419
+
420
+ rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
421
+ rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
422
+ scale = np.sqrt(1.0 / (rand_cos**2 + rand_sin**2 + 1e-8))
423
+ rand_cos = rand_cos * scale
424
+ rand_sin = rand_sin * scale
425
+
426
+ height = (rand_top_height + rand_bot_height)
427
+ width = np.clip(height * self.comp_w_h_ratio, self.min_width,
428
+ self.max_width)
429
+
430
+ rand_comp_attribs = np.hstack([
431
+ rand_centers[:, ::-1], height, width, rand_cos, rand_sin,
432
+ np.zeros_like(rand_sin)
433
+ ]).astype(np.float32)
434
+
435
+ return rand_comp_attribs
436
+
437
+ def jitter_comp_attribs(self, comp_attribs, jitter_level):
438
+ """Jitter text components attributes.
439
+
440
+ Args:
441
+ comp_attribs (ndarray): The text component attributes.
442
+ jitter_level (float): The jitter level of text components
443
+ attributes.
444
+
445
+ Returns:
446
+ jittered_comp_attribs (ndarray): The jittered text component
447
+ attributes (x, y, h, w, cos, sin, comp_label).
448
+ """
449
+
450
+ assert comp_attribs.shape[1] == 7
451
+ assert comp_attribs.shape[0] > 0
452
+ assert isinstance(jitter_level, float)
453
+
454
+ x = comp_attribs[:, 0].reshape((-1, 1))
455
+ y = comp_attribs[:, 1].reshape((-1, 1))
456
+ h = comp_attribs[:, 2].reshape((-1, 1))
457
+ w = comp_attribs[:, 3].reshape((-1, 1))
458
+ cos = comp_attribs[:, 4].reshape((-1, 1))
459
+ sin = comp_attribs[:, 5].reshape((-1, 1))
460
+ comp_labels = comp_attribs[:, 6].reshape((-1, 1))
461
+
462
+ x += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
463
+ h * np.abs(cos) + w * np.abs(sin)) * jitter_level
464
+ y += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
465
+ h * np.abs(sin) + w * np.abs(cos)) * jitter_level
466
+
467
+ h += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
468
+ ) * h * jitter_level
469
+ w += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
470
+ ) * w * jitter_level
471
+
472
+ cos += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
473
+ ) * 2 * jitter_level
474
+ sin += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
475
+ ) * 2 * jitter_level
476
+
477
+ scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8))
478
+ cos = cos * scale
479
+ sin = sin * scale
480
+
481
+ jittered_comp_attribs = np.hstack([x, y, h, w, cos, sin, comp_labels])
482
+
483
+ return jittered_comp_attribs
484
+
485
+ def generate_comp_attribs(self, center_lines, text_mask, center_region_mask,
486
+ top_height_map, bot_height_map, sin_map, cos_map):
487
+ """Generate text component attributes.
488
+
489
+ Args:
490
+ center_lines (list[ndarray]): The list of text center lines .
491
+ text_mask (ndarray): The text region mask.
492
+ center_region_mask (ndarray): The text center region mask.
493
+ top_height_map (ndarray): The map on which the distance from points
494
+ to top side lines will be drawn for each pixel in text center
495
+ regions.
496
+ bot_height_map (ndarray): The map on which the distance from points
497
+ to bottom side lines will be drawn for each pixel in text
498
+ center regions.
499
+ sin_map (ndarray): The sin(theta) map where theta is the angle
500
+ between vector (top point - bottom point) and vector (1, 0).
501
+ cos_map (ndarray): The cos(theta) map where theta is the angle
502
+ between vector (top point - bottom point) and vector (1, 0).
503
+
504
+ Returns:
505
+ pad_comp_attribs (ndarray): The padded text component attributes
506
+ of a fixed size.
507
+ """
508
+
509
+ assert isinstance(center_lines, list)
510
+ assert (
511
+ text_mask.shape == center_region_mask.shape == top_height_map.shape
512
+ == bot_height_map.shape == sin_map.shape == cos_map.shape)
513
+
514
+ center_lines_mask = np.zeros_like(center_region_mask)
515
+ cv2.polylines(center_lines_mask, center_lines, 0, 1, 1)
516
+ center_lines_mask = center_lines_mask * center_region_mask
517
+ comp_centers = np.argwhere(center_lines_mask > 0)
518
+
519
+ y = comp_centers[:, 0]
520
+ x = comp_centers[:, 1]
521
+
522
+ top_height = top_height_map[y, x].reshape(
523
+ (-1, 1)) * self.comp_shrink_ratio
524
+ bot_height = bot_height_map[y, x].reshape(
525
+ (-1, 1)) * self.comp_shrink_ratio
526
+ sin = sin_map[y, x].reshape((-1, 1))
527
+ cos = cos_map[y, x].reshape((-1, 1))
528
+
529
+ top_mid_points = comp_centers + np.hstack(
530
+ [top_height * sin, top_height * cos])
531
+ bot_mid_points = comp_centers - np.hstack(
532
+ [bot_height * sin, bot_height * cos])
533
+
534
+ width = (top_height + bot_height) * self.comp_w_h_ratio
535
+ width = np.clip(width, self.min_width, self.max_width)
536
+ r = width / 2
537
+
538
+ tl = top_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
539
+ tr = top_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
540
+ br = bot_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
541
+ bl = bot_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
542
+ text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
543
+
544
+ score = np.ones((text_comps.shape[0], 1), dtype=np.float32)
545
+ text_comps = np.hstack([text_comps, score])
546
+ text_comps = la_nms(text_comps, self.text_comp_nms_thr)
547
+
548
+ if text_comps.shape[0] >= 1:
549
+ img_h, img_w = center_region_mask.shape
550
+ text_comps[:, 0:8:2] = np.clip(text_comps[:, 0:8:2], 0, img_w - 1)
551
+ text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1)
552
+
553
+ comp_centers = np.mean(
554
+ text_comps[:, 0:8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
555
+ x = comp_centers[:, 0]
556
+ y = comp_centers[:, 1]
557
+
558
+ height = (top_height_map[y, x] + bot_height_map[y, x]).reshape(
559
+ (-1, 1))
560
+ width = np.clip(height * self.comp_w_h_ratio, self.min_width,
561
+ self.max_width)
562
+
563
+ cos = cos_map[y, x].reshape((-1, 1))
564
+ sin = sin_map[y, x].reshape((-1, 1))
565
+
566
+ _, comp_label_mask = cv2.connectedComponents(
567
+ center_region_mask, connectivity=8)
568
+ comp_labels = comp_label_mask[y, x].reshape(
569
+ (-1, 1)).astype(np.float32)
570
+
571
+ x = x.reshape((-1, 1)).astype(np.float32)
572
+ y = y.reshape((-1, 1)).astype(np.float32)
573
+ comp_attribs = np.hstack(
574
+ [x, y, height, width, cos, sin, comp_labels])
575
+ comp_attribs = self.jitter_comp_attribs(comp_attribs,
576
+ self.jitter_level)
577
+
578
+ if comp_attribs.shape[0] < self.num_min_comps:
579
+ num_rand_comps = self.num_min_comps - comp_attribs.shape[0]
580
+ rand_comp_attribs = self.generate_rand_comp_attribs(
581
+ num_rand_comps, 1 - text_mask)
582
+ comp_attribs = np.vstack([comp_attribs, rand_comp_attribs])
583
+ else:
584
+ comp_attribs = self.generate_rand_comp_attribs(self.num_min_comps,
585
+ 1 - text_mask)
586
+
587
+ num_comps = (np.ones(
588
+ (comp_attribs.shape[0], 1),
589
+ dtype=np.float32) * comp_attribs.shape[0])
590
+ comp_attribs = np.hstack([num_comps, comp_attribs])
591
+
592
+ if comp_attribs.shape[0] > self.num_max_comps:
593
+ comp_attribs = comp_attribs[:self.num_max_comps, :]
594
+ comp_attribs[:, 0] = self.num_max_comps
595
+
596
+ pad_comp_attribs = np.zeros(
597
+ (self.num_max_comps, comp_attribs.shape[1]), dtype=np.float32)
598
+ pad_comp_attribs[:comp_attribs.shape[0], :] = comp_attribs
599
+
600
+ return pad_comp_attribs
601
+
602
+ def generate_text_region_mask(self, img_size, text_polys):
603
+ """Generate text center region mask and geometry attribute maps.
604
+
605
+ Args:
606
+ img_size (tuple): The image size (height, width).
607
+ text_polys (list[list[ndarray]]): The list of text polygons.
608
+
609
+ Returns:
610
+ text_region_mask (ndarray): The text region mask.
611
+ """
612
+
613
+ assert isinstance(img_size, tuple)
614
+
615
+ h, w = img_size
616
+ text_region_mask = np.zeros((h, w), dtype=np.uint8)
617
+
618
+ for poly in text_polys:
619
+ polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
620
+ cv2.fillPoly(text_region_mask, polygon, 1)
621
+
622
+ return text_region_mask
623
+
624
+ def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
625
+ """Generate effective mask by setting the ineffective regions to 0 and
626
+ effective regions to 1.
627
+
628
+ Args:
629
+ mask_size (tuple): The mask size.
630
+ polygons_ignore (list[[ndarray]]: The list of ignored text
631
+ polygons.
632
+
633
+ Returns:
634
+ mask (ndarray): The effective mask of (height, width).
635
+ """
636
+ mask = np.ones(mask_size, dtype=np.uint8)
637
+
638
+ for poly in polygons_ignore:
639
+ instance = poly.astype(np.int32).reshape(1, -1, 2)
640
+ cv2.fillPoly(mask, instance, 0)
641
+
642
+ return mask
643
+
644
+ def generate_targets(self, data):
645
+ """Generate the gt targets for DRRG.
646
+
647
+ Args:
648
+ data (dict): The input result dictionary.
649
+
650
+ Returns:
651
+ data (dict): The output result dictionary.
652
+ """
653
+
654
+ assert isinstance(data, dict)
655
+
656
+ image = data['image']
657
+ polygons = data['polys']
658
+ ignore_tags = data['ignore_tags']
659
+ h, w, _ = image.shape
660
+
661
+ polygon_masks = []
662
+ polygon_masks_ignore = []
663
+ for tag, polygon in zip(ignore_tags, polygons):
664
+ if tag is True:
665
+ polygon_masks_ignore.append(polygon)
666
+ else:
667
+ polygon_masks.append(polygon)
668
+
669
+ gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
670
+ gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
671
+ (center_lines, gt_center_region_mask, gt_top_height_map,
672
+ gt_bot_height_map, gt_sin_map,
673
+ gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
674
+ polygon_masks)
675
+
676
+ gt_comp_attribs = self.generate_comp_attribs(
677
+ center_lines, gt_text_mask, gt_center_region_mask,
678
+ gt_top_height_map, gt_bot_height_map, gt_sin_map, gt_cos_map)
679
+
680
+ mapping = {
681
+ 'gt_text_mask': gt_text_mask,
682
+ 'gt_center_region_mask': gt_center_region_mask,
683
+ 'gt_mask': gt_mask,
684
+ 'gt_top_height_map': gt_top_height_map,
685
+ 'gt_bot_height_map': gt_bot_height_map,
686
+ 'gt_sin_map': gt_sin_map,
687
+ 'gt_cos_map': gt_cos_map
688
+ }
689
+
690
+ data.update(mapping)
691
+ data['gt_comp_attribs'] = gt_comp_attribs
692
+ return data
693
+
694
+ def __call__(self, data):
695
+ data = self.generate_targets(data)
696
+ return data
ppocr/data/imaug/east_process.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ #Licensed under the Apache License, Version 2.0 (the "License");
4
+ #you may not use this file except in compliance with the License.
5
+ #You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ #Unless required by applicable law or agreed to in writing, software
10
+ #distributed under the License is distributed on an "AS IS" BASIS,
11
+ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ #See the License for the specific language governing permissions and
13
+ #limitations under the License.
14
+ """
15
+ This code is refered from:
16
+ https://github.com/songdejia/EAST/blob/master/data_utils.py
17
+ """
18
+ import math
19
+ import cv2
20
+ import numpy as np
21
+ import json
22
+ import sys
23
+ import os
24
+
25
+ __all__ = ['EASTProcessTrain']
26
+
27
+
28
+ class EASTProcessTrain(object):
29
+ def __init__(self,
30
+ image_shape=[512, 512],
31
+ background_ratio=0.125,
32
+ min_crop_side_ratio=0.1,
33
+ min_text_size=10,
34
+ **kwargs):
35
+ self.input_size = image_shape[1]
36
+ self.random_scale = np.array([0.5, 1, 2.0, 3.0])
37
+ self.background_ratio = background_ratio
38
+ self.min_crop_side_ratio = min_crop_side_ratio
39
+ self.min_text_size = min_text_size
40
+
41
+ def preprocess(self, im):
42
+ input_size = self.input_size
43
+ im_shape = im.shape
44
+ im_size_min = np.min(im_shape[0:2])
45
+ im_size_max = np.max(im_shape[0:2])
46
+ im_scale = float(input_size) / float(im_size_max)
47
+ im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
48
+ img_mean = [0.485, 0.456, 0.406]
49
+ img_std = [0.229, 0.224, 0.225]
50
+ # im = im[:, :, ::-1].astype(np.float32)
51
+ im = im / 255
52
+ im -= img_mean
53
+ im /= img_std
54
+ new_h, new_w, _ = im.shape
55
+ im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
56
+ im_padded[:new_h, :new_w, :] = im
57
+ im_padded = im_padded.transpose((2, 0, 1))
58
+ im_padded = im_padded[np.newaxis, :]
59
+ return im_padded, im_scale
60
+
61
+ def rotate_im_poly(self, im, text_polys):
62
+ """
63
+ rotate image with 90 / 180 / 270 degre
64
+ """
65
+ im_w, im_h = im.shape[1], im.shape[0]
66
+ dst_im = im.copy()
67
+ dst_polys = []
68
+ rand_degree_ratio = np.random.rand()
69
+ rand_degree_cnt = 1
70
+ if 0.333 < rand_degree_ratio < 0.666:
71
+ rand_degree_cnt = 2
72
+ elif rand_degree_ratio > 0.666:
73
+ rand_degree_cnt = 3
74
+ for i in range(rand_degree_cnt):
75
+ dst_im = np.rot90(dst_im)
76
+ rot_degree = -90 * rand_degree_cnt
77
+ rot_angle = rot_degree * math.pi / 180.0
78
+ n_poly = text_polys.shape[0]
79
+ cx, cy = 0.5 * im_w, 0.5 * im_h
80
+ ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
81
+ for i in range(n_poly):
82
+ wordBB = text_polys[i]
83
+ poly = []
84
+ for j in range(4):
85
+ sx, sy = wordBB[j][0], wordBB[j][1]
86
+ dx = math.cos(rot_angle) * (sx - cx)\
87
+ - math.sin(rot_angle) * (sy - cy) + ncx
88
+ dy = math.sin(rot_angle) * (sx - cx)\
89
+ + math.cos(rot_angle) * (sy - cy) + ncy
90
+ poly.append([dx, dy])
91
+ dst_polys.append(poly)
92
+ dst_polys = np.array(dst_polys, dtype=np.float32)
93
+ return dst_im, dst_polys
94
+
95
+ def polygon_area(self, poly):
96
+ """
97
+ compute area of a polygon
98
+ :param poly:
99
+ :return:
100
+ """
101
+ edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
102
+ (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
103
+ (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
104
+ (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
105
+ return np.sum(edge) / 2.
106
+
107
+ def check_and_validate_polys(self, polys, tags, img_height, img_width):
108
+ """
109
+ check so that the text poly is in the same direction,
110
+ and also filter some invalid polygons
111
+ :param polys:
112
+ :param tags:
113
+ :return:
114
+ """
115
+ h, w = img_height, img_width
116
+ if polys.shape[0] == 0:
117
+ return polys
118
+ polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
119
+ polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
120
+
121
+ validated_polys = []
122
+ validated_tags = []
123
+ for poly, tag in zip(polys, tags):
124
+ p_area = self.polygon_area(poly)
125
+ #invalid poly
126
+ if abs(p_area) < 1:
127
+ continue
128
+ if p_area > 0:
129
+ #'poly in wrong direction'
130
+ if not tag:
131
+ tag = True #reversed cases should be ignore
132
+ poly = poly[(0, 3, 2, 1), :]
133
+ validated_polys.append(poly)
134
+ validated_tags.append(tag)
135
+ return np.array(validated_polys), np.array(validated_tags)
136
+
137
+ def draw_img_polys(self, img, polys):
138
+ if len(img.shape) == 4:
139
+ img = np.squeeze(img, axis=0)
140
+ if img.shape[0] == 3:
141
+ img = img.transpose((1, 2, 0))
142
+ img[:, :, 2] += 123.68
143
+ img[:, :, 1] += 116.78
144
+ img[:, :, 0] += 103.94
145
+ cv2.imwrite("tmp.jpg", img)
146
+ img = cv2.imread("tmp.jpg")
147
+ for box in polys:
148
+ box = box.astype(np.int32).reshape((-1, 1, 2))
149
+ cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
150
+ import random
151
+ ino = random.randint(0, 100)
152
+ cv2.imwrite("tmp_%d.jpg" % ino, img)
153
+ return
154
+
155
+ def shrink_poly(self, poly, r):
156
+ """
157
+ fit a poly inside the origin poly, maybe bugs here...
158
+ used for generate the score map
159
+ :param poly: the text poly
160
+ :param r: r in the paper
161
+ :return: the shrinked poly
162
+ """
163
+ # shrink ratio
164
+ R = 0.3
165
+ # find the longer pair
166
+ dist0 = np.linalg.norm(poly[0] - poly[1])
167
+ dist1 = np.linalg.norm(poly[2] - poly[3])
168
+ dist2 = np.linalg.norm(poly[0] - poly[3])
169
+ dist3 = np.linalg.norm(poly[1] - poly[2])
170
+ if dist0 + dist1 > dist2 + dist3:
171
+ # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
172
+ ## p0, p1
173
+ theta = np.arctan2((poly[1][1] - poly[0][1]),
174
+ (poly[1][0] - poly[0][0]))
175
+ poly[0][0] += R * r[0] * np.cos(theta)
176
+ poly[0][1] += R * r[0] * np.sin(theta)
177
+ poly[1][0] -= R * r[1] * np.cos(theta)
178
+ poly[1][1] -= R * r[1] * np.sin(theta)
179
+ ## p2, p3
180
+ theta = np.arctan2((poly[2][1] - poly[3][1]),
181
+ (poly[2][0] - poly[3][0]))
182
+ poly[3][0] += R * r[3] * np.cos(theta)
183
+ poly[3][1] += R * r[3] * np.sin(theta)
184
+ poly[2][0] -= R * r[2] * np.cos(theta)
185
+ poly[2][1] -= R * r[2] * np.sin(theta)
186
+ ## p0, p3
187
+ theta = np.arctan2((poly[3][0] - poly[0][0]),
188
+ (poly[3][1] - poly[0][1]))
189
+ poly[0][0] += R * r[0] * np.sin(theta)
190
+ poly[0][1] += R * r[0] * np.cos(theta)
191
+ poly[3][0] -= R * r[3] * np.sin(theta)
192
+ poly[3][1] -= R * r[3] * np.cos(theta)
193
+ ## p1, p2
194
+ theta = np.arctan2((poly[2][0] - poly[1][0]),
195
+ (poly[2][1] - poly[1][1]))
196
+ poly[1][0] += R * r[1] * np.sin(theta)
197
+ poly[1][1] += R * r[1] * np.cos(theta)
198
+ poly[2][0] -= R * r[2] * np.sin(theta)
199
+ poly[2][1] -= R * r[2] * np.cos(theta)
200
+ else:
201
+ ## p0, p3
202
+ # print poly
203
+ theta = np.arctan2((poly[3][0] - poly[0][0]),
204
+ (poly[3][1] - poly[0][1]))
205
+ poly[0][0] += R * r[0] * np.sin(theta)
206
+ poly[0][1] += R * r[0] * np.cos(theta)
207
+ poly[3][0] -= R * r[3] * np.sin(theta)
208
+ poly[3][1] -= R * r[3] * np.cos(theta)
209
+ ## p1, p2
210
+ theta = np.arctan2((poly[2][0] - poly[1][0]),
211
+ (poly[2][1] - poly[1][1]))
212
+ poly[1][0] += R * r[1] * np.sin(theta)
213
+ poly[1][1] += R * r[1] * np.cos(theta)
214
+ poly[2][0] -= R * r[2] * np.sin(theta)
215
+ poly[2][1] -= R * r[2] * np.cos(theta)
216
+ ## p0, p1
217
+ theta = np.arctan2((poly[1][1] - poly[0][1]),
218
+ (poly[1][0] - poly[0][0]))
219
+ poly[0][0] += R * r[0] * np.cos(theta)
220
+ poly[0][1] += R * r[0] * np.sin(theta)
221
+ poly[1][0] -= R * r[1] * np.cos(theta)
222
+ poly[1][1] -= R * r[1] * np.sin(theta)
223
+ ## p2, p3
224
+ theta = np.arctan2((poly[2][1] - poly[3][1]),
225
+ (poly[2][0] - poly[3][0]))
226
+ poly[3][0] += R * r[3] * np.cos(theta)
227
+ poly[3][1] += R * r[3] * np.sin(theta)
228
+ poly[2][0] -= R * r[2] * np.cos(theta)
229
+ poly[2][1] -= R * r[2] * np.sin(theta)
230
+ return poly
231
+
232
+ def generate_quad(self, im_size, polys, tags):
233
+ """
234
+ Generate quadrangle.
235
+ """
236
+ h, w = im_size
237
+ poly_mask = np.zeros((h, w), dtype=np.uint8)
238
+ score_map = np.zeros((h, w), dtype=np.uint8)
239
+ # (x1, y1, ..., x4, y4, short_edge_norm)
240
+ geo_map = np.zeros((h, w, 9), dtype=np.float32)
241
+ # mask used during traning, to ignore some hard areas
242
+ training_mask = np.ones((h, w), dtype=np.uint8)
243
+ for poly_idx, poly_tag in enumerate(zip(polys, tags)):
244
+ poly = poly_tag[0]
245
+ tag = poly_tag[1]
246
+
247
+ r = [None, None, None, None]
248
+ for i in range(4):
249
+ dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
250
+ dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
251
+ r[i] = min(dist1, dist2)
252
+ # score map
253
+ shrinked_poly = self.shrink_poly(
254
+ poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
255
+ cv2.fillPoly(score_map, shrinked_poly, 1)
256
+ cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
257
+ # if the poly is too small, then ignore it during training
258
+ poly_h = min(
259
+ np.linalg.norm(poly[0] - poly[3]),
260
+ np.linalg.norm(poly[1] - poly[2]))
261
+ poly_w = min(
262
+ np.linalg.norm(poly[0] - poly[1]),
263
+ np.linalg.norm(poly[2] - poly[3]))
264
+ if min(poly_h, poly_w) < self.min_text_size:
265
+ cv2.fillPoly(training_mask,
266
+ poly.astype(np.int32)[np.newaxis, :, :], 0)
267
+
268
+ if tag:
269
+ cv2.fillPoly(training_mask,
270
+ poly.astype(np.int32)[np.newaxis, :, :], 0)
271
+
272
+ xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
273
+ # geo map.
274
+ y_in_poly = xy_in_poly[:, 0]
275
+ x_in_poly = xy_in_poly[:, 1]
276
+ poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
277
+ poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
278
+ for pno in range(4):
279
+ geo_channel_beg = pno * 2
280
+ geo_map[y_in_poly, x_in_poly, geo_channel_beg] =\
281
+ x_in_poly - poly[pno, 0]
282
+ geo_map[y_in_poly, x_in_poly, geo_channel_beg+1] =\
283
+ y_in_poly - poly[pno, 1]
284
+ geo_map[y_in_poly, x_in_poly, 8] = \
285
+ 1.0 / max(min(poly_h, poly_w), 1.0)
286
+ return score_map, geo_map, training_mask
287
+
288
+ def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
289
+ """
290
+ make random crop from the input image
291
+ :param im:
292
+ :param polys:
293
+ :param tags:
294
+ :param crop_background:
295
+ :param max_tries:
296
+ :return:
297
+ """
298
+ h, w, _ = im.shape
299
+ pad_h = h // 10
300
+ pad_w = w // 10
301
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
302
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
303
+ for poly in polys:
304
+ poly = np.round(poly, decimals=0).astype(np.int32)
305
+ minx = np.min(poly[:, 0])
306
+ maxx = np.max(poly[:, 0])
307
+ w_array[minx + pad_w:maxx + pad_w] = 1
308
+ miny = np.min(poly[:, 1])
309
+ maxy = np.max(poly[:, 1])
310
+ h_array[miny + pad_h:maxy + pad_h] = 1
311
+ # ensure the cropped area not across a text
312
+ h_axis = np.where(h_array == 0)[0]
313
+ w_axis = np.where(w_array == 0)[0]
314
+ if len(h_axis) == 0 or len(w_axis) == 0:
315
+ return im, polys, tags
316
+
317
+ for i in range(max_tries):
318
+ xx = np.random.choice(w_axis, size=2)
319
+ xmin = np.min(xx) - pad_w
320
+ xmax = np.max(xx) - pad_w
321
+ xmin = np.clip(xmin, 0, w - 1)
322
+ xmax = np.clip(xmax, 0, w - 1)
323
+ yy = np.random.choice(h_axis, size=2)
324
+ ymin = np.min(yy) - pad_h
325
+ ymax = np.max(yy) - pad_h
326
+ ymin = np.clip(ymin, 0, h - 1)
327
+ ymax = np.clip(ymax, 0, h - 1)
328
+ if xmax - xmin < self.min_crop_side_ratio * w or \
329
+ ymax - ymin < self.min_crop_side_ratio * h:
330
+ # area too small
331
+ continue
332
+ if polys.shape[0] != 0:
333
+ poly_axis_in_area = (polys[:, :, 0] >= xmin)\
334
+ & (polys[:, :, 0] <= xmax)\
335
+ & (polys[:, :, 1] >= ymin)\
336
+ & (polys[:, :, 1] <= ymax)
337
+ selected_polys = np.where(
338
+ np.sum(poly_axis_in_area, axis=1) == 4)[0]
339
+ else:
340
+ selected_polys = []
341
+
342
+ if len(selected_polys) == 0:
343
+ # no text in this area
344
+ if crop_background:
345
+ im = im[ymin:ymax + 1, xmin:xmax + 1, :]
346
+ polys = []
347
+ tags = []
348
+ return im, polys, tags
349
+ else:
350
+ continue
351
+
352
+ im = im[ymin:ymax + 1, xmin:xmax + 1, :]
353
+ polys = polys[selected_polys]
354
+ tags = tags[selected_polys]
355
+ polys[:, :, 0] -= xmin
356
+ polys[:, :, 1] -= ymin
357
+ return im, polys, tags
358
+ return im, polys, tags
359
+
360
+ def crop_background_infor(self, im, text_polys, text_tags):
361
+ im, text_polys, text_tags = self.crop_area(
362
+ im, text_polys, text_tags, crop_background=True)
363
+
364
+ if len(text_polys) > 0:
365
+ return None
366
+ # pad and resize image
367
+ input_size = self.input_size
368
+ im, ratio = self.preprocess(im)
369
+ score_map = np.zeros((input_size, input_size), dtype=np.float32)
370
+ geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
371
+ training_mask = np.ones((input_size, input_size), dtype=np.float32)
372
+ return im, score_map, geo_map, training_mask
373
+
374
+ def crop_foreground_infor(self, im, text_polys, text_tags):
375
+ im, text_polys, text_tags = self.crop_area(
376
+ im, text_polys, text_tags, crop_background=False)
377
+
378
+ if text_polys.shape[0] == 0:
379
+ return None
380
+ #continue for all ignore case
381
+ if np.sum((text_tags * 1.0)) >= text_tags.size:
382
+ return None
383
+ # pad and resize image
384
+ input_size = self.input_size
385
+ im, ratio = self.preprocess(im)
386
+ text_polys[:, :, 0] *= ratio
387
+ text_polys[:, :, 1] *= ratio
388
+ _, _, new_h, new_w = im.shape
389
+ # print(im.shape)
390
+ # self.draw_img_polys(im, text_polys)
391
+ score_map, geo_map, training_mask = self.generate_quad(
392
+ (new_h, new_w), text_polys, text_tags)
393
+ return im, score_map, geo_map, training_mask
394
+
395
+ def __call__(self, data):
396
+ im = data['image']
397
+ text_polys = data['polys']
398
+ text_tags = data['ignore_tags']
399
+ if im is None:
400
+ return None
401
+ if text_polys.shape[0] == 0:
402
+ return None
403
+
404
+ #add rotate cases
405
+ if np.random.rand() < 0.5:
406
+ im, text_polys = self.rotate_im_poly(im, text_polys)
407
+ h, w, _ = im.shape
408
+ text_polys, text_tags = self.check_and_validate_polys(text_polys,
409
+ text_tags, h, w)
410
+ if text_polys.shape[0] == 0:
411
+ return None
412
+
413
+ # random scale this image
414
+ rd_scale = np.random.choice(self.random_scale)
415
+ im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
416
+ text_polys *= rd_scale
417
+ if np.random.rand() < self.background_ratio:
418
+ outs = self.crop_background_infor(im, text_polys, text_tags)
419
+ else:
420
+ outs = self.crop_foreground_infor(im, text_polys, text_tags)
421
+
422
+ if outs is None:
423
+ return None
424
+ im, score_map, geo_map, training_mask = outs
425
+ score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
426
+ geo_map = np.swapaxes(geo_map, 1, 2)
427
+ geo_map = np.swapaxes(geo_map, 1, 0)
428
+ geo_map = geo_map[:, ::4, ::4].astype(np.float32)
429
+ training_mask = training_mask[np.newaxis, ::4, ::4]
430
+ training_mask = training_mask.astype(np.float32)
431
+
432
+ data['image'] = im[0]
433
+ data['score_map'] = score_map
434
+ data['geo_map'] = geo_map
435
+ data['training_mask'] = training_mask
436
+ return data
ppocr/data/imaug/fce_aug.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This code is refer from:
16
+ https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/transforms.py
17
+ """
18
+ import numpy as np
19
+ from PIL import Image, ImageDraw
20
+ import cv2
21
+ from shapely.geometry import Polygon
22
+ import math
23
+ from ppocr.utils.poly_nms import poly_intersection
24
+
25
+
26
+ class RandomScaling:
27
+ def __init__(self, size=800, scale=(3. / 4, 5. / 2), **kwargs):
28
+ """Random scale the image while keeping aspect.
29
+
30
+ Args:
31
+ size (int) : Base size before scaling.
32
+ scale (tuple(float)) : The range of scaling.
33
+ """
34
+ assert isinstance(size, int)
35
+ assert isinstance(scale, float) or isinstance(scale, tuple)
36
+ self.size = size
37
+ self.scale = scale if isinstance(scale, tuple) \
38
+ else (1 - scale, 1 + scale)
39
+
40
+ def __call__(self, data):
41
+ image = data['image']
42
+ text_polys = data['polys']
43
+ h, w, _ = image.shape
44
+
45
+ aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
46
+ scales = self.size * 1.0 / max(h, w) * aspect_ratio
47
+ scales = np.array([scales, scales])
48
+ out_size = (int(h * scales[1]), int(w * scales[0]))
49
+ image = cv2.resize(image, out_size[::-1])
50
+
51
+ data['image'] = image
52
+ text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1]
53
+ text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0]
54
+ data['polys'] = text_polys
55
+
56
+ return data
57
+
58
+
59
+ class RandomCropFlip:
60
+ def __init__(self,
61
+ pad_ratio=0.1,
62
+ crop_ratio=0.5,
63
+ iter_num=1,
64
+ min_area_ratio=0.2,
65
+ **kwargs):
66
+ """Random crop and flip a patch of the image.
67
+
68
+ Args:
69
+ crop_ratio (float): The ratio of cropping.
70
+ iter_num (int): Number of operations.
71
+ min_area_ratio (float): Minimal area ratio between cropped patch
72
+ and original image.
73
+ """
74
+ assert isinstance(crop_ratio, float)
75
+ assert isinstance(iter_num, int)
76
+ assert isinstance(min_area_ratio, float)
77
+
78
+ self.pad_ratio = pad_ratio
79
+ self.epsilon = 1e-2
80
+ self.crop_ratio = crop_ratio
81
+ self.iter_num = iter_num
82
+ self.min_area_ratio = min_area_ratio
83
+
84
+ def __call__(self, results):
85
+ for i in range(self.iter_num):
86
+ results = self.random_crop_flip(results)
87
+
88
+ return results
89
+
90
+ def random_crop_flip(self, results):
91
+ image = results['image']
92
+ polygons = results['polys']
93
+ ignore_tags = results['ignore_tags']
94
+ if len(polygons) == 0:
95
+ return results
96
+
97
+ if np.random.random() >= self.crop_ratio:
98
+ return results
99
+
100
+ h, w, _ = image.shape
101
+ area = h * w
102
+ pad_h = int(h * self.pad_ratio)
103
+ pad_w = int(w * self.pad_ratio)
104
+ h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h,
105
+ pad_w)
106
+ if len(h_axis) == 0 or len(w_axis) == 0:
107
+ return results
108
+
109
+ attempt = 0
110
+ while attempt < 50:
111
+ attempt += 1
112
+ polys_keep = []
113
+ polys_new = []
114
+ ignore_tags_keep = []
115
+ ignore_tags_new = []
116
+ xx = np.random.choice(w_axis, size=2)
117
+ xmin = np.min(xx) - pad_w
118
+ xmax = np.max(xx) - pad_w
119
+ xmin = np.clip(xmin, 0, w - 1)
120
+ xmax = np.clip(xmax, 0, w - 1)
121
+ yy = np.random.choice(h_axis, size=2)
122
+ ymin = np.min(yy) - pad_h
123
+ ymax = np.max(yy) - pad_h
124
+ ymin = np.clip(ymin, 0, h - 1)
125
+ ymax = np.clip(ymax, 0, h - 1)
126
+ if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
127
+ # area too small
128
+ continue
129
+
130
+ pts = np.stack([[xmin, xmax, xmax, xmin],
131
+ [ymin, ymin, ymax, ymax]]).T.astype(np.int32)
132
+ pp = Polygon(pts)
133
+ fail_flag = False
134
+ for polygon, ignore_tag in zip(polygons, ignore_tags):
135
+ ppi = Polygon(polygon.reshape(-1, 2))
136
+ ppiou, _ = poly_intersection(ppi, pp, buffer=0)
137
+ if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
138
+ np.abs(ppiou) > self.epsilon:
139
+ fail_flag = True
140
+ break
141
+ elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
142
+ polys_new.append(polygon)
143
+ ignore_tags_new.append(ignore_tag)
144
+ else:
145
+ polys_keep.append(polygon)
146
+ ignore_tags_keep.append(ignore_tag)
147
+
148
+ if fail_flag:
149
+ continue
150
+ else:
151
+ break
152
+
153
+ cropped = image[ymin:ymax, xmin:xmax, :]
154
+ select_type = np.random.randint(3)
155
+ if select_type == 0:
156
+ img = np.ascontiguousarray(cropped[:, ::-1])
157
+ elif select_type == 1:
158
+ img = np.ascontiguousarray(cropped[::-1, :])
159
+ else:
160
+ img = np.ascontiguousarray(cropped[::-1, ::-1])
161
+ image[ymin:ymax, xmin:xmax, :] = img
162
+ results['img'] = image
163
+
164
+ if len(polys_new) != 0:
165
+ height, width, _ = cropped.shape
166
+ if select_type == 0:
167
+ for idx, polygon in enumerate(polys_new):
168
+ poly = polygon.reshape(-1, 2)
169
+ poly[:, 0] = width - poly[:, 0] + 2 * xmin
170
+ polys_new[idx] = poly
171
+ elif select_type == 1:
172
+ for idx, polygon in enumerate(polys_new):
173
+ poly = polygon.reshape(-1, 2)
174
+ poly[:, 1] = height - poly[:, 1] + 2 * ymin
175
+ polys_new[idx] = poly
176
+ else:
177
+ for idx, polygon in enumerate(polys_new):
178
+ poly = polygon.reshape(-1, 2)
179
+ poly[:, 0] = width - poly[:, 0] + 2 * xmin
180
+ poly[:, 1] = height - poly[:, 1] + 2 * ymin
181
+ polys_new[idx] = poly
182
+ polygons = polys_keep + polys_new
183
+ ignore_tags = ignore_tags_keep + ignore_tags_new
184
+ results['polys'] = np.array(polygons)
185
+ results['ignore_tags'] = ignore_tags
186
+
187
+ return results
188
+
189
+ def generate_crop_target(self, image, all_polys, pad_h, pad_w):
190
+ """Generate crop target and make sure not to crop the polygon
191
+ instances.
192
+
193
+ Args:
194
+ image (ndarray): The image waited to be crop.
195
+ all_polys (list[list[ndarray]]): All polygons including ground
196
+ truth polygons and ground truth ignored polygons.
197
+ pad_h (int): Padding length of height.
198
+ pad_w (int): Padding length of width.
199
+ Returns:
200
+ h_axis (ndarray): Vertical cropping range.
201
+ w_axis (ndarray): Horizontal cropping range.
202
+ """
203
+ h, w, _ = image.shape
204
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
205
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
206
+
207
+ text_polys = []
208
+ for polygon in all_polys:
209
+ rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2))
210
+ box = cv2.boxPoints(rect)
211
+ box = np.int0(box)
212
+ text_polys.append([box[0], box[1], box[2], box[3]])
213
+
214
+ polys = np.array(text_polys, dtype=np.int32)
215
+ for poly in polys:
216
+ poly = np.round(poly, decimals=0).astype(np.int32)
217
+ minx = np.min(poly[:, 0])
218
+ maxx = np.max(poly[:, 0])
219
+ w_array[minx + pad_w:maxx + pad_w] = 1
220
+ miny = np.min(poly[:, 1])
221
+ maxy = np.max(poly[:, 1])
222
+ h_array[miny + pad_h:maxy + pad_h] = 1
223
+
224
+ h_axis = np.where(h_array == 0)[0]
225
+ w_axis = np.where(w_array == 0)[0]
226
+ return h_axis, w_axis
227
+
228
+
229
+ class RandomCropPolyInstances:
230
+ """Randomly crop images and make sure to contain at least one intact
231
+ instance."""
232
+
233
+ def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs):
234
+ super().__init__()
235
+ self.crop_ratio = crop_ratio
236
+ self.min_side_ratio = min_side_ratio
237
+
238
+ def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
239
+
240
+ assert isinstance(min_len, int)
241
+ assert len(valid_array) > min_len
242
+
243
+ start_array = valid_array.copy()
244
+ max_start = min(len(start_array) - min_len, max_start)
245
+ start_array[max_start:] = 0
246
+ start_array[0] = 1
247
+ diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
248
+ region_starts = np.where(diff_array < 0)[0]
249
+ region_ends = np.where(diff_array > 0)[0]
250
+ region_ind = np.random.randint(0, len(region_starts))
251
+ start = np.random.randint(region_starts[region_ind],
252
+ region_ends[region_ind])
253
+
254
+ end_array = valid_array.copy()
255
+ min_end = max(start + min_len, min_end)
256
+ end_array[:min_end] = 0
257
+ end_array[-1] = 1
258
+ diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
259
+ region_starts = np.where(diff_array < 0)[0]
260
+ region_ends = np.where(diff_array > 0)[0]
261
+ region_ind = np.random.randint(0, len(region_starts))
262
+ end = np.random.randint(region_starts[region_ind],
263
+ region_ends[region_ind])
264
+ return start, end
265
+
266
+ def sample_crop_box(self, img_size, results):
267
+ """Generate crop box and make sure not to crop the polygon instances.
268
+
269
+ Args:
270
+ img_size (tuple(int)): The image size (h, w).
271
+ results (dict): The results dict.
272
+ """
273
+
274
+ assert isinstance(img_size, tuple)
275
+ h, w = img_size[:2]
276
+
277
+ key_masks = results['polys']
278
+
279
+ x_valid_array = np.ones(w, dtype=np.int32)
280
+ y_valid_array = np.ones(h, dtype=np.int32)
281
+
282
+ selected_mask = key_masks[np.random.randint(0, len(key_masks))]
283
+ selected_mask = selected_mask.reshape((-1, 2)).astype(np.int32)
284
+ max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
285
+ min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
286
+ max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
287
+ min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
288
+
289
+ for mask in key_masks:
290
+ mask = mask.reshape((-1, 2)).astype(np.int32)
291
+ clip_x = np.clip(mask[:, 0], 0, w - 1)
292
+ clip_y = np.clip(mask[:, 1], 0, h - 1)
293
+ min_x, max_x = np.min(clip_x), np.max(clip_x)
294
+ min_y, max_y = np.min(clip_y), np.max(clip_y)
295
+
296
+ x_valid_array[min_x - 2:max_x + 3] = 0
297
+ y_valid_array[min_y - 2:max_y + 3] = 0
298
+
299
+ min_w = int(w * self.min_side_ratio)
300
+ min_h = int(h * self.min_side_ratio)
301
+
302
+ x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start,
303
+ min_x_end)
304
+ y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start,
305
+ min_y_end)
306
+
307
+ return np.array([x1, y1, x2, y2])
308
+
309
+ def crop_img(self, img, bbox):
310
+ assert img.ndim == 3
311
+ h, w, _ = img.shape
312
+ assert 0 <= bbox[1] < bbox[3] <= h
313
+ assert 0 <= bbox[0] < bbox[2] <= w
314
+ return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
315
+
316
+ def __call__(self, results):
317
+ image = results['image']
318
+ polygons = results['polys']
319
+ ignore_tags = results['ignore_tags']
320
+ if len(polygons) < 1:
321
+ return results
322
+
323
+ if np.random.random_sample() < self.crop_ratio:
324
+
325
+ crop_box = self.sample_crop_box(image.shape, results)
326
+ img = self.crop_img(image, crop_box)
327
+ results['image'] = img
328
+ # crop and filter masks
329
+ x1, y1, x2, y2 = crop_box
330
+ w = max(x2 - x1, 1)
331
+ h = max(y2 - y1, 1)
332
+ polygons[:, :, 0::2] = polygons[:, :, 0::2] - x1
333
+ polygons[:, :, 1::2] = polygons[:, :, 1::2] - y1
334
+
335
+ valid_masks_list = []
336
+ valid_tags_list = []
337
+ for ind, polygon in enumerate(polygons):
338
+ if (polygon[:, ::2] > -4).all() and (
339
+ polygon[:, ::2] < w + 4).all() and (
340
+ polygon[:, 1::2] > -4).all() and (
341
+ polygon[:, 1::2] < h + 4).all():
342
+ polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w)
343
+ polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h)
344
+ valid_masks_list.append(polygon)
345
+ valid_tags_list.append(ignore_tags[ind])
346
+
347
+ results['polys'] = np.array(valid_masks_list)
348
+ results['ignore_tags'] = valid_tags_list
349
+
350
+ return results
351
+
352
+ def __repr__(self):
353
+ repr_str = self.__class__.__name__
354
+ return repr_str
355
+
356
+
357
+ class RandomRotatePolyInstances:
358
+ def __init__(self,
359
+ rotate_ratio=0.5,
360
+ max_angle=10,
361
+ pad_with_fixed_color=False,
362
+ pad_value=(0, 0, 0),
363
+ **kwargs):
364
+ """Randomly rotate images and polygon masks.
365
+
366
+ Args:
367
+ rotate_ratio (float): The ratio of samples to operate rotation.
368
+ max_angle (int): The maximum rotation angle.
369
+ pad_with_fixed_color (bool): The flag for whether to pad rotated
370
+ image with fixed value. If set to False, the rotated image will
371
+ be padded onto cropped image.
372
+ pad_value (tuple(int)): The color value for padding rotated image.
373
+ """
374
+ self.rotate_ratio = rotate_ratio
375
+ self.max_angle = max_angle
376
+ self.pad_with_fixed_color = pad_with_fixed_color
377
+ self.pad_value = pad_value
378
+
379
+ def rotate(self, center, points, theta, center_shift=(0, 0)):
380
+ # rotate points.
381
+ (center_x, center_y) = center
382
+ center_y = -center_y
383
+ x, y = points[:, ::2], points[:, 1::2]
384
+ y = -y
385
+
386
+ theta = theta / 180 * math.pi
387
+ cos = math.cos(theta)
388
+ sin = math.sin(theta)
389
+
390
+ x = (x - center_x)
391
+ y = (y - center_y)
392
+
393
+ _x = center_x + x * cos - y * sin + center_shift[0]
394
+ _y = -(center_y + x * sin + y * cos) + center_shift[1]
395
+
396
+ points[:, ::2], points[:, 1::2] = _x, _y
397
+ return points
398
+
399
+ def cal_canvas_size(self, ori_size, degree):
400
+ assert isinstance(ori_size, tuple)
401
+ angle = degree * math.pi / 180.0
402
+ h, w = ori_size[:2]
403
+
404
+ cos = math.cos(angle)
405
+ sin = math.sin(angle)
406
+ canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
407
+ canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
408
+
409
+ canvas_size = (canvas_h, canvas_w)
410
+ return canvas_size
411
+
412
+ def sample_angle(self, max_angle):
413
+ angle = np.random.random_sample() * 2 * max_angle - max_angle
414
+ return angle
415
+
416
+ def rotate_img(self, img, angle, canvas_size):
417
+ h, w = img.shape[:2]
418
+ rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
419
+ rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
420
+ rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
421
+
422
+ if self.pad_with_fixed_color:
423
+ target_img = cv2.warpAffine(
424
+ img,
425
+ rotation_matrix, (canvas_size[1], canvas_size[0]),
426
+ flags=cv2.INTER_NEAREST,
427
+ borderValue=self.pad_value)
428
+ else:
429
+ mask = np.zeros_like(img)
430
+ (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
431
+ np.random.randint(0, w * 7 // 8))
432
+ img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
433
+ img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
434
+
435
+ mask = cv2.warpAffine(
436
+ mask,
437
+ rotation_matrix, (canvas_size[1], canvas_size[0]),
438
+ borderValue=[1, 1, 1])
439
+ target_img = cv2.warpAffine(
440
+ img,
441
+ rotation_matrix, (canvas_size[1], canvas_size[0]),
442
+ borderValue=[0, 0, 0])
443
+ target_img = target_img + img_cut * mask
444
+
445
+ return target_img
446
+
447
+ def __call__(self, results):
448
+ if np.random.random_sample() < self.rotate_ratio:
449
+ image = results['image']
450
+ polygons = results['polys']
451
+ h, w = image.shape[:2]
452
+
453
+ angle = self.sample_angle(self.max_angle)
454
+ canvas_size = self.cal_canvas_size((h, w), angle)
455
+ center_shift = (int((canvas_size[1] - w) / 2), int(
456
+ (canvas_size[0] - h) / 2))
457
+ image = self.rotate_img(image, angle, canvas_size)
458
+ results['image'] = image
459
+ # rotate polygons
460
+ rotated_masks = []
461
+ for mask in polygons:
462
+ rotated_mask = self.rotate((w / 2, h / 2), mask, angle,
463
+ center_shift)
464
+ rotated_masks.append(rotated_mask)
465
+ results['polys'] = np.array(rotated_masks)
466
+
467
+ return results
468
+
469
+ def __repr__(self):
470
+ repr_str = self.__class__.__name__
471
+ return repr_str
472
+
473
+
474
+ class SquareResizePad:
475
+ def __init__(self,
476
+ target_size,
477
+ pad_ratio=0.6,
478
+ pad_with_fixed_color=False,
479
+ pad_value=(0, 0, 0),
480
+ **kwargs):
481
+ """Resize or pad images to be square shape.
482
+
483
+ Args:
484
+ target_size (int): The target size of square shaped image.
485
+ pad_with_fixed_color (bool): The flag for whether to pad rotated
486
+ image with fixed value. If set to False, the rescales image will
487
+ be padded onto cropped image.
488
+ pad_value (tuple(int)): The color value for padding rotated image.
489
+ """
490
+ assert isinstance(target_size, int)
491
+ assert isinstance(pad_ratio, float)
492
+ assert isinstance(pad_with_fixed_color, bool)
493
+ assert isinstance(pad_value, tuple)
494
+
495
+ self.target_size = target_size
496
+ self.pad_ratio = pad_ratio
497
+ self.pad_with_fixed_color = pad_with_fixed_color
498
+ self.pad_value = pad_value
499
+
500
+ def resize_img(self, img, keep_ratio=True):
501
+ h, w, _ = img.shape
502
+ if keep_ratio:
503
+ t_h = self.target_size if h >= w else int(h * self.target_size / w)
504
+ t_w = self.target_size if h <= w else int(w * self.target_size / h)
505
+ else:
506
+ t_h = t_w = self.target_size
507
+ img = cv2.resize(img, (t_w, t_h))
508
+ return img, (t_h, t_w)
509
+
510
+ def square_pad(self, img):
511
+ h, w = img.shape[:2]
512
+ if h == w:
513
+ return img, (0, 0)
514
+ pad_size = max(h, w)
515
+ if self.pad_with_fixed_color:
516
+ expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
517
+ expand_img[:] = self.pad_value
518
+ else:
519
+ (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
520
+ np.random.randint(0, w * 7 // 8))
521
+ img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
522
+ expand_img = cv2.resize(img_cut, (pad_size, pad_size))
523
+ if h > w:
524
+ y0, x0 = 0, (h - w) // 2
525
+ else:
526
+ y0, x0 = (w - h) // 2, 0
527
+ expand_img[y0:y0 + h, x0:x0 + w] = img
528
+ offset = (x0, y0)
529
+
530
+ return expand_img, offset
531
+
532
+ def square_pad_mask(self, points, offset):
533
+ x0, y0 = offset
534
+ pad_points = points.copy()
535
+ pad_points[::2] = pad_points[::2] + x0
536
+ pad_points[1::2] = pad_points[1::2] + y0
537
+ return pad_points
538
+
539
+ def __call__(self, results):
540
+ image = results['image']
541
+ polygons = results['polys']
542
+ h, w = image.shape[:2]
543
+
544
+ if np.random.random_sample() < self.pad_ratio:
545
+ image, out_size = self.resize_img(image, keep_ratio=True)
546
+ image, offset = self.square_pad(image)
547
+ else:
548
+ image, out_size = self.resize_img(image, keep_ratio=False)
549
+ offset = (0, 0)
550
+ results['image'] = image
551
+ try:
552
+ polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[
553
+ 1] / w + offset[0]
554
+ polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[
555
+ 0] / h + offset[1]
556
+ except:
557
+ pass
558
+ results['polys'] = polygons
559
+
560
+ return results
561
+
562
+ def __repr__(self):
563
+ repr_str = self.__class__.__name__
564
+ return repr_str
ppocr/data/imaug/fce_targets.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This code is refer from:
16
+ https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py
17
+ """
18
+
19
+ import cv2
20
+ import numpy as np
21
+ from numpy.fft import fft
22
+ from numpy.linalg import norm
23
+ import sys
24
+
25
+ def vector_slope(vec):
26
+ assert len(vec) == 2
27
+ return abs(vec[1] / (vec[0] + 1e-8))
28
+
29
+ class FCENetTargets:
30
+ """Generate the ground truth targets of FCENet: Fourier Contour Embedding
31
+ for Arbitrary-Shaped Text Detection.
32
+
33
+ [https://arxiv.org/abs/2104.10442]
34
+
35
+ Args:
36
+ fourier_degree (int): The maximum Fourier transform degree k.
37
+ resample_step (float): The step size for resampling the text center
38
+ line (TCL). It's better not to exceed half of the minimum width.
39
+ center_region_shrink_ratio (float): The shrink ratio of text center
40
+ region.
41
+ level_size_divisors (tuple(int)): The downsample ratio on each level.
42
+ level_proportion_range (tuple(tuple(int))): The range of text sizes
43
+ assigned to each level.
44
+ """
45
+
46
+ def __init__(self,
47
+ fourier_degree=5,
48
+ resample_step=4.0,
49
+ center_region_shrink_ratio=0.3,
50
+ level_size_divisors=(8, 16, 32),
51
+ level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)),
52
+ orientation_thr=2.0,
53
+ **kwargs):
54
+
55
+ super().__init__()
56
+ assert isinstance(level_size_divisors, tuple)
57
+ assert isinstance(level_proportion_range, tuple)
58
+ assert len(level_size_divisors) == len(level_proportion_range)
59
+ self.fourier_degree = fourier_degree
60
+ self.resample_step = resample_step
61
+ self.center_region_shrink_ratio = center_region_shrink_ratio
62
+ self.level_size_divisors = level_size_divisors
63
+ self.level_proportion_range = level_proportion_range
64
+
65
+ self.orientation_thr = orientation_thr
66
+
67
+ def vector_angle(self, vec1, vec2):
68
+ if vec1.ndim > 1:
69
+ unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
70
+ else:
71
+ unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
72
+ if vec2.ndim > 1:
73
+ unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
74
+ else:
75
+ unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
76
+ return np.arccos(
77
+ np.clip(
78
+ np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
79
+
80
+ def resample_line(self, line, n):
81
+ """Resample n points on a line.
82
+
83
+ Args:
84
+ line (ndarray): The points composing a line.
85
+ n (int): The resampled points number.
86
+
87
+ Returns:
88
+ resampled_line (ndarray): The points composing the resampled line.
89
+ """
90
+
91
+ assert line.ndim == 2
92
+ assert line.shape[0] >= 2
93
+ assert line.shape[1] == 2
94
+ assert isinstance(n, int)
95
+ assert n > 0
96
+
97
+ length_list = [
98
+ norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
99
+ ]
100
+ total_length = sum(length_list)
101
+ length_cumsum = np.cumsum([0.0] + length_list)
102
+ delta_length = total_length / (float(n) + 1e-8)
103
+
104
+ current_edge_ind = 0
105
+ resampled_line = [line[0]]
106
+
107
+ for i in range(1, n):
108
+ current_line_len = i * delta_length
109
+
110
+ while current_edge_ind + 1 < len(length_cumsum) and current_line_len >= length_cumsum[current_edge_ind + 1]:
111
+ current_edge_ind += 1
112
+
113
+ current_edge_end_shift = current_line_len - length_cumsum[
114
+ current_edge_ind]
115
+
116
+ if current_edge_ind >= len(length_list):
117
+ break
118
+ end_shift_ratio = current_edge_end_shift / length_list[
119
+ current_edge_ind]
120
+ current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
121
+ - line[current_edge_ind]
122
+ ) * end_shift_ratio
123
+ resampled_line.append(current_point)
124
+ resampled_line.append(line[-1])
125
+ resampled_line = np.array(resampled_line)
126
+
127
+ return resampled_line
128
+
129
+ def reorder_poly_edge(self, points):
130
+ """Get the respective points composing head edge, tail edge, top
131
+ sideline and bottom sideline.
132
+
133
+ Args:
134
+ points (ndarray): The points composing a text polygon.
135
+
136
+ Returns:
137
+ head_edge (ndarray): The two points composing the head edge of text
138
+ polygon.
139
+ tail_edge (ndarray): The two points composing the tail edge of text
140
+ polygon.
141
+ top_sideline (ndarray): The points composing top curved sideline of
142
+ text polygon.
143
+ bot_sideline (ndarray): The points composing bottom curved sideline
144
+ of text polygon.
145
+ """
146
+
147
+ assert points.ndim == 2
148
+ assert points.shape[0] >= 4
149
+ assert points.shape[1] == 2
150
+
151
+ head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
152
+ head_edge, tail_edge = points[head_inds], points[tail_inds]
153
+
154
+ pad_points = np.vstack([points, points])
155
+ if tail_inds[1] < 1:
156
+ tail_inds[1] = len(points)
157
+ sideline1 = pad_points[head_inds[1]:tail_inds[1]]
158
+ sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
159
+ sideline_mean_shift = np.mean(
160
+ sideline1, axis=0) - np.mean(
161
+ sideline2, axis=0)
162
+
163
+ if sideline_mean_shift[1] > 0:
164
+ top_sideline, bot_sideline = sideline2, sideline1
165
+ else:
166
+ top_sideline, bot_sideline = sideline1, sideline2
167
+
168
+ return head_edge, tail_edge, top_sideline, bot_sideline
169
+
170
+ def find_head_tail(self, points, orientation_thr):
171
+ """Find the head edge and tail edge of a text polygon.
172
+
173
+ Args:
174
+ points (ndarray): The points composing a text polygon.
175
+ orientation_thr (float): The threshold for distinguishing between
176
+ head edge and tail edge among the horizontal and vertical edges
177
+ of a quadrangle.
178
+
179
+ Returns:
180
+ head_inds (list): The indexes of two points composing head edge.
181
+ tail_inds (list): The indexes of two points composing tail edge.
182
+ """
183
+
184
+ assert points.ndim == 2
185
+ assert points.shape[0] >= 4
186
+ assert points.shape[1] == 2
187
+ assert isinstance(orientation_thr, float)
188
+
189
+ if len(points) > 4:
190
+ pad_points = np.vstack([points, points[0]])
191
+ edge_vec = pad_points[1:] - pad_points[:-1]
192
+
193
+ theta_sum = []
194
+ adjacent_vec_theta = []
195
+ for i, edge_vec1 in enumerate(edge_vec):
196
+ adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
197
+ adjacent_edge_vec = edge_vec[adjacent_ind]
198
+ temp_theta_sum = np.sum(
199
+ self.vector_angle(edge_vec1, adjacent_edge_vec))
200
+ temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
201
+ adjacent_edge_vec[1])
202
+ theta_sum.append(temp_theta_sum)
203
+ adjacent_vec_theta.append(temp_adjacent_theta)
204
+ theta_sum_score = np.array(theta_sum) / np.pi
205
+ adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
206
+ poly_center = np.mean(points, axis=0)
207
+ edge_dist = np.maximum(
208
+ norm(
209
+ pad_points[1:] - poly_center, axis=-1),
210
+ norm(
211
+ pad_points[:-1] - poly_center, axis=-1))
212
+ dist_score = edge_dist / np.max(edge_dist)
213
+ position_score = np.zeros(len(edge_vec))
214
+ score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
215
+ score += 0.35 * dist_score
216
+ if len(points) % 2 == 0:
217
+ position_score[(len(score) // 2 - 1)] += 1
218
+ position_score[-1] += 1
219
+ score += 0.1 * position_score
220
+ pad_score = np.concatenate([score, score])
221
+ score_matrix = np.zeros((len(score), len(score) - 3))
222
+ x = np.arange(len(score) - 3) / float(len(score) - 4)
223
+ gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
224
+ (x - 0.5) / 0.5, 2.) / 2)
225
+ gaussian = gaussian / np.max(gaussian)
226
+ for i in range(len(score)):
227
+ score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
228
+ score) - 1)] * gaussian * 0.3
229
+
230
+ head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
231
+ score_matrix.shape)
232
+ tail_start = (head_start + tail_increment + 2) % len(points)
233
+ head_end = (head_start + 1) % len(points)
234
+ tail_end = (tail_start + 1) % len(points)
235
+
236
+ if head_end > tail_end:
237
+ head_start, tail_start = tail_start, head_start
238
+ head_end, tail_end = tail_end, head_end
239
+ head_inds = [head_start, head_end]
240
+ tail_inds = [tail_start, tail_end]
241
+ else:
242
+ if vector_slope(points[1] - points[0]) + vector_slope(
243
+ points[3] - points[2]) < vector_slope(points[
244
+ 2] - points[1]) + vector_slope(points[0] - points[
245
+ 3]):
246
+ horizontal_edge_inds = [[0, 1], [2, 3]]
247
+ vertical_edge_inds = [[3, 0], [1, 2]]
248
+ else:
249
+ horizontal_edge_inds = [[3, 0], [1, 2]]
250
+ vertical_edge_inds = [[0, 1], [2, 3]]
251
+
252
+ vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
253
+ vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
254
+ 0]] - points[vertical_edge_inds[1][1]])
255
+ horizontal_len_sum = norm(points[horizontal_edge_inds[0][
256
+ 0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
257
+ horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
258
+ [1]])
259
+
260
+ if vertical_len_sum > horizontal_len_sum * orientation_thr:
261
+ head_inds = horizontal_edge_inds[0]
262
+ tail_inds = horizontal_edge_inds[1]
263
+ else:
264
+ head_inds = vertical_edge_inds[0]
265
+ tail_inds = vertical_edge_inds[1]
266
+
267
+ return head_inds, tail_inds
268
+
269
+ def resample_sidelines(self, sideline1, sideline2, resample_step):
270
+ """Resample two sidelines to be of the same points number according to
271
+ step size.
272
+
273
+ Args:
274
+ sideline1 (ndarray): The points composing a sideline of a text
275
+ polygon.
276
+ sideline2 (ndarray): The points composing another sideline of a
277
+ text polygon.
278
+ resample_step (float): The resampled step size.
279
+
280
+ Returns:
281
+ resampled_line1 (ndarray): The resampled line 1.
282
+ resampled_line2 (ndarray): The resampled line 2.
283
+ """
284
+
285
+ assert sideline1.ndim == sideline2.ndim == 2
286
+ assert sideline1.shape[1] == sideline2.shape[1] == 2
287
+ assert sideline1.shape[0] >= 2
288
+ assert sideline2.shape[0] >= 2
289
+ assert isinstance(resample_step, float)
290
+
291
+ length1 = sum([
292
+ norm(sideline1[i + 1] - sideline1[i])
293
+ for i in range(len(sideline1) - 1)
294
+ ])
295
+ length2 = sum([
296
+ norm(sideline2[i + 1] - sideline2[i])
297
+ for i in range(len(sideline2) - 1)
298
+ ])
299
+
300
+ total_length = (length1 + length2) / 2
301
+ resample_point_num = max(int(float(total_length) / resample_step), 1)
302
+
303
+ resampled_line1 = self.resample_line(sideline1, resample_point_num)
304
+ resampled_line2 = self.resample_line(sideline2, resample_point_num)
305
+
306
+ return resampled_line1, resampled_line2
307
+
308
+ def generate_center_region_mask(self, img_size, text_polys):
309
+ """Generate text center region mask.
310
+
311
+ Args:
312
+ img_size (tuple): The image size of (height, width).
313
+ text_polys (list[list[ndarray]]): The list of text polygons.
314
+
315
+ Returns:
316
+ center_region_mask (ndarray): The text center region mask.
317
+ """
318
+
319
+ assert isinstance(img_size, tuple)
320
+ # assert check_argument.is_2dlist(text_polys)
321
+
322
+ h, w = img_size
323
+
324
+ center_region_mask = np.zeros((h, w), np.uint8)
325
+
326
+ center_region_boxes = []
327
+ for poly in text_polys:
328
+ # assert len(poly) == 1
329
+ polygon_points = poly.reshape(-1, 2)
330
+ _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
331
+ resampled_top_line, resampled_bot_line = self.resample_sidelines(
332
+ top_line, bot_line, self.resample_step)
333
+ resampled_bot_line = resampled_bot_line[::-1]
334
+ if len(resampled_top_line) != len(resampled_bot_line):
335
+ continue
336
+ center_line = (resampled_top_line + resampled_bot_line) / 2
337
+
338
+ line_head_shrink_len = norm(resampled_top_line[0] -
339
+ resampled_bot_line[0]) / 4.0
340
+ line_tail_shrink_len = norm(resampled_top_line[-1] -
341
+ resampled_bot_line[-1]) / 4.0
342
+ head_shrink_num = int(line_head_shrink_len // self.resample_step)
343
+ tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
344
+ if len(center_line) > head_shrink_num + tail_shrink_num + 2:
345
+ center_line = center_line[head_shrink_num:len(center_line) -
346
+ tail_shrink_num]
347
+ resampled_top_line = resampled_top_line[head_shrink_num:len(
348
+ resampled_top_line) - tail_shrink_num]
349
+ resampled_bot_line = resampled_bot_line[head_shrink_num:len(
350
+ resampled_bot_line) - tail_shrink_num]
351
+
352
+ for i in range(0, len(center_line) - 1):
353
+ tl = center_line[i] + (resampled_top_line[i] - center_line[i]
354
+ ) * self.center_region_shrink_ratio
355
+ tr = center_line[i + 1] + (resampled_top_line[i + 1] -
356
+ center_line[i + 1]
357
+ ) * self.center_region_shrink_ratio
358
+ br = center_line[i + 1] + (resampled_bot_line[i + 1] -
359
+ center_line[i + 1]
360
+ ) * self.center_region_shrink_ratio
361
+ bl = center_line[i] + (resampled_bot_line[i] - center_line[i]
362
+ ) * self.center_region_shrink_ratio
363
+ current_center_box = np.vstack([tl, tr, br,
364
+ bl]).astype(np.int32)
365
+ center_region_boxes.append(current_center_box)
366
+
367
+ cv2.fillPoly(center_region_mask, center_region_boxes, 1)
368
+ return center_region_mask
369
+
370
+ def resample_polygon(self, polygon, n=400):
371
+ """Resample one polygon with n points on its boundary.
372
+
373
+ Args:
374
+ polygon (list[float]): The input polygon.
375
+ n (int): The number of resampled points.
376
+ Returns:
377
+ resampled_polygon (list[float]): The resampled polygon.
378
+ """
379
+ length = []
380
+
381
+ for i in range(len(polygon)):
382
+ p1 = polygon[i]
383
+ if i == len(polygon) - 1:
384
+ p2 = polygon[0]
385
+ else:
386
+ p2 = polygon[i + 1]
387
+ length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5)
388
+
389
+ total_length = sum(length)
390
+ n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
391
+ n_on_each_line = n_on_each_line.astype(np.int32)
392
+ new_polygon = []
393
+
394
+ for i in range(len(polygon)):
395
+ num = n_on_each_line[i]
396
+ p1 = polygon[i]
397
+ if i == len(polygon) - 1:
398
+ p2 = polygon[0]
399
+ else:
400
+ p2 = polygon[i + 1]
401
+
402
+ if num == 0:
403
+ continue
404
+
405
+ dxdy = (p2 - p1) / num
406
+ for j in range(num):
407
+ point = p1 + dxdy * j
408
+ new_polygon.append(point)
409
+
410
+ return np.array(new_polygon)
411
+
412
+ def normalize_polygon(self, polygon):
413
+ """Normalize one polygon so that its start point is at right most.
414
+
415
+ Args:
416
+ polygon (list[float]): The origin polygon.
417
+ Returns:
418
+ new_polygon (lost[float]): The polygon with start point at right.
419
+ """
420
+ temp_polygon = polygon - polygon.mean(axis=0)
421
+ x = np.abs(temp_polygon[:, 0])
422
+ y = temp_polygon[:, 1]
423
+ index_x = np.argsort(x)
424
+ index_y = np.argmin(y[index_x[:8]])
425
+ index = index_x[index_y]
426
+ new_polygon = np.concatenate([polygon[index:], polygon[:index]])
427
+ return new_polygon
428
+
429
+ def poly2fourier(self, polygon, fourier_degree):
430
+ """Perform Fourier transformation to generate Fourier coefficients ck
431
+ from polygon.
432
+
433
+ Args:
434
+ polygon (ndarray): An input polygon.
435
+ fourier_degree (int): The maximum Fourier degree K.
436
+ Returns:
437
+ c (ndarray(complex)): Fourier coefficients.
438
+ """
439
+ points = polygon[:, 0] + polygon[:, 1] * 1j
440
+ c_fft = fft(points) / len(points)
441
+ c = np.hstack((c_fft[-fourier_degree:], c_fft[:fourier_degree + 1]))
442
+ return c
443
+
444
+ def clockwise(self, c, fourier_degree):
445
+ """Make sure the polygon reconstructed from Fourier coefficients c in
446
+ the clockwise direction.
447
+
448
+ Args:
449
+ polygon (list[float]): The origin polygon.
450
+ Returns:
451
+ new_polygon (lost[float]): The polygon in clockwise point order.
452
+ """
453
+ if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
454
+ return c
455
+ elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
456
+ return c[::-1]
457
+ else:
458
+ if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
459
+ return c
460
+ else:
461
+ return c[::-1]
462
+
463
+ def cal_fourier_signature(self, polygon, fourier_degree):
464
+ """Calculate Fourier signature from input polygon.
465
+
466
+ Args:
467
+ polygon (ndarray): The input polygon.
468
+ fourier_degree (int): The maximum Fourier degree K.
469
+ Returns:
470
+ fourier_signature (ndarray): An array shaped (2k+1, 2) containing
471
+ real part and image part of 2k+1 Fourier coefficients.
472
+ """
473
+ resampled_polygon = self.resample_polygon(polygon)
474
+ resampled_polygon = self.normalize_polygon(resampled_polygon)
475
+
476
+ fourier_coeff = self.poly2fourier(resampled_polygon, fourier_degree)
477
+ fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
478
+
479
+ real_part = np.real(fourier_coeff).reshape((-1, 1))
480
+ image_part = np.imag(fourier_coeff).reshape((-1, 1))
481
+ fourier_signature = np.hstack([real_part, image_part])
482
+
483
+ return fourier_signature
484
+
485
+ def generate_fourier_maps(self, img_size, text_polys):
486
+ """Generate Fourier coefficient maps.
487
+
488
+ Args:
489
+ img_size (tuple): The image size of (height, width).
490
+ text_polys (list[list[ndarray]]): The list of text polygons.
491
+
492
+ Returns:
493
+ fourier_real_map (ndarray): The Fourier coefficient real part maps.
494
+ fourier_image_map (ndarray): The Fourier coefficient image part
495
+ maps.
496
+ """
497
+
498
+ assert isinstance(img_size, tuple)
499
+
500
+ h, w = img_size
501
+ k = self.fourier_degree
502
+ real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
503
+ imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
504
+
505
+ for poly in text_polys:
506
+ mask = np.zeros((h, w), dtype=np.uint8)
507
+ polygon = np.array(poly).reshape((1, -1, 2))
508
+ cv2.fillPoly(mask, polygon.astype(np.int32), 1)
509
+ fourier_coeff = self.cal_fourier_signature(polygon[0], k)
510
+ for i in range(-k, k + 1):
511
+ if i != 0:
512
+ real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + (
513
+ 1 - mask) * real_map[i + k, :, :]
514
+ imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + (
515
+ 1 - mask) * imag_map[i + k, :, :]
516
+ else:
517
+ yx = np.argwhere(mask > 0.5)
518
+ k_ind = np.ones((len(yx)), dtype=np.int64) * k
519
+ y, x = yx[:, 0], yx[:, 1]
520
+ real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
521
+ imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
522
+
523
+ return real_map, imag_map
524
+
525
+ def generate_text_region_mask(self, img_size, text_polys):
526
+ """Generate text center region mask and geometry attribute maps.
527
+
528
+ Args:
529
+ img_size (tuple): The image size (height, width).
530
+ text_polys (list[list[ndarray]]): The list of text polygons.
531
+
532
+ Returns:
533
+ text_region_mask (ndarray): The text region mask.
534
+ """
535
+
536
+ assert isinstance(img_size, tuple)
537
+
538
+ h, w = img_size
539
+ text_region_mask = np.zeros((h, w), dtype=np.uint8)
540
+
541
+ for poly in text_polys:
542
+ polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
543
+ cv2.fillPoly(text_region_mask, polygon, 1)
544
+
545
+ return text_region_mask
546
+
547
+ def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
548
+ """Generate effective mask by setting the ineffective regions to 0 and
549
+ effective regions to 1.
550
+
551
+ Args:
552
+ mask_size (tuple): The mask size.
553
+ polygons_ignore (list[[ndarray]]: The list of ignored text
554
+ polygons.
555
+
556
+ Returns:
557
+ mask (ndarray): The effective mask of (height, width).
558
+ """
559
+
560
+ mask = np.ones(mask_size, dtype=np.uint8)
561
+
562
+ for poly in polygons_ignore:
563
+ instance = poly.reshape(-1, 2).astype(np.int32).reshape(1, -1, 2)
564
+ cv2.fillPoly(mask, instance, 0)
565
+
566
+ return mask
567
+
568
+ def generate_level_targets(self, img_size, text_polys, ignore_polys):
569
+ """Generate ground truth target on each level.
570
+
571
+ Args:
572
+ img_size (list[int]): Shape of input image.
573
+ text_polys (list[list[ndarray]]): A list of ground truth polygons.
574
+ ignore_polys (list[list[ndarray]]): A list of ignored polygons.
575
+ Returns:
576
+ level_maps (list(ndarray)): A list of ground target on each level.
577
+ """
578
+ h, w = img_size
579
+ lv_size_divs = self.level_size_divisors
580
+ lv_proportion_range = self.level_proportion_range
581
+ lv_text_polys = [[] for i in range(len(lv_size_divs))]
582
+ lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
583
+ level_maps = []
584
+ for poly in text_polys:
585
+ polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2))
586
+ _, _, box_w, box_h = cv2.boundingRect(polygon)
587
+ proportion = max(box_h, box_w) / (h + 1e-8)
588
+
589
+ for ind, proportion_range in enumerate(lv_proportion_range):
590
+ if proportion_range[0] < proportion < proportion_range[1]:
591
+ lv_text_polys[ind].append(poly / lv_size_divs[ind])
592
+
593
+ for ignore_poly in ignore_polys:
594
+ polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2))
595
+ _, _, box_w, box_h = cv2.boundingRect(polygon)
596
+ proportion = max(box_h, box_w) / (h + 1e-8)
597
+
598
+ for ind, proportion_range in enumerate(lv_proportion_range):
599
+ if proportion_range[0] < proportion < proportion_range[1]:
600
+ lv_ignore_polys[ind].append(ignore_poly / lv_size_divs[ind])
601
+
602
+ for ind, size_divisor in enumerate(lv_size_divs):
603
+ current_level_maps = []
604
+ level_img_size = (h // size_divisor, w // size_divisor)
605
+
606
+ text_region = self.generate_text_region_mask(
607
+ level_img_size, lv_text_polys[ind])[None]
608
+ current_level_maps.append(text_region)
609
+
610
+ center_region = self.generate_center_region_mask(
611
+ level_img_size, lv_text_polys[ind])[None]
612
+ current_level_maps.append(center_region)
613
+
614
+ effective_mask = self.generate_effective_mask(
615
+ level_img_size, lv_ignore_polys[ind])[None]
616
+ current_level_maps.append(effective_mask)
617
+
618
+ fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
619
+ level_img_size, lv_text_polys[ind])
620
+ current_level_maps.append(fourier_real_map)
621
+ current_level_maps.append(fourier_image_maps)
622
+
623
+ level_maps.append(np.concatenate(current_level_maps))
624
+
625
+ return level_maps
626
+
627
+ def generate_targets(self, results):
628
+ """Generate the ground truth targets for FCENet.
629
+
630
+ Args:
631
+ results (dict): The input result dictionary.
632
+
633
+ Returns:
634
+ results (dict): The output result dictionary.
635
+ """
636
+
637
+ assert isinstance(results, dict)
638
+ image = results['image']
639
+ polygons = results['polys']
640
+ ignore_tags = results['ignore_tags']
641
+ h, w, _ = image.shape
642
+
643
+ polygon_masks = []
644
+ polygon_masks_ignore = []
645
+ for tag, polygon in zip(ignore_tags, polygons):
646
+ if tag is True:
647
+ polygon_masks_ignore.append(polygon)
648
+ else:
649
+ polygon_masks.append(polygon)
650
+
651
+ level_maps = self.generate_level_targets((h, w), polygon_masks,
652
+ polygon_masks_ignore)
653
+
654
+ mapping = {
655
+ 'p3_maps': level_maps[0],
656
+ 'p4_maps': level_maps[1],
657
+ 'p5_maps': level_maps[2]
658
+ }
659
+ for key, value in mapping.items():
660
+ results[key] = value
661
+
662
+ return results
663
+
664
+ def __call__(self, results):
665
+ results = self.generate_targets(results)
666
+ return results
ppocr/data/imaug/iaa_augment.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This code is refer from:
16
+ https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py
17
+ """
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+ from __future__ import unicode_literals
22
+
23
+ import numpy as np
24
+ import imgaug
25
+ import imgaug.augmenters as iaa
26
+
27
+
28
+ class AugmenterBuilder(object):
29
+ def __init__(self):
30
+ pass
31
+
32
+ def build(self, args, root=True):
33
+ if args is None or len(args) == 0:
34
+ return None
35
+ elif isinstance(args, list):
36
+ if root:
37
+ sequence = [self.build(value, root=False) for value in args]
38
+ return iaa.Sequential(sequence)
39
+ else:
40
+ return getattr(iaa, args[0])(
41
+ *[self.to_tuple_if_list(a) for a in args[1:]])
42
+ elif isinstance(args, dict):
43
+ cls = getattr(iaa, args['type'])
44
+ return cls(**{
45
+ k: self.to_tuple_if_list(v)
46
+ for k, v in args['args'].items()
47
+ })
48
+ else:
49
+ raise RuntimeError('unknown augmenter arg: ' + str(args))
50
+
51
+ def to_tuple_if_list(self, obj):
52
+ if isinstance(obj, list):
53
+ return tuple(obj)
54
+ return obj
55
+
56
+
57
+ class IaaAugment():
58
+ def __init__(self, augmenter_args=None, **kwargs):
59
+ if augmenter_args is None:
60
+ augmenter_args = [{
61
+ 'type': 'Fliplr',
62
+ 'args': {
63
+ 'p': 0.5
64
+ }
65
+ }, {
66
+ 'type': 'Affine',
67
+ 'args': {
68
+ 'rotate': [-10, 10]
69
+ }
70
+ }, {
71
+ 'type': 'Resize',
72
+ 'args': {
73
+ 'size': [0.5, 3]
74
+ }
75
+ }]
76
+ self.augmenter = AugmenterBuilder().build(augmenter_args)
77
+
78
+ def __call__(self, data):
79
+ image = data['image']
80
+ shape = image.shape
81
+
82
+ if self.augmenter:
83
+ aug = self.augmenter.to_deterministic()
84
+ data['image'] = aug.augment_image(image)
85
+ data = self.may_augment_annotation(aug, data, shape)
86
+ return data
87
+
88
+ def may_augment_annotation(self, aug, data, shape):
89
+ if aug is None:
90
+ return data
91
+
92
+ line_polys = []
93
+ for poly in data['polys']:
94
+ new_poly = self.may_augment_poly(aug, shape, poly)
95
+ line_polys.append(new_poly)
96
+ data['polys'] = np.array(line_polys)
97
+ return data
98
+
99
+ def may_augment_poly(self, aug, img_shape, poly):
100
+ keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
101
+ keypoints = aug.augment_keypoints(
102
+ [imgaug.KeypointsOnImage(
103
+ keypoints, shape=img_shape)])[0].keypoints
104
+ poly = [(p.x, p.y) for p in keypoints]
105
+ return poly
ppocr/data/imaug/label_ops.py ADDED
@@ -0,0 +1,1505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+ from __future__ import unicode_literals
19
+
20
+ import copy
21
+ import numpy as np
22
+ import string
23
+ from shapely.geometry import LineString, Point, Polygon
24
+ import json
25
+ import copy
26
+ from random import sample
27
+
28
+ from ppocr.utils.logging import get_logger
29
+ from ppocr.data.imaug.vqa.augment import order_by_tbyx
30
+
31
+
32
+ class ClsLabelEncode(object):
33
+ def __init__(self, label_list, **kwargs):
34
+ self.label_list = label_list
35
+
36
+ def __call__(self, data):
37
+ label = data['label']
38
+ if label not in self.label_list:
39
+ return None
40
+ label = self.label_list.index(label)
41
+ data['label'] = label
42
+ return data
43
+
44
+
45
+ class DetLabelEncode(object):
46
+ def __init__(self, **kwargs):
47
+ pass
48
+
49
+ def __call__(self, data):
50
+ label = data['label']
51
+ label = json.loads(label)
52
+ nBox = len(label)
53
+ boxes, txts, txt_tags = [], [], []
54
+ for bno in range(0, nBox):
55
+ box = label[bno]['points']
56
+ txt = label[bno]['transcription']
57
+ boxes.append(box)
58
+ txts.append(txt)
59
+ if txt in ['*', '###']:
60
+ txt_tags.append(True)
61
+ else:
62
+ txt_tags.append(False)
63
+ if len(boxes) == 0:
64
+ return None
65
+ boxes = self.expand_points_num(boxes)
66
+ boxes = np.array(boxes, dtype=np.float32)
67
+ txt_tags = np.array(txt_tags, dtype=bool)
68
+
69
+ data['polys'] = boxes
70
+ data['texts'] = txts
71
+ data['ignore_tags'] = txt_tags
72
+ return data
73
+
74
+ def order_points_clockwise(self, pts):
75
+ rect = np.zeros((4, 2), dtype="float32")
76
+ s = pts.sum(axis=1)
77
+ rect[0] = pts[np.argmin(s)]
78
+ rect[2] = pts[np.argmax(s)]
79
+ tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
80
+ diff = np.diff(np.array(tmp), axis=1)
81
+ rect[1] = tmp[np.argmin(diff)]
82
+ rect[3] = tmp[np.argmax(diff)]
83
+ return rect
84
+
85
+ def expand_points_num(self, boxes):
86
+ max_points_num = 0
87
+ for box in boxes:
88
+ if len(box) > max_points_num:
89
+ max_points_num = len(box)
90
+ ex_boxes = []
91
+ for box in boxes:
92
+ ex_box = box + [box[-1]] * (max_points_num - len(box))
93
+ ex_boxes.append(ex_box)
94
+ return ex_boxes
95
+
96
+
97
+ class BaseRecLabelEncode(object):
98
+ """ Convert between text-label and text-index """
99
+
100
+ def __init__(self,
101
+ max_text_length,
102
+ character_dict_path=None,
103
+ use_space_char=False,
104
+ lower=False):
105
+
106
+ self.max_text_len = max_text_length
107
+ self.beg_str = "sos"
108
+ self.end_str = "eos"
109
+ self.lower = lower
110
+
111
+ if character_dict_path is None:
112
+ logger = get_logger()
113
+ logger.warning(
114
+ "The character_dict_path is None, model can only recognize number and lower letters"
115
+ )
116
+ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
117
+ dict_character = list(self.character_str)
118
+ self.lower = True
119
+ else:
120
+ self.character_str = []
121
+ with open(character_dict_path, "rb") as fin:
122
+ lines = fin.readlines()
123
+ for line in lines:
124
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
125
+ self.character_str.append(line)
126
+ if use_space_char:
127
+ self.character_str.append(" ")
128
+ dict_character = list(self.character_str)
129
+ dict_character = self.add_special_char(dict_character)
130
+ self.dict = {}
131
+ for i, char in enumerate(dict_character):
132
+ self.dict[char] = i
133
+ self.character = dict_character
134
+
135
+ def add_special_char(self, dict_character):
136
+ return dict_character
137
+
138
+ def encode(self, text):
139
+ """convert text-label into text-index.
140
+ input:
141
+ text: text labels of each image. [batch_size]
142
+
143
+ output:
144
+ text: concatenated text index for CTCLoss.
145
+ [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
146
+ length: length of each text. [batch_size]
147
+ """
148
+ if len(text) == 0 or len(text) > self.max_text_len:
149
+ return None
150
+ if self.lower:
151
+ text = text.lower()
152
+ text_list = []
153
+ for char in text:
154
+ if char not in self.dict:
155
+ # logger = get_logger()
156
+ # logger.warning('{} is not in dict'.format(char))
157
+ continue
158
+ text_list.append(self.dict[char])
159
+ if len(text_list) == 0:
160
+ return None
161
+ return text_list
162
+
163
+
164
+ class CTCLabelEncode(BaseRecLabelEncode):
165
+ """ Convert between text-label and text-index """
166
+
167
+ def __init__(self,
168
+ max_text_length,
169
+ character_dict_path=None,
170
+ use_space_char=False,
171
+ **kwargs):
172
+ super(CTCLabelEncode, self).__init__(
173
+ max_text_length, character_dict_path, use_space_char)
174
+
175
+ def __call__(self, data):
176
+ text = data['label']
177
+ text = self.encode(text)
178
+ if text is None:
179
+ return None
180
+ data['length'] = np.array(len(text))
181
+ text = text + [0] * (self.max_text_len - len(text))
182
+ data['label'] = np.array(text)
183
+
184
+ label = [0] * len(self.character)
185
+ for x in text:
186
+ label[x] += 1
187
+ data['label_ace'] = np.array(label)
188
+ return data
189
+
190
+ def add_special_char(self, dict_character):
191
+ dict_character = ['blank'] + dict_character
192
+ return dict_character
193
+
194
+
195
+ class E2ELabelEncodeTest(BaseRecLabelEncode):
196
+ def __init__(self,
197
+ max_text_length,
198
+ character_dict_path=None,
199
+ use_space_char=False,
200
+ **kwargs):
201
+ super(E2ELabelEncodeTest, self).__init__(
202
+ max_text_length, character_dict_path, use_space_char)
203
+
204
+ def __call__(self, data):
205
+ import json
206
+ padnum = len(self.dict)
207
+ label = data['label']
208
+ label = json.loads(label)
209
+ nBox = len(label)
210
+ boxes, txts, txt_tags = [], [], []
211
+ for bno in range(0, nBox):
212
+ box = label[bno]['points']
213
+ txt = label[bno]['transcription']
214
+ boxes.append(box)
215
+ txts.append(txt)
216
+ if txt in ['*', '###']:
217
+ txt_tags.append(True)
218
+ else:
219
+ txt_tags.append(False)
220
+ boxes = np.array(boxes, dtype=np.float32)
221
+ txt_tags = np.array(txt_tags, dtype=bool)
222
+ data['polys'] = boxes
223
+ data['ignore_tags'] = txt_tags
224
+ temp_texts = []
225
+ for text in txts:
226
+ text = text.lower()
227
+ text = self.encode(text)
228
+ if text is None:
229
+ return None
230
+ text = text + [padnum] * (self.max_text_len - len(text)
231
+ ) # use 36 to pad
232
+ temp_texts.append(text)
233
+ data['texts'] = np.array(temp_texts)
234
+ return data
235
+
236
+
237
+ class E2ELabelEncodeTrain(object):
238
+ def __init__(self, **kwargs):
239
+ pass
240
+
241
+ def __call__(self, data):
242
+ import json
243
+ label = data['label']
244
+ label = json.loads(label)
245
+ nBox = len(label)
246
+ boxes, txts, txt_tags = [], [], []
247
+ for bno in range(0, nBox):
248
+ box = label[bno]['points']
249
+ txt = label[bno]['transcription']
250
+ boxes.append(box)
251
+ txts.append(txt)
252
+ if txt in ['*', '###']:
253
+ txt_tags.append(True)
254
+ else:
255
+ txt_tags.append(False)
256
+ boxes = np.array(boxes, dtype=np.float32)
257
+ txt_tags = np.array(txt_tags, dtype=bool)
258
+
259
+ data['polys'] = boxes
260
+ data['texts'] = txts
261
+ data['ignore_tags'] = txt_tags
262
+ return data
263
+
264
+
265
+ class KieLabelEncode(object):
266
+ def __init__(self,
267
+ character_dict_path,
268
+ class_path,
269
+ norm=10,
270
+ directed=False,
271
+ **kwargs):
272
+ super(KieLabelEncode, self).__init__()
273
+ self.dict = dict({'': 0})
274
+ self.label2classid_map = dict()
275
+ with open(character_dict_path, 'r', encoding='utf-8') as fr:
276
+ idx = 1
277
+ for line in fr:
278
+ char = line.strip()
279
+ self.dict[char] = idx
280
+ idx += 1
281
+ with open(class_path, "r") as fin:
282
+ lines = fin.readlines()
283
+ for idx, line in enumerate(lines):
284
+ line = line.strip("\n")
285
+ self.label2classid_map[line] = idx
286
+ self.norm = norm
287
+ self.directed = directed
288
+
289
+ def compute_relation(self, boxes):
290
+ """Compute relation between every two boxes."""
291
+ x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
292
+ x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
293
+ ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
294
+ dxs = (x1s[:, 0][None] - x1s) / self.norm
295
+ dys = (y1s[:, 0][None] - y1s) / self.norm
296
+ xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
297
+ whs = ws / hs + np.zeros_like(xhhs)
298
+ relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
299
+ bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
300
+ return relations, bboxes
301
+
302
+ def pad_text_indices(self, text_inds):
303
+ """Pad text index to same length."""
304
+ max_len = 300
305
+ recoder_len = max([len(text_ind) for text_ind in text_inds])
306
+ padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
307
+ for idx, text_ind in enumerate(text_inds):
308
+ padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
309
+ return padded_text_inds, recoder_len
310
+
311
+ def list_to_numpy(self, ann_infos):
312
+ """Convert bboxes, relations, texts and labels to ndarray."""
313
+ boxes, text_inds = ann_infos['points'], ann_infos['text_inds']
314
+ boxes = np.array(boxes, np.int32)
315
+ relations, bboxes = self.compute_relation(boxes)
316
+
317
+ labels = ann_infos.get('labels', None)
318
+ if labels is not None:
319
+ labels = np.array(labels, np.int32)
320
+ edges = ann_infos.get('edges', None)
321
+ if edges is not None:
322
+ labels = labels[:, None]
323
+ edges = np.array(edges)
324
+ edges = (edges[:, None] == edges[None, :]).astype(np.int32)
325
+ if self.directed:
326
+ edges = (edges & labels == 1).astype(np.int32)
327
+ np.fill_diagonal(edges, -1)
328
+ labels = np.concatenate([labels, edges], -1)
329
+ padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
330
+ max_num = 300
331
+ temp_bboxes = np.zeros([max_num, 4])
332
+ h, _ = bboxes.shape
333
+ temp_bboxes[:h, :] = bboxes
334
+
335
+ temp_relations = np.zeros([max_num, max_num, 5])
336
+ temp_relations[:h, :h, :] = relations
337
+
338
+ temp_padded_text_inds = np.zeros([max_num, max_num])
339
+ temp_padded_text_inds[:h, :] = padded_text_inds
340
+
341
+ temp_labels = np.zeros([max_num, max_num])
342
+ temp_labels[:h, :h + 1] = labels
343
+
344
+ tag = np.array([h, recoder_len])
345
+ return dict(
346
+ image=ann_infos['image'],
347
+ points=temp_bboxes,
348
+ relations=temp_relations,
349
+ texts=temp_padded_text_inds,
350
+ labels=temp_labels,
351
+ tag=tag)
352
+
353
+ def convert_canonical(self, points_x, points_y):
354
+
355
+ assert len(points_x) == 4
356
+ assert len(points_y) == 4
357
+
358
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
359
+
360
+ polygon = Polygon([(p.x, p.y) for p in points])
361
+ min_x, min_y, _, _ = polygon.bounds
362
+ points_to_lefttop = [
363
+ LineString([points[i], Point(min_x, min_y)]) for i in range(4)
364
+ ]
365
+ distances = np.array([line.length for line in points_to_lefttop])
366
+ sort_dist_idx = np.argsort(distances)
367
+ lefttop_idx = sort_dist_idx[0]
368
+
369
+ if lefttop_idx == 0:
370
+ point_orders = [0, 1, 2, 3]
371
+ elif lefttop_idx == 1:
372
+ point_orders = [1, 2, 3, 0]
373
+ elif lefttop_idx == 2:
374
+ point_orders = [2, 3, 0, 1]
375
+ else:
376
+ point_orders = [3, 0, 1, 2]
377
+
378
+ sorted_points_x = [points_x[i] for i in point_orders]
379
+ sorted_points_y = [points_y[j] for j in point_orders]
380
+
381
+ return sorted_points_x, sorted_points_y
382
+
383
+ def sort_vertex(self, points_x, points_y):
384
+
385
+ assert len(points_x) == 4
386
+ assert len(points_y) == 4
387
+
388
+ x = np.array(points_x)
389
+ y = np.array(points_y)
390
+ center_x = np.sum(x) * 0.25
391
+ center_y = np.sum(y) * 0.25
392
+
393
+ x_arr = np.array(x - center_x)
394
+ y_arr = np.array(y - center_y)
395
+
396
+ angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
397
+ sort_idx = np.argsort(angle)
398
+
399
+ sorted_points_x, sorted_points_y = [], []
400
+ for i in range(4):
401
+ sorted_points_x.append(points_x[sort_idx[i]])
402
+ sorted_points_y.append(points_y[sort_idx[i]])
403
+
404
+ return self.convert_canonical(sorted_points_x, sorted_points_y)
405
+
406
+ def __call__(self, data):
407
+ import json
408
+ label = data['label']
409
+ annotations = json.loads(label)
410
+ boxes, texts, text_inds, labels, edges = [], [], [], [], []
411
+ for ann in annotations:
412
+ box = ann['points']
413
+ x_list = [box[i][0] for i in range(4)]
414
+ y_list = [box[i][1] for i in range(4)]
415
+ sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
416
+ sorted_box = []
417
+ for x, y in zip(sorted_x_list, sorted_y_list):
418
+ sorted_box.append(x)
419
+ sorted_box.append(y)
420
+ boxes.append(sorted_box)
421
+ text = ann['transcription']
422
+ texts.append(ann['transcription'])
423
+ text_ind = [self.dict[c] for c in text if c in self.dict]
424
+ text_inds.append(text_ind)
425
+ if 'label' in ann.keys():
426
+ labels.append(self.label2classid_map[ann['label']])
427
+ elif 'key_cls' in ann.keys():
428
+ labels.append(ann['key_cls'])
429
+ else:
430
+ raise ValueError(
431
+ "Cannot found 'key_cls' in ann.keys(), please check your training annotation."
432
+ )
433
+ edges.append(ann.get('edge', 0))
434
+ ann_infos = dict(
435
+ image=data['image'],
436
+ points=boxes,
437
+ texts=texts,
438
+ text_inds=text_inds,
439
+ edges=edges,
440
+ labels=labels)
441
+
442
+ return self.list_to_numpy(ann_infos)
443
+
444
+
445
+ class AttnLabelEncode(BaseRecLabelEncode):
446
+ """ Convert between text-label and text-index """
447
+
448
+ def __init__(self,
449
+ max_text_length,
450
+ character_dict_path=None,
451
+ use_space_char=False,
452
+ **kwargs):
453
+ super(AttnLabelEncode, self).__init__(
454
+ max_text_length, character_dict_path, use_space_char)
455
+
456
+ def add_special_char(self, dict_character):
457
+ self.beg_str = "sos"
458
+ self.end_str = "eos"
459
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
460
+ return dict_character
461
+
462
+ def __call__(self, data):
463
+ text = data['label']
464
+ text = self.encode(text)
465
+ if text is None:
466
+ return None
467
+ if len(text) >= self.max_text_len:
468
+ return None
469
+ data['length'] = np.array(len(text))
470
+ text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
471
+ - len(text) - 2)
472
+ data['label'] = np.array(text)
473
+ return data
474
+
475
+ def get_ignored_tokens(self):
476
+ beg_idx = self.get_beg_end_flag_idx("beg")
477
+ end_idx = self.get_beg_end_flag_idx("end")
478
+ return [beg_idx, end_idx]
479
+
480
+ def get_beg_end_flag_idx(self, beg_or_end):
481
+ if beg_or_end == "beg":
482
+ idx = np.array(self.dict[self.beg_str])
483
+ elif beg_or_end == "end":
484
+ idx = np.array(self.dict[self.end_str])
485
+ else:
486
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
487
+ % beg_or_end
488
+ return idx
489
+
490
+
491
+ class RFLLabelEncode(BaseRecLabelEncode):
492
+ """ Convert between text-label and text-index """
493
+
494
+ def __init__(self,
495
+ max_text_length,
496
+ character_dict_path=None,
497
+ use_space_char=False,
498
+ **kwargs):
499
+ super(RFLLabelEncode, self).__init__(
500
+ max_text_length, character_dict_path, use_space_char)
501
+
502
+ def add_special_char(self, dict_character):
503
+ self.beg_str = "sos"
504
+ self.end_str = "eos"
505
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
506
+ return dict_character
507
+
508
+ def encode_cnt(self, text):
509
+ cnt_label = [0.0] * len(self.character)
510
+ for char_ in text:
511
+ cnt_label[char_] += 1
512
+ return np.array(cnt_label)
513
+
514
+ def __call__(self, data):
515
+ text = data['label']
516
+ text = self.encode(text)
517
+ if text is None:
518
+ return None
519
+ if len(text) >= self.max_text_len:
520
+ return None
521
+ cnt_label = self.encode_cnt(text)
522
+ data['length'] = np.array(len(text))
523
+ text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
524
+ - len(text) - 2)
525
+ if len(text) != self.max_text_len:
526
+ return None
527
+ data['label'] = np.array(text)
528
+ data['cnt_label'] = cnt_label
529
+ return data
530
+
531
+ def get_ignored_tokens(self):
532
+ beg_idx = self.get_beg_end_flag_idx("beg")
533
+ end_idx = self.get_beg_end_flag_idx("end")
534
+ return [beg_idx, end_idx]
535
+
536
+ def get_beg_end_flag_idx(self, beg_or_end):
537
+ if beg_or_end == "beg":
538
+ idx = np.array(self.dict[self.beg_str])
539
+ elif beg_or_end == "end":
540
+ idx = np.array(self.dict[self.end_str])
541
+ else:
542
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
543
+ % beg_or_end
544
+ return idx
545
+
546
+
547
+ class SEEDLabelEncode(BaseRecLabelEncode):
548
+ """ Convert between text-label and text-index """
549
+
550
+ def __init__(self,
551
+ max_text_length,
552
+ character_dict_path=None,
553
+ use_space_char=False,
554
+ **kwargs):
555
+ super(SEEDLabelEncode, self).__init__(
556
+ max_text_length, character_dict_path, use_space_char)
557
+
558
+ def add_special_char(self, dict_character):
559
+ self.padding = "padding"
560
+ self.end_str = "eos"
561
+ self.unknown = "unknown"
562
+ dict_character = dict_character + [
563
+ self.end_str, self.padding, self.unknown
564
+ ]
565
+ return dict_character
566
+
567
+ def __call__(self, data):
568
+ text = data['label']
569
+ text = self.encode(text)
570
+ if text is None:
571
+ return None
572
+ if len(text) >= self.max_text_len:
573
+ return None
574
+ data['length'] = np.array(len(text)) + 1 # conclude eos
575
+ text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
576
+ self.max_text_len - len(text) - 1)
577
+ data['label'] = np.array(text)
578
+ return data
579
+
580
+
581
+ class SRNLabelEncode(BaseRecLabelEncode):
582
+ """ Convert between text-label and text-index """
583
+
584
+ def __init__(self,
585
+ max_text_length=25,
586
+ character_dict_path=None,
587
+ use_space_char=False,
588
+ **kwargs):
589
+ super(SRNLabelEncode, self).__init__(
590
+ max_text_length, character_dict_path, use_space_char)
591
+
592
+ def add_special_char(self, dict_character):
593
+ dict_character = dict_character + [self.beg_str, self.end_str]
594
+ return dict_character
595
+
596
+ def __call__(self, data):
597
+ text = data['label']
598
+ text = self.encode(text)
599
+ char_num = len(self.character)
600
+ if text is None:
601
+ return None
602
+ if len(text) > self.max_text_len:
603
+ return None
604
+ data['length'] = np.array(len(text))
605
+ text = text + [char_num - 1] * (self.max_text_len - len(text))
606
+ data['label'] = np.array(text)
607
+ return data
608
+
609
+ def get_ignored_tokens(self):
610
+ beg_idx = self.get_beg_end_flag_idx("beg")
611
+ end_idx = self.get_beg_end_flag_idx("end")
612
+ return [beg_idx, end_idx]
613
+
614
+ def get_beg_end_flag_idx(self, beg_or_end):
615
+ if beg_or_end == "beg":
616
+ idx = np.array(self.dict[self.beg_str])
617
+ elif beg_or_end == "end":
618
+ idx = np.array(self.dict[self.end_str])
619
+ else:
620
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
621
+ % beg_or_end
622
+ return idx
623
+
624
+
625
+ class TableLabelEncode(AttnLabelEncode):
626
+ """ Convert between text-label and text-index """
627
+
628
+ def __init__(self,
629
+ max_text_length,
630
+ character_dict_path,
631
+ replace_empty_cell_token=False,
632
+ merge_no_span_structure=False,
633
+ learn_empty_box=False,
634
+ loc_reg_num=4,
635
+ **kwargs):
636
+ self.max_text_len = max_text_length
637
+ self.lower = False
638
+ self.learn_empty_box = learn_empty_box
639
+ self.merge_no_span_structure = merge_no_span_structure
640
+ self.replace_empty_cell_token = replace_empty_cell_token
641
+
642
+ dict_character = []
643
+ with open(character_dict_path, "rb") as fin:
644
+ lines = fin.readlines()
645
+ for line in lines:
646
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
647
+ dict_character.append(line)
648
+
649
+ if self.merge_no_span_structure:
650
+ if "<td></td>" not in dict_character:
651
+ dict_character.append("<td></td>")
652
+ if "<td>" in dict_character:
653
+ dict_character.remove("<td>")
654
+
655
+ dict_character = self.add_special_char(dict_character)
656
+ self.dict = {}
657
+ for i, char in enumerate(dict_character):
658
+ self.dict[char] = i
659
+ self.idx2char = {v: k for k, v in self.dict.items()}
660
+
661
+ self.character = dict_character
662
+ self.loc_reg_num = loc_reg_num
663
+ self.pad_idx = self.dict[self.beg_str]
664
+ self.start_idx = self.dict[self.beg_str]
665
+ self.end_idx = self.dict[self.end_str]
666
+
667
+ self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
668
+ self.empty_bbox_token_dict = {
669
+ "[]": '<eb></eb>',
670
+ "[' ']": '<eb1></eb1>',
671
+ "['<b>', ' ', '</b>']": '<eb2></eb2>',
672
+ "['\\u2028', '\\u2028']": '<eb3></eb3>',
673
+ "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
674
+ "['<b>', '</b>']": '<eb5></eb5>',
675
+ "['<i>', ' ', '</i>']": '<eb6></eb6>',
676
+ "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
677
+ "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
678
+ "['<i>', '</i>']": '<eb9></eb9>',
679
+ "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']":
680
+ '<eb10></eb10>',
681
+ }
682
+
683
+ @property
684
+ def _max_text_len(self):
685
+ return self.max_text_len + 2
686
+
687
+ def __call__(self, data):
688
+ cells = data['cells']
689
+ structure = data['structure']
690
+ if self.merge_no_span_structure:
691
+ structure = self._merge_no_span_structure(structure)
692
+ if self.replace_empty_cell_token:
693
+ structure = self._replace_empty_cell_token(structure, cells)
694
+ # remove empty token and add " " to span token
695
+ new_structure = []
696
+ for token in structure:
697
+ if token != '':
698
+ if 'span' in token and token[0] != ' ':
699
+ token = ' ' + token
700
+ new_structure.append(token)
701
+ # encode structure
702
+ structure = self.encode(new_structure)
703
+ if structure is None:
704
+ return None
705
+
706
+ structure = [self.start_idx] + structure + [self.end_idx
707
+ ] # add sos abd eos
708
+ structure = structure + [self.pad_idx] * (self._max_text_len -
709
+ len(structure)) # pad
710
+ structure = np.array(structure)
711
+ data['structure'] = structure
712
+
713
+ if len(structure) > self._max_text_len:
714
+ return None
715
+
716
+ # encode box
717
+ bboxes = np.zeros(
718
+ (self._max_text_len, self.loc_reg_num), dtype=np.float32)
719
+ bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
720
+
721
+ bbox_idx = 0
722
+
723
+ for i, token in enumerate(structure):
724
+ if self.idx2char[token] in self.td_token:
725
+ if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][
726
+ 'tokens']) > 0:
727
+ bbox = cells[bbox_idx]['bbox'].copy()
728
+ bbox = np.array(bbox, dtype=np.float32).reshape(-1)
729
+ bboxes[i] = bbox
730
+ bbox_masks[i] = 1.0
731
+ if self.learn_empty_box:
732
+ bbox_masks[i] = 1.0
733
+ bbox_idx += 1
734
+ data['bboxes'] = bboxes
735
+ data['bbox_masks'] = bbox_masks
736
+ return data
737
+
738
+ def _merge_no_span_structure(self, structure):
739
+ """
740
+ This code is refer from:
741
+ https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
742
+ """
743
+ new_structure = []
744
+ i = 0
745
+ while i < len(structure):
746
+ token = structure[i]
747
+ if token == '<td>':
748
+ token = '<td></td>'
749
+ i += 1
750
+ new_structure.append(token)
751
+ i += 1
752
+ return new_structure
753
+
754
+ def _replace_empty_cell_token(self, token_list, cells):
755
+ """
756
+ This fun code is refer from:
757
+ https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
758
+ """
759
+
760
+ bbox_idx = 0
761
+ add_empty_bbox_token_list = []
762
+ for token in token_list:
763
+ if token in ['<td></td>', '<td', '<td>']:
764
+ if 'bbox' not in cells[bbox_idx].keys():
765
+ content = str(cells[bbox_idx]['tokens'])
766
+ token = self.empty_bbox_token_dict[content]
767
+ add_empty_bbox_token_list.append(token)
768
+ bbox_idx += 1
769
+ else:
770
+ add_empty_bbox_token_list.append(token)
771
+ return add_empty_bbox_token_list
772
+
773
+
774
+ class TableMasterLabelEncode(TableLabelEncode):
775
+ """ Convert between text-label and text-index """
776
+
777
+ def __init__(self,
778
+ max_text_length,
779
+ character_dict_path,
780
+ replace_empty_cell_token=False,
781
+ merge_no_span_structure=False,
782
+ learn_empty_box=False,
783
+ loc_reg_num=4,
784
+ **kwargs):
785
+ super(TableMasterLabelEncode, self).__init__(
786
+ max_text_length, character_dict_path, replace_empty_cell_token,
787
+ merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs)
788
+ self.pad_idx = self.dict[self.pad_str]
789
+ self.unknown_idx = self.dict[self.unknown_str]
790
+
791
+ @property
792
+ def _max_text_len(self):
793
+ return self.max_text_len
794
+
795
+ def add_special_char(self, dict_character):
796
+ self.beg_str = '<SOS>'
797
+ self.end_str = '<EOS>'
798
+ self.unknown_str = '<UKN>'
799
+ self.pad_str = '<PAD>'
800
+ dict_character = dict_character
801
+ dict_character = dict_character + [
802
+ self.unknown_str, self.beg_str, self.end_str, self.pad_str
803
+ ]
804
+ return dict_character
805
+
806
+
807
+ class TableBoxEncode(object):
808
+ def __init__(self, in_box_format='xyxy', out_box_format='xyxy', **kwargs):
809
+ assert out_box_format in ['xywh', 'xyxy', 'xyxyxyxy']
810
+ self.in_box_format = in_box_format
811
+ self.out_box_format = out_box_format
812
+
813
+ def __call__(self, data):
814
+ img_height, img_width = data['image'].shape[:2]
815
+ bboxes = data['bboxes']
816
+ if self.in_box_format != self.out_box_format:
817
+ if self.out_box_format == 'xywh':
818
+ if self.in_box_format == 'xyxyxyxy':
819
+ bboxes = self.xyxyxyxy2xywh(bboxes)
820
+ elif self.in_box_format == 'xyxy':
821
+ bboxes = self.xyxy2xywh(bboxes)
822
+
823
+ bboxes[:, 0::2] /= img_width
824
+ bboxes[:, 1::2] /= img_height
825
+ data['bboxes'] = bboxes
826
+ return data
827
+
828
+ def xyxyxyxy2xywh(self, boxes):
829
+ new_bboxes = np.zeros([len(bboxes), 4])
830
+ new_bboxes[:, 0] = bboxes[:, 0::2].min() # x1
831
+ new_bboxes[:, 1] = bboxes[:, 1::2].min() # y1
832
+ new_bboxes[:, 2] = bboxes[:, 0::2].max() - new_bboxes[:, 0] # w
833
+ new_bboxes[:, 3] = bboxes[:, 1::2].max() - new_bboxes[:, 1] # h
834
+ return new_bboxes
835
+
836
+ def xyxy2xywh(self, bboxes):
837
+ new_bboxes = np.empty_like(bboxes)
838
+ new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
839
+ new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
840
+ new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width
841
+ new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height
842
+ return new_bboxes
843
+
844
+
845
+ class SARLabelEncode(BaseRecLabelEncode):
846
+ """ Convert between text-label and text-index """
847
+
848
+ def __init__(self,
849
+ max_text_length,
850
+ character_dict_path=None,
851
+ use_space_char=False,
852
+ **kwargs):
853
+ super(SARLabelEncode, self).__init__(
854
+ max_text_length, character_dict_path, use_space_char)
855
+
856
+ def add_special_char(self, dict_character):
857
+ beg_end_str = "<BOS/EOS>"
858
+ unknown_str = "<UKN>"
859
+ padding_str = "<PAD>"
860
+ dict_character = dict_character + [unknown_str]
861
+ self.unknown_idx = len(dict_character) - 1
862
+ dict_character = dict_character + [beg_end_str]
863
+ self.start_idx = len(dict_character) - 1
864
+ self.end_idx = len(dict_character) - 1
865
+ dict_character = dict_character + [padding_str]
866
+ self.padding_idx = len(dict_character) - 1
867
+
868
+ return dict_character
869
+
870
+ def __call__(self, data):
871
+ text = data['label']
872
+ text = self.encode(text)
873
+ if text is None:
874
+ return None
875
+ if len(text) >= self.max_text_len - 1:
876
+ return None
877
+ data['length'] = np.array(len(text))
878
+ target = [self.start_idx] + text + [self.end_idx]
879
+ padded_text = [self.padding_idx for _ in range(self.max_text_len)]
880
+
881
+ padded_text[:len(target)] = target
882
+ data['label'] = np.array(padded_text)
883
+ return data
884
+
885
+ def get_ignored_tokens(self):
886
+ return [self.padding_idx]
887
+
888
+
889
+ class PRENLabelEncode(BaseRecLabelEncode):
890
+ def __init__(self,
891
+ max_text_length,
892
+ character_dict_path,
893
+ use_space_char=False,
894
+ **kwargs):
895
+ super(PRENLabelEncode, self).__init__(
896
+ max_text_length, character_dict_path, use_space_char)
897
+
898
+ def add_special_char(self, dict_character):
899
+ padding_str = '<PAD>' # 0
900
+ end_str = '<EOS>' # 1
901
+ unknown_str = '<UNK>' # 2
902
+
903
+ dict_character = [padding_str, end_str, unknown_str] + dict_character
904
+ self.padding_idx = 0
905
+ self.end_idx = 1
906
+ self.unknown_idx = 2
907
+
908
+ return dict_character
909
+
910
+ def encode(self, text):
911
+ if len(text) == 0 or len(text) >= self.max_text_len:
912
+ return None
913
+ if self.lower:
914
+ text = text.lower()
915
+ text_list = []
916
+ for char in text:
917
+ if char not in self.dict:
918
+ text_list.append(self.unknown_idx)
919
+ else:
920
+ text_list.append(self.dict[char])
921
+ text_list.append(self.end_idx)
922
+ if len(text_list) < self.max_text_len:
923
+ text_list += [self.padding_idx] * (
924
+ self.max_text_len - len(text_list))
925
+ return text_list
926
+
927
+ def __call__(self, data):
928
+ text = data['label']
929
+ encoded_text = self.encode(text)
930
+ if encoded_text is None:
931
+ return None
932
+ data['label'] = np.array(encoded_text)
933
+ return data
934
+
935
+
936
+ class VQATokenLabelEncode(object):
937
+ """
938
+ Label encode for NLP VQA methods
939
+ """
940
+
941
+ def __init__(self,
942
+ class_path,
943
+ contains_re=False,
944
+ add_special_ids=False,
945
+ algorithm='LayoutXLM',
946
+ use_textline_bbox_info=True,
947
+ order_method=None,
948
+ infer_mode=False,
949
+ ocr_engine=None,
950
+ **kwargs):
951
+ super(VQATokenLabelEncode, self).__init__()
952
+ from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer
953
+ from ppocr.utils.utility import load_vqa_bio_label_maps
954
+ tokenizer_dict = {
955
+ 'LayoutXLM': {
956
+ 'class': LayoutXLMTokenizer,
957
+ 'pretrained_model': 'layoutxlm-base-uncased'
958
+ },
959
+ 'LayoutLM': {
960
+ 'class': LayoutLMTokenizer,
961
+ 'pretrained_model': 'layoutlm-base-uncased'
962
+ },
963
+ 'LayoutLMv2': {
964
+ 'class': LayoutLMv2Tokenizer,
965
+ 'pretrained_model': 'layoutlmv2-base-uncased'
966
+ }
967
+ }
968
+ self.contains_re = contains_re
969
+ tokenizer_config = tokenizer_dict[algorithm]
970
+ self.tokenizer = tokenizer_config['class'].from_pretrained(
971
+ tokenizer_config['pretrained_model'])
972
+ self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
973
+ self.add_special_ids = add_special_ids
974
+ self.infer_mode = infer_mode
975
+ self.ocr_engine = ocr_engine
976
+ self.use_textline_bbox_info = use_textline_bbox_info
977
+ self.order_method = order_method
978
+ assert self.order_method in [None, "tb-yx"]
979
+
980
+ def split_bbox(self, bbox, text, tokenizer):
981
+ words = text.split()
982
+ token_bboxes = []
983
+ curr_word_idx = 0
984
+ x1, y1, x2, y2 = bbox
985
+ unit_w = (x2 - x1) / len(text)
986
+ for idx, word in enumerate(words):
987
+ curr_w = len(word) * unit_w
988
+ word_bbox = [x1, y1, x1 + curr_w, y2]
989
+ token_bboxes.extend([word_bbox] * len(tokenizer.tokenize(word)))
990
+ x1 += (len(word) + 1) * unit_w
991
+ return token_bboxes
992
+
993
+ def filter_empty_contents(self, ocr_info):
994
+ """
995
+ find out the empty texts and remove the links
996
+ """
997
+ new_ocr_info = []
998
+ empty_index = []
999
+ for idx, info in enumerate(ocr_info):
1000
+ if len(info["transcription"]) > 0:
1001
+ new_ocr_info.append(copy.deepcopy(info))
1002
+ else:
1003
+ empty_index.append(info["id"])
1004
+
1005
+ for idx, info in enumerate(new_ocr_info):
1006
+ new_link = []
1007
+ for link in info["linking"]:
1008
+ if link[0] in empty_index or link[1] in empty_index:
1009
+ continue
1010
+ new_link.append(link)
1011
+ new_ocr_info[idx]["linking"] = new_link
1012
+ return new_ocr_info
1013
+
1014
+ def __call__(self, data):
1015
+ # load bbox and label info
1016
+ ocr_info = self._load_ocr_info(data)
1017
+
1018
+ for idx in range(len(ocr_info)):
1019
+ if "bbox" not in ocr_info[idx]:
1020
+ ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][
1021
+ "points"])
1022
+
1023
+ if self.order_method == "tb-yx":
1024
+ ocr_info = order_by_tbyx(ocr_info)
1025
+
1026
+ # for re
1027
+ train_re = self.contains_re and not self.infer_mode
1028
+ if train_re:
1029
+ ocr_info = self.filter_empty_contents(ocr_info)
1030
+
1031
+ height, width, _ = data['image'].shape
1032
+
1033
+ words_list = []
1034
+ bbox_list = []
1035
+ input_ids_list = []
1036
+ token_type_ids_list = []
1037
+ segment_offset_id = []
1038
+ gt_label_list = []
1039
+
1040
+ entities = []
1041
+
1042
+ if train_re:
1043
+ relations = []
1044
+ id2label = {}
1045
+ entity_id_to_index_map = {}
1046
+ empty_entity = set()
1047
+
1048
+ data['ocr_info'] = copy.deepcopy(ocr_info)
1049
+
1050
+ for info in ocr_info:
1051
+ text = info["transcription"]
1052
+ if len(text) <= 0:
1053
+ continue
1054
+ if train_re:
1055
+ # for re
1056
+ if len(text) == 0:
1057
+ empty_entity.add(info["id"])
1058
+ continue
1059
+ id2label[info["id"]] = info["label"]
1060
+ relations.extend([tuple(sorted(l)) for l in info["linking"]])
1061
+ # smooth_box
1062
+ info["bbox"] = self.trans_poly_to_bbox(info["points"])
1063
+
1064
+ encode_res = self.tokenizer.encode(
1065
+ text,
1066
+ pad_to_max_seq_len=False,
1067
+ return_attention_mask=True,
1068
+ return_token_type_ids=True)
1069
+
1070
+ if not self.add_special_ids:
1071
+ # TODO: use tok.all_special_ids to remove
1072
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
1073
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
1074
+ -1]
1075
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:
1076
+ -1]
1077
+
1078
+ if self.use_textline_bbox_info:
1079
+ bbox = [info["bbox"]] * len(encode_res["input_ids"])
1080
+ else:
1081
+ bbox = self.split_bbox(info["bbox"], info["transcription"],
1082
+ self.tokenizer)
1083
+ if len(bbox) <= 0:
1084
+ continue
1085
+ bbox = self._smooth_box(bbox, height, width)
1086
+ if self.add_special_ids:
1087
+ bbox.insert(0, [0, 0, 0, 0])
1088
+ bbox.append([0, 0, 0, 0])
1089
+
1090
+ # parse label
1091
+ if not self.infer_mode:
1092
+ label = info['label']
1093
+ gt_label = self._parse_label(label, encode_res)
1094
+
1095
+ # construct entities for re
1096
+ if train_re:
1097
+ if gt_label[0] != self.label2id_map["O"]:
1098
+ entity_id_to_index_map[info["id"]] = len(entities)
1099
+ label = label.upper()
1100
+ entities.append({
1101
+ "start": len(input_ids_list),
1102
+ "end":
1103
+ len(input_ids_list) + len(encode_res["input_ids"]),
1104
+ "label": label.upper(),
1105
+ })
1106
+ else:
1107
+ entities.append({
1108
+ "start": len(input_ids_list),
1109
+ "end": len(input_ids_list) + len(encode_res["input_ids"]),
1110
+ "label": 'O',
1111
+ })
1112
+ input_ids_list.extend(encode_res["input_ids"])
1113
+ token_type_ids_list.extend(encode_res["token_type_ids"])
1114
+ bbox_list.extend(bbox)
1115
+ words_list.append(text)
1116
+ segment_offset_id.append(len(input_ids_list))
1117
+ if not self.infer_mode:
1118
+ gt_label_list.extend(gt_label)
1119
+
1120
+ data['input_ids'] = input_ids_list
1121
+ data['token_type_ids'] = token_type_ids_list
1122
+ data['bbox'] = bbox_list
1123
+ data['attention_mask'] = [1] * len(input_ids_list)
1124
+ data['labels'] = gt_label_list
1125
+ data['segment_offset_id'] = segment_offset_id
1126
+ data['tokenizer_params'] = dict(
1127
+ padding_side=self.tokenizer.padding_side,
1128
+ pad_token_type_id=self.tokenizer.pad_token_type_id,
1129
+ pad_token_id=self.tokenizer.pad_token_id)
1130
+ data['entities'] = entities
1131
+
1132
+ if train_re:
1133
+ data['relations'] = relations
1134
+ data['id2label'] = id2label
1135
+ data['empty_entity'] = empty_entity
1136
+ data['entity_id_to_index_map'] = entity_id_to_index_map
1137
+ return data
1138
+
1139
+ def trans_poly_to_bbox(self, poly):
1140
+ x1 = int(np.min([p[0] for p in poly]))
1141
+ x2 = int(np.max([p[0] for p in poly]))
1142
+ y1 = int(np.min([p[1] for p in poly]))
1143
+ y2 = int(np.max([p[1] for p in poly]))
1144
+ return [x1, y1, x2, y2]
1145
+
1146
+ def _load_ocr_info(self, data):
1147
+ if self.infer_mode:
1148
+ ocr_result = self.ocr_engine.ocr(data['image'], cls=False)[0]
1149
+ ocr_info = []
1150
+ for res in ocr_result:
1151
+ ocr_info.append({
1152
+ "transcription": res[1][0],
1153
+ "bbox": self.trans_poly_to_bbox(res[0]),
1154
+ "points": res[0],
1155
+ })
1156
+ return ocr_info
1157
+ else:
1158
+ info = data['label']
1159
+ # read text info
1160
+ info_dict = json.loads(info)
1161
+ return info_dict
1162
+
1163
+ def _smooth_box(self, bboxes, height, width):
1164
+ bboxes = np.array(bboxes)
1165
+ bboxes[:, 0] = bboxes[:, 0] * 1000 / width
1166
+ bboxes[:, 2] = bboxes[:, 2] * 1000 / width
1167
+ bboxes[:, 1] = bboxes[:, 1] * 1000 / height
1168
+ bboxes[:, 3] = bboxes[:, 3] * 1000 / height
1169
+ bboxes = bboxes.astype("int64").tolist()
1170
+ return bboxes
1171
+
1172
+ def _parse_label(self, label, encode_res):
1173
+ gt_label = []
1174
+ if label.lower() in ["other", "others", "ignore"]:
1175
+ gt_label.extend([0] * len(encode_res["input_ids"]))
1176
+ else:
1177
+ gt_label.append(self.label2id_map[("b-" + label).upper()])
1178
+ gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
1179
+ (len(encode_res["input_ids"]) - 1))
1180
+ return gt_label
1181
+
1182
+
1183
+ class MultiLabelEncode(BaseRecLabelEncode):
1184
+ def __init__(self,
1185
+ max_text_length,
1186
+ character_dict_path=None,
1187
+ use_space_char=False,
1188
+ **kwargs):
1189
+ super(MultiLabelEncode, self).__init__(
1190
+ max_text_length, character_dict_path, use_space_char)
1191
+
1192
+ self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path,
1193
+ use_space_char, **kwargs)
1194
+ self.sar_encode = SARLabelEncode(max_text_length, character_dict_path,
1195
+ use_space_char, **kwargs)
1196
+
1197
+ def __call__(self, data):
1198
+ data_ctc = copy.deepcopy(data)
1199
+ data_sar = copy.deepcopy(data)
1200
+ data_out = dict()
1201
+ data_out['img_path'] = data.get('img_path', None)
1202
+ data_out['image'] = data['image']
1203
+ ctc = self.ctc_encode.__call__(data_ctc)
1204
+ sar = self.sar_encode.__call__(data_sar)
1205
+ if ctc is None or sar is None:
1206
+ return None
1207
+ data_out['label_ctc'] = ctc['label']
1208
+ data_out['label_sar'] = sar['label']
1209
+ data_out['length'] = ctc['length']
1210
+ return data_out
1211
+
1212
+
1213
+ class NRTRLabelEncode(BaseRecLabelEncode):
1214
+ """ Convert between text-label and text-index """
1215
+
1216
+ def __init__(self,
1217
+ max_text_length,
1218
+ character_dict_path=None,
1219
+ use_space_char=False,
1220
+ **kwargs):
1221
+
1222
+ super(NRTRLabelEncode, self).__init__(
1223
+ max_text_length, character_dict_path, use_space_char)
1224
+
1225
+ def __call__(self, data):
1226
+ text = data['label']
1227
+ text = self.encode(text)
1228
+ if text is None:
1229
+ return None
1230
+ if len(text) >= self.max_text_len - 1:
1231
+ return None
1232
+ data['length'] = np.array(len(text))
1233
+ text.insert(0, 2)
1234
+ text.append(3)
1235
+ text = text + [0] * (self.max_text_len - len(text))
1236
+ data['label'] = np.array(text)
1237
+ return data
1238
+
1239
+ def add_special_char(self, dict_character):
1240
+ dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
1241
+ return dict_character
1242
+
1243
+
1244
+ class ViTSTRLabelEncode(BaseRecLabelEncode):
1245
+ """ Convert between text-label and text-index """
1246
+
1247
+ def __init__(self,
1248
+ max_text_length,
1249
+ character_dict_path=None,
1250
+ use_space_char=False,
1251
+ ignore_index=0,
1252
+ **kwargs):
1253
+
1254
+ super(ViTSTRLabelEncode, self).__init__(
1255
+ max_text_length, character_dict_path, use_space_char)
1256
+ self.ignore_index = ignore_index
1257
+
1258
+ def __call__(self, data):
1259
+ text = data['label']
1260
+ text = self.encode(text)
1261
+ if text is None:
1262
+ return None
1263
+ if len(text) >= self.max_text_len:
1264
+ return None
1265
+ data['length'] = np.array(len(text))
1266
+ text.insert(0, self.ignore_index)
1267
+ text.append(1)
1268
+ text = text + [self.ignore_index] * (self.max_text_len + 2 - len(text))
1269
+ data['label'] = np.array(text)
1270
+ return data
1271
+
1272
+ def add_special_char(self, dict_character):
1273
+ dict_character = ['<s>', '</s>'] + dict_character
1274
+ return dict_character
1275
+
1276
+
1277
+ class ABINetLabelEncode(BaseRecLabelEncode):
1278
+ """ Convert between text-label and text-index """
1279
+
1280
+ def __init__(self,
1281
+ max_text_length,
1282
+ character_dict_path=None,
1283
+ use_space_char=False,
1284
+ ignore_index=100,
1285
+ **kwargs):
1286
+
1287
+ super(ABINetLabelEncode, self).__init__(
1288
+ max_text_length, character_dict_path, use_space_char)
1289
+ self.ignore_index = ignore_index
1290
+
1291
+ def __call__(self, data):
1292
+ text = data['label']
1293
+ text = self.encode(text)
1294
+ if text is None:
1295
+ return None
1296
+ if len(text) >= self.max_text_len:
1297
+ return None
1298
+ data['length'] = np.array(len(text))
1299
+ text.append(0)
1300
+ text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
1301
+ data['label'] = np.array(text)
1302
+ return data
1303
+
1304
+ def add_special_char(self, dict_character):
1305
+ dict_character = ['</s>'] + dict_character
1306
+ return dict_character
1307
+
1308
+
1309
+ class SRLabelEncode(BaseRecLabelEncode):
1310
+ def __init__(self,
1311
+ max_text_length,
1312
+ character_dict_path=None,
1313
+ use_space_char=False,
1314
+ **kwargs):
1315
+ super(SRLabelEncode, self).__init__(max_text_length,
1316
+ character_dict_path, use_space_char)
1317
+ self.dic = {}
1318
+ with open(character_dict_path, 'r') as fin:
1319
+ for line in fin.readlines():
1320
+ line = line.strip()
1321
+ character, sequence = line.split()
1322
+ self.dic[character] = sequence
1323
+ english_stroke_alphabet = '0123456789'
1324
+ self.english_stroke_dict = {}
1325
+ for index in range(len(english_stroke_alphabet)):
1326
+ self.english_stroke_dict[english_stroke_alphabet[index]] = index
1327
+
1328
+ def encode(self, label):
1329
+ stroke_sequence = ''
1330
+ for character in label:
1331
+ if character not in self.dic:
1332
+ continue
1333
+ else:
1334
+ stroke_sequence += self.dic[character]
1335
+ stroke_sequence += '0'
1336
+ label = stroke_sequence
1337
+
1338
+ length = len(label)
1339
+
1340
+ input_tensor = np.zeros(self.max_text_len).astype("int64")
1341
+ for j in range(length - 1):
1342
+ input_tensor[j + 1] = self.english_stroke_dict[label[j]]
1343
+
1344
+ return length, input_tensor
1345
+
1346
+ def __call__(self, data):
1347
+ text = data['label']
1348
+ length, input_tensor = self.encode(text)
1349
+
1350
+ data["length"] = length
1351
+ data["input_tensor"] = input_tensor
1352
+ if text is None:
1353
+ return None
1354
+ return data
1355
+
1356
+
1357
+ class SPINLabelEncode(AttnLabelEncode):
1358
+ """ Convert between text-label and text-index """
1359
+
1360
+ def __init__(self,
1361
+ max_text_length,
1362
+ character_dict_path=None,
1363
+ use_space_char=False,
1364
+ lower=True,
1365
+ **kwargs):
1366
+ super(SPINLabelEncode, self).__init__(
1367
+ max_text_length, character_dict_path, use_space_char)
1368
+ self.lower = lower
1369
+
1370
+ def add_special_char(self, dict_character):
1371
+ self.beg_str = "sos"
1372
+ self.end_str = "eos"
1373
+ dict_character = [self.beg_str] + [self.end_str] + dict_character
1374
+ return dict_character
1375
+
1376
+ def __call__(self, data):
1377
+ text = data['label']
1378
+ text = self.encode(text)
1379
+ if text is None:
1380
+ return None
1381
+ if len(text) > self.max_text_len:
1382
+ return None
1383
+ data['length'] = np.array(len(text))
1384
+ target = [0] + text + [1]
1385
+ padded_text = [0 for _ in range(self.max_text_len + 2)]
1386
+
1387
+ padded_text[:len(target)] = target
1388
+ data['label'] = np.array(padded_text)
1389
+ return data
1390
+
1391
+
1392
+ class VLLabelEncode(BaseRecLabelEncode):
1393
+ """ Convert between text-label and text-index """
1394
+
1395
+ def __init__(self,
1396
+ max_text_length,
1397
+ character_dict_path=None,
1398
+ use_space_char=False,
1399
+ **kwargs):
1400
+ super(VLLabelEncode, self).__init__(max_text_length,
1401
+ character_dict_path, use_space_char)
1402
+ self.dict = {}
1403
+ for i, char in enumerate(self.character):
1404
+ self.dict[char] = i
1405
+
1406
+ def __call__(self, data):
1407
+ text = data['label'] # original string
1408
+ # generate occluded text
1409
+ len_str = len(text)
1410
+ if len_str <= 0:
1411
+ return None
1412
+ change_num = 1
1413
+ order = list(range(len_str))
1414
+ change_id = sample(order, change_num)[0]
1415
+ label_sub = text[change_id]
1416
+ if change_id == (len_str - 1):
1417
+ label_res = text[:change_id]
1418
+ elif change_id == 0:
1419
+ label_res = text[1:]
1420
+ else:
1421
+ label_res = text[:change_id] + text[change_id + 1:]
1422
+
1423
+ data['label_res'] = label_res # remaining string
1424
+ data['label_sub'] = label_sub # occluded character
1425
+ data['label_id'] = change_id # character index
1426
+ # encode label
1427
+ text = self.encode(text)
1428
+ if text is None:
1429
+ return None
1430
+ text = [i + 1 for i in text]
1431
+ data['length'] = np.array(len(text))
1432
+ text = text + [0] * (self.max_text_len - len(text))
1433
+ data['label'] = np.array(text)
1434
+ label_res = self.encode(label_res)
1435
+ label_sub = self.encode(label_sub)
1436
+ if label_res is None:
1437
+ label_res = []
1438
+ else:
1439
+ label_res = [i + 1 for i in label_res]
1440
+ if label_sub is None:
1441
+ label_sub = []
1442
+ else:
1443
+ label_sub = [i + 1 for i in label_sub]
1444
+ data['length_res'] = np.array(len(label_res))
1445
+ data['length_sub'] = np.array(len(label_sub))
1446
+ label_res = label_res + [0] * (self.max_text_len - len(label_res))
1447
+ label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
1448
+ data['label_res'] = np.array(label_res)
1449
+ data['label_sub'] = np.array(label_sub)
1450
+ return data
1451
+
1452
+
1453
+ class CTLabelEncode(object):
1454
+ def __init__(self, **kwargs):
1455
+ pass
1456
+
1457
+ def __call__(self, data):
1458
+ label = data['label']
1459
+
1460
+ label = json.loads(label)
1461
+ nBox = len(label)
1462
+ boxes, txts = [], []
1463
+ for bno in range(0, nBox):
1464
+ box = label[bno]['points']
1465
+ box = np.array(box)
1466
+
1467
+ boxes.append(box)
1468
+ txt = label[bno]['transcription']
1469
+ txts.append(txt)
1470
+
1471
+ if len(boxes) == 0:
1472
+ return None
1473
+
1474
+ data['polys'] = boxes
1475
+ data['texts'] = txts
1476
+ return data
1477
+
1478
+
1479
+ class CANLabelEncode(BaseRecLabelEncode):
1480
+ def __init__(self,
1481
+ character_dict_path,
1482
+ max_text_length=100,
1483
+ use_space_char=False,
1484
+ lower=True,
1485
+ **kwargs):
1486
+ super(CANLabelEncode, self).__init__(
1487
+ max_text_length, character_dict_path, use_space_char, lower)
1488
+
1489
+ def encode(self, text_seq):
1490
+ text_seq_encoded = []
1491
+ for text in text_seq:
1492
+ if text not in self.character:
1493
+ continue
1494
+ text_seq_encoded.append(self.dict.get(text))
1495
+ if len(text_seq_encoded) == 0:
1496
+ return None
1497
+ return text_seq_encoded
1498
+
1499
+ def __call__(self, data):
1500
+ label = data['label']
1501
+ if isinstance(label, str):
1502
+ label = label.strip().split()
1503
+ label.append(self.end_str)
1504
+ data['label'] = self.encode(label)
1505
+ return data
ppocr/data/imaug/make_border_map.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This code is refer from:
16
+ https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py
17
+ """
18
+
19
+ from __future__ import absolute_import
20
+ from __future__ import division
21
+ from __future__ import print_function
22
+ from __future__ import unicode_literals
23
+
24
+ import numpy as np
25
+ import cv2
26
+
27
+ np.seterr(divide='ignore', invalid='ignore')
28
+ import pyclipper
29
+ from shapely.geometry import Polygon
30
+ import sys
31
+ import warnings
32
+
33
+ warnings.simplefilter("ignore")
34
+
35
+ __all__ = ['MakeBorderMap']
36
+
37
+
38
+ class MakeBorderMap(object):
39
+ def __init__(self,
40
+ shrink_ratio=0.4,
41
+ thresh_min=0.3,
42
+ thresh_max=0.7,
43
+ **kwargs):
44
+ self.shrink_ratio = shrink_ratio
45
+ self.thresh_min = thresh_min
46
+ self.thresh_max = thresh_max
47
+
48
+ def __call__(self, data):
49
+
50
+ img = data['image']
51
+ text_polys = data['polys']
52
+ ignore_tags = data['ignore_tags']
53
+
54
+ canvas = np.zeros(img.shape[:2], dtype=np.float32)
55
+ mask = np.zeros(img.shape[:2], dtype=np.float32)
56
+
57
+ for i in range(len(text_polys)):
58
+ if ignore_tags[i]:
59
+ continue
60
+ self.draw_border_map(text_polys[i], canvas, mask=mask)
61
+ canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
62
+
63
+ data['threshold_map'] = canvas
64
+ data['threshold_mask'] = mask
65
+ return data
66
+
67
+ def draw_border_map(self, polygon, canvas, mask):
68
+ polygon = np.array(polygon)
69
+ assert polygon.ndim == 2
70
+ assert polygon.shape[1] == 2
71
+
72
+ polygon_shape = Polygon(polygon)
73
+ if polygon_shape.area <= 0:
74
+ return
75
+ distance = polygon_shape.area * (
76
+ 1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
77
+ subject = [tuple(l) for l in polygon]
78
+ padding = pyclipper.PyclipperOffset()
79
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
80
+
81
+ padded_polygon = np.array(padding.Execute(distance)[0])
82
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
83
+
84
+ xmin = padded_polygon[:, 0].min()
85
+ xmax = padded_polygon[:, 0].max()
86
+ ymin = padded_polygon[:, 1].min()
87
+ ymax = padded_polygon[:, 1].max()
88
+ width = xmax - xmin + 1
89
+ height = ymax - ymin + 1
90
+
91
+ polygon[:, 0] = polygon[:, 0] - xmin
92
+ polygon[:, 1] = polygon[:, 1] - ymin
93
+
94
+ xs = np.broadcast_to(
95
+ np.linspace(
96
+ 0, width - 1, num=width).reshape(1, width), (height, width))
97
+ ys = np.broadcast_to(
98
+ np.linspace(
99
+ 0, height - 1, num=height).reshape(height, 1), (height, width))
100
+
101
+ distance_map = np.zeros(
102
+ (polygon.shape[0], height, width), dtype=np.float32)
103
+ for i in range(polygon.shape[0]):
104
+ j = (i + 1) % polygon.shape[0]
105
+ absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
106
+ distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
107
+ distance_map = distance_map.min(axis=0)
108
+
109
+ xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
110
+ xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
111
+ ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
112
+ ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
113
+ canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
114
+ 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
115
+ xmin_valid - xmin:xmax_valid - xmax + width],
116
+ canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
117
+
118
+ def _distance(self, xs, ys, point_1, point_2):
119
+ '''
120
+ compute the distance from point to a line
121
+ ys: coordinates in the first axis
122
+ xs: coordinates in the second axis
123
+ point_1, point_2: (x, y), the end of the line
124
+ '''
125
+ height, width = xs.shape[:2]
126
+ square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[
127
+ 1])
128
+ square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[
129
+ 1])
130
+ square_distance = np.square(point_1[0] - point_2[0]) + np.square(
131
+ point_1[1] - point_2[1])
132
+
133
+ cosin = (square_distance - square_distance_1 - square_distance_2) / (
134
+ 2 * np.sqrt(square_distance_1 * square_distance_2))
135
+ square_sin = 1 - np.square(cosin)
136
+ square_sin = np.nan_to_num(square_sin)
137
+ result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
138
+ square_distance)
139
+
140
+ result[cosin <
141
+ 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin
142
+ < 0]
143
+ # self.extend_line(point_1, point_2, result)
144
+ return result
145
+
146
+ def extend_line(self, point_1, point_2, result, shrink_ratio):
147
+ ex_point_1 = (int(
148
+ round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
149
+ int(
150
+ round(point_1[1] + (point_1[1] - point_2[1]) * (
151
+ 1 + shrink_ratio))))
152
+ cv2.line(
153
+ result,
154
+ tuple(ex_point_1),
155
+ tuple(point_1),
156
+ 4096.0,
157
+ 1,
158
+ lineType=cv2.LINE_AA,
159
+ shift=0)
160
+ ex_point_2 = (int(
161
+ round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
162
+ int(
163
+ round(point_2[1] + (point_2[1] - point_1[1]) * (
164
+ 1 + shrink_ratio))))
165
+ cv2.line(
166
+ result,
167
+ tuple(ex_point_2),
168
+ tuple(point_2),
169
+ 4096.0,
170
+ 1,
171
+ lineType=cv2.LINE_AA,
172
+ shift=0)
173
+ return ex_point_1, ex_point_2
ppocr/data/imaug/make_pse_gt.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+ from __future__ import unicode_literals
19
+
20
+ import cv2
21
+ import numpy as np
22
+ import pyclipper
23
+ from shapely.geometry import Polygon
24
+
25
+ __all__ = ['MakePseGt']
26
+
27
+
28
+ class MakePseGt(object):
29
+ def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
30
+ self.kernel_num = kernel_num
31
+ self.min_shrink_ratio = min_shrink_ratio
32
+ self.size = size
33
+
34
+ def __call__(self, data):
35
+
36
+ image = data['image']
37
+ text_polys = data['polys']
38
+ ignore_tags = data['ignore_tags']
39
+
40
+ h, w, _ = image.shape
41
+ short_edge = min(h, w)
42
+ if short_edge < self.size:
43
+ # keep short_size >= self.size
44
+ scale = self.size / short_edge
45
+ image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
46
+ text_polys *= scale
47
+
48
+ gt_kernels = []
49
+ for i in range(1, self.kernel_num + 1):
50
+ # s1->sn, from big to small
51
+ rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1
52
+ ) * i
53
+ text_kernel, ignore_tags = self.generate_kernel(
54
+ image.shape[0:2], rate, text_polys, ignore_tags)
55
+ gt_kernels.append(text_kernel)
56
+
57
+ training_mask = np.ones(image.shape[0:2], dtype='uint8')
58
+ for i in range(text_polys.shape[0]):
59
+ if ignore_tags[i]:
60
+ cv2.fillPoly(training_mask,
61
+ text_polys[i].astype(np.int32)[np.newaxis, :, :],
62
+ 0)
63
+
64
+ gt_kernels = np.array(gt_kernels)
65
+ gt_kernels[gt_kernels > 0] = 1
66
+
67
+ data['image'] = image
68
+ data['polys'] = text_polys
69
+ data['gt_kernels'] = gt_kernels[0:]
70
+ data['gt_text'] = gt_kernels[0]
71
+ data['mask'] = training_mask.astype('float32')
72
+ return data
73
+
74
+ def generate_kernel(self,
75
+ img_size,
76
+ shrink_ratio,
77
+ text_polys,
78
+ ignore_tags=None):
79
+ """
80
+ Refer to part of the code:
81
+ https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
82
+ """
83
+
84
+ h, w = img_size
85
+ text_kernel = np.zeros((h, w), dtype=np.float32)
86
+ for i, poly in enumerate(text_polys):
87
+ polygon = Polygon(poly)
88
+ distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (
89
+ polygon.length + 1e-6)
90
+ subject = [tuple(l) for l in poly]
91
+ pco = pyclipper.PyclipperOffset()
92
+ pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
93
+ shrinked = np.array(pco.Execute(-distance))
94
+
95
+ if len(shrinked) == 0 or shrinked.size == 0:
96
+ if ignore_tags is not None:
97
+ ignore_tags[i] = True
98
+ continue
99
+ try:
100
+ shrinked = np.array(shrinked[0]).reshape(-1, 2)
101
+ except:
102
+ if ignore_tags is not None:
103
+ ignore_tags[i] = True
104
+ continue
105
+ cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
106
+ return text_kernel, ignore_tags
ppocr/data/imaug/make_shrink_map.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This code is refer from:
16
+ https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py
17
+ """
18
+
19
+ from __future__ import absolute_import
20
+ from __future__ import division
21
+ from __future__ import print_function
22
+ from __future__ import unicode_literals
23
+
24
+ import numpy as np
25
+ import cv2
26
+ from shapely.geometry import Polygon
27
+ import pyclipper
28
+
29
+ __all__ = ['MakeShrinkMap']
30
+
31
+
32
+ class MakeShrinkMap(object):
33
+ r'''
34
+ Making binary mask from detection data with ICDAR format.
35
+ Typically following the process of class `MakeICDARData`.
36
+ '''
37
+
38
+ def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
39
+ self.min_text_size = min_text_size
40
+ self.shrink_ratio = shrink_ratio
41
+
42
+ def __call__(self, data):
43
+ image = data['image']
44
+ text_polys = data['polys']
45
+ ignore_tags = data['ignore_tags']
46
+
47
+ h, w = image.shape[:2]
48
+ text_polys, ignore_tags = self.validate_polygons(text_polys,
49
+ ignore_tags, h, w)
50
+ gt = np.zeros((h, w), dtype=np.float32)
51
+ mask = np.ones((h, w), dtype=np.float32)
52
+ for i in range(len(text_polys)):
53
+ polygon = text_polys[i]
54
+ height = max(polygon[:, 1]) - min(polygon[:, 1])
55
+ width = max(polygon[:, 0]) - min(polygon[:, 0])
56
+ if ignore_tags[i] or min(height, width) < self.min_text_size:
57
+ cv2.fillPoly(mask,
58
+ polygon.astype(np.int32)[np.newaxis, :, :], 0)
59
+ ignore_tags[i] = True
60
+ else:
61
+ polygon_shape = Polygon(polygon)
62
+ subject = [tuple(l) for l in polygon]
63
+ padding = pyclipper.PyclipperOffset()
64
+ padding.AddPath(subject, pyclipper.JT_ROUND,
65
+ pyclipper.ET_CLOSEDPOLYGON)
66
+ shrinked = []
67
+
68
+ # Increase the shrink ratio every time we get multiple polygon returned back
69
+ possible_ratios = np.arange(self.shrink_ratio, 1,
70
+ self.shrink_ratio)
71
+ np.append(possible_ratios, 1)
72
+ # print(possible_ratios)
73
+ for ratio in possible_ratios:
74
+ # print(f"Change shrink ratio to {ratio}")
75
+ distance = polygon_shape.area * (
76
+ 1 - np.power(ratio, 2)) / polygon_shape.length
77
+ shrinked = padding.Execute(-distance)
78
+ if len(shrinked) == 1:
79
+ break
80
+
81
+ if shrinked == []:
82
+ cv2.fillPoly(mask,
83
+ polygon.astype(np.int32)[np.newaxis, :, :], 0)
84
+ ignore_tags[i] = True
85
+ continue
86
+
87
+ for each_shirnk in shrinked:
88
+ shirnk = np.array(each_shirnk).reshape(-1, 2)
89
+ cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
90
+
91
+ data['shrink_map'] = gt
92
+ data['shrink_mask'] = mask
93
+ return data
94
+
95
+ def validate_polygons(self, polygons, ignore_tags, h, w):
96
+ '''
97
+ polygons (numpy.array, required): of shape (num_instances, num_points, 2)
98
+ '''
99
+ if len(polygons) == 0:
100
+ return polygons, ignore_tags
101
+ assert len(polygons) == len(ignore_tags)
102
+ for polygon in polygons:
103
+ polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
104
+ polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
105
+
106
+ for i in range(len(polygons)):
107
+ area = self.polygon_area(polygons[i])
108
+ if abs(area) < 1:
109
+ ignore_tags[i] = True
110
+ if area > 0:
111
+ polygons[i] = polygons[i][::-1, :]
112
+ return polygons, ignore_tags
113
+
114
+ def polygon_area(self, polygon):
115
+ """
116
+ compute polygon area
117
+ """
118
+ area = 0
119
+ q = polygon[-1]
120
+ for p in polygon:
121
+ area += p[0] * q[1] - p[1] * q[0]
122
+ q = p
123
+ return area / 2.0
ppocr/data/imaug/operators.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+ from __future__ import unicode_literals
21
+
22
+ import sys
23
+ import six
24
+ import cv2
25
+ import numpy as np
26
+ import math
27
+ from PIL import Image
28
+
29
+
30
+ class DecodeImage(object):
31
+ """ decode image """
32
+
33
+ def __init__(self,
34
+ img_mode='RGB',
35
+ channel_first=False,
36
+ ignore_orientation=False,
37
+ **kwargs):
38
+ self.img_mode = img_mode
39
+ self.channel_first = channel_first
40
+ self.ignore_orientation = ignore_orientation
41
+
42
+ def __call__(self, data):
43
+ img = data['image']
44
+ if six.PY2:
45
+ assert type(img) is str and len(
46
+ img) > 0, "invalid input 'img' in DecodeImage"
47
+ else:
48
+ assert type(img) is bytes and len(
49
+ img) > 0, "invalid input 'img' in DecodeImage"
50
+ img = np.frombuffer(img, dtype='uint8')
51
+ if self.ignore_orientation:
52
+ img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
53
+ cv2.IMREAD_COLOR)
54
+ else:
55
+ img = cv2.imdecode(img, 1)
56
+ if img is None:
57
+ return None
58
+ if self.img_mode == 'GRAY':
59
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
60
+ elif self.img_mode == 'RGB':
61
+ assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
62
+ img = img[:, :, ::-1]
63
+
64
+ if self.channel_first:
65
+ img = img.transpose((2, 0, 1))
66
+
67
+ data['image'] = img
68
+ return data
69
+
70
+
71
+ class NormalizeImage(object):
72
+ """ normalize image such as substract mean, divide std
73
+ """
74
+
75
+ def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
76
+ if isinstance(scale, str):
77
+ scale = eval(scale)
78
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
79
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
80
+ std = std if std is not None else [0.229, 0.224, 0.225]
81
+
82
+ shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
83
+ self.mean = np.array(mean).reshape(shape).astype('float32')
84
+ self.std = np.array(std).reshape(shape).astype('float32')
85
+
86
+ def __call__(self, data):
87
+ img = data['image']
88
+ from PIL import Image
89
+ if isinstance(img, Image.Image):
90
+ img = np.array(img)
91
+ assert isinstance(img,
92
+ np.ndarray), "invalid input 'img' in NormalizeImage"
93
+ data['image'] = (
94
+ img.astype('float32') * self.scale - self.mean) / self.std
95
+ return data
96
+
97
+
98
+ class ToCHWImage(object):
99
+ """ convert hwc image to chw image
100
+ """
101
+
102
+ def __init__(self, **kwargs):
103
+ pass
104
+
105
+ def __call__(self, data):
106
+ img = data['image']
107
+ from PIL import Image
108
+ if isinstance(img, Image.Image):
109
+ img = np.array(img)
110
+ data['image'] = img.transpose((2, 0, 1))
111
+ return data
112
+
113
+
114
+ class Fasttext(object):
115
+ def __init__(self, path="None", **kwargs):
116
+ import fasttext
117
+ self.fast_model = fasttext.load_model(path)
118
+
119
+ def __call__(self, data):
120
+ label = data['label']
121
+ fast_label = self.fast_model[label]
122
+ data['fast_label'] = fast_label
123
+ return data
124
+
125
+
126
+ class KeepKeys(object):
127
+ def __init__(self, keep_keys, **kwargs):
128
+ self.keep_keys = keep_keys
129
+
130
+ def __call__(self, data):
131
+ data_list = []
132
+ for key in self.keep_keys:
133
+ data_list.append(data[key])
134
+ return data_list
135
+
136
+
137
+ class Pad(object):
138
+ def __init__(self, size=None, size_div=32, **kwargs):
139
+ if size is not None and not isinstance(size, (int, list, tuple)):
140
+ raise TypeError("Type of target_size is invalid. Now is {}".format(
141
+ type(size)))
142
+ if isinstance(size, int):
143
+ size = [size, size]
144
+ self.size = size
145
+ self.size_div = size_div
146
+
147
+ def __call__(self, data):
148
+
149
+ img = data['image']
150
+ img_h, img_w = img.shape[0], img.shape[1]
151
+ if self.size:
152
+ resize_h2, resize_w2 = self.size
153
+ assert (
154
+ img_h < resize_h2 and img_w < resize_w2
155
+ ), '(h, w) of target size should be greater than (img_h, img_w)'
156
+ else:
157
+ resize_h2 = max(
158
+ int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
159
+ self.size_div)
160
+ resize_w2 = max(
161
+ int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
162
+ self.size_div)
163
+ img = cv2.copyMakeBorder(
164
+ img,
165
+ 0,
166
+ resize_h2 - img_h,
167
+ 0,
168
+ resize_w2 - img_w,
169
+ cv2.BORDER_CONSTANT,
170
+ value=0)
171
+ data['image'] = img
172
+ return data
173
+
174
+
175
+ class Resize(object):
176
+ def __init__(self, size=(640, 640), **kwargs):
177
+ self.size = size
178
+
179
+ def resize_image(self, img):
180
+ resize_h, resize_w = self.size
181
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
182
+ ratio_h = float(resize_h) / ori_h
183
+ ratio_w = float(resize_w) / ori_w
184
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
185
+ return img, [ratio_h, ratio_w]
186
+
187
+ def __call__(self, data):
188
+ img = data['image']
189
+ if 'polys' in data:
190
+ text_polys = data['polys']
191
+
192
+ img_resize, [ratio_h, ratio_w] = self.resize_image(img)
193
+ if 'polys' in data:
194
+ new_boxes = []
195
+ for box in text_polys:
196
+ new_box = []
197
+ for cord in box:
198
+ new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
199
+ new_boxes.append(new_box)
200
+ data['polys'] = np.array(new_boxes, dtype=np.float32)
201
+ data['image'] = img_resize
202
+ return data
203
+
204
+
205
+ class DetResizeForTest(object):
206
+ def __init__(self, **kwargs):
207
+ super(DetResizeForTest, self).__init__()
208
+ self.resize_type = 0
209
+ self.keep_ratio = False
210
+ if 'image_shape' in kwargs:
211
+ self.image_shape = kwargs['image_shape']
212
+ self.resize_type = 1
213
+ if 'keep_ratio' in kwargs:
214
+ self.keep_ratio = kwargs['keep_ratio']
215
+ elif 'limit_side_len' in kwargs:
216
+ self.limit_side_len = kwargs['limit_side_len']
217
+ self.limit_type = kwargs.get('limit_type', 'min')
218
+ elif 'resize_long' in kwargs:
219
+ self.resize_type = 2
220
+ self.resize_long = kwargs.get('resize_long', 960)
221
+ else:
222
+ self.limit_side_len = 736
223
+ self.limit_type = 'min'
224
+
225
+ def __call__(self, data):
226
+ img = data['image']
227
+ src_h, src_w, _ = img.shape
228
+ if sum([src_h, src_w]) < 64:
229
+ img = self.image_padding(img)
230
+
231
+ if self.resize_type == 0:
232
+ # img, shape = self.resize_image_type0(img)
233
+ img, [ratio_h, ratio_w] = self.resize_image_type0(img)
234
+ elif self.resize_type == 2:
235
+ img, [ratio_h, ratio_w] = self.resize_image_type2(img)
236
+ else:
237
+ # img, shape = self.resize_image_type1(img)
238
+ img, [ratio_h, ratio_w] = self.resize_image_type1(img)
239
+ data['image'] = img
240
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
241
+ return data
242
+
243
+ def image_padding(self, im, value=0):
244
+ h, w, c = im.shape
245
+ im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
246
+ im_pad[:h, :w, :] = im
247
+ return im_pad
248
+
249
+ def resize_image_type1(self, img):
250
+ resize_h, resize_w = self.image_shape
251
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
252
+ if self.keep_ratio is True:
253
+ resize_w = ori_w * resize_h / ori_h
254
+ N = math.ceil(resize_w / 32)
255
+ resize_w = N * 32
256
+ ratio_h = float(resize_h) / ori_h
257
+ ratio_w = float(resize_w) / ori_w
258
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
259
+ # return img, np.array([ori_h, ori_w])
260
+ return img, [ratio_h, ratio_w]
261
+
262
+ def resize_image_type0(self, img):
263
+ """
264
+ resize image to a size multiple of 32 which is required by the network
265
+ args:
266
+ img(array): array with shape [h, w, c]
267
+ return(tuple):
268
+ img, (ratio_h, ratio_w)
269
+ """
270
+ limit_side_len = self.limit_side_len
271
+ h, w, c = img.shape
272
+
273
+ # limit the max side
274
+ if self.limit_type == 'max':
275
+ if max(h, w) > limit_side_len:
276
+ if h > w:
277
+ ratio = float(limit_side_len) / h
278
+ else:
279
+ ratio = float(limit_side_len) / w
280
+ else:
281
+ ratio = 1.
282
+ elif self.limit_type == 'min':
283
+ if min(h, w) < limit_side_len:
284
+ if h < w:
285
+ ratio = float(limit_side_len) / h
286
+ else:
287
+ ratio = float(limit_side_len) / w
288
+ else:
289
+ ratio = 1.
290
+ elif self.limit_type == 'resize_long':
291
+ ratio = float(limit_side_len) / max(h, w)
292
+ else:
293
+ raise Exception('not support limit type, image ')
294
+ resize_h = int(h * ratio)
295
+ resize_w = int(w * ratio)
296
+
297
+ resize_h = max(int(round(resize_h / 32) * 32), 32)
298
+ resize_w = max(int(round(resize_w / 32) * 32), 32)
299
+
300
+ try:
301
+ if int(resize_w) <= 0 or int(resize_h) <= 0:
302
+ return None, (None, None)
303
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
304
+ except:
305
+ print(img.shape, resize_w, resize_h)
306
+ sys.exit(0)
307
+ ratio_h = resize_h / float(h)
308
+ ratio_w = resize_w / float(w)
309
+ return img, [ratio_h, ratio_w]
310
+
311
+ def resize_image_type2(self, img):
312
+ h, w, _ = img.shape
313
+
314
+ resize_w = w
315
+ resize_h = h
316
+
317
+ if resize_h > resize_w:
318
+ ratio = float(self.resize_long) / resize_h
319
+ else:
320
+ ratio = float(self.resize_long) / resize_w
321
+
322
+ resize_h = int(resize_h * ratio)
323
+ resize_w = int(resize_w * ratio)
324
+
325
+ max_stride = 128
326
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
327
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
328
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
329
+ ratio_h = resize_h / float(h)
330
+ ratio_w = resize_w / float(w)
331
+
332
+ return img, [ratio_h, ratio_w]
333
+
334
+
335
+ class E2EResizeForTest(object):
336
+ def __init__(self, **kwargs):
337
+ super(E2EResizeForTest, self).__init__()
338
+ self.max_side_len = kwargs['max_side_len']
339
+ self.valid_set = kwargs['valid_set']
340
+
341
+ def __call__(self, data):
342
+ img = data['image']
343
+ src_h, src_w, _ = img.shape
344
+ if self.valid_set == 'totaltext':
345
+ im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
346
+ img, max_side_len=self.max_side_len)
347
+ else:
348
+ im_resized, (ratio_h, ratio_w) = self.resize_image(
349
+ img, max_side_len=self.max_side_len)
350
+ data['image'] = im_resized
351
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
352
+ return data
353
+
354
+ def resize_image_for_totaltext(self, im, max_side_len=512):
355
+
356
+ h, w, _ = im.shape
357
+ resize_w = w
358
+ resize_h = h
359
+ ratio = 1.25
360
+ if h * ratio > max_side_len:
361
+ ratio = float(max_side_len) / resize_h
362
+ resize_h = int(resize_h * ratio)
363
+ resize_w = int(resize_w * ratio)
364
+
365
+ max_stride = 128
366
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
367
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
368
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
369
+ ratio_h = resize_h / float(h)
370
+ ratio_w = resize_w / float(w)
371
+ return im, (ratio_h, ratio_w)
372
+
373
+ def resize_image(self, im, max_side_len=512):
374
+ """
375
+ resize image to a size multiple of max_stride which is required by the network
376
+ :param im: the resized image
377
+ :param max_side_len: limit of max image size to avoid out of memory in gpu
378
+ :return: the resized image and the resize ratio
379
+ """
380
+ h, w, _ = im.shape
381
+
382
+ resize_w = w
383
+ resize_h = h
384
+
385
+ # Fix the longer side
386
+ if resize_h > resize_w:
387
+ ratio = float(max_side_len) / resize_h
388
+ else:
389
+ ratio = float(max_side_len) / resize_w
390
+
391
+ resize_h = int(resize_h * ratio)
392
+ resize_w = int(resize_w * ratio)
393
+
394
+ max_stride = 128
395
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
396
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
397
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
398
+ ratio_h = resize_h / float(h)
399
+ ratio_w = resize_w / float(w)
400
+
401
+ return im, (ratio_h, ratio_w)
402
+
403
+
404
+ class KieResize(object):
405
+ def __init__(self, **kwargs):
406
+ super(KieResize, self).__init__()
407
+ self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
408
+ 'img_scale'][1]
409
+
410
+ def __call__(self, data):
411
+ img = data['image']
412
+ points = data['points']
413
+ src_h, src_w, _ = img.shape
414
+ im_resized, scale_factor, [ratio_h, ratio_w
415
+ ], [new_h, new_w] = self.resize_image(img)
416
+ resize_points = self.resize_boxes(img, points, scale_factor)
417
+ data['ori_image'] = img
418
+ data['ori_boxes'] = points
419
+ data['points'] = resize_points
420
+ data['image'] = im_resized
421
+ data['shape'] = np.array([new_h, new_w])
422
+ return data
423
+
424
+ def resize_image(self, img):
425
+ norm_img = np.zeros([1024, 1024, 3], dtype='float32')
426
+ scale = [512, 1024]
427
+ h, w = img.shape[:2]
428
+ max_long_edge = max(scale)
429
+ max_short_edge = min(scale)
430
+ scale_factor = min(max_long_edge / max(h, w),
431
+ max_short_edge / min(h, w))
432
+ resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
433
+ scale_factor) + 0.5)
434
+ max_stride = 32
435
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
436
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
437
+ im = cv2.resize(img, (resize_w, resize_h))
438
+ new_h, new_w = im.shape[:2]
439
+ w_scale = new_w / w
440
+ h_scale = new_h / h
441
+ scale_factor = np.array(
442
+ [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
443
+ norm_img[:new_h, :new_w, :] = im
444
+ return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
445
+
446
+ def resize_boxes(self, im, points, scale_factor):
447
+ points = points * scale_factor
448
+ img_shape = im.shape[:2]
449
+ points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
450
+ points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
451
+ return points
452
+
453
+
454
+ class SRResize(object):
455
+ def __init__(self,
456
+ imgH=32,
457
+ imgW=128,
458
+ down_sample_scale=4,
459
+ keep_ratio=False,
460
+ min_ratio=1,
461
+ mask=False,
462
+ infer_mode=False,
463
+ **kwargs):
464
+ self.imgH = imgH
465
+ self.imgW = imgW
466
+ self.keep_ratio = keep_ratio
467
+ self.min_ratio = min_ratio
468
+ self.down_sample_scale = down_sample_scale
469
+ self.mask = mask
470
+ self.infer_mode = infer_mode
471
+
472
+ def __call__(self, data):
473
+ imgH = self.imgH
474
+ imgW = self.imgW
475
+ images_lr = data["image_lr"]
476
+ transform2 = ResizeNormalize(
477
+ (imgW // self.down_sample_scale, imgH // self.down_sample_scale))
478
+ images_lr = transform2(images_lr)
479
+ data["img_lr"] = images_lr
480
+ if self.infer_mode:
481
+ return data
482
+
483
+ images_HR = data["image_hr"]
484
+ label_strs = data["label"]
485
+ transform = ResizeNormalize((imgW, imgH))
486
+ images_HR = transform(images_HR)
487
+ data["img_hr"] = images_HR
488
+ return data
489
+
490
+
491
+ class ResizeNormalize(object):
492
+ def __init__(self, size, interpolation=Image.BICUBIC):
493
+ self.size = size
494
+ self.interpolation = interpolation
495
+
496
+ def __call__(self, img):
497
+ img = img.resize(self.size, self.interpolation)
498
+ img_numpy = np.array(img).astype("float32")
499
+ img_numpy = img_numpy.transpose((2, 0, 1)) / 255
500
+ return img_numpy
501
+
502
+
503
+ class GrayImageChannelFormat(object):
504
+ """
505
+ format gray scale image's channel: (3,h,w) -> (1,h,w)
506
+ Args:
507
+ inverse: inverse gray image
508
+ """
509
+
510
+ def __init__(self, inverse=False, **kwargs):
511
+ self.inverse = inverse
512
+
513
+ def __call__(self, data):
514
+ img = data['image']
515
+ img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
516
+ img_expanded = np.expand_dims(img_single_channel, 0)
517
+
518
+ if self.inverse:
519
+ data['image'] = np.abs(img_expanded - 1)
520
+ else:
521
+ data['image'] = img_expanded
522
+
523
+ data['src_image'] = img
524
+ return data
ppocr/data/imaug/pg_process.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import cv2
17
+ import numpy as np
18
+ from skimage.morphology._skeletonize import thin
19
+ from ppocr.utils.e2e_utils.extract_textpoint_fast import sort_and_expand_with_direction_v2
20
+
21
+ __all__ = ['PGProcessTrain']
22
+
23
+
24
+ class PGProcessTrain(object):
25
+ def __init__(self,
26
+ character_dict_path,
27
+ max_text_length,
28
+ max_text_nums,
29
+ tcl_len,
30
+ batch_size=14,
31
+ use_resize=True,
32
+ use_random_crop=False,
33
+ min_crop_size=24,
34
+ min_text_size=4,
35
+ max_text_size=512,
36
+ point_gather_mode=None,
37
+ **kwargs):
38
+ self.tcl_len = tcl_len
39
+ self.max_text_length = max_text_length
40
+ self.max_text_nums = max_text_nums
41
+ self.batch_size = batch_size
42
+ if use_random_crop is True:
43
+ self.min_crop_size = min_crop_size
44
+ self.use_random_crop = use_random_crop
45
+ self.min_text_size = min_text_size
46
+ self.max_text_size = max_text_size
47
+ self.use_resize = use_resize
48
+ self.point_gather_mode = point_gather_mode
49
+ self.Lexicon_Table = self.get_dict(character_dict_path)
50
+ self.pad_num = len(self.Lexicon_Table)
51
+ self.img_id = 0
52
+
53
+ def get_dict(self, character_dict_path):
54
+ character_str = ""
55
+ with open(character_dict_path, "rb") as fin:
56
+ lines = fin.readlines()
57
+ for line in lines:
58
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
59
+ character_str += line
60
+ dict_character = list(character_str)
61
+ return dict_character
62
+
63
+ def quad_area(self, poly):
64
+ """
65
+ compute area of a polygon
66
+ :param poly:
67
+ :return:
68
+ """
69
+ edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
70
+ (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
71
+ (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
72
+ (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
73
+ return np.sum(edge) / 2.
74
+
75
+ def gen_quad_from_poly(self, poly):
76
+ """
77
+ Generate min area quad from poly.
78
+ """
79
+ point_num = poly.shape[0]
80
+ min_area_quad = np.zeros((4, 2), dtype=np.float32)
81
+ rect = cv2.minAreaRect(poly.astype(
82
+ np.int32)) # (center (x,y), (width, height), angle of rotation)
83
+ box = np.array(cv2.boxPoints(rect))
84
+
85
+ first_point_idx = 0
86
+ min_dist = 1e4
87
+ for i in range(4):
88
+ dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
89
+ np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
90
+ np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
91
+ np.linalg.norm(box[(i + 3) % 4] - poly[-1])
92
+ if dist < min_dist:
93
+ min_dist = dist
94
+ first_point_idx = i
95
+ for i in range(4):
96
+ min_area_quad[i] = box[(first_point_idx + i) % 4]
97
+
98
+ return min_area_quad
99
+
100
+ def check_and_validate_polys(self, polys, tags, im_size):
101
+ """
102
+ check so that the text poly is in the same direction,
103
+ and also filter some invalid polygons
104
+ :param polys:
105
+ :param tags:
106
+ :return:
107
+ """
108
+ (h, w) = im_size
109
+ if polys.shape[0] == 0:
110
+ return polys, np.array([]), np.array([])
111
+ polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
112
+ polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
113
+
114
+ validated_polys = []
115
+ validated_tags = []
116
+ hv_tags = []
117
+ for poly, tag in zip(polys, tags):
118
+ quad = self.gen_quad_from_poly(poly)
119
+ p_area = self.quad_area(quad)
120
+ if abs(p_area) < 1:
121
+ print('invalid poly')
122
+ continue
123
+ if p_area > 0:
124
+ if tag == False:
125
+ print('poly in wrong direction')
126
+ tag = True # reversed cases should be ignore
127
+ poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
128
+ 1), :]
129
+ quad = quad[(0, 3, 2, 1), :]
130
+
131
+ len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
132
+ quad[2])
133
+ len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
134
+ quad[2])
135
+ hv_tag = 1
136
+
137
+ if len_w * 2.0 < len_h:
138
+ hv_tag = 0
139
+
140
+ validated_polys.append(poly)
141
+ validated_tags.append(tag)
142
+ hv_tags.append(hv_tag)
143
+ return np.array(validated_polys), np.array(validated_tags), np.array(
144
+ hv_tags)
145
+
146
+ def crop_area(self,
147
+ im,
148
+ polys,
149
+ tags,
150
+ hv_tags,
151
+ txts,
152
+ crop_background=False,
153
+ max_tries=25):
154
+ """
155
+ make random crop from the input image
156
+ :param im:
157
+ :param polys: [b,4,2]
158
+ :param tags:
159
+ :param crop_background:
160
+ :param max_tries: 50 -> 25
161
+ :return:
162
+ """
163
+ h, w, _ = im.shape
164
+ pad_h = h // 10
165
+ pad_w = w // 10
166
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
167
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
168
+ for poly in polys:
169
+ poly = np.round(poly, decimals=0).astype(np.int32)
170
+ minx = np.min(poly[:, 0])
171
+ maxx = np.max(poly[:, 0])
172
+ w_array[minx + pad_w:maxx + pad_w] = 1
173
+ miny = np.min(poly[:, 1])
174
+ maxy = np.max(poly[:, 1])
175
+ h_array[miny + pad_h:maxy + pad_h] = 1
176
+ # ensure the cropped area not across a text
177
+ h_axis = np.where(h_array == 0)[0]
178
+ w_axis = np.where(w_array == 0)[0]
179
+ if len(h_axis) == 0 or len(w_axis) == 0:
180
+ return im, polys, tags, hv_tags, txts
181
+ for i in range(max_tries):
182
+ xx = np.random.choice(w_axis, size=2)
183
+ xmin = np.min(xx) - pad_w
184
+ xmax = np.max(xx) - pad_w
185
+ xmin = np.clip(xmin, 0, w - 1)
186
+ xmax = np.clip(xmax, 0, w - 1)
187
+ yy = np.random.choice(h_axis, size=2)
188
+ ymin = np.min(yy) - pad_h
189
+ ymax = np.max(yy) - pad_h
190
+ ymin = np.clip(ymin, 0, h - 1)
191
+ ymax = np.clip(ymax, 0, h - 1)
192
+ if xmax - xmin < self.min_crop_size or \
193
+ ymax - ymin < self.min_crop_size:
194
+ continue
195
+ if polys.shape[0] != 0:
196
+ poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
197
+ & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
198
+ selected_polys = np.where(
199
+ np.sum(poly_axis_in_area, axis=1) == 4)[0]
200
+ else:
201
+ selected_polys = []
202
+ if len(selected_polys) == 0:
203
+ # no text in this area
204
+ if crop_background:
205
+ txts_tmp = []
206
+ for selected_poly in selected_polys:
207
+ txts_tmp.append(txts[selected_poly])
208
+ txts = txts_tmp
209
+ return im[ymin: ymax + 1, xmin: xmax + 1, :], \
210
+ polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
211
+ else:
212
+ continue
213
+ im = im[ymin:ymax + 1, xmin:xmax + 1, :]
214
+ polys = polys[selected_polys]
215
+ tags = tags[selected_polys]
216
+ hv_tags = hv_tags[selected_polys]
217
+ txts_tmp = []
218
+ for selected_poly in selected_polys:
219
+ txts_tmp.append(txts[selected_poly])
220
+ txts = txts_tmp
221
+ polys[:, :, 0] -= xmin
222
+ polys[:, :, 1] -= ymin
223
+ return im, polys, tags, hv_tags, txts
224
+
225
+ return im, polys, tags, hv_tags, txts
226
+
227
+ def fit_and_gather_tcl_points_v2(self,
228
+ min_area_quad,
229
+ poly,
230
+ max_h,
231
+ max_w,
232
+ fixed_point_num=64,
233
+ img_id=0,
234
+ reference_height=3):
235
+ """
236
+ Find the center point of poly as key_points, then fit and gather.
237
+ """
238
+ key_point_xys = []
239
+ point_num = poly.shape[0]
240
+ for idx in range(point_num // 2):
241
+ center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
242
+ key_point_xys.append(center_point)
243
+
244
+ tmp_image = np.zeros(
245
+ shape=(
246
+ max_h,
247
+ max_w, ), dtype='float32')
248
+ cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
249
+ False, 1.0)
250
+ ys, xs = np.where(tmp_image > 0)
251
+ xy_text = np.array(list(zip(xs, ys)), dtype='float32')
252
+
253
+ left_center_pt = (
254
+ (min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
255
+ right_center_pt = (
256
+ (min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
257
+ proj_unit_vec = (right_center_pt - left_center_pt) / (
258
+ np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
259
+ proj_unit_vec_tile = np.tile(proj_unit_vec,
260
+ (xy_text.shape[0], 1)) # (n, 2)
261
+ left_center_pt_tile = np.tile(left_center_pt,
262
+ (xy_text.shape[0], 1)) # (n, 2)
263
+ xy_text_to_left_center = xy_text - left_center_pt_tile
264
+ proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
265
+ xy_text = xy_text[np.argsort(proj_value)]
266
+
267
+ # convert to np and keep the num of point not greater then fixed_point_num
268
+ pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
269
+ point_num = len(pos_info)
270
+ if point_num > fixed_point_num:
271
+ keep_ids = [
272
+ int((point_num * 1.0 / fixed_point_num) * x)
273
+ for x in range(fixed_point_num)
274
+ ]
275
+ pos_info = pos_info[keep_ids, :]
276
+
277
+ keep = int(min(len(pos_info), fixed_point_num))
278
+ if np.random.rand() < 0.2 and reference_height >= 3:
279
+ dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
280
+ random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
281
+ [keep, 1])
282
+ pos_info += random_float
283
+ pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
284
+ pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
285
+
286
+ # padding to fixed length
287
+ pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
288
+ pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
289
+ pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
290
+ pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
291
+ pos_m[:keep] = 1.0
292
+ return pos_l, pos_m
293
+
294
+ def fit_and_gather_tcl_points_v3(self,
295
+ min_area_quad,
296
+ poly,
297
+ max_h,
298
+ max_w,
299
+ fixed_point_num=64,
300
+ img_id=0,
301
+ reference_height=3):
302
+ """
303
+ Find the center point of poly as key_points, then fit and gather.
304
+ """
305
+ det_mask = np.zeros((int(max_h / self.ds_ratio),
306
+ int(max_w / self.ds_ratio))).astype(np.float32)
307
+
308
+ # score_big_map
309
+ cv2.fillPoly(det_mask,
310
+ np.round(poly / self.ds_ratio).astype(np.int32), 1.0)
311
+ det_mask = cv2.resize(
312
+ det_mask, dsize=None, fx=self.ds_ratio, fy=self.ds_ratio)
313
+ det_mask = np.array(det_mask > 1e-3, dtype='float32')
314
+
315
+ f_direction = self.f_direction
316
+ skeleton_map = thin(det_mask.astype(np.uint8))
317
+ instance_count, instance_label_map = cv2.connectedComponents(
318
+ skeleton_map.astype(np.uint8), connectivity=8)
319
+
320
+ ys, xs = np.where(instance_label_map == 1)
321
+ pos_list = list(zip(ys, xs))
322
+ if len(pos_list) < 3:
323
+ return None
324
+ pos_list_sorted = sort_and_expand_with_direction_v2(
325
+ pos_list, f_direction, det_mask)
326
+
327
+ pos_list_sorted = np.array(pos_list_sorted)
328
+ length = len(pos_list_sorted) - 1
329
+ insert_num = 0
330
+ for index in range(length):
331
+ stride_y = np.abs(pos_list_sorted[index + insert_num][0] -
332
+ pos_list_sorted[index + 1 + insert_num][0])
333
+ stride_x = np.abs(pos_list_sorted[index + insert_num][1] -
334
+ pos_list_sorted[index + 1 + insert_num][1])
335
+ max_points = int(max(stride_x, stride_y))
336
+
337
+ stride = (pos_list_sorted[index + insert_num] -
338
+ pos_list_sorted[index + 1 + insert_num]) / (max_points)
339
+ insert_num_temp = max_points - 1
340
+
341
+ for i in range(int(insert_num_temp)):
342
+ insert_value = pos_list_sorted[index + insert_num] - (i + 1
343
+ ) * stride
344
+ insert_index = index + i + 1 + insert_num
345
+ pos_list_sorted = np.insert(
346
+ pos_list_sorted, insert_index, insert_value, axis=0)
347
+ insert_num += insert_num_temp
348
+
349
+ pos_info = np.array(pos_list_sorted).reshape(-1, 2).astype(
350
+ np.float32) # xy-> yx
351
+
352
+ point_num = len(pos_info)
353
+ if point_num > fixed_point_num:
354
+ keep_ids = [
355
+ int((point_num * 1.0 / fixed_point_num) * x)
356
+ for x in range(fixed_point_num)
357
+ ]
358
+ pos_info = pos_info[keep_ids, :]
359
+
360
+ keep = int(min(len(pos_info), fixed_point_num))
361
+ reference_width = (np.abs(poly[0, 0, 0] - poly[-1, 1, 0]) +
362
+ np.abs(poly[0, 3, 0] - poly[-1, 2, 0])) // 2
363
+ if np.random.rand() < 1:
364
+ dh = (np.random.rand(keep) - 0.5) * reference_height
365
+ offset = np.random.rand() - 0.5
366
+ dw = np.array([[0, offset * reference_width * 0.2]])
367
+ random_float_h = np.array([1, 0]).reshape([1, 2]) * dh.reshape(
368
+ [keep, 1])
369
+ random_float_w = dw.repeat(keep, axis=0)
370
+ pos_info += random_float_h
371
+ pos_info += random_float_w
372
+ pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
373
+ pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
374
+
375
+ # padding to fixed length
376
+ pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
377
+ pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
378
+ pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
379
+ pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
380
+ pos_m[:keep] = 1.0
381
+ return pos_l, pos_m
382
+
383
+ def generate_direction_map(self, poly_quads, n_char, direction_map):
384
+ """
385
+ """
386
+ width_list = []
387
+ height_list = []
388
+ for quad in poly_quads:
389
+ quad_w = (np.linalg.norm(quad[0] - quad[1]) +
390
+ np.linalg.norm(quad[2] - quad[3])) / 2.0
391
+ quad_h = (np.linalg.norm(quad[0] - quad[3]) +
392
+ np.linalg.norm(quad[2] - quad[1])) / 2.0
393
+ width_list.append(quad_w)
394
+ height_list.append(quad_h)
395
+ norm_width = max(sum(width_list) / n_char, 1.0)
396
+ average_height = max(sum(height_list) / len(height_list), 1.0)
397
+ k = 1
398
+ for quad in poly_quads:
399
+ direct_vector_full = (
400
+ (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
401
+ direct_vector = direct_vector_full / (
402
+ np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
403
+ direction_label = tuple(
404
+ map(float,
405
+ [direct_vector[0], direct_vector[1], 1.0 / average_height]))
406
+ cv2.fillPoly(direction_map,
407
+ quad.round().astype(np.int32)[np.newaxis, :, :],
408
+ direction_label)
409
+ k += 1
410
+ return direction_map
411
+
412
+ def calculate_average_height(self, poly_quads):
413
+ """
414
+ """
415
+ height_list = []
416
+ for quad in poly_quads:
417
+ quad_h = (np.linalg.norm(quad[0] - quad[3]) +
418
+ np.linalg.norm(quad[2] - quad[1])) / 2.0
419
+ height_list.append(quad_h)
420
+ average_height = max(sum(height_list) / len(height_list), 1.0)
421
+ return average_height
422
+
423
+ def generate_tcl_ctc_label(self,
424
+ h,
425
+ w,
426
+ polys,
427
+ tags,
428
+ text_strs,
429
+ ds_ratio,
430
+ tcl_ratio=0.3,
431
+ shrink_ratio_of_width=0.15):
432
+ """
433
+ Generate polygon.
434
+ """
435
+ self.ds_ratio = ds_ratio
436
+ score_map_big = np.zeros(
437
+ (
438
+ h,
439
+ w, ), dtype=np.float32)
440
+ h, w = int(h * ds_ratio), int(w * ds_ratio)
441
+ polys = polys * ds_ratio
442
+
443
+ score_map = np.zeros(
444
+ (
445
+ h,
446
+ w, ), dtype=np.float32)
447
+ score_label_map = np.zeros(
448
+ (
449
+ h,
450
+ w, ), dtype=np.float32)
451
+ tbo_map = np.zeros((h, w, 5), dtype=np.float32)
452
+ training_mask = np.ones(
453
+ (
454
+ h,
455
+ w, ), dtype=np.float32)
456
+ direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
457
+ [1, 1, 3]).astype(np.float32)
458
+
459
+ label_idx = 0
460
+ score_label_map_text_label_list = []
461
+ pos_list, pos_mask, label_list = [], [], []
462
+ for poly_idx, poly_tag in enumerate(zip(polys, tags)):
463
+ poly = poly_tag[0]
464
+ tag = poly_tag[1]
465
+
466
+ # generate min_area_quad
467
+ min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
468
+ min_area_quad_h = 0.5 * (
469
+ np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
470
+ np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
471
+ min_area_quad_w = 0.5 * (
472
+ np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
473
+ np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
474
+
475
+ if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
476
+ or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
477
+ continue
478
+
479
+ if tag:
480
+ cv2.fillPoly(training_mask,
481
+ poly.astype(np.int32)[np.newaxis, :, :], 0.15)
482
+ else:
483
+ text_label = text_strs[poly_idx]
484
+ text_label = self.prepare_text_label(text_label,
485
+ self.Lexicon_Table)
486
+ text_label_index_list = [[self.Lexicon_Table.index(c_)]
487
+ for c_ in text_label
488
+ if c_ in self.Lexicon_Table]
489
+ if len(text_label_index_list) < 1:
490
+ continue
491
+
492
+ tcl_poly = self.poly2tcl(poly, tcl_ratio)
493
+ tcl_quads = self.poly2quads(tcl_poly)
494
+ poly_quads = self.poly2quads(poly)
495
+
496
+ stcl_quads, quad_index = self.shrink_poly_along_width(
497
+ tcl_quads,
498
+ shrink_ratio_of_width=shrink_ratio_of_width,
499
+ expand_height_ratio=1.0 / tcl_ratio)
500
+
501
+ cv2.fillPoly(score_map,
502
+ np.round(stcl_quads).astype(np.int32), 1.0)
503
+ cv2.fillPoly(score_map_big,
504
+ np.round(stcl_quads / ds_ratio).astype(np.int32),
505
+ 1.0)
506
+
507
+ for idx, quad in enumerate(stcl_quads):
508
+ quad_mask = np.zeros((h, w), dtype=np.float32)
509
+ quad_mask = cv2.fillPoly(
510
+ quad_mask,
511
+ np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
512
+ tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
513
+ quad_mask, tbo_map)
514
+
515
+ # score label map and score_label_map_text_label_list for refine
516
+ if label_idx == 0:
517
+ text_pos_list_ = [[len(self.Lexicon_Table)], ]
518
+ score_label_map_text_label_list.append(text_pos_list_)
519
+
520
+ label_idx += 1
521
+ cv2.fillPoly(score_label_map,
522
+ np.round(poly_quads).astype(np.int32), label_idx)
523
+ score_label_map_text_label_list.append(text_label_index_list)
524
+
525
+ # direction info, fix-me
526
+ n_char = len(text_label_index_list)
527
+ direction_map = self.generate_direction_map(poly_quads, n_char,
528
+ direction_map)
529
+
530
+ # pos info
531
+ average_shrink_height = self.calculate_average_height(
532
+ stcl_quads)
533
+
534
+ if self.point_gather_mode == 'align':
535
+ self.f_direction = direction_map[:, :, :-1].copy()
536
+ pos_res = self.fit_and_gather_tcl_points_v3(
537
+ min_area_quad,
538
+ stcl_quads,
539
+ max_h=h,
540
+ max_w=w,
541
+ fixed_point_num=64,
542
+ img_id=self.img_id,
543
+ reference_height=average_shrink_height)
544
+ if pos_res is None:
545
+ continue
546
+ pos_l, pos_m = pos_res[0], pos_res[1]
547
+
548
+ else:
549
+ pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
550
+ min_area_quad,
551
+ poly,
552
+ max_h=h,
553
+ max_w=w,
554
+ fixed_point_num=64,
555
+ img_id=self.img_id,
556
+ reference_height=average_shrink_height)
557
+
558
+ label_l = text_label_index_list
559
+ if len(text_label_index_list) < 2:
560
+ continue
561
+
562
+ pos_list.append(pos_l)
563
+ pos_mask.append(pos_m)
564
+ label_list.append(label_l)
565
+
566
+ # use big score_map for smooth tcl lines
567
+ score_map_big_resized = cv2.resize(
568
+ score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
569
+ score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
570
+
571
+ return score_map, score_label_map, tbo_map, direction_map, training_mask, \
572
+ pos_list, pos_mask, label_list, score_label_map_text_label_list
573
+
574
+ def adjust_point(self, poly):
575
+ """
576
+ adjust point order.
577
+ """
578
+ point_num = poly.shape[0]
579
+ if point_num == 4:
580
+ len_1 = np.linalg.norm(poly[0] - poly[1])
581
+ len_2 = np.linalg.norm(poly[1] - poly[2])
582
+ len_3 = np.linalg.norm(poly[2] - poly[3])
583
+ len_4 = np.linalg.norm(poly[3] - poly[0])
584
+
585
+ if (len_1 + len_3) * 1.5 < (len_2 + len_4):
586
+ poly = poly[[1, 2, 3, 0], :]
587
+
588
+ elif point_num > 4:
589
+ vector_1 = poly[0] - poly[1]
590
+ vector_2 = poly[1] - poly[2]
591
+ cos_theta = np.dot(vector_1, vector_2) / (
592
+ np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
593
+ theta = np.arccos(np.round(cos_theta, decimals=4))
594
+
595
+ if abs(theta) > (70 / 180 * math.pi):
596
+ index = list(range(1, point_num)) + [0]
597
+ poly = poly[np.array(index), :]
598
+ return poly
599
+
600
+ def gen_min_area_quad_from_poly(self, poly):
601
+ """
602
+ Generate min area quad from poly.
603
+ """
604
+ point_num = poly.shape[0]
605
+ min_area_quad = np.zeros((4, 2), dtype=np.float32)
606
+ if point_num == 4:
607
+ min_area_quad = poly
608
+ center_point = np.sum(poly, axis=0) / 4
609
+ else:
610
+ rect = cv2.minAreaRect(poly.astype(
611
+ np.int32)) # (center (x,y), (width, height), angle of rotation)
612
+ center_point = rect[0]
613
+ box = np.array(cv2.boxPoints(rect))
614
+
615
+ first_point_idx = 0
616
+ min_dist = 1e4
617
+ for i in range(4):
618
+ dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
619
+ np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
620
+ np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
621
+ np.linalg.norm(box[(i + 3) % 4] - poly[-1])
622
+ if dist < min_dist:
623
+ min_dist = dist
624
+ first_point_idx = i
625
+
626
+ for i in range(4):
627
+ min_area_quad[i] = box[(first_point_idx + i) % 4]
628
+
629
+ return min_area_quad, center_point
630
+
631
+ def shrink_quad_along_width(self,
632
+ quad,
633
+ begin_width_ratio=0.,
634
+ end_width_ratio=1.):
635
+ """
636
+ Generate shrink_quad_along_width.
637
+ """
638
+ ratio_pair = np.array(
639
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
640
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
641
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
642
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
643
+
644
+ def shrink_poly_along_width(self,
645
+ quads,
646
+ shrink_ratio_of_width,
647
+ expand_height_ratio=1.0):
648
+ """
649
+ shrink poly with given length.
650
+ """
651
+ upper_edge_list = []
652
+
653
+ def get_cut_info(edge_len_list, cut_len):
654
+ for idx, edge_len in enumerate(edge_len_list):
655
+ cut_len -= edge_len
656
+ if cut_len <= 0.000001:
657
+ ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
658
+ return idx, ratio
659
+
660
+ for quad in quads:
661
+ upper_edge_len = np.linalg.norm(quad[0] - quad[1])
662
+ upper_edge_list.append(upper_edge_len)
663
+
664
+ # length of left edge and right edge.
665
+ left_length = np.linalg.norm(quads[0][0] - quads[0][
666
+ 3]) * expand_height_ratio
667
+ right_length = np.linalg.norm(quads[-1][1] - quads[-1][
668
+ 2]) * expand_height_ratio
669
+
670
+ shrink_length = min(left_length, right_length,
671
+ sum(upper_edge_list)) * shrink_ratio_of_width
672
+ # shrinking length
673
+ upper_len_left = shrink_length
674
+ upper_len_right = sum(upper_edge_list) - shrink_length
675
+
676
+ left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
677
+ left_quad = self.shrink_quad_along_width(
678
+ quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
679
+ right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
680
+ right_quad = self.shrink_quad_along_width(
681
+ quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
682
+
683
+ out_quad_list = []
684
+ if left_idx == right_idx:
685
+ out_quad_list.append(
686
+ [left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
687
+ else:
688
+ out_quad_list.append(left_quad)
689
+ for idx in range(left_idx + 1, right_idx):
690
+ out_quad_list.append(quads[idx])
691
+ out_quad_list.append(right_quad)
692
+
693
+ return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
694
+
695
+ def prepare_text_label(self, label_str, Lexicon_Table):
696
+ """
697
+ Prepare text lablel by given Lexicon_Table.
698
+ """
699
+ if len(Lexicon_Table) == 36:
700
+ return label_str.lower()
701
+ else:
702
+ return label_str
703
+
704
+ def vector_angle(self, A, B):
705
+ """
706
+ Calculate the angle between vector AB and x-axis positive direction.
707
+ """
708
+ AB = np.array([B[1] - A[1], B[0] - A[0]])
709
+ return np.arctan2(*AB)
710
+
711
+ def theta_line_cross_point(self, theta, point):
712
+ """
713
+ Calculate the line through given point and angle in ax + by + c =0 form.
714
+ """
715
+ x, y = point
716
+ cos = np.cos(theta)
717
+ sin = np.sin(theta)
718
+ return [sin, -cos, cos * y - sin * x]
719
+
720
+ def line_cross_two_point(self, A, B):
721
+ """
722
+ Calculate the line through given point A and B in ax + by + c =0 form.
723
+ """
724
+ angle = self.vector_angle(A, B)
725
+ return self.theta_line_cross_point(angle, A)
726
+
727
+ def average_angle(self, poly):
728
+ """
729
+ Calculate the average angle between left and right edge in given poly.
730
+ """
731
+ p0, p1, p2, p3 = poly
732
+ angle30 = self.vector_angle(p3, p0)
733
+ angle21 = self.vector_angle(p2, p1)
734
+ return (angle30 + angle21) / 2
735
+
736
+ def line_cross_point(self, line1, line2):
737
+ """
738
+ line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
739
+ """
740
+ a1, b1, c1 = line1
741
+ a2, b2, c2 = line2
742
+ d = a1 * b2 - a2 * b1
743
+
744
+ if d == 0:
745
+ print('Cross point does not exist')
746
+ return np.array([0, 0], dtype=np.float32)
747
+ else:
748
+ x = (b1 * c2 - b2 * c1) / d
749
+ y = (a2 * c1 - a1 * c2) / d
750
+
751
+ return np.array([x, y], dtype=np.float32)
752
+
753
+ def quad2tcl(self, poly, ratio):
754
+ """
755
+ Generate center line by poly clock-wise point. (4, 2)
756
+ """
757
+ ratio_pair = np.array(
758
+ [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
759
+ p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
760
+ p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
761
+ return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
762
+
763
+ def poly2tcl(self, poly, ratio):
764
+ """
765
+ Generate center line by poly clock-wise point.
766
+ """
767
+ ratio_pair = np.array(
768
+ [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
769
+ tcl_poly = np.zeros_like(poly)
770
+ point_num = poly.shape[0]
771
+
772
+ for idx in range(point_num // 2):
773
+ point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
774
+ ) * ratio_pair
775
+ tcl_poly[idx] = point_pair[0]
776
+ tcl_poly[point_num - 1 - idx] = point_pair[1]
777
+ return tcl_poly
778
+
779
+ def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
780
+ """
781
+ Generate tbo_map for give quad.
782
+ """
783
+ # upper and lower line function: ax + by + c = 0;
784
+ up_line = self.line_cross_two_point(quad[0], quad[1])
785
+ lower_line = self.line_cross_two_point(quad[3], quad[2])
786
+
787
+ quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
788
+ np.linalg.norm(quad[1] - quad[2]))
789
+ quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
790
+ np.linalg.norm(quad[2] - quad[3]))
791
+
792
+ # average angle of left and right line.
793
+ angle = self.average_angle(quad)
794
+
795
+ xy_in_poly = np.argwhere(tcl_mask == 1)
796
+ for y, x in xy_in_poly:
797
+ point = (x, y)
798
+ line = self.theta_line_cross_point(angle, point)
799
+ cross_point_upper = self.line_cross_point(up_line, line)
800
+ cross_point_lower = self.line_cross_point(lower_line, line)
801
+ ##FIX, offset reverse
802
+ upper_offset_x, upper_offset_y = cross_point_upper - point
803
+ lower_offset_x, lower_offset_y = cross_point_lower - point
804
+ tbo_map[y, x, 0] = upper_offset_y
805
+ tbo_map[y, x, 1] = upper_offset_x
806
+ tbo_map[y, x, 2] = lower_offset_y
807
+ tbo_map[y, x, 3] = lower_offset_x
808
+ tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
809
+ return tbo_map
810
+
811
+ def poly2quads(self, poly):
812
+ """
813
+ Split poly into quads.
814
+ """
815
+ quad_list = []
816
+ point_num = poly.shape[0]
817
+
818
+ # point pair
819
+ point_pair_list = []
820
+ for idx in range(point_num // 2):
821
+ point_pair = [poly[idx], poly[point_num - 1 - idx]]
822
+ point_pair_list.append(point_pair)
823
+
824
+ quad_num = point_num // 2 - 1
825
+ for idx in range(quad_num):
826
+ # reshape and adjust to clock-wise
827
+ quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
828
+ ).reshape(4, 2)[[0, 2, 3, 1]])
829
+
830
+ return np.array(quad_list)
831
+
832
+ def rotate_im_poly(self, im, text_polys):
833
+ """
834
+ rotate image with 90 / 180 / 270 degre
835
+ """
836
+ im_w, im_h = im.shape[1], im.shape[0]
837
+ dst_im = im.copy()
838
+ dst_polys = []
839
+ rand_degree_ratio = np.random.rand()
840
+ rand_degree_cnt = 1
841
+ if rand_degree_ratio > 0.5:
842
+ rand_degree_cnt = 3
843
+ for i in range(rand_degree_cnt):
844
+ dst_im = np.rot90(dst_im)
845
+ rot_degree = -90 * rand_degree_cnt
846
+ rot_angle = rot_degree * math.pi / 180.0
847
+ n_poly = text_polys.shape[0]
848
+ cx, cy = 0.5 * im_w, 0.5 * im_h
849
+ ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
850
+ for i in range(n_poly):
851
+ wordBB = text_polys[i]
852
+ poly = []
853
+ for j in range(4): # 16->4
854
+ sx, sy = wordBB[j][0], wordBB[j][1]
855
+ dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
856
+ sy - cy) + ncx
857
+ dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
858
+ sy - cy) + ncy
859
+ poly.append([dx, dy])
860
+ dst_polys.append(poly)
861
+ return dst_im, np.array(dst_polys, dtype=np.float32)
862
+
863
+ def __call__(self, data):
864
+ input_size = 512
865
+ im = data['image']
866
+ text_polys = data['polys']
867
+ text_tags = data['ignore_tags']
868
+ text_strs = data['texts']
869
+ h, w, _ = im.shape
870
+ text_polys, text_tags, hv_tags = self.check_and_validate_polys(
871
+ text_polys, text_tags, (h, w))
872
+ if text_polys.shape[0] <= 0:
873
+ return None
874
+ # set aspect ratio and keep area fix
875
+ asp_scales = np.arange(1.0, 1.55, 0.1)
876
+ asp_scale = np.random.choice(asp_scales)
877
+ if np.random.rand() < 0.5:
878
+ asp_scale = 1.0 / asp_scale
879
+ asp_scale = math.sqrt(asp_scale)
880
+
881
+ asp_wx = asp_scale
882
+ asp_hy = 1.0 / asp_scale
883
+ im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
884
+ text_polys[:, :, 0] *= asp_wx
885
+ text_polys[:, :, 1] *= asp_hy
886
+
887
+ if self.use_resize is True:
888
+ ori_h, ori_w, _ = im.shape
889
+ if max(ori_h, ori_w) < 200:
890
+ ratio = 200 / max(ori_h, ori_w)
891
+ im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
892
+ text_polys[:, :, 0] *= ratio
893
+ text_polys[:, :, 1] *= ratio
894
+
895
+ if max(ori_h, ori_w) > 512:
896
+ ratio = 512 / max(ori_h, ori_w)
897
+ im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
898
+ text_polys[:, :, 0] *= ratio
899
+ text_polys[:, :, 1] *= ratio
900
+ elif self.use_random_crop is True:
901
+ h, w, _ = im.shape
902
+ if max(h, w) > 2048:
903
+ rd_scale = 2048.0 / max(h, w)
904
+ im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
905
+ text_polys *= rd_scale
906
+ h, w, _ = im.shape
907
+ if min(h, w) < 16:
908
+ return None
909
+
910
+ # no background
911
+ im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
912
+ im,
913
+ text_polys,
914
+ text_tags,
915
+ hv_tags,
916
+ text_strs,
917
+ crop_background=False)
918
+
919
+ if text_polys.shape[0] == 0:
920
+ return None
921
+ # continue for all ignore case
922
+ if np.sum((text_tags * 1.0)) >= text_tags.size:
923
+ return None
924
+ new_h, new_w, _ = im.shape
925
+ if (new_h is None) or (new_w is None):
926
+ return None
927
+ # resize image
928
+ std_ratio = float(input_size) / max(new_w, new_h)
929
+ rand_scales = np.array(
930
+ [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
931
+ rz_scale = std_ratio * np.random.choice(rand_scales)
932
+ im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
933
+ text_polys[:, :, 0] *= rz_scale
934
+ text_polys[:, :, 1] *= rz_scale
935
+
936
+ # add gaussian blur
937
+ if np.random.rand() < 0.1 * 0.5:
938
+ ks = np.random.permutation(5)[0] + 1
939
+ ks = int(ks / 2) * 2 + 1
940
+ im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
941
+ # add brighter
942
+ if np.random.rand() < 0.1 * 0.5:
943
+ im = im * (1.0 + np.random.rand() * 0.5)
944
+ im = np.clip(im, 0.0, 255.0)
945
+ # add darker
946
+ if np.random.rand() < 0.1 * 0.5:
947
+ im = im * (1.0 - np.random.rand() * 0.5)
948
+ im = np.clip(im, 0.0, 255.0)
949
+
950
+ # Padding the im to [input_size, input_size]
951
+ new_h, new_w, _ = im.shape
952
+ if min(new_w, new_h) < input_size * 0.5:
953
+ return None
954
+ im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
955
+ im_padded[:, :, 2] = 0.485 * 255
956
+ im_padded[:, :, 1] = 0.456 * 255
957
+ im_padded[:, :, 0] = 0.406 * 255
958
+
959
+ # Random the start position
960
+ del_h = input_size - new_h
961
+ del_w = input_size - new_w
962
+ sh, sw = 0, 0
963
+ if del_h > 1:
964
+ sh = int(np.random.rand() * del_h)
965
+ if del_w > 1:
966
+ sw = int(np.random.rand() * del_w)
967
+
968
+ # Padding
969
+ im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
970
+ text_polys[:, :, 0] += sw
971
+ text_polys[:, :, 1] += sh
972
+
973
+ score_map, score_label_map, border_map, direction_map, training_mask, \
974
+ pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
975
+ input_size,
976
+ text_polys,
977
+ text_tags,
978
+ text_strs, 0.25)
979
+ if len(label_list) <= 0: # eliminate negative samples
980
+ return None
981
+ pos_list_temp = np.zeros([64, 3])
982
+ pos_mask_temp = np.zeros([64, 1])
983
+ label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
984
+
985
+ for i, label in enumerate(label_list):
986
+ n = len(label)
987
+ if n > self.max_text_length:
988
+ label_list[i] = label[:self.max_text_length]
989
+ continue
990
+ while n < self.max_text_length:
991
+ label.append([self.pad_num])
992
+ n += 1
993
+
994
+ for i in range(len(label_list)):
995
+ label_list[i] = np.array(label_list[i])
996
+
997
+ if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
998
+ return None
999
+ for __ in range(self.max_text_nums - len(pos_list), 0, -1):
1000
+ pos_list.append(pos_list_temp)
1001
+ pos_mask.append(pos_mask_temp)
1002
+ label_list.append(label_list_temp)
1003
+
1004
+ if self.img_id == self.batch_size - 1:
1005
+ self.img_id = 0
1006
+ else:
1007
+ self.img_id += 1
1008
+
1009
+ im_padded[:, :, 2] -= 0.485 * 255
1010
+ im_padded[:, :, 1] -= 0.456 * 255
1011
+ im_padded[:, :, 0] -= 0.406 * 255
1012
+ im_padded[:, :, 2] /= (255.0 * 0.229)
1013
+ im_padded[:, :, 1] /= (255.0 * 0.224)
1014
+ im_padded[:, :, 0] /= (255.0 * 0.225)
1015
+ im_padded = im_padded.transpose((2, 0, 1))
1016
+ images = im_padded[::-1, :, :]
1017
+ tcl_maps = score_map[np.newaxis, :, :]
1018
+ tcl_label_maps = score_label_map[np.newaxis, :, :]
1019
+ border_maps = border_map.transpose((2, 0, 1))
1020
+ direction_maps = direction_map.transpose((2, 0, 1))
1021
+ training_masks = training_mask[np.newaxis, :, :]
1022
+ pos_list = np.array(pos_list)
1023
+ pos_mask = np.array(pos_mask)
1024
+ label_list = np.array(label_list)
1025
+ data['images'] = images
1026
+ data['tcl_maps'] = tcl_maps
1027
+ data['tcl_label_maps'] = tcl_label_maps
1028
+ data['border_maps'] = border_maps
1029
+ data['direction_maps'] = direction_maps
1030
+ data['training_masks'] = training_masks
1031
+ data['label_list'] = label_list
1032
+ data['pos_list'] = pos_list
1033
+ data['pos_mask'] = pos_mask
1034
+ return data
ppocr/data/imaug/randaugment.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+ from __future__ import unicode_literals
19
+
20
+ from PIL import Image, ImageEnhance, ImageOps
21
+ import numpy as np
22
+ import random
23
+ import six
24
+
25
+
26
+ class RawRandAugment(object):
27
+ def __init__(self,
28
+ num_layers=2,
29
+ magnitude=5,
30
+ fillcolor=(128, 128, 128),
31
+ **kwargs):
32
+ self.num_layers = num_layers
33
+ self.magnitude = magnitude
34
+ self.max_level = 10
35
+
36
+ abso_level = self.magnitude / self.max_level
37
+ self.level_map = {
38
+ "shearX": 0.3 * abso_level,
39
+ "shearY": 0.3 * abso_level,
40
+ "translateX": 150.0 / 331 * abso_level,
41
+ "translateY": 150.0 / 331 * abso_level,
42
+ "rotate": 30 * abso_level,
43
+ "color": 0.9 * abso_level,
44
+ "posterize": int(4.0 * abso_level),
45
+ "solarize": 256.0 * abso_level,
46
+ "contrast": 0.9 * abso_level,
47
+ "sharpness": 0.9 * abso_level,
48
+ "brightness": 0.9 * abso_level,
49
+ "autocontrast": 0,
50
+ "equalize": 0,
51
+ "invert": 0
52
+ }
53
+
54
+ # from https://stackoverflow.com/questions/5252170/
55
+ # specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
56
+ def rotate_with_fill(img, magnitude):
57
+ rot = img.convert("RGBA").rotate(magnitude)
58
+ return Image.composite(rot,
59
+ Image.new("RGBA", rot.size, (128, ) * 4),
60
+ rot).convert(img.mode)
61
+
62
+ rnd_ch_op = random.choice
63
+
64
+ self.func = {
65
+ "shearX": lambda img, magnitude: img.transform(
66
+ img.size,
67
+ Image.AFFINE,
68
+ (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
69
+ Image.BICUBIC,
70
+ fillcolor=fillcolor),
71
+ "shearY": lambda img, magnitude: img.transform(
72
+ img.size,
73
+ Image.AFFINE,
74
+ (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
75
+ Image.BICUBIC,
76
+ fillcolor=fillcolor),
77
+ "translateX": lambda img, magnitude: img.transform(
78
+ img.size,
79
+ Image.AFFINE,
80
+ (1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
81
+ fillcolor=fillcolor),
82
+ "translateY": lambda img, magnitude: img.transform(
83
+ img.size,
84
+ Image.AFFINE,
85
+ (1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
86
+ fillcolor=fillcolor),
87
+ "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
88
+ "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
89
+ 1 + magnitude * rnd_ch_op([-1, 1])),
90
+ "posterize": lambda img, magnitude:
91
+ ImageOps.posterize(img, magnitude),
92
+ "solarize": lambda img, magnitude:
93
+ ImageOps.solarize(img, magnitude),
94
+ "contrast": lambda img, magnitude:
95
+ ImageEnhance.Contrast(img).enhance(
96
+ 1 + magnitude * rnd_ch_op([-1, 1])),
97
+ "sharpness": lambda img, magnitude:
98
+ ImageEnhance.Sharpness(img).enhance(
99
+ 1 + magnitude * rnd_ch_op([-1, 1])),
100
+ "brightness": lambda img, magnitude:
101
+ ImageEnhance.Brightness(img).enhance(
102
+ 1 + magnitude * rnd_ch_op([-1, 1])),
103
+ "autocontrast": lambda img, magnitude:
104
+ ImageOps.autocontrast(img),
105
+ "equalize": lambda img, magnitude: ImageOps.equalize(img),
106
+ "invert": lambda img, magnitude: ImageOps.invert(img)
107
+ }
108
+
109
+ def __call__(self, img):
110
+ avaiable_op_names = list(self.level_map.keys())
111
+ for layer_num in range(self.num_layers):
112
+ op_name = np.random.choice(avaiable_op_names)
113
+ img = self.func[op_name](img, self.level_map[op_name])
114
+ return img
115
+
116
+
117
+ class RandAugment(RawRandAugment):
118
+ """ RandAugment wrapper to auto fit different img types """
119
+
120
+ def __init__(self, prob=0.5, *args, **kwargs):
121
+ self.prob = prob
122
+ if six.PY2:
123
+ super(RandAugment, self).__init__(*args, **kwargs)
124
+ else:
125
+ super().__init__(*args, **kwargs)
126
+
127
+ def __call__(self, data):
128
+ if np.random.rand() > self.prob:
129
+ return data
130
+ img = data['image']
131
+ if not isinstance(img, Image.Image):
132
+ img = np.ascontiguousarray(img)
133
+ img = Image.fromarray(img)
134
+
135
+ if six.PY2:
136
+ img = super(RandAugment, self).__call__(img)
137
+ else:
138
+ img = super().__call__(img)
139
+
140
+ if isinstance(img, Image.Image):
141
+ img = np.asarray(img)
142
+ data['image'] = img
143
+ return data
ppocr/data/imaug/random_crop_data.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This code is refer from:
16
+ https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py
17
+ """
18
+
19
+ from __future__ import absolute_import
20
+ from __future__ import division
21
+ from __future__ import print_function
22
+ from __future__ import unicode_literals
23
+
24
+ import numpy as np
25
+ import cv2
26
+ import random
27
+
28
+
29
+ def is_poly_in_rect(poly, x, y, w, h):
30
+ poly = np.array(poly)
31
+ if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
32
+ return False
33
+ if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
34
+ return False
35
+ return True
36
+
37
+
38
+ def is_poly_outside_rect(poly, x, y, w, h):
39
+ poly = np.array(poly)
40
+ if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
41
+ return True
42
+ if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
43
+ return True
44
+ return False
45
+
46
+
47
+ def split_regions(axis):
48
+ regions = []
49
+ min_axis = 0
50
+ for i in range(1, axis.shape[0]):
51
+ if axis[i] != axis[i - 1] + 1:
52
+ region = axis[min_axis:i]
53
+ min_axis = i
54
+ regions.append(region)
55
+ return regions
56
+
57
+
58
+ def random_select(axis, max_size):
59
+ xx = np.random.choice(axis, size=2)
60
+ xmin = np.min(xx)
61
+ xmax = np.max(xx)
62
+ xmin = np.clip(xmin, 0, max_size - 1)
63
+ xmax = np.clip(xmax, 0, max_size - 1)
64
+ return xmin, xmax
65
+
66
+
67
+ def region_wise_random_select(regions, max_size):
68
+ selected_index = list(np.random.choice(len(regions), 2))
69
+ selected_values = []
70
+ for index in selected_index:
71
+ axis = regions[index]
72
+ xx = int(np.random.choice(axis, size=1))
73
+ selected_values.append(xx)
74
+ xmin = min(selected_values)
75
+ xmax = max(selected_values)
76
+ return xmin, xmax
77
+
78
+
79
+ def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
80
+ h, w, _ = im.shape
81
+ h_array = np.zeros(h, dtype=np.int32)
82
+ w_array = np.zeros(w, dtype=np.int32)
83
+ for points in text_polys:
84
+ points = np.round(points, decimals=0).astype(np.int32)
85
+ minx = np.min(points[:, 0])
86
+ maxx = np.max(points[:, 0])
87
+ w_array[minx:maxx] = 1
88
+ miny = np.min(points[:, 1])
89
+ maxy = np.max(points[:, 1])
90
+ h_array[miny:maxy] = 1
91
+ # ensure the cropped area not across a text
92
+ h_axis = np.where(h_array == 0)[0]
93
+ w_axis = np.where(w_array == 0)[0]
94
+
95
+ if len(h_axis) == 0 or len(w_axis) == 0:
96
+ return 0, 0, w, h
97
+
98
+ h_regions = split_regions(h_axis)
99
+ w_regions = split_regions(w_axis)
100
+
101
+ for i in range(max_tries):
102
+ if len(w_regions) > 1:
103
+ xmin, xmax = region_wise_random_select(w_regions, w)
104
+ else:
105
+ xmin, xmax = random_select(w_axis, w)
106
+ if len(h_regions) > 1:
107
+ ymin, ymax = region_wise_random_select(h_regions, h)
108
+ else:
109
+ ymin, ymax = random_select(h_axis, h)
110
+
111
+ if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h:
112
+ # area too small
113
+ continue
114
+ num_poly_in_rect = 0
115
+ for poly in text_polys:
116
+ if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
117
+ ymax - ymin):
118
+ num_poly_in_rect += 1
119
+ break
120
+
121
+ if num_poly_in_rect > 0:
122
+ return xmin, ymin, xmax - xmin, ymax - ymin
123
+
124
+ return 0, 0, w, h
125
+
126
+
127
+ class EastRandomCropData(object):
128
+ def __init__(self,
129
+ size=(640, 640),
130
+ max_tries=10,
131
+ min_crop_side_ratio=0.1,
132
+ keep_ratio=True,
133
+ **kwargs):
134
+ self.size = size
135
+ self.max_tries = max_tries
136
+ self.min_crop_side_ratio = min_crop_side_ratio
137
+ self.keep_ratio = keep_ratio
138
+
139
+ def __call__(self, data):
140
+ img = data['image']
141
+ text_polys = data['polys']
142
+ ignore_tags = data['ignore_tags']
143
+ texts = data['texts']
144
+ all_care_polys = [
145
+ text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
146
+ ]
147
+ # 计算crop区域
148
+ crop_x, crop_y, crop_w, crop_h = crop_area(
149
+ img, all_care_polys, self.min_crop_side_ratio, self.max_tries)
150
+ # crop 图片 保持比例填充
151
+ scale_w = self.size[0] / crop_w
152
+ scale_h = self.size[1] / crop_h
153
+ scale = min(scale_w, scale_h)
154
+ h = int(crop_h * scale)
155
+ w = int(crop_w * scale)
156
+ if self.keep_ratio:
157
+ padimg = np.zeros((self.size[1], self.size[0], img.shape[2]),
158
+ img.dtype)
159
+ padimg[:h, :w] = cv2.resize(
160
+ img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
161
+ img = padimg
162
+ else:
163
+ img = cv2.resize(
164
+ img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
165
+ tuple(self.size))
166
+ # crop 文本框
167
+ text_polys_crop = []
168
+ ignore_tags_crop = []
169
+ texts_crop = []
170
+ for poly, text, tag in zip(text_polys, texts, ignore_tags):
171
+ poly = ((poly - (crop_x, crop_y)) * scale).tolist()
172
+ if not is_poly_outside_rect(poly, 0, 0, w, h):
173
+ text_polys_crop.append(poly)
174
+ ignore_tags_crop.append(tag)
175
+ texts_crop.append(text)
176
+ data['image'] = img
177
+ data['polys'] = np.array(text_polys_crop)
178
+ data['ignore_tags'] = ignore_tags_crop
179
+ data['texts'] = texts_crop
180
+ return data
181
+
182
+
183
+ class RandomCropImgMask(object):
184
+ def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs):
185
+ self.size = size
186
+ self.main_key = main_key
187
+ self.crop_keys = crop_keys
188
+ self.p = p
189
+
190
+ def __call__(self, data):
191
+ image = data['image']
192
+
193
+ h, w = image.shape[0:2]
194
+ th, tw = self.size
195
+ if w == tw and h == th:
196
+ return data
197
+
198
+ mask = data[self.main_key]
199
+ if np.max(mask) > 0 and random.random() > self.p:
200
+ # make sure to crop the text region
201
+ tl = np.min(np.where(mask > 0), axis=1) - (th, tw)
202
+ tl[tl < 0] = 0
203
+ br = np.max(np.where(mask > 0), axis=1) - (th, tw)
204
+ br[br < 0] = 0
205
+
206
+ br[0] = min(br[0], h - th)
207
+ br[1] = min(br[1], w - tw)
208
+
209
+ i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
210
+ j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
211
+ else:
212
+ i = random.randint(0, h - th) if h - th > 0 else 0
213
+ j = random.randint(0, w - tw) if w - tw > 0 else 0
214
+
215
+ # return i, j, th, tw
216
+ for k in data:
217
+ if k in self.crop_keys:
218
+ if len(data[k].shape) == 3:
219
+ if np.argmin(data[k].shape) == 0:
220
+ img = data[k][:, i:i + th, j:j + tw]
221
+ if img.shape[1] != img.shape[2]:
222
+ a = 1
223
+ elif np.argmin(data[k].shape) == 2:
224
+ img = data[k][i:i + th, j:j + tw, :]
225
+ if img.shape[1] != img.shape[0]:
226
+ a = 1
227
+ else:
228
+ img = data[k]
229
+ else:
230
+ img = data[k][i:i + th, j:j + tw]
231
+ if img.shape[0] != img.shape[1]:
232
+ a = 1
233
+ data[k] = img
234
+ return data
ppocr/data/imaug/rec_img_aug.py ADDED
@@ -0,0 +1,825 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import cv2
17
+ import numpy as np
18
+ import random
19
+ import copy
20
+ from PIL import Image
21
+ from .text_image_aug import tia_perspective, tia_stretch, tia_distort
22
+ from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration
23
+ from paddle.vision.transforms import Compose
24
+
25
+
26
+ class RecAug(object):
27
+ def __init__(self,
28
+ tia_prob=0.4,
29
+ crop_prob=0.4,
30
+ reverse_prob=0.4,
31
+ noise_prob=0.4,
32
+ jitter_prob=0.4,
33
+ blur_prob=0.4,
34
+ hsv_aug_prob=0.4,
35
+ **kwargs):
36
+ self.tia_prob = tia_prob
37
+ self.bda = BaseDataAugmentation(crop_prob, reverse_prob, noise_prob,
38
+ jitter_prob, blur_prob, hsv_aug_prob)
39
+
40
+ def __call__(self, data):
41
+ img = data['image']
42
+ h, w, _ = img.shape
43
+
44
+ # tia
45
+ if random.random() <= self.tia_prob:
46
+ if h >= 20 and w >= 20:
47
+ img = tia_distort(img, random.randint(3, 6))
48
+ img = tia_stretch(img, random.randint(3, 6))
49
+ img = tia_perspective(img)
50
+
51
+ # bda
52
+ data['image'] = img
53
+ data = self.bda(data)
54
+ return data
55
+
56
+
57
+ class BaseDataAugmentation(object):
58
+ def __init__(self,
59
+ crop_prob=0.4,
60
+ reverse_prob=0.4,
61
+ noise_prob=0.4,
62
+ jitter_prob=0.4,
63
+ blur_prob=0.4,
64
+ hsv_aug_prob=0.4,
65
+ **kwargs):
66
+ self.crop_prob = crop_prob
67
+ self.reverse_prob = reverse_prob
68
+ self.noise_prob = noise_prob
69
+ self.jitter_prob = jitter_prob
70
+ self.blur_prob = blur_prob
71
+ self.hsv_aug_prob = hsv_aug_prob
72
+
73
+ def __call__(self, data):
74
+ img = data['image']
75
+ h, w, _ = img.shape
76
+
77
+ if random.random() <= self.crop_prob and h >= 20 and w >= 20:
78
+ img = get_crop(img)
79
+
80
+ if random.random() <= self.blur_prob:
81
+ img = blur(img)
82
+
83
+ if random.random() <= self.hsv_aug_prob:
84
+ img = hsv_aug(img)
85
+
86
+ if random.random() <= self.jitter_prob:
87
+ img = jitter(img)
88
+
89
+ if random.random() <= self.noise_prob:
90
+ img = add_gasuss_noise(img)
91
+
92
+ if random.random() <= self.reverse_prob:
93
+ img = 255 - img
94
+
95
+ data['image'] = img
96
+ return data
97
+
98
+
99
+ class ABINetRecAug(object):
100
+ def __init__(self,
101
+ geometry_p=0.5,
102
+ deterioration_p=0.25,
103
+ colorjitter_p=0.25,
104
+ **kwargs):
105
+ self.transforms = Compose([
106
+ CVGeometry(
107
+ degrees=45,
108
+ translate=(0.0, 0.0),
109
+ scale=(0.5, 2.),
110
+ shear=(45, 15),
111
+ distortion=0.5,
112
+ p=geometry_p), CVDeterioration(
113
+ var=20, degrees=6, factor=4, p=deterioration_p),
114
+ CVColorJitter(
115
+ brightness=0.5,
116
+ contrast=0.5,
117
+ saturation=0.5,
118
+ hue=0.1,
119
+ p=colorjitter_p)
120
+ ])
121
+
122
+ def __call__(self, data):
123
+ img = data['image']
124
+ img = self.transforms(img)
125
+ data['image'] = img
126
+ return data
127
+
128
+
129
+ class RecConAug(object):
130
+ def __init__(self,
131
+ prob=0.5,
132
+ image_shape=(32, 320, 3),
133
+ max_text_length=25,
134
+ ext_data_num=1,
135
+ **kwargs):
136
+ self.ext_data_num = ext_data_num
137
+ self.prob = prob
138
+ self.max_text_length = max_text_length
139
+ self.image_shape = image_shape
140
+ self.max_wh_ratio = self.image_shape[1] / self.image_shape[0]
141
+
142
+ def merge_ext_data(self, data, ext_data):
143
+ ori_w = round(data['image'].shape[1] / data['image'].shape[0] *
144
+ self.image_shape[0])
145
+ ext_w = round(ext_data['image'].shape[1] / ext_data['image'].shape[0] *
146
+ self.image_shape[0])
147
+ data['image'] = cv2.resize(data['image'], (ori_w, self.image_shape[0]))
148
+ ext_data['image'] = cv2.resize(ext_data['image'],
149
+ (ext_w, self.image_shape[0]))
150
+ data['image'] = np.concatenate(
151
+ [data['image'], ext_data['image']], axis=1)
152
+ data["label"] += ext_data["label"]
153
+ return data
154
+
155
+ def __call__(self, data):
156
+ rnd_num = random.random()
157
+ if rnd_num > self.prob:
158
+ return data
159
+ for idx, ext_data in enumerate(data["ext_data"]):
160
+ if len(data["label"]) + len(ext_data[
161
+ "label"]) > self.max_text_length:
162
+ break
163
+ concat_ratio = data['image'].shape[1] / data['image'].shape[
164
+ 0] + ext_data['image'].shape[1] / ext_data['image'].shape[0]
165
+ if concat_ratio > self.max_wh_ratio:
166
+ break
167
+ data = self.merge_ext_data(data, ext_data)
168
+ data.pop("ext_data")
169
+ return data
170
+
171
+
172
+ class SVTRRecAug(object):
173
+ def __init__(self,
174
+ aug_type=0,
175
+ geometry_p=0.5,
176
+ deterioration_p=0.25,
177
+ colorjitter_p=0.25,
178
+ **kwargs):
179
+ self.transforms = Compose([
180
+ SVTRGeometry(
181
+ aug_type=aug_type,
182
+ degrees=45,
183
+ translate=(0.0, 0.0),
184
+ scale=(0.5, 2.),
185
+ shear=(45, 15),
186
+ distortion=0.5,
187
+ p=geometry_p), SVTRDeterioration(
188
+ var=20, degrees=6, factor=4, p=deterioration_p),
189
+ CVColorJitter(
190
+ brightness=0.5,
191
+ contrast=0.5,
192
+ saturation=0.5,
193
+ hue=0.1,
194
+ p=colorjitter_p)
195
+ ])
196
+
197
+ def __call__(self, data):
198
+ img = data['image']
199
+ img = self.transforms(img)
200
+ data['image'] = img
201
+ return data
202
+
203
+
204
+ class ClsResizeImg(object):
205
+ def __init__(self, image_shape, **kwargs):
206
+ self.image_shape = image_shape
207
+
208
+ def __call__(self, data):
209
+ img = data['image']
210
+ norm_img, _ = resize_norm_img(img, self.image_shape)
211
+ data['image'] = norm_img
212
+ return data
213
+
214
+
215
+ class RecResizeImg(object):
216
+ def __init__(self,
217
+ image_shape,
218
+ infer_mode=False,
219
+ character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
220
+ padding=True,
221
+ **kwargs):
222
+ self.image_shape = image_shape
223
+ self.infer_mode = infer_mode
224
+ self.character_dict_path = character_dict_path
225
+ self.padding = padding
226
+
227
+ def __call__(self, data):
228
+ img = data['image']
229
+ if self.infer_mode and self.character_dict_path is not None:
230
+ norm_img, valid_ratio = resize_norm_img_chinese(img,
231
+ self.image_shape)
232
+ else:
233
+ norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
234
+ self.padding)
235
+ data['image'] = norm_img
236
+ data['valid_ratio'] = valid_ratio
237
+ return data
238
+
239
+
240
+ class VLRecResizeImg(object):
241
+ def __init__(self,
242
+ image_shape,
243
+ infer_mode=False,
244
+ character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
245
+ padding=True,
246
+ **kwargs):
247
+ self.image_shape = image_shape
248
+ self.infer_mode = infer_mode
249
+ self.character_dict_path = character_dict_path
250
+ self.padding = padding
251
+
252
+ def __call__(self, data):
253
+ img = data['image']
254
+
255
+ imgC, imgH, imgW = self.image_shape
256
+ resized_image = cv2.resize(
257
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
258
+ resized_w = imgW
259
+ resized_image = resized_image.astype('float32')
260
+ if self.image_shape[0] == 1:
261
+ resized_image = resized_image / 255
262
+ norm_img = resized_image[np.newaxis, :]
263
+ else:
264
+ norm_img = resized_image.transpose((2, 0, 1)) / 255
265
+ valid_ratio = min(1.0, float(resized_w / imgW))
266
+
267
+ data['image'] = norm_img
268
+ data['valid_ratio'] = valid_ratio
269
+ return data
270
+
271
+
272
+ class RFLRecResizeImg(object):
273
+ def __init__(self, image_shape, padding=True, interpolation=1, **kwargs):
274
+ self.image_shape = image_shape
275
+ self.padding = padding
276
+
277
+ self.interpolation = interpolation
278
+ if self.interpolation == 0:
279
+ self.interpolation = cv2.INTER_NEAREST
280
+ elif self.interpolation == 1:
281
+ self.interpolation = cv2.INTER_LINEAR
282
+ elif self.interpolation == 2:
283
+ self.interpolation = cv2.INTER_CUBIC
284
+ elif self.interpolation == 3:
285
+ self.interpolation = cv2.INTER_AREA
286
+ else:
287
+ raise Exception("Unsupported interpolation type !!!")
288
+
289
+ def __call__(self, data):
290
+ img = data['image']
291
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
292
+ norm_img, valid_ratio = resize_norm_img(
293
+ img, self.image_shape, self.padding, self.interpolation)
294
+ data['image'] = norm_img
295
+ data['valid_ratio'] = valid_ratio
296
+ return data
297
+
298
+
299
+ class SRNRecResizeImg(object):
300
+ def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
301
+ self.image_shape = image_shape
302
+ self.num_heads = num_heads
303
+ self.max_text_length = max_text_length
304
+
305
+ def __call__(self, data):
306
+ img = data['image']
307
+ norm_img = resize_norm_img_srn(img, self.image_shape)
308
+ data['image'] = norm_img
309
+ [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
310
+ srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
311
+
312
+ data['encoder_word_pos'] = encoder_word_pos
313
+ data['gsrm_word_pos'] = gsrm_word_pos
314
+ data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
315
+ data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
316
+ return data
317
+
318
+
319
+ class SARRecResizeImg(object):
320
+ def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
321
+ self.image_shape = image_shape
322
+ self.width_downsample_ratio = width_downsample_ratio
323
+
324
+ def __call__(self, data):
325
+ img = data['image']
326
+ norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
327
+ img, self.image_shape, self.width_downsample_ratio)
328
+ data['image'] = norm_img
329
+ data['resized_shape'] = resize_shape
330
+ data['pad_shape'] = pad_shape
331
+ data['valid_ratio'] = valid_ratio
332
+ return data
333
+
334
+
335
+ class PRENResizeImg(object):
336
+ def __init__(self, image_shape, **kwargs):
337
+ """
338
+ Accroding to original paper's realization, it's a hard resize method here.
339
+ So maybe you should optimize it to fit for your task better.
340
+ """
341
+ self.dst_h, self.dst_w = image_shape
342
+
343
+ def __call__(self, data):
344
+ img = data['image']
345
+ resized_img = cv2.resize(
346
+ img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR)
347
+ resized_img = resized_img.transpose((2, 0, 1)) / 255
348
+ resized_img -= 0.5
349
+ resized_img /= 0.5
350
+ data['image'] = resized_img.astype(np.float32)
351
+ return data
352
+
353
+
354
+ class SPINRecResizeImg(object):
355
+ def __init__(self,
356
+ image_shape,
357
+ interpolation=2,
358
+ mean=(127.5, 127.5, 127.5),
359
+ std=(127.5, 127.5, 127.5),
360
+ **kwargs):
361
+ self.image_shape = image_shape
362
+
363
+ self.mean = np.array(mean, dtype=np.float32)
364
+ self.std = np.array(std, dtype=np.float32)
365
+ self.interpolation = interpolation
366
+
367
+ def __call__(self, data):
368
+ img = data['image']
369
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
370
+ # different interpolation type corresponding the OpenCV
371
+ if self.interpolation == 0:
372
+ interpolation = cv2.INTER_NEAREST
373
+ elif self.interpolation == 1:
374
+ interpolation = cv2.INTER_LINEAR
375
+ elif self.interpolation == 2:
376
+ interpolation = cv2.INTER_CUBIC
377
+ elif self.interpolation == 3:
378
+ interpolation = cv2.INTER_AREA
379
+ else:
380
+ raise Exception("Unsupported interpolation type !!!")
381
+ # Deal with the image error during image loading
382
+ if img is None:
383
+ return None
384
+
385
+ img = cv2.resize(img, tuple(self.image_shape), interpolation)
386
+ img = np.array(img, np.float32)
387
+ img = np.expand_dims(img, -1)
388
+ img = img.transpose((2, 0, 1))
389
+ # normalize the image
390
+ img = img.copy().astype(np.float32)
391
+ mean = np.float64(self.mean.reshape(1, -1))
392
+ stdinv = 1 / np.float64(self.std.reshape(1, -1))
393
+ img -= mean
394
+ img *= stdinv
395
+ data['image'] = img
396
+ return data
397
+
398
+
399
+ class GrayRecResizeImg(object):
400
+ def __init__(self,
401
+ image_shape,
402
+ resize_type,
403
+ inter_type='Image.ANTIALIAS',
404
+ scale=True,
405
+ padding=False,
406
+ **kwargs):
407
+ self.image_shape = image_shape
408
+ self.resize_type = resize_type
409
+ self.padding = padding
410
+ self.inter_type = eval(inter_type)
411
+ self.scale = scale
412
+
413
+ def __call__(self, data):
414
+ img = data['image']
415
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
416
+ image_shape = self.image_shape
417
+ if self.padding:
418
+ imgC, imgH, imgW = image_shape
419
+ # todo: change to 0 and modified image shape
420
+ h = img.shape[0]
421
+ w = img.shape[1]
422
+ ratio = w / float(h)
423
+ if math.ceil(imgH * ratio) > imgW:
424
+ resized_w = imgW
425
+ else:
426
+ resized_w = int(math.ceil(imgH * ratio))
427
+ resized_image = cv2.resize(img, (resized_w, imgH))
428
+ norm_img = np.expand_dims(resized_image, -1)
429
+ norm_img = norm_img.transpose((2, 0, 1))
430
+ resized_image = norm_img.astype(np.float32) / 128. - 1.
431
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
432
+ padding_im[:, :, 0:resized_w] = resized_image
433
+ data['image'] = padding_im
434
+ return data
435
+ if self.resize_type == 'PIL':
436
+ image_pil = Image.fromarray(np.uint8(img))
437
+ img = image_pil.resize(self.image_shape, self.inter_type)
438
+ img = np.array(img)
439
+ if self.resize_type == 'OpenCV':
440
+ img = cv2.resize(img, self.image_shape)
441
+ norm_img = np.expand_dims(img, -1)
442
+ norm_img = norm_img.transpose((2, 0, 1))
443
+ if self.scale:
444
+ data['image'] = norm_img.astype(np.float32) / 128. - 1.
445
+ else:
446
+ data['image'] = norm_img.astype(np.float32) / 255.
447
+ return data
448
+
449
+
450
+ class ABINetRecResizeImg(object):
451
+ def __init__(self, image_shape, **kwargs):
452
+ self.image_shape = image_shape
453
+
454
+ def __call__(self, data):
455
+ img = data['image']
456
+ norm_img, valid_ratio = resize_norm_img_abinet(img, self.image_shape)
457
+ data['image'] = norm_img
458
+ data['valid_ratio'] = valid_ratio
459
+ return data
460
+
461
+
462
+ class SVTRRecResizeImg(object):
463
+ def __init__(self, image_shape, padding=True, **kwargs):
464
+ self.image_shape = image_shape
465
+ self.padding = padding
466
+
467
+ def __call__(self, data):
468
+ img = data['image']
469
+
470
+ norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
471
+ self.padding)
472
+ data['image'] = norm_img
473
+ data['valid_ratio'] = valid_ratio
474
+ return data
475
+
476
+
477
+ class RobustScannerRecResizeImg(object):
478
+ def __init__(self,
479
+ image_shape,
480
+ max_text_length,
481
+ width_downsample_ratio=0.25,
482
+ **kwargs):
483
+ self.image_shape = image_shape
484
+ self.width_downsample_ratio = width_downsample_ratio
485
+ self.max_text_length = max_text_length
486
+
487
+ def __call__(self, data):
488
+ img = data['image']
489
+ norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
490
+ img, self.image_shape, self.width_downsample_ratio)
491
+ word_positons = np.array(range(0, self.max_text_length)).astype('int64')
492
+ data['image'] = norm_img
493
+ data['resized_shape'] = resize_shape
494
+ data['pad_shape'] = pad_shape
495
+ data['valid_ratio'] = valid_ratio
496
+ data['word_positons'] = word_positons
497
+ return data
498
+
499
+
500
+ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
501
+ imgC, imgH, imgW_min, imgW_max = image_shape
502
+ h = img.shape[0]
503
+ w = img.shape[1]
504
+ valid_ratio = 1.0
505
+ # make sure new_width is an integral multiple of width_divisor.
506
+ width_divisor = int(1 / width_downsample_ratio)
507
+ # resize
508
+ ratio = w / float(h)
509
+ resize_w = math.ceil(imgH * ratio)
510
+ if resize_w % width_divisor != 0:
511
+ resize_w = round(resize_w / width_divisor) * width_divisor
512
+ if imgW_min is not None:
513
+ resize_w = max(imgW_min, resize_w)
514
+ if imgW_max is not None:
515
+ valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
516
+ resize_w = min(imgW_max, resize_w)
517
+ resized_image = cv2.resize(img, (resize_w, imgH))
518
+ resized_image = resized_image.astype('float32')
519
+ # norm
520
+ if image_shape[0] == 1:
521
+ resized_image = resized_image / 255
522
+ resized_image = resized_image[np.newaxis, :]
523
+ else:
524
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
525
+ resized_image -= 0.5
526
+ resized_image /= 0.5
527
+ resize_shape = resized_image.shape
528
+ padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
529
+ padding_im[:, :, 0:resize_w] = resized_image
530
+ pad_shape = padding_im.shape
531
+
532
+ return padding_im, resize_shape, pad_shape, valid_ratio
533
+
534
+
535
+ def resize_norm_img(img,
536
+ image_shape,
537
+ padding=True,
538
+ interpolation=cv2.INTER_LINEAR):
539
+ imgC, imgH, imgW = image_shape
540
+ h = img.shape[0]
541
+ w = img.shape[1]
542
+ if not padding:
543
+ resized_image = cv2.resize(
544
+ img, (imgW, imgH), interpolation=interpolation)
545
+ resized_w = imgW
546
+ else:
547
+ ratio = w / float(h)
548
+ if math.ceil(imgH * ratio) > imgW:
549
+ resized_w = imgW
550
+ else:
551
+ resized_w = int(math.ceil(imgH * ratio))
552
+ resized_image = cv2.resize(img, (resized_w, imgH))
553
+ resized_image = resized_image.astype('float32')
554
+ if image_shape[0] == 1:
555
+ resized_image = resized_image / 255
556
+ resized_image = resized_image[np.newaxis, :]
557
+ else:
558
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
559
+ resized_image -= 0.5
560
+ resized_image /= 0.5
561
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
562
+ padding_im[:, :, 0:resized_w] = resized_image
563
+ valid_ratio = min(1.0, float(resized_w / imgW))
564
+ return padding_im, valid_ratio
565
+
566
+
567
+ def resize_norm_img_chinese(img, image_shape):
568
+ imgC, imgH, imgW = image_shape
569
+ # todo: change to 0 and modified image shape
570
+ max_wh_ratio = imgW * 1.0 / imgH
571
+ h, w = img.shape[0], img.shape[1]
572
+ ratio = w * 1.0 / h
573
+ max_wh_ratio = max(max_wh_ratio, ratio)
574
+ imgW = int(imgH * max_wh_ratio)
575
+ if math.ceil(imgH * ratio) > imgW:
576
+ resized_w = imgW
577
+ else:
578
+ resized_w = int(math.ceil(imgH * ratio))
579
+ resized_image = cv2.resize(img, (resized_w, imgH))
580
+ resized_image = resized_image.astype('float32')
581
+ if image_shape[0] == 1:
582
+ resized_image = resized_image / 255
583
+ resized_image = resized_image[np.newaxis, :]
584
+ else:
585
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
586
+ resized_image -= 0.5
587
+ resized_image /= 0.5
588
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
589
+ padding_im[:, :, 0:resized_w] = resized_image
590
+ valid_ratio = min(1.0, float(resized_w / imgW))
591
+ return padding_im, valid_ratio
592
+
593
+
594
+ def resize_norm_img_srn(img, image_shape):
595
+ imgC, imgH, imgW = image_shape
596
+
597
+ img_black = np.zeros((imgH, imgW))
598
+ im_hei = img.shape[0]
599
+ im_wid = img.shape[1]
600
+
601
+ if im_wid <= im_hei * 1:
602
+ img_new = cv2.resize(img, (imgH * 1, imgH))
603
+ elif im_wid <= im_hei * 2:
604
+ img_new = cv2.resize(img, (imgH * 2, imgH))
605
+ elif im_wid <= im_hei * 3:
606
+ img_new = cv2.resize(img, (imgH * 3, imgH))
607
+ else:
608
+ img_new = cv2.resize(img, (imgW, imgH))
609
+
610
+ img_np = np.asarray(img_new)
611
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
612
+ img_black[:, 0:img_np.shape[1]] = img_np
613
+ img_black = img_black[:, :, np.newaxis]
614
+
615
+ row, col, c = img_black.shape
616
+ c = 1
617
+
618
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
619
+
620
+
621
+ def resize_norm_img_abinet(img, image_shape):
622
+ imgC, imgH, imgW = image_shape
623
+
624
+ resized_image = cv2.resize(
625
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
626
+ resized_w = imgW
627
+ resized_image = resized_image.astype('float32')
628
+ resized_image = resized_image / 255.
629
+
630
+ mean = np.array([0.485, 0.456, 0.406])
631
+ std = np.array([0.229, 0.224, 0.225])
632
+ resized_image = (
633
+ resized_image - mean[None, None, ...]) / std[None, None, ...]
634
+ resized_image = resized_image.transpose((2, 0, 1))
635
+ resized_image = resized_image.astype('float32')
636
+
637
+ valid_ratio = min(1.0, float(resized_w / imgW))
638
+ return resized_image, valid_ratio
639
+
640
+
641
+ def srn_other_inputs(image_shape, num_heads, max_text_length):
642
+
643
+ imgC, imgH, imgW = image_shape
644
+ feature_dim = int((imgH / 8) * (imgW / 8))
645
+
646
+ encoder_word_pos = np.array(range(0, feature_dim)).reshape(
647
+ (feature_dim, 1)).astype('int64')
648
+ gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
649
+ (max_text_length, 1)).astype('int64')
650
+
651
+ gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
652
+ gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
653
+ [1, max_text_length, max_text_length])
654
+ gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
655
+ [num_heads, 1, 1]) * [-1e9]
656
+
657
+ gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
658
+ [1, max_text_length, max_text_length])
659
+ gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
660
+ [num_heads, 1, 1]) * [-1e9]
661
+
662
+ return [
663
+ encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
664
+ gsrm_slf_attn_bias2
665
+ ]
666
+
667
+
668
+ def flag():
669
+ """
670
+ flag
671
+ """
672
+ return 1 if random.random() > 0.5000001 else -1
673
+
674
+
675
+ def hsv_aug(img):
676
+ """
677
+ cvtColor
678
+ """
679
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
680
+ delta = 0.001 * random.random() * flag()
681
+ hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta)
682
+ new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
683
+ return new_img
684
+
685
+
686
+ def blur(img):
687
+ """
688
+ blur
689
+ """
690
+ h, w, _ = img.shape
691
+ if h > 10 and w > 10:
692
+ return cv2.GaussianBlur(img, (5, 5), 1)
693
+ else:
694
+ return img
695
+
696
+
697
+ def jitter(img):
698
+ """
699
+ jitter
700
+ """
701
+ w, h, _ = img.shape
702
+ if h > 10 and w > 10:
703
+ thres = min(w, h)
704
+ s = int(random.random() * thres * 0.01)
705
+ src_img = img.copy()
706
+ for i in range(s):
707
+ img[i:, i:, :] = src_img[:w - i, :h - i, :]
708
+ return img
709
+ else:
710
+ return img
711
+
712
+
713
+ def add_gasuss_noise(image, mean=0, var=0.1):
714
+ """
715
+ Gasuss noise
716
+ """
717
+
718
+ noise = np.random.normal(mean, var**0.5, image.shape)
719
+ out = image + 0.5 * noise
720
+ out = np.clip(out, 0, 255)
721
+ out = np.uint8(out)
722
+ return out
723
+
724
+
725
+ def get_crop(image):
726
+ """
727
+ random crop
728
+ """
729
+ h, w, _ = image.shape
730
+ top_min = 1
731
+ top_max = 8
732
+ top_crop = int(random.randint(top_min, top_max))
733
+ top_crop = min(top_crop, h - 1)
734
+ crop_img = image.copy()
735
+ ratio = random.randint(0, 1)
736
+ if ratio:
737
+ crop_img = crop_img[top_crop:h, :, :]
738
+ else:
739
+ crop_img = crop_img[0:h - top_crop, :, :]
740
+ return crop_img
741
+
742
+
743
+ def rad(x):
744
+ """
745
+ rad
746
+ """
747
+ return x * np.pi / 180
748
+
749
+
750
+ def get_warpR(config):
751
+ """
752
+ get_warpR
753
+ """
754
+ anglex, angley, anglez, fov, w, h, r = \
755
+ config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r
756
+ if w > 69 and w < 112:
757
+ anglex = anglex * 1.5
758
+
759
+ z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2))
760
+ # Homogeneous coordinate transformation matrix
761
+ rx = np.array([[1, 0, 0, 0],
762
+ [0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [
763
+ 0,
764
+ -np.sin(rad(anglex)),
765
+ np.cos(rad(anglex)),
766
+ 0,
767
+ ], [0, 0, 0, 1]], np.float32)
768
+ ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
769
+ [0, 1, 0, 0], [
770
+ -np.sin(rad(angley)),
771
+ 0,
772
+ np.cos(rad(angley)),
773
+ 0,
774
+ ], [0, 0, 0, 1]], np.float32)
775
+ rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
776
+ [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
777
+ [0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
778
+ r = rx.dot(ry).dot(rz)
779
+ # generate 4 points
780
+ pcenter = np.array([h / 2, w / 2, 0, 0], np.float32)
781
+ p1 = np.array([0, 0, 0, 0], np.float32) - pcenter
782
+ p2 = np.array([w, 0, 0, 0], np.float32) - pcenter
783
+ p3 = np.array([0, h, 0, 0], np.float32) - pcenter
784
+ p4 = np.array([w, h, 0, 0], np.float32) - pcenter
785
+ dst1 = r.dot(p1)
786
+ dst2 = r.dot(p2)
787
+ dst3 = r.dot(p3)
788
+ dst4 = r.dot(p4)
789
+ list_dst = np.array([dst1, dst2, dst3, dst4])
790
+ org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
791
+ dst = np.zeros((4, 2), np.float32)
792
+ # Project onto the image plane
793
+ dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0]
794
+ dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1]
795
+
796
+ warpR = cv2.getPerspectiveTransform(org, dst)
797
+
798
+ dst1, dst2, dst3, dst4 = dst
799
+ r1 = int(min(dst1[1], dst2[1]))
800
+ r2 = int(max(dst3[1], dst4[1]))
801
+ c1 = int(min(dst1[0], dst3[0]))
802
+ c2 = int(max(dst2[0], dst4[0]))
803
+
804
+ try:
805
+ ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1))
806
+
807
+ dx = -c1
808
+ dy = -r1
809
+ T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]])
810
+ ret = T1.dot(warpR)
811
+ except:
812
+ ratio = 1.0
813
+ T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
814
+ ret = T1
815
+ return ret, (-r1, -c1), ratio, dst
816
+
817
+
818
+ def get_warpAffine(config):
819
+ """
820
+ get_warpAffine
821
+ """
822
+ anglez = config.anglez
823
+ rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
824
+ [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
825
+ return rz
ppocr/data/imaug/sast_process.py ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ #Licensed under the Apache License, Version 2.0 (the "License");
4
+ #you may not use this file except in compliance with the License.
5
+ #You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ #Unless required by applicable law or agreed to in writing, software
10
+ #distributed under the License is distributed on an "AS IS" BASIS,
11
+ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ #See the License for the specific language governing permissions and
13
+ #limitations under the License.
14
+ """
15
+ This part code is refered from:
16
+ https://github.com/songdejia/EAST/blob/master/data_utils.py
17
+ """
18
+ import math
19
+ import cv2
20
+ import numpy as np
21
+ import json
22
+ import sys
23
+ import os
24
+
25
+ __all__ = ['SASTProcessTrain']
26
+
27
+
28
+ class SASTProcessTrain(object):
29
+ def __init__(self,
30
+ image_shape=[512, 512],
31
+ min_crop_size=24,
32
+ min_crop_side_ratio=0.3,
33
+ min_text_size=10,
34
+ max_text_size=512,
35
+ **kwargs):
36
+ self.input_size = image_shape[1]
37
+ self.min_crop_size = min_crop_size
38
+ self.min_crop_side_ratio = min_crop_side_ratio
39
+ self.min_text_size = min_text_size
40
+ self.max_text_size = max_text_size
41
+
42
+ def quad_area(self, poly):
43
+ """
44
+ compute area of a polygon
45
+ :param poly:
46
+ :return:
47
+ """
48
+ edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
49
+ (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
50
+ (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
51
+ (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
52
+ return np.sum(edge) / 2.
53
+
54
+ def gen_quad_from_poly(self, poly):
55
+ """
56
+ Generate min area quad from poly.
57
+ """
58
+ point_num = poly.shape[0]
59
+ min_area_quad = np.zeros((4, 2), dtype=np.float32)
60
+ if True:
61
+ rect = cv2.minAreaRect(poly.astype(
62
+ np.int32)) # (center (x,y), (width, height), angle of rotation)
63
+ center_point = rect[0]
64
+ box = np.array(cv2.boxPoints(rect))
65
+
66
+ first_point_idx = 0
67
+ min_dist = 1e4
68
+ for i in range(4):
69
+ dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
70
+ np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
71
+ np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
72
+ np.linalg.norm(box[(i + 3) % 4] - poly[-1])
73
+ if dist < min_dist:
74
+ min_dist = dist
75
+ first_point_idx = i
76
+ for i in range(4):
77
+ min_area_quad[i] = box[(first_point_idx + i) % 4]
78
+
79
+ return min_area_quad
80
+
81
+ def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
82
+ """
83
+ check so that the text poly is in the same direction,
84
+ and also filter some invalid polygons
85
+ :param polys:
86
+ :param tags:
87
+ :return:
88
+ """
89
+ (h, w) = xxx_todo_changeme
90
+ if polys.shape[0] == 0:
91
+ return polys, np.array([]), np.array([])
92
+ polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
93
+ polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
94
+
95
+ validated_polys = []
96
+ validated_tags = []
97
+ hv_tags = []
98
+ for poly, tag in zip(polys, tags):
99
+ quad = self.gen_quad_from_poly(poly)
100
+ p_area = self.quad_area(quad)
101
+ if abs(p_area) < 1:
102
+ print('invalid poly')
103
+ continue
104
+ if p_area > 0:
105
+ if tag == False:
106
+ print('poly in wrong direction')
107
+ tag = True # reversed cases should be ignore
108
+ poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
109
+ 1), :]
110
+ quad = quad[(0, 3, 2, 1), :]
111
+
112
+ len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
113
+ quad[2])
114
+ len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
115
+ quad[2])
116
+ hv_tag = 1
117
+
118
+ if len_w * 2.0 < len_h:
119
+ hv_tag = 0
120
+
121
+ validated_polys.append(poly)
122
+ validated_tags.append(tag)
123
+ hv_tags.append(hv_tag)
124
+ return np.array(validated_polys), np.array(validated_tags), np.array(
125
+ hv_tags)
126
+
127
+ def crop_area(self,
128
+ im,
129
+ polys,
130
+ tags,
131
+ hv_tags,
132
+ crop_background=False,
133
+ max_tries=25):
134
+ """
135
+ make random crop from the input image
136
+ :param im:
137
+ :param polys:
138
+ :param tags:
139
+ :param crop_background:
140
+ :param max_tries: 50 -> 25
141
+ :return:
142
+ """
143
+ h, w, _ = im.shape
144
+ pad_h = h // 10
145
+ pad_w = w // 10
146
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
147
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
148
+ for poly in polys:
149
+ poly = np.round(poly, decimals=0).astype(np.int32)
150
+ minx = np.min(poly[:, 0])
151
+ maxx = np.max(poly[:, 0])
152
+ w_array[minx + pad_w:maxx + pad_w] = 1
153
+ miny = np.min(poly[:, 1])
154
+ maxy = np.max(poly[:, 1])
155
+ h_array[miny + pad_h:maxy + pad_h] = 1
156
+ # ensure the cropped area not across a text
157
+ h_axis = np.where(h_array == 0)[0]
158
+ w_axis = np.where(w_array == 0)[0]
159
+ if len(h_axis) == 0 or len(w_axis) == 0:
160
+ return im, polys, tags, hv_tags
161
+ for i in range(max_tries):
162
+ xx = np.random.choice(w_axis, size=2)
163
+ xmin = np.min(xx) - pad_w
164
+ xmax = np.max(xx) - pad_w
165
+ xmin = np.clip(xmin, 0, w - 1)
166
+ xmax = np.clip(xmax, 0, w - 1)
167
+ yy = np.random.choice(h_axis, size=2)
168
+ ymin = np.min(yy) - pad_h
169
+ ymax = np.max(yy) - pad_h
170
+ ymin = np.clip(ymin, 0, h - 1)
171
+ ymax = np.clip(ymax, 0, h - 1)
172
+ # if xmax - xmin < ARGS.min_crop_side_ratio * w or \
173
+ # ymax - ymin < ARGS.min_crop_side_ratio * h:
174
+ if xmax - xmin < self.min_crop_size or \
175
+ ymax - ymin < self.min_crop_size:
176
+ # area too small
177
+ continue
178
+ if polys.shape[0] != 0:
179
+ poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
180
+ & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
181
+ selected_polys = np.where(
182
+ np.sum(poly_axis_in_area, axis=1) == 4)[0]
183
+ else:
184
+ selected_polys = []
185
+ if len(selected_polys) == 0:
186
+ # no text in this area
187
+ if crop_background:
188
+ return im[ymin : ymax + 1, xmin : xmax + 1, :], \
189
+ polys[selected_polys], tags[selected_polys], hv_tags[selected_polys]
190
+ else:
191
+ continue
192
+ im = im[ymin:ymax + 1, xmin:xmax + 1, :]
193
+ polys = polys[selected_polys]
194
+ tags = tags[selected_polys]
195
+ hv_tags = hv_tags[selected_polys]
196
+ polys[:, :, 0] -= xmin
197
+ polys[:, :, 1] -= ymin
198
+ return im, polys, tags, hv_tags
199
+
200
+ return im, polys, tags, hv_tags
201
+
202
+ def generate_direction_map(self, poly_quads, direction_map):
203
+ """
204
+ """
205
+ width_list = []
206
+ height_list = []
207
+ for quad in poly_quads:
208
+ quad_w = (np.linalg.norm(quad[0] - quad[1]) +
209
+ np.linalg.norm(quad[2] - quad[3])) / 2.0
210
+ quad_h = (np.linalg.norm(quad[0] - quad[3]) +
211
+ np.linalg.norm(quad[2] - quad[1])) / 2.0
212
+ width_list.append(quad_w)
213
+ height_list.append(quad_h)
214
+ norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
215
+ average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
216
+
217
+ for quad in poly_quads:
218
+ direct_vector_full = (
219
+ (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
220
+ direct_vector = direct_vector_full / (
221
+ np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
222
+ direction_label = tuple(
223
+ map(float, [
224
+ direct_vector[0], direct_vector[1], 1.0 / (average_height +
225
+ 1e-6)
226
+ ]))
227
+ cv2.fillPoly(direction_map,
228
+ quad.round().astype(np.int32)[np.newaxis, :, :],
229
+ direction_label)
230
+ return direction_map
231
+
232
+ def calculate_average_height(self, poly_quads):
233
+ """
234
+ """
235
+ height_list = []
236
+ for quad in poly_quads:
237
+ quad_h = (np.linalg.norm(quad[0] - quad[3]) +
238
+ np.linalg.norm(quad[2] - quad[1])) / 2.0
239
+ height_list.append(quad_h)
240
+ average_height = max(sum(height_list) / len(height_list), 1.0)
241
+ return average_height
242
+
243
+ def generate_tcl_label(self,
244
+ hw,
245
+ polys,
246
+ tags,
247
+ ds_ratio,
248
+ tcl_ratio=0.3,
249
+ shrink_ratio_of_width=0.15):
250
+ """
251
+ Generate polygon.
252
+ """
253
+ h, w = hw
254
+ h, w = int(h * ds_ratio), int(w * ds_ratio)
255
+ polys = polys * ds_ratio
256
+
257
+ score_map = np.zeros(
258
+ (
259
+ h,
260
+ w, ), dtype=np.float32)
261
+ tbo_map = np.zeros((h, w, 5), dtype=np.float32)
262
+ training_mask = np.ones(
263
+ (
264
+ h,
265
+ w, ), dtype=np.float32)
266
+ direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
267
+ [1, 1, 3]).astype(np.float32)
268
+
269
+ for poly_idx, poly_tag in enumerate(zip(polys, tags)):
270
+ poly = poly_tag[0]
271
+ tag = poly_tag[1]
272
+
273
+ # generate min_area_quad
274
+ min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
275
+ min_area_quad_h = 0.5 * (
276
+ np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
277
+ np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
278
+ min_area_quad_w = 0.5 * (
279
+ np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
280
+ np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
281
+
282
+ if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
283
+ or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
284
+ continue
285
+
286
+ if tag:
287
+ # continue
288
+ cv2.fillPoly(training_mask,
289
+ poly.astype(np.int32)[np.newaxis, :, :], 0.15)
290
+ else:
291
+ tcl_poly = self.poly2tcl(poly, tcl_ratio)
292
+ tcl_quads = self.poly2quads(tcl_poly)
293
+ poly_quads = self.poly2quads(poly)
294
+ # stcl map
295
+ stcl_quads, quad_index = self.shrink_poly_along_width(
296
+ tcl_quads,
297
+ shrink_ratio_of_width=shrink_ratio_of_width,
298
+ expand_height_ratio=1.0 / tcl_ratio)
299
+ # generate tcl map
300
+ cv2.fillPoly(score_map,
301
+ np.round(stcl_quads).astype(np.int32), 1.0)
302
+
303
+ # generate tbo map
304
+ for idx, quad in enumerate(stcl_quads):
305
+ quad_mask = np.zeros((h, w), dtype=np.float32)
306
+ quad_mask = cv2.fillPoly(
307
+ quad_mask,
308
+ np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
309
+ tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
310
+ quad_mask, tbo_map)
311
+ return score_map, tbo_map, training_mask
312
+
313
+ def generate_tvo_and_tco(self,
314
+ hw,
315
+ polys,
316
+ tags,
317
+ tcl_ratio=0.3,
318
+ ds_ratio=0.25):
319
+ """
320
+ Generate tcl map, tvo map and tbo map.
321
+ """
322
+ h, w = hw
323
+ h, w = int(h * ds_ratio), int(w * ds_ratio)
324
+ polys = polys * ds_ratio
325
+ poly_mask = np.zeros((h, w), dtype=np.float32)
326
+
327
+ tvo_map = np.ones((9, h, w), dtype=np.float32)
328
+ tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1))
329
+ tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T
330
+ poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32)
331
+
332
+ # tco map
333
+ tco_map = np.ones((3, h, w), dtype=np.float32)
334
+ tco_map[0] = np.tile(np.arange(0, w), (h, 1))
335
+ tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T
336
+ poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32)
337
+
338
+ poly_short_edge_map = np.ones((h, w), dtype=np.float32)
339
+
340
+ for poly, poly_tag in zip(polys, tags):
341
+
342
+ if poly_tag == True:
343
+ continue
344
+
345
+ # adjust point order for vertical poly
346
+ poly = self.adjust_point(poly)
347
+
348
+ # generate min_area_quad
349
+ min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
350
+ min_area_quad_h = 0.5 * (
351
+ np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
352
+ np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
353
+ min_area_quad_w = 0.5 * (
354
+ np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
355
+ np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
356
+
357
+ # generate tcl map and text, 128 * 128
358
+ tcl_poly = self.poly2tcl(poly, tcl_ratio)
359
+
360
+ # generate poly_tv_xy_map
361
+ for idx in range(4):
362
+ cv2.fillPoly(
363
+ poly_tv_xy_map[2 * idx],
364
+ np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
365
+ float(min(max(min_area_quad[idx, 0], 0), w)))
366
+ cv2.fillPoly(
367
+ poly_tv_xy_map[2 * idx + 1],
368
+ np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
369
+ float(min(max(min_area_quad[idx, 1], 0), h)))
370
+
371
+ # generate poly_tc_xy_map
372
+ for idx in range(2):
373
+ cv2.fillPoly(
374
+ poly_tc_xy_map[idx],
375
+ np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
376
+ float(center_point[idx]))
377
+
378
+ # generate poly_short_edge_map
379
+ cv2.fillPoly(
380
+ poly_short_edge_map,
381
+ np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
382
+ float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
383
+
384
+ # generate poly_mask and training_mask
385
+ cv2.fillPoly(poly_mask,
386
+ np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
387
+ 1)
388
+
389
+ tvo_map *= poly_mask
390
+ tvo_map[:8] -= poly_tv_xy_map
391
+ tvo_map[-1] /= poly_short_edge_map
392
+ tvo_map = tvo_map.transpose((1, 2, 0))
393
+
394
+ tco_map *= poly_mask
395
+ tco_map[:2] -= poly_tc_xy_map
396
+ tco_map[-1] /= poly_short_edge_map
397
+ tco_map = tco_map.transpose((1, 2, 0))
398
+
399
+ return tvo_map, tco_map
400
+
401
+ def adjust_point(self, poly):
402
+ """
403
+ adjust point order.
404
+ """
405
+ point_num = poly.shape[0]
406
+ if point_num == 4:
407
+ len_1 = np.linalg.norm(poly[0] - poly[1])
408
+ len_2 = np.linalg.norm(poly[1] - poly[2])
409
+ len_3 = np.linalg.norm(poly[2] - poly[3])
410
+ len_4 = np.linalg.norm(poly[3] - poly[0])
411
+
412
+ if (len_1 + len_3) * 1.5 < (len_2 + len_4):
413
+ poly = poly[[1, 2, 3, 0], :]
414
+
415
+ elif point_num > 4:
416
+ vector_1 = poly[0] - poly[1]
417
+ vector_2 = poly[1] - poly[2]
418
+ cos_theta = np.dot(vector_1, vector_2) / (
419
+ np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
420
+ theta = np.arccos(np.round(cos_theta, decimals=4))
421
+
422
+ if abs(theta) > (70 / 180 * math.pi):
423
+ index = list(range(1, point_num)) + [0]
424
+ poly = poly[np.array(index), :]
425
+ return poly
426
+
427
+ def gen_min_area_quad_from_poly(self, poly):
428
+ """
429
+ Generate min area quad from poly.
430
+ """
431
+ point_num = poly.shape[0]
432
+ min_area_quad = np.zeros((4, 2), dtype=np.float32)
433
+ if point_num == 4:
434
+ min_area_quad = poly
435
+ center_point = np.sum(poly, axis=0) / 4
436
+ else:
437
+ rect = cv2.minAreaRect(poly.astype(
438
+ np.int32)) # (center (x,y), (width, height), angle of rotation)
439
+ center_point = rect[0]
440
+ box = np.array(cv2.boxPoints(rect))
441
+
442
+ first_point_idx = 0
443
+ min_dist = 1e4
444
+ for i in range(4):
445
+ dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
446
+ np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
447
+ np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
448
+ np.linalg.norm(box[(i + 3) % 4] - poly[-1])
449
+ if dist < min_dist:
450
+ min_dist = dist
451
+ first_point_idx = i
452
+
453
+ for i in range(4):
454
+ min_area_quad[i] = box[(first_point_idx + i) % 4]
455
+
456
+ return min_area_quad, center_point
457
+
458
+ def shrink_quad_along_width(self,
459
+ quad,
460
+ begin_width_ratio=0.,
461
+ end_width_ratio=1.):
462
+ """
463
+ Generate shrink_quad_along_width.
464
+ """
465
+ ratio_pair = np.array(
466
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
467
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
468
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
469
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
470
+
471
+ def shrink_poly_along_width(self,
472
+ quads,
473
+ shrink_ratio_of_width,
474
+ expand_height_ratio=1.0):
475
+ """
476
+ shrink poly with given length.
477
+ """
478
+ upper_edge_list = []
479
+
480
+ def get_cut_info(edge_len_list, cut_len):
481
+ for idx, edge_len in enumerate(edge_len_list):
482
+ cut_len -= edge_len
483
+ if cut_len <= 0.000001:
484
+ ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
485
+ return idx, ratio
486
+
487
+ for quad in quads:
488
+ upper_edge_len = np.linalg.norm(quad[0] - quad[1])
489
+ upper_edge_list.append(upper_edge_len)
490
+
491
+ # length of left edge and right edge.
492
+ left_length = np.linalg.norm(quads[0][0] - quads[0][
493
+ 3]) * expand_height_ratio
494
+ right_length = np.linalg.norm(quads[-1][1] - quads[-1][
495
+ 2]) * expand_height_ratio
496
+
497
+ shrink_length = min(left_length, right_length,
498
+ sum(upper_edge_list)) * shrink_ratio_of_width
499
+ # shrinking length
500
+ upper_len_left = shrink_length
501
+ upper_len_right = sum(upper_edge_list) - shrink_length
502
+
503
+ left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
504
+ left_quad = self.shrink_quad_along_width(
505
+ quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
506
+ right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
507
+ right_quad = self.shrink_quad_along_width(
508
+ quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
509
+
510
+ out_quad_list = []
511
+ if left_idx == right_idx:
512
+ out_quad_list.append(
513
+ [left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
514
+ else:
515
+ out_quad_list.append(left_quad)
516
+ for idx in range(left_idx + 1, right_idx):
517
+ out_quad_list.append(quads[idx])
518
+ out_quad_list.append(right_quad)
519
+
520
+ return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
521
+
522
+ def vector_angle(self, A, B):
523
+ """
524
+ Calculate the angle between vector AB and x-axis positive direction.
525
+ """
526
+ AB = np.array([B[1] - A[1], B[0] - A[0]])
527
+ return np.arctan2(*AB)
528
+
529
+ def theta_line_cross_point(self, theta, point):
530
+ """
531
+ Calculate the line through given point and angle in ax + by + c =0 form.
532
+ """
533
+ x, y = point
534
+ cos = np.cos(theta)
535
+ sin = np.sin(theta)
536
+ return [sin, -cos, cos * y - sin * x]
537
+
538
+ def line_cross_two_point(self, A, B):
539
+ """
540
+ Calculate the line through given point A and B in ax + by + c =0 form.
541
+ """
542
+ angle = self.vector_angle(A, B)
543
+ return self.theta_line_cross_point(angle, A)
544
+
545
+ def average_angle(self, poly):
546
+ """
547
+ Calculate the average angle between left and right edge in given poly.
548
+ """
549
+ p0, p1, p2, p3 = poly
550
+ angle30 = self.vector_angle(p3, p0)
551
+ angle21 = self.vector_angle(p2, p1)
552
+ return (angle30 + angle21) / 2
553
+
554
+ def line_cross_point(self, line1, line2):
555
+ """
556
+ line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
557
+ """
558
+ a1, b1, c1 = line1
559
+ a2, b2, c2 = line2
560
+ d = a1 * b2 - a2 * b1
561
+
562
+ if d == 0:
563
+ #print("line1", line1)
564
+ #print("line2", line2)
565
+ print('Cross point does not exist')
566
+ return np.array([0, 0], dtype=np.float32)
567
+ else:
568
+ x = (b1 * c2 - b2 * c1) / d
569
+ y = (a2 * c1 - a1 * c2) / d
570
+
571
+ return np.array([x, y], dtype=np.float32)
572
+
573
+ def quad2tcl(self, poly, ratio):
574
+ """
575
+ Generate center line by poly clock-wise point. (4, 2)
576
+ """
577
+ ratio_pair = np.array(
578
+ [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
579
+ p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
580
+ p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
581
+ return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
582
+
583
+ def poly2tcl(self, poly, ratio):
584
+ """
585
+ Generate center line by poly clock-wise point.
586
+ """
587
+ ratio_pair = np.array(
588
+ [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
589
+ tcl_poly = np.zeros_like(poly)
590
+ point_num = poly.shape[0]
591
+
592
+ for idx in range(point_num // 2):
593
+ point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
594
+ ) * ratio_pair
595
+ tcl_poly[idx] = point_pair[0]
596
+ tcl_poly[point_num - 1 - idx] = point_pair[1]
597
+ return tcl_poly
598
+
599
+ def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
600
+ """
601
+ Generate tbo_map for give quad.
602
+ """
603
+ # upper and lower line function: ax + by + c = 0;
604
+ up_line = self.line_cross_two_point(quad[0], quad[1])
605
+ lower_line = self.line_cross_two_point(quad[3], quad[2])
606
+
607
+ quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
608
+ np.linalg.norm(quad[1] - quad[2]))
609
+ quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
610
+ np.linalg.norm(quad[2] - quad[3]))
611
+
612
+ # average angle of left and right line.
613
+ angle = self.average_angle(quad)
614
+
615
+ xy_in_poly = np.argwhere(tcl_mask == 1)
616
+ for y, x in xy_in_poly:
617
+ point = (x, y)
618
+ line = self.theta_line_cross_point(angle, point)
619
+ cross_point_upper = self.line_cross_point(up_line, line)
620
+ cross_point_lower = self.line_cross_point(lower_line, line)
621
+ ##FIX, offset reverse
622
+ upper_offset_x, upper_offset_y = cross_point_upper - point
623
+ lower_offset_x, lower_offset_y = cross_point_lower - point
624
+ tbo_map[y, x, 0] = upper_offset_y
625
+ tbo_map[y, x, 1] = upper_offset_x
626
+ tbo_map[y, x, 2] = lower_offset_y
627
+ tbo_map[y, x, 3] = lower_offset_x
628
+ tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
629
+ return tbo_map
630
+
631
+ def poly2quads(self, poly):
632
+ """
633
+ Split poly into quads.
634
+ """
635
+ quad_list = []
636
+ point_num = poly.shape[0]
637
+
638
+ # point pair
639
+ point_pair_list = []
640
+ for idx in range(point_num // 2):
641
+ point_pair = [poly[idx], poly[point_num - 1 - idx]]
642
+ point_pair_list.append(point_pair)
643
+
644
+ quad_num = point_num // 2 - 1
645
+ for idx in range(quad_num):
646
+ # reshape and adjust to clock-wise
647
+ quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
648
+ ).reshape(4, 2)[[0, 2, 3, 1]])
649
+
650
+ return np.array(quad_list)
651
+
652
+ def __call__(self, data):
653
+ im = data['image']
654
+ text_polys = data['polys']
655
+ text_tags = data['ignore_tags']
656
+ if im is None:
657
+ return None
658
+ if text_polys.shape[0] == 0:
659
+ return None
660
+
661
+ h, w, _ = im.shape
662
+ text_polys, text_tags, hv_tags = self.check_and_validate_polys(
663
+ text_polys, text_tags, (h, w))
664
+
665
+ if text_polys.shape[0] == 0:
666
+ return None
667
+
668
+ #set aspect ratio and keep area fix
669
+ asp_scales = np.arange(1.0, 1.55, 0.1)
670
+ asp_scale = np.random.choice(asp_scales)
671
+
672
+ if np.random.rand() < 0.5:
673
+ asp_scale = 1.0 / asp_scale
674
+ asp_scale = math.sqrt(asp_scale)
675
+
676
+ asp_wx = asp_scale
677
+ asp_hy = 1.0 / asp_scale
678
+ im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
679
+ text_polys[:, :, 0] *= asp_wx
680
+ text_polys[:, :, 1] *= asp_hy
681
+
682
+ h, w, _ = im.shape
683
+ if max(h, w) > 2048:
684
+ rd_scale = 2048.0 / max(h, w)
685
+ im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
686
+ text_polys *= rd_scale
687
+ h, w, _ = im.shape
688
+ if min(h, w) < 16:
689
+ return None
690
+
691
+ #no background
692
+ im, text_polys, text_tags, hv_tags = self.crop_area(im, \
693
+ text_polys, text_tags, hv_tags, crop_background=False)
694
+
695
+ if text_polys.shape[0] == 0:
696
+ return None
697
+ #continue for all ignore case
698
+ if np.sum((text_tags * 1.0)) >= text_tags.size:
699
+ return None
700
+ new_h, new_w, _ = im.shape
701
+ if (new_h is None) or (new_w is None):
702
+ return None
703
+ #resize image
704
+ std_ratio = float(self.input_size) / max(new_w, new_h)
705
+ rand_scales = np.array(
706
+ [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
707
+ rz_scale = std_ratio * np.random.choice(rand_scales)
708
+ im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
709
+ text_polys[:, :, 0] *= rz_scale
710
+ text_polys[:, :, 1] *= rz_scale
711
+
712
+ #add gaussian blur
713
+ if np.random.rand() < 0.1 * 0.5:
714
+ ks = np.random.permutation(5)[0] + 1
715
+ ks = int(ks / 2) * 2 + 1
716
+ im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
717
+ #add brighter
718
+ if np.random.rand() < 0.1 * 0.5:
719
+ im = im * (1.0 + np.random.rand() * 0.5)
720
+ im = np.clip(im, 0.0, 255.0)
721
+ #add darker
722
+ if np.random.rand() < 0.1 * 0.5:
723
+ im = im * (1.0 - np.random.rand() * 0.5)
724
+ im = np.clip(im, 0.0, 255.0)
725
+
726
+ # Padding the im to [input_size, input_size]
727
+ new_h, new_w, _ = im.shape
728
+ if min(new_w, new_h) < self.input_size * 0.5:
729
+ return None
730
+
731
+ im_padded = np.ones(
732
+ (self.input_size, self.input_size, 3), dtype=np.float32)
733
+ im_padded[:, :, 2] = 0.485 * 255
734
+ im_padded[:, :, 1] = 0.456 * 255
735
+ im_padded[:, :, 0] = 0.406 * 255
736
+
737
+ # Random the start position
738
+ del_h = self.input_size - new_h
739
+ del_w = self.input_size - new_w
740
+ sh, sw = 0, 0
741
+ if del_h > 1:
742
+ sh = int(np.random.rand() * del_h)
743
+ if del_w > 1:
744
+ sw = int(np.random.rand() * del_w)
745
+
746
+ # Padding
747
+ im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
748
+ text_polys[:, :, 0] += sw
749
+ text_polys[:, :, 1] += sh
750
+
751
+ score_map, border_map, training_mask = self.generate_tcl_label(
752
+ (self.input_size, self.input_size), text_polys, text_tags, 0.25)
753
+
754
+ # SAST head
755
+ tvo_map, tco_map = self.generate_tvo_and_tco(
756
+ (self.input_size, self.input_size),
757
+ text_polys,
758
+ text_tags,
759
+ tcl_ratio=0.3,
760
+ ds_ratio=0.25)
761
+ # print("test--------tvo_map shape:", tvo_map.shape)
762
+
763
+ im_padded[:, :, 2] -= 0.485 * 255
764
+ im_padded[:, :, 1] -= 0.456 * 255
765
+ im_padded[:, :, 0] -= 0.406 * 255
766
+ im_padded[:, :, 2] /= (255.0 * 0.229)
767
+ im_padded[:, :, 1] /= (255.0 * 0.224)
768
+ im_padded[:, :, 0] /= (255.0 * 0.225)
769
+ im_padded = im_padded.transpose((2, 0, 1))
770
+
771
+ data['image'] = im_padded[::-1, :, :]
772
+ data['score_map'] = score_map[np.newaxis, :, :]
773
+ data['border_map'] = border_map.transpose((2, 0, 1))
774
+ data['training_mask'] = training_mask[np.newaxis, :, :]
775
+ data['tvo_map'] = tvo_map.transpose((2, 0, 1))
776
+ data['tco_map'] = tco_map.transpose((2, 0, 1))
777
+ return data
ppocr/data/imaug/ssl_img_aug.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import cv2
17
+ import numpy as np
18
+ import random
19
+ from PIL import Image
20
+
21
+ from .rec_img_aug import resize_norm_img
22
+
23
+
24
+ class SSLRotateResize(object):
25
+ def __init__(self,
26
+ image_shape,
27
+ padding=False,
28
+ select_all=True,
29
+ mode="train",
30
+ **kwargs):
31
+ self.image_shape = image_shape
32
+ self.padding = padding
33
+ self.select_all = select_all
34
+ self.mode = mode
35
+
36
+ def __call__(self, data):
37
+ img = data["image"]
38
+
39
+ data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
40
+ data["image_r180"] = cv2.rotate(data["image_r90"],
41
+ cv2.ROTATE_90_CLOCKWISE)
42
+ data["image_r270"] = cv2.rotate(data["image_r180"],
43
+ cv2.ROTATE_90_CLOCKWISE)
44
+
45
+ images = []
46
+ for key in ["image", "image_r90", "image_r180", "image_r270"]:
47
+ images.append(
48
+ resize_norm_img(
49
+ data.pop(key),
50
+ image_shape=self.image_shape,
51
+ padding=self.padding)[0])
52
+ data["image"] = np.stack(images, axis=0)
53
+ data["label"] = np.array(list(range(4)))
54
+ if not self.select_all:
55
+ data["image"] = data["image"][0::2] # just choose 0 and 180
56
+ data["label"] = data["label"][0:2] # label needs to be continuous
57
+ if self.mode == "test":
58
+ data["image"] = data["image"][0]
59
+ data["label"] = data["label"][0]
60
+ return data
ppocr/data/imaug/table_ops.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+ from __future__ import unicode_literals
21
+
22
+ import sys
23
+ import six
24
+ import cv2
25
+ import numpy as np
26
+
27
+
28
+ class GenTableMask(object):
29
+ """ gen table mask """
30
+
31
+ def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
32
+ self.shrink_h_max = 5
33
+ self.shrink_w_max = 5
34
+ self.mask_type = mask_type
35
+
36
+ def projection(self, erosion, h, w, spilt_threshold=0):
37
+ # 水平投影
38
+ projection_map = np.ones_like(erosion)
39
+ project_val_array = [0 for _ in range(0, h)]
40
+
41
+ for j in range(0, h):
42
+ for i in range(0, w):
43
+ if erosion[j, i] == 255:
44
+ project_val_array[j] += 1
45
+ # 根据数组,获取切割点
46
+ start_idx = 0 # 记录进入字符区的索引
47
+ end_idx = 0 # 记录进入空白区域的索引
48
+ in_text = False # 是否遍历到了字符区内
49
+ box_list = []
50
+ for i in range(len(project_val_array)):
51
+ if in_text == False and project_val_array[
52
+ i] > spilt_threshold: # 进入字符区了
53
+ in_text = True
54
+ start_idx = i
55
+ elif project_val_array[
56
+ i] <= spilt_threshold and in_text == True: # 进入空白区了
57
+ end_idx = i
58
+ in_text = False
59
+ if end_idx - start_idx <= 2:
60
+ continue
61
+ box_list.append((start_idx, end_idx + 1))
62
+
63
+ if in_text:
64
+ box_list.append((start_idx, h - 1))
65
+ # 绘制投影直方图
66
+ for j in range(0, h):
67
+ for i in range(0, project_val_array[j]):
68
+ projection_map[j, i] = 0
69
+ return box_list, projection_map
70
+
71
+ def projection_cx(self, box_img):
72
+ box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
73
+ h, w = box_gray_img.shape
74
+ # 灰度图片进行二值化处理
75
+ ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
76
+ cv2.THRESH_BINARY_INV)
77
+ # 纵向腐蚀
78
+ if h < w:
79
+ kernel = np.ones((2, 1), np.uint8)
80
+ erode = cv2.erode(thresh1, kernel, iterations=1)
81
+ else:
82
+ erode = thresh1
83
+ # 水平膨胀
84
+ kernel = np.ones((1, 5), np.uint8)
85
+ erosion = cv2.dilate(erode, kernel, iterations=1)
86
+ # 水平投影
87
+ projection_map = np.ones_like(erosion)
88
+ project_val_array = [0 for _ in range(0, h)]
89
+
90
+ for j in range(0, h):
91
+ for i in range(0, w):
92
+ if erosion[j, i] == 255:
93
+ project_val_array[j] += 1
94
+ # 根据数组,获取切割点
95
+ start_idx = 0 # 记录进入字符区的索引
96
+ end_idx = 0 # 记录进入空白区域的索引
97
+ in_text = False # 是否遍历到了字符区内
98
+ box_list = []
99
+ spilt_threshold = 0
100
+ for i in range(len(project_val_array)):
101
+ if in_text == False and project_val_array[
102
+ i] > spilt_threshold: # 进入字符区了
103
+ in_text = True
104
+ start_idx = i
105
+ elif project_val_array[
106
+ i] <= spilt_threshold and in_text == True: # 进入空白区了
107
+ end_idx = i
108
+ in_text = False
109
+ if end_idx - start_idx <= 2:
110
+ continue
111
+ box_list.append((start_idx, end_idx + 1))
112
+
113
+ if in_text:
114
+ box_list.append((start_idx, h - 1))
115
+ # 绘制投影直方图
116
+ for j in range(0, h):
117
+ for i in range(0, project_val_array[j]):
118
+ projection_map[j, i] = 0
119
+ split_bbox_list = []
120
+ if len(box_list) > 1:
121
+ for i, (h_start, h_end) in enumerate(box_list):
122
+ if i == 0:
123
+ h_start = 0
124
+ if i == len(box_list):
125
+ h_end = h
126
+ word_img = erosion[h_start:h_end + 1, :]
127
+ word_h, word_w = word_img.shape
128
+ w_split_list, w_projection_map = self.projection(word_img.T,
129
+ word_w, word_h)
130
+ w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
131
+ if h_start > 0:
132
+ h_start -= 1
133
+ h_end += 1
134
+ word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
135
+ split_bbox_list.append([w_start, h_start, w_end, h_end])
136
+ else:
137
+ split_bbox_list.append([0, 0, w, h])
138
+ return split_bbox_list
139
+
140
+ def shrink_bbox(self, bbox):
141
+ left, top, right, bottom = bbox
142
+ sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
143
+ sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
144
+ left_new = left + sh_w
145
+ right_new = right - sh_w
146
+ top_new = top + sh_h
147
+ bottom_new = bottom - sh_h
148
+ if left_new >= right_new:
149
+ left_new = left
150
+ right_new = right
151
+ if top_new >= bottom_new:
152
+ top_new = top
153
+ bottom_new = bottom
154
+ return [left_new, top_new, right_new, bottom_new]
155
+
156
+ def __call__(self, data):
157
+ img = data['image']
158
+ cells = data['cells']
159
+ height, width = img.shape[0:2]
160
+ if self.mask_type == 1:
161
+ mask_img = np.zeros((height, width), dtype=np.float32)
162
+ else:
163
+ mask_img = np.zeros((height, width, 3), dtype=np.float32)
164
+ cell_num = len(cells)
165
+ for cno in range(cell_num):
166
+ if "bbox" in cells[cno]:
167
+ bbox = cells[cno]['bbox']
168
+ left, top, right, bottom = bbox
169
+ box_img = img[top:bottom, left:right, :].copy()
170
+ split_bbox_list = self.projection_cx(box_img)
171
+ for sno in range(len(split_bbox_list)):
172
+ split_bbox_list[sno][0] += left
173
+ split_bbox_list[sno][1] += top
174
+ split_bbox_list[sno][2] += left
175
+ split_bbox_list[sno][3] += top
176
+
177
+ for sno in range(len(split_bbox_list)):
178
+ left, top, right, bottom = split_bbox_list[sno]
179
+ left, top, right, bottom = self.shrink_bbox(
180
+ [left, top, right, bottom])
181
+ if self.mask_type == 1:
182
+ mask_img[top:bottom, left:right] = 1.0
183
+ data['mask_img'] = mask_img
184
+ else:
185
+ mask_img[top:bottom, left:right, :] = (255, 255, 255)
186
+ data['image'] = mask_img
187
+ return data
188
+
189
+
190
+ class ResizeTableImage(object):
191
+ def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
192
+ **kwargs):
193
+ super(ResizeTableImage, self).__init__()
194
+ self.max_len = max_len
195
+ self.resize_bboxes = resize_bboxes
196
+ self.infer_mode = infer_mode
197
+
198
+ def __call__(self, data):
199
+ img = data['image']
200
+ height, width = img.shape[0:2]
201
+ ratio = self.max_len / (max(height, width) * 1.0)
202
+ resize_h = int(height * ratio)
203
+ resize_w = int(width * ratio)
204
+ resize_img = cv2.resize(img, (resize_w, resize_h))
205
+ if self.resize_bboxes and not self.infer_mode:
206
+ data['bboxes'] = data['bboxes'] * ratio
207
+ data['image'] = resize_img
208
+ data['src_img'] = img
209
+ data['shape'] = np.array([height, width, ratio, ratio])
210
+ data['max_len'] = self.max_len
211
+ return data
212
+
213
+
214
+ class PaddingTableImage(object):
215
+ def __init__(self, size, **kwargs):
216
+ super(PaddingTableImage, self).__init__()
217
+ self.size = size
218
+
219
+ def __call__(self, data):
220
+ img = data['image']
221
+ pad_h, pad_w = self.size
222
+ padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
223
+ height, width = img.shape[0:2]
224
+ padding_img[0:height, 0:width, :] = img.copy()
225
+ data['image'] = padding_img
226
+ shape = data['shape'].tolist()
227
+ shape.extend([pad_h, pad_w])
228
+ data['shape'] = np.array(shape)
229
+ return data
ppocr/data/imaug/text_image_aug/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .augment import tia_perspective, tia_distort, tia_stretch
16
+
17
+ __all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
ppocr/data/imaug/text_image_aug/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (260 Bytes). View file
 
ppocr/data/imaug/text_image_aug/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (297 Bytes). View file
 
ppocr/data/imaug/text_image_aug/__pycache__/augment.cpython-37.pyc ADDED
Binary file (2.15 kB). View file
 
ppocr/data/imaug/text_image_aug/__pycache__/augment.cpython-38.pyc ADDED
Binary file (2.19 kB). View file
 
ppocr/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-37.pyc ADDED
Binary file (3.89 kB). View file
 
ppocr/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-38.pyc ADDED
Binary file (3.96 kB). View file
 
ppocr/data/imaug/text_image_aug/augment.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This code is refer from:
16
+ https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py
17
+ """
18
+
19
+ import numpy as np
20
+ from .warp_mls import WarpMLS
21
+
22
+
23
+ def tia_distort(src, segment=4):
24
+ img_h, img_w = src.shape[:2]
25
+
26
+ cut = img_w // segment
27
+ thresh = cut // 3
28
+
29
+ src_pts = list()
30
+ dst_pts = list()
31
+
32
+ src_pts.append([0, 0])
33
+ src_pts.append([img_w, 0])
34
+ src_pts.append([img_w, img_h])
35
+ src_pts.append([0, img_h])
36
+
37
+ dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
38
+ dst_pts.append(
39
+ [img_w - np.random.randint(thresh), np.random.randint(thresh)])
40
+ dst_pts.append(
41
+ [img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
42
+ dst_pts.append(
43
+ [np.random.randint(thresh), img_h - np.random.randint(thresh)])
44
+
45
+ half_thresh = thresh * 0.5
46
+
47
+ for cut_idx in np.arange(1, segment, 1):
48
+ src_pts.append([cut * cut_idx, 0])
49
+ src_pts.append([cut * cut_idx, img_h])
50
+ dst_pts.append([
51
+ cut * cut_idx + np.random.randint(thresh) - half_thresh,
52
+ np.random.randint(thresh) - half_thresh
53
+ ])
54
+ dst_pts.append([
55
+ cut * cut_idx + np.random.randint(thresh) - half_thresh,
56
+ img_h + np.random.randint(thresh) - half_thresh
57
+ ])
58
+
59
+ trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
60
+ dst = trans.generate()
61
+
62
+ return dst
63
+
64
+
65
+ def tia_stretch(src, segment=4):
66
+ img_h, img_w = src.shape[:2]
67
+
68
+ cut = img_w // segment
69
+ thresh = cut * 4 // 5
70
+
71
+ src_pts = list()
72
+ dst_pts = list()
73
+
74
+ src_pts.append([0, 0])
75
+ src_pts.append([img_w, 0])
76
+ src_pts.append([img_w, img_h])
77
+ src_pts.append([0, img_h])
78
+
79
+ dst_pts.append([0, 0])
80
+ dst_pts.append([img_w, 0])
81
+ dst_pts.append([img_w, img_h])
82
+ dst_pts.append([0, img_h])
83
+
84
+ half_thresh = thresh * 0.5
85
+
86
+ for cut_idx in np.arange(1, segment, 1):
87
+ move = np.random.randint(thresh) - half_thresh
88
+ src_pts.append([cut * cut_idx, 0])
89
+ src_pts.append([cut * cut_idx, img_h])
90
+ dst_pts.append([cut * cut_idx + move, 0])
91
+ dst_pts.append([cut * cut_idx + move, img_h])
92
+
93
+ trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
94
+ dst = trans.generate()
95
+
96
+ return dst
97
+
98
+
99
+ def tia_perspective(src):
100
+ img_h, img_w = src.shape[:2]
101
+
102
+ thresh = img_h // 2
103
+
104
+ src_pts = list()
105
+ dst_pts = list()
106
+
107
+ src_pts.append([0, 0])
108
+ src_pts.append([img_w, 0])
109
+ src_pts.append([img_w, img_h])
110
+ src_pts.append([0, img_h])
111
+
112
+ dst_pts.append([0, np.random.randint(thresh)])
113
+ dst_pts.append([img_w, np.random.randint(thresh)])
114
+ dst_pts.append([img_w, img_h - np.random.randint(thresh)])
115
+ dst_pts.append([0, img_h - np.random.randint(thresh)])
116
+
117
+ trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
118
+ dst = trans.generate()
119
+
120
+ return dst
ppocr/data/imaug/text_image_aug/warp_mls.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This code is refer from:
16
+ https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py
17
+ """
18
+
19
+ import numpy as np
20
+
21
+
22
+ class WarpMLS:
23
+ def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
24
+ self.src = src
25
+ self.src_pts = src_pts
26
+ self.dst_pts = dst_pts
27
+ self.pt_count = len(self.dst_pts)
28
+ self.dst_w = dst_w
29
+ self.dst_h = dst_h
30
+ self.trans_ratio = trans_ratio
31
+ self.grid_size = 100
32
+ self.rdx = np.zeros((self.dst_h, self.dst_w))
33
+ self.rdy = np.zeros((self.dst_h, self.dst_w))
34
+
35
+ @staticmethod
36
+ def __bilinear_interp(x, y, v11, v12, v21, v22):
37
+ return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
38
+ (1 - y) + v22 * y) * x
39
+
40
+ def generate(self):
41
+ self.calc_delta()
42
+ return self.gen_img()
43
+
44
+ def calc_delta(self):
45
+ w = np.zeros(self.pt_count, dtype=np.float32)
46
+
47
+ if self.pt_count < 2:
48
+ return
49
+
50
+ i = 0
51
+ while 1:
52
+ if self.dst_w <= i < self.dst_w + self.grid_size - 1:
53
+ i = self.dst_w - 1
54
+ elif i >= self.dst_w:
55
+ break
56
+
57
+ j = 0
58
+ while 1:
59
+ if self.dst_h <= j < self.dst_h + self.grid_size - 1:
60
+ j = self.dst_h - 1
61
+ elif j >= self.dst_h:
62
+ break
63
+
64
+ sw = 0
65
+ swp = np.zeros(2, dtype=np.float32)
66
+ swq = np.zeros(2, dtype=np.float32)
67
+ new_pt = np.zeros(2, dtype=np.float32)
68
+ cur_pt = np.array([i, j], dtype=np.float32)
69
+
70
+ k = 0
71
+ for k in range(self.pt_count):
72
+ if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
73
+ break
74
+
75
+ w[k] = 1. / (
76
+ (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
77
+ (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
78
+
79
+ sw += w[k]
80
+ swp = swp + w[k] * np.array(self.dst_pts[k])
81
+ swq = swq + w[k] * np.array(self.src_pts[k])
82
+
83
+ if k == self.pt_count - 1:
84
+ pstar = 1 / sw * swp
85
+ qstar = 1 / sw * swq
86
+
87
+ miu_s = 0
88
+ for k in range(self.pt_count):
89
+ if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
90
+ continue
91
+ pt_i = self.dst_pts[k] - pstar
92
+ miu_s += w[k] * np.sum(pt_i * pt_i)
93
+
94
+ cur_pt -= pstar
95
+ cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
96
+
97
+ for k in range(self.pt_count):
98
+ if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
99
+ continue
100
+
101
+ pt_i = self.dst_pts[k] - pstar
102
+ pt_j = np.array([-pt_i[1], pt_i[0]])
103
+
104
+ tmp_pt = np.zeros(2, dtype=np.float32)
105
+ tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
106
+ np.sum(pt_j * cur_pt) * self.src_pts[k][1]
107
+ tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
108
+ np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
109
+ tmp_pt *= (w[k] / miu_s)
110
+ new_pt += tmp_pt
111
+
112
+ new_pt += qstar
113
+ else:
114
+ new_pt = self.src_pts[k]
115
+
116
+ self.rdx[j, i] = new_pt[0] - i
117
+ self.rdy[j, i] = new_pt[1] - j
118
+
119
+ j += self.grid_size
120
+ i += self.grid_size
121
+
122
+ def gen_img(self):
123
+ src_h, src_w = self.src.shape[:2]
124
+ dst = np.zeros_like(self.src, dtype=np.float32)
125
+
126
+ for i in np.arange(0, self.dst_h, self.grid_size):
127
+ for j in np.arange(0, self.dst_w, self.grid_size):
128
+ ni = i + self.grid_size
129
+ nj = j + self.grid_size
130
+ w = h = self.grid_size
131
+ if ni >= self.dst_h:
132
+ ni = self.dst_h - 1
133
+ h = ni - i + 1
134
+ if nj >= self.dst_w:
135
+ nj = self.dst_w - 1
136
+ w = nj - j + 1
137
+
138
+ di = np.reshape(np.arange(h), (-1, 1))
139
+ dj = np.reshape(np.arange(w), (1, -1))
140
+ delta_x = self.__bilinear_interp(
141
+ di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
142
+ self.rdx[ni, j], self.rdx[ni, nj])
143
+ delta_y = self.__bilinear_interp(
144
+ di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
145
+ self.rdy[ni, j], self.rdy[ni, nj])
146
+ nx = j + dj + delta_x * self.trans_ratio
147
+ ny = i + di + delta_y * self.trans_ratio
148
+ nx = np.clip(nx, 0, src_w - 1)
149
+ ny = np.clip(ny, 0, src_h - 1)
150
+ nxi = np.array(np.floor(nx), dtype=np.int32)
151
+ nyi = np.array(np.floor(ny), dtype=np.int32)
152
+ nxi1 = np.array(np.ceil(nx), dtype=np.int32)
153
+ nyi1 = np.array(np.ceil(ny), dtype=np.int32)
154
+
155
+ if len(self.src.shape) == 3:
156
+ x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
157
+ y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
158
+ else:
159
+ x = ny - nyi
160
+ y = nx - nxi
161
+ dst[i:i + h, j:j + w] = self.__bilinear_interp(
162
+ x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
163
+ self.src[nyi1, nxi], self.src[nyi1, nxi1])
164
+
165
+ dst = np.clip(dst, 0, 255)
166
+ dst = np.array(dst, dtype=np.uint8)
167
+
168
+ return dst
ppocr/data/imaug/vqa/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation, TensorizeEntitiesRelations
16
+
17
+ __all__ = [
18
+ 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation',
19
+ 'TensorizeEntitiesRelations'
20
+ ]
ppocr/data/imaug/vqa/augment.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+ import numpy as np
18
+ import random
19
+ from copy import deepcopy
20
+
21
+
22
+ def order_by_tbyx(ocr_info):
23
+ res = sorted(ocr_info, key=lambda r: (r["bbox"][1], r["bbox"][0]))
24
+ for i in range(len(res) - 1):
25
+ for j in range(i, 0, -1):
26
+ if abs(res[j + 1]["bbox"][1] - res[j]["bbox"][1]) < 20 and \
27
+ (res[j + 1]["bbox"][0] < res[j]["bbox"][0]):
28
+ tmp = deepcopy(res[j])
29
+ res[j] = deepcopy(res[j + 1])
30
+ res[j + 1] = deepcopy(tmp)
31
+ else:
32
+ break
33
+ return res
ppocr/data/imaug/vqa/token/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
16
+ from .vqa_token_pad import VQATokenPad
17
+ from .vqa_token_relation import VQAReTokenRelation
18
+ from .vqa_re_convert import TensorizeEntitiesRelations
ppocr/data/imaug/vqa/token/vqa_re_convert.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+
17
+
18
+ class TensorizeEntitiesRelations(object):
19
+ def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
20
+ self.max_seq_len = max_seq_len
21
+ self.infer_mode = infer_mode
22
+
23
+ def __call__(self, data):
24
+ entities = data['entities']
25
+ relations = data['relations']
26
+
27
+ entities_new = np.full(
28
+ shape=[self.max_seq_len + 1, 3], fill_value=-1, dtype='int64')
29
+ entities_new[0, 0] = len(entities['start'])
30
+ entities_new[0, 1] = len(entities['end'])
31
+ entities_new[0, 2] = len(entities['label'])
32
+ entities_new[1:len(entities['start']) + 1, 0] = np.array(entities[
33
+ 'start'])
34
+ entities_new[1:len(entities['end']) + 1, 1] = np.array(entities['end'])
35
+ entities_new[1:len(entities['label']) + 1, 2] = np.array(entities[
36
+ 'label'])
37
+
38
+ relations_new = np.full(
39
+ shape=[self.max_seq_len * self.max_seq_len + 1, 2],
40
+ fill_value=-1,
41
+ dtype='int64')
42
+ relations_new[0, 0] = len(relations['head'])
43
+ relations_new[0, 1] = len(relations['tail'])
44
+ relations_new[1:len(relations['head']) + 1, 0] = np.array(relations[
45
+ 'head'])
46
+ relations_new[1:len(relations['tail']) + 1, 1] = np.array(relations[
47
+ 'tail'])
48
+
49
+ data['entities'] = entities_new
50
+ data['relations'] = relations_new
51
+ return data
ppocr/data/imaug/vqa/token/vqa_token_chunk.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+
17
+
18
+ class VQASerTokenChunk(object):
19
+ def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
20
+ self.max_seq_len = max_seq_len
21
+ self.infer_mode = infer_mode
22
+
23
+ def __call__(self, data):
24
+ encoded_inputs_all = []
25
+ seq_len = len(data['input_ids'])
26
+ for index in range(0, seq_len, self.max_seq_len):
27
+ chunk_beg = index
28
+ chunk_end = min(index + self.max_seq_len, seq_len)
29
+ encoded_inputs_example = {}
30
+ for key in data:
31
+ if key in [
32
+ 'label', 'input_ids', 'labels', 'token_type_ids',
33
+ 'bbox', 'attention_mask'
34
+ ]:
35
+ if self.infer_mode and key == 'labels':
36
+ encoded_inputs_example[key] = data[key]
37
+ else:
38
+ encoded_inputs_example[key] = data[key][chunk_beg:
39
+ chunk_end]
40
+ else:
41
+ encoded_inputs_example[key] = data[key]
42
+
43
+ encoded_inputs_all.append(encoded_inputs_example)
44
+ if len(encoded_inputs_all) == 0:
45
+ return None
46
+ return encoded_inputs_all[0]
47
+
48
+
49
+ class VQAReTokenChunk(object):
50
+ def __init__(self,
51
+ max_seq_len=512,
52
+ entities_labels=None,
53
+ infer_mode=False,
54
+ **kwargs):
55
+ self.max_seq_len = max_seq_len
56
+ self.entities_labels = {
57
+ 'HEADER': 0,
58
+ 'QUESTION': 1,
59
+ 'ANSWER': 2
60
+ } if entities_labels is None else entities_labels
61
+ self.infer_mode = infer_mode
62
+
63
+ def __call__(self, data):
64
+ # prepare data
65
+ entities = data.pop('entities')
66
+ relations = data.pop('relations')
67
+ encoded_inputs_all = []
68
+ for index in range(0, len(data["input_ids"]), self.max_seq_len):
69
+ item = {}
70
+ for key in data:
71
+ if key in [
72
+ 'label', 'input_ids', 'labels', 'token_type_ids',
73
+ 'bbox', 'attention_mask'
74
+ ]:
75
+ if self.infer_mode and key == 'labels':
76
+ item[key] = data[key]
77
+ else:
78
+ item[key] = data[key][index:index + self.max_seq_len]
79
+ else:
80
+ item[key] = data[key]
81
+ # select entity in current chunk
82
+ entities_in_this_span = []
83
+ global_to_local_map = {} #
84
+ for entity_id, entity in enumerate(entities):
85
+ if (index <= entity["start"] < index + self.max_seq_len and
86
+ index <= entity["end"] < index + self.max_seq_len):
87
+ entity["start"] = entity["start"] - index
88
+ entity["end"] = entity["end"] - index
89
+ global_to_local_map[entity_id] = len(entities_in_this_span)
90
+ entities_in_this_span.append(entity)
91
+
92
+ # select relations in current chunk
93
+ relations_in_this_span = []
94
+ for relation in relations:
95
+ if (index <= relation["start_index"] < index + self.max_seq_len
96
+ and index <= relation["end_index"] <
97
+ index + self.max_seq_len):
98
+ relations_in_this_span.append({
99
+ "head": global_to_local_map[relation["head"]],
100
+ "tail": global_to_local_map[relation["tail"]],
101
+ "start_index": relation["start_index"] - index,
102
+ "end_index": relation["end_index"] - index,
103
+ })
104
+ item.update({
105
+ "entities": self.reformat(entities_in_this_span),
106
+ "relations": self.reformat(relations_in_this_span),
107
+ })
108
+ if len(item['entities']) > 0:
109
+ item['entities']['label'] = [
110
+ self.entities_labels[x] for x in item['entities']['label']
111
+ ]
112
+ encoded_inputs_all.append(item)
113
+ if len(encoded_inputs_all) == 0:
114
+ return None
115
+ return encoded_inputs_all[0]
116
+
117
+ def reformat(self, data):
118
+ new_data = defaultdict(list)
119
+ for item in data:
120
+ for k, v in item.items():
121
+ new_data[k].append(v)
122
+ return new_data
ppocr/data/imaug/vqa/token/vqa_token_pad.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import paddle
15
+ import numpy as np
16
+
17
+
18
+ class VQATokenPad(object):
19
+ def __init__(self,
20
+ max_seq_len=512,
21
+ pad_to_max_seq_len=True,
22
+ return_attention_mask=True,
23
+ return_token_type_ids=True,
24
+ truncation_strategy="longest_first",
25
+ return_overflowing_tokens=False,
26
+ return_special_tokens_mask=False,
27
+ infer_mode=False,
28
+ **kwargs):
29
+ self.max_seq_len = max_seq_len
30
+ self.pad_to_max_seq_len = max_seq_len
31
+ self.return_attention_mask = return_attention_mask
32
+ self.return_token_type_ids = return_token_type_ids
33
+ self.truncation_strategy = truncation_strategy
34
+ self.return_overflowing_tokens = return_overflowing_tokens
35
+ self.return_special_tokens_mask = return_special_tokens_mask
36
+ self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
37
+ self.infer_mode = infer_mode
38
+
39
+ def __call__(self, data):
40
+ needs_to_be_padded = self.pad_to_max_seq_len and len(data[
41
+ "input_ids"]) < self.max_seq_len
42
+
43
+ if needs_to_be_padded:
44
+ if 'tokenizer_params' in data:
45
+ tokenizer_params = data.pop('tokenizer_params')
46
+ else:
47
+ tokenizer_params = dict(
48
+ padding_side='right', pad_token_type_id=0, pad_token_id=1)
49
+
50
+ difference = self.max_seq_len - len(data["input_ids"])
51
+ if tokenizer_params['padding_side'] == 'right':
52
+ if self.return_attention_mask:
53
+ data["attention_mask"] = [1] * len(data[
54
+ "input_ids"]) + [0] * difference
55
+ if self.return_token_type_ids:
56
+ data["token_type_ids"] = (
57
+ data["token_type_ids"] +
58
+ [tokenizer_params['pad_token_type_id']] * difference)
59
+ if self.return_special_tokens_mask:
60
+ data["special_tokens_mask"] = data[
61
+ "special_tokens_mask"] + [1] * difference
62
+ data["input_ids"] = data["input_ids"] + [
63
+ tokenizer_params['pad_token_id']
64
+ ] * difference
65
+ if not self.infer_mode:
66
+ data["labels"] = data[
67
+ "labels"] + [self.pad_token_label_id] * difference
68
+ data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
69
+ elif tokenizer_params['padding_side'] == 'left':
70
+ if self.return_attention_mask:
71
+ data["attention_mask"] = [0] * difference + [
72
+ 1
73
+ ] * len(data["input_ids"])
74
+ if self.return_token_type_ids:
75
+ data["token_type_ids"] = (
76
+ [tokenizer_params['pad_token_type_id']] * difference +
77
+ data["token_type_ids"])
78
+ if self.return_special_tokens_mask:
79
+ data["special_tokens_mask"] = [
80
+ 1
81
+ ] * difference + data["special_tokens_mask"]
82
+ data["input_ids"] = [tokenizer_params['pad_token_id']
83
+ ] * difference + data["input_ids"]
84
+ if not self.infer_mode:
85
+ data["labels"] = [self.pad_token_label_id
86
+ ] * difference + data["labels"]
87
+ data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
88
+ else:
89
+ if self.return_attention_mask:
90
+ data["attention_mask"] = [1] * len(data["input_ids"])
91
+
92
+ for key in data:
93
+ if key in [
94
+ 'input_ids', 'labels', 'token_type_ids', 'bbox',
95
+ 'attention_mask'
96
+ ]:
97
+ if self.infer_mode:
98
+ if key != 'labels':
99
+ length = min(len(data[key]), self.max_seq_len)
100
+ data[key] = data[key][:length]
101
+ else:
102
+ continue
103
+ data[key] = np.array(data[key], dtype='int64')
104
+ return data
ppocr/data/imaug/vqa/token/vqa_token_relation.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ class VQAReTokenRelation(object):
17
+ def __init__(self, **kwargs):
18
+ pass
19
+
20
+ def __call__(self, data):
21
+ """
22
+ build relations
23
+ """
24
+ entities = data['entities']
25
+ relations = data['relations']
26
+ id2label = data.pop('id2label')
27
+ empty_entity = data.pop('empty_entity')
28
+ entity_id_to_index_map = data.pop('entity_id_to_index_map')
29
+
30
+ relations = list(set(relations))
31
+ relations = [
32
+ rel for rel in relations
33
+ if rel[0] not in empty_entity and rel[1] not in empty_entity
34
+ ]
35
+ kv_relations = []
36
+ for rel in relations:
37
+ pair = [id2label[rel[0]], id2label[rel[1]]]
38
+ if pair == ["question", "answer"]:
39
+ kv_relations.append({
40
+ "head": entity_id_to_index_map[rel[0]],
41
+ "tail": entity_id_to_index_map[rel[1]]
42
+ })
43
+ elif pair == ["answer", "question"]:
44
+ kv_relations.append({
45
+ "head": entity_id_to_index_map[rel[1]],
46
+ "tail": entity_id_to_index_map[rel[0]]
47
+ })
48
+ else:
49
+ continue
50
+ relations = sorted(
51
+ [{
52
+ "head": rel["head"],
53
+ "tail": rel["tail"],
54
+ "start_index": self.get_relation_span(rel, entities)[0],
55
+ "end_index": self.get_relation_span(rel, entities)[1],
56
+ } for rel in kv_relations],
57
+ key=lambda x: x["head"], )
58
+
59
+ data['relations'] = relations
60
+ return data
61
+
62
+ def get_relation_span(self, rel, entities):
63
+ bound = []
64
+ for entity_index in [rel["head"], rel["tail"]]:
65
+ bound.append(entities[entity_index]["start"])
66
+ bound.append(entities[entity_index]["end"])
67
+ return min(bound), max(bound)
ppocr/data/lmdb_dataset.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import os
16
+ from paddle.io import Dataset
17
+ import lmdb
18
+ import cv2
19
+ import string
20
+ import six
21
+ from PIL import Image
22
+
23
+ from .imaug import transform, create_operators
24
+
25
+
26
+ class LMDBDataSet(Dataset):
27
+ def __init__(self, config, mode, logger, seed=None):
28
+ super(LMDBDataSet, self).__init__()
29
+
30
+ global_config = config['Global']
31
+ dataset_config = config[mode]['dataset']
32
+ loader_config = config[mode]['loader']
33
+ batch_size = loader_config['batch_size_per_card']
34
+ data_dir = dataset_config['data_dir']
35
+ self.do_shuffle = loader_config['shuffle']
36
+
37
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
38
+ logger.info("Initialize indexs of datasets:%s" % data_dir)
39
+ self.data_idx_order_list = self.dataset_traversal()
40
+ if self.do_shuffle:
41
+ np.random.shuffle(self.data_idx_order_list)
42
+ self.ops = create_operators(dataset_config['transforms'], global_config)
43
+ self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
44
+ 1)
45
+
46
+ ratio_list = dataset_config.get("ratio_list", [1.0])
47
+ self.need_reset = True in [x < 1 for x in ratio_list]
48
+
49
+ def load_hierarchical_lmdb_dataset(self, data_dir):
50
+ lmdb_sets = {}
51
+ dataset_idx = 0
52
+ for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
53
+ if not dirnames:
54
+ env = lmdb.open(
55
+ dirpath,
56
+ max_readers=32,
57
+ readonly=True,
58
+ lock=False,
59
+ readahead=False,
60
+ meminit=False)
61
+ txn = env.begin(write=False)
62
+ num_samples = int(txn.get('num-samples'.encode()))
63
+ lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
64
+ "txn":txn, "num_samples":num_samples}
65
+ dataset_idx += 1
66
+ return lmdb_sets
67
+
68
+ def dataset_traversal(self):
69
+ lmdb_num = len(self.lmdb_sets)
70
+ total_sample_num = 0
71
+ for lno in range(lmdb_num):
72
+ total_sample_num += self.lmdb_sets[lno]['num_samples']
73
+ data_idx_order_list = np.zeros((total_sample_num, 2))
74
+ beg_idx = 0
75
+ for lno in range(lmdb_num):
76
+ tmp_sample_num = self.lmdb_sets[lno]['num_samples']
77
+ end_idx = beg_idx + tmp_sample_num
78
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
79
+ data_idx_order_list[beg_idx:end_idx, 1] \
80
+ = list(range(tmp_sample_num))
81
+ data_idx_order_list[beg_idx:end_idx, 1] += 1
82
+ beg_idx = beg_idx + tmp_sample_num
83
+ return data_idx_order_list
84
+
85
+ def get_img_data(self, value):
86
+ """get_img_data"""
87
+ if not value:
88
+ return None
89
+ imgdata = np.frombuffer(value, dtype='uint8')
90
+ if imgdata is None:
91
+ return None
92
+ imgori = cv2.imdecode(imgdata, 1)
93
+ if imgori is None:
94
+ return None
95
+ return imgori
96
+
97
+ def get_ext_data(self):
98
+ ext_data_num = 0
99
+ for op in self.ops:
100
+ if hasattr(op, 'ext_data_num'):
101
+ ext_data_num = getattr(op, 'ext_data_num')
102
+ break
103
+ load_data_ops = self.ops[:self.ext_op_transform_idx]
104
+ ext_data = []
105
+
106
+ while len(ext_data) < ext_data_num:
107
+ lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(
108
+ len(self))]
109
+ lmdb_idx = int(lmdb_idx)
110
+ file_idx = int(file_idx)
111
+ sample_info = self.get_lmdb_sample_info(
112
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
113
+ if sample_info is None:
114
+ continue
115
+ img, label = sample_info
116
+ data = {'image': img, 'label': label}
117
+ data = transform(data, load_data_ops)
118
+ if data is None:
119
+ continue
120
+ ext_data.append(data)
121
+ return ext_data
122
+
123
+ def get_lmdb_sample_info(self, txn, index):
124
+ label_key = 'label-%09d'.encode() % index
125
+ label = txn.get(label_key)
126
+ if label is None:
127
+ return None
128
+ label = label.decode('utf-8')
129
+ img_key = 'image-%09d'.encode() % index
130
+ imgbuf = txn.get(img_key)
131
+ return imgbuf, label
132
+
133
+ def __getitem__(self, idx):
134
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
135
+ lmdb_idx = int(lmdb_idx)
136
+ file_idx = int(file_idx)
137
+ sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
138
+ file_idx)
139
+ if sample_info is None:
140
+ return self.__getitem__(np.random.randint(self.__len__()))
141
+ img, label = sample_info
142
+ data = {'image': img, 'label': label}
143
+ data['ext_data'] = self.get_ext_data()
144
+ outs = transform(data, self.ops)
145
+ if outs is None:
146
+ return self.__getitem__(np.random.randint(self.__len__()))
147
+ return outs
148
+
149
+ def __len__(self):
150
+ return self.data_idx_order_list.shape[0]
151
+
152
+
153
+ class LMDBDataSetSR(LMDBDataSet):
154
+ def buf2PIL(self, txn, key, type='RGB'):
155
+ imgbuf = txn.get(key)
156
+ buf = six.BytesIO()
157
+ buf.write(imgbuf)
158
+ buf.seek(0)
159
+ im = Image.open(buf).convert(type)
160
+ return im
161
+
162
+ def str_filt(self, str_, voc_type):
163
+ alpha_dict = {
164
+ 'digit': string.digits,
165
+ 'lower': string.digits + string.ascii_lowercase,
166
+ 'upper': string.digits + string.ascii_letters,
167
+ 'all': string.digits + string.ascii_letters + string.punctuation
168
+ }
169
+ if voc_type == 'lower':
170
+ str_ = str_.lower()
171
+ for char in str_:
172
+ if char not in alpha_dict[voc_type]:
173
+ str_ = str_.replace(char, '')
174
+ return str_
175
+
176
+ def get_lmdb_sample_info(self, txn, index):
177
+ self.voc_type = 'upper'
178
+ self.max_len = 100
179
+ self.test = False
180
+ label_key = b'label-%09d' % index
181
+ word = str(txn.get(label_key).decode())
182
+ img_HR_key = b'image_hr-%09d' % index # 128*32
183
+ img_lr_key = b'image_lr-%09d' % index # 64*16
184
+ try:
185
+ img_HR = self.buf2PIL(txn, img_HR_key, 'RGB')
186
+ img_lr = self.buf2PIL(txn, img_lr_key, 'RGB')
187
+ except IOError or len(word) > self.max_len:
188
+ return self[index + 1]
189
+ label_str = self.str_filt(word, self.voc_type)
190
+ return img_HR, img_lr, label_str
191
+
192
+ def __getitem__(self, idx):
193
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
194
+ lmdb_idx = int(lmdb_idx)
195
+ file_idx = int(file_idx)
196
+ sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
197
+ file_idx)
198
+ if sample_info is None:
199
+ return self.__getitem__(np.random.randint(self.__len__()))
200
+ img_HR, img_lr, label_str = sample_info
201
+ data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str}
202
+ outs = transform(data, self.ops)
203
+ if outs is None:
204
+ return self.__getitem__(np.random.randint(self.__len__()))
205
+ return outs
ppocr/data/pgnet_dataset.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import os
16
+ from paddle.io import Dataset
17
+ from .imaug import transform, create_operators
18
+ import random
19
+
20
+
21
+ class PGDataSet(Dataset):
22
+ def __init__(self, config, mode, logger, seed=None):
23
+ super(PGDataSet, self).__init__()
24
+
25
+ self.logger = logger
26
+ self.seed = seed
27
+ self.mode = mode
28
+ global_config = config['Global']
29
+ dataset_config = config[mode]['dataset']
30
+ loader_config = config[mode]['loader']
31
+
32
+ self.delimiter = dataset_config.get('delimiter', '\t')
33
+ label_file_list = dataset_config.pop('label_file_list')
34
+ data_source_num = len(label_file_list)
35
+ ratio_list = dataset_config.get("ratio_list", [1.0])
36
+ if isinstance(ratio_list, (float, int)):
37
+ ratio_list = [float(ratio_list)] * int(data_source_num)
38
+ assert len(
39
+ ratio_list
40
+ ) == data_source_num, "The length of ratio_list should be the same as the file_list."
41
+ self.data_dir = dataset_config['data_dir']
42
+ self.do_shuffle = loader_config['shuffle']
43
+
44
+ logger.info("Initialize indexs of datasets:%s" % label_file_list)
45
+ self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
46
+ self.data_idx_order_list = list(range(len(self.data_lines)))
47
+ if mode.lower() == "train":
48
+ self.shuffle_data_random()
49
+
50
+ self.ops = create_operators(dataset_config['transforms'], global_config)
51
+
52
+ self.need_reset = True in [x < 1 for x in ratio_list]
53
+
54
+ def shuffle_data_random(self):
55
+ if self.do_shuffle:
56
+ random.seed(self.seed)
57
+ random.shuffle(self.data_lines)
58
+ return
59
+
60
+ def get_image_info_list(self, file_list, ratio_list):
61
+ if isinstance(file_list, str):
62
+ file_list = [file_list]
63
+ data_lines = []
64
+ for idx, file in enumerate(file_list):
65
+ with open(file, "rb") as f:
66
+ lines = f.readlines()
67
+ if self.mode == "train" or ratio_list[idx] < 1.0:
68
+ random.seed(self.seed)
69
+ lines = random.sample(lines,
70
+ round(len(lines) * ratio_list[idx]))
71
+ data_lines.extend(lines)
72
+ return data_lines
73
+
74
+ def __getitem__(self, idx):
75
+ file_idx = self.data_idx_order_list[idx]
76
+ data_line = self.data_lines[file_idx]
77
+ img_id = 0
78
+ try:
79
+ data_line = data_line.decode('utf-8')
80
+ substr = data_line.strip("\n").split(self.delimiter)
81
+ file_name = substr[0]
82
+ label = substr[1]
83
+ img_path = os.path.join(self.data_dir, file_name)
84
+ if self.mode.lower() == 'eval':
85
+ try:
86
+ img_id = int(data_line.split(".")[0][7:])
87
+ except:
88
+ img_id = 0
89
+ data = {'img_path': img_path, 'label': label, 'img_id': img_id}
90
+ if not os.path.exists(img_path):
91
+ raise Exception("{} does not exist!".format(img_path))
92
+ with open(data['img_path'], 'rb') as f:
93
+ img = f.read()
94
+ data['image'] = img
95
+ outs = transform(data, self.ops)
96
+ except Exception as e:
97
+ self.logger.error(
98
+ "When parsing line {}, error happened with msg: {}".format(
99
+ self.data_idx_order_list[idx], e))
100
+ outs = None
101
+ if outs is None:
102
+ return self.__getitem__(np.random.randint(self.__len__()))
103
+ return outs
104
+
105
+ def __len__(self):
106
+ return len(self.data_idx_order_list)
ppocr/data/pubtab_dataset.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import os
16
+ import random
17
+ from paddle.io import Dataset
18
+ import json
19
+ from copy import deepcopy
20
+
21
+ from .imaug import transform, create_operators
22
+
23
+
24
+ class PubTabDataSet(Dataset):
25
+ def __init__(self, config, mode, logger, seed=None):
26
+ super(PubTabDataSet, self).__init__()
27
+ self.logger = logger
28
+
29
+ global_config = config['Global']
30
+ dataset_config = config[mode]['dataset']
31
+ loader_config = config[mode]['loader']
32
+
33
+ label_file_list = dataset_config.pop('label_file_list')
34
+ data_source_num = len(label_file_list)
35
+ ratio_list = dataset_config.get("ratio_list", [1.0])
36
+ if isinstance(ratio_list, (float, int)):
37
+ ratio_list = [float(ratio_list)] * int(data_source_num)
38
+
39
+ assert len(
40
+ ratio_list
41
+ ) == data_source_num, "The length of ratio_list should be the same as the file_list."
42
+
43
+ self.data_dir = dataset_config['data_dir']
44
+ self.do_shuffle = loader_config['shuffle']
45
+
46
+ self.seed = seed
47
+ self.mode = mode.lower()
48
+ logger.info("Initialize indexs of datasets:%s" % label_file_list)
49
+ self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
50
+ # self.check(config['Global']['max_text_length'])
51
+
52
+ if mode.lower() == "train" and self.do_shuffle:
53
+ self.shuffle_data_random()
54
+ self.ops = create_operators(dataset_config['transforms'], global_config)
55
+ self.need_reset = True in [x < 1 for x in ratio_list]
56
+
57
+ def get_image_info_list(self, file_list, ratio_list):
58
+ if isinstance(file_list, str):
59
+ file_list = [file_list]
60
+ data_lines = []
61
+ for idx, file in enumerate(file_list):
62
+ with open(file, "rb") as f:
63
+ lines = f.readlines()
64
+ if self.mode == "train" or ratio_list[idx] < 1.0:
65
+ random.seed(self.seed)
66
+ lines = random.sample(lines,
67
+ round(len(lines) * ratio_list[idx]))
68
+ data_lines.extend(lines)
69
+ return data_lines
70
+
71
+ def check(self, max_text_length):
72
+ data_lines = []
73
+ for line in self.data_lines:
74
+ data_line = line.decode('utf-8').strip("\n")
75
+ info = json.loads(data_line)
76
+ file_name = info['filename']
77
+ cells = info['html']['cells'].copy()
78
+ structure = info['html']['structure']['tokens'].copy()
79
+
80
+ img_path = os.path.join(self.data_dir, file_name)
81
+ if not os.path.exists(img_path):
82
+ self.logger.warning("{} does not exist!".format(img_path))
83
+ continue
84
+ if len(structure) == 0 or len(structure) > max_text_length:
85
+ continue
86
+ # data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name}
87
+ data_lines.append(line)
88
+ self.data_lines = data_lines
89
+
90
+ def shuffle_data_random(self):
91
+ if self.do_shuffle:
92
+ random.seed(self.seed)
93
+ random.shuffle(self.data_lines)
94
+ return
95
+
96
+ def __getitem__(self, idx):
97
+ try:
98
+ data_line = self.data_lines[idx]
99
+ data_line = data_line.decode('utf-8').strip("\n")
100
+ info = json.loads(data_line)
101
+ file_name = info['filename']
102
+ cells = info['html']['cells'].copy()
103
+ structure = info['html']['structure']['tokens'].copy()
104
+
105
+ img_path = os.path.join(self.data_dir, file_name)
106
+ if not os.path.exists(img_path):
107
+ raise Exception("{} does not exist!".format(img_path))
108
+ data = {
109
+ 'img_path': img_path,
110
+ 'cells': cells,
111
+ 'structure': structure,
112
+ 'file_name': file_name
113
+ }
114
+
115
+ with open(data['img_path'], 'rb') as f:
116
+ img = f.read()
117
+ data['image'] = img
118
+ outs = transform(data, self.ops)
119
+ except:
120
+ import traceback
121
+ err = traceback.format_exc()
122
+ self.logger.error(
123
+ "When parsing line {}, error happened with msg: {}".format(
124
+ data_line, err))
125
+ outs = None
126
+ if outs is None:
127
+ rnd_idx = np.random.randint(self.__len__(
128
+ )) if self.mode == "train" else (idx + 1) % self.__len__()
129
+ return self.__getitem__(rnd_idx)
130
+ return outs
131
+
132
+ def __len__(self):
133
+ return len(self.data_lines)
ppocr/data/simple_dataset.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import os
16
+ import json
17
+ import random
18
+ import traceback
19
+ from paddle.io import Dataset
20
+ from .imaug import transform, create_operators
21
+
22
+
23
+ class SimpleDataSet(Dataset):
24
+ def __init__(self, config, mode, logger, seed=None):
25
+ super(SimpleDataSet, self).__init__()
26
+ self.logger = logger
27
+ self.mode = mode.lower()
28
+
29
+ global_config = config['Global']
30
+ dataset_config = config[mode]['dataset']
31
+ loader_config = config[mode]['loader']
32
+
33
+ self.delimiter = dataset_config.get('delimiter', '\t')
34
+ label_file_list = dataset_config.pop('label_file_list')
35
+ data_source_num = len(label_file_list)
36
+ ratio_list = dataset_config.get("ratio_list", 1.0)
37
+ if isinstance(ratio_list, (float, int)):
38
+ ratio_list = [float(ratio_list)] * int(data_source_num)
39
+
40
+ assert len(
41
+ ratio_list
42
+ ) == data_source_num, "The length of ratio_list should be the same as the file_list."
43
+ self.data_dir = dataset_config['data_dir']
44
+ self.do_shuffle = loader_config['shuffle']
45
+ self.seed = seed
46
+ logger.info("Initialize indexs of datasets:%s" % label_file_list)
47
+ self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
48
+ self.data_idx_order_list = list(range(len(self.data_lines)))
49
+ if self.mode == "train" and self.do_shuffle:
50
+ self.shuffle_data_random()
51
+ self.ops = create_operators(dataset_config['transforms'], global_config)
52
+ self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
53
+ 2)
54
+ self.need_reset = True in [x < 1 for x in ratio_list]
55
+
56
+ def get_image_info_list(self, file_list, ratio_list):
57
+ if isinstance(file_list, str):
58
+ file_list = [file_list]
59
+ data_lines = []
60
+ for idx, file in enumerate(file_list):
61
+ with open(file, "rb") as f:
62
+ lines = f.readlines()
63
+ if self.mode == "train" or ratio_list[idx] < 1.0:
64
+ random.seed(self.seed)
65
+ lines = random.sample(lines,
66
+ round(len(lines) * ratio_list[idx]))
67
+ data_lines.extend(lines)
68
+ return data_lines
69
+
70
+ def shuffle_data_random(self):
71
+ random.seed(self.seed)
72
+ random.shuffle(self.data_lines)
73
+ return
74
+
75
+ def _try_parse_filename_list(self, file_name):
76
+ # multiple images -> one gt label
77
+ if len(file_name) > 0 and file_name[0] == "[":
78
+ try:
79
+ info = json.loads(file_name)
80
+ file_name = random.choice(info)
81
+ except:
82
+ pass
83
+ return file_name
84
+
85
+ def get_ext_data(self):
86
+ ext_data_num = 0
87
+ for op in self.ops:
88
+ if hasattr(op, 'ext_data_num'):
89
+ ext_data_num = getattr(op, 'ext_data_num')
90
+ break
91
+ load_data_ops = self.ops[:self.ext_op_transform_idx]
92
+ ext_data = []
93
+
94
+ while len(ext_data) < ext_data_num:
95
+ file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
96
+ ))]
97
+ data_line = self.data_lines[file_idx]
98
+ data_line = data_line.decode('utf-8')
99
+ substr = data_line.strip("\n").split(self.delimiter)
100
+ file_name = substr[0]
101
+ file_name = self._try_parse_filename_list(file_name)
102
+ label = substr[1]
103
+ img_path = os.path.join(self.data_dir, file_name)
104
+ data = {'img_path': img_path, 'label': label}
105
+ if not os.path.exists(img_path):
106
+ continue
107
+ with open(data['img_path'], 'rb') as f:
108
+ img = f.read()
109
+ data['image'] = img
110
+ data = transform(data, load_data_ops)
111
+
112
+ if data is None:
113
+ continue
114
+ if 'polys' in data.keys():
115
+ if data['polys'].shape[1] != 4:
116
+ continue
117
+ ext_data.append(data)
118
+ return ext_data
119
+
120
+ def __getitem__(self, idx):
121
+ file_idx = self.data_idx_order_list[idx]
122
+ data_line = self.data_lines[file_idx]
123
+ try:
124
+ data_line = data_line.decode('utf-8')
125
+ substr = data_line.strip("\n").split(self.delimiter)
126
+ file_name = substr[0]
127
+ file_name = self._try_parse_filename_list(file_name)
128
+ label = substr[1]
129
+ img_path = os.path.join(self.data_dir, file_name)
130
+ data = {'img_path': img_path, 'label': label}
131
+ if not os.path.exists(img_path):
132
+ raise Exception("{} does not exist!".format(img_path))
133
+ with open(data['img_path'], 'rb') as f:
134
+ img = f.read()
135
+ data['image'] = img
136
+ data['ext_data'] = self.get_ext_data()
137
+ outs = transform(data, self.ops)
138
+ except:
139
+ self.logger.error(
140
+ "When parsing line {}, error happened with msg: {}".format(
141
+ data_line, traceback.format_exc()))
142
+ outs = None
143
+ if outs is None:
144
+ # during evaluation, we should fix the idx to get same results for many times of evaluation.
145
+ rnd_idx = np.random.randint(self.__len__(
146
+ )) if self.mode == "train" else (idx + 1) % self.__len__()
147
+ return self.__getitem__(rnd_idx)
148
+ return outs
149
+
150
+ def __len__(self):
151
+ return len(self.data_idx_order_list)
ppocr/ext_op/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .roi_align_rotated.roi_align_rotated import RoIAlignRotated
ppocr/ext_op/roi_align_rotated/roi_align_rotated.cc ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ // This code is refer from:
3
+ // https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cpu/roi_align_rotated.cpp
4
+
5
+ #include <cassert>
6
+ #include <cmath>
7
+ #include <vector>
8
+
9
+ #include "paddle/extension.h"
10
+
11
+ #define PADDLE_WITH_CUDA
12
+ #define CHECK_INPUT_SAME(x1, x2) \
13
+ PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.")
14
+ #define CHECK_INPUT_CPU(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
15
+
16
+ template <typename T> struct PreCalc {
17
+ int pos1;
18
+ int pos2;
19
+ int pos3;
20
+ int pos4;
21
+ T w1;
22
+ T w2;
23
+ T w3;
24
+ T w4;
25
+ };
26
+
27
+ template <typename T>
28
+ void pre_calc_for_bilinear_interpolate(
29
+ const int height, const int width, const int pooled_height,
30
+ const int pooled_width, const int iy_upper, const int ix_upper,
31
+ T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w,
32
+ int roi_bin_grid_h, int roi_bin_grid_w, T roi_center_h, T roi_center_w,
33
+ T cos_theta, T sin_theta, std::vector<PreCalc<T>> &pre_calc) {
34
+ int pre_calc_index = 0;
35
+ for (int ph = 0; ph < pooled_height; ph++) {
36
+ for (int pw = 0; pw < pooled_width; pw++) {
37
+ for (int iy = 0; iy < iy_upper; iy++) {
38
+ const T yy = roi_start_h + ph * bin_size_h +
39
+ static_cast<T>(iy + .5f) * bin_size_h /
40
+ static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
41
+ for (int ix = 0; ix < ix_upper; ix++) {
42
+ const T xx = roi_start_w + pw * bin_size_w +
43
+ static_cast<T>(ix + .5f) * bin_size_w /
44
+ static_cast<T>(roi_bin_grid_w);
45
+
46
+ // Rotate by theta around the center and translate
47
+ // In image space, (y, x) is the order for Right Handed System,
48
+ // and this is essentially multiplying the point by a rotation matrix
49
+ // to rotate it counterclockwise through angle theta.
50
+ T y = yy * cos_theta - xx * sin_theta + roi_center_h;
51
+ T x = yy * sin_theta + xx * cos_theta + roi_center_w;
52
+ // deal with: inverse elements are out of feature map boundary
53
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
54
+ // empty
55
+ PreCalc<T> pc;
56
+ pc.pos1 = 0;
57
+ pc.pos2 = 0;
58
+ pc.pos3 = 0;
59
+ pc.pos4 = 0;
60
+ pc.w1 = 0;
61
+ pc.w2 = 0;
62
+ pc.w3 = 0;
63
+ pc.w4 = 0;
64
+ pre_calc[pre_calc_index] = pc;
65
+ pre_calc_index += 1;
66
+ continue;
67
+ }
68
+
69
+ if (y < 0) {
70
+ y = 0;
71
+ }
72
+ if (x < 0) {
73
+ x = 0;
74
+ }
75
+
76
+ int y_low = (int)y;
77
+ int x_low = (int)x;
78
+ int y_high;
79
+ int x_high;
80
+
81
+ if (y_low >= height - 1) {
82
+ y_high = y_low = height - 1;
83
+ y = (T)y_low;
84
+ } else {
85
+ y_high = y_low + 1;
86
+ }
87
+
88
+ if (x_low >= width - 1) {
89
+ x_high = x_low = width - 1;
90
+ x = (T)x_low;
91
+ } else {
92
+ x_high = x_low + 1;
93
+ }
94
+
95
+ T ly = y - y_low;
96
+ T lx = x - x_low;
97
+ T hy = 1. - ly, hx = 1. - lx;
98
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
99
+
100
+ // save weights and indices
101
+ PreCalc<T> pc;
102
+ pc.pos1 = y_low * width + x_low;
103
+ pc.pos2 = y_low * width + x_high;
104
+ pc.pos3 = y_high * width + x_low;
105
+ pc.pos4 = y_high * width + x_high;
106
+ pc.w1 = w1;
107
+ pc.w2 = w2;
108
+ pc.w3 = w3;
109
+ pc.w4 = w4;
110
+ pre_calc[pre_calc_index] = pc;
111
+
112
+ pre_calc_index += 1;
113
+ }
114
+ }
115
+ }
116
+ }
117
+ }
118
+
119
+ template <typename T>
120
+ void roi_align_rotated_cpu_forward(const int nthreads, const T *input,
121
+ const T &spatial_scale, const bool aligned,
122
+ const bool clockwise, const int channels,
123
+ const int height, const int width,
124
+ const int pooled_height,
125
+ const int pooled_width,
126
+ const int sampling_ratio, const T *rois,
127
+ T *output) {
128
+ int n_rois = nthreads / channels / pooled_width / pooled_height;
129
+ // (n, c, ph, pw) is an element in the pooled output
130
+ // can be parallelized using omp
131
+ // #pragma omp parallel for num_threads(32)
132
+ for (int n = 0; n < n_rois; n++) {
133
+ int index_n = n * channels * pooled_width * pooled_height;
134
+
135
+ const T *current_roi = rois + n * 6;
136
+ int roi_batch_ind = current_roi[0];
137
+
138
+ // Do not use rounding; this implementation detail is critical
139
+ T offset = aligned ? (T)0.5 : (T)0.0;
140
+ T roi_center_w = current_roi[1] * spatial_scale - offset;
141
+ T roi_center_h = current_roi[2] * spatial_scale - offset;
142
+ T roi_width = current_roi[3] * spatial_scale;
143
+ T roi_height = current_roi[4] * spatial_scale;
144
+ T theta = current_roi[5];
145
+ if (clockwise) {
146
+ theta = -theta; // If clockwise, the angle needs to be reversed.
147
+ }
148
+ T cos_theta = cos(theta);
149
+ T sin_theta = sin(theta);
150
+
151
+ if (aligned) {
152
+ assert(roi_width >= 0 && roi_height >= 0);
153
+ } else { // for backward-compatibility only
154
+ roi_width = std::max(roi_width, (T)1.);
155
+ roi_height = std::max(roi_height, (T)1.);
156
+ }
157
+
158
+ T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
159
+ T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
160
+
161
+ // We use roi_bin_grid to sample the grid and mimic integral
162
+ int roi_bin_grid_h = (sampling_ratio > 0)
163
+ ? sampling_ratio
164
+ : ceilf(roi_height / pooled_height); // e.g., = 2
165
+ int roi_bin_grid_w =
166
+ (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
167
+
168
+ // We do average (integral) pooling inside a bin
169
+ const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
170
+
171
+ // we want to precalculate indices and weights shared by all channels,
172
+ // this is the key point of optimization
173
+ std::vector<PreCalc<T>> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
174
+ pooled_width * pooled_height);
175
+
176
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
177
+ // Appropriate translation needs to be applied after.
178
+ T roi_start_h = -roi_height / 2.0;
179
+ T roi_start_w = -roi_width / 2.0;
180
+
181
+ pre_calc_for_bilinear_interpolate(
182
+ height, width, pooled_height, pooled_width, roi_bin_grid_h,
183
+ roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
184
+ roi_bin_grid_h, roi_bin_grid_w, roi_center_h, roi_center_w, cos_theta,
185
+ sin_theta, pre_calc);
186
+
187
+ for (int c = 0; c < channels; c++) {
188
+ int index_n_c = index_n + c * pooled_width * pooled_height;
189
+ const T *offset_input =
190
+ input + (roi_batch_ind * channels + c) * height * width;
191
+ int pre_calc_index = 0;
192
+
193
+ for (int ph = 0; ph < pooled_height; ph++) {
194
+ for (int pw = 0; pw < pooled_width; pw++) {
195
+ int index = index_n_c + ph * pooled_width + pw;
196
+
197
+ T output_val = 0.;
198
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
199
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
200
+ PreCalc<T> pc = pre_calc[pre_calc_index];
201
+ output_val += pc.w1 * offset_input[pc.pos1] +
202
+ pc.w2 * offset_input[pc.pos2] +
203
+ pc.w3 * offset_input[pc.pos3] +
204
+ pc.w4 * offset_input[pc.pos4];
205
+
206
+ pre_calc_index += 1;
207
+ }
208
+ }
209
+ output_val /= count;
210
+
211
+ output[index] = output_val;
212
+ } // for pw
213
+ } // for ph
214
+ } // for c
215
+ } // for n
216
+ }
217
+
218
+ template <typename T>
219
+ void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
220
+ T &w1, T &w2, T &w3, T &w4, int &x_low,
221
+ int &x_high, int &y_low, int &y_high) {
222
+ // deal with cases that inverse elements are out of feature map boundary
223
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
224
+ // empty
225
+ w1 = w2 = w3 = w4 = 0.;
226
+ x_low = x_high = y_low = y_high = -1;
227
+ return;
228
+ }
229
+
230
+ if (y < 0) {
231
+ y = 0;
232
+ }
233
+
234
+ if (x < 0) {
235
+ x = 0;
236
+ }
237
+
238
+ y_low = (int)y;
239
+ x_low = (int)x;
240
+
241
+ if (y_low >= height - 1) {
242
+ y_high = y_low = height - 1;
243
+ y = (T)y_low;
244
+ } else {
245
+ y_high = y_low + 1;
246
+ }
247
+
248
+ if (x_low >= width - 1) {
249
+ x_high = x_low = width - 1;
250
+ x = (T)x_low;
251
+ } else {
252
+ x_high = x_low + 1;
253
+ }
254
+
255
+ T ly = y - y_low;
256
+ T lx = x - x_low;
257
+ T hy = 1. - ly, hx = 1. - lx;
258
+
259
+ // reference in forward
260
+ // T v1 = input[y_low * width + x_low];
261
+ // T v2 = input[y_low * width + x_high];
262
+ // T v3 = input[y_high * width + x_low];
263
+ // T v4 = input[y_high * width + x_high];
264
+ // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
265
+
266
+ w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
267
+
268
+ return;
269
+ }
270
+
271
+ template <class T> inline void add(T *address, const T &val) {
272
+ *address += val;
273
+ }
274
+
275
+ template <typename T>
276
+ void roi_align_rotated_cpu_backward(
277
+ const int nthreads,
278
+ // may not be contiguous. should index using n_stride, etc
279
+ const T *grad_output, const T &spatial_scale, const bool aligned,
280
+ const bool clockwise, const int channels, const int height, const int width,
281
+ const int pooled_height, const int pooled_width, const int sampling_ratio,
282
+ T *grad_input, const T *rois, const int n_stride, const int c_stride,
283
+ const int h_stride, const int w_stride) {
284
+ for (int index = 0; index < nthreads; index++) {
285
+ // (n, c, ph, pw) is an element in the pooled output
286
+ int pw = index % pooled_width;
287
+ int ph = (index / pooled_width) % pooled_height;
288
+ int c = (index / pooled_width / pooled_height) % channels;
289
+ int n = index / pooled_width / pooled_height / channels;
290
+
291
+ const T *current_roi = rois + n * 6;
292
+ int roi_batch_ind = current_roi[0];
293
+
294
+ // Do not use rounding; this implementation detail is critical
295
+ T offset = aligned ? (T)0.5 : (T)0.0;
296
+ T roi_center_w = current_roi[1] * spatial_scale - offset;
297
+ T roi_center_h = current_roi[2] * spatial_scale - offset;
298
+ T roi_width = current_roi[3] * spatial_scale;
299
+ T roi_height = current_roi[4] * spatial_scale;
300
+ T theta = current_roi[5];
301
+ if (clockwise) {
302
+ theta = -theta; // If clockwise, the angle needs to be reversed.
303
+ }
304
+ T cos_theta = cos(theta);
305
+ T sin_theta = sin(theta);
306
+
307
+ if (aligned) {
308
+ assert(roi_width >= 0 && roi_height >= 0);
309
+ } else { // for backward-compatibility only
310
+ roi_width = std::max(roi_width, (T)1.);
311
+ roi_height = std::max(roi_height, (T)1.);
312
+ }
313
+
314
+ T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
315
+ T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
316
+
317
+ T *offset_grad_input =
318
+ grad_input + ((roi_batch_ind * channels + c) * height * width);
319
+
320
+ int output_offset = n * n_stride + c * c_stride;
321
+ const T *offset_grad_output = grad_output + output_offset;
322
+ const T grad_output_this_bin =
323
+ offset_grad_output[ph * h_stride + pw * w_stride];
324
+
325
+ // We use roi_bin_grid to sample the grid and mimic integral
326
+ int roi_bin_grid_h = (sampling_ratio > 0)
327
+ ? sampling_ratio
328
+ : ceilf(roi_height / pooled_height); // e.g., = 2
329
+ int roi_bin_grid_w =
330
+ (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
331
+
332
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
333
+ // Appropriate translation needs to be applied after.
334
+ T roi_start_h = -roi_height / 2.0;
335
+ T roi_start_w = -roi_width / 2.0;
336
+
337
+ // We do average (integral) pooling inside a bin
338
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
339
+
340
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
341
+ const T yy = roi_start_h + ph * bin_size_h +
342
+ static_cast<T>(iy + .5f) * bin_size_h /
343
+ static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
344
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
345
+ const T xx = roi_start_w + pw * bin_size_w +
346
+ static_cast<T>(ix + .5f) * bin_size_w /
347
+ static_cast<T>(roi_bin_grid_w);
348
+
349
+ // Rotate by theta around the center and translate
350
+ T y = yy * cos_theta - xx * sin_theta + roi_center_h;
351
+ T x = yy * sin_theta + xx * cos_theta + roi_center_w;
352
+
353
+ T w1, w2, w3, w4;
354
+ int x_low, x_high, y_low, y_high;
355
+
356
+ bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
357
+ x_low, x_high, y_low, y_high);
358
+
359
+ T g1 = grad_output_this_bin * w1 / count;
360
+ T g2 = grad_output_this_bin * w2 / count;
361
+ T g3 = grad_output_this_bin * w3 / count;
362
+ T g4 = grad_output_this_bin * w4 / count;
363
+
364
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
365
+ // atomic add is not needed for now since it is single threaded
366
+ add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
367
+ add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
368
+ add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
369
+ add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
370
+ } // if
371
+ } // ix
372
+ } // iy
373
+ } // for
374
+ } // ROIAlignRotatedBackward
375
+
376
+ std::vector<paddle::Tensor>
377
+ RoIAlignRotatedCPUForward(const paddle::Tensor &input,
378
+ const paddle::Tensor &rois, int aligned_height,
379
+ int aligned_width, float spatial_scale,
380
+ int sampling_ratio, bool aligned, bool clockwise) {
381
+ CHECK_INPUT_CPU(input);
382
+ CHECK_INPUT_CPU(rois);
383
+
384
+ auto num_rois = rois.shape()[0];
385
+
386
+ auto channels = input.shape()[1];
387
+ auto height = input.shape()[2];
388
+ auto width = input.shape()[3];
389
+
390
+ auto output =
391
+ paddle::empty({num_rois, channels, aligned_height, aligned_width},
392
+ input.type(), paddle::CPUPlace());
393
+ auto output_size = output.numel();
394
+
395
+ PD_DISPATCH_FLOATING_TYPES(
396
+ input.type(), "roi_align_rotated_cpu_forward", ([&] {
397
+ roi_align_rotated_cpu_forward<data_t>(
398
+ output_size, input.data<data_t>(),
399
+ static_cast<data_t>(spatial_scale), aligned, clockwise, channels,
400
+ height, width, aligned_height, aligned_width, sampling_ratio,
401
+ rois.data<data_t>(), output.data<data_t>());
402
+ }));
403
+
404
+ return {output};
405
+ }
406
+
407
+ std::vector<paddle::Tensor> RoIAlignRotatedCPUBackward(
408
+ const paddle::Tensor &input, const paddle::Tensor &rois,
409
+ const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
410
+ float spatial_scale, int sampling_ratio, bool aligned, bool clockwise) {
411
+
412
+ auto batch_size = input.shape()[0];
413
+ auto channels = input.shape()[1];
414
+ auto height = input.shape()[2];
415
+ auto width = input.shape()[3];
416
+
417
+ auto grad_input = paddle::full({batch_size, channels, height, width}, 0.0,
418
+ input.type(), paddle::CPUPlace());
419
+
420
+ // get stride values to ensure indexing into gradients is correct.
421
+ int n_stride = grad_output.shape()[0];
422
+ int c_stride = grad_output.shape()[1];
423
+ int h_stride = grad_output.shape()[2];
424
+ int w_stride = grad_output.shape()[3];
425
+
426
+ PD_DISPATCH_FLOATING_TYPES(
427
+ grad_output.type(), "roi_align_rotated_cpu_backward", [&] {
428
+ roi_align_rotated_cpu_backward<data_t>(
429
+ grad_output.numel(), grad_output.data<data_t>(),
430
+ static_cast<data_t>(spatial_scale), aligned, clockwise, channels,
431
+ height, width, aligned_height, aligned_width, sampling_ratio,
432
+ grad_input.data<data_t>(), rois.data<data_t>(), n_stride, c_stride,
433
+ h_stride, w_stride);
434
+ });
435
+ return {grad_input};
436
+ }
437
+
438
+ #ifdef PADDLE_WITH_CUDA
439
+ std::vector<paddle::Tensor>
440
+ RoIAlignRotatedCUDAForward(const paddle::Tensor &input,
441
+ const paddle::Tensor &rois, int aligned_height,
442
+ int aligned_width, float spatial_scale,
443
+ int sampling_ratio, bool aligned, bool clockwise);
444
+ #endif
445
+
446
+ #ifdef PADDLE_WITH_CUDA
447
+ std::vector<paddle::Tensor> RoIAlignRotatedCUDABackward(
448
+ const paddle::Tensor &input, const paddle::Tensor &rois,
449
+ const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
450
+ float spatial_scale, int sampling_ratio, bool aligned, bool clockwise);
451
+ #endif
452
+
453
+ std::vector<paddle::Tensor>
454
+ RoIAlignRotatedForward(const paddle::Tensor &input, const paddle::Tensor &rois,
455
+ int aligned_height, int aligned_width,
456
+ float spatial_scale, int sampling_ratio, bool aligned,
457
+ bool clockwise) {
458
+ CHECK_INPUT_SAME(input, rois);
459
+ if (input.is_cpu()) {
460
+ return RoIAlignRotatedCPUForward(input, rois, aligned_height, aligned_width,
461
+ spatial_scale, sampling_ratio, aligned,
462
+ clockwise);
463
+ #ifdef PADDLE_WITH_CUDA
464
+ } else if (input.is_gpu()) {
465
+ return RoIAlignRotatedCUDAForward(input, rois, aligned_height,
466
+ aligned_width, spatial_scale,
467
+ sampling_ratio, aligned, clockwise);
468
+ #endif
469
+ } else {
470
+ PD_THROW("Unsupported device type for forward function of roi align "
471
+ "rotated operator.");
472
+ }
473
+ }
474
+
475
+ std::vector<paddle::Tensor>
476
+ RoIAlignRotatedBackward(const paddle::Tensor &input, const paddle::Tensor &rois,
477
+ const paddle::Tensor &grad_output, int aligned_height,
478
+ int aligned_width, float spatial_scale,
479
+ int sampling_ratio, bool aligned, bool clockwise) {
480
+ CHECK_INPUT_SAME(input, rois);
481
+ if (input.is_cpu()) {
482
+ return RoIAlignRotatedCPUBackward(input, rois, grad_output, aligned_height,
483
+ aligned_width, spatial_scale,
484
+ sampling_ratio, aligned, clockwise);
485
+ #ifdef PADDLE_WITH_CUDA
486
+ } else if (input.is_gpu()) {
487
+ return RoIAlignRotatedCUDABackward(input, rois, grad_output, aligned_height,
488
+ aligned_width, spatial_scale,
489
+ sampling_ratio, aligned, clockwise);
490
+ #endif
491
+ } else {
492
+ PD_THROW("Unsupported device type for forward function of roi align "
493
+ "rotated operator.");
494
+ }
495
+ }
496
+
497
+ std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> input_shape,
498
+ std::vector<int64_t> rois_shape) {
499
+ return {{rois_shape[0], input_shape[1], input_shape[2], input_shape[3]}};
500
+ }
501
+
502
+ std::vector<std::vector<int64_t>>
503
+ InferBackShape(std::vector<int64_t> input_shape,
504
+ std::vector<int64_t> rois_shape) {
505
+ return {input_shape};
506
+ }
507
+
508
+ std::vector<paddle::DataType> InferDtype(paddle::DataType input_dtype,
509
+ paddle::DataType rois_dtype) {
510
+ return {input_dtype};
511
+ }
512
+
513
+ PD_BUILD_OP(roi_align_rotated)
514
+ .Inputs({"Input", "Rois"})
515
+ .Outputs({"Output"})
516
+ .Attrs({"aligned_height: int", "aligned_width: int", "spatial_scale: float",
517
+ "sampling_ratio: int", "aligned: bool", "clockwise: bool"})
518
+ .SetKernelFn(PD_KERNEL(RoIAlignRotatedForward))
519
+ .SetInferShapeFn(PD_INFER_SHAPE(InferShape))
520
+ .SetInferDtypeFn(PD_INFER_DTYPE(InferDtype));
521
+
522
+ PD_BUILD_GRAD_OP(roi_align_rotated)
523
+ .Inputs({"Input", "Rois", paddle::Grad("Output")})
524
+ .Attrs({"aligned_height: int", "aligned_width: int", "spatial_scale: float",
525
+ "sampling_ratio: int", "aligned: bool", "clockwise: bool"})
526
+ .Outputs({paddle::Grad("Input")})
527
+ .SetKernelFn(PD_KERNEL(RoIAlignRotatedBackward))
528
+ .SetInferShapeFn(PD_INFER_SHAPE(InferBackShape));
ppocr/ext_op/roi_align_rotated/roi_align_rotated.cu ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ // This code is refer from:
3
+ // https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/common/cuda/roi_align_rotated_cuda_kernel.cuh
4
+
5
+ #include <cassert>
6
+ #include <cmath>
7
+ #include <vector>
8
+
9
+ #include "paddle/extension.h"
10
+ #include <cuda.h>
11
+
12
+ #define CUDA_1D_KERNEL_LOOP(i, n) \
13
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
14
+ i += blockDim.x * gridDim.x)
15
+
16
+ #define THREADS_PER_BLOCK 512
17
+
18
+ inline int GET_BLOCKS(const int N) {
19
+ int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
20
+ int max_block_num = 4096;
21
+ return min(optimal_block_num, max_block_num);
22
+ }
23
+
24
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
25
+
26
+ static __inline__ __device__ double atomicAdd(double *address, double val) {
27
+ unsigned long long int *address_as_ull = (unsigned long long int *)address;
28
+ unsigned long long int old = *address_as_ull, assumed;
29
+ if (val == 0.0)
30
+ return __longlong_as_double(old);
31
+ do {
32
+ assumed = old;
33
+ old = atomicCAS(address_as_ull, assumed,
34
+ __double_as_longlong(val + __longlong_as_double(assumed)));
35
+ } while (assumed != old);
36
+ return __longlong_as_double(old);
37
+ }
38
+
39
+ #endif
40
+
41
+ template <typename T>
42
+ __device__ T bilinear_interpolate(const T *input, const int height,
43
+ const int width, T y, T x,
44
+ const int index /* index for debug only*/) {
45
+ // deal with cases that inverse elements are out of feature map boundary
46
+ if (y < -1.0 || y > height || x < -1.0 || x > width)
47
+ return 0;
48
+
49
+ if (y <= 0)
50
+ y = 0;
51
+ if (x <= 0)
52
+ x = 0;
53
+
54
+ int y_low = (int)y;
55
+ int x_low = (int)x;
56
+ int y_high;
57
+ int x_high;
58
+
59
+ if (y_low >= height - 1) {
60
+ y_high = y_low = height - 1;
61
+ y = (T)y_low;
62
+ } else {
63
+ y_high = y_low + 1;
64
+ }
65
+
66
+ if (x_low >= width - 1) {
67
+ x_high = x_low = width - 1;
68
+ x = (T)x_low;
69
+ } else {
70
+ x_high = x_low + 1;
71
+ }
72
+
73
+ T ly = y - y_low;
74
+ T lx = x - x_low;
75
+ T hy = 1. - ly, hx = 1. - lx;
76
+ // do bilinear interpolation
77
+ T v1 = input[y_low * width + x_low];
78
+ T v2 = input[y_low * width + x_high];
79
+ T v3 = input[y_high * width + x_low];
80
+ T v4 = input[y_high * width + x_high];
81
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
82
+
83
+ T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
84
+
85
+ return val;
86
+ }
87
+
88
+ template <typename T>
89
+ __device__ void
90
+ bilinear_interpolate_gradient(const int height, const int width, T y, T x,
91
+ T &w1, T &w2, T &w3, T &w4, int &x_low,
92
+ int &x_high, int &y_low, int &y_high,
93
+ const int index /* index for debug only*/) {
94
+ // deal with cases that inverse elements are out of feature map boundary
95
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
96
+ // empty
97
+ w1 = w2 = w3 = w4 = 0.;
98
+ x_low = x_high = y_low = y_high = -1;
99
+ return;
100
+ }
101
+
102
+ if (y <= 0)
103
+ y = 0;
104
+ if (x <= 0)
105
+ x = 0;
106
+
107
+ y_low = (int)y;
108
+ x_low = (int)x;
109
+
110
+ if (y_low >= height - 1) {
111
+ y_high = y_low = height - 1;
112
+ y = (T)y_low;
113
+ } else {
114
+ y_high = y_low + 1;
115
+ }
116
+
117
+ if (x_low >= width - 1) {
118
+ x_high = x_low = width - 1;
119
+ x = (T)x_low;
120
+ } else {
121
+ x_high = x_low + 1;
122
+ }
123
+
124
+ T ly = y - y_low;
125
+ T lx = x - x_low;
126
+ T hy = 1. - ly, hx = 1. - lx;
127
+
128
+ // reference in forward
129
+ // T v1 = input[y_low * width + x_low];
130
+ // T v2 = input[y_low * width + x_high];
131
+ // T v3 = input[y_high * width + x_low];
132
+ // T v4 = input[y_high * width + x_high];
133
+ // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
134
+
135
+ w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
136
+
137
+ return;
138
+ }
139
+
140
+ /*** Forward ***/
141
+ template <typename scalar_t>
142
+ __global__ void roi_align_rotated_cuda_forward_kernel(
143
+ const int nthreads, const scalar_t *bottom_data,
144
+ const scalar_t *bottom_rois, const scalar_t spatial_scale,
145
+ const int sample_num, const bool aligned, const bool clockwise,
146
+ const int channels, const int height, const int width,
147
+ const int pooled_height, const int pooled_width, scalar_t *top_data) {
148
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
149
+ // (n, c, ph, pw) is an element in the pooled output
150
+ int pw = index % pooled_width;
151
+ int ph = (index / pooled_width) % pooled_height;
152
+ int c = (index / pooled_width / pooled_height) % channels;
153
+ int n = index / pooled_width / pooled_height / channels;
154
+
155
+ const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
156
+ int roi_batch_ind = offset_bottom_rois[0];
157
+
158
+ // Do not using rounding; this implementation detail is critical
159
+ scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
160
+ scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
161
+ scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
162
+ scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
163
+ scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
164
+ // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
165
+ scalar_t theta = offset_bottom_rois[5];
166
+ if (clockwise) {
167
+ theta = -theta; // If clockwise, the angle needs to be reversed.
168
+ }
169
+ if (!aligned) { // for backward-compatibility only
170
+ // Force malformed ROIs to be 1x1
171
+ roi_width = max(roi_width, (scalar_t)1.);
172
+ roi_height = max(roi_height, (scalar_t)1.);
173
+ }
174
+ scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
175
+ static_cast<scalar_t>(pooled_height);
176
+ scalar_t bin_size_w =
177
+ static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
178
+
179
+ const scalar_t *offset_bottom_data =
180
+ bottom_data + (roi_batch_ind * channels + c) * height * width;
181
+
182
+ // We use roi_bin_grid to sample the grid and mimic integral
183
+ int roi_bin_grid_h = (sample_num > 0)
184
+ ? sample_num
185
+ : ceilf(roi_height / pooled_height); // e.g., = 2
186
+ int roi_bin_grid_w =
187
+ (sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
188
+
189
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
190
+ // Appropriate translation needs to be applied after.
191
+ scalar_t roi_start_h = -roi_height / 2.0;
192
+ scalar_t roi_start_w = -roi_width / 2.0;
193
+ scalar_t cosscalar_theta = cos(theta);
194
+ scalar_t sinscalar_theta = sin(theta);
195
+
196
+ // We do average (integral) pooling inside a bin
197
+ const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
198
+
199
+ scalar_t output_val = 0.;
200
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
201
+ const scalar_t yy =
202
+ roi_start_h + ph * bin_size_h +
203
+ static_cast<scalar_t>(iy + .5f) * bin_size_h /
204
+ static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
205
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
206
+ const scalar_t xx = roi_start_w + pw * bin_size_w +
207
+ static_cast<scalar_t>(ix + .5f) * bin_size_w /
208
+ static_cast<scalar_t>(roi_bin_grid_w);
209
+
210
+ // Rotate by theta (counterclockwise) around the center and translate
211
+ scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h;
212
+ scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w;
213
+
214
+ scalar_t val = bilinear_interpolate<scalar_t>(
215
+ offset_bottom_data, height, width, y, x, index);
216
+ output_val += val;
217
+ }
218
+ }
219
+ output_val /= count;
220
+
221
+ top_data[index] = output_val;
222
+ }
223
+ }
224
+
225
+ /*** Backward ***/
226
+ template <typename scalar_t>
227
+ __global__ void roi_align_rotated_backward_cuda_kernel(
228
+ const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois,
229
+ const scalar_t spatial_scale, const int sample_num, const bool aligned,
230
+ const bool clockwise, const int channels, const int height, const int width,
231
+ const int pooled_height, const int pooled_width, scalar_t *bottom_diff) {
232
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
233
+ // (n, c, ph, pw) is an element in the pooled output
234
+ int pw = index % pooled_width;
235
+ int ph = (index / pooled_width) % pooled_height;
236
+ int c = (index / pooled_width / pooled_height) % channels;
237
+ int n = index / pooled_width / pooled_height / channels;
238
+
239
+ const scalar_t *offset_bottom_rois = bottom_rois + n * 6;
240
+ int roi_batch_ind = offset_bottom_rois[0];
241
+
242
+ // Do not round
243
+ scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
244
+ scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset;
245
+ scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset;
246
+ scalar_t roi_width = offset_bottom_rois[3] * spatial_scale;
247
+ scalar_t roi_height = offset_bottom_rois[4] * spatial_scale;
248
+ // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0;
249
+ scalar_t theta = offset_bottom_rois[5];
250
+ if (clockwise) {
251
+ theta = -theta; // If clockwise, the angle needs to be reversed.
252
+ }
253
+ if (!aligned) { // for backward-compatibility only
254
+ // Force malformed ROIs to be 1x1
255
+ roi_width = max(roi_width, (scalar_t)1.);
256
+ roi_height = max(roi_height, (scalar_t)1.);
257
+ }
258
+ scalar_t bin_size_h = static_cast<scalar_t>(roi_height) /
259
+ static_cast<scalar_t>(pooled_height);
260
+ scalar_t bin_size_w =
261
+ static_cast<scalar_t>(roi_width) / static_cast<scalar_t>(pooled_width);
262
+
263
+ scalar_t *offset_bottom_diff =
264
+ bottom_diff + (roi_batch_ind * channels + c) * height * width;
265
+
266
+ int top_offset = (n * channels + c) * pooled_height * pooled_width;
267
+ const scalar_t *offset_top_diff = top_diff + top_offset;
268
+ const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
269
+
270
+ // We use roi_bin_grid to sample the grid and mimic integral
271
+ int roi_bin_grid_h = (sample_num > 0)
272
+ ? sample_num
273
+ : ceilf(roi_height / pooled_height); // e.g., = 2
274
+ int roi_bin_grid_w =
275
+ (sample_num > 0) ? sample_num : ceilf(roi_width / pooled_width);
276
+
277
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
278
+ // Appropriate translation needs to be applied after.
279
+ scalar_t roi_start_h = -roi_height / 2.0;
280
+ scalar_t roi_start_w = -roi_width / 2.0;
281
+ scalar_t cosTheta = cos(theta);
282
+ scalar_t sinTheta = sin(theta);
283
+
284
+ // We do average (integral) pooling inside a bin
285
+ const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
286
+
287
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1
288
+ const scalar_t yy =
289
+ roi_start_h + ph * bin_size_h +
290
+ static_cast<scalar_t>(iy + .5f) * bin_size_h /
291
+ static_cast<scalar_t>(roi_bin_grid_h); // e.g., 0.5, 1.5
292
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
293
+ const scalar_t xx = roi_start_w + pw * bin_size_w +
294
+ static_cast<scalar_t>(ix + .5f) * bin_size_w /
295
+ static_cast<scalar_t>(roi_bin_grid_w);
296
+
297
+ // Rotate by theta around the center and translate
298
+ scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h;
299
+ scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w;
300
+
301
+ scalar_t w1, w2, w3, w4;
302
+ int x_low, x_high, y_low, y_high;
303
+
304
+ bilinear_interpolate_gradient<scalar_t>(height, width, y, x, w1, w2, w3,
305
+ w4, x_low, x_high, y_low,
306
+ y_high, index);
307
+
308
+ scalar_t g1 = top_diff_this_bin * w1 / count;
309
+ scalar_t g2 = top_diff_this_bin * w2 / count;
310
+ scalar_t g3 = top_diff_this_bin * w3 / count;
311
+ scalar_t g4 = top_diff_this_bin * w4 / count;
312
+
313
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
314
+ atomicAdd(offset_bottom_diff + y_low * width + x_low, g1);
315
+ atomicAdd(offset_bottom_diff + y_low * width + x_high, g2);
316
+ atomicAdd(offset_bottom_diff + y_high * width + x_low, g3);
317
+ atomicAdd(offset_bottom_diff + y_high * width + x_high, g4);
318
+ } // if
319
+ } // ix
320
+ } // iy
321
+ } // CUDA_1D_KERNEL_LOOP
322
+ } // RoIAlignBackward
323
+
324
+ std::vector<paddle::Tensor>
325
+ RoIAlignRotatedCUDAForward(const paddle::Tensor &input,
326
+ const paddle::Tensor &rois, int aligned_height,
327
+ int aligned_width, float spatial_scale,
328
+ int sampling_ratio, bool aligned, bool clockwise) {
329
+
330
+ auto num_rois = rois.shape()[0];
331
+
332
+ auto channels = input.shape()[1];
333
+ auto height = input.shape()[2];
334
+ auto width = input.shape()[3];
335
+
336
+ auto output =
337
+ paddle::empty({num_rois, channels, aligned_height, aligned_width},
338
+ input.type(), paddle::GPUPlace());
339
+ auto output_size = output.numel();
340
+
341
+ PD_DISPATCH_FLOATING_TYPES(
342
+ input.type(), "roi_align_rotated_cuda_forward_kernel", ([&] {
343
+ roi_align_rotated_cuda_forward_kernel<
344
+ data_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
345
+ output_size, input.data<data_t>(), rois.data<data_t>(),
346
+ static_cast<data_t>(spatial_scale), sampling_ratio, aligned,
347
+ clockwise, channels, height, width, aligned_height, aligned_width,
348
+ output.data<data_t>());
349
+ }));
350
+
351
+ return {output};
352
+ }
353
+
354
+ std::vector<paddle::Tensor> RoIAlignRotatedCUDABackward(
355
+ const paddle::Tensor &input, const paddle::Tensor &rois,
356
+ const paddle::Tensor &grad_output, int aligned_height, int aligned_width,
357
+ float spatial_scale, int sampling_ratio, bool aligned, bool clockwise) {
358
+
359
+ auto num_rois = rois.shape()[0];
360
+
361
+ auto batch_size = input.shape()[0];
362
+ auto channels = input.shape()[1];
363
+ auto height = input.shape()[2];
364
+ auto width = input.shape()[3];
365
+
366
+ auto grad_input = paddle::full({batch_size, channels, height, width}, 0.0,
367
+ input.type(), paddle::GPUPlace());
368
+
369
+ const int output_size = num_rois * aligned_height * aligned_width * channels;
370
+
371
+ PD_DISPATCH_FLOATING_TYPES(
372
+ grad_output.type(), "roi_align_rotated_backward_cuda_kernel", ([&] {
373
+ roi_align_rotated_backward_cuda_kernel<
374
+ data_t><<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
375
+ output_size, grad_output.data<data_t>(), rois.data<data_t>(),
376
+ spatial_scale, sampling_ratio, aligned, clockwise, channels, height,
377
+ width, aligned_height, aligned_width, grad_input.data<data_t>());
378
+ }));
379
+ return {grad_input};
380
+ }