datdo2717 commited on
Commit
c5b5437
·
1 Parent(s): 688feaa
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Rotate/StyleText/README.md +219 -0
  2. Rotate/StyleText/README_ch.md +205 -0
  3. Rotate/StyleText/__init__.py +0 -0
  4. Rotate/StyleText/arch/__init__.py +0 -0
  5. Rotate/StyleText/arch/base_module.py +255 -0
  6. Rotate/StyleText/arch/decoder.py +251 -0
  7. Rotate/StyleText/arch/encoder.py +186 -0
  8. Rotate/StyleText/arch/spectral_norm.py +150 -0
  9. Rotate/StyleText/arch/style_text_rec.py +285 -0
  10. Rotate/StyleText/configs/config.yml +54 -0
  11. Rotate/StyleText/configs/dataset_config.yml +64 -0
  12. Rotate/StyleText/engine/__init__.py +0 -0
  13. Rotate/StyleText/engine/corpus_generators.py +66 -0
  14. Rotate/StyleText/engine/predictors.py +139 -0
  15. Rotate/StyleText/engine/style_samplers.py +62 -0
  16. Rotate/StyleText/engine/synthesisers.py +77 -0
  17. Rotate/StyleText/engine/text_drawers.py +85 -0
  18. Rotate/StyleText/engine/writers.py +71 -0
  19. Rotate/StyleText/examples/corpus/example.txt +2 -0
  20. Rotate/StyleText/examples/image_list.txt +2 -0
  21. Rotate/StyleText/tools/__init__.py +0 -0
  22. Rotate/StyleText/tools/synth_dataset.py +31 -0
  23. Rotate/StyleText/tools/synth_image.py +82 -0
  24. Rotate/StyleText/utils/__init__.py +0 -0
  25. Rotate/StyleText/utils/config.py +224 -0
  26. Rotate/StyleText/utils/load_params.py +27 -0
  27. Rotate/StyleText/utils/logging.py +65 -0
  28. Rotate/StyleText/utils/math_functions.py +45 -0
  29. Rotate/StyleText/utils/sys_funcs.py +67 -0
  30. Rotate/__init__.py +18 -0
  31. Rotate/ch_PP-OCRv4_det_infer/inference.pdiparams.info +0 -0
  32. Rotate/ch_PP-OCRv4_det_infer/inference.pdmodel +3 -0
  33. Rotate/ch_ppocr_mobile_v2.0_cls_infer/._inference.pdmodel +3 -0
  34. Rotate/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams +3 -0
  35. Rotate/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams.info +0 -0
  36. Rotate/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml +98 -0
  37. Rotate/configs/cls/cls_mv3.yml +94 -0
  38. Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml +206 -0
  39. Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml +175 -0
  40. Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml +178 -0
  41. Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml +132 -0
  42. Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml +226 -0
  43. Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml +173 -0
  44. Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml +163 -0
  45. Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_cml.yml +235 -0
  46. Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml +171 -0
  47. Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml +172 -0
  48. Rotate/configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml +132 -0
  49. Rotate/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml +131 -0
  50. Rotate/configs/det/det_mv3_db.yml +133 -0
Rotate/StyleText/README.md ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ English | [简体中文](README_ch.md)
2
+
3
+ ## Style Text
4
+
5
+ ### Contents
6
+ - [1. Introduction](#Introduction)
7
+ - [2. Preparation](#Preparation)
8
+ - [3. Quick Start](#Quick_Start)
9
+ - [4. Applications](#Applications)
10
+ - [5. Code Structure](#Code_structure)
11
+
12
+
13
+ <a name="Introduction"></a>
14
+ ### Introduction
15
+
16
+ <div align="center">
17
+ <img src="doc/images/3.png" width="800">
18
+ </div>
19
+
20
+ <div align="center">
21
+ <img src="doc/images/9.png" width="600">
22
+ </div>
23
+
24
+
25
+ The Style-Text data synthesis tool is a tool based on Baidu and HUST cooperation research work, "Editing Text in the Wild" [https://arxiv.org/abs/1908.03047](https://arxiv.org/abs/1908.03047).
26
+
27
+ Different from the commonly used GAN-based data synthesis tools, the main framework of Style-Text includes:
28
+ * (1) Text foreground style transfer module.
29
+ * (2) Background extraction module.
30
+ * (3) Fusion module.
31
+
32
+ After these three steps, you can quickly realize the image text style transfer. The following figure is some results of the data synthesis tool.
33
+
34
+ <div align="center">
35
+ <img src="doc/images/10.png" width="1000">
36
+ </div>
37
+
38
+
39
+ <a name="Preparation"></a>
40
+ #### Preparation
41
+
42
+ 1. Please refer the [QUICK INSTALLATION](../doc/doc_en/installation_en.md) to install PaddlePaddle. Python3 environment is strongly recommended.
43
+ 2. Download the pretrained models and unzip:
44
+
45
+ ```bash
46
+ cd StyleText
47
+ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip
48
+ unzip style_text_models.zip
49
+ ```
50
+
51
+ If you save the model in another location, please modify the address of the model file in `configs/config.yml`, and you need to modify these three configurations at the same time:
52
+
53
+ ```
54
+ bg_generator:
55
+ pretrain: style_text_models/bg_generator
56
+ ...
57
+ text_generator:
58
+ pretrain: style_text_models/text_generator
59
+ ...
60
+ fusion_generator:
61
+ pretrain: style_text_models/fusion_generator
62
+ ```
63
+
64
+ <a name="Quick_Start"></a>
65
+ ### Quick Start
66
+
67
+ #### Synthesis single image
68
+
69
+ 1. You can run `tools/synth_image` and generate the demo image, which is saved in the current folder.
70
+
71
+ ```python
72
+ python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
73
+ ```
74
+
75
+ * Note 1: The language options is correspond to the corpus. Currently, the tool only supports English(en), Simplified Chinese(ch) and Korean(ko).
76
+ * Note 2: Synth-Text is mainly used to generate images for OCR recognition models.
77
+ So the height of style images should be around 32 pixels. Images in other sizes may behave poorly.
78
+ * Note 3: You can modify `use_gpu` in `configs/config.yml` to determine whether to use GPU for prediction.
79
+
80
+
81
+
82
+ For example, enter the following image and corpus `PaddleOCR`.
83
+
84
+ <div align="center">
85
+ <img src="examples/style_images/2.jpg" width="300">
86
+ </div>
87
+
88
+ The result `fake_fusion.jpg` will be generated.
89
+
90
+ <div align="center">
91
+ <img src="doc/images/4.jpg" width="300">
92
+ </div>
93
+
94
+ What's more, the medium result `fake_bg.jpg` will also be saved, which is the background output.
95
+
96
+ <div align="center">
97
+ <img src="doc/images/7.jpg" width="300">
98
+ </div>
99
+
100
+
101
+ `fake_text.jpg` * `fake_text.jpg` is the generated image with the same font style as `Style Input`.
102
+
103
+
104
+ <div align="center">
105
+ <img src="doc/images/8.jpg" width="300">
106
+ </div>
107
+
108
+
109
+ #### Batch synthesis
110
+
111
+ In actual application scenarios, it is often necessary to synthesize pictures in batches and add them to the training set. StyleText can use a batch of style pictures and corpus to synthesize data in batches. The synthesis process is as follows:
112
+
113
+ 1. The referenced dataset can be specifed in `configs/dataset_config.yml`:
114
+
115
+ * `Global`:
116
+ * `output_dir:`:Output synthesis data path.
117
+ * `StyleSampler`:
118
+ * `image_home`:style images' folder.
119
+ * `label_file`:Style images' file list. If label is provided, then it is the label file path.
120
+ * `with_label`:Whether the `label_file` is label file list.
121
+ * `CorpusGenerator`:
122
+ * `method`:Method of CorpusGenerator,supports `FileCorpus` and `EnNumCorpus`. If `EnNumCorpus` is used,No other configuration is needed,otherwise you need to set `corpus_file` and `language`.
123
+ * `language`:Language of the corpus. Currently, the tool only supports English(en), Simplified Chinese(ch) and Korean(ko).
124
+ * `corpus_file`: Filepath of the corpus. Corpus file should be a text file which will be split by line-endings('\n'). Corpus generator samples one line each time.
125
+
126
+
127
+ Example of corpus file:
128
+ ```
129
+ PaddleOCR
130
+ 飞桨文字识别
131
+ StyleText
132
+ 风格文本图像数据合成
133
+ ```
134
+
135
+ We provide a general dataset containing Chinese, English and Korean (50,000 images in all) for your trial ([download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar)), some examples are given below :
136
+
137
+ <div align="center">
138
+ <img src="doc/images/5.png" width="800">
139
+ </div>
140
+
141
+ 2. You can run the following command to start synthesis task:
142
+
143
+ ``` bash
144
+ python3 tools/synth_dataset.py -c configs/dataset_config.yml
145
+ ```
146
+
147
+ We also provide example corpus and images in `examples` folder.
148
+ <div align="center">
149
+ <img src="examples/style_images/1.jpg" width="300">
150
+ <img src="examples/style_images/2.jpg" width="300">
151
+ </div>
152
+ If you run the code above directly, you will get example output data in `output_data` folder.
153
+ You will get synthesis images and labels as below:
154
+ <div align="center">
155
+ <img src="doc/images/12.png" width="800">
156
+ </div>
157
+ There will be some cache under the `label` folder. If the program exit unexpectedly, you can find cached labels there.
158
+ When the program finish normally, you will find all the labels in `label.txt` which give the final results.
159
+
160
+ <a name="Applications"></a>
161
+ ### Applications
162
+ We take two scenes as examples, which are metal surface English number recognition and general Korean recognition, to illustrate practical cases of using StyleText to synthesize data to improve text recognition. The following figure shows some examples of real scene images and composite images:
163
+
164
+ <div align="center">
165
+ <img src="doc/images/11.png" width="800">
166
+ </div>
167
+
168
+
169
+ After adding the above synthetic data for training, the accuracy of the recognition model is improved, which is shown in the following table:
170
+
171
+
172
+ | Scenario | Characters | Raw Data | Test Data | Only Use Raw Data</br>Recognition Accuracy | New Synthetic Data | Simultaneous Use of Synthetic Data</br>Recognition Accuracy | Index Improvement |
173
+ | -------- | ---------- | -------- | -------- | -------------------------- | ------------ | ---------------------- | -------- |
174
+ | Metal surface | English and numbers | 2203 | 650 | 59.38% | 20000 | 75.46% | 16.08% |
175
+ | Random background | Korean | 5631 | 1230 | 30.12% | 100000 | 50.57% | 20.45% |
176
+
177
+ <a name="Code_structure"></a>
178
+ ### Code Structure
179
+
180
+ ```
181
+ StyleText
182
+ |-- arch // Network module files.
183
+ | |-- base_module.py
184
+ | |-- decoder.py
185
+ | |-- encoder.py
186
+ | |-- spectral_norm.py
187
+ | `-- style_text_rec.py
188
+ |-- configs // Config files.
189
+ | |-- config.yml
190
+ | `-- dataset_config.yml
191
+ |-- engine // Synthesis engines.
192
+ | |-- corpus_generators.py // Sample corpus from file or generate random corpus.
193
+ | |-- predictors.py // Predict using network.
194
+ | |-- style_samplers.py // Sample style images.
195
+ | |-- synthesisers.py // Manage other engines to synthesis images.
196
+ | |-- text_drawers.py // Generate standard input text images.
197
+ | `-- writers.py // Write synthesis images and labels into files.
198
+ |-- examples // Example files.
199
+ | |-- corpus
200
+ | | `-- example.txt
201
+ | |-- image_list.txt
202
+ | `-- style_images
203
+ | |-- 1.jpg
204
+ | `-- 2.jpg
205
+ |-- fonts // Font files.
206
+ | |-- ch_standard.ttf
207
+ | |-- en_standard.ttf
208
+ | `-- ko_standard.ttf
209
+ |-- tools // Program entrance.
210
+ | |-- __init__.py
211
+ | |-- synth_dataset.py // Synthesis dataset.
212
+ | `-- synth_image.py // Synthesis image.
213
+ `-- utils // Module of basic functions.
214
+ |-- config.py
215
+ |-- load_params.py
216
+ |-- logging.py
217
+ |-- math_functions.py
218
+ `-- sys_funcs.py
219
+ ```
Rotate/StyleText/README_ch.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 简体中文 | [English](README.md)
2
+
3
+ ## Style Text
4
+
5
+
6
+ ### 目录
7
+ - [一、工具简介](#工具简介)
8
+ - [二、环境配置](#环境配置)
9
+ - [三、快速上手](#快速上手)
10
+ - [四、应用案例](#应用案例)
11
+ - [五、代码结构](#代码结构)
12
+
13
+ <a name="工具简介"></a>
14
+ ### 一、工具简介
15
+ <div align="center">
16
+ <img src="doc/images/3.png" width="800">
17
+ </div>
18
+
19
+ <div align="center">
20
+ <img src="doc/images/1.png" width="600">
21
+ </div>
22
+
23
+
24
+ Style-Text数据合成工具是基于百度和华科合作研发的文本编辑算法《Editing Text in the Wild》https://arxiv.org/abs/1908.03047
25
+
26
+ 不同于常用的基于GAN的数据合成工具,Style-Text主要框架包括:1.文本前景风格迁移模块 2.背景抽取模块 3.融合模块。经过这样三步,就可以迅速实现图像文本风格迁移。下图是一些该数据合成工具效果图。
27
+
28
+ <div align="center">
29
+ <img src="doc/images/2.png" width="1000">
30
+ </div>
31
+
32
+ <a name="环境配置"></a>
33
+ ### 二、环境配置
34
+
35
+ 1. 参考[快速安装](../doc/doc_ch/installation.md),安装PaddleOCR。
36
+ 2. 进入`StyleText`目录,下载模型,并解压:
37
+
38
+ ```bash
39
+ cd StyleText
40
+ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip
41
+ unzip style_text_models.zip
42
+ ```
43
+
44
+ 如果您将模型保存再其他位置,请在`configs/config.yml`中修改模型文件的地址,修改时需要同时修改这三个配置:
45
+
46
+ ```
47
+ bg_generator:
48
+ pretrain: style_text_models/bg_generator
49
+ ...
50
+ text_generator:
51
+ pretrain: style_text_models/text_generator
52
+ ...
53
+ fusion_generator:
54
+ pretrain: style_text_models/fusion_generator
55
+ ```
56
+
57
+ <a name="快速上手"></a>
58
+ ### 三、快速上手
59
+
60
+ #### 合成单张图
61
+ 输入一张风格图和一段文字语料,运行tools/synth_image,合成单张图片,结果图像保存在当前目录下:
62
+
63
+ ```python
64
+ python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
65
+ ```
66
+ * 注1:语言选项和语料相对应,目前支持英文(en)、简体中文(ch)和韩语(ko)。
67
+ * 注2:Style-Text生成的数据主要应用于OCR识别场景。基于当前PaddleOCR识别模型的设计,我们主要支持高度在32左右的风格图像。
68
+ 如果输入图像尺寸相差过多,效果可能不佳。
69
+ * 注3:可以通过修改配置文件`configs/config.yml`中的`use_gpu`(true或者false)参数来决定是否使用GPU进行预测。
70
+
71
+
72
+ 例如,输入如下图片和语料"PaddleOCR":
73
+
74
+ <div align="center">
75
+ <img src="examples/style_images/2.jpg" width="300">
76
+ </div>
77
+
78
+ 生成合成数据`fake_fusion.jpg`:
79
+ <div align="center">
80
+ <img src="doc/images/4.jpg" width="300">
81
+ </div>
82
+
83
+ 除此之外,程序还会生成并保存中间结果`fake_bg.jpg`:为风格参考图去掉文字后的背景;
84
+
85
+ <div align="center">
86
+ <img src="doc/images/7.jpg" width="300">
87
+ </div>
88
+
89
+ `fake_text.jpg`:是用提供的字符串,仿照风格参考图中文字的风格,生成在灰色背景上的文字图片。
90
+
91
+ <div align="center">
92
+ <img src="doc/images/8.jpg" width="300">
93
+ </div>
94
+
95
+ #### 批量合成
96
+ 在实际应用场景中,经常需要批量合成图片,补充到训练集中。Style-Text可以使用一批风格图片和语料,批量合成数据。合成过程如下:
97
+
98
+ 1. 在`configs/dataset_config.yml`中配置目标场景风格图像和语料的路径,具体如下:
99
+
100
+ * `Global`:
101
+ * `output_dir:`:保存合成数据的目录。
102
+ * `StyleSampler`:
103
+ * `image_home`:风格图片目录;
104
+ * `label_file`:风格图片路径列表文件,如果所用数据集有label,则label_file为label文件路径;
105
+ * `with_label`:标志`label_file`是否为label文件。
106
+ * `CorpusGenerator`:
107
+ * `method`:语料生成方法,目前有`FileCorpus`和`EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file`和`language`;
108
+ * `language`:语料的语种,目前支持英文(en)、简体中文(ch)和韩语(ko);
109
+ * `corpus_file`: 语料文件路径。语料文件应使用文本文件。语料生成器首先会将语料按行切分,之后每次随机选取一行。
110
+
111
+ 语料文件格式示例:
112
+ ```
113
+ PaddleOCR
114
+ 飞桨文字识别
115
+ StyleText
116
+ 风格文本图像数据合成
117
+ ...
118
+ ```
119
+
120
+ Style-Text也提供了一批中英韩5万张通用场景数据用作文本风格图像,便于合成场景丰富的文本图像,下图给出了一些示例。
121
+
122
+ 中英韩5万张通用场景数据: [下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar)
123
+
124
+ <div align="center">
125
+ <img src="doc/images/5.png" width="800">
126
+ </div>
127
+
128
+ 2. 运行`tools/synth_dataset`合成数据:
129
+
130
+ ``` bash
131
+ python3 tools/synth_dataset.py -c configs/dataset_config.yml
132
+ ```
133
+ 我们在examples目录下提供了样例图片和语料。
134
+ <div align="center">
135
+ <img src="examples/style_images/1.jpg" width="300">
136
+ <img src="examples/style_images/2.jpg" width="300">
137
+ </div>
138
+
139
+ 直接运行上述命令,可以在output_data中产生样例输出,包括图片和用于训练识别模型的标注文件:
140
+ <div align="center">
141
+ <img src="doc/images/12.png" width="800">
142
+ </div>
143
+
144
+ 其中label目录下的标注文件为程序运行过程中产生的缓存,如果程序在中途异常终止,可以使用缓存的标注文件。
145
+ 如果程序正常运行完毕,则会在output_data下生成label.txt,为最终的标注结果。
146
+
147
+ <a name="应用案例"></a>
148
+ ### 四、应用案例
149
+ 下面以金属表面英文数字识别和通用韩语识别两个场景为例,说明使用Style-Text合成数据,来提升文本识别效果的实际案例。下图给出了一些真实场景图像和合成图像的示例:
150
+
151
+ <div align="center">
152
+ <img src="doc/images/6.png" width="800">
153
+ </div>
154
+
155
+ 在添加上述合成数据进行训练后,识别模型的效果提升,如下表所示:
156
+
157
+ | 场景 | 字符 | 原始数据 | 测试数据 | 只使用原始数据</br>识别准确率 | 新增合成数据 | 同时使用合成数据</br>识别准确率 | 指标提升 |
158
+ | -------- | ---------- | -------- | -------- | -------------------------- | ------------ | ---------------------- | -------- |
159
+ | 金属表面 | 英文和数字 | 2203 | 650 | 59.38% | 20000 | 75.46% | 16.08% |
160
+ | 随机背景 | 韩语 | 5631 | 1230 | 30.12% | 100000 | 50.57% | 20.45% |
161
+
162
+
163
+ <a name="代码结构"></a>
164
+ ### 五、代码结构
165
+
166
+ ```
167
+ StyleText
168
+ |-- arch // 网络结构定义文件
169
+ | |-- base_module.py
170
+ | |-- decoder.py
171
+ | |-- encoder.py
172
+ | |-- spectral_norm.py
173
+ | `-- style_text_rec.py
174
+ |-- configs // 配置文件
175
+ | |-- config.yml
176
+ | `-- dataset_config.yml
177
+ |-- engine // 数据合成引擎
178
+ | |-- corpus_generators.py // 从文本采样或随机生成语料
179
+ | |-- predictors.py // 调用网络生成数据
180
+ | |-- style_samplers.py // 采样风格图片
181
+ | |-- synthesisers.py // 调度各个模块,合成数据
182
+ | |-- text_drawers.py // 生成标准文字图片,用作输入
183
+ | `-- writers.py // 将合成的图片和标签写入本地目录
184
+ |-- examples // 示例文件
185
+ | |-- corpus
186
+ | | `-- example.txt
187
+ | |-- image_list.txt
188
+ | `-- style_images
189
+ | |-- 1.jpg
190
+ | `-- 2.jpg
191
+ |-- fonts // 字体文件
192
+ | |-- ch_standard.ttf
193
+ | |-- en_standard.ttf
194
+ | `-- ko_standard.ttf
195
+ |-- tools // 程序入口
196
+ | |-- __init__.py
197
+ | |-- synth_dataset.py // 批量合成数据
198
+ | `-- synth_image.py // 合成单张图片
199
+ `-- utils // 其他基础功能模块
200
+ |-- config.py
201
+ |-- load_params.py
202
+ |-- logging.py
203
+ |-- math_functions.py
204
+ `-- sys_funcs.py
205
+ ```
Rotate/StyleText/__init__.py ADDED
File without changes
Rotate/StyleText/arch/__init__.py ADDED
File without changes
Rotate/StyleText/arch/base_module.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import paddle
15
+ import paddle.nn as nn
16
+
17
+ from arch.spectral_norm import spectral_norm
18
+
19
+
20
+ class CBN(nn.Layer):
21
+ def __init__(self,
22
+ name,
23
+ in_channels,
24
+ out_channels,
25
+ kernel_size,
26
+ stride=1,
27
+ padding=0,
28
+ dilation=1,
29
+ groups=1,
30
+ use_bias=False,
31
+ norm_layer=None,
32
+ act=None,
33
+ act_attr=None):
34
+ super(CBN, self).__init__()
35
+ if use_bias:
36
+ bias_attr = paddle.ParamAttr(name=name + "_bias")
37
+ else:
38
+ bias_attr = None
39
+ self._conv = paddle.nn.Conv2D(
40
+ in_channels=in_channels,
41
+ out_channels=out_channels,
42
+ kernel_size=kernel_size,
43
+ stride=stride,
44
+ padding=padding,
45
+ dilation=dilation,
46
+ groups=groups,
47
+ weight_attr=paddle.ParamAttr(name=name + "_weights"),
48
+ bias_attr=bias_attr)
49
+ if norm_layer:
50
+ self._norm_layer = getattr(paddle.nn, norm_layer)(
51
+ num_features=out_channels, name=name + "_bn")
52
+ else:
53
+ self._norm_layer = None
54
+ if act:
55
+ if act_attr:
56
+ self._act = getattr(paddle.nn, act)(**act_attr,
57
+ name=name + "_" + act)
58
+ else:
59
+ self._act = getattr(paddle.nn, act)(name=name + "_" + act)
60
+ else:
61
+ self._act = None
62
+
63
+ def forward(self, x):
64
+ out = self._conv(x)
65
+ if self._norm_layer:
66
+ out = self._norm_layer(out)
67
+ if self._act:
68
+ out = self._act(out)
69
+ return out
70
+
71
+
72
+ class SNConv(nn.Layer):
73
+ def __init__(self,
74
+ name,
75
+ in_channels,
76
+ out_channels,
77
+ kernel_size,
78
+ stride=1,
79
+ padding=0,
80
+ dilation=1,
81
+ groups=1,
82
+ use_bias=False,
83
+ norm_layer=None,
84
+ act=None,
85
+ act_attr=None):
86
+ super(SNConv, self).__init__()
87
+ if use_bias:
88
+ bias_attr = paddle.ParamAttr(name=name + "_bias")
89
+ else:
90
+ bias_attr = None
91
+ self._sn_conv = spectral_norm(
92
+ paddle.nn.Conv2D(
93
+ in_channels=in_channels,
94
+ out_channels=out_channels,
95
+ kernel_size=kernel_size,
96
+ stride=stride,
97
+ padding=padding,
98
+ dilation=dilation,
99
+ groups=groups,
100
+ weight_attr=paddle.ParamAttr(name=name + "_weights"),
101
+ bias_attr=bias_attr))
102
+ if norm_layer:
103
+ self._norm_layer = getattr(paddle.nn, norm_layer)(
104
+ num_features=out_channels, name=name + "_bn")
105
+ else:
106
+ self._norm_layer = None
107
+ if act:
108
+ if act_attr:
109
+ self._act = getattr(paddle.nn, act)(**act_attr,
110
+ name=name + "_" + act)
111
+ else:
112
+ self._act = getattr(paddle.nn, act)(name=name + "_" + act)
113
+ else:
114
+ self._act = None
115
+
116
+ def forward(self, x):
117
+ out = self._sn_conv(x)
118
+ if self._norm_layer:
119
+ out = self._norm_layer(out)
120
+ if self._act:
121
+ out = self._act(out)
122
+ return out
123
+
124
+
125
+ class SNConvTranspose(nn.Layer):
126
+ def __init__(self,
127
+ name,
128
+ in_channels,
129
+ out_channels,
130
+ kernel_size,
131
+ stride=1,
132
+ padding=0,
133
+ output_padding=0,
134
+ dilation=1,
135
+ groups=1,
136
+ use_bias=False,
137
+ norm_layer=None,
138
+ act=None,
139
+ act_attr=None):
140
+ super(SNConvTranspose, self).__init__()
141
+ if use_bias:
142
+ bias_attr = paddle.ParamAttr(name=name + "_bias")
143
+ else:
144
+ bias_attr = None
145
+ self._sn_conv_transpose = spectral_norm(
146
+ paddle.nn.Conv2DTranspose(
147
+ in_channels=in_channels,
148
+ out_channels=out_channels,
149
+ kernel_size=kernel_size,
150
+ stride=stride,
151
+ padding=padding,
152
+ output_padding=output_padding,
153
+ dilation=dilation,
154
+ groups=groups,
155
+ weight_attr=paddle.ParamAttr(name=name + "_weights"),
156
+ bias_attr=bias_attr))
157
+ if norm_layer:
158
+ self._norm_layer = getattr(paddle.nn, norm_layer)(
159
+ num_features=out_channels, name=name + "_bn")
160
+ else:
161
+ self._norm_layer = None
162
+ if act:
163
+ if act_attr:
164
+ self._act = getattr(paddle.nn, act)(**act_attr,
165
+ name=name + "_" + act)
166
+ else:
167
+ self._act = getattr(paddle.nn, act)(name=name + "_" + act)
168
+ else:
169
+ self._act = None
170
+
171
+ def forward(self, x):
172
+ out = self._sn_conv_transpose(x)
173
+ if self._norm_layer:
174
+ out = self._norm_layer(out)
175
+ if self._act:
176
+ out = self._act(out)
177
+ return out
178
+
179
+
180
+ class MiddleNet(nn.Layer):
181
+ def __init__(self, name, in_channels, mid_channels, out_channels,
182
+ use_bias):
183
+ super(MiddleNet, self).__init__()
184
+ self._sn_conv1 = SNConv(
185
+ name=name + "_sn_conv1",
186
+ in_channels=in_channels,
187
+ out_channels=mid_channels,
188
+ kernel_size=1,
189
+ use_bias=use_bias,
190
+ norm_layer=None,
191
+ act=None)
192
+ self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate")
193
+ self._sn_conv2 = SNConv(
194
+ name=name + "_sn_conv2",
195
+ in_channels=mid_channels,
196
+ out_channels=mid_channels,
197
+ kernel_size=3,
198
+ use_bias=use_bias)
199
+ self._sn_conv3 = SNConv(
200
+ name=name + "_sn_conv3",
201
+ in_channels=mid_channels,
202
+ out_channels=out_channels,
203
+ kernel_size=1,
204
+ use_bias=use_bias)
205
+
206
+ def forward(self, x):
207
+
208
+ sn_conv1 = self._sn_conv1.forward(x)
209
+ pad_2d = self._pad2d.forward(sn_conv1)
210
+ sn_conv2 = self._sn_conv2.forward(pad_2d)
211
+ sn_conv3 = self._sn_conv3.forward(sn_conv2)
212
+ return sn_conv3
213
+
214
+
215
+ class ResBlock(nn.Layer):
216
+ def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
217
+ use_bias):
218
+ super(ResBlock, self).__init__()
219
+ if use_dilation:
220
+ padding_mat = [1, 1, 1, 1]
221
+ else:
222
+ padding_mat = [0, 0, 0, 0]
223
+ self._pad1 = nn.Pad2D(padding_mat, mode="replicate")
224
+
225
+ self._sn_conv1 = SNConv(
226
+ name=name + "_sn_conv1",
227
+ in_channels=channels,
228
+ out_channels=channels,
229
+ kernel_size=3,
230
+ padding=0,
231
+ norm_layer=norm_layer,
232
+ use_bias=use_bias,
233
+ act="ReLU",
234
+ act_attr=None)
235
+ if use_dropout:
236
+ self._dropout = nn.Dropout(0.5)
237
+ else:
238
+ self._dropout = None
239
+ self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate")
240
+ self._sn_conv2 = SNConv(
241
+ name=name + "_sn_conv2",
242
+ in_channels=channels,
243
+ out_channels=channels,
244
+ kernel_size=3,
245
+ norm_layer=norm_layer,
246
+ use_bias=use_bias,
247
+ act="ReLU",
248
+ act_attr=None)
249
+
250
+ def forward(self, x):
251
+ pad1 = self._pad1.forward(x)
252
+ sn_conv1 = self._sn_conv1.forward(pad1)
253
+ pad2 = self._pad2.forward(sn_conv1)
254
+ sn_conv2 = self._sn_conv2.forward(pad2)
255
+ return sn_conv2 + x
Rotate/StyleText/arch/decoder.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import paddle
15
+ import paddle.nn as nn
16
+
17
+ from arch.base_module import SNConv, SNConvTranspose, ResBlock
18
+
19
+
20
+ class Decoder(nn.Layer):
21
+ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
22
+ act, act_attr, conv_block_dropout, conv_block_num,
23
+ conv_block_dilation, out_conv_act, out_conv_act_attr):
24
+ super(Decoder, self).__init__()
25
+ conv_blocks = []
26
+ for i in range(conv_block_num):
27
+ conv_blocks.append(
28
+ ResBlock(
29
+ name="{}_conv_block_{}".format(name, i),
30
+ channels=encode_dim * 8,
31
+ norm_layer=norm_layer,
32
+ use_dropout=conv_block_dropout,
33
+ use_dilation=conv_block_dilation,
34
+ use_bias=use_bias))
35
+ self.conv_blocks = nn.Sequential(*conv_blocks)
36
+ self._up1 = SNConvTranspose(
37
+ name=name + "_up1",
38
+ in_channels=encode_dim * 8,
39
+ out_channels=encode_dim * 4,
40
+ kernel_size=3,
41
+ stride=2,
42
+ padding=1,
43
+ output_padding=1,
44
+ use_bias=use_bias,
45
+ norm_layer=norm_layer,
46
+ act=act,
47
+ act_attr=act_attr)
48
+ self._up2 = SNConvTranspose(
49
+ name=name + "_up2",
50
+ in_channels=encode_dim * 4,
51
+ out_channels=encode_dim * 2,
52
+ kernel_size=3,
53
+ stride=2,
54
+ padding=1,
55
+ output_padding=1,
56
+ use_bias=use_bias,
57
+ norm_layer=norm_layer,
58
+ act=act,
59
+ act_attr=act_attr)
60
+ self._up3 = SNConvTranspose(
61
+ name=name + "_up3",
62
+ in_channels=encode_dim * 2,
63
+ out_channels=encode_dim,
64
+ kernel_size=3,
65
+ stride=2,
66
+ padding=1,
67
+ output_padding=1,
68
+ use_bias=use_bias,
69
+ norm_layer=norm_layer,
70
+ act=act,
71
+ act_attr=act_attr)
72
+ self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
73
+ self._out_conv = SNConv(
74
+ name=name + "_out_conv",
75
+ in_channels=encode_dim,
76
+ out_channels=out_channels,
77
+ kernel_size=3,
78
+ use_bias=use_bias,
79
+ norm_layer=None,
80
+ act=out_conv_act,
81
+ act_attr=out_conv_act_attr)
82
+
83
+ def forward(self, x):
84
+ if isinstance(x, (list, tuple)):
85
+ x = paddle.concat(x, axis=1)
86
+ output_dict = dict()
87
+ output_dict["conv_blocks"] = self.conv_blocks.forward(x)
88
+ output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
89
+ output_dict["up2"] = self._up2.forward(output_dict["up1"])
90
+ output_dict["up3"] = self._up3.forward(output_dict["up2"])
91
+ output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
92
+ output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
93
+ return output_dict
94
+
95
+
96
+ class DecoderUnet(nn.Layer):
97
+ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
98
+ act, act_attr, conv_block_dropout, conv_block_num,
99
+ conv_block_dilation, out_conv_act, out_conv_act_attr):
100
+ super(DecoderUnet, self).__init__()
101
+ conv_blocks = []
102
+ for i in range(conv_block_num):
103
+ conv_blocks.append(
104
+ ResBlock(
105
+ name="{}_conv_block_{}".format(name, i),
106
+ channels=encode_dim * 8,
107
+ norm_layer=norm_layer,
108
+ use_dropout=conv_block_dropout,
109
+ use_dilation=conv_block_dilation,
110
+ use_bias=use_bias))
111
+ self._conv_blocks = nn.Sequential(*conv_blocks)
112
+ self._up1 = SNConvTranspose(
113
+ name=name + "_up1",
114
+ in_channels=encode_dim * 8,
115
+ out_channels=encode_dim * 4,
116
+ kernel_size=3,
117
+ stride=2,
118
+ padding=1,
119
+ output_padding=1,
120
+ use_bias=use_bias,
121
+ norm_layer=norm_layer,
122
+ act=act,
123
+ act_attr=act_attr)
124
+ self._up2 = SNConvTranspose(
125
+ name=name + "_up2",
126
+ in_channels=encode_dim * 8,
127
+ out_channels=encode_dim * 2,
128
+ kernel_size=3,
129
+ stride=2,
130
+ padding=1,
131
+ output_padding=1,
132
+ use_bias=use_bias,
133
+ norm_layer=norm_layer,
134
+ act=act,
135
+ act_attr=act_attr)
136
+ self._up3 = SNConvTranspose(
137
+ name=name + "_up3",
138
+ in_channels=encode_dim * 4,
139
+ out_channels=encode_dim,
140
+ kernel_size=3,
141
+ stride=2,
142
+ padding=1,
143
+ output_padding=1,
144
+ use_bias=use_bias,
145
+ norm_layer=norm_layer,
146
+ act=act,
147
+ act_attr=act_attr)
148
+ self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
149
+ self._out_conv = SNConv(
150
+ name=name + "_out_conv",
151
+ in_channels=encode_dim,
152
+ out_channels=out_channels,
153
+ kernel_size=3,
154
+ use_bias=use_bias,
155
+ norm_layer=None,
156
+ act=out_conv_act,
157
+ act_attr=out_conv_act_attr)
158
+
159
+ def forward(self, x, y, feature2, feature1):
160
+ output_dict = dict()
161
+ output_dict["conv_blocks"] = self._conv_blocks(
162
+ paddle.concat(
163
+ (x, y), axis=1))
164
+ output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
165
+ output_dict["up2"] = self._up2.forward(
166
+ paddle.concat(
167
+ (output_dict["up1"], feature2), axis=1))
168
+ output_dict["up3"] = self._up3.forward(
169
+ paddle.concat(
170
+ (output_dict["up2"], feature1), axis=1))
171
+ output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
172
+ output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
173
+ return output_dict
174
+
175
+
176
+ class SingleDecoder(nn.Layer):
177
+ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
178
+ act, act_attr, conv_block_dropout, conv_block_num,
179
+ conv_block_dilation, out_conv_act, out_conv_act_attr):
180
+ super(SingleDecoder, self).__init__()
181
+ conv_blocks = []
182
+ for i in range(conv_block_num):
183
+ conv_blocks.append(
184
+ ResBlock(
185
+ name="{}_conv_block_{}".format(name, i),
186
+ channels=encode_dim * 4,
187
+ norm_layer=norm_layer,
188
+ use_dropout=conv_block_dropout,
189
+ use_dilation=conv_block_dilation,
190
+ use_bias=use_bias))
191
+ self._conv_blocks = nn.Sequential(*conv_blocks)
192
+ self._up1 = SNConvTranspose(
193
+ name=name + "_up1",
194
+ in_channels=encode_dim * 4,
195
+ out_channels=encode_dim * 4,
196
+ kernel_size=3,
197
+ stride=2,
198
+ padding=1,
199
+ output_padding=1,
200
+ use_bias=use_bias,
201
+ norm_layer=norm_layer,
202
+ act=act,
203
+ act_attr=act_attr)
204
+ self._up2 = SNConvTranspose(
205
+ name=name + "_up2",
206
+ in_channels=encode_dim * 8,
207
+ out_channels=encode_dim * 2,
208
+ kernel_size=3,
209
+ stride=2,
210
+ padding=1,
211
+ output_padding=1,
212
+ use_bias=use_bias,
213
+ norm_layer=norm_layer,
214
+ act=act,
215
+ act_attr=act_attr)
216
+ self._up3 = SNConvTranspose(
217
+ name=name + "_up3",
218
+ in_channels=encode_dim * 4,
219
+ out_channels=encode_dim,
220
+ kernel_size=3,
221
+ stride=2,
222
+ padding=1,
223
+ output_padding=1,
224
+ use_bias=use_bias,
225
+ norm_layer=norm_layer,
226
+ act=act,
227
+ act_attr=act_attr)
228
+ self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
229
+ self._out_conv = SNConv(
230
+ name=name + "_out_conv",
231
+ in_channels=encode_dim,
232
+ out_channels=out_channels,
233
+ kernel_size=3,
234
+ use_bias=use_bias,
235
+ norm_layer=None,
236
+ act=out_conv_act,
237
+ act_attr=out_conv_act_attr)
238
+
239
+ def forward(self, x, feature2, feature1):
240
+ output_dict = dict()
241
+ output_dict["conv_blocks"] = self._conv_blocks.forward(x)
242
+ output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
243
+ output_dict["up2"] = self._up2.forward(
244
+ paddle.concat(
245
+ (output_dict["up1"], feature2), axis=1))
246
+ output_dict["up3"] = self._up3.forward(
247
+ paddle.concat(
248
+ (output_dict["up2"], feature1), axis=1))
249
+ output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
250
+ output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
251
+ return output_dict
Rotate/StyleText/arch/encoder.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import paddle
15
+ import paddle.nn as nn
16
+
17
+ from arch.base_module import SNConv, SNConvTranspose, ResBlock
18
+
19
+
20
+ class Encoder(nn.Layer):
21
+ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
22
+ act, act_attr, conv_block_dropout, conv_block_num,
23
+ conv_block_dilation):
24
+ super(Encoder, self).__init__()
25
+ self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
26
+ self._in_conv = SNConv(
27
+ name=name + "_in_conv",
28
+ in_channels=in_channels,
29
+ out_channels=encode_dim,
30
+ kernel_size=7,
31
+ use_bias=use_bias,
32
+ norm_layer=norm_layer,
33
+ act=act,
34
+ act_attr=act_attr)
35
+ self._down1 = SNConv(
36
+ name=name + "_down1",
37
+ in_channels=encode_dim,
38
+ out_channels=encode_dim * 2,
39
+ kernel_size=3,
40
+ stride=2,
41
+ padding=1,
42
+ use_bias=use_bias,
43
+ norm_layer=norm_layer,
44
+ act=act,
45
+ act_attr=act_attr)
46
+ self._down2 = SNConv(
47
+ name=name + "_down2",
48
+ in_channels=encode_dim * 2,
49
+ out_channels=encode_dim * 4,
50
+ kernel_size=3,
51
+ stride=2,
52
+ padding=1,
53
+ use_bias=use_bias,
54
+ norm_layer=norm_layer,
55
+ act=act,
56
+ act_attr=act_attr)
57
+ self._down3 = SNConv(
58
+ name=name + "_down3",
59
+ in_channels=encode_dim * 4,
60
+ out_channels=encode_dim * 4,
61
+ kernel_size=3,
62
+ stride=2,
63
+ padding=1,
64
+ use_bias=use_bias,
65
+ norm_layer=norm_layer,
66
+ act=act,
67
+ act_attr=act_attr)
68
+ conv_blocks = []
69
+ for i in range(conv_block_num):
70
+ conv_blocks.append(
71
+ ResBlock(
72
+ name="{}_conv_block_{}".format(name, i),
73
+ channels=encode_dim * 4,
74
+ norm_layer=norm_layer,
75
+ use_dropout=conv_block_dropout,
76
+ use_dilation=conv_block_dilation,
77
+ use_bias=use_bias))
78
+ self._conv_blocks = nn.Sequential(*conv_blocks)
79
+
80
+ def forward(self, x):
81
+ out_dict = dict()
82
+ x = self._pad2d(x)
83
+ out_dict["in_conv"] = self._in_conv.forward(x)
84
+ out_dict["down1"] = self._down1.forward(out_dict["in_conv"])
85
+ out_dict["down2"] = self._down2.forward(out_dict["down1"])
86
+ out_dict["down3"] = self._down3.forward(out_dict["down2"])
87
+ out_dict["res_blocks"] = self._conv_blocks.forward(out_dict["down3"])
88
+ return out_dict
89
+
90
+
91
+ class EncoderUnet(nn.Layer):
92
+ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
93
+ act, act_attr):
94
+ super(EncoderUnet, self).__init__()
95
+ self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
96
+ self._in_conv = SNConv(
97
+ name=name + "_in_conv",
98
+ in_channels=in_channels,
99
+ out_channels=encode_dim,
100
+ kernel_size=7,
101
+ use_bias=use_bias,
102
+ norm_layer=norm_layer,
103
+ act=act,
104
+ act_attr=act_attr)
105
+ self._down1 = SNConv(
106
+ name=name + "_down1",
107
+ in_channels=encode_dim,
108
+ out_channels=encode_dim * 2,
109
+ kernel_size=3,
110
+ stride=2,
111
+ padding=1,
112
+ use_bias=use_bias,
113
+ norm_layer=norm_layer,
114
+ act=act,
115
+ act_attr=act_attr)
116
+ self._down2 = SNConv(
117
+ name=name + "_down2",
118
+ in_channels=encode_dim * 2,
119
+ out_channels=encode_dim * 2,
120
+ kernel_size=3,
121
+ stride=2,
122
+ padding=1,
123
+ use_bias=use_bias,
124
+ norm_layer=norm_layer,
125
+ act=act,
126
+ act_attr=act_attr)
127
+ self._down3 = SNConv(
128
+ name=name + "_down3",
129
+ in_channels=encode_dim * 2,
130
+ out_channels=encode_dim * 2,
131
+ kernel_size=3,
132
+ stride=2,
133
+ padding=1,
134
+ use_bias=use_bias,
135
+ norm_layer=norm_layer,
136
+ act=act,
137
+ act_attr=act_attr)
138
+ self._down4 = SNConv(
139
+ name=name + "_down4",
140
+ in_channels=encode_dim * 2,
141
+ out_channels=encode_dim * 2,
142
+ kernel_size=3,
143
+ stride=2,
144
+ padding=1,
145
+ use_bias=use_bias,
146
+ norm_layer=norm_layer,
147
+ act=act,
148
+ act_attr=act_attr)
149
+ self._up1 = SNConvTranspose(
150
+ name=name + "_up1",
151
+ in_channels=encode_dim * 2,
152
+ out_channels=encode_dim * 2,
153
+ kernel_size=3,
154
+ stride=2,
155
+ padding=1,
156
+ use_bias=use_bias,
157
+ norm_layer=norm_layer,
158
+ act=act,
159
+ act_attr=act_attr)
160
+ self._up2 = SNConvTranspose(
161
+ name=name + "_up2",
162
+ in_channels=encode_dim * 4,
163
+ out_channels=encode_dim * 4,
164
+ kernel_size=3,
165
+ stride=2,
166
+ padding=1,
167
+ use_bias=use_bias,
168
+ norm_layer=norm_layer,
169
+ act=act,
170
+ act_attr=act_attr)
171
+
172
+ def forward(self, x):
173
+ output_dict = dict()
174
+ x = self._pad2d(x)
175
+ output_dict['in_conv'] = self._in_conv.forward(x)
176
+ output_dict['down1'] = self._down1.forward(output_dict['in_conv'])
177
+ output_dict['down2'] = self._down2.forward(output_dict['down1'])
178
+ output_dict['down3'] = self._down3.forward(output_dict['down2'])
179
+ output_dict['down4'] = self._down4.forward(output_dict['down3'])
180
+ output_dict['up1'] = self._up1.forward(output_dict['down4'])
181
+ output_dict['up2'] = self._up2.forward(
182
+ paddle.concat(
183
+ (output_dict['down3'], output_dict['up1']), axis=1))
184
+ output_dict['concat'] = paddle.concat(
185
+ (output_dict['down2'], output_dict['up2']), axis=1)
186
+ return output_dict
Rotate/StyleText/arch/spectral_norm.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import paddle
15
+ import paddle.nn as nn
16
+ import paddle.nn.functional as F
17
+
18
+
19
+ def normal_(x, mean=0., std=1.):
20
+ temp_value = paddle.normal(mean, std, shape=x.shape)
21
+ x.set_value(temp_value)
22
+ return x
23
+
24
+
25
+ class SpectralNorm(object):
26
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
27
+ self.name = name
28
+ self.dim = dim
29
+ if n_power_iterations <= 0:
30
+ raise ValueError('Expected n_power_iterations to be positive, but '
31
+ 'got n_power_iterations={}'.format(
32
+ n_power_iterations))
33
+ self.n_power_iterations = n_power_iterations
34
+ self.eps = eps
35
+
36
+ def reshape_weight_to_matrix(self, weight):
37
+ weight_mat = weight
38
+ if self.dim != 0:
39
+ # transpose dim to front
40
+ weight_mat = weight_mat.transpose([
41
+ self.dim,
42
+ * [d for d in range(weight_mat.dim()) if d != self.dim]
43
+ ])
44
+
45
+ height = weight_mat.shape[0]
46
+
47
+ return weight_mat.reshape([height, -1])
48
+
49
+ def compute_weight(self, module, do_power_iteration):
50
+ weight = getattr(module, self.name + '_orig')
51
+ u = getattr(module, self.name + '_u')
52
+ v = getattr(module, self.name + '_v')
53
+ weight_mat = self.reshape_weight_to_matrix(weight)
54
+
55
+ if do_power_iteration:
56
+ with paddle.no_grad():
57
+ for _ in range(self.n_power_iterations):
58
+ v.set_value(
59
+ F.normalize(
60
+ paddle.matmul(
61
+ weight_mat,
62
+ u,
63
+ transpose_x=True,
64
+ transpose_y=False),
65
+ axis=0,
66
+ epsilon=self.eps, ))
67
+
68
+ u.set_value(
69
+ F.normalize(
70
+ paddle.matmul(weight_mat, v),
71
+ axis=0,
72
+ epsilon=self.eps, ))
73
+ if self.n_power_iterations > 0:
74
+ u = u.clone()
75
+ v = v.clone()
76
+
77
+ sigma = paddle.dot(u, paddle.mv(weight_mat, v))
78
+ weight = weight / sigma
79
+ return weight
80
+
81
+ def remove(self, module):
82
+ with paddle.no_grad():
83
+ weight = self.compute_weight(module, do_power_iteration=False)
84
+ delattr(module, self.name)
85
+ delattr(module, self.name + '_u')
86
+ delattr(module, self.name + '_v')
87
+ delattr(module, self.name + '_orig')
88
+
89
+ module.add_parameter(self.name, weight.detach())
90
+
91
+ def __call__(self, module, inputs):
92
+ setattr(
93
+ module,
94
+ self.name,
95
+ self.compute_weight(
96
+ module, do_power_iteration=module.training))
97
+
98
+ @staticmethod
99
+ def apply(module, name, n_power_iterations, dim, eps):
100
+ for k, hook in module._forward_pre_hooks.items():
101
+ if isinstance(hook, SpectralNorm) and hook.name == name:
102
+ raise RuntimeError(
103
+ "Cannot register two spectral_norm hooks on "
104
+ "the same parameter {}".format(name))
105
+
106
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
107
+ weight = module._parameters[name]
108
+
109
+ with paddle.no_grad():
110
+ weight_mat = fn.reshape_weight_to_matrix(weight)
111
+ h, w = weight_mat.shape
112
+
113
+ # randomly initialize u and v
114
+ u = module.create_parameter([h])
115
+ u = normal_(u, 0., 1.)
116
+ v = module.create_parameter([w])
117
+ v = normal_(v, 0., 1.)
118
+ u = F.normalize(u, axis=0, epsilon=fn.eps)
119
+ v = F.normalize(v, axis=0, epsilon=fn.eps)
120
+
121
+ # delete fn.name form parameters, otherwise you can not set attribute
122
+ del module._parameters[fn.name]
123
+ module.add_parameter(fn.name + "_orig", weight)
124
+ # still need to assign weight back as fn.name because all sorts of
125
+ # things may assume that it exists, e.g., when initializing weights.
126
+ # However, we can't directly assign as it could be an Parameter and
127
+ # gets added as a parameter. Instead, we register weight * 1.0 as a plain
128
+ # attribute.
129
+ setattr(module, fn.name, weight * 1.0)
130
+ module.register_buffer(fn.name + "_u", u)
131
+ module.register_buffer(fn.name + "_v", v)
132
+
133
+ module.register_forward_pre_hook(fn)
134
+ return fn
135
+
136
+
137
+ def spectral_norm(module,
138
+ name='weight',
139
+ n_power_iterations=1,
140
+ eps=1e-12,
141
+ dim=None):
142
+
143
+ if dim is None:
144
+ if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose,
145
+ nn.Conv3DTranspose, nn.Linear)):
146
+ dim = 1
147
+ else:
148
+ dim = 0
149
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
150
+ return module
Rotate/StyleText/arch/style_text_rec.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import paddle
15
+ import paddle.nn as nn
16
+
17
+ from arch.base_module import MiddleNet, ResBlock
18
+ from arch.encoder import Encoder
19
+ from arch.decoder import Decoder, DecoderUnet, SingleDecoder
20
+ from utils.load_params import load_dygraph_pretrain
21
+ from utils.logging import get_logger
22
+
23
+
24
+ class StyleTextRec(nn.Layer):
25
+ def __init__(self, config):
26
+ super(StyleTextRec, self).__init__()
27
+ self.logger = get_logger()
28
+ self.text_generator = TextGenerator(config["Predictor"][
29
+ "text_generator"])
30
+ self.bg_generator = BgGeneratorWithMask(config["Predictor"][
31
+ "bg_generator"])
32
+ self.fusion_generator = FusionGeneratorSimple(config["Predictor"][
33
+ "fusion_generator"])
34
+ bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"]
35
+ text_generator_pretrain = config["Predictor"]["text_generator"][
36
+ "pretrain"]
37
+ fusion_generator_pretrain = config["Predictor"]["fusion_generator"][
38
+ "pretrain"]
39
+ load_dygraph_pretrain(
40
+ self.bg_generator,
41
+ self.logger,
42
+ path=bg_generator_pretrain,
43
+ load_static_weights=False)
44
+ load_dygraph_pretrain(
45
+ self.text_generator,
46
+ self.logger,
47
+ path=text_generator_pretrain,
48
+ load_static_weights=False)
49
+ load_dygraph_pretrain(
50
+ self.fusion_generator,
51
+ self.logger,
52
+ path=fusion_generator_pretrain,
53
+ load_static_weights=False)
54
+
55
+ def forward(self, style_input, text_input):
56
+ text_gen_output = self.text_generator.forward(style_input, text_input)
57
+ fake_text = text_gen_output["fake_text"]
58
+ fake_sk = text_gen_output["fake_sk"]
59
+ bg_gen_output = self.bg_generator.forward(style_input)
60
+ bg_encode_feature = bg_gen_output["bg_encode_feature"]
61
+ bg_decode_feature1 = bg_gen_output["bg_decode_feature1"]
62
+ bg_decode_feature2 = bg_gen_output["bg_decode_feature2"]
63
+ fake_bg = bg_gen_output["fake_bg"]
64
+
65
+ fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg)
66
+ fake_fusion = fusion_gen_output["fake_fusion"]
67
+ return {
68
+ "fake_fusion": fake_fusion,
69
+ "fake_text": fake_text,
70
+ "fake_sk": fake_sk,
71
+ "fake_bg": fake_bg,
72
+ }
73
+
74
+
75
+ class TextGenerator(nn.Layer):
76
+ def __init__(self, config):
77
+ super(TextGenerator, self).__init__()
78
+ name = config["module_name"]
79
+ encode_dim = config["encode_dim"]
80
+ norm_layer = config["norm_layer"]
81
+ conv_block_dropout = config["conv_block_dropout"]
82
+ conv_block_num = config["conv_block_num"]
83
+ conv_block_dilation = config["conv_block_dilation"]
84
+ if norm_layer == "InstanceNorm2D":
85
+ use_bias = True
86
+ else:
87
+ use_bias = False
88
+ self.encoder_text = Encoder(
89
+ name=name + "_encoder_text",
90
+ in_channels=3,
91
+ encode_dim=encode_dim,
92
+ use_bias=use_bias,
93
+ norm_layer=norm_layer,
94
+ act="ReLU",
95
+ act_attr=None,
96
+ conv_block_dropout=conv_block_dropout,
97
+ conv_block_num=conv_block_num,
98
+ conv_block_dilation=conv_block_dilation)
99
+ self.encoder_style = Encoder(
100
+ name=name + "_encoder_style",
101
+ in_channels=3,
102
+ encode_dim=encode_dim,
103
+ use_bias=use_bias,
104
+ norm_layer=norm_layer,
105
+ act="ReLU",
106
+ act_attr=None,
107
+ conv_block_dropout=conv_block_dropout,
108
+ conv_block_num=conv_block_num,
109
+ conv_block_dilation=conv_block_dilation)
110
+ self.decoder_text = Decoder(
111
+ name=name + "_decoder_text",
112
+ encode_dim=encode_dim,
113
+ out_channels=int(encode_dim / 2),
114
+ use_bias=use_bias,
115
+ norm_layer=norm_layer,
116
+ act="ReLU",
117
+ act_attr=None,
118
+ conv_block_dropout=conv_block_dropout,
119
+ conv_block_num=conv_block_num,
120
+ conv_block_dilation=conv_block_dilation,
121
+ out_conv_act="Tanh",
122
+ out_conv_act_attr=None)
123
+ self.decoder_sk = Decoder(
124
+ name=name + "_decoder_sk",
125
+ encode_dim=encode_dim,
126
+ out_channels=1,
127
+ use_bias=use_bias,
128
+ norm_layer=norm_layer,
129
+ act="ReLU",
130
+ act_attr=None,
131
+ conv_block_dropout=conv_block_dropout,
132
+ conv_block_num=conv_block_num,
133
+ conv_block_dilation=conv_block_dilation,
134
+ out_conv_act="Sigmoid",
135
+ out_conv_act_attr=None)
136
+
137
+ self.middle = MiddleNet(
138
+ name=name + "_middle_net",
139
+ in_channels=int(encode_dim / 2) + 1,
140
+ mid_channels=encode_dim,
141
+ out_channels=3,
142
+ use_bias=use_bias)
143
+
144
+ def forward(self, style_input, text_input):
145
+ style_feature = self.encoder_style.forward(style_input)["res_blocks"]
146
+ text_feature = self.encoder_text.forward(text_input)["res_blocks"]
147
+ fake_c_temp = self.decoder_text.forward([text_feature,
148
+ style_feature])["out_conv"]
149
+ fake_sk = self.decoder_sk.forward([text_feature,
150
+ style_feature])["out_conv"]
151
+ fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1))
152
+ return {"fake_sk": fake_sk, "fake_text": fake_text}
153
+
154
+
155
+ class BgGeneratorWithMask(nn.Layer):
156
+ def __init__(self, config):
157
+ super(BgGeneratorWithMask, self).__init__()
158
+ name = config["module_name"]
159
+ encode_dim = config["encode_dim"]
160
+ norm_layer = config["norm_layer"]
161
+ conv_block_dropout = config["conv_block_dropout"]
162
+ conv_block_num = config["conv_block_num"]
163
+ conv_block_dilation = config["conv_block_dilation"]
164
+ self.output_factor = config.get("output_factor", 1.0)
165
+
166
+ if norm_layer == "InstanceNorm2D":
167
+ use_bias = True
168
+ else:
169
+ use_bias = False
170
+
171
+ self.encoder_bg = Encoder(
172
+ name=name + "_encoder_bg",
173
+ in_channels=3,
174
+ encode_dim=encode_dim,
175
+ use_bias=use_bias,
176
+ norm_layer=norm_layer,
177
+ act="ReLU",
178
+ act_attr=None,
179
+ conv_block_dropout=conv_block_dropout,
180
+ conv_block_num=conv_block_num,
181
+ conv_block_dilation=conv_block_dilation)
182
+
183
+ self.decoder_bg = SingleDecoder(
184
+ name=name + "_decoder_bg",
185
+ encode_dim=encode_dim,
186
+ out_channels=3,
187
+ use_bias=use_bias,
188
+ norm_layer=norm_layer,
189
+ act="ReLU",
190
+ act_attr=None,
191
+ conv_block_dropout=conv_block_dropout,
192
+ conv_block_num=conv_block_num,
193
+ conv_block_dilation=conv_block_dilation,
194
+ out_conv_act="Tanh",
195
+ out_conv_act_attr=None)
196
+
197
+ self.decoder_mask = Decoder(
198
+ name=name + "_decoder_mask",
199
+ encode_dim=encode_dim // 2,
200
+ out_channels=1,
201
+ use_bias=use_bias,
202
+ norm_layer=norm_layer,
203
+ act="ReLU",
204
+ act_attr=None,
205
+ conv_block_dropout=conv_block_dropout,
206
+ conv_block_num=conv_block_num,
207
+ conv_block_dilation=conv_block_dilation,
208
+ out_conv_act="Sigmoid",
209
+ out_conv_act_attr=None)
210
+
211
+ self.middle = MiddleNet(
212
+ name=name + "_middle_net",
213
+ in_channels=3 + 1,
214
+ mid_channels=encode_dim,
215
+ out_channels=3,
216
+ use_bias=use_bias)
217
+
218
+ def forward(self, style_input):
219
+ encode_bg_output = self.encoder_bg(style_input)
220
+ decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"],
221
+ encode_bg_output["down2"],
222
+ encode_bg_output["down1"])
223
+
224
+ fake_c_temp = decode_bg_output["out_conv"]
225
+ fake_bg_mask = self.decoder_mask.forward(encode_bg_output[
226
+ "res_blocks"])["out_conv"]
227
+ fake_bg = self.middle(
228
+ paddle.concat(
229
+ (fake_c_temp, fake_bg_mask), axis=1))
230
+ return {
231
+ "bg_encode_feature": encode_bg_output["res_blocks"],
232
+ "bg_decode_feature1": decode_bg_output["up1"],
233
+ "bg_decode_feature2": decode_bg_output["up2"],
234
+ "fake_bg": fake_bg,
235
+ "fake_bg_mask": fake_bg_mask,
236
+ }
237
+
238
+
239
+ class FusionGeneratorSimple(nn.Layer):
240
+ def __init__(self, config):
241
+ super(FusionGeneratorSimple, self).__init__()
242
+ name = config["module_name"]
243
+ encode_dim = config["encode_dim"]
244
+ norm_layer = config["norm_layer"]
245
+ conv_block_dropout = config["conv_block_dropout"]
246
+ conv_block_dilation = config["conv_block_dilation"]
247
+ if norm_layer == "InstanceNorm2D":
248
+ use_bias = True
249
+ else:
250
+ use_bias = False
251
+
252
+ self._conv = nn.Conv2D(
253
+ in_channels=6,
254
+ out_channels=encode_dim,
255
+ kernel_size=3,
256
+ stride=1,
257
+ padding=1,
258
+ groups=1,
259
+ weight_attr=paddle.ParamAttr(name=name + "_conv_weights"),
260
+ bias_attr=False)
261
+
262
+ self._res_block = ResBlock(
263
+ name="{}_conv_block".format(name),
264
+ channels=encode_dim,
265
+ norm_layer=norm_layer,
266
+ use_dropout=conv_block_dropout,
267
+ use_dilation=conv_block_dilation,
268
+ use_bias=use_bias)
269
+
270
+ self._reduce_conv = nn.Conv2D(
271
+ in_channels=encode_dim,
272
+ out_channels=3,
273
+ kernel_size=3,
274
+ stride=1,
275
+ padding=1,
276
+ groups=1,
277
+ weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"),
278
+ bias_attr=False)
279
+
280
+ def forward(self, fake_text, fake_bg):
281
+ fake_concat = paddle.concat((fake_text, fake_bg), axis=1)
282
+ fake_concat_tmp = self._conv(fake_concat)
283
+ output_res = self._res_block(fake_concat_tmp)
284
+ fake_fusion = self._reduce_conv(output_res)
285
+ return {"fake_fusion": fake_fusion}
Rotate/StyleText/configs/config.yml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ output_num: 10
3
+ output_dir: output_data
4
+ use_gpu: false
5
+ image_height: 32
6
+ image_width: 320
7
+ TextDrawer:
8
+ fonts:
9
+ en: fonts/en_standard.ttf
10
+ ch: fonts/ch_standard.ttf
11
+ ko: fonts/ko_standard.ttf
12
+ Predictor:
13
+ method: StyleTextRecPredictor
14
+ algorithm: StyleTextRec
15
+ scale: 0.00392156862745098
16
+ mean:
17
+ - 0.5
18
+ - 0.5
19
+ - 0.5
20
+ std:
21
+ - 0.5
22
+ - 0.5
23
+ - 0.5
24
+ expand_result: false
25
+ bg_generator:
26
+ pretrain: style_text_models/bg_generator
27
+ module_name: bg_generator
28
+ generator_type: BgGeneratorWithMask
29
+ encode_dim: 64
30
+ norm_layer: null
31
+ conv_block_num: 4
32
+ conv_block_dropout: false
33
+ conv_block_dilation: true
34
+ output_factor: 1.05
35
+ text_generator:
36
+ pretrain: style_text_models/text_generator
37
+ module_name: text_generator
38
+ generator_type: TextGenerator
39
+ encode_dim: 64
40
+ norm_layer: InstanceNorm2D
41
+ conv_block_num: 4
42
+ conv_block_dropout: false
43
+ conv_block_dilation: true
44
+ fusion_generator:
45
+ pretrain: style_text_models/fusion_generator
46
+ module_name: fusion_generator
47
+ generator_type: FusionGeneratorSimple
48
+ encode_dim: 64
49
+ norm_layer: null
50
+ conv_block_num: 4
51
+ conv_block_dropout: false
52
+ conv_block_dilation: true
53
+ Writer:
54
+ method: SimpleWriter
Rotate/StyleText/configs/dataset_config.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ output_num: 10
3
+ output_dir: output_data
4
+ use_gpu: false
5
+ image_height: 32
6
+ image_width: 320
7
+ standard_font: fonts/en_standard.ttf
8
+ TextDrawer:
9
+ fonts:
10
+ en: fonts/en_standard.ttf
11
+ ch: fonts/ch_standard.ttf
12
+ ko: fonts/ko_standard.ttf
13
+ StyleSampler:
14
+ method: DatasetSampler
15
+ image_home: examples
16
+ label_file: examples/image_list.txt
17
+ with_label: true
18
+ CorpusGenerator:
19
+ method: FileCorpus
20
+ language: ch
21
+ corpus_file: examples/corpus/example.txt
22
+ Predictor:
23
+ method: StyleTextRecPredictor
24
+ algorithm: StyleTextRec
25
+ scale: 0.00392156862745098
26
+ mean:
27
+ - 0.5
28
+ - 0.5
29
+ - 0.5
30
+ std:
31
+ - 0.5
32
+ - 0.5
33
+ - 0.5
34
+ expand_result: false
35
+ bg_generator:
36
+ pretrain: style_text_models/bg_generator
37
+ module_name: bg_generator
38
+ generator_type: BgGeneratorWithMask
39
+ encode_dim: 64
40
+ norm_layer: null
41
+ conv_block_num: 4
42
+ conv_block_dropout: false
43
+ conv_block_dilation: true
44
+ output_factor: 1.05
45
+ text_generator:
46
+ pretrain: style_text_models/text_generator
47
+ module_name: text_generator
48
+ generator_type: TextGenerator
49
+ encode_dim: 64
50
+ norm_layer: InstanceNorm2D
51
+ conv_block_num: 4
52
+ conv_block_dropout: false
53
+ conv_block_dilation: true
54
+ fusion_generator:
55
+ pretrain: style_text_models/fusion_generator
56
+ module_name: fusion_generator
57
+ generator_type: FusionGeneratorSimple
58
+ encode_dim: 64
59
+ norm_layer: null
60
+ conv_block_num: 4
61
+ conv_block_dropout: false
62
+ conv_block_dilation: true
63
+ Writer:
64
+ method: SimpleWriter
Rotate/StyleText/engine/__init__.py ADDED
File without changes
Rotate/StyleText/engine/corpus_generators.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import random
15
+
16
+ from utils.logging import get_logger
17
+
18
+
19
+ class FileCorpus(object):
20
+ def __init__(self, config):
21
+ self.logger = get_logger()
22
+ self.logger.info("using FileCorpus")
23
+
24
+ self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
25
+
26
+ corpus_file = config["CorpusGenerator"]["corpus_file"]
27
+ self.language = config["CorpusGenerator"]["language"]
28
+ with open(corpus_file, 'r') as f:
29
+ corpus_raw = f.read()
30
+ self.corpus_list = corpus_raw.split("\n")[:-1]
31
+ assert len(self.corpus_list) > 0
32
+ random.shuffle(self.corpus_list)
33
+ self.index = 0
34
+
35
+ def generate(self, corpus_length=0):
36
+ if self.index >= len(self.corpus_list):
37
+ self.index = 0
38
+ random.shuffle(self.corpus_list)
39
+ corpus = self.corpus_list[self.index]
40
+ if corpus_length != 0:
41
+ corpus = corpus[0:corpus_length]
42
+ if corpus_length > len(corpus):
43
+ self.logger.warning("generated corpus is shorter than expected.")
44
+ self.index += 1
45
+ return self.language, corpus
46
+
47
+
48
+ class EnNumCorpus(object):
49
+ def __init__(self, config):
50
+ self.logger = get_logger()
51
+ self.logger.info("using NumberCorpus")
52
+ self.num_list = "0123456789"
53
+ self.en_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
54
+ self.height = config["Global"]["image_height"]
55
+ self.max_width = config["Global"]["image_width"]
56
+
57
+ def generate(self, corpus_length=0):
58
+ corpus = ""
59
+ if corpus_length == 0:
60
+ corpus_length = random.randint(5, 15)
61
+ for i in range(corpus_length):
62
+ if random.random() < 0.2:
63
+ corpus += "{}".format(random.choice(self.en_char_list))
64
+ else:
65
+ corpus += "{}".format(random.choice(self.num_list))
66
+ return "en", corpus
Rotate/StyleText/engine/predictors.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import cv2
16
+ import math
17
+ import paddle
18
+
19
+ from arch import style_text_rec
20
+ from utils.sys_funcs import check_gpu
21
+ from utils.logging import get_logger
22
+
23
+
24
+ class StyleTextRecPredictor(object):
25
+ def __init__(self, config):
26
+ algorithm = config['Predictor']['algorithm']
27
+ assert algorithm in ["StyleTextRec"
28
+ ], "Generator {} not supported.".format(algorithm)
29
+ use_gpu = config["Global"]['use_gpu']
30
+ check_gpu(use_gpu)
31
+ paddle.set_device('gpu' if use_gpu else 'cpu')
32
+ self.logger = get_logger()
33
+ self.generator = getattr(style_text_rec, algorithm)(config)
34
+ self.height = config["Global"]["image_height"]
35
+ self.width = config["Global"]["image_width"]
36
+ self.scale = config["Predictor"]["scale"]
37
+ self.mean = config["Predictor"]["mean"]
38
+ self.std = config["Predictor"]["std"]
39
+ self.expand_result = config["Predictor"]["expand_result"]
40
+
41
+ def reshape_to_same_height(self, img_list):
42
+ h = img_list[0].shape[0]
43
+ for idx in range(1, len(img_list)):
44
+ new_w = round(1.0 * img_list[idx].shape[1] /
45
+ img_list[idx].shape[0] * h)
46
+ img_list[idx] = cv2.resize(img_list[idx], (new_w, h))
47
+ return img_list
48
+
49
+ def predict_single_image(self, style_input, text_input):
50
+ style_input = self.rep_style_input(style_input, text_input)
51
+ tensor_style_input = self.preprocess(style_input)
52
+ tensor_text_input = self.preprocess(text_input)
53
+ style_text_result = self.generator.forward(tensor_style_input,
54
+ tensor_text_input)
55
+ fake_fusion = self.postprocess(style_text_result["fake_fusion"])
56
+ fake_text = self.postprocess(style_text_result["fake_text"])
57
+ fake_sk = self.postprocess(style_text_result["fake_sk"])
58
+ fake_bg = self.postprocess(style_text_result["fake_bg"])
59
+ bbox = self.get_text_boundary(fake_text)
60
+ if bbox:
61
+ left, right, top, bottom = bbox
62
+ fake_fusion = fake_fusion[top:bottom, left:right, :]
63
+ fake_text = fake_text[top:bottom, left:right, :]
64
+ fake_sk = fake_sk[top:bottom, left:right, :]
65
+ fake_bg = fake_bg[top:bottom, left:right, :]
66
+
67
+ # fake_fusion = self.crop_by_text(img_fake_fusion, img_fake_text)
68
+ return {
69
+ "fake_fusion": fake_fusion,
70
+ "fake_text": fake_text,
71
+ "fake_sk": fake_sk,
72
+ "fake_bg": fake_bg,
73
+ }
74
+
75
+ def predict(self, style_input, text_input_list):
76
+ if not isinstance(text_input_list, (tuple, list)):
77
+ return self.predict_single_image(style_input, text_input_list)
78
+
79
+ synth_result_list = []
80
+ for text_input in text_input_list:
81
+ synth_result = self.predict_single_image(style_input, text_input)
82
+ synth_result_list.append(synth_result)
83
+
84
+ for key in synth_result:
85
+ res = [r[key] for r in synth_result_list]
86
+ res = self.reshape_to_same_height(res)
87
+ synth_result[key] = np.concatenate(res, axis=1)
88
+ return synth_result
89
+
90
+ def preprocess(self, img):
91
+ img = (img.astype('float32') * self.scale - self.mean) / self.std
92
+ img_height, img_width, channel = img.shape
93
+ assert channel == 3, "Please use an rgb image."
94
+ ratio = img_width / float(img_height)
95
+ if math.ceil(self.height * ratio) > self.width:
96
+ resized_w = self.width
97
+ else:
98
+ resized_w = int(math.ceil(self.height * ratio))
99
+ img = cv2.resize(img, (resized_w, self.height))
100
+
101
+ new_img = np.zeros([self.height, self.width, 3]).astype('float32')
102
+ new_img[:, 0:resized_w, :] = img
103
+ img = new_img.transpose((2, 0, 1))
104
+ img = img[np.newaxis, :, :, :]
105
+ return paddle.to_tensor(img)
106
+
107
+ def postprocess(self, tensor):
108
+ img = tensor.numpy()[0]
109
+ img = img.transpose((1, 2, 0))
110
+ img = (img * self.std + self.mean) / self.scale
111
+ img = np.maximum(img, 0.0)
112
+ img = np.minimum(img, 255.0)
113
+ img = img.astype('uint8')
114
+ return img
115
+
116
+ def rep_style_input(self, style_input, text_input):
117
+ rep_num = int(1.2 * (text_input.shape[1] / text_input.shape[0]) /
118
+ (style_input.shape[1] / style_input.shape[0])) + 1
119
+ style_input = np.tile(style_input, reps=[1, rep_num, 1])
120
+ max_width = int(self.width / self.height * style_input.shape[0])
121
+ style_input = style_input[:, :max_width, :]
122
+ return style_input
123
+
124
+ def get_text_boundary(self, text_img):
125
+ img_height = text_img.shape[0]
126
+ img_width = text_img.shape[1]
127
+ bounder = 3
128
+ text_canny_img = cv2.Canny(text_img, 10, 20)
129
+ edge_num_h = text_canny_img.sum(axis=0)
130
+ no_zero_list_h = np.where(edge_num_h > 0)[0]
131
+ edge_num_w = text_canny_img.sum(axis=1)
132
+ no_zero_list_w = np.where(edge_num_w > 0)[0]
133
+ if len(no_zero_list_h) == 0 or len(no_zero_list_w) == 0:
134
+ return None
135
+ left = max(no_zero_list_h[0] - bounder, 0)
136
+ right = min(no_zero_list_h[-1] + bounder, img_width)
137
+ top = max(no_zero_list_w[0] - bounder, 0)
138
+ bottom = min(no_zero_list_w[-1] + bounder, img_height)
139
+ return [left, right, top, bottom]
Rotate/StyleText/engine/style_samplers.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import random
16
+ import cv2
17
+
18
+
19
+ class DatasetSampler(object):
20
+ def __init__(self, config):
21
+ self.image_home = config["StyleSampler"]["image_home"]
22
+ label_file = config["StyleSampler"]["label_file"]
23
+ self.dataset_with_label = config["StyleSampler"]["with_label"]
24
+ self.height = config["Global"]["image_height"]
25
+ self.index = 0
26
+ with open(label_file, "r") as f:
27
+ label_raw = f.read()
28
+ self.path_label_list = label_raw.split("\n")[:-1]
29
+ assert len(self.path_label_list) > 0
30
+ random.shuffle(self.path_label_list)
31
+
32
+ def sample(self):
33
+ if self.index >= len(self.path_label_list):
34
+ random.shuffle(self.path_label_list)
35
+ self.index = 0
36
+ if self.dataset_with_label:
37
+ path_label = self.path_label_list[self.index]
38
+ rel_image_path, label = path_label.split('\t')
39
+ else:
40
+ rel_image_path = self.path_label_list[self.index]
41
+ label = None
42
+ img_path = "{}/{}".format(self.image_home, rel_image_path)
43
+ image = cv2.imread(img_path)
44
+ origin_height = image.shape[0]
45
+ ratio = self.height / origin_height
46
+ width = int(image.shape[1] * ratio)
47
+ height = int(image.shape[0] * ratio)
48
+ image = cv2.resize(image, (width, height))
49
+
50
+ self.index += 1
51
+ if label:
52
+ return {"image": image, "label": label}
53
+ else:
54
+ return {"image": image}
55
+
56
+
57
+ def duplicate_image(image, width):
58
+ image_width = image.shape[1]
59
+ dup_num = width // image_width + 1
60
+ image = np.tile(image, reps=[1, dup_num, 1])
61
+ cropped_image = image[:, :width, :]
62
+ return cropped_image
Rotate/StyleText/engine/synthesisers.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import numpy as np
16
+ import cv2
17
+
18
+ from utils.config import ArgsParser, load_config, override_config
19
+ from utils.logging import get_logger
20
+ from engine import style_samplers, corpus_generators, text_drawers, predictors, writers
21
+
22
+
23
+ class ImageSynthesiser(object):
24
+ def __init__(self):
25
+ self.FLAGS = ArgsParser().parse_args()
26
+ self.config = load_config(self.FLAGS.config)
27
+ self.config = override_config(self.config, options=self.FLAGS.override)
28
+ self.output_dir = self.config["Global"]["output_dir"]
29
+ if not os.path.exists(self.output_dir):
30
+ os.mkdir(self.output_dir)
31
+ self.logger = get_logger(
32
+ log_file='{}/predict.log'.format(self.output_dir))
33
+
34
+ self.text_drawer = text_drawers.StdTextDrawer(self.config)
35
+
36
+ predictor_method = self.config["Predictor"]["method"]
37
+ assert predictor_method is not None
38
+ self.predictor = getattr(predictors, predictor_method)(self.config)
39
+
40
+ def synth_image(self, corpus, style_input, language="en"):
41
+ corpus_list, text_input_list = self.text_drawer.draw_text(
42
+ corpus, language, style_input_width=style_input.shape[1])
43
+ synth_result = self.predictor.predict(style_input, text_input_list)
44
+ return synth_result
45
+
46
+
47
+ class DatasetSynthesiser(ImageSynthesiser):
48
+ def __init__(self):
49
+ super(DatasetSynthesiser, self).__init__()
50
+ self.tag = self.FLAGS.tag
51
+ self.output_num = self.config["Global"]["output_num"]
52
+ corpus_generator_method = self.config["CorpusGenerator"]["method"]
53
+ self.corpus_generator = getattr(corpus_generators,
54
+ corpus_generator_method)(self.config)
55
+
56
+ style_sampler_method = self.config["StyleSampler"]["method"]
57
+ assert style_sampler_method is not None
58
+ self.style_sampler = style_samplers.DatasetSampler(self.config)
59
+ self.writer = writers.SimpleWriter(self.config, self.tag)
60
+
61
+ def synth_dataset(self):
62
+ for i in range(self.output_num):
63
+ style_data = self.style_sampler.sample()
64
+ style_input = style_data["image"]
65
+ corpus_language, text_input_label = self.corpus_generator.generate()
66
+ text_input_label_list, text_input_list = self.text_drawer.draw_text(
67
+ text_input_label,
68
+ corpus_language,
69
+ style_input_width=style_input.shape[1])
70
+
71
+ text_input_label = "".join(text_input_label_list)
72
+
73
+ synth_result = self.predictor.predict(style_input, text_input_list)
74
+ fake_fusion = synth_result["fake_fusion"]
75
+ self.writer.save_image(fake_fusion, text_input_label)
76
+ self.writer.save_label()
77
+ self.writer.merge_label()
Rotate/StyleText/engine/text_drawers.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageFont
2
+ import numpy as np
3
+ import cv2
4
+ from utils.logging import get_logger
5
+
6
+
7
+ class StdTextDrawer(object):
8
+ def __init__(self, config):
9
+ self.logger = get_logger()
10
+ self.max_width = config["Global"]["image_width"]
11
+ self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
12
+ self.height = config["Global"]["image_height"]
13
+ self.font_dict = {}
14
+ self.load_fonts(config["TextDrawer"]["fonts"])
15
+ self.support_languages = list(self.font_dict)
16
+
17
+ def load_fonts(self, fonts_config):
18
+ for language in fonts_config:
19
+ font_path = fonts_config[language]
20
+ font_height = self.get_valid_height(font_path)
21
+ font = ImageFont.truetype(font_path, font_height)
22
+ self.font_dict[language] = font
23
+
24
+ def get_valid_height(self, font_path):
25
+ font = ImageFont.truetype(font_path, self.height - 4)
26
+ left, top, right, bottom = font.getbbox(self.char_list)
27
+ _, font_height = right - left, bottom - top
28
+ if font_height <= self.height - 4:
29
+ return self.height - 4
30
+ else:
31
+ return int((self.height - 4)**2 / font_height)
32
+
33
+ def draw_text(self,
34
+ corpus,
35
+ language="en",
36
+ crop=True,
37
+ style_input_width=None):
38
+ if language not in self.support_languages:
39
+ self.logger.warning(
40
+ "language {} not supported, use en instead.".format(language))
41
+ language = "en"
42
+ if crop:
43
+ width = min(self.max_width, len(corpus) * self.height) + 4
44
+ else:
45
+ width = len(corpus) * self.height + 4
46
+
47
+ if style_input_width is not None:
48
+ width = min(width, style_input_width)
49
+
50
+ corpus_list = []
51
+ text_input_list = []
52
+
53
+ while len(corpus) != 0:
54
+ bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
55
+ draw = ImageDraw.Draw(bg)
56
+ char_x = 2
57
+ font = self.font_dict[language]
58
+ i = 0
59
+ while i < len(corpus):
60
+ char_i = corpus[i]
61
+ char_size = font.getsize(char_i)[0]
62
+ # split when char_x exceeds char size and index is not 0 (at least 1 char should be wroten on the image)
63
+ if char_x + char_size >= width and i != 0:
64
+ text_input = np.array(bg).astype(np.uint8)
65
+ text_input = text_input[:, 0:char_x, :]
66
+
67
+ corpus_list.append(corpus[0:i])
68
+ text_input_list.append(text_input)
69
+ corpus = corpus[i:]
70
+ i = 0
71
+ break
72
+ draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
73
+ char_x += char_size
74
+
75
+ i += 1
76
+ # the whole text is shorter than style input
77
+ if i == len(corpus):
78
+ text_input = np.array(bg).astype(np.uint8)
79
+ text_input = text_input[:, 0:char_x, :]
80
+
81
+ corpus_list.append(corpus[0:i])
82
+ text_input_list.append(text_input)
83
+ break
84
+
85
+ return corpus_list, text_input_list
Rotate/StyleText/engine/writers.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import cv2
16
+ import glob
17
+
18
+ from utils.logging import get_logger
19
+
20
+
21
+ class SimpleWriter(object):
22
+ def __init__(self, config, tag):
23
+ self.logger = get_logger()
24
+ self.output_dir = config["Global"]["output_dir"]
25
+ self.counter = 0
26
+ self.label_dict = {}
27
+ self.tag = tag
28
+ self.label_file_index = 0
29
+
30
+ def save_image(self, image, text_input_label):
31
+ image_home = os.path.join(self.output_dir, "images", self.tag)
32
+ if not os.path.exists(image_home):
33
+ os.makedirs(image_home)
34
+
35
+ image_path = os.path.join(image_home, "{}.png".format(self.counter))
36
+ # todo support continue synth
37
+ cv2.imwrite(image_path, image)
38
+ self.logger.info("generate image: {}".format(image_path))
39
+
40
+ image_name = os.path.join(self.tag, "{}.png".format(self.counter))
41
+ self.label_dict[image_name] = text_input_label
42
+
43
+ self.counter += 1
44
+ if not self.counter % 100:
45
+ self.save_label()
46
+
47
+ def save_label(self):
48
+ label_raw = ""
49
+ label_home = os.path.join(self.output_dir, "label")
50
+ if not os.path.exists(label_home):
51
+ os.mkdir(label_home)
52
+ for image_path in self.label_dict:
53
+ label = self.label_dict[image_path]
54
+ label_raw += "{}\t{}\n".format(image_path, label)
55
+ label_file_path = os.path.join(label_home,
56
+ "{}_label.txt".format(self.tag))
57
+ with open(label_file_path, "w") as f:
58
+ f.write(label_raw)
59
+ self.label_file_index += 1
60
+
61
+ def merge_label(self):
62
+ label_raw = ""
63
+ label_file_regex = os.path.join(self.output_dir, "label",
64
+ "*_label.txt")
65
+ label_file_list = glob.glob(label_file_regex)
66
+ for label_file_i in label_file_list:
67
+ with open(label_file_i, "r") as f:
68
+ label_raw += f.read()
69
+ label_file_path = os.path.join(self.output_dir, "label.txt")
70
+ with open(label_file_path, "w") as f:
71
+ f.write(label_raw)
Rotate/StyleText/examples/corpus/example.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Paddle
2
+ 飞桨文字识别
Rotate/StyleText/examples/image_list.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ style_images/1.jpg NEATNESS
2
+ style_images/2.jpg 锁店君和宾馆
Rotate/StyleText/tools/__init__.py ADDED
File without changes
Rotate/StyleText/tools/synth_dataset.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+
18
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
19
+ sys.path.append(__dir__)
20
+ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
21
+
22
+ from engine.synthesisers import DatasetSynthesiser
23
+
24
+
25
+ def synth_dataset():
26
+ dataset_synthesiser = DatasetSynthesiser()
27
+ dataset_synthesiser.synth_dataset()
28
+
29
+
30
+ if __name__ == '__main__':
31
+ synth_dataset()
Rotate/StyleText/tools/synth_image.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import cv2
16
+ import sys
17
+ import glob
18
+
19
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
20
+ sys.path.append(__dir__)
21
+ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
22
+
23
+ from utils.config import ArgsParser
24
+ from engine.synthesisers import ImageSynthesiser
25
+
26
+
27
+ def synth_image():
28
+ args = ArgsParser().parse_args()
29
+ image_synthesiser = ImageSynthesiser()
30
+ style_image_path = args.style_image
31
+ img = cv2.imread(style_image_path)
32
+ text_corpus = args.text_corpus
33
+ language = args.language
34
+
35
+ synth_result = image_synthesiser.synth_image(text_corpus, img, language)
36
+ fake_fusion = synth_result["fake_fusion"]
37
+ fake_text = synth_result["fake_text"]
38
+ fake_bg = synth_result["fake_bg"]
39
+ cv2.imwrite("fake_fusion.jpg", fake_fusion)
40
+ cv2.imwrite("fake_text.jpg", fake_text)
41
+ cv2.imwrite("fake_bg.jpg", fake_bg)
42
+
43
+
44
+ def batch_synth_images():
45
+ image_synthesiser = ImageSynthesiser()
46
+
47
+ corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt"
48
+ style_data_dir = "../StyleTextRec_data/test_20201208/style_images/"
49
+ save_path = "./output_data/"
50
+ corpus_list = []
51
+ with open(corpus_file, "rb") as fin:
52
+ lines = fin.readlines()
53
+ for line in lines:
54
+ substr = line.decode("utf-8").strip("\n").split("\t")
55
+ corpus_list.append(substr)
56
+ style_img_list = glob.glob("{}/*.jpg".format(style_data_dir))
57
+ corpus_num = len(corpus_list)
58
+ style_img_num = len(style_img_list)
59
+ for cno in range(corpus_num):
60
+ for sno in range(style_img_num):
61
+ corpus, lang = corpus_list[cno]
62
+ style_img_path = style_img_list[sno]
63
+ img = cv2.imread(style_img_path)
64
+ synth_result = image_synthesiser.synth_image(corpus, img, lang)
65
+ fake_fusion = synth_result["fake_fusion"]
66
+ fake_text = synth_result["fake_text"]
67
+ fake_bg = synth_result["fake_bg"]
68
+ for tp in range(2):
69
+ if tp == 0:
70
+ prefix = "%s/c%d_s%d_" % (save_path, cno, sno)
71
+ else:
72
+ prefix = "%s/s%d_c%d_" % (save_path, sno, cno)
73
+ cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion)
74
+ cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text)
75
+ cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg)
76
+ cv2.imwrite("%s_input_style.jpg" % prefix, img)
77
+ print(cno, corpus_num, sno, style_img_num)
78
+
79
+
80
+ if __name__ == '__main__':
81
+ # batch_synth_images()
82
+ synth_image()
Rotate/StyleText/utils/__init__.py ADDED
File without changes
Rotate/StyleText/utils/config.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import yaml
15
+ import os
16
+ from argparse import ArgumentParser, RawDescriptionHelpFormatter
17
+
18
+
19
+ def override(dl, ks, v):
20
+ """
21
+ Recursively replace dict of list
22
+
23
+ Args:
24
+ dl(dict or list): dict or list to be replaced
25
+ ks(list): list of keys
26
+ v(str): value to be replaced
27
+ """
28
+
29
+ def str2num(v):
30
+ try:
31
+ return eval(v)
32
+ except Exception:
33
+ return v
34
+
35
+ assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
36
+ assert len(ks) > 0, ('lenght of keys should larger than 0')
37
+ if isinstance(dl, list):
38
+ k = str2num(ks[0])
39
+ if len(ks) == 1:
40
+ assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
41
+ dl[k] = str2num(v)
42
+ else:
43
+ override(dl[k], ks[1:], v)
44
+ else:
45
+ if len(ks) == 1:
46
+ #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
47
+ if not ks[0] in dl:
48
+ logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
49
+ dl[ks[0]] = str2num(v)
50
+ else:
51
+ assert ks[0] in dl, (
52
+ '({}) doesn\'t exist in {}, a new dict field is invalid'.
53
+ format(ks[0], dl))
54
+ override(dl[ks[0]], ks[1:], v)
55
+
56
+
57
+ def override_config(config, options=None):
58
+ """
59
+ Recursively override the config
60
+
61
+ Args:
62
+ config(dict): dict to be replaced
63
+ options(list): list of pairs(key0.key1.idx.key2=value)
64
+ such as: [
65
+ 'topk=2',
66
+ 'VALID.transforms.1.ResizeImage.resize_short=300'
67
+ ]
68
+
69
+ Returns:
70
+ config(dict): replaced config
71
+ """
72
+ if options is not None:
73
+ for opt in options:
74
+ assert isinstance(opt, str), (
75
+ "option({}) should be a str".format(opt))
76
+ assert "=" in opt, (
77
+ "option({}) should contain a ="
78
+ "to distinguish between key and value".format(opt))
79
+ pair = opt.split('=')
80
+ assert len(pair) == 2, ("there can be only a = in the option")
81
+ key, value = pair
82
+ keys = key.split('.')
83
+ override(config, keys, value)
84
+
85
+ return config
86
+
87
+
88
+ class ArgsParser(ArgumentParser):
89
+ def __init__(self):
90
+ super(ArgsParser, self).__init__(
91
+ formatter_class=RawDescriptionHelpFormatter)
92
+ self.add_argument("-c", "--config", help="configuration file to use")
93
+ self.add_argument(
94
+ "-t", "--tag", default="0", help="tag for marking worker")
95
+ self.add_argument(
96
+ '-o',
97
+ '--override',
98
+ action='append',
99
+ default=[],
100
+ help='config options to be overridden')
101
+ self.add_argument(
102
+ "--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
103
+ self.add_argument(
104
+ "--text_corpus", default="PaddleOCR", help="tag for marking worker")
105
+ self.add_argument(
106
+ "--language", default="en", help="tag for marking worker")
107
+
108
+ def parse_args(self, argv=None):
109
+ args = super(ArgsParser, self).parse_args(argv)
110
+ assert args.config is not None, \
111
+ "Please specify --config=configure_file_path."
112
+ return args
113
+
114
+
115
+ def load_config(file_path):
116
+ """
117
+ Load config from yml/yaml file.
118
+ Args:
119
+ file_path (str): Path of the config file to be loaded.
120
+ Returns: config
121
+ """
122
+ ext = os.path.splitext(file_path)[1]
123
+ assert ext in ['.yml', '.yaml'], "only support yaml files for now"
124
+ with open(file_path, 'rb') as f:
125
+ config = yaml.load(f, Loader=yaml.Loader)
126
+
127
+ return config
128
+
129
+
130
+ def gen_config():
131
+ base_config = {
132
+ "Global": {
133
+ "algorithm": "SRNet",
134
+ "use_gpu": True,
135
+ "start_epoch": 1,
136
+ "stage1_epoch_num": 100,
137
+ "stage2_epoch_num": 100,
138
+ "log_smooth_window": 20,
139
+ "print_batch_step": 2,
140
+ "save_model_dir": "./output/SRNet",
141
+ "use_visualdl": False,
142
+ "save_epoch_step": 10,
143
+ "vgg_pretrain": "./pretrained/VGG19_pretrained",
144
+ "vgg_load_static_pretrain": True
145
+ },
146
+ "Architecture": {
147
+ "model_type": "data_aug",
148
+ "algorithm": "SRNet",
149
+ "net_g": {
150
+ "name": "srnet_net_g",
151
+ "encode_dim": 64,
152
+ "norm": "batch",
153
+ "use_dropout": False,
154
+ "init_type": "xavier",
155
+ "init_gain": 0.02,
156
+ "use_dilation": 1
157
+ },
158
+ # input_nc, ndf, netD,
159
+ # n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
160
+ "bg_discriminator": {
161
+ "name": "srnet_bg_discriminator",
162
+ "input_nc": 6,
163
+ "ndf": 64,
164
+ "netD": "basic",
165
+ "norm": "none",
166
+ "init_type": "xavier",
167
+ },
168
+ "fusion_discriminator": {
169
+ "name": "srnet_fusion_discriminator",
170
+ "input_nc": 6,
171
+ "ndf": 64,
172
+ "netD": "basic",
173
+ "norm": "none",
174
+ "init_type": "xavier",
175
+ }
176
+ },
177
+ "Loss": {
178
+ "lamb": 10,
179
+ "perceptual_lamb": 1,
180
+ "muvar_lamb": 50,
181
+ "style_lamb": 500
182
+ },
183
+ "Optimizer": {
184
+ "name": "Adam",
185
+ "learning_rate": {
186
+ "name": "lambda",
187
+ "lr": 0.0002,
188
+ "lr_decay_iters": 50
189
+ },
190
+ "beta1": 0.5,
191
+ "beta2": 0.999,
192
+ },
193
+ "Train": {
194
+ "batch_size_per_card": 8,
195
+ "num_workers_per_card": 4,
196
+ "dataset": {
197
+ "delimiter": "\t",
198
+ "data_dir": "/",
199
+ "label_file": "tmp/label.txt",
200
+ "transforms": [{
201
+ "DecodeImage": {
202
+ "to_rgb": True,
203
+ "to_np": False,
204
+ "channel_first": False
205
+ }
206
+ }, {
207
+ "NormalizeImage": {
208
+ "scale": 1. / 255.,
209
+ "mean": [0.485, 0.456, 0.406],
210
+ "std": [0.229, 0.224, 0.225],
211
+ "order": None
212
+ }
213
+ }, {
214
+ "ToCHWImage": None
215
+ }]
216
+ }
217
+ }
218
+ }
219
+ with open("config.yml", "w") as f:
220
+ yaml.dump(base_config, f)
221
+
222
+
223
+ if __name__ == '__main__':
224
+ gen_config()
Rotate/StyleText/utils/load_params.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import paddle
16
+
17
+ __all__ = ['load_dygraph_pretrain']
18
+
19
+
20
+ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
21
+ if not os.path.exists(path + '.pdparams'):
22
+ raise ValueError("Model pretrain path {} does not "
23
+ "exists.".format(path))
24
+ param_state_dict = paddle.load(path + '.pdparams')
25
+ model.set_state_dict(param_state_dict)
26
+ logger.info("load pretrained model from {}".format(path))
27
+ return
Rotate/StyleText/utils/logging.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import logging
17
+ import functools
18
+ import paddle.distributed as dist
19
+
20
+ logger_initialized = {}
21
+
22
+
23
+ @functools.lru_cache()
24
+ def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
25
+ """Initialize and get a logger by name.
26
+ If the logger has not been initialized, this method will initialize the
27
+ logger by adding one or two handlers, otherwise the initialized logger will
28
+ be directly returned. During initialization, a StreamHandler will always be
29
+ added. If `log_file` is specified a FileHandler will also be added.
30
+ Args:
31
+ name (str): Logger name.
32
+ log_file (str | None): The log filename. If specified, a FileHandler
33
+ will be added to the logger.
34
+ log_level (int): The logger level. Note that only the process of
35
+ rank 0 is affected, and other processes will set the level to
36
+ "Error" thus be silent most of the time.
37
+ Returns:
38
+ logging.Logger: The expected logger.
39
+ """
40
+ logger = logging.getLogger(name)
41
+ if name in logger_initialized:
42
+ return logger
43
+ for logger_name in logger_initialized:
44
+ if name.startswith(logger_name):
45
+ return logger
46
+
47
+ formatter = logging.Formatter(
48
+ '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
49
+ datefmt="%Y/%m/%d %H:%M:%S")
50
+
51
+ stream_handler = logging.StreamHandler(stream=sys.stdout)
52
+ stream_handler.setFormatter(formatter)
53
+ logger.addHandler(stream_handler)
54
+ if log_file is not None and dist.get_rank() == 0:
55
+ log_file_folder = os.path.split(log_file)[0]
56
+ os.makedirs(log_file_folder, exist_ok=True)
57
+ file_handler = logging.FileHandler(log_file, 'a')
58
+ file_handler.setFormatter(formatter)
59
+ logger.addHandler(file_handler)
60
+ if dist.get_rank() == 0:
61
+ logger.setLevel(log_level)
62
+ else:
63
+ logger.setLevel(logging.ERROR)
64
+ logger_initialized[name] = True
65
+ return logger
Rotate/StyleText/utils/math_functions.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import paddle
15
+
16
+
17
+ def compute_mean_covariance(img):
18
+ batch_size = img.shape[0]
19
+ channel_num = img.shape[1]
20
+ height = img.shape[2]
21
+ width = img.shape[3]
22
+ num_pixels = height * width
23
+
24
+ # batch_size * channel_num * 1 * 1
25
+ mu = img.mean(2, keepdim=True).mean(3, keepdim=True)
26
+
27
+ # batch_size * channel_num * num_pixels
28
+ img_hat = img - mu.expand_as(img)
29
+ img_hat = img_hat.reshape([batch_size, channel_num, num_pixels])
30
+ # batch_size * num_pixels * channel_num
31
+ img_hat_transpose = img_hat.transpose([0, 2, 1])
32
+ # batch_size * channel_num * channel_num
33
+ covariance = paddle.bmm(img_hat, img_hat_transpose)
34
+ covariance = covariance / num_pixels
35
+
36
+ return mu, covariance
37
+
38
+
39
+ def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
40
+ eps = 1e-5
41
+ intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask)
42
+ union = paddle.sum(y_true_cls * training_mask) + paddle.sum(
43
+ y_pred_cls * training_mask) + eps
44
+ loss = 1. - (2 * intersection / union)
45
+ return loss
Rotate/StyleText/utils/sys_funcs.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+ import os
16
+ import errno
17
+ import paddle
18
+
19
+
20
+ def get_check_global_params(mode):
21
+ check_params = [
22
+ 'use_gpu', 'max_text_length', 'image_shape', 'image_shape',
23
+ 'character_type', 'loss_type'
24
+ ]
25
+ if mode == "train_eval":
26
+ check_params = check_params + [
27
+ 'train_batch_size_per_card', 'test_batch_size_per_card'
28
+ ]
29
+ elif mode == "test":
30
+ check_params = check_params + ['test_batch_size_per_card']
31
+ return check_params
32
+
33
+
34
+ def check_gpu(use_gpu):
35
+ """
36
+ Log error and exit when set use_gpu=true in paddlepaddle
37
+ cpu version.
38
+ """
39
+ err = "Config use_gpu cannot be set as true while you are " \
40
+ "using paddlepaddle cpu version ! \nPlease try: \n" \
41
+ "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
42
+ "\t2. Set use_gpu as false in config file to run " \
43
+ "model on CPU"
44
+ if use_gpu:
45
+ try:
46
+ if not paddle.is_compiled_with_cuda():
47
+ print(err)
48
+ sys.exit(1)
49
+ except:
50
+ print("Fail to check gpu state.")
51
+ sys.exit(1)
52
+
53
+
54
+ def _mkdir_if_not_exist(path, logger):
55
+ """
56
+ mkdir if not exists, ignore the exception when multiprocess mkdir together
57
+ """
58
+ if not os.path.exists(path):
59
+ try:
60
+ os.makedirs(path)
61
+ except OSError as e:
62
+ if e.errno == errno.EEXIST and os.path.isdir(path):
63
+ logger.warning(
64
+ 'be happy if some process has already created {}'.format(
65
+ path))
66
+ else:
67
+ raise OSError('Failed to mkdir {}'.format(path))
Rotate/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import paddleocr
15
+ # from .paddleocr import *
16
+
17
+ # __version__ = paddleocr.VERSION
18
+ # __all__ = ['PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'save_structure_res','download_with_progressbar']
Rotate/ch_PP-OCRv4_det_infer/inference.pdiparams.info ADDED
Binary file (23.6 kB). View file
 
Rotate/ch_PP-OCRv4_det_infer/inference.pdmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ad68ed2768fe6c41166a5bc64680cc9f445390acb6528da449a4db2f7b90e14
3
+ size 166367
Rotate/ch_ppocr_mobile_v2.0_cls_infer/._inference.pdmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d86f5afbfb8cd933a1d0dbbfd8ff2b93ca3eacc6c45f4590a4a2ee107047f6d2
3
+ size 176
Rotate/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1efda1b80e174b4fcb168a035ac96c1af4938892bd86a55f300a6027105d08c
3
+ size 539978
Rotate/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams.info ADDED
Binary file (18.5 kB). View file
 
Rotate/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ debug: false
3
+ use_gpu: true
4
+ epoch_num: 100
5
+ log_smooth_window: 20
6
+ print_batch_step: 10
7
+ save_model_dir: ./output/rec_ppocr_v3_rotnet
8
+ save_epoch_step: 3
9
+ eval_batch_step: [0, 2000]
10
+ cal_metric_during_train: true
11
+ pretrained_model: null
12
+ checkpoints: null
13
+ save_inference_dir: null
14
+ use_visualdl: false
15
+ infer_img: doc/imgs_words/ch/word_1.jpg
16
+ character_dict_path: ppocr/utils/ppocr_keys_v1.txt
17
+ max_text_length: 25
18
+ infer_mode: false
19
+ use_space_char: true
20
+ save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt
21
+ Optimizer:
22
+ name: Adam
23
+ beta1: 0.9
24
+ beta2: 0.999
25
+ lr:
26
+ name: Cosine
27
+ learning_rate: 0.001
28
+ regularizer:
29
+ name: L2
30
+ factor: 1.0e-05
31
+ Architecture:
32
+ model_type: cls
33
+ algorithm: CLS
34
+ Transform: null
35
+ Backbone:
36
+ name: MobileNetV1Enhance
37
+ scale: 0.5
38
+ last_conv_stride: [1, 2]
39
+ last_pool_type: avg
40
+ Neck:
41
+ Head:
42
+ name: ClsHead
43
+ class_dim: 4
44
+
45
+ Loss:
46
+ name: ClsLoss
47
+ main_indicator: acc
48
+
49
+ PostProcess:
50
+ name: ClsPostProcess
51
+
52
+ Metric:
53
+ name: ClsMetric
54
+ main_indicator: acc
55
+
56
+ Train:
57
+ dataset:
58
+ name: SimpleDataSet
59
+ data_dir: ./train_data
60
+ label_file_list:
61
+ - ./train_data/train_list.txt
62
+ transforms:
63
+ - DecodeImage:
64
+ img_mode: BGR
65
+ channel_first: false
66
+ - BaseDataAugmentation:
67
+ - RandAugment:
68
+ - SSLRotateResize:
69
+ image_shape: [3, 48, 320]
70
+ - KeepKeys:
71
+ keep_keys: ["image", "label"]
72
+ loader:
73
+ collate_fn: "SSLRotateCollate"
74
+ shuffle: true
75
+ batch_size_per_card: 32
76
+ drop_last: true
77
+ num_workers: 8
78
+ Eval:
79
+ dataset:
80
+ name: SimpleDataSet
81
+ data_dir: ./train_data
82
+ label_file_list:
83
+ - ./train_data/val_list.txt
84
+ transforms:
85
+ - DecodeImage:
86
+ img_mode: BGR
87
+ channel_first: false
88
+ - SSLRotateResize:
89
+ image_shape: [3, 48, 320]
90
+ - KeepKeys:
91
+ keep_keys: ["image", "label"]
92
+ loader:
93
+ collate_fn: "SSLRotateCollate"
94
+ shuffle: false
95
+ drop_last: false
96
+ batch_size_per_card: 64
97
+ num_workers: 8
98
+ profiler_options: null
Rotate/configs/cls/cls_mv3.yml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ use_gpu: true
3
+ epoch_num: 100
4
+ log_smooth_window: 20
5
+ print_batch_step: 10
6
+ save_model_dir: ./output/cls/mv3/
7
+ save_epoch_step: 3
8
+ # evaluation is run every 5000 iterations after the 4000th iteration
9
+ eval_batch_step: [0, 1000]
10
+ cal_metric_during_train: True
11
+ pretrained_model:
12
+ checkpoints:
13
+ save_inference_dir:
14
+ use_visualdl: False
15
+ infer_img: doc/imgs_words_en/word_10.png
16
+ label_list: ['0','180']
17
+
18
+ Architecture:
19
+ model_type: cls
20
+ algorithm: CLS
21
+ Transform:
22
+ Backbone:
23
+ name: MobileNetV3
24
+ scale: 0.35
25
+ model_name: small
26
+ Neck:
27
+ Head:
28
+ name: ClsHead
29
+ class_dim: 2
30
+
31
+ Loss:
32
+ name: ClsLoss
33
+
34
+ Optimizer:
35
+ name: Adam
36
+ beta1: 0.9
37
+ beta2: 0.999
38
+ lr:
39
+ name: Cosine
40
+ learning_rate: 0.001
41
+ regularizer:
42
+ name: 'L2'
43
+ factor: 0
44
+
45
+ PostProcess:
46
+ name: ClsPostProcess
47
+
48
+ Metric:
49
+ name: ClsMetric
50
+ main_indicator: acc
51
+
52
+ Train:
53
+ dataset:
54
+ name: SimpleDataSet
55
+ data_dir: ./train_data/cls
56
+ label_file_list:
57
+ - ./train_data/cls/train.txt
58
+ transforms:
59
+ - DecodeImage: # load image
60
+ img_mode: BGR
61
+ channel_first: False
62
+ - ClsLabelEncode: # Class handling label
63
+ - BaseDataAugmentation:
64
+ - RandAugment:
65
+ - ClsResizeImg:
66
+ image_shape: [3, 48, 192]
67
+ - KeepKeys:
68
+ keep_keys: ['image', 'label'] # dataloader will return list in this order
69
+ loader:
70
+ shuffle: True
71
+ batch_size_per_card: 512
72
+ drop_last: True
73
+ num_workers: 8
74
+
75
+ Eval:
76
+ dataset:
77
+ name: SimpleDataSet
78
+ data_dir: ./train_data/cls
79
+ label_file_list:
80
+ - ./train_data/cls/test.txt
81
+ transforms:
82
+ - DecodeImage: # load image
83
+ img_mode: BGR
84
+ channel_first: False
85
+ - ClsLabelEncode: # Class handling label
86
+ - ClsResizeImg:
87
+ image_shape: [3, 48, 192]
88
+ - KeepKeys:
89
+ keep_keys: ['image', 'label'] # dataloader will return list in this order
90
+ loader:
91
+ shuffle: False
92
+ drop_last: False
93
+ batch_size_per_card: 512
94
+ num_workers: 4
Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ use_gpu: true
3
+ epoch_num: 1200
4
+ log_smooth_window: 20
5
+ print_batch_step: 2
6
+ save_model_dir: ./output/ch_db_mv3/
7
+ save_epoch_step: 1200
8
+ # evaluation is run every 5000 iterations after the 4000th iteration
9
+ eval_batch_step: [3000, 2000]
10
+ cal_metric_during_train: False
11
+ pretrained_model:
12
+ checkpoints:
13
+ save_inference_dir:
14
+ use_visualdl: False
15
+ infer_img: doc/imgs_en/img_10.jpg
16
+ save_res_path: ./output/det_db/predicts_db.txt
17
+ use_amp: False
18
+ amp_level: O2
19
+ amp_dtype: bfloat16
20
+
21
+ Architecture:
22
+ name: DistillationModel
23
+ algorithm: Distillation
24
+ model_type: det
25
+ Models:
26
+ Teacher:
27
+ pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
28
+ freeze_params: true
29
+ return_all_feats: false
30
+ model_type: det
31
+ algorithm: DB
32
+ Transform:
33
+ Backbone:
34
+ name: ResNet_vd
35
+ layers: 18
36
+ Neck:
37
+ name: DBFPN
38
+ out_channels: 256
39
+ Head:
40
+ name: DBHead
41
+ k: 50
42
+ Student:
43
+ pretrained:
44
+ freeze_params: false
45
+ return_all_feats: false
46
+ model_type: det
47
+ algorithm: DB
48
+ Backbone:
49
+ name: MobileNetV3
50
+ scale: 0.5
51
+ model_name: large
52
+ disable_se: True
53
+ Neck:
54
+ name: DBFPN
55
+ out_channels: 96
56
+ Head:
57
+ name: DBHead
58
+ k: 50
59
+ Student2:
60
+ pretrained:
61
+ freeze_params: false
62
+ return_all_feats: false
63
+ model_type: det
64
+ algorithm: DB
65
+ Transform:
66
+ Backbone:
67
+ name: MobileNetV3
68
+ scale: 0.5
69
+ model_name: large
70
+ disable_se: True
71
+ Neck:
72
+ name: DBFPN
73
+ out_channels: 96
74
+ Head:
75
+ name: DBHead
76
+ k: 50
77
+
78
+ Loss:
79
+ name: CombinedLoss
80
+ loss_config_list:
81
+ - DistillationDilaDBLoss:
82
+ weight: 1.0
83
+ model_name_pairs:
84
+ - ["Student", "Teacher"]
85
+ - ["Student2", "Teacher"]
86
+ key: maps
87
+ balance_loss: true
88
+ main_loss_type: DiceLoss
89
+ alpha: 5
90
+ beta: 10
91
+ ohem_ratio: 3
92
+ - DistillationDMLLoss:
93
+ model_name_pairs:
94
+ - ["Student", "Student2"]
95
+ maps_name: "thrink_maps"
96
+ weight: 1.0
97
+ # act: None
98
+ model_name_pairs: ["Student", "Student2"]
99
+ key: maps
100
+ - DistillationDBLoss:
101
+ weight: 1.0
102
+ model_name_list: ["Student", "Student2"]
103
+ # key: maps
104
+ # name: DBLoss
105
+ balance_loss: true
106
+ main_loss_type: DiceLoss
107
+ alpha: 5
108
+ beta: 10
109
+ ohem_ratio: 3
110
+
111
+
112
+ Optimizer:
113
+ name: Adam
114
+ beta1: 0.9
115
+ beta2: 0.999
116
+ lr:
117
+ name: Cosine
118
+ learning_rate: 0.001
119
+ warmup_epoch: 2
120
+ regularizer:
121
+ name: 'L2'
122
+ factor: 0
123
+
124
+ PostProcess:
125
+ name: DistillationDBPostProcess
126
+ model_name: ["Student", "Student2", "Teacher"]
127
+ # key: maps
128
+ thresh: 0.3
129
+ box_thresh: 0.6
130
+ max_candidates: 1000
131
+ unclip_ratio: 1.5
132
+
133
+ Metric:
134
+ name: DistillationMetric
135
+ base_metric_name: DetMetric
136
+ main_indicator: hmean
137
+ key: "Student"
138
+
139
+ Train:
140
+ dataset:
141
+ name: SimpleDataSet
142
+ data_dir: ./train_data/icdar2015/text_localization/
143
+ label_file_list:
144
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
145
+ ratio_list: [1.0]
146
+ transforms:
147
+ - DecodeImage: # load image
148
+ img_mode: BGR
149
+ channel_first: False
150
+ - DetLabelEncode: # Class handling label
151
+ - CopyPaste:
152
+ - IaaAugment:
153
+ augmenter_args:
154
+ - { 'type': Fliplr, 'args': { 'p': 0.5 } }
155
+ - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
156
+ - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
157
+ - EastRandomCropData:
158
+ size: [960, 960]
159
+ max_tries: 50
160
+ keep_ratio: true
161
+ - MakeBorderMap:
162
+ shrink_ratio: 0.4
163
+ thresh_min: 0.3
164
+ thresh_max: 0.7
165
+ - MakeShrinkMap:
166
+ shrink_ratio: 0.4
167
+ min_text_size: 8
168
+ - NormalizeImage:
169
+ scale: 1./255.
170
+ mean: [0.485, 0.456, 0.406]
171
+ std: [0.229, 0.224, 0.225]
172
+ order: 'hwc'
173
+ - ToCHWImage:
174
+ - KeepKeys:
175
+ keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
176
+ loader:
177
+ shuffle: True
178
+ drop_last: False
179
+ batch_size_per_card: 8
180
+ num_workers: 4
181
+
182
+ Eval:
183
+ dataset:
184
+ name: SimpleDataSet
185
+ data_dir: ./train_data/icdar2015/text_localization/
186
+ label_file_list:
187
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
188
+ transforms:
189
+ - DecodeImage: # load image
190
+ img_mode: BGR
191
+ channel_first: False
192
+ - DetLabelEncode: # Class handling label
193
+ - DetResizeForTest:
194
+ - NormalizeImage:
195
+ scale: 1./255.
196
+ mean: [0.485, 0.456, 0.406]
197
+ std: [0.229, 0.224, 0.225]
198
+ order: 'hwc'
199
+ - ToCHWImage:
200
+ - KeepKeys:
201
+ keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
202
+ loader:
203
+ shuffle: False
204
+ drop_last: False
205
+ batch_size_per_card: 1 # must be 1
206
+ num_workers: 2
Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ use_gpu: true
3
+ epoch_num: 1200
4
+ log_smooth_window: 20
5
+ print_batch_step: 2
6
+ save_model_dir: ./output/ch_db_mv3/
7
+ save_epoch_step: 1200
8
+ # evaluation is run every 5000 iterations after the 4000th iteration
9
+ eval_batch_step: [3000, 2000]
10
+ cal_metric_during_train: False
11
+ pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
12
+ checkpoints:
13
+ save_inference_dir:
14
+ use_visualdl: False
15
+ infer_img: doc/imgs_en/img_10.jpg
16
+ save_res_path: ./output/det_db/predicts_db.txt
17
+
18
+ Architecture:
19
+ name: DistillationModel
20
+ algorithm: Distillation
21
+ model_type: det
22
+ Models:
23
+ Student:
24
+ pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
25
+ freeze_params: false
26
+ return_all_feats: false
27
+ model_type: det
28
+ algorithm: DB
29
+ Backbone:
30
+ name: MobileNetV3
31
+ scale: 0.5
32
+ model_name: large
33
+ disable_se: True
34
+ Neck:
35
+ name: DBFPN
36
+ out_channels: 96
37
+ Head:
38
+ name: DBHead
39
+ k: 50
40
+ Teacher:
41
+ pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
42
+ freeze_params: true
43
+ return_all_feats: false
44
+ model_type: det
45
+ algorithm: DB
46
+ Transform:
47
+ Backbone:
48
+ name: ResNet_vd
49
+ layers: 18
50
+ Neck:
51
+ name: DBFPN
52
+ out_channels: 256
53
+ Head:
54
+ name: DBHead
55
+ k: 50
56
+
57
+ Loss:
58
+ name: CombinedLoss
59
+ loss_config_list:
60
+ - DistillationDilaDBLoss:
61
+ weight: 1.0
62
+ model_name_pairs:
63
+ - ["Student", "Teacher"]
64
+ key: maps
65
+ balance_loss: true
66
+ main_loss_type: DiceLoss
67
+ alpha: 5
68
+ beta: 10
69
+ ohem_ratio: 3
70
+ - DistillationDBLoss:
71
+ weight: 1.0
72
+ model_name_list: ["Student"]
73
+ name: DBLoss
74
+ balance_loss: true
75
+ main_loss_type: DiceLoss
76
+ alpha: 5
77
+ beta: 10
78
+ ohem_ratio: 3
79
+
80
+ Optimizer:
81
+ name: Adam
82
+ beta1: 0.9
83
+ beta2: 0.999
84
+ lr:
85
+ name: Cosine
86
+ learning_rate: 0.001
87
+ warmup_epoch: 2
88
+ regularizer:
89
+ name: 'L2'
90
+ factor: 0
91
+
92
+ PostProcess:
93
+ name: DistillationDBPostProcess
94
+ model_name: ["Student"]
95
+ key: head_out
96
+ thresh: 0.3
97
+ box_thresh: 0.6
98
+ max_candidates: 1000
99
+ unclip_ratio: 1.5
100
+
101
+ Metric:
102
+ name: DistillationMetric
103
+ base_metric_name: DetMetric
104
+ main_indicator: hmean
105
+ key: "Student"
106
+
107
+ Train:
108
+ dataset:
109
+ name: SimpleDataSet
110
+ data_dir: ./train_data/icdar2015/text_localization/
111
+ label_file_list:
112
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
113
+ ratio_list: [1.0]
114
+ transforms:
115
+ - DecodeImage: # load image
116
+ img_mode: BGR
117
+ channel_first: False
118
+ - DetLabelEncode: # Class handling label
119
+ - CopyPaste:
120
+ - IaaAugment:
121
+ augmenter_args:
122
+ - { 'type': Fliplr, 'args': { 'p': 0.5 } }
123
+ - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
124
+ - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
125
+ - EastRandomCropData:
126
+ size: [960, 960]
127
+ max_tries: 50
128
+ keep_ratio: true
129
+ - MakeBorderMap:
130
+ shrink_ratio: 0.4
131
+ thresh_min: 0.3
132
+ thresh_max: 0.7
133
+ - MakeShrinkMap:
134
+ shrink_ratio: 0.4
135
+ min_text_size: 8
136
+ - NormalizeImage:
137
+ scale: 1./255.
138
+ mean: [0.485, 0.456, 0.406]
139
+ std: [0.229, 0.224, 0.225]
140
+ order: 'hwc'
141
+ - ToCHWImage:
142
+ - KeepKeys:
143
+ keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
144
+ loader:
145
+ shuffle: True
146
+ drop_last: False
147
+ batch_size_per_card: 8
148
+ num_workers: 4
149
+
150
+ Eval:
151
+ dataset:
152
+ name: SimpleDataSet
153
+ data_dir: ./train_data/icdar2015/text_localization/
154
+ label_file_list:
155
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
156
+ transforms:
157
+ - DecodeImage: # load image
158
+ img_mode: BGR
159
+ channel_first: False
160
+ - DetLabelEncode: # Class handling label
161
+ - DetResizeForTest:
162
+ # image_shape: [736, 1280]
163
+ - NormalizeImage:
164
+ scale: 1./255.
165
+ mean: [0.485, 0.456, 0.406]
166
+ std: [0.229, 0.224, 0.225]
167
+ order: 'hwc'
168
+ - ToCHWImage:
169
+ - KeepKeys:
170
+ keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
171
+ loader:
172
+ shuffle: False
173
+ drop_last: False
174
+ batch_size_per_card: 1 # must be 1
175
+ num_workers: 2
Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ use_gpu: true
3
+ epoch_num: 1200
4
+ log_smooth_window: 20
5
+ print_batch_step: 2
6
+ save_model_dir: ./output/ch_db_mv3/
7
+ save_epoch_step: 1200
8
+ # evaluation is run every 5000 iterations after the 4000th iteration
9
+ eval_batch_step: [3000, 2000]
10
+ cal_metric_during_train: False
11
+ pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
12
+ checkpoints:
13
+ save_inference_dir:
14
+ use_visualdl: False
15
+ infer_img: doc/imgs_en/img_10.jpg
16
+ save_res_path: ./output/det_db/predicts_db.txt
17
+
18
+ Architecture:
19
+ name: DistillationModel
20
+ algorithm: Distillation
21
+ model_type: det
22
+ Models:
23
+ Student:
24
+ pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
25
+ freeze_params: false
26
+ return_all_feats: false
27
+ model_type: det
28
+ algorithm: DB
29
+ Backbone:
30
+ name: MobileNetV3
31
+ scale: 0.5
32
+ model_name: large
33
+ disable_se: True
34
+ Neck:
35
+ name: DBFPN
36
+ out_channels: 96
37
+ Head:
38
+ name: DBHead
39
+ k: 50
40
+ Teacher:
41
+ pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
42
+ freeze_params: false
43
+ return_all_feats: false
44
+ model_type: det
45
+ algorithm: DB
46
+ Transform:
47
+ Backbone:
48
+ name: MobileNetV3
49
+ scale: 0.5
50
+ model_name: large
51
+ disable_se: True
52
+ Neck:
53
+ name: DBFPN
54
+ out_channels: 96
55
+ Head:
56
+ name: DBHead
57
+ k: 50
58
+
59
+
60
+ Loss:
61
+ name: CombinedLoss
62
+ loss_config_list:
63
+ - DistillationDMLLoss:
64
+ model_name_pairs:
65
+ - ["Student", "Teacher"]
66
+ maps_name: "thrink_maps"
67
+ weight: 1.0
68
+ # act: None
69
+ model_name_pairs: ["Student", "Teacher"]
70
+ key: maps
71
+ - DistillationDBLoss:
72
+ weight: 1.0
73
+ model_name_list: ["Student", "Teacher"]
74
+ # key: maps
75
+ name: DBLoss
76
+ balance_loss: true
77
+ main_loss_type: DiceLoss
78
+ alpha: 5
79
+ beta: 10
80
+ ohem_ratio: 3
81
+
82
+
83
+ Optimizer:
84
+ name: Adam
85
+ beta1: 0.9
86
+ beta2: 0.999
87
+ lr:
88
+ name: Cosine
89
+ learning_rate: 0.001
90
+ warmup_epoch: 2
91
+ regularizer:
92
+ name: 'L2'
93
+ factor: 0
94
+
95
+ PostProcess:
96
+ name: DistillationDBPostProcess
97
+ model_name: ["Student", "Teacher"]
98
+ key: head_out
99
+ thresh: 0.3
100
+ box_thresh: 0.6
101
+ max_candidates: 1000
102
+ unclip_ratio: 1.5
103
+
104
+ Metric:
105
+ name: DistillationMetric
106
+ base_metric_name: DetMetric
107
+ main_indicator: hmean
108
+ key: "Student"
109
+
110
+ Train:
111
+ dataset:
112
+ name: SimpleDataSet
113
+ data_dir: ./train_data/icdar2015/text_localization/
114
+ label_file_list:
115
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
116
+ ratio_list: [1.0]
117
+ transforms:
118
+ - DecodeImage: # load image
119
+ img_mode: BGR
120
+ channel_first: False
121
+ - DetLabelEncode: # Class handling label
122
+ - CopyPaste:
123
+ - IaaAugment:
124
+ augmenter_args:
125
+ - { 'type': Fliplr, 'args': { 'p': 0.5 } }
126
+ - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
127
+ - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
128
+ - EastRandomCropData:
129
+ size: [960, 960]
130
+ max_tries: 50
131
+ keep_ratio: true
132
+ - MakeBorderMap:
133
+ shrink_ratio: 0.4
134
+ thresh_min: 0.3
135
+ thresh_max: 0.7
136
+ - MakeShrinkMap:
137
+ shrink_ratio: 0.4
138
+ min_text_size: 8
139
+ - NormalizeImage:
140
+ scale: 1./255.
141
+ mean: [0.485, 0.456, 0.406]
142
+ std: [0.229, 0.224, 0.225]
143
+ order: 'hwc'
144
+ - ToCHWImage:
145
+ - KeepKeys:
146
+ keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
147
+ loader:
148
+ shuffle: True
149
+ drop_last: False
150
+ batch_size_per_card: 8
151
+ num_workers: 4
152
+
153
+ Eval:
154
+ dataset:
155
+ name: SimpleDataSet
156
+ data_dir: ./train_data/icdar2015/text_localization/
157
+ label_file_list:
158
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
159
+ transforms:
160
+ - DecodeImage: # load image
161
+ img_mode: BGR
162
+ channel_first: False
163
+ - DetLabelEncode: # Class handling label
164
+ - DetResizeForTest:
165
+ # image_shape: [736, 1280]
166
+ - NormalizeImage:
167
+ scale: 1./255.
168
+ mean: [0.485, 0.456, 0.406]
169
+ std: [0.229, 0.224, 0.225]
170
+ order: 'hwc'
171
+ - ToCHWImage:
172
+ - KeepKeys:
173
+ keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
174
+ loader:
175
+ shuffle: False
176
+ drop_last: False
177
+ batch_size_per_card: 1 # must be 1
178
+ num_workers: 2
Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ use_gpu: true
3
+ epoch_num: 1200
4
+ log_smooth_window: 20
5
+ print_batch_step: 10
6
+ save_model_dir: ./output/ch_db_mv3/
7
+ save_epoch_step: 1200
8
+ # evaluation is run every 5000 iterations after the 4000th iteration
9
+ eval_batch_step: [0, 400]
10
+ cal_metric_during_train: False
11
+ pretrained_model: ./pretrain_models/student.pdparams
12
+ checkpoints:
13
+ save_inference_dir:
14
+ use_visualdl: False
15
+ infer_img: doc/imgs_en/img_10.jpg
16
+ save_res_path: ./output/det_db/predicts_db.txt
17
+
18
+ Architecture:
19
+ model_type: det
20
+ algorithm: DB
21
+ Transform:
22
+ Backbone:
23
+ name: MobileNetV3
24
+ scale: 0.5
25
+ model_name: large
26
+ disable_se: True
27
+ Neck:
28
+ name: DBFPN
29
+ out_channels: 96
30
+ Head:
31
+ name: DBHead
32
+ k: 50
33
+
34
+ Loss:
35
+ name: DBLoss
36
+ balance_loss: true
37
+ main_loss_type: DiceLoss
38
+ alpha: 5
39
+ beta: 10
40
+ ohem_ratio: 3
41
+
42
+ Optimizer:
43
+ name: Adam
44
+ beta1: 0.9
45
+ beta2: 0.999
46
+ lr:
47
+ name: Cosine
48
+ learning_rate: 0.001
49
+ warmup_epoch: 2
50
+ regularizer:
51
+ name: 'L2'
52
+ factor: 0
53
+
54
+ PostProcess:
55
+ name: DBPostProcess
56
+ thresh: 0.3
57
+ box_thresh: 0.6
58
+ max_candidates: 1000
59
+ unclip_ratio: 1.5
60
+
61
+ Metric:
62
+ name: DetMetric
63
+ main_indicator: hmean
64
+
65
+ Train:
66
+ dataset:
67
+ name: SimpleDataSet
68
+ data_dir: ./train_data/icdar2015/text_localization/
69
+ label_file_list:
70
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
71
+ ratio_list: [1.0]
72
+ transforms:
73
+ - DecodeImage: # load image
74
+ img_mode: BGR
75
+ channel_first: False
76
+ - DetLabelEncode: # Class handling label
77
+ - IaaAugment:
78
+ augmenter_args:
79
+ - { 'type': Fliplr, 'args': { 'p': 0.5 } }
80
+ - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
81
+ - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
82
+ - EastRandomCropData:
83
+ size: [960, 960]
84
+ max_tries: 50
85
+ keep_ratio: true
86
+ - MakeBorderMap:
87
+ shrink_ratio: 0.4
88
+ thresh_min: 0.3
89
+ thresh_max: 0.7
90
+ - MakeShrinkMap:
91
+ shrink_ratio: 0.4
92
+ min_text_size: 8
93
+ - NormalizeImage:
94
+ scale: 1./255.
95
+ mean: [0.485, 0.456, 0.406]
96
+ std: [0.229, 0.224, 0.225]
97
+ order: 'hwc'
98
+ - ToCHWImage:
99
+ - KeepKeys:
100
+ keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
101
+ loader:
102
+ shuffle: True
103
+ drop_last: False
104
+ batch_size_per_card: 8
105
+ num_workers: 4
106
+
107
+ Eval:
108
+ dataset:
109
+ name: SimpleDataSet
110
+ data_dir: ./train_data/icdar2015/text_localization/
111
+ label_file_list:
112
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
113
+ transforms:
114
+ - DecodeImage: # load image
115
+ img_mode: BGR
116
+ channel_first: False
117
+ - DetLabelEncode: # Class handling label
118
+ - DetResizeForTest:
119
+ # image_shape: [736, 1280]
120
+ - NormalizeImage:
121
+ scale: 1./255.
122
+ mean: [0.485, 0.456, 0.406]
123
+ std: [0.229, 0.224, 0.225]
124
+ order: 'hwc'
125
+ - ToCHWImage:
126
+ - KeepKeys:
127
+ keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
128
+ loader:
129
+ shuffle: False
130
+ drop_last: False
131
+ batch_size_per_card: 1 # must be 1
132
+ num_workers: 2
Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ debug: false
3
+ use_gpu: true
4
+ epoch_num: 500
5
+ log_smooth_window: 20
6
+ print_batch_step: 10
7
+ save_model_dir: ./output/ch_PP-OCR_v3_det/
8
+ save_epoch_step: 100
9
+ eval_batch_step:
10
+ - 0
11
+ - 400
12
+ cal_metric_during_train: false
13
+ pretrained_model: null
14
+ checkpoints: null
15
+ save_inference_dir: null
16
+ use_visualdl: false
17
+ infer_img: doc/imgs_en/img_10.jpg
18
+ save_res_path: ./checkpoints/det_db/predicts_db.txt
19
+ distributed: true
20
+ d2s_train_image_shape: [3, -1, -1]
21
+ amp_dtype: bfloat16
22
+
23
+ Architecture:
24
+ name: DistillationModel
25
+ algorithm: Distillation
26
+ model_type: det
27
+ Models:
28
+ Student:
29
+ pretrained:
30
+ model_type: det
31
+ algorithm: DB
32
+ Transform: null
33
+ Backbone:
34
+ name: MobileNetV3
35
+ scale: 0.5
36
+ model_name: large
37
+ disable_se: true
38
+ Neck:
39
+ name: RSEFPN
40
+ out_channels: 96
41
+ shortcut: True
42
+ Head:
43
+ name: DBHead
44
+ k: 50
45
+ Student2:
46
+ pretrained:
47
+ model_type: det
48
+ algorithm: DB
49
+ Transform: null
50
+ Backbone:
51
+ name: MobileNetV3
52
+ scale: 0.5
53
+ model_name: large
54
+ disable_se: true
55
+ Neck:
56
+ name: RSEFPN
57
+ out_channels: 96
58
+ shortcut: True
59
+ Head:
60
+ name: DBHead
61
+ k: 50
62
+ Teacher:
63
+ freeze_params: true
64
+ return_all_feats: false
65
+ model_type: det
66
+ algorithm: DB
67
+ Backbone:
68
+ name: ResNet_vd
69
+ in_channels: 3
70
+ layers: 50
71
+ Neck:
72
+ name: LKPAN
73
+ out_channels: 256
74
+ Head:
75
+ name: DBHead
76
+ kernel_list: [7,2,2]
77
+ k: 50
78
+
79
+ Loss:
80
+ name: CombinedLoss
81
+ loss_config_list:
82
+ - DistillationDilaDBLoss:
83
+ weight: 1.0
84
+ model_name_pairs:
85
+ - ["Student", "Teacher"]
86
+ - ["Student2", "Teacher"]
87
+ key: maps
88
+ balance_loss: true
89
+ main_loss_type: DiceLoss
90
+ alpha: 5
91
+ beta: 10
92
+ ohem_ratio: 3
93
+ - DistillationDMLLoss:
94
+ model_name_pairs:
95
+ - ["Student", "Student2"]
96
+ maps_name: "thrink_maps"
97
+ weight: 1.0
98
+ model_name_pairs: ["Student", "Student2"]
99
+ key: maps
100
+ - DistillationDBLoss:
101
+ weight: 1.0
102
+ model_name_list: ["Student", "Student2"]
103
+ balance_loss: true
104
+ main_loss_type: DiceLoss
105
+ alpha: 5
106
+ beta: 10
107
+ ohem_ratio: 3
108
+
109
+ Optimizer:
110
+ name: Adam
111
+ beta1: 0.9
112
+ beta2: 0.999
113
+ lr:
114
+ name: Cosine
115
+ learning_rate: 0.001
116
+ warmup_epoch: 2
117
+ regularizer:
118
+ name: L2
119
+ factor: 5.0e-05
120
+
121
+ PostProcess:
122
+ name: DistillationDBPostProcess
123
+ model_name: ["Student"]
124
+ key: head_out
125
+ thresh: 0.3
126
+ box_thresh: 0.6
127
+ max_candidates: 1000
128
+ unclip_ratio: 1.5
129
+
130
+ Metric:
131
+ name: DistillationMetric
132
+ base_metric_name: DetMetric
133
+ main_indicator: hmean
134
+ key: "Student"
135
+
136
+ Train:
137
+ dataset:
138
+ name: SimpleDataSet
139
+ data_dir: ./train_data/icdar2015/text_localization/
140
+ label_file_list:
141
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
142
+ ratio_list: [1.0]
143
+ transforms:
144
+ - DecodeImage:
145
+ img_mode: BGR
146
+ channel_first: false
147
+ - DetLabelEncode: null
148
+ - CopyPaste:
149
+ - IaaAugment:
150
+ augmenter_args:
151
+ - type: Fliplr
152
+ args:
153
+ p: 0.5
154
+ - type: Affine
155
+ args:
156
+ rotate:
157
+ - -10
158
+ - 10
159
+ - type: Resize
160
+ args:
161
+ size:
162
+ - 0.5
163
+ - 3
164
+ - EastRandomCropData:
165
+ size:
166
+ - 960
167
+ - 960
168
+ max_tries: 50
169
+ keep_ratio: true
170
+ - MakeBorderMap:
171
+ shrink_ratio: 0.4
172
+ thresh_min: 0.3
173
+ thresh_max: 0.7
174
+ - MakeShrinkMap:
175
+ shrink_ratio: 0.4
176
+ min_text_size: 8
177
+ - NormalizeImage:
178
+ scale: 1./255.
179
+ mean:
180
+ - 0.485
181
+ - 0.456
182
+ - 0.406
183
+ std:
184
+ - 0.229
185
+ - 0.224
186
+ - 0.225
187
+ order: hwc
188
+ - ToCHWImage: null
189
+ - KeepKeys:
190
+ keep_keys:
191
+ - image
192
+ - threshold_map
193
+ - threshold_mask
194
+ - shrink_map
195
+ - shrink_mask
196
+ loader:
197
+ shuffle: true
198
+ drop_last: false
199
+ batch_size_per_card: 8
200
+ num_workers: 4
201
+
202
+ Eval:
203
+ dataset:
204
+ name: SimpleDataSet
205
+ data_dir: ./train_data/icdar2015/text_localization/
206
+ label_file_list:
207
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
208
+ transforms:
209
+ - DecodeImage: # load image
210
+ img_mode: BGR
211
+ channel_first: False
212
+ - DetLabelEncode: # Class handling label
213
+ - DetResizeForTest:
214
+ - NormalizeImage:
215
+ scale: 1./255.
216
+ mean: [0.485, 0.456, 0.406]
217
+ std: [0.229, 0.224, 0.225]
218
+ order: 'hwc'
219
+ - ToCHWImage:
220
+ - KeepKeys:
221
+ keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
222
+ loader:
223
+ shuffle: False
224
+ drop_last: False
225
+ batch_size_per_card: 1 # must be 1
226
+ num_workers: 2
Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ use_gpu: true
3
+ epoch_num: 1200
4
+ log_smooth_window: 20
5
+ print_batch_step: 2
6
+ save_model_dir: ./output/ch_db_mv3/
7
+ save_epoch_step: 1200
8
+ # evaluation is run every 5000 iterations after the 4000th iteration
9
+ eval_batch_step: [3000, 2000]
10
+ cal_metric_during_train: False
11
+ pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
12
+ checkpoints:
13
+ save_inference_dir:
14
+ use_visualdl: False
15
+ infer_img: doc/imgs_en/img_10.jpg
16
+ save_res_path: ./output/det_db/predicts_db.txt
17
+
18
+ Architecture:
19
+ name: DistillationModel
20
+ algorithm: Distillation
21
+ model_type: det
22
+ Models:
23
+ Student:
24
+ return_all_feats: false
25
+ model_type: det
26
+ algorithm: DB
27
+ Backbone:
28
+ name: ResNet_vd
29
+ in_channels: 3
30
+ layers: 50
31
+ Neck:
32
+ name: LKPAN
33
+ out_channels: 256
34
+ Head:
35
+ name: DBHead
36
+ kernel_list: [7,2,2]
37
+ k: 50
38
+ Student2:
39
+ return_all_feats: false
40
+ model_type: det
41
+ algorithm: DB
42
+ Backbone:
43
+ name: ResNet_vd
44
+ in_channels: 3
45
+ layers: 50
46
+ Neck:
47
+ name: LKPAN
48
+ out_channels: 256
49
+ Head:
50
+ name: DBHead
51
+ kernel_list: [7,2,2]
52
+ k: 50
53
+
54
+
55
+ Loss:
56
+ name: CombinedLoss
57
+ loss_config_list:
58
+ - DistillationDMLLoss:
59
+ model_name_pairs:
60
+ - ["Student", "Student2"]
61
+ maps_name: "thrink_maps"
62
+ weight: 1.0
63
+ # act: None
64
+ model_name_pairs: ["Student", "Student2"]
65
+ key: maps
66
+ - DistillationDBLoss:
67
+ weight: 1.0
68
+ model_name_list: ["Student", "Student2"]
69
+ # key: maps
70
+ name: DBLoss
71
+ balance_loss: true
72
+ main_loss_type: DiceLoss
73
+ alpha: 5
74
+ beta: 10
75
+ ohem_ratio: 3
76
+
77
+
78
+ Optimizer:
79
+ name: Adam
80
+ beta1: 0.9
81
+ beta2: 0.999
82
+ lr:
83
+ name: Cosine
84
+ learning_rate: 0.001
85
+ warmup_epoch: 2
86
+ regularizer:
87
+ name: 'L2'
88
+ factor: 0
89
+
90
+ PostProcess:
91
+ name: DistillationDBPostProcess
92
+ model_name: ["Student", "Student2"]
93
+ key: head_out
94
+ thresh: 0.3
95
+ box_thresh: 0.6
96
+ max_candidates: 1000
97
+ unclip_ratio: 1.5
98
+
99
+ Metric:
100
+ name: DistillationMetric
101
+ base_metric_name: DetMetric
102
+ main_indicator: hmean
103
+ key: "Student"
104
+
105
+ Train:
106
+ dataset:
107
+ name: SimpleDataSet
108
+ data_dir: ./train_data/icdar2015/text_localization/
109
+ label_file_list:
110
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
111
+ ratio_list: [1.0]
112
+ transforms:
113
+ - DecodeImage: # load image
114
+ img_mode: BGR
115
+ channel_first: False
116
+ - DetLabelEncode: # Class handling label
117
+ - CopyPaste:
118
+ - IaaAugment:
119
+ augmenter_args:
120
+ - { 'type': Fliplr, 'args': { 'p': 0.5 } }
121
+ - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
122
+ - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
123
+ - EastRandomCropData:
124
+ size: [960, 960]
125
+ max_tries: 50
126
+ keep_ratio: true
127
+ - MakeBorderMap:
128
+ shrink_ratio: 0.4
129
+ thresh_min: 0.3
130
+ thresh_max: 0.7
131
+ - MakeShrinkMap:
132
+ shrink_ratio: 0.4
133
+ min_text_size: 8
134
+ - NormalizeImage:
135
+ scale: 1./255.
136
+ mean: [0.485, 0.456, 0.406]
137
+ std: [0.229, 0.224, 0.225]
138
+ order: 'hwc'
139
+ - ToCHWImage:
140
+ - KeepKeys:
141
+ keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
142
+ loader:
143
+ shuffle: True
144
+ drop_last: False
145
+ batch_size_per_card: 8
146
+ num_workers: 4
147
+
148
+ Eval:
149
+ dataset:
150
+ name: SimpleDataSet
151
+ data_dir: ./train_data/icdar2015/text_localization/
152
+ label_file_list:
153
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
154
+ transforms:
155
+ - DecodeImage: # load image
156
+ img_mode: BGR
157
+ channel_first: False
158
+ - DetLabelEncode: # Class handling label
159
+ - DetResizeForTest:
160
+ # image_shape: [736, 1280]
161
+ - NormalizeImage:
162
+ scale: 1./255.
163
+ mean: [0.485, 0.456, 0.406]
164
+ std: [0.229, 0.224, 0.225]
165
+ order: 'hwc'
166
+ - ToCHWImage:
167
+ - KeepKeys:
168
+ keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
169
+ loader:
170
+ shuffle: False
171
+ drop_last: False
172
+ batch_size_per_card: 1 # must be 1
173
+ num_workers: 2
Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ debug: false
3
+ use_gpu: true
4
+ epoch_num: 500
5
+ log_smooth_window: 20
6
+ print_batch_step: 10
7
+ save_model_dir: ./output/ch_PP-OCR_V3_det/
8
+ save_epoch_step: 100
9
+ eval_batch_step:
10
+ - 0
11
+ - 400
12
+ cal_metric_during_train: false
13
+ pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
14
+ checkpoints: null
15
+ save_inference_dir: null
16
+ use_visualdl: false
17
+ infer_img: doc/imgs_en/img_10.jpg
18
+ save_res_path: ./checkpoints/det_db/predicts_db.txt
19
+ distributed: true
20
+
21
+ Architecture:
22
+ model_type: det
23
+ algorithm: DB
24
+ Transform:
25
+ Backbone:
26
+ name: MobileNetV3
27
+ scale: 0.5
28
+ model_name: large
29
+ disable_se: True
30
+ Neck:
31
+ name: RSEFPN
32
+ out_channels: 96
33
+ shortcut: True
34
+ Head:
35
+ name: DBHead
36
+ k: 50
37
+
38
+ Loss:
39
+ name: DBLoss
40
+ balance_loss: true
41
+ main_loss_type: DiceLoss
42
+ alpha: 5
43
+ beta: 10
44
+ ohem_ratio: 3
45
+ Optimizer:
46
+ name: Adam
47
+ beta1: 0.9
48
+ beta2: 0.999
49
+ lr:
50
+ name: Cosine
51
+ learning_rate: 0.001
52
+ warmup_epoch: 2
53
+ regularizer:
54
+ name: L2
55
+ factor: 5.0e-05
56
+ PostProcess:
57
+ name: DBPostProcess
58
+ thresh: 0.3
59
+ box_thresh: 0.6
60
+ max_candidates: 1000
61
+ unclip_ratio: 1.5
62
+ Metric:
63
+ name: DetMetric
64
+ main_indicator: hmean
65
+ Train:
66
+ dataset:
67
+ name: SimpleDataSet
68
+ data_dir: ./train_data/icdar2015/text_localization/
69
+ label_file_list:
70
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
71
+ ratio_list: [1.0]
72
+ transforms:
73
+ - DecodeImage:
74
+ img_mode: BGR
75
+ channel_first: false
76
+ - DetLabelEncode: null
77
+ - IaaAugment:
78
+ augmenter_args:
79
+ - type: Fliplr
80
+ args:
81
+ p: 0.5
82
+ - type: Affine
83
+ args:
84
+ rotate:
85
+ - -10
86
+ - 10
87
+ - type: Resize
88
+ args:
89
+ size:
90
+ - 0.5
91
+ - 3
92
+ - EastRandomCropData:
93
+ size:
94
+ - 960
95
+ - 960
96
+ max_tries: 50
97
+ keep_ratio: true
98
+ - MakeBorderMap:
99
+ shrink_ratio: 0.4
100
+ thresh_min: 0.3
101
+ thresh_max: 0.7
102
+ - MakeShrinkMap:
103
+ shrink_ratio: 0.4
104
+ min_text_size: 8
105
+ - NormalizeImage:
106
+ scale: 1./255.
107
+ mean:
108
+ - 0.485
109
+ - 0.456
110
+ - 0.406
111
+ std:
112
+ - 0.229
113
+ - 0.224
114
+ - 0.225
115
+ order: hwc
116
+ - ToCHWImage: null
117
+ - KeepKeys:
118
+ keep_keys:
119
+ - image
120
+ - threshold_map
121
+ - threshold_mask
122
+ - shrink_map
123
+ - shrink_mask
124
+ loader:
125
+ shuffle: true
126
+ drop_last: false
127
+ batch_size_per_card: 8
128
+ num_workers: 4
129
+ Eval:
130
+ dataset:
131
+ name: SimpleDataSet
132
+ data_dir: ./train_data/icdar2015/text_localization/
133
+ label_file_list:
134
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
135
+ transforms:
136
+ - DecodeImage:
137
+ img_mode: BGR
138
+ channel_first: false
139
+ - DetLabelEncode: null
140
+ - DetResizeForTest: null
141
+ - NormalizeImage:
142
+ scale: 1./255.
143
+ mean:
144
+ - 0.485
145
+ - 0.456
146
+ - 0.406
147
+ std:
148
+ - 0.229
149
+ - 0.224
150
+ - 0.225
151
+ order: hwc
152
+ - ToCHWImage: null
153
+ - KeepKeys:
154
+ keep_keys:
155
+ - image
156
+ - shape
157
+ - polys
158
+ - ignore_tags
159
+ loader:
160
+ shuffle: false
161
+ drop_last: false
162
+ batch_size_per_card: 1
163
+ num_workers: 2
Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_cml.yml ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ debug: false
3
+ use_gpu: true
4
+ epoch_num: 500
5
+ log_smooth_window: 20
6
+ print_batch_step: 20
7
+ save_model_dir: ./output/ch_PP-OCRv4
8
+ save_epoch_step: 50
9
+ eval_batch_step:
10
+ - 0
11
+ - 1000
12
+ cal_metric_during_train: true
13
+ checkpoints: null
14
+ pretrained_model: null
15
+ save_inference_dir: null
16
+ use_visualdl: false
17
+ infer_img: doc/imgs_en/img_10.jpg
18
+ save_res_path: ./checkpoints/det_db/predicts_db.txt
19
+ distributed: true
20
+ Architecture:
21
+ name: DistillationModel
22
+ algorithm: Distillation
23
+ model_type: det
24
+ Models:
25
+ Student:
26
+ model_type: det
27
+ algorithm: DB
28
+ Transform: null
29
+ Backbone:
30
+ name: PPLCNetNew
31
+ scale: 0.75
32
+ pretrained: false
33
+ Neck:
34
+ name: RSEFPN
35
+ out_channels: 96
36
+ shortcut: true
37
+ Head:
38
+ name: DBHead
39
+ k: 50
40
+ Student2:
41
+ pretrained: null
42
+ model_type: det
43
+ algorithm: DB
44
+ Transform: null
45
+ Backbone:
46
+ name: PPLCNetNew
47
+ scale: 0.75
48
+ pretrained: true
49
+ Neck:
50
+ name: RSEFPN
51
+ out_channels: 96
52
+ shortcut: true
53
+ Head:
54
+ name: DBHead
55
+ k: 50
56
+ Teacher:
57
+ pretrained: https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_cml_teacher_pretrained/teacher.pdparams
58
+ freeze_params: true
59
+ return_all_feats: false
60
+ model_type: det
61
+ algorithm: DB
62
+ Backbone:
63
+ name: ResNet_vd
64
+ in_channels: 3
65
+ layers: 50
66
+ Neck:
67
+ name: LKPAN
68
+ out_channels: 256
69
+ Head:
70
+ name: DBHead
71
+ kernel_list:
72
+ - 7
73
+ - 2
74
+ - 2
75
+ k: 50
76
+ Loss:
77
+ name: CombinedLoss
78
+ loss_config_list:
79
+ - DistillationDilaDBLoss:
80
+ weight: 1.0
81
+ model_name_pairs:
82
+ - - Student
83
+ - Teacher
84
+ - - Student2
85
+ - Teacher
86
+ key: maps
87
+ balance_loss: true
88
+ main_loss_type: DiceLoss
89
+ alpha: 5
90
+ beta: 10
91
+ ohem_ratio: 3
92
+ - DistillationDMLLoss:
93
+ model_name_pairs:
94
+ - Student
95
+ - Student2
96
+ maps_name: thrink_maps
97
+ weight: 1.0
98
+ key: maps
99
+ - DistillationDBLoss:
100
+ weight: 1.0
101
+ model_name_list:
102
+ - Student
103
+ - Student2
104
+ balance_loss: true
105
+ main_loss_type: DiceLoss
106
+ alpha: 5
107
+ beta: 10
108
+ ohem_ratio: 3
109
+ Optimizer:
110
+ name: Adam
111
+ beta1: 0.9
112
+ beta2: 0.999
113
+ lr:
114
+ name: Cosine
115
+ learning_rate: 0.001
116
+ warmup_epoch: 2
117
+ regularizer:
118
+ name: L2
119
+ factor: 5.0e-05
120
+ PostProcess:
121
+ name: DistillationDBPostProcess
122
+ model_name:
123
+ - Student
124
+ key: head_out
125
+ thresh: 0.3
126
+ box_thresh: 0.6
127
+ max_candidates: 1000
128
+ unclip_ratio: 1.5
129
+ Metric:
130
+ name: DistillationMetric
131
+ base_metric_name: DetMetric
132
+ main_indicator: hmean
133
+ key: Student
134
+ Train:
135
+ dataset:
136
+ name: SimpleDataSet
137
+ data_dir: ./train_data/icdar2015/text_localization/
138
+ label_file_list:
139
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
140
+ ratio_list: [1.0]
141
+ transforms:
142
+ - DecodeImage:
143
+ img_mode: BGR
144
+ channel_first: false
145
+ - DetLabelEncode: null
146
+ - IaaAugment:
147
+ augmenter_args:
148
+ - type: Fliplr
149
+ args:
150
+ p: 0.5
151
+ - type: Affine
152
+ args:
153
+ rotate:
154
+ - -10
155
+ - 10
156
+ - type: Resize
157
+ args:
158
+ size:
159
+ - 0.5
160
+ - 3
161
+ - EastRandomCropData:
162
+ size:
163
+ - 640
164
+ - 640
165
+ max_tries: 50
166
+ keep_ratio: true
167
+ - MakeBorderMap:
168
+ shrink_ratio: 0.4
169
+ thresh_min: 0.3
170
+ thresh_max: 0.7
171
+ total_epoch: 500
172
+ - MakeShrinkMap:
173
+ shrink_ratio: 0.4
174
+ min_text_size: 8
175
+ total_epoch: 500
176
+ - NormalizeImage:
177
+ scale: 1./255.
178
+ mean:
179
+ - 0.485
180
+ - 0.456
181
+ - 0.406
182
+ std:
183
+ - 0.229
184
+ - 0.224
185
+ - 0.225
186
+ order: hwc
187
+ - ToCHWImage: null
188
+ - KeepKeys:
189
+ keep_keys:
190
+ - image
191
+ - threshold_map
192
+ - threshold_mask
193
+ - shrink_map
194
+ - shrink_mask
195
+ loader:
196
+ shuffle: true
197
+ drop_last: false
198
+ batch_size_per_card: 16
199
+ num_workers: 8
200
+ Eval:
201
+ dataset:
202
+ name: SimpleDataSet
203
+ data_dir: ./train_data/icdar2015/text_localization/
204
+ label_file_list:
205
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
206
+ transforms:
207
+ - DecodeImage:
208
+ img_mode: BGR
209
+ channel_first: false
210
+ - DetLabelEncode: null
211
+ - DetResizeForTest: null
212
+ - NormalizeImage:
213
+ scale: 1./255.
214
+ mean:
215
+ - 0.485
216
+ - 0.456
217
+ - 0.406
218
+ std:
219
+ - 0.229
220
+ - 0.224
221
+ - 0.225
222
+ order: hwc
223
+ - ToCHWImage: null
224
+ - KeepKeys:
225
+ keep_keys:
226
+ - image
227
+ - shape
228
+ - polys
229
+ - ignore_tags
230
+ loader:
231
+ shuffle: false
232
+ drop_last: false
233
+ batch_size_per_card: 1
234
+ num_workers: 2
235
+ profiler_options: null
Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ debug: false
3
+ use_gpu: true
4
+ epoch_num: &epoch_num 500
5
+ log_smooth_window: 20
6
+ print_batch_step: 100
7
+ save_model_dir: ./output/ch_PP-OCRv4
8
+ save_epoch_step: 10
9
+ eval_batch_step:
10
+ - 0
11
+ - 1500
12
+ cal_metric_during_train: false
13
+ checkpoints:
14
+ pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/PPLCNetV3_x0_75_ocr_det.pdparams
15
+ save_inference_dir: null
16
+ use_visualdl: false
17
+ infer_img: doc/imgs_en/img_10.jpg
18
+ save_res_path: ./checkpoints/det_db/predicts_db.txt
19
+ distributed: true
20
+
21
+ Architecture:
22
+ model_type: det
23
+ algorithm: DB
24
+ Transform: null
25
+ Backbone:
26
+ name: PPLCNetV3
27
+ scale: 0.75
28
+ det: True
29
+ Neck:
30
+ name: RSEFPN
31
+ out_channels: 96
32
+ shortcut: True
33
+ Head:
34
+ name: DBHead
35
+ k: 50
36
+
37
+ Loss:
38
+ name: DBLoss
39
+ balance_loss: true
40
+ main_loss_type: DiceLoss
41
+ alpha: 5
42
+ beta: 10
43
+ ohem_ratio: 3
44
+
45
+ Optimizer:
46
+ name: Adam
47
+ beta1: 0.9
48
+ beta2: 0.999
49
+ lr:
50
+ name: Cosine
51
+ learning_rate: 0.001 #(8*8c)
52
+ warmup_epoch: 2
53
+ regularizer:
54
+ name: L2
55
+ factor: 5.0e-05
56
+
57
+ PostProcess:
58
+ name: DBPostProcess
59
+ thresh: 0.3
60
+ box_thresh: 0.6
61
+ max_candidates: 1000
62
+ unclip_ratio: 1.5
63
+
64
+ Metric:
65
+ name: DetMetric
66
+ main_indicator: hmean
67
+
68
+ Train:
69
+ dataset:
70
+ name: SimpleDataSet
71
+ data_dir: ./train_data/icdar2015/text_localization/
72
+ label_file_list:
73
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
74
+ ratio_list: [1.0]
75
+ transforms:
76
+ - DecodeImage:
77
+ img_mode: BGR
78
+ channel_first: false
79
+ - DetLabelEncode: null
80
+ - CopyPaste: null
81
+ - IaaAugment:
82
+ augmenter_args:
83
+ - type: Fliplr
84
+ args:
85
+ p: 0.5
86
+ - type: Affine
87
+ args:
88
+ rotate:
89
+ - -10
90
+ - 10
91
+ - type: Resize
92
+ args:
93
+ size:
94
+ - 0.5
95
+ - 3
96
+ - EastRandomCropData:
97
+ size:
98
+ - 640
99
+ - 640
100
+ max_tries: 50
101
+ keep_ratio: true
102
+ - MakeBorderMap:
103
+ shrink_ratio: 0.4
104
+ thresh_min: 0.3
105
+ thresh_max: 0.7
106
+ total_epoch: *epoch_num
107
+ - MakeShrinkMap:
108
+ shrink_ratio: 0.4
109
+ min_text_size: 8
110
+ total_epoch: *epoch_num
111
+ - NormalizeImage:
112
+ scale: 1./255.
113
+ mean:
114
+ - 0.485
115
+ - 0.456
116
+ - 0.406
117
+ std:
118
+ - 0.229
119
+ - 0.224
120
+ - 0.225
121
+ order: hwc
122
+ - ToCHWImage: null
123
+ - KeepKeys:
124
+ keep_keys:
125
+ - image
126
+ - threshold_map
127
+ - threshold_mask
128
+ - shrink_map
129
+ - shrink_mask
130
+ loader:
131
+ shuffle: true
132
+ drop_last: false
133
+ batch_size_per_card: 8
134
+ num_workers: 8
135
+
136
+ Eval:
137
+ dataset:
138
+ name: SimpleDataSet
139
+ data_dir: ./train_data/icdar2015/text_localization/
140
+ label_file_list:
141
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
142
+ transforms:
143
+ - DecodeImage:
144
+ img_mode: BGR
145
+ channel_first: false
146
+ - DetLabelEncode: null
147
+ - DetResizeForTest:
148
+ - NormalizeImage:
149
+ scale: 1./255.
150
+ mean:
151
+ - 0.485
152
+ - 0.456
153
+ - 0.406
154
+ std:
155
+ - 0.229
156
+ - 0.224
157
+ - 0.225
158
+ order: hwc
159
+ - ToCHWImage: null
160
+ - KeepKeys:
161
+ keep_keys:
162
+ - image
163
+ - shape
164
+ - polys
165
+ - ignore_tags
166
+ loader:
167
+ shuffle: false
168
+ drop_last: false
169
+ batch_size_per_card: 1
170
+ num_workers: 2
171
+ profiler_options: null
Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ debug: false
3
+ use_gpu: true
4
+ epoch_num: &epoch_num 500
5
+ log_smooth_window: 20
6
+ print_batch_step: 100
7
+ save_model_dir: ./output/ch_PP-OCRv4
8
+ save_epoch_step: 10
9
+ eval_batch_step:
10
+ - 0
11
+ - 1500
12
+ cal_metric_during_train: false
13
+ checkpoints:
14
+ pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/PPHGNet_small_ocr_det.pdparams
15
+ save_inference_dir: null
16
+ use_visualdl: false
17
+ infer_img: doc/imgs_en/img_10.jpg
18
+ save_res_path: ./checkpoints/det_db/predicts_db.txt
19
+ distributed: true
20
+
21
+ Architecture:
22
+ model_type: det
23
+ algorithm: DB
24
+ Transform: null
25
+ Backbone:
26
+ name: PPHGNet_small
27
+ det: True
28
+ Neck:
29
+ name: LKPAN
30
+ out_channels: 256
31
+ intracl: true
32
+ Head:
33
+ name: PFHeadLocal
34
+ k: 50
35
+ mode: "large"
36
+
37
+
38
+ Loss:
39
+ name: DBLoss
40
+ balance_loss: true
41
+ main_loss_type: DiceLoss
42
+ alpha: 5
43
+ beta: 10
44
+ ohem_ratio: 3
45
+
46
+ Optimizer:
47
+ name: Adam
48
+ beta1: 0.9
49
+ beta2: 0.999
50
+ lr:
51
+ name: Cosine
52
+ learning_rate: 0.001 #(8*8c)
53
+ warmup_epoch: 2
54
+ regularizer:
55
+ name: L2
56
+ factor: 1e-6
57
+
58
+ PostProcess:
59
+ name: DBPostProcess
60
+ thresh: 0.3
61
+ box_thresh: 0.6
62
+ max_candidates: 1000
63
+ unclip_ratio: 1.5
64
+
65
+ Metric:
66
+ name: DetMetric
67
+ main_indicator: hmean
68
+
69
+ Train:
70
+ dataset:
71
+ name: SimpleDataSet
72
+ data_dir: ./train_data/icdar2015/text_localization/
73
+ label_file_list:
74
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
75
+ ratio_list: [1.0]
76
+ transforms:
77
+ - DecodeImage:
78
+ img_mode: BGR
79
+ channel_first: false
80
+ - DetLabelEncode: null
81
+ - CopyPaste: null
82
+ - IaaAugment:
83
+ augmenter_args:
84
+ - type: Fliplr
85
+ args:
86
+ p: 0.5
87
+ - type: Affine
88
+ args:
89
+ rotate:
90
+ - -10
91
+ - 10
92
+ - type: Resize
93
+ args:
94
+ size:
95
+ - 0.5
96
+ - 3
97
+ - EastRandomCropData:
98
+ size:
99
+ - 640
100
+ - 640
101
+ max_tries: 50
102
+ keep_ratio: true
103
+ - MakeBorderMap:
104
+ shrink_ratio: 0.4
105
+ thresh_min: 0.3
106
+ thresh_max: 0.7
107
+ total_epoch: *epoch_num
108
+ - MakeShrinkMap:
109
+ shrink_ratio: 0.4
110
+ min_text_size: 8
111
+ total_epoch: *epoch_num
112
+ - NormalizeImage:
113
+ scale: 1./255.
114
+ mean:
115
+ - 0.485
116
+ - 0.456
117
+ - 0.406
118
+ std:
119
+ - 0.229
120
+ - 0.224
121
+ - 0.225
122
+ order: hwc
123
+ - ToCHWImage: null
124
+ - KeepKeys:
125
+ keep_keys:
126
+ - image
127
+ - threshold_map
128
+ - threshold_mask
129
+ - shrink_map
130
+ - shrink_mask
131
+ loader:
132
+ shuffle: true
133
+ drop_last: false
134
+ batch_size_per_card: 8
135
+ num_workers: 8
136
+
137
+ Eval:
138
+ dataset:
139
+ name: SimpleDataSet
140
+ data_dir: ./train_data/icdar2015/text_localization/
141
+ label_file_list:
142
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
143
+ transforms:
144
+ - DecodeImage:
145
+ img_mode: BGR
146
+ channel_first: false
147
+ - DetLabelEncode: null
148
+ - DetResizeForTest:
149
+ - NormalizeImage:
150
+ scale: 1./255.
151
+ mean:
152
+ - 0.485
153
+ - 0.456
154
+ - 0.406
155
+ std:
156
+ - 0.229
157
+ - 0.224
158
+ - 0.225
159
+ order: hwc
160
+ - ToCHWImage: null
161
+ - KeepKeys:
162
+ keep_keys:
163
+ - image
164
+ - shape
165
+ - polys
166
+ - ignore_tags
167
+ loader:
168
+ shuffle: false
169
+ drop_last: false
170
+ batch_size_per_card: 1
171
+ num_workers: 2
172
+ profiler_options: null
Rotate/configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ use_gpu: true
3
+ epoch_num: 1200
4
+ log_smooth_window: 20
5
+ print_batch_step: 2
6
+ save_model_dir: ./output/ch_db_mv3/
7
+ save_epoch_step: 1200
8
+ # evaluation is run every 5000 iterations after the 4000th iteration
9
+ eval_batch_step: [3000, 2000]
10
+ cal_metric_during_train: False
11
+ pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
12
+ checkpoints:
13
+ save_inference_dir:
14
+ use_visualdl: False
15
+ infer_img: doc/imgs_en/img_10.jpg
16
+ save_res_path: ./output/det_db/predicts_db.txt
17
+
18
+ Architecture:
19
+ model_type: det
20
+ algorithm: DB
21
+ Transform:
22
+ Backbone:
23
+ name: MobileNetV3
24
+ scale: 0.5
25
+ model_name: large
26
+ disable_se: True
27
+ Neck:
28
+ name: DBFPN
29
+ out_channels: 96
30
+ Head:
31
+ name: DBHead
32
+ k: 50
33
+
34
+ Loss:
35
+ name: DBLoss
36
+ balance_loss: true
37
+ main_loss_type: DiceLoss
38
+ alpha: 5
39
+ beta: 10
40
+ ohem_ratio: 3
41
+
42
+ Optimizer:
43
+ name: Adam
44
+ beta1: 0.9
45
+ beta2: 0.999
46
+ lr:
47
+ name: Cosine
48
+ learning_rate: 0.001
49
+ warmup_epoch: 2
50
+ regularizer:
51
+ name: 'L2'
52
+ factor: 0
53
+
54
+ PostProcess:
55
+ name: DBPostProcess
56
+ thresh: 0.3
57
+ box_thresh: 0.6
58
+ max_candidates: 1000
59
+ unclip_ratio: 1.5
60
+
61
+ Metric:
62
+ name: DetMetric
63
+ main_indicator: hmean
64
+
65
+ Train:
66
+ dataset:
67
+ name: SimpleDataSet
68
+ data_dir: ./train_data/icdar2015/text_localization/
69
+ label_file_list:
70
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
71
+ ratio_list: [1.0]
72
+ transforms:
73
+ - DecodeImage: # load image
74
+ img_mode: BGR
75
+ channel_first: False
76
+ - DetLabelEncode: # Class handling label
77
+ - IaaAugment:
78
+ augmenter_args:
79
+ - { 'type': Fliplr, 'args': { 'p': 0.5 } }
80
+ - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
81
+ - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
82
+ - EastRandomCropData:
83
+ size: [960, 960]
84
+ max_tries: 50
85
+ keep_ratio: true
86
+ - MakeBorderMap:
87
+ shrink_ratio: 0.4
88
+ thresh_min: 0.3
89
+ thresh_max: 0.7
90
+ - MakeShrinkMap:
91
+ shrink_ratio: 0.4
92
+ min_text_size: 8
93
+ - NormalizeImage:
94
+ scale: 1./255.
95
+ mean: [0.485, 0.456, 0.406]
96
+ std: [0.229, 0.224, 0.225]
97
+ order: 'hwc'
98
+ - ToCHWImage:
99
+ - KeepKeys:
100
+ keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
101
+ loader:
102
+ shuffle: True
103
+ drop_last: False
104
+ batch_size_per_card: 8
105
+ num_workers: 4
106
+
107
+ Eval:
108
+ dataset:
109
+ name: SimpleDataSet
110
+ data_dir: ./train_data/icdar2015/text_localization/
111
+ label_file_list:
112
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
113
+ transforms:
114
+ - DecodeImage: # load image
115
+ img_mode: BGR
116
+ channel_first: False
117
+ - DetLabelEncode: # Class handling label
118
+ - DetResizeForTest:
119
+ # image_shape: [736, 1280]
120
+ - NormalizeImage:
121
+ scale: 1./255.
122
+ mean: [0.485, 0.456, 0.406]
123
+ std: [0.229, 0.224, 0.225]
124
+ order: 'hwc'
125
+ - ToCHWImage:
126
+ - KeepKeys:
127
+ keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
128
+ loader:
129
+ shuffle: False
130
+ drop_last: False
131
+ batch_size_per_card: 1 # must be 1
132
+ num_workers: 2
Rotate/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ use_gpu: true
3
+ epoch_num: 1200
4
+ log_smooth_window: 20
5
+ print_batch_step: 2
6
+ save_model_dir: ./output/ch_db_res18/
7
+ save_epoch_step: 1200
8
+ # evaluation is run every 5000 iterations after the 4000th iteration
9
+ eval_batch_step: [3000, 2000]
10
+ cal_metric_during_train: False
11
+ pretrained_model: ./pretrain_models/ResNet18_vd_pretrained
12
+ checkpoints:
13
+ save_inference_dir:
14
+ use_visualdl: False
15
+ infer_img: doc/imgs_en/img_10.jpg
16
+ save_res_path: ./output/det_db/predicts_db.txt
17
+
18
+ Architecture:
19
+ model_type: det
20
+ algorithm: DB
21
+ Transform:
22
+ Backbone:
23
+ name: ResNet_vd
24
+ layers: 18
25
+ disable_se: True
26
+ Neck:
27
+ name: DBFPN
28
+ out_channels: 256
29
+ Head:
30
+ name: DBHead
31
+ k: 50
32
+
33
+ Loss:
34
+ name: DBLoss
35
+ balance_loss: true
36
+ main_loss_type: DiceLoss
37
+ alpha: 5
38
+ beta: 10
39
+ ohem_ratio: 3
40
+
41
+ Optimizer:
42
+ name: Adam
43
+ beta1: 0.9
44
+ beta2: 0.999
45
+ lr:
46
+ name: Cosine
47
+ learning_rate: 0.001
48
+ warmup_epoch: 2
49
+ regularizer:
50
+ name: 'L2'
51
+ factor: 0
52
+
53
+ PostProcess:
54
+ name: DBPostProcess
55
+ thresh: 0.3
56
+ box_thresh: 0.6
57
+ max_candidates: 1000
58
+ unclip_ratio: 1.5
59
+
60
+ Metric:
61
+ name: DetMetric
62
+ main_indicator: hmean
63
+
64
+ Train:
65
+ dataset:
66
+ name: SimpleDataSet
67
+ data_dir: ./train_data/icdar2015/text_localization/
68
+ label_file_list:
69
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
70
+ ratio_list: [1.0]
71
+ transforms:
72
+ - DecodeImage: # load image
73
+ img_mode: BGR
74
+ channel_first: False
75
+ - DetLabelEncode: # Class handling label
76
+ - IaaAugment:
77
+ augmenter_args:
78
+ - { 'type': Fliplr, 'args': { 'p': 0.5 } }
79
+ - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
80
+ - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
81
+ - EastRandomCropData:
82
+ size: [960, 960]
83
+ max_tries: 50
84
+ keep_ratio: true
85
+ - MakeBorderMap:
86
+ shrink_ratio: 0.4
87
+ thresh_min: 0.3
88
+ thresh_max: 0.7
89
+ - MakeShrinkMap:
90
+ shrink_ratio: 0.4
91
+ min_text_size: 8
92
+ - NormalizeImage:
93
+ scale: 1./255.
94
+ mean: [0.485, 0.456, 0.406]
95
+ std: [0.229, 0.224, 0.225]
96
+ order: 'hwc'
97
+ - ToCHWImage:
98
+ - KeepKeys:
99
+ keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
100
+ loader:
101
+ shuffle: True
102
+ drop_last: False
103
+ batch_size_per_card: 8
104
+ num_workers: 4
105
+
106
+ Eval:
107
+ dataset:
108
+ name: SimpleDataSet
109
+ data_dir: ./train_data/icdar2015/text_localization/
110
+ label_file_list:
111
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
112
+ transforms:
113
+ - DecodeImage: # load image
114
+ img_mode: BGR
115
+ channel_first: False
116
+ - DetLabelEncode: # Class handling label
117
+ - DetResizeForTest:
118
+ # image_shape: [736, 1280]
119
+ - NormalizeImage:
120
+ scale: 1./255.
121
+ mean: [0.485, 0.456, 0.406]
122
+ std: [0.229, 0.224, 0.225]
123
+ order: 'hwc'
124
+ - ToCHWImage:
125
+ - KeepKeys:
126
+ keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
127
+ loader:
128
+ shuffle: False
129
+ drop_last: False
130
+ batch_size_per_card: 1 # must be 1
131
+ num_workers: 2
Rotate/configs/det/det_mv3_db.yml ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Global:
2
+ use_gpu: true
3
+ use_xpu: false
4
+ use_mlu: false
5
+ epoch_num: 1200
6
+ log_smooth_window: 20
7
+ print_batch_step: 10
8
+ save_model_dir: ./output/db_mv3/
9
+ save_epoch_step: 1200
10
+ # evaluation is run every 2000 iterations
11
+ eval_batch_step: [0, 2000]
12
+ cal_metric_during_train: False
13
+ pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
14
+ checkpoints:
15
+ save_inference_dir:
16
+ use_visualdl: False
17
+ infer_img: doc/imgs_en/img_10.jpg
18
+ save_res_path: ./output/det_db/predicts_db.txt
19
+
20
+ Architecture:
21
+ model_type: det
22
+ algorithm: DB
23
+ Transform:
24
+ Backbone:
25
+ name: MobileNetV3
26
+ scale: 0.5
27
+ model_name: large
28
+ Neck:
29
+ name: DBFPN
30
+ out_channels: 256
31
+ Head:
32
+ name: DBHead
33
+ k: 50
34
+
35
+ Loss:
36
+ name: DBLoss
37
+ balance_loss: true
38
+ main_loss_type: DiceLoss
39
+ alpha: 5
40
+ beta: 10
41
+ ohem_ratio: 3
42
+
43
+ Optimizer:
44
+ name: Adam
45
+ beta1: 0.9
46
+ beta2: 0.999
47
+ lr:
48
+ learning_rate: 0.001
49
+ regularizer:
50
+ name: 'L2'
51
+ factor: 0
52
+
53
+ PostProcess:
54
+ name: DBPostProcess
55
+ thresh: 0.3
56
+ box_thresh: 0.6
57
+ max_candidates: 1000
58
+ unclip_ratio: 1.5
59
+
60
+ Metric:
61
+ name: DetMetric
62
+ main_indicator: hmean
63
+
64
+ Train:
65
+ dataset:
66
+ name: SimpleDataSet
67
+ data_dir: ./train_data/icdar2015/text_localization/
68
+ label_file_list:
69
+ - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
70
+ ratio_list: [1.0]
71
+ transforms:
72
+ - DecodeImage: # load image
73
+ img_mode: BGR
74
+ channel_first: False
75
+ - DetLabelEncode: # Class handling label
76
+ - IaaAugment:
77
+ augmenter_args:
78
+ - { 'type': Fliplr, 'args': { 'p': 0.5 } }
79
+ - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
80
+ - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
81
+ - EastRandomCropData:
82
+ size: [640, 640]
83
+ max_tries: 50
84
+ keep_ratio: true
85
+ - MakeBorderMap:
86
+ shrink_ratio: 0.4
87
+ thresh_min: 0.3
88
+ thresh_max: 0.7
89
+ - MakeShrinkMap:
90
+ shrink_ratio: 0.4
91
+ min_text_size: 8
92
+ - NormalizeImage:
93
+ scale: 1./255.
94
+ mean: [0.485, 0.456, 0.406]
95
+ std: [0.229, 0.224, 0.225]
96
+ order: 'hwc'
97
+ - ToCHWImage:
98
+ - KeepKeys:
99
+ keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
100
+ loader:
101
+ shuffle: True
102
+ drop_last: False
103
+ batch_size_per_card: 16
104
+ num_workers: 8
105
+ use_shared_memory: True
106
+
107
+ Eval:
108
+ dataset:
109
+ name: SimpleDataSet
110
+ data_dir: ./train_data/icdar2015/text_localization/
111
+ label_file_list:
112
+ - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
113
+ transforms:
114
+ - DecodeImage: # load image
115
+ img_mode: BGR
116
+ channel_first: False
117
+ - DetLabelEncode: # Class handling label
118
+ - DetResizeForTest:
119
+ image_shape: [736, 1280]
120
+ - NormalizeImage:
121
+ scale: 1./255.
122
+ mean: [0.485, 0.456, 0.406]
123
+ std: [0.229, 0.224, 0.225]
124
+ order: 'hwc'
125
+ - ToCHWImage:
126
+ - KeepKeys:
127
+ keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
128
+ loader:
129
+ shuffle: False
130
+ drop_last: False
131
+ batch_size_per_card: 1 # must be 1
132
+ num_workers: 8
133
+ use_shared_memory: True