Spaces:
Runtime error
Runtime error
rotate
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Rotate/StyleText/README.md +219 -0
- Rotate/StyleText/README_ch.md +205 -0
- Rotate/StyleText/__init__.py +0 -0
- Rotate/StyleText/arch/__init__.py +0 -0
- Rotate/StyleText/arch/base_module.py +255 -0
- Rotate/StyleText/arch/decoder.py +251 -0
- Rotate/StyleText/arch/encoder.py +186 -0
- Rotate/StyleText/arch/spectral_norm.py +150 -0
- Rotate/StyleText/arch/style_text_rec.py +285 -0
- Rotate/StyleText/configs/config.yml +54 -0
- Rotate/StyleText/configs/dataset_config.yml +64 -0
- Rotate/StyleText/engine/__init__.py +0 -0
- Rotate/StyleText/engine/corpus_generators.py +66 -0
- Rotate/StyleText/engine/predictors.py +139 -0
- Rotate/StyleText/engine/style_samplers.py +62 -0
- Rotate/StyleText/engine/synthesisers.py +77 -0
- Rotate/StyleText/engine/text_drawers.py +85 -0
- Rotate/StyleText/engine/writers.py +71 -0
- Rotate/StyleText/examples/corpus/example.txt +2 -0
- Rotate/StyleText/examples/image_list.txt +2 -0
- Rotate/StyleText/tools/__init__.py +0 -0
- Rotate/StyleText/tools/synth_dataset.py +31 -0
- Rotate/StyleText/tools/synth_image.py +82 -0
- Rotate/StyleText/utils/__init__.py +0 -0
- Rotate/StyleText/utils/config.py +224 -0
- Rotate/StyleText/utils/load_params.py +27 -0
- Rotate/StyleText/utils/logging.py +65 -0
- Rotate/StyleText/utils/math_functions.py +45 -0
- Rotate/StyleText/utils/sys_funcs.py +67 -0
- Rotate/__init__.py +18 -0
- Rotate/ch_PP-OCRv4_det_infer/inference.pdiparams.info +0 -0
- Rotate/ch_PP-OCRv4_det_infer/inference.pdmodel +3 -0
- Rotate/ch_ppocr_mobile_v2.0_cls_infer/._inference.pdmodel +3 -0
- Rotate/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams +3 -0
- Rotate/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams.info +0 -0
- Rotate/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml +98 -0
- Rotate/configs/cls/cls_mv3.yml +94 -0
- Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml +206 -0
- Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml +175 -0
- Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml +178 -0
- Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml +132 -0
- Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml +226 -0
- Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml +173 -0
- Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml +163 -0
- Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_cml.yml +235 -0
- Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml +171 -0
- Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml +172 -0
- Rotate/configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml +132 -0
- Rotate/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml +131 -0
- Rotate/configs/det/det_mv3_db.yml +133 -0
Rotate/StyleText/README.md
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
English | [简体中文](README_ch.md)
|
2 |
+
|
3 |
+
## Style Text
|
4 |
+
|
5 |
+
### Contents
|
6 |
+
- [1. Introduction](#Introduction)
|
7 |
+
- [2. Preparation](#Preparation)
|
8 |
+
- [3. Quick Start](#Quick_Start)
|
9 |
+
- [4. Applications](#Applications)
|
10 |
+
- [5. Code Structure](#Code_structure)
|
11 |
+
|
12 |
+
|
13 |
+
<a name="Introduction"></a>
|
14 |
+
### Introduction
|
15 |
+
|
16 |
+
<div align="center">
|
17 |
+
<img src="doc/images/3.png" width="800">
|
18 |
+
</div>
|
19 |
+
|
20 |
+
<div align="center">
|
21 |
+
<img src="doc/images/9.png" width="600">
|
22 |
+
</div>
|
23 |
+
|
24 |
+
|
25 |
+
The Style-Text data synthesis tool is a tool based on Baidu and HUST cooperation research work, "Editing Text in the Wild" [https://arxiv.org/abs/1908.03047](https://arxiv.org/abs/1908.03047).
|
26 |
+
|
27 |
+
Different from the commonly used GAN-based data synthesis tools, the main framework of Style-Text includes:
|
28 |
+
* (1) Text foreground style transfer module.
|
29 |
+
* (2) Background extraction module.
|
30 |
+
* (3) Fusion module.
|
31 |
+
|
32 |
+
After these three steps, you can quickly realize the image text style transfer. The following figure is some results of the data synthesis tool.
|
33 |
+
|
34 |
+
<div align="center">
|
35 |
+
<img src="doc/images/10.png" width="1000">
|
36 |
+
</div>
|
37 |
+
|
38 |
+
|
39 |
+
<a name="Preparation"></a>
|
40 |
+
#### Preparation
|
41 |
+
|
42 |
+
1. Please refer the [QUICK INSTALLATION](../doc/doc_en/installation_en.md) to install PaddlePaddle. Python3 environment is strongly recommended.
|
43 |
+
2. Download the pretrained models and unzip:
|
44 |
+
|
45 |
+
```bash
|
46 |
+
cd StyleText
|
47 |
+
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip
|
48 |
+
unzip style_text_models.zip
|
49 |
+
```
|
50 |
+
|
51 |
+
If you save the model in another location, please modify the address of the model file in `configs/config.yml`, and you need to modify these three configurations at the same time:
|
52 |
+
|
53 |
+
```
|
54 |
+
bg_generator:
|
55 |
+
pretrain: style_text_models/bg_generator
|
56 |
+
...
|
57 |
+
text_generator:
|
58 |
+
pretrain: style_text_models/text_generator
|
59 |
+
...
|
60 |
+
fusion_generator:
|
61 |
+
pretrain: style_text_models/fusion_generator
|
62 |
+
```
|
63 |
+
|
64 |
+
<a name="Quick_Start"></a>
|
65 |
+
### Quick Start
|
66 |
+
|
67 |
+
#### Synthesis single image
|
68 |
+
|
69 |
+
1. You can run `tools/synth_image` and generate the demo image, which is saved in the current folder.
|
70 |
+
|
71 |
+
```python
|
72 |
+
python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
|
73 |
+
```
|
74 |
+
|
75 |
+
* Note 1: The language options is correspond to the corpus. Currently, the tool only supports English(en), Simplified Chinese(ch) and Korean(ko).
|
76 |
+
* Note 2: Synth-Text is mainly used to generate images for OCR recognition models.
|
77 |
+
So the height of style images should be around 32 pixels. Images in other sizes may behave poorly.
|
78 |
+
* Note 3: You can modify `use_gpu` in `configs/config.yml` to determine whether to use GPU for prediction.
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
For example, enter the following image and corpus `PaddleOCR`.
|
83 |
+
|
84 |
+
<div align="center">
|
85 |
+
<img src="examples/style_images/2.jpg" width="300">
|
86 |
+
</div>
|
87 |
+
|
88 |
+
The result `fake_fusion.jpg` will be generated.
|
89 |
+
|
90 |
+
<div align="center">
|
91 |
+
<img src="doc/images/4.jpg" width="300">
|
92 |
+
</div>
|
93 |
+
|
94 |
+
What's more, the medium result `fake_bg.jpg` will also be saved, which is the background output.
|
95 |
+
|
96 |
+
<div align="center">
|
97 |
+
<img src="doc/images/7.jpg" width="300">
|
98 |
+
</div>
|
99 |
+
|
100 |
+
|
101 |
+
`fake_text.jpg` * `fake_text.jpg` is the generated image with the same font style as `Style Input`.
|
102 |
+
|
103 |
+
|
104 |
+
<div align="center">
|
105 |
+
<img src="doc/images/8.jpg" width="300">
|
106 |
+
</div>
|
107 |
+
|
108 |
+
|
109 |
+
#### Batch synthesis
|
110 |
+
|
111 |
+
In actual application scenarios, it is often necessary to synthesize pictures in batches and add them to the training set. StyleText can use a batch of style pictures and corpus to synthesize data in batches. The synthesis process is as follows:
|
112 |
+
|
113 |
+
1. The referenced dataset can be specifed in `configs/dataset_config.yml`:
|
114 |
+
|
115 |
+
* `Global`:
|
116 |
+
* `output_dir:`:Output synthesis data path.
|
117 |
+
* `StyleSampler`:
|
118 |
+
* `image_home`:style images' folder.
|
119 |
+
* `label_file`:Style images' file list. If label is provided, then it is the label file path.
|
120 |
+
* `with_label`:Whether the `label_file` is label file list.
|
121 |
+
* `CorpusGenerator`:
|
122 |
+
* `method`:Method of CorpusGenerator,supports `FileCorpus` and `EnNumCorpus`. If `EnNumCorpus` is used,No other configuration is needed,otherwise you need to set `corpus_file` and `language`.
|
123 |
+
* `language`:Language of the corpus. Currently, the tool only supports English(en), Simplified Chinese(ch) and Korean(ko).
|
124 |
+
* `corpus_file`: Filepath of the corpus. Corpus file should be a text file which will be split by line-endings('\n'). Corpus generator samples one line each time.
|
125 |
+
|
126 |
+
|
127 |
+
Example of corpus file:
|
128 |
+
```
|
129 |
+
PaddleOCR
|
130 |
+
飞桨文字识别
|
131 |
+
StyleText
|
132 |
+
风格文本图像数据合成
|
133 |
+
```
|
134 |
+
|
135 |
+
We provide a general dataset containing Chinese, English and Korean (50,000 images in all) for your trial ([download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar)), some examples are given below :
|
136 |
+
|
137 |
+
<div align="center">
|
138 |
+
<img src="doc/images/5.png" width="800">
|
139 |
+
</div>
|
140 |
+
|
141 |
+
2. You can run the following command to start synthesis task:
|
142 |
+
|
143 |
+
``` bash
|
144 |
+
python3 tools/synth_dataset.py -c configs/dataset_config.yml
|
145 |
+
```
|
146 |
+
|
147 |
+
We also provide example corpus and images in `examples` folder.
|
148 |
+
<div align="center">
|
149 |
+
<img src="examples/style_images/1.jpg" width="300">
|
150 |
+
<img src="examples/style_images/2.jpg" width="300">
|
151 |
+
</div>
|
152 |
+
If you run the code above directly, you will get example output data in `output_data` folder.
|
153 |
+
You will get synthesis images and labels as below:
|
154 |
+
<div align="center">
|
155 |
+
<img src="doc/images/12.png" width="800">
|
156 |
+
</div>
|
157 |
+
There will be some cache under the `label` folder. If the program exit unexpectedly, you can find cached labels there.
|
158 |
+
When the program finish normally, you will find all the labels in `label.txt` which give the final results.
|
159 |
+
|
160 |
+
<a name="Applications"></a>
|
161 |
+
### Applications
|
162 |
+
We take two scenes as examples, which are metal surface English number recognition and general Korean recognition, to illustrate practical cases of using StyleText to synthesize data to improve text recognition. The following figure shows some examples of real scene images and composite images:
|
163 |
+
|
164 |
+
<div align="center">
|
165 |
+
<img src="doc/images/11.png" width="800">
|
166 |
+
</div>
|
167 |
+
|
168 |
+
|
169 |
+
After adding the above synthetic data for training, the accuracy of the recognition model is improved, which is shown in the following table:
|
170 |
+
|
171 |
+
|
172 |
+
| Scenario | Characters | Raw Data | Test Data | Only Use Raw Data</br>Recognition Accuracy | New Synthetic Data | Simultaneous Use of Synthetic Data</br>Recognition Accuracy | Index Improvement |
|
173 |
+
| -------- | ---------- | -------- | -------- | -------------------------- | ------------ | ---------------------- | -------- |
|
174 |
+
| Metal surface | English and numbers | 2203 | 650 | 59.38% | 20000 | 75.46% | 16.08% |
|
175 |
+
| Random background | Korean | 5631 | 1230 | 30.12% | 100000 | 50.57% | 20.45% |
|
176 |
+
|
177 |
+
<a name="Code_structure"></a>
|
178 |
+
### Code Structure
|
179 |
+
|
180 |
+
```
|
181 |
+
StyleText
|
182 |
+
|-- arch // Network module files.
|
183 |
+
| |-- base_module.py
|
184 |
+
| |-- decoder.py
|
185 |
+
| |-- encoder.py
|
186 |
+
| |-- spectral_norm.py
|
187 |
+
| `-- style_text_rec.py
|
188 |
+
|-- configs // Config files.
|
189 |
+
| |-- config.yml
|
190 |
+
| `-- dataset_config.yml
|
191 |
+
|-- engine // Synthesis engines.
|
192 |
+
| |-- corpus_generators.py // Sample corpus from file or generate random corpus.
|
193 |
+
| |-- predictors.py // Predict using network.
|
194 |
+
| |-- style_samplers.py // Sample style images.
|
195 |
+
| |-- synthesisers.py // Manage other engines to synthesis images.
|
196 |
+
| |-- text_drawers.py // Generate standard input text images.
|
197 |
+
| `-- writers.py // Write synthesis images and labels into files.
|
198 |
+
|-- examples // Example files.
|
199 |
+
| |-- corpus
|
200 |
+
| | `-- example.txt
|
201 |
+
| |-- image_list.txt
|
202 |
+
| `-- style_images
|
203 |
+
| |-- 1.jpg
|
204 |
+
| `-- 2.jpg
|
205 |
+
|-- fonts // Font files.
|
206 |
+
| |-- ch_standard.ttf
|
207 |
+
| |-- en_standard.ttf
|
208 |
+
| `-- ko_standard.ttf
|
209 |
+
|-- tools // Program entrance.
|
210 |
+
| |-- __init__.py
|
211 |
+
| |-- synth_dataset.py // Synthesis dataset.
|
212 |
+
| `-- synth_image.py // Synthesis image.
|
213 |
+
`-- utils // Module of basic functions.
|
214 |
+
|-- config.py
|
215 |
+
|-- load_params.py
|
216 |
+
|-- logging.py
|
217 |
+
|-- math_functions.py
|
218 |
+
`-- sys_funcs.py
|
219 |
+
```
|
Rotate/StyleText/README_ch.md
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
简体中文 | [English](README.md)
|
2 |
+
|
3 |
+
## Style Text
|
4 |
+
|
5 |
+
|
6 |
+
### 目录
|
7 |
+
- [一、工具简介](#工具简介)
|
8 |
+
- [二、环境配置](#环境配置)
|
9 |
+
- [三、快速上手](#快速上手)
|
10 |
+
- [四、应用案例](#应用案例)
|
11 |
+
- [五、代码结构](#代码结构)
|
12 |
+
|
13 |
+
<a name="工具简介"></a>
|
14 |
+
### 一、工具简介
|
15 |
+
<div align="center">
|
16 |
+
<img src="doc/images/3.png" width="800">
|
17 |
+
</div>
|
18 |
+
|
19 |
+
<div align="center">
|
20 |
+
<img src="doc/images/1.png" width="600">
|
21 |
+
</div>
|
22 |
+
|
23 |
+
|
24 |
+
Style-Text数据合成工具是基于百度和华科合作研发的文本编辑算法《Editing Text in the Wild》https://arxiv.org/abs/1908.03047
|
25 |
+
|
26 |
+
不同于常用的基于GAN的数据合成工具,Style-Text主要框架包括:1.文本前景风格迁移模块 2.背景抽取模块 3.融合模块。经过这样三步,就可以迅速实现图像文本风格迁移。下图是一些该数据合成工具效果图。
|
27 |
+
|
28 |
+
<div align="center">
|
29 |
+
<img src="doc/images/2.png" width="1000">
|
30 |
+
</div>
|
31 |
+
|
32 |
+
<a name="环境配置"></a>
|
33 |
+
### 二、环境配置
|
34 |
+
|
35 |
+
1. 参考[快速安装](../doc/doc_ch/installation.md),安装PaddleOCR。
|
36 |
+
2. 进入`StyleText`目录,下载模型,并解压:
|
37 |
+
|
38 |
+
```bash
|
39 |
+
cd StyleText
|
40 |
+
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip
|
41 |
+
unzip style_text_models.zip
|
42 |
+
```
|
43 |
+
|
44 |
+
如果您将模型保存再其他位置,请在`configs/config.yml`中修改模型文件的地址,修改时需要同时修改这三个配置:
|
45 |
+
|
46 |
+
```
|
47 |
+
bg_generator:
|
48 |
+
pretrain: style_text_models/bg_generator
|
49 |
+
...
|
50 |
+
text_generator:
|
51 |
+
pretrain: style_text_models/text_generator
|
52 |
+
...
|
53 |
+
fusion_generator:
|
54 |
+
pretrain: style_text_models/fusion_generator
|
55 |
+
```
|
56 |
+
|
57 |
+
<a name="快速上手"></a>
|
58 |
+
### 三、快速上手
|
59 |
+
|
60 |
+
#### 合成单张图
|
61 |
+
输入一张风格图和一段文字语料,运行tools/synth_image,合成单张图片,结果图像保存在当前目录下:
|
62 |
+
|
63 |
+
```python
|
64 |
+
python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
|
65 |
+
```
|
66 |
+
* 注1:语言选项和语料相对应,目前支持英文(en)、简体中文(ch)和韩语(ko)。
|
67 |
+
* 注2:Style-Text生成的数据主要应用于OCR识别场景。基于当前PaddleOCR识别模型的设计,我们主要支持高度在32左右的风格图像。
|
68 |
+
如果输入图像尺寸相差过多,效果可能不佳。
|
69 |
+
* 注3:可以通过修改配置文件`configs/config.yml`中的`use_gpu`(true或者false)参数来决定是否使用GPU进行预测。
|
70 |
+
|
71 |
+
|
72 |
+
例如,输入如下图片和语料"PaddleOCR":
|
73 |
+
|
74 |
+
<div align="center">
|
75 |
+
<img src="examples/style_images/2.jpg" width="300">
|
76 |
+
</div>
|
77 |
+
|
78 |
+
生成合成数据`fake_fusion.jpg`:
|
79 |
+
<div align="center">
|
80 |
+
<img src="doc/images/4.jpg" width="300">
|
81 |
+
</div>
|
82 |
+
|
83 |
+
除此之外,程序还会生成并保存中间结果`fake_bg.jpg`:为风格参考图去掉文字后的背景;
|
84 |
+
|
85 |
+
<div align="center">
|
86 |
+
<img src="doc/images/7.jpg" width="300">
|
87 |
+
</div>
|
88 |
+
|
89 |
+
`fake_text.jpg`:是用提供的字符串,仿照风格参考图中文字的风格,生成在灰色背景上的文字图片。
|
90 |
+
|
91 |
+
<div align="center">
|
92 |
+
<img src="doc/images/8.jpg" width="300">
|
93 |
+
</div>
|
94 |
+
|
95 |
+
#### 批量合成
|
96 |
+
在实际应用场景中,经常需要批量合成图片,补充到训练集中。Style-Text可以使用一批风格图片和语料,批量合成数据。合成过程如下:
|
97 |
+
|
98 |
+
1. 在`configs/dataset_config.yml`中配置目标场景风格图像和语料的路径,具体如下:
|
99 |
+
|
100 |
+
* `Global`:
|
101 |
+
* `output_dir:`:保存合成数据的目录。
|
102 |
+
* `StyleSampler`:
|
103 |
+
* `image_home`:风格图片目录;
|
104 |
+
* `label_file`:风格图片路径列表文件,如果所用数据集有label,则label_file为label文件路径;
|
105 |
+
* `with_label`:标志`label_file`是否为label文件。
|
106 |
+
* `CorpusGenerator`:
|
107 |
+
* `method`:语料生成方法,目前有`FileCorpus`和`EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file`和`language`;
|
108 |
+
* `language`:语料的语种,目前支持英文(en)、简体中文(ch)和韩语(ko);
|
109 |
+
* `corpus_file`: 语料文件路径。语料文件应使用文本文件。语料生成器首先会将语料按行切分,之后每次随机选取一行。
|
110 |
+
|
111 |
+
语料文件格式示例:
|
112 |
+
```
|
113 |
+
PaddleOCR
|
114 |
+
飞桨文字识别
|
115 |
+
StyleText
|
116 |
+
风格文本图像数据合成
|
117 |
+
...
|
118 |
+
```
|
119 |
+
|
120 |
+
Style-Text也提供了一批中英韩5万张通用场景数据用作文本风格图像,便于合成场景丰富的文本图像,下图给出了一些示例。
|
121 |
+
|
122 |
+
中英韩5万张通用场景数据: [下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar)
|
123 |
+
|
124 |
+
<div align="center">
|
125 |
+
<img src="doc/images/5.png" width="800">
|
126 |
+
</div>
|
127 |
+
|
128 |
+
2. 运行`tools/synth_dataset`合成数据:
|
129 |
+
|
130 |
+
``` bash
|
131 |
+
python3 tools/synth_dataset.py -c configs/dataset_config.yml
|
132 |
+
```
|
133 |
+
我们在examples目录下提供了样例图片和语料。
|
134 |
+
<div align="center">
|
135 |
+
<img src="examples/style_images/1.jpg" width="300">
|
136 |
+
<img src="examples/style_images/2.jpg" width="300">
|
137 |
+
</div>
|
138 |
+
|
139 |
+
直接运行上述命令,可以在output_data中产生样例输出,包括图片和用于训练识别模型的标注文件:
|
140 |
+
<div align="center">
|
141 |
+
<img src="doc/images/12.png" width="800">
|
142 |
+
</div>
|
143 |
+
|
144 |
+
其中label目录下的标注文件为程序运行过程中产生的缓存,如果程序在中途异常终止,可以使用缓存的标注文件。
|
145 |
+
如果程序正常运行完毕,则会在output_data下生成label.txt,为最终的标注结果。
|
146 |
+
|
147 |
+
<a name="应用案例"></a>
|
148 |
+
### 四、应用案例
|
149 |
+
下面以金属表面英文数字识别和通用韩语识别两个场景为例,说明使用Style-Text合成数据,来提升文本识别效果的实际案例。下图给出了一些真实场景图像和合成图像的示例:
|
150 |
+
|
151 |
+
<div align="center">
|
152 |
+
<img src="doc/images/6.png" width="800">
|
153 |
+
</div>
|
154 |
+
|
155 |
+
在添加上述合成数据进行训练后,识别模型的效果提升,如下表所示:
|
156 |
+
|
157 |
+
| 场景 | 字符 | 原始数据 | 测试数据 | 只使用原始数据</br>识别准确率 | 新增合成数据 | 同时使用合成数据</br>识别准确率 | 指标提升 |
|
158 |
+
| -------- | ---------- | -------- | -------- | -------------------------- | ------------ | ---------------------- | -------- |
|
159 |
+
| 金属表面 | 英文和数字 | 2203 | 650 | 59.38% | 20000 | 75.46% | 16.08% |
|
160 |
+
| 随机背景 | 韩语 | 5631 | 1230 | 30.12% | 100000 | 50.57% | 20.45% |
|
161 |
+
|
162 |
+
|
163 |
+
<a name="代码结构"></a>
|
164 |
+
### 五、代码结构
|
165 |
+
|
166 |
+
```
|
167 |
+
StyleText
|
168 |
+
|-- arch // 网络结构定义文件
|
169 |
+
| |-- base_module.py
|
170 |
+
| |-- decoder.py
|
171 |
+
| |-- encoder.py
|
172 |
+
| |-- spectral_norm.py
|
173 |
+
| `-- style_text_rec.py
|
174 |
+
|-- configs // 配置文件
|
175 |
+
| |-- config.yml
|
176 |
+
| `-- dataset_config.yml
|
177 |
+
|-- engine // 数据合成引擎
|
178 |
+
| |-- corpus_generators.py // 从文本采样或随机生成语料
|
179 |
+
| |-- predictors.py // 调用网络生成数据
|
180 |
+
| |-- style_samplers.py // 采样风格图片
|
181 |
+
| |-- synthesisers.py // 调度各个模块,合成数据
|
182 |
+
| |-- text_drawers.py // 生成标准文字图片,用作输入
|
183 |
+
| `-- writers.py // 将合成的图片和标签写入本地目录
|
184 |
+
|-- examples // 示例文件
|
185 |
+
| |-- corpus
|
186 |
+
| | `-- example.txt
|
187 |
+
| |-- image_list.txt
|
188 |
+
| `-- style_images
|
189 |
+
| |-- 1.jpg
|
190 |
+
| `-- 2.jpg
|
191 |
+
|-- fonts // 字体文件
|
192 |
+
| |-- ch_standard.ttf
|
193 |
+
| |-- en_standard.ttf
|
194 |
+
| `-- ko_standard.ttf
|
195 |
+
|-- tools // 程序入口
|
196 |
+
| |-- __init__.py
|
197 |
+
| |-- synth_dataset.py // 批量合成数据
|
198 |
+
| `-- synth_image.py // 合成单张图片
|
199 |
+
`-- utils // 其他基础功能模块
|
200 |
+
|-- config.py
|
201 |
+
|-- load_params.py
|
202 |
+
|-- logging.py
|
203 |
+
|-- math_functions.py
|
204 |
+
`-- sys_funcs.py
|
205 |
+
```
|
Rotate/StyleText/__init__.py
ADDED
File without changes
|
Rotate/StyleText/arch/__init__.py
ADDED
File without changes
|
Rotate/StyleText/arch/base_module.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import paddle
|
15 |
+
import paddle.nn as nn
|
16 |
+
|
17 |
+
from arch.spectral_norm import spectral_norm
|
18 |
+
|
19 |
+
|
20 |
+
class CBN(nn.Layer):
|
21 |
+
def __init__(self,
|
22 |
+
name,
|
23 |
+
in_channels,
|
24 |
+
out_channels,
|
25 |
+
kernel_size,
|
26 |
+
stride=1,
|
27 |
+
padding=0,
|
28 |
+
dilation=1,
|
29 |
+
groups=1,
|
30 |
+
use_bias=False,
|
31 |
+
norm_layer=None,
|
32 |
+
act=None,
|
33 |
+
act_attr=None):
|
34 |
+
super(CBN, self).__init__()
|
35 |
+
if use_bias:
|
36 |
+
bias_attr = paddle.ParamAttr(name=name + "_bias")
|
37 |
+
else:
|
38 |
+
bias_attr = None
|
39 |
+
self._conv = paddle.nn.Conv2D(
|
40 |
+
in_channels=in_channels,
|
41 |
+
out_channels=out_channels,
|
42 |
+
kernel_size=kernel_size,
|
43 |
+
stride=stride,
|
44 |
+
padding=padding,
|
45 |
+
dilation=dilation,
|
46 |
+
groups=groups,
|
47 |
+
weight_attr=paddle.ParamAttr(name=name + "_weights"),
|
48 |
+
bias_attr=bias_attr)
|
49 |
+
if norm_layer:
|
50 |
+
self._norm_layer = getattr(paddle.nn, norm_layer)(
|
51 |
+
num_features=out_channels, name=name + "_bn")
|
52 |
+
else:
|
53 |
+
self._norm_layer = None
|
54 |
+
if act:
|
55 |
+
if act_attr:
|
56 |
+
self._act = getattr(paddle.nn, act)(**act_attr,
|
57 |
+
name=name + "_" + act)
|
58 |
+
else:
|
59 |
+
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
|
60 |
+
else:
|
61 |
+
self._act = None
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
out = self._conv(x)
|
65 |
+
if self._norm_layer:
|
66 |
+
out = self._norm_layer(out)
|
67 |
+
if self._act:
|
68 |
+
out = self._act(out)
|
69 |
+
return out
|
70 |
+
|
71 |
+
|
72 |
+
class SNConv(nn.Layer):
|
73 |
+
def __init__(self,
|
74 |
+
name,
|
75 |
+
in_channels,
|
76 |
+
out_channels,
|
77 |
+
kernel_size,
|
78 |
+
stride=1,
|
79 |
+
padding=0,
|
80 |
+
dilation=1,
|
81 |
+
groups=1,
|
82 |
+
use_bias=False,
|
83 |
+
norm_layer=None,
|
84 |
+
act=None,
|
85 |
+
act_attr=None):
|
86 |
+
super(SNConv, self).__init__()
|
87 |
+
if use_bias:
|
88 |
+
bias_attr = paddle.ParamAttr(name=name + "_bias")
|
89 |
+
else:
|
90 |
+
bias_attr = None
|
91 |
+
self._sn_conv = spectral_norm(
|
92 |
+
paddle.nn.Conv2D(
|
93 |
+
in_channels=in_channels,
|
94 |
+
out_channels=out_channels,
|
95 |
+
kernel_size=kernel_size,
|
96 |
+
stride=stride,
|
97 |
+
padding=padding,
|
98 |
+
dilation=dilation,
|
99 |
+
groups=groups,
|
100 |
+
weight_attr=paddle.ParamAttr(name=name + "_weights"),
|
101 |
+
bias_attr=bias_attr))
|
102 |
+
if norm_layer:
|
103 |
+
self._norm_layer = getattr(paddle.nn, norm_layer)(
|
104 |
+
num_features=out_channels, name=name + "_bn")
|
105 |
+
else:
|
106 |
+
self._norm_layer = None
|
107 |
+
if act:
|
108 |
+
if act_attr:
|
109 |
+
self._act = getattr(paddle.nn, act)(**act_attr,
|
110 |
+
name=name + "_" + act)
|
111 |
+
else:
|
112 |
+
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
|
113 |
+
else:
|
114 |
+
self._act = None
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
out = self._sn_conv(x)
|
118 |
+
if self._norm_layer:
|
119 |
+
out = self._norm_layer(out)
|
120 |
+
if self._act:
|
121 |
+
out = self._act(out)
|
122 |
+
return out
|
123 |
+
|
124 |
+
|
125 |
+
class SNConvTranspose(nn.Layer):
|
126 |
+
def __init__(self,
|
127 |
+
name,
|
128 |
+
in_channels,
|
129 |
+
out_channels,
|
130 |
+
kernel_size,
|
131 |
+
stride=1,
|
132 |
+
padding=0,
|
133 |
+
output_padding=0,
|
134 |
+
dilation=1,
|
135 |
+
groups=1,
|
136 |
+
use_bias=False,
|
137 |
+
norm_layer=None,
|
138 |
+
act=None,
|
139 |
+
act_attr=None):
|
140 |
+
super(SNConvTranspose, self).__init__()
|
141 |
+
if use_bias:
|
142 |
+
bias_attr = paddle.ParamAttr(name=name + "_bias")
|
143 |
+
else:
|
144 |
+
bias_attr = None
|
145 |
+
self._sn_conv_transpose = spectral_norm(
|
146 |
+
paddle.nn.Conv2DTranspose(
|
147 |
+
in_channels=in_channels,
|
148 |
+
out_channels=out_channels,
|
149 |
+
kernel_size=kernel_size,
|
150 |
+
stride=stride,
|
151 |
+
padding=padding,
|
152 |
+
output_padding=output_padding,
|
153 |
+
dilation=dilation,
|
154 |
+
groups=groups,
|
155 |
+
weight_attr=paddle.ParamAttr(name=name + "_weights"),
|
156 |
+
bias_attr=bias_attr))
|
157 |
+
if norm_layer:
|
158 |
+
self._norm_layer = getattr(paddle.nn, norm_layer)(
|
159 |
+
num_features=out_channels, name=name + "_bn")
|
160 |
+
else:
|
161 |
+
self._norm_layer = None
|
162 |
+
if act:
|
163 |
+
if act_attr:
|
164 |
+
self._act = getattr(paddle.nn, act)(**act_attr,
|
165 |
+
name=name + "_" + act)
|
166 |
+
else:
|
167 |
+
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
|
168 |
+
else:
|
169 |
+
self._act = None
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
out = self._sn_conv_transpose(x)
|
173 |
+
if self._norm_layer:
|
174 |
+
out = self._norm_layer(out)
|
175 |
+
if self._act:
|
176 |
+
out = self._act(out)
|
177 |
+
return out
|
178 |
+
|
179 |
+
|
180 |
+
class MiddleNet(nn.Layer):
|
181 |
+
def __init__(self, name, in_channels, mid_channels, out_channels,
|
182 |
+
use_bias):
|
183 |
+
super(MiddleNet, self).__init__()
|
184 |
+
self._sn_conv1 = SNConv(
|
185 |
+
name=name + "_sn_conv1",
|
186 |
+
in_channels=in_channels,
|
187 |
+
out_channels=mid_channels,
|
188 |
+
kernel_size=1,
|
189 |
+
use_bias=use_bias,
|
190 |
+
norm_layer=None,
|
191 |
+
act=None)
|
192 |
+
self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate")
|
193 |
+
self._sn_conv2 = SNConv(
|
194 |
+
name=name + "_sn_conv2",
|
195 |
+
in_channels=mid_channels,
|
196 |
+
out_channels=mid_channels,
|
197 |
+
kernel_size=3,
|
198 |
+
use_bias=use_bias)
|
199 |
+
self._sn_conv3 = SNConv(
|
200 |
+
name=name + "_sn_conv3",
|
201 |
+
in_channels=mid_channels,
|
202 |
+
out_channels=out_channels,
|
203 |
+
kernel_size=1,
|
204 |
+
use_bias=use_bias)
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
|
208 |
+
sn_conv1 = self._sn_conv1.forward(x)
|
209 |
+
pad_2d = self._pad2d.forward(sn_conv1)
|
210 |
+
sn_conv2 = self._sn_conv2.forward(pad_2d)
|
211 |
+
sn_conv3 = self._sn_conv3.forward(sn_conv2)
|
212 |
+
return sn_conv3
|
213 |
+
|
214 |
+
|
215 |
+
class ResBlock(nn.Layer):
|
216 |
+
def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
|
217 |
+
use_bias):
|
218 |
+
super(ResBlock, self).__init__()
|
219 |
+
if use_dilation:
|
220 |
+
padding_mat = [1, 1, 1, 1]
|
221 |
+
else:
|
222 |
+
padding_mat = [0, 0, 0, 0]
|
223 |
+
self._pad1 = nn.Pad2D(padding_mat, mode="replicate")
|
224 |
+
|
225 |
+
self._sn_conv1 = SNConv(
|
226 |
+
name=name + "_sn_conv1",
|
227 |
+
in_channels=channels,
|
228 |
+
out_channels=channels,
|
229 |
+
kernel_size=3,
|
230 |
+
padding=0,
|
231 |
+
norm_layer=norm_layer,
|
232 |
+
use_bias=use_bias,
|
233 |
+
act="ReLU",
|
234 |
+
act_attr=None)
|
235 |
+
if use_dropout:
|
236 |
+
self._dropout = nn.Dropout(0.5)
|
237 |
+
else:
|
238 |
+
self._dropout = None
|
239 |
+
self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
240 |
+
self._sn_conv2 = SNConv(
|
241 |
+
name=name + "_sn_conv2",
|
242 |
+
in_channels=channels,
|
243 |
+
out_channels=channels,
|
244 |
+
kernel_size=3,
|
245 |
+
norm_layer=norm_layer,
|
246 |
+
use_bias=use_bias,
|
247 |
+
act="ReLU",
|
248 |
+
act_attr=None)
|
249 |
+
|
250 |
+
def forward(self, x):
|
251 |
+
pad1 = self._pad1.forward(x)
|
252 |
+
sn_conv1 = self._sn_conv1.forward(pad1)
|
253 |
+
pad2 = self._pad2.forward(sn_conv1)
|
254 |
+
sn_conv2 = self._sn_conv2.forward(pad2)
|
255 |
+
return sn_conv2 + x
|
Rotate/StyleText/arch/decoder.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import paddle
|
15 |
+
import paddle.nn as nn
|
16 |
+
|
17 |
+
from arch.base_module import SNConv, SNConvTranspose, ResBlock
|
18 |
+
|
19 |
+
|
20 |
+
class Decoder(nn.Layer):
|
21 |
+
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
|
22 |
+
act, act_attr, conv_block_dropout, conv_block_num,
|
23 |
+
conv_block_dilation, out_conv_act, out_conv_act_attr):
|
24 |
+
super(Decoder, self).__init__()
|
25 |
+
conv_blocks = []
|
26 |
+
for i in range(conv_block_num):
|
27 |
+
conv_blocks.append(
|
28 |
+
ResBlock(
|
29 |
+
name="{}_conv_block_{}".format(name, i),
|
30 |
+
channels=encode_dim * 8,
|
31 |
+
norm_layer=norm_layer,
|
32 |
+
use_dropout=conv_block_dropout,
|
33 |
+
use_dilation=conv_block_dilation,
|
34 |
+
use_bias=use_bias))
|
35 |
+
self.conv_blocks = nn.Sequential(*conv_blocks)
|
36 |
+
self._up1 = SNConvTranspose(
|
37 |
+
name=name + "_up1",
|
38 |
+
in_channels=encode_dim * 8,
|
39 |
+
out_channels=encode_dim * 4,
|
40 |
+
kernel_size=3,
|
41 |
+
stride=2,
|
42 |
+
padding=1,
|
43 |
+
output_padding=1,
|
44 |
+
use_bias=use_bias,
|
45 |
+
norm_layer=norm_layer,
|
46 |
+
act=act,
|
47 |
+
act_attr=act_attr)
|
48 |
+
self._up2 = SNConvTranspose(
|
49 |
+
name=name + "_up2",
|
50 |
+
in_channels=encode_dim * 4,
|
51 |
+
out_channels=encode_dim * 2,
|
52 |
+
kernel_size=3,
|
53 |
+
stride=2,
|
54 |
+
padding=1,
|
55 |
+
output_padding=1,
|
56 |
+
use_bias=use_bias,
|
57 |
+
norm_layer=norm_layer,
|
58 |
+
act=act,
|
59 |
+
act_attr=act_attr)
|
60 |
+
self._up3 = SNConvTranspose(
|
61 |
+
name=name + "_up3",
|
62 |
+
in_channels=encode_dim * 2,
|
63 |
+
out_channels=encode_dim,
|
64 |
+
kernel_size=3,
|
65 |
+
stride=2,
|
66 |
+
padding=1,
|
67 |
+
output_padding=1,
|
68 |
+
use_bias=use_bias,
|
69 |
+
norm_layer=norm_layer,
|
70 |
+
act=act,
|
71 |
+
act_attr=act_attr)
|
72 |
+
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
73 |
+
self._out_conv = SNConv(
|
74 |
+
name=name + "_out_conv",
|
75 |
+
in_channels=encode_dim,
|
76 |
+
out_channels=out_channels,
|
77 |
+
kernel_size=3,
|
78 |
+
use_bias=use_bias,
|
79 |
+
norm_layer=None,
|
80 |
+
act=out_conv_act,
|
81 |
+
act_attr=out_conv_act_attr)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
if isinstance(x, (list, tuple)):
|
85 |
+
x = paddle.concat(x, axis=1)
|
86 |
+
output_dict = dict()
|
87 |
+
output_dict["conv_blocks"] = self.conv_blocks.forward(x)
|
88 |
+
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
|
89 |
+
output_dict["up2"] = self._up2.forward(output_dict["up1"])
|
90 |
+
output_dict["up3"] = self._up3.forward(output_dict["up2"])
|
91 |
+
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
|
92 |
+
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
|
93 |
+
return output_dict
|
94 |
+
|
95 |
+
|
96 |
+
class DecoderUnet(nn.Layer):
|
97 |
+
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
|
98 |
+
act, act_attr, conv_block_dropout, conv_block_num,
|
99 |
+
conv_block_dilation, out_conv_act, out_conv_act_attr):
|
100 |
+
super(DecoderUnet, self).__init__()
|
101 |
+
conv_blocks = []
|
102 |
+
for i in range(conv_block_num):
|
103 |
+
conv_blocks.append(
|
104 |
+
ResBlock(
|
105 |
+
name="{}_conv_block_{}".format(name, i),
|
106 |
+
channels=encode_dim * 8,
|
107 |
+
norm_layer=norm_layer,
|
108 |
+
use_dropout=conv_block_dropout,
|
109 |
+
use_dilation=conv_block_dilation,
|
110 |
+
use_bias=use_bias))
|
111 |
+
self._conv_blocks = nn.Sequential(*conv_blocks)
|
112 |
+
self._up1 = SNConvTranspose(
|
113 |
+
name=name + "_up1",
|
114 |
+
in_channels=encode_dim * 8,
|
115 |
+
out_channels=encode_dim * 4,
|
116 |
+
kernel_size=3,
|
117 |
+
stride=2,
|
118 |
+
padding=1,
|
119 |
+
output_padding=1,
|
120 |
+
use_bias=use_bias,
|
121 |
+
norm_layer=norm_layer,
|
122 |
+
act=act,
|
123 |
+
act_attr=act_attr)
|
124 |
+
self._up2 = SNConvTranspose(
|
125 |
+
name=name + "_up2",
|
126 |
+
in_channels=encode_dim * 8,
|
127 |
+
out_channels=encode_dim * 2,
|
128 |
+
kernel_size=3,
|
129 |
+
stride=2,
|
130 |
+
padding=1,
|
131 |
+
output_padding=1,
|
132 |
+
use_bias=use_bias,
|
133 |
+
norm_layer=norm_layer,
|
134 |
+
act=act,
|
135 |
+
act_attr=act_attr)
|
136 |
+
self._up3 = SNConvTranspose(
|
137 |
+
name=name + "_up3",
|
138 |
+
in_channels=encode_dim * 4,
|
139 |
+
out_channels=encode_dim,
|
140 |
+
kernel_size=3,
|
141 |
+
stride=2,
|
142 |
+
padding=1,
|
143 |
+
output_padding=1,
|
144 |
+
use_bias=use_bias,
|
145 |
+
norm_layer=norm_layer,
|
146 |
+
act=act,
|
147 |
+
act_attr=act_attr)
|
148 |
+
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
149 |
+
self._out_conv = SNConv(
|
150 |
+
name=name + "_out_conv",
|
151 |
+
in_channels=encode_dim,
|
152 |
+
out_channels=out_channels,
|
153 |
+
kernel_size=3,
|
154 |
+
use_bias=use_bias,
|
155 |
+
norm_layer=None,
|
156 |
+
act=out_conv_act,
|
157 |
+
act_attr=out_conv_act_attr)
|
158 |
+
|
159 |
+
def forward(self, x, y, feature2, feature1):
|
160 |
+
output_dict = dict()
|
161 |
+
output_dict["conv_blocks"] = self._conv_blocks(
|
162 |
+
paddle.concat(
|
163 |
+
(x, y), axis=1))
|
164 |
+
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
|
165 |
+
output_dict["up2"] = self._up2.forward(
|
166 |
+
paddle.concat(
|
167 |
+
(output_dict["up1"], feature2), axis=1))
|
168 |
+
output_dict["up3"] = self._up3.forward(
|
169 |
+
paddle.concat(
|
170 |
+
(output_dict["up2"], feature1), axis=1))
|
171 |
+
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
|
172 |
+
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
|
173 |
+
return output_dict
|
174 |
+
|
175 |
+
|
176 |
+
class SingleDecoder(nn.Layer):
|
177 |
+
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
|
178 |
+
act, act_attr, conv_block_dropout, conv_block_num,
|
179 |
+
conv_block_dilation, out_conv_act, out_conv_act_attr):
|
180 |
+
super(SingleDecoder, self).__init__()
|
181 |
+
conv_blocks = []
|
182 |
+
for i in range(conv_block_num):
|
183 |
+
conv_blocks.append(
|
184 |
+
ResBlock(
|
185 |
+
name="{}_conv_block_{}".format(name, i),
|
186 |
+
channels=encode_dim * 4,
|
187 |
+
norm_layer=norm_layer,
|
188 |
+
use_dropout=conv_block_dropout,
|
189 |
+
use_dilation=conv_block_dilation,
|
190 |
+
use_bias=use_bias))
|
191 |
+
self._conv_blocks = nn.Sequential(*conv_blocks)
|
192 |
+
self._up1 = SNConvTranspose(
|
193 |
+
name=name + "_up1",
|
194 |
+
in_channels=encode_dim * 4,
|
195 |
+
out_channels=encode_dim * 4,
|
196 |
+
kernel_size=3,
|
197 |
+
stride=2,
|
198 |
+
padding=1,
|
199 |
+
output_padding=1,
|
200 |
+
use_bias=use_bias,
|
201 |
+
norm_layer=norm_layer,
|
202 |
+
act=act,
|
203 |
+
act_attr=act_attr)
|
204 |
+
self._up2 = SNConvTranspose(
|
205 |
+
name=name + "_up2",
|
206 |
+
in_channels=encode_dim * 8,
|
207 |
+
out_channels=encode_dim * 2,
|
208 |
+
kernel_size=3,
|
209 |
+
stride=2,
|
210 |
+
padding=1,
|
211 |
+
output_padding=1,
|
212 |
+
use_bias=use_bias,
|
213 |
+
norm_layer=norm_layer,
|
214 |
+
act=act,
|
215 |
+
act_attr=act_attr)
|
216 |
+
self._up3 = SNConvTranspose(
|
217 |
+
name=name + "_up3",
|
218 |
+
in_channels=encode_dim * 4,
|
219 |
+
out_channels=encode_dim,
|
220 |
+
kernel_size=3,
|
221 |
+
stride=2,
|
222 |
+
padding=1,
|
223 |
+
output_padding=1,
|
224 |
+
use_bias=use_bias,
|
225 |
+
norm_layer=norm_layer,
|
226 |
+
act=act,
|
227 |
+
act_attr=act_attr)
|
228 |
+
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
229 |
+
self._out_conv = SNConv(
|
230 |
+
name=name + "_out_conv",
|
231 |
+
in_channels=encode_dim,
|
232 |
+
out_channels=out_channels,
|
233 |
+
kernel_size=3,
|
234 |
+
use_bias=use_bias,
|
235 |
+
norm_layer=None,
|
236 |
+
act=out_conv_act,
|
237 |
+
act_attr=out_conv_act_attr)
|
238 |
+
|
239 |
+
def forward(self, x, feature2, feature1):
|
240 |
+
output_dict = dict()
|
241 |
+
output_dict["conv_blocks"] = self._conv_blocks.forward(x)
|
242 |
+
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
|
243 |
+
output_dict["up2"] = self._up2.forward(
|
244 |
+
paddle.concat(
|
245 |
+
(output_dict["up1"], feature2), axis=1))
|
246 |
+
output_dict["up3"] = self._up3.forward(
|
247 |
+
paddle.concat(
|
248 |
+
(output_dict["up2"], feature1), axis=1))
|
249 |
+
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
|
250 |
+
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
|
251 |
+
return output_dict
|
Rotate/StyleText/arch/encoder.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import paddle
|
15 |
+
import paddle.nn as nn
|
16 |
+
|
17 |
+
from arch.base_module import SNConv, SNConvTranspose, ResBlock
|
18 |
+
|
19 |
+
|
20 |
+
class Encoder(nn.Layer):
|
21 |
+
def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
|
22 |
+
act, act_attr, conv_block_dropout, conv_block_num,
|
23 |
+
conv_block_dilation):
|
24 |
+
super(Encoder, self).__init__()
|
25 |
+
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
|
26 |
+
self._in_conv = SNConv(
|
27 |
+
name=name + "_in_conv",
|
28 |
+
in_channels=in_channels,
|
29 |
+
out_channels=encode_dim,
|
30 |
+
kernel_size=7,
|
31 |
+
use_bias=use_bias,
|
32 |
+
norm_layer=norm_layer,
|
33 |
+
act=act,
|
34 |
+
act_attr=act_attr)
|
35 |
+
self._down1 = SNConv(
|
36 |
+
name=name + "_down1",
|
37 |
+
in_channels=encode_dim,
|
38 |
+
out_channels=encode_dim * 2,
|
39 |
+
kernel_size=3,
|
40 |
+
stride=2,
|
41 |
+
padding=1,
|
42 |
+
use_bias=use_bias,
|
43 |
+
norm_layer=norm_layer,
|
44 |
+
act=act,
|
45 |
+
act_attr=act_attr)
|
46 |
+
self._down2 = SNConv(
|
47 |
+
name=name + "_down2",
|
48 |
+
in_channels=encode_dim * 2,
|
49 |
+
out_channels=encode_dim * 4,
|
50 |
+
kernel_size=3,
|
51 |
+
stride=2,
|
52 |
+
padding=1,
|
53 |
+
use_bias=use_bias,
|
54 |
+
norm_layer=norm_layer,
|
55 |
+
act=act,
|
56 |
+
act_attr=act_attr)
|
57 |
+
self._down3 = SNConv(
|
58 |
+
name=name + "_down3",
|
59 |
+
in_channels=encode_dim * 4,
|
60 |
+
out_channels=encode_dim * 4,
|
61 |
+
kernel_size=3,
|
62 |
+
stride=2,
|
63 |
+
padding=1,
|
64 |
+
use_bias=use_bias,
|
65 |
+
norm_layer=norm_layer,
|
66 |
+
act=act,
|
67 |
+
act_attr=act_attr)
|
68 |
+
conv_blocks = []
|
69 |
+
for i in range(conv_block_num):
|
70 |
+
conv_blocks.append(
|
71 |
+
ResBlock(
|
72 |
+
name="{}_conv_block_{}".format(name, i),
|
73 |
+
channels=encode_dim * 4,
|
74 |
+
norm_layer=norm_layer,
|
75 |
+
use_dropout=conv_block_dropout,
|
76 |
+
use_dilation=conv_block_dilation,
|
77 |
+
use_bias=use_bias))
|
78 |
+
self._conv_blocks = nn.Sequential(*conv_blocks)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
out_dict = dict()
|
82 |
+
x = self._pad2d(x)
|
83 |
+
out_dict["in_conv"] = self._in_conv.forward(x)
|
84 |
+
out_dict["down1"] = self._down1.forward(out_dict["in_conv"])
|
85 |
+
out_dict["down2"] = self._down2.forward(out_dict["down1"])
|
86 |
+
out_dict["down3"] = self._down3.forward(out_dict["down2"])
|
87 |
+
out_dict["res_blocks"] = self._conv_blocks.forward(out_dict["down3"])
|
88 |
+
return out_dict
|
89 |
+
|
90 |
+
|
91 |
+
class EncoderUnet(nn.Layer):
|
92 |
+
def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
|
93 |
+
act, act_attr):
|
94 |
+
super(EncoderUnet, self).__init__()
|
95 |
+
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
|
96 |
+
self._in_conv = SNConv(
|
97 |
+
name=name + "_in_conv",
|
98 |
+
in_channels=in_channels,
|
99 |
+
out_channels=encode_dim,
|
100 |
+
kernel_size=7,
|
101 |
+
use_bias=use_bias,
|
102 |
+
norm_layer=norm_layer,
|
103 |
+
act=act,
|
104 |
+
act_attr=act_attr)
|
105 |
+
self._down1 = SNConv(
|
106 |
+
name=name + "_down1",
|
107 |
+
in_channels=encode_dim,
|
108 |
+
out_channels=encode_dim * 2,
|
109 |
+
kernel_size=3,
|
110 |
+
stride=2,
|
111 |
+
padding=1,
|
112 |
+
use_bias=use_bias,
|
113 |
+
norm_layer=norm_layer,
|
114 |
+
act=act,
|
115 |
+
act_attr=act_attr)
|
116 |
+
self._down2 = SNConv(
|
117 |
+
name=name + "_down2",
|
118 |
+
in_channels=encode_dim * 2,
|
119 |
+
out_channels=encode_dim * 2,
|
120 |
+
kernel_size=3,
|
121 |
+
stride=2,
|
122 |
+
padding=1,
|
123 |
+
use_bias=use_bias,
|
124 |
+
norm_layer=norm_layer,
|
125 |
+
act=act,
|
126 |
+
act_attr=act_attr)
|
127 |
+
self._down3 = SNConv(
|
128 |
+
name=name + "_down3",
|
129 |
+
in_channels=encode_dim * 2,
|
130 |
+
out_channels=encode_dim * 2,
|
131 |
+
kernel_size=3,
|
132 |
+
stride=2,
|
133 |
+
padding=1,
|
134 |
+
use_bias=use_bias,
|
135 |
+
norm_layer=norm_layer,
|
136 |
+
act=act,
|
137 |
+
act_attr=act_attr)
|
138 |
+
self._down4 = SNConv(
|
139 |
+
name=name + "_down4",
|
140 |
+
in_channels=encode_dim * 2,
|
141 |
+
out_channels=encode_dim * 2,
|
142 |
+
kernel_size=3,
|
143 |
+
stride=2,
|
144 |
+
padding=1,
|
145 |
+
use_bias=use_bias,
|
146 |
+
norm_layer=norm_layer,
|
147 |
+
act=act,
|
148 |
+
act_attr=act_attr)
|
149 |
+
self._up1 = SNConvTranspose(
|
150 |
+
name=name + "_up1",
|
151 |
+
in_channels=encode_dim * 2,
|
152 |
+
out_channels=encode_dim * 2,
|
153 |
+
kernel_size=3,
|
154 |
+
stride=2,
|
155 |
+
padding=1,
|
156 |
+
use_bias=use_bias,
|
157 |
+
norm_layer=norm_layer,
|
158 |
+
act=act,
|
159 |
+
act_attr=act_attr)
|
160 |
+
self._up2 = SNConvTranspose(
|
161 |
+
name=name + "_up2",
|
162 |
+
in_channels=encode_dim * 4,
|
163 |
+
out_channels=encode_dim * 4,
|
164 |
+
kernel_size=3,
|
165 |
+
stride=2,
|
166 |
+
padding=1,
|
167 |
+
use_bias=use_bias,
|
168 |
+
norm_layer=norm_layer,
|
169 |
+
act=act,
|
170 |
+
act_attr=act_attr)
|
171 |
+
|
172 |
+
def forward(self, x):
|
173 |
+
output_dict = dict()
|
174 |
+
x = self._pad2d(x)
|
175 |
+
output_dict['in_conv'] = self._in_conv.forward(x)
|
176 |
+
output_dict['down1'] = self._down1.forward(output_dict['in_conv'])
|
177 |
+
output_dict['down2'] = self._down2.forward(output_dict['down1'])
|
178 |
+
output_dict['down3'] = self._down3.forward(output_dict['down2'])
|
179 |
+
output_dict['down4'] = self._down4.forward(output_dict['down3'])
|
180 |
+
output_dict['up1'] = self._up1.forward(output_dict['down4'])
|
181 |
+
output_dict['up2'] = self._up2.forward(
|
182 |
+
paddle.concat(
|
183 |
+
(output_dict['down3'], output_dict['up1']), axis=1))
|
184 |
+
output_dict['concat'] = paddle.concat(
|
185 |
+
(output_dict['down2'], output_dict['up2']), axis=1)
|
186 |
+
return output_dict
|
Rotate/StyleText/arch/spectral_norm.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import paddle
|
15 |
+
import paddle.nn as nn
|
16 |
+
import paddle.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
def normal_(x, mean=0., std=1.):
|
20 |
+
temp_value = paddle.normal(mean, std, shape=x.shape)
|
21 |
+
x.set_value(temp_value)
|
22 |
+
return x
|
23 |
+
|
24 |
+
|
25 |
+
class SpectralNorm(object):
|
26 |
+
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
|
27 |
+
self.name = name
|
28 |
+
self.dim = dim
|
29 |
+
if n_power_iterations <= 0:
|
30 |
+
raise ValueError('Expected n_power_iterations to be positive, but '
|
31 |
+
'got n_power_iterations={}'.format(
|
32 |
+
n_power_iterations))
|
33 |
+
self.n_power_iterations = n_power_iterations
|
34 |
+
self.eps = eps
|
35 |
+
|
36 |
+
def reshape_weight_to_matrix(self, weight):
|
37 |
+
weight_mat = weight
|
38 |
+
if self.dim != 0:
|
39 |
+
# transpose dim to front
|
40 |
+
weight_mat = weight_mat.transpose([
|
41 |
+
self.dim,
|
42 |
+
* [d for d in range(weight_mat.dim()) if d != self.dim]
|
43 |
+
])
|
44 |
+
|
45 |
+
height = weight_mat.shape[0]
|
46 |
+
|
47 |
+
return weight_mat.reshape([height, -1])
|
48 |
+
|
49 |
+
def compute_weight(self, module, do_power_iteration):
|
50 |
+
weight = getattr(module, self.name + '_orig')
|
51 |
+
u = getattr(module, self.name + '_u')
|
52 |
+
v = getattr(module, self.name + '_v')
|
53 |
+
weight_mat = self.reshape_weight_to_matrix(weight)
|
54 |
+
|
55 |
+
if do_power_iteration:
|
56 |
+
with paddle.no_grad():
|
57 |
+
for _ in range(self.n_power_iterations):
|
58 |
+
v.set_value(
|
59 |
+
F.normalize(
|
60 |
+
paddle.matmul(
|
61 |
+
weight_mat,
|
62 |
+
u,
|
63 |
+
transpose_x=True,
|
64 |
+
transpose_y=False),
|
65 |
+
axis=0,
|
66 |
+
epsilon=self.eps, ))
|
67 |
+
|
68 |
+
u.set_value(
|
69 |
+
F.normalize(
|
70 |
+
paddle.matmul(weight_mat, v),
|
71 |
+
axis=0,
|
72 |
+
epsilon=self.eps, ))
|
73 |
+
if self.n_power_iterations > 0:
|
74 |
+
u = u.clone()
|
75 |
+
v = v.clone()
|
76 |
+
|
77 |
+
sigma = paddle.dot(u, paddle.mv(weight_mat, v))
|
78 |
+
weight = weight / sigma
|
79 |
+
return weight
|
80 |
+
|
81 |
+
def remove(self, module):
|
82 |
+
with paddle.no_grad():
|
83 |
+
weight = self.compute_weight(module, do_power_iteration=False)
|
84 |
+
delattr(module, self.name)
|
85 |
+
delattr(module, self.name + '_u')
|
86 |
+
delattr(module, self.name + '_v')
|
87 |
+
delattr(module, self.name + '_orig')
|
88 |
+
|
89 |
+
module.add_parameter(self.name, weight.detach())
|
90 |
+
|
91 |
+
def __call__(self, module, inputs):
|
92 |
+
setattr(
|
93 |
+
module,
|
94 |
+
self.name,
|
95 |
+
self.compute_weight(
|
96 |
+
module, do_power_iteration=module.training))
|
97 |
+
|
98 |
+
@staticmethod
|
99 |
+
def apply(module, name, n_power_iterations, dim, eps):
|
100 |
+
for k, hook in module._forward_pre_hooks.items():
|
101 |
+
if isinstance(hook, SpectralNorm) and hook.name == name:
|
102 |
+
raise RuntimeError(
|
103 |
+
"Cannot register two spectral_norm hooks on "
|
104 |
+
"the same parameter {}".format(name))
|
105 |
+
|
106 |
+
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
107 |
+
weight = module._parameters[name]
|
108 |
+
|
109 |
+
with paddle.no_grad():
|
110 |
+
weight_mat = fn.reshape_weight_to_matrix(weight)
|
111 |
+
h, w = weight_mat.shape
|
112 |
+
|
113 |
+
# randomly initialize u and v
|
114 |
+
u = module.create_parameter([h])
|
115 |
+
u = normal_(u, 0., 1.)
|
116 |
+
v = module.create_parameter([w])
|
117 |
+
v = normal_(v, 0., 1.)
|
118 |
+
u = F.normalize(u, axis=0, epsilon=fn.eps)
|
119 |
+
v = F.normalize(v, axis=0, epsilon=fn.eps)
|
120 |
+
|
121 |
+
# delete fn.name form parameters, otherwise you can not set attribute
|
122 |
+
del module._parameters[fn.name]
|
123 |
+
module.add_parameter(fn.name + "_orig", weight)
|
124 |
+
# still need to assign weight back as fn.name because all sorts of
|
125 |
+
# things may assume that it exists, e.g., when initializing weights.
|
126 |
+
# However, we can't directly assign as it could be an Parameter and
|
127 |
+
# gets added as a parameter. Instead, we register weight * 1.0 as a plain
|
128 |
+
# attribute.
|
129 |
+
setattr(module, fn.name, weight * 1.0)
|
130 |
+
module.register_buffer(fn.name + "_u", u)
|
131 |
+
module.register_buffer(fn.name + "_v", v)
|
132 |
+
|
133 |
+
module.register_forward_pre_hook(fn)
|
134 |
+
return fn
|
135 |
+
|
136 |
+
|
137 |
+
def spectral_norm(module,
|
138 |
+
name='weight',
|
139 |
+
n_power_iterations=1,
|
140 |
+
eps=1e-12,
|
141 |
+
dim=None):
|
142 |
+
|
143 |
+
if dim is None:
|
144 |
+
if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose,
|
145 |
+
nn.Conv3DTranspose, nn.Linear)):
|
146 |
+
dim = 1
|
147 |
+
else:
|
148 |
+
dim = 0
|
149 |
+
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
150 |
+
return module
|
Rotate/StyleText/arch/style_text_rec.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import paddle
|
15 |
+
import paddle.nn as nn
|
16 |
+
|
17 |
+
from arch.base_module import MiddleNet, ResBlock
|
18 |
+
from arch.encoder import Encoder
|
19 |
+
from arch.decoder import Decoder, DecoderUnet, SingleDecoder
|
20 |
+
from utils.load_params import load_dygraph_pretrain
|
21 |
+
from utils.logging import get_logger
|
22 |
+
|
23 |
+
|
24 |
+
class StyleTextRec(nn.Layer):
|
25 |
+
def __init__(self, config):
|
26 |
+
super(StyleTextRec, self).__init__()
|
27 |
+
self.logger = get_logger()
|
28 |
+
self.text_generator = TextGenerator(config["Predictor"][
|
29 |
+
"text_generator"])
|
30 |
+
self.bg_generator = BgGeneratorWithMask(config["Predictor"][
|
31 |
+
"bg_generator"])
|
32 |
+
self.fusion_generator = FusionGeneratorSimple(config["Predictor"][
|
33 |
+
"fusion_generator"])
|
34 |
+
bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"]
|
35 |
+
text_generator_pretrain = config["Predictor"]["text_generator"][
|
36 |
+
"pretrain"]
|
37 |
+
fusion_generator_pretrain = config["Predictor"]["fusion_generator"][
|
38 |
+
"pretrain"]
|
39 |
+
load_dygraph_pretrain(
|
40 |
+
self.bg_generator,
|
41 |
+
self.logger,
|
42 |
+
path=bg_generator_pretrain,
|
43 |
+
load_static_weights=False)
|
44 |
+
load_dygraph_pretrain(
|
45 |
+
self.text_generator,
|
46 |
+
self.logger,
|
47 |
+
path=text_generator_pretrain,
|
48 |
+
load_static_weights=False)
|
49 |
+
load_dygraph_pretrain(
|
50 |
+
self.fusion_generator,
|
51 |
+
self.logger,
|
52 |
+
path=fusion_generator_pretrain,
|
53 |
+
load_static_weights=False)
|
54 |
+
|
55 |
+
def forward(self, style_input, text_input):
|
56 |
+
text_gen_output = self.text_generator.forward(style_input, text_input)
|
57 |
+
fake_text = text_gen_output["fake_text"]
|
58 |
+
fake_sk = text_gen_output["fake_sk"]
|
59 |
+
bg_gen_output = self.bg_generator.forward(style_input)
|
60 |
+
bg_encode_feature = bg_gen_output["bg_encode_feature"]
|
61 |
+
bg_decode_feature1 = bg_gen_output["bg_decode_feature1"]
|
62 |
+
bg_decode_feature2 = bg_gen_output["bg_decode_feature2"]
|
63 |
+
fake_bg = bg_gen_output["fake_bg"]
|
64 |
+
|
65 |
+
fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg)
|
66 |
+
fake_fusion = fusion_gen_output["fake_fusion"]
|
67 |
+
return {
|
68 |
+
"fake_fusion": fake_fusion,
|
69 |
+
"fake_text": fake_text,
|
70 |
+
"fake_sk": fake_sk,
|
71 |
+
"fake_bg": fake_bg,
|
72 |
+
}
|
73 |
+
|
74 |
+
|
75 |
+
class TextGenerator(nn.Layer):
|
76 |
+
def __init__(self, config):
|
77 |
+
super(TextGenerator, self).__init__()
|
78 |
+
name = config["module_name"]
|
79 |
+
encode_dim = config["encode_dim"]
|
80 |
+
norm_layer = config["norm_layer"]
|
81 |
+
conv_block_dropout = config["conv_block_dropout"]
|
82 |
+
conv_block_num = config["conv_block_num"]
|
83 |
+
conv_block_dilation = config["conv_block_dilation"]
|
84 |
+
if norm_layer == "InstanceNorm2D":
|
85 |
+
use_bias = True
|
86 |
+
else:
|
87 |
+
use_bias = False
|
88 |
+
self.encoder_text = Encoder(
|
89 |
+
name=name + "_encoder_text",
|
90 |
+
in_channels=3,
|
91 |
+
encode_dim=encode_dim,
|
92 |
+
use_bias=use_bias,
|
93 |
+
norm_layer=norm_layer,
|
94 |
+
act="ReLU",
|
95 |
+
act_attr=None,
|
96 |
+
conv_block_dropout=conv_block_dropout,
|
97 |
+
conv_block_num=conv_block_num,
|
98 |
+
conv_block_dilation=conv_block_dilation)
|
99 |
+
self.encoder_style = Encoder(
|
100 |
+
name=name + "_encoder_style",
|
101 |
+
in_channels=3,
|
102 |
+
encode_dim=encode_dim,
|
103 |
+
use_bias=use_bias,
|
104 |
+
norm_layer=norm_layer,
|
105 |
+
act="ReLU",
|
106 |
+
act_attr=None,
|
107 |
+
conv_block_dropout=conv_block_dropout,
|
108 |
+
conv_block_num=conv_block_num,
|
109 |
+
conv_block_dilation=conv_block_dilation)
|
110 |
+
self.decoder_text = Decoder(
|
111 |
+
name=name + "_decoder_text",
|
112 |
+
encode_dim=encode_dim,
|
113 |
+
out_channels=int(encode_dim / 2),
|
114 |
+
use_bias=use_bias,
|
115 |
+
norm_layer=norm_layer,
|
116 |
+
act="ReLU",
|
117 |
+
act_attr=None,
|
118 |
+
conv_block_dropout=conv_block_dropout,
|
119 |
+
conv_block_num=conv_block_num,
|
120 |
+
conv_block_dilation=conv_block_dilation,
|
121 |
+
out_conv_act="Tanh",
|
122 |
+
out_conv_act_attr=None)
|
123 |
+
self.decoder_sk = Decoder(
|
124 |
+
name=name + "_decoder_sk",
|
125 |
+
encode_dim=encode_dim,
|
126 |
+
out_channels=1,
|
127 |
+
use_bias=use_bias,
|
128 |
+
norm_layer=norm_layer,
|
129 |
+
act="ReLU",
|
130 |
+
act_attr=None,
|
131 |
+
conv_block_dropout=conv_block_dropout,
|
132 |
+
conv_block_num=conv_block_num,
|
133 |
+
conv_block_dilation=conv_block_dilation,
|
134 |
+
out_conv_act="Sigmoid",
|
135 |
+
out_conv_act_attr=None)
|
136 |
+
|
137 |
+
self.middle = MiddleNet(
|
138 |
+
name=name + "_middle_net",
|
139 |
+
in_channels=int(encode_dim / 2) + 1,
|
140 |
+
mid_channels=encode_dim,
|
141 |
+
out_channels=3,
|
142 |
+
use_bias=use_bias)
|
143 |
+
|
144 |
+
def forward(self, style_input, text_input):
|
145 |
+
style_feature = self.encoder_style.forward(style_input)["res_blocks"]
|
146 |
+
text_feature = self.encoder_text.forward(text_input)["res_blocks"]
|
147 |
+
fake_c_temp = self.decoder_text.forward([text_feature,
|
148 |
+
style_feature])["out_conv"]
|
149 |
+
fake_sk = self.decoder_sk.forward([text_feature,
|
150 |
+
style_feature])["out_conv"]
|
151 |
+
fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1))
|
152 |
+
return {"fake_sk": fake_sk, "fake_text": fake_text}
|
153 |
+
|
154 |
+
|
155 |
+
class BgGeneratorWithMask(nn.Layer):
|
156 |
+
def __init__(self, config):
|
157 |
+
super(BgGeneratorWithMask, self).__init__()
|
158 |
+
name = config["module_name"]
|
159 |
+
encode_dim = config["encode_dim"]
|
160 |
+
norm_layer = config["norm_layer"]
|
161 |
+
conv_block_dropout = config["conv_block_dropout"]
|
162 |
+
conv_block_num = config["conv_block_num"]
|
163 |
+
conv_block_dilation = config["conv_block_dilation"]
|
164 |
+
self.output_factor = config.get("output_factor", 1.0)
|
165 |
+
|
166 |
+
if norm_layer == "InstanceNorm2D":
|
167 |
+
use_bias = True
|
168 |
+
else:
|
169 |
+
use_bias = False
|
170 |
+
|
171 |
+
self.encoder_bg = Encoder(
|
172 |
+
name=name + "_encoder_bg",
|
173 |
+
in_channels=3,
|
174 |
+
encode_dim=encode_dim,
|
175 |
+
use_bias=use_bias,
|
176 |
+
norm_layer=norm_layer,
|
177 |
+
act="ReLU",
|
178 |
+
act_attr=None,
|
179 |
+
conv_block_dropout=conv_block_dropout,
|
180 |
+
conv_block_num=conv_block_num,
|
181 |
+
conv_block_dilation=conv_block_dilation)
|
182 |
+
|
183 |
+
self.decoder_bg = SingleDecoder(
|
184 |
+
name=name + "_decoder_bg",
|
185 |
+
encode_dim=encode_dim,
|
186 |
+
out_channels=3,
|
187 |
+
use_bias=use_bias,
|
188 |
+
norm_layer=norm_layer,
|
189 |
+
act="ReLU",
|
190 |
+
act_attr=None,
|
191 |
+
conv_block_dropout=conv_block_dropout,
|
192 |
+
conv_block_num=conv_block_num,
|
193 |
+
conv_block_dilation=conv_block_dilation,
|
194 |
+
out_conv_act="Tanh",
|
195 |
+
out_conv_act_attr=None)
|
196 |
+
|
197 |
+
self.decoder_mask = Decoder(
|
198 |
+
name=name + "_decoder_mask",
|
199 |
+
encode_dim=encode_dim // 2,
|
200 |
+
out_channels=1,
|
201 |
+
use_bias=use_bias,
|
202 |
+
norm_layer=norm_layer,
|
203 |
+
act="ReLU",
|
204 |
+
act_attr=None,
|
205 |
+
conv_block_dropout=conv_block_dropout,
|
206 |
+
conv_block_num=conv_block_num,
|
207 |
+
conv_block_dilation=conv_block_dilation,
|
208 |
+
out_conv_act="Sigmoid",
|
209 |
+
out_conv_act_attr=None)
|
210 |
+
|
211 |
+
self.middle = MiddleNet(
|
212 |
+
name=name + "_middle_net",
|
213 |
+
in_channels=3 + 1,
|
214 |
+
mid_channels=encode_dim,
|
215 |
+
out_channels=3,
|
216 |
+
use_bias=use_bias)
|
217 |
+
|
218 |
+
def forward(self, style_input):
|
219 |
+
encode_bg_output = self.encoder_bg(style_input)
|
220 |
+
decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"],
|
221 |
+
encode_bg_output["down2"],
|
222 |
+
encode_bg_output["down1"])
|
223 |
+
|
224 |
+
fake_c_temp = decode_bg_output["out_conv"]
|
225 |
+
fake_bg_mask = self.decoder_mask.forward(encode_bg_output[
|
226 |
+
"res_blocks"])["out_conv"]
|
227 |
+
fake_bg = self.middle(
|
228 |
+
paddle.concat(
|
229 |
+
(fake_c_temp, fake_bg_mask), axis=1))
|
230 |
+
return {
|
231 |
+
"bg_encode_feature": encode_bg_output["res_blocks"],
|
232 |
+
"bg_decode_feature1": decode_bg_output["up1"],
|
233 |
+
"bg_decode_feature2": decode_bg_output["up2"],
|
234 |
+
"fake_bg": fake_bg,
|
235 |
+
"fake_bg_mask": fake_bg_mask,
|
236 |
+
}
|
237 |
+
|
238 |
+
|
239 |
+
class FusionGeneratorSimple(nn.Layer):
|
240 |
+
def __init__(self, config):
|
241 |
+
super(FusionGeneratorSimple, self).__init__()
|
242 |
+
name = config["module_name"]
|
243 |
+
encode_dim = config["encode_dim"]
|
244 |
+
norm_layer = config["norm_layer"]
|
245 |
+
conv_block_dropout = config["conv_block_dropout"]
|
246 |
+
conv_block_dilation = config["conv_block_dilation"]
|
247 |
+
if norm_layer == "InstanceNorm2D":
|
248 |
+
use_bias = True
|
249 |
+
else:
|
250 |
+
use_bias = False
|
251 |
+
|
252 |
+
self._conv = nn.Conv2D(
|
253 |
+
in_channels=6,
|
254 |
+
out_channels=encode_dim,
|
255 |
+
kernel_size=3,
|
256 |
+
stride=1,
|
257 |
+
padding=1,
|
258 |
+
groups=1,
|
259 |
+
weight_attr=paddle.ParamAttr(name=name + "_conv_weights"),
|
260 |
+
bias_attr=False)
|
261 |
+
|
262 |
+
self._res_block = ResBlock(
|
263 |
+
name="{}_conv_block".format(name),
|
264 |
+
channels=encode_dim,
|
265 |
+
norm_layer=norm_layer,
|
266 |
+
use_dropout=conv_block_dropout,
|
267 |
+
use_dilation=conv_block_dilation,
|
268 |
+
use_bias=use_bias)
|
269 |
+
|
270 |
+
self._reduce_conv = nn.Conv2D(
|
271 |
+
in_channels=encode_dim,
|
272 |
+
out_channels=3,
|
273 |
+
kernel_size=3,
|
274 |
+
stride=1,
|
275 |
+
padding=1,
|
276 |
+
groups=1,
|
277 |
+
weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"),
|
278 |
+
bias_attr=False)
|
279 |
+
|
280 |
+
def forward(self, fake_text, fake_bg):
|
281 |
+
fake_concat = paddle.concat((fake_text, fake_bg), axis=1)
|
282 |
+
fake_concat_tmp = self._conv(fake_concat)
|
283 |
+
output_res = self._res_block(fake_concat_tmp)
|
284 |
+
fake_fusion = self._reduce_conv(output_res)
|
285 |
+
return {"fake_fusion": fake_fusion}
|
Rotate/StyleText/configs/config.yml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
output_num: 10
|
3 |
+
output_dir: output_data
|
4 |
+
use_gpu: false
|
5 |
+
image_height: 32
|
6 |
+
image_width: 320
|
7 |
+
TextDrawer:
|
8 |
+
fonts:
|
9 |
+
en: fonts/en_standard.ttf
|
10 |
+
ch: fonts/ch_standard.ttf
|
11 |
+
ko: fonts/ko_standard.ttf
|
12 |
+
Predictor:
|
13 |
+
method: StyleTextRecPredictor
|
14 |
+
algorithm: StyleTextRec
|
15 |
+
scale: 0.00392156862745098
|
16 |
+
mean:
|
17 |
+
- 0.5
|
18 |
+
- 0.5
|
19 |
+
- 0.5
|
20 |
+
std:
|
21 |
+
- 0.5
|
22 |
+
- 0.5
|
23 |
+
- 0.5
|
24 |
+
expand_result: false
|
25 |
+
bg_generator:
|
26 |
+
pretrain: style_text_models/bg_generator
|
27 |
+
module_name: bg_generator
|
28 |
+
generator_type: BgGeneratorWithMask
|
29 |
+
encode_dim: 64
|
30 |
+
norm_layer: null
|
31 |
+
conv_block_num: 4
|
32 |
+
conv_block_dropout: false
|
33 |
+
conv_block_dilation: true
|
34 |
+
output_factor: 1.05
|
35 |
+
text_generator:
|
36 |
+
pretrain: style_text_models/text_generator
|
37 |
+
module_name: text_generator
|
38 |
+
generator_type: TextGenerator
|
39 |
+
encode_dim: 64
|
40 |
+
norm_layer: InstanceNorm2D
|
41 |
+
conv_block_num: 4
|
42 |
+
conv_block_dropout: false
|
43 |
+
conv_block_dilation: true
|
44 |
+
fusion_generator:
|
45 |
+
pretrain: style_text_models/fusion_generator
|
46 |
+
module_name: fusion_generator
|
47 |
+
generator_type: FusionGeneratorSimple
|
48 |
+
encode_dim: 64
|
49 |
+
norm_layer: null
|
50 |
+
conv_block_num: 4
|
51 |
+
conv_block_dropout: false
|
52 |
+
conv_block_dilation: true
|
53 |
+
Writer:
|
54 |
+
method: SimpleWriter
|
Rotate/StyleText/configs/dataset_config.yml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
output_num: 10
|
3 |
+
output_dir: output_data
|
4 |
+
use_gpu: false
|
5 |
+
image_height: 32
|
6 |
+
image_width: 320
|
7 |
+
standard_font: fonts/en_standard.ttf
|
8 |
+
TextDrawer:
|
9 |
+
fonts:
|
10 |
+
en: fonts/en_standard.ttf
|
11 |
+
ch: fonts/ch_standard.ttf
|
12 |
+
ko: fonts/ko_standard.ttf
|
13 |
+
StyleSampler:
|
14 |
+
method: DatasetSampler
|
15 |
+
image_home: examples
|
16 |
+
label_file: examples/image_list.txt
|
17 |
+
with_label: true
|
18 |
+
CorpusGenerator:
|
19 |
+
method: FileCorpus
|
20 |
+
language: ch
|
21 |
+
corpus_file: examples/corpus/example.txt
|
22 |
+
Predictor:
|
23 |
+
method: StyleTextRecPredictor
|
24 |
+
algorithm: StyleTextRec
|
25 |
+
scale: 0.00392156862745098
|
26 |
+
mean:
|
27 |
+
- 0.5
|
28 |
+
- 0.5
|
29 |
+
- 0.5
|
30 |
+
std:
|
31 |
+
- 0.5
|
32 |
+
- 0.5
|
33 |
+
- 0.5
|
34 |
+
expand_result: false
|
35 |
+
bg_generator:
|
36 |
+
pretrain: style_text_models/bg_generator
|
37 |
+
module_name: bg_generator
|
38 |
+
generator_type: BgGeneratorWithMask
|
39 |
+
encode_dim: 64
|
40 |
+
norm_layer: null
|
41 |
+
conv_block_num: 4
|
42 |
+
conv_block_dropout: false
|
43 |
+
conv_block_dilation: true
|
44 |
+
output_factor: 1.05
|
45 |
+
text_generator:
|
46 |
+
pretrain: style_text_models/text_generator
|
47 |
+
module_name: text_generator
|
48 |
+
generator_type: TextGenerator
|
49 |
+
encode_dim: 64
|
50 |
+
norm_layer: InstanceNorm2D
|
51 |
+
conv_block_num: 4
|
52 |
+
conv_block_dropout: false
|
53 |
+
conv_block_dilation: true
|
54 |
+
fusion_generator:
|
55 |
+
pretrain: style_text_models/fusion_generator
|
56 |
+
module_name: fusion_generator
|
57 |
+
generator_type: FusionGeneratorSimple
|
58 |
+
encode_dim: 64
|
59 |
+
norm_layer: null
|
60 |
+
conv_block_num: 4
|
61 |
+
conv_block_dropout: false
|
62 |
+
conv_block_dilation: true
|
63 |
+
Writer:
|
64 |
+
method: SimpleWriter
|
Rotate/StyleText/engine/__init__.py
ADDED
File without changes
|
Rotate/StyleText/engine/corpus_generators.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import random
|
15 |
+
|
16 |
+
from utils.logging import get_logger
|
17 |
+
|
18 |
+
|
19 |
+
class FileCorpus(object):
|
20 |
+
def __init__(self, config):
|
21 |
+
self.logger = get_logger()
|
22 |
+
self.logger.info("using FileCorpus")
|
23 |
+
|
24 |
+
self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
25 |
+
|
26 |
+
corpus_file = config["CorpusGenerator"]["corpus_file"]
|
27 |
+
self.language = config["CorpusGenerator"]["language"]
|
28 |
+
with open(corpus_file, 'r') as f:
|
29 |
+
corpus_raw = f.read()
|
30 |
+
self.corpus_list = corpus_raw.split("\n")[:-1]
|
31 |
+
assert len(self.corpus_list) > 0
|
32 |
+
random.shuffle(self.corpus_list)
|
33 |
+
self.index = 0
|
34 |
+
|
35 |
+
def generate(self, corpus_length=0):
|
36 |
+
if self.index >= len(self.corpus_list):
|
37 |
+
self.index = 0
|
38 |
+
random.shuffle(self.corpus_list)
|
39 |
+
corpus = self.corpus_list[self.index]
|
40 |
+
if corpus_length != 0:
|
41 |
+
corpus = corpus[0:corpus_length]
|
42 |
+
if corpus_length > len(corpus):
|
43 |
+
self.logger.warning("generated corpus is shorter than expected.")
|
44 |
+
self.index += 1
|
45 |
+
return self.language, corpus
|
46 |
+
|
47 |
+
|
48 |
+
class EnNumCorpus(object):
|
49 |
+
def __init__(self, config):
|
50 |
+
self.logger = get_logger()
|
51 |
+
self.logger.info("using NumberCorpus")
|
52 |
+
self.num_list = "0123456789"
|
53 |
+
self.en_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
54 |
+
self.height = config["Global"]["image_height"]
|
55 |
+
self.max_width = config["Global"]["image_width"]
|
56 |
+
|
57 |
+
def generate(self, corpus_length=0):
|
58 |
+
corpus = ""
|
59 |
+
if corpus_length == 0:
|
60 |
+
corpus_length = random.randint(5, 15)
|
61 |
+
for i in range(corpus_length):
|
62 |
+
if random.random() < 0.2:
|
63 |
+
corpus += "{}".format(random.choice(self.en_char_list))
|
64 |
+
else:
|
65 |
+
corpus += "{}".format(random.choice(self.num_list))
|
66 |
+
return "en", corpus
|
Rotate/StyleText/engine/predictors.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
import cv2
|
16 |
+
import math
|
17 |
+
import paddle
|
18 |
+
|
19 |
+
from arch import style_text_rec
|
20 |
+
from utils.sys_funcs import check_gpu
|
21 |
+
from utils.logging import get_logger
|
22 |
+
|
23 |
+
|
24 |
+
class StyleTextRecPredictor(object):
|
25 |
+
def __init__(self, config):
|
26 |
+
algorithm = config['Predictor']['algorithm']
|
27 |
+
assert algorithm in ["StyleTextRec"
|
28 |
+
], "Generator {} not supported.".format(algorithm)
|
29 |
+
use_gpu = config["Global"]['use_gpu']
|
30 |
+
check_gpu(use_gpu)
|
31 |
+
paddle.set_device('gpu' if use_gpu else 'cpu')
|
32 |
+
self.logger = get_logger()
|
33 |
+
self.generator = getattr(style_text_rec, algorithm)(config)
|
34 |
+
self.height = config["Global"]["image_height"]
|
35 |
+
self.width = config["Global"]["image_width"]
|
36 |
+
self.scale = config["Predictor"]["scale"]
|
37 |
+
self.mean = config["Predictor"]["mean"]
|
38 |
+
self.std = config["Predictor"]["std"]
|
39 |
+
self.expand_result = config["Predictor"]["expand_result"]
|
40 |
+
|
41 |
+
def reshape_to_same_height(self, img_list):
|
42 |
+
h = img_list[0].shape[0]
|
43 |
+
for idx in range(1, len(img_list)):
|
44 |
+
new_w = round(1.0 * img_list[idx].shape[1] /
|
45 |
+
img_list[idx].shape[0] * h)
|
46 |
+
img_list[idx] = cv2.resize(img_list[idx], (new_w, h))
|
47 |
+
return img_list
|
48 |
+
|
49 |
+
def predict_single_image(self, style_input, text_input):
|
50 |
+
style_input = self.rep_style_input(style_input, text_input)
|
51 |
+
tensor_style_input = self.preprocess(style_input)
|
52 |
+
tensor_text_input = self.preprocess(text_input)
|
53 |
+
style_text_result = self.generator.forward(tensor_style_input,
|
54 |
+
tensor_text_input)
|
55 |
+
fake_fusion = self.postprocess(style_text_result["fake_fusion"])
|
56 |
+
fake_text = self.postprocess(style_text_result["fake_text"])
|
57 |
+
fake_sk = self.postprocess(style_text_result["fake_sk"])
|
58 |
+
fake_bg = self.postprocess(style_text_result["fake_bg"])
|
59 |
+
bbox = self.get_text_boundary(fake_text)
|
60 |
+
if bbox:
|
61 |
+
left, right, top, bottom = bbox
|
62 |
+
fake_fusion = fake_fusion[top:bottom, left:right, :]
|
63 |
+
fake_text = fake_text[top:bottom, left:right, :]
|
64 |
+
fake_sk = fake_sk[top:bottom, left:right, :]
|
65 |
+
fake_bg = fake_bg[top:bottom, left:right, :]
|
66 |
+
|
67 |
+
# fake_fusion = self.crop_by_text(img_fake_fusion, img_fake_text)
|
68 |
+
return {
|
69 |
+
"fake_fusion": fake_fusion,
|
70 |
+
"fake_text": fake_text,
|
71 |
+
"fake_sk": fake_sk,
|
72 |
+
"fake_bg": fake_bg,
|
73 |
+
}
|
74 |
+
|
75 |
+
def predict(self, style_input, text_input_list):
|
76 |
+
if not isinstance(text_input_list, (tuple, list)):
|
77 |
+
return self.predict_single_image(style_input, text_input_list)
|
78 |
+
|
79 |
+
synth_result_list = []
|
80 |
+
for text_input in text_input_list:
|
81 |
+
synth_result = self.predict_single_image(style_input, text_input)
|
82 |
+
synth_result_list.append(synth_result)
|
83 |
+
|
84 |
+
for key in synth_result:
|
85 |
+
res = [r[key] for r in synth_result_list]
|
86 |
+
res = self.reshape_to_same_height(res)
|
87 |
+
synth_result[key] = np.concatenate(res, axis=1)
|
88 |
+
return synth_result
|
89 |
+
|
90 |
+
def preprocess(self, img):
|
91 |
+
img = (img.astype('float32') * self.scale - self.mean) / self.std
|
92 |
+
img_height, img_width, channel = img.shape
|
93 |
+
assert channel == 3, "Please use an rgb image."
|
94 |
+
ratio = img_width / float(img_height)
|
95 |
+
if math.ceil(self.height * ratio) > self.width:
|
96 |
+
resized_w = self.width
|
97 |
+
else:
|
98 |
+
resized_w = int(math.ceil(self.height * ratio))
|
99 |
+
img = cv2.resize(img, (resized_w, self.height))
|
100 |
+
|
101 |
+
new_img = np.zeros([self.height, self.width, 3]).astype('float32')
|
102 |
+
new_img[:, 0:resized_w, :] = img
|
103 |
+
img = new_img.transpose((2, 0, 1))
|
104 |
+
img = img[np.newaxis, :, :, :]
|
105 |
+
return paddle.to_tensor(img)
|
106 |
+
|
107 |
+
def postprocess(self, tensor):
|
108 |
+
img = tensor.numpy()[0]
|
109 |
+
img = img.transpose((1, 2, 0))
|
110 |
+
img = (img * self.std + self.mean) / self.scale
|
111 |
+
img = np.maximum(img, 0.0)
|
112 |
+
img = np.minimum(img, 255.0)
|
113 |
+
img = img.astype('uint8')
|
114 |
+
return img
|
115 |
+
|
116 |
+
def rep_style_input(self, style_input, text_input):
|
117 |
+
rep_num = int(1.2 * (text_input.shape[1] / text_input.shape[0]) /
|
118 |
+
(style_input.shape[1] / style_input.shape[0])) + 1
|
119 |
+
style_input = np.tile(style_input, reps=[1, rep_num, 1])
|
120 |
+
max_width = int(self.width / self.height * style_input.shape[0])
|
121 |
+
style_input = style_input[:, :max_width, :]
|
122 |
+
return style_input
|
123 |
+
|
124 |
+
def get_text_boundary(self, text_img):
|
125 |
+
img_height = text_img.shape[0]
|
126 |
+
img_width = text_img.shape[1]
|
127 |
+
bounder = 3
|
128 |
+
text_canny_img = cv2.Canny(text_img, 10, 20)
|
129 |
+
edge_num_h = text_canny_img.sum(axis=0)
|
130 |
+
no_zero_list_h = np.where(edge_num_h > 0)[0]
|
131 |
+
edge_num_w = text_canny_img.sum(axis=1)
|
132 |
+
no_zero_list_w = np.where(edge_num_w > 0)[0]
|
133 |
+
if len(no_zero_list_h) == 0 or len(no_zero_list_w) == 0:
|
134 |
+
return None
|
135 |
+
left = max(no_zero_list_h[0] - bounder, 0)
|
136 |
+
right = min(no_zero_list_h[-1] + bounder, img_width)
|
137 |
+
top = max(no_zero_list_w[0] - bounder, 0)
|
138 |
+
bottom = min(no_zero_list_w[-1] + bounder, img_height)
|
139 |
+
return [left, right, top, bottom]
|
Rotate/StyleText/engine/style_samplers.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
import random
|
16 |
+
import cv2
|
17 |
+
|
18 |
+
|
19 |
+
class DatasetSampler(object):
|
20 |
+
def __init__(self, config):
|
21 |
+
self.image_home = config["StyleSampler"]["image_home"]
|
22 |
+
label_file = config["StyleSampler"]["label_file"]
|
23 |
+
self.dataset_with_label = config["StyleSampler"]["with_label"]
|
24 |
+
self.height = config["Global"]["image_height"]
|
25 |
+
self.index = 0
|
26 |
+
with open(label_file, "r") as f:
|
27 |
+
label_raw = f.read()
|
28 |
+
self.path_label_list = label_raw.split("\n")[:-1]
|
29 |
+
assert len(self.path_label_list) > 0
|
30 |
+
random.shuffle(self.path_label_list)
|
31 |
+
|
32 |
+
def sample(self):
|
33 |
+
if self.index >= len(self.path_label_list):
|
34 |
+
random.shuffle(self.path_label_list)
|
35 |
+
self.index = 0
|
36 |
+
if self.dataset_with_label:
|
37 |
+
path_label = self.path_label_list[self.index]
|
38 |
+
rel_image_path, label = path_label.split('\t')
|
39 |
+
else:
|
40 |
+
rel_image_path = self.path_label_list[self.index]
|
41 |
+
label = None
|
42 |
+
img_path = "{}/{}".format(self.image_home, rel_image_path)
|
43 |
+
image = cv2.imread(img_path)
|
44 |
+
origin_height = image.shape[0]
|
45 |
+
ratio = self.height / origin_height
|
46 |
+
width = int(image.shape[1] * ratio)
|
47 |
+
height = int(image.shape[0] * ratio)
|
48 |
+
image = cv2.resize(image, (width, height))
|
49 |
+
|
50 |
+
self.index += 1
|
51 |
+
if label:
|
52 |
+
return {"image": image, "label": label}
|
53 |
+
else:
|
54 |
+
return {"image": image}
|
55 |
+
|
56 |
+
|
57 |
+
def duplicate_image(image, width):
|
58 |
+
image_width = image.shape[1]
|
59 |
+
dup_num = width // image_width + 1
|
60 |
+
image = np.tile(image, reps=[1, dup_num, 1])
|
61 |
+
cropped_image = image[:, :width, :]
|
62 |
+
return cropped_image
|
Rotate/StyleText/engine/synthesisers.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import numpy as np
|
16 |
+
import cv2
|
17 |
+
|
18 |
+
from utils.config import ArgsParser, load_config, override_config
|
19 |
+
from utils.logging import get_logger
|
20 |
+
from engine import style_samplers, corpus_generators, text_drawers, predictors, writers
|
21 |
+
|
22 |
+
|
23 |
+
class ImageSynthesiser(object):
|
24 |
+
def __init__(self):
|
25 |
+
self.FLAGS = ArgsParser().parse_args()
|
26 |
+
self.config = load_config(self.FLAGS.config)
|
27 |
+
self.config = override_config(self.config, options=self.FLAGS.override)
|
28 |
+
self.output_dir = self.config["Global"]["output_dir"]
|
29 |
+
if not os.path.exists(self.output_dir):
|
30 |
+
os.mkdir(self.output_dir)
|
31 |
+
self.logger = get_logger(
|
32 |
+
log_file='{}/predict.log'.format(self.output_dir))
|
33 |
+
|
34 |
+
self.text_drawer = text_drawers.StdTextDrawer(self.config)
|
35 |
+
|
36 |
+
predictor_method = self.config["Predictor"]["method"]
|
37 |
+
assert predictor_method is not None
|
38 |
+
self.predictor = getattr(predictors, predictor_method)(self.config)
|
39 |
+
|
40 |
+
def synth_image(self, corpus, style_input, language="en"):
|
41 |
+
corpus_list, text_input_list = self.text_drawer.draw_text(
|
42 |
+
corpus, language, style_input_width=style_input.shape[1])
|
43 |
+
synth_result = self.predictor.predict(style_input, text_input_list)
|
44 |
+
return synth_result
|
45 |
+
|
46 |
+
|
47 |
+
class DatasetSynthesiser(ImageSynthesiser):
|
48 |
+
def __init__(self):
|
49 |
+
super(DatasetSynthesiser, self).__init__()
|
50 |
+
self.tag = self.FLAGS.tag
|
51 |
+
self.output_num = self.config["Global"]["output_num"]
|
52 |
+
corpus_generator_method = self.config["CorpusGenerator"]["method"]
|
53 |
+
self.corpus_generator = getattr(corpus_generators,
|
54 |
+
corpus_generator_method)(self.config)
|
55 |
+
|
56 |
+
style_sampler_method = self.config["StyleSampler"]["method"]
|
57 |
+
assert style_sampler_method is not None
|
58 |
+
self.style_sampler = style_samplers.DatasetSampler(self.config)
|
59 |
+
self.writer = writers.SimpleWriter(self.config, self.tag)
|
60 |
+
|
61 |
+
def synth_dataset(self):
|
62 |
+
for i in range(self.output_num):
|
63 |
+
style_data = self.style_sampler.sample()
|
64 |
+
style_input = style_data["image"]
|
65 |
+
corpus_language, text_input_label = self.corpus_generator.generate()
|
66 |
+
text_input_label_list, text_input_list = self.text_drawer.draw_text(
|
67 |
+
text_input_label,
|
68 |
+
corpus_language,
|
69 |
+
style_input_width=style_input.shape[1])
|
70 |
+
|
71 |
+
text_input_label = "".join(text_input_label_list)
|
72 |
+
|
73 |
+
synth_result = self.predictor.predict(style_input, text_input_list)
|
74 |
+
fake_fusion = synth_result["fake_fusion"]
|
75 |
+
self.writer.save_image(fake_fusion, text_input_label)
|
76 |
+
self.writer.save_label()
|
77 |
+
self.writer.merge_label()
|
Rotate/StyleText/engine/text_drawers.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw, ImageFont
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
from utils.logging import get_logger
|
5 |
+
|
6 |
+
|
7 |
+
class StdTextDrawer(object):
|
8 |
+
def __init__(self, config):
|
9 |
+
self.logger = get_logger()
|
10 |
+
self.max_width = config["Global"]["image_width"]
|
11 |
+
self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
12 |
+
self.height = config["Global"]["image_height"]
|
13 |
+
self.font_dict = {}
|
14 |
+
self.load_fonts(config["TextDrawer"]["fonts"])
|
15 |
+
self.support_languages = list(self.font_dict)
|
16 |
+
|
17 |
+
def load_fonts(self, fonts_config):
|
18 |
+
for language in fonts_config:
|
19 |
+
font_path = fonts_config[language]
|
20 |
+
font_height = self.get_valid_height(font_path)
|
21 |
+
font = ImageFont.truetype(font_path, font_height)
|
22 |
+
self.font_dict[language] = font
|
23 |
+
|
24 |
+
def get_valid_height(self, font_path):
|
25 |
+
font = ImageFont.truetype(font_path, self.height - 4)
|
26 |
+
left, top, right, bottom = font.getbbox(self.char_list)
|
27 |
+
_, font_height = right - left, bottom - top
|
28 |
+
if font_height <= self.height - 4:
|
29 |
+
return self.height - 4
|
30 |
+
else:
|
31 |
+
return int((self.height - 4)**2 / font_height)
|
32 |
+
|
33 |
+
def draw_text(self,
|
34 |
+
corpus,
|
35 |
+
language="en",
|
36 |
+
crop=True,
|
37 |
+
style_input_width=None):
|
38 |
+
if language not in self.support_languages:
|
39 |
+
self.logger.warning(
|
40 |
+
"language {} not supported, use en instead.".format(language))
|
41 |
+
language = "en"
|
42 |
+
if crop:
|
43 |
+
width = min(self.max_width, len(corpus) * self.height) + 4
|
44 |
+
else:
|
45 |
+
width = len(corpus) * self.height + 4
|
46 |
+
|
47 |
+
if style_input_width is not None:
|
48 |
+
width = min(width, style_input_width)
|
49 |
+
|
50 |
+
corpus_list = []
|
51 |
+
text_input_list = []
|
52 |
+
|
53 |
+
while len(corpus) != 0:
|
54 |
+
bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
|
55 |
+
draw = ImageDraw.Draw(bg)
|
56 |
+
char_x = 2
|
57 |
+
font = self.font_dict[language]
|
58 |
+
i = 0
|
59 |
+
while i < len(corpus):
|
60 |
+
char_i = corpus[i]
|
61 |
+
char_size = font.getsize(char_i)[0]
|
62 |
+
# split when char_x exceeds char size and index is not 0 (at least 1 char should be wroten on the image)
|
63 |
+
if char_x + char_size >= width and i != 0:
|
64 |
+
text_input = np.array(bg).astype(np.uint8)
|
65 |
+
text_input = text_input[:, 0:char_x, :]
|
66 |
+
|
67 |
+
corpus_list.append(corpus[0:i])
|
68 |
+
text_input_list.append(text_input)
|
69 |
+
corpus = corpus[i:]
|
70 |
+
i = 0
|
71 |
+
break
|
72 |
+
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
|
73 |
+
char_x += char_size
|
74 |
+
|
75 |
+
i += 1
|
76 |
+
# the whole text is shorter than style input
|
77 |
+
if i == len(corpus):
|
78 |
+
text_input = np.array(bg).astype(np.uint8)
|
79 |
+
text_input = text_input[:, 0:char_x, :]
|
80 |
+
|
81 |
+
corpus_list.append(corpus[0:i])
|
82 |
+
text_input_list.append(text_input)
|
83 |
+
break
|
84 |
+
|
85 |
+
return corpus_list, text_input_list
|
Rotate/StyleText/engine/writers.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import cv2
|
16 |
+
import glob
|
17 |
+
|
18 |
+
from utils.logging import get_logger
|
19 |
+
|
20 |
+
|
21 |
+
class SimpleWriter(object):
|
22 |
+
def __init__(self, config, tag):
|
23 |
+
self.logger = get_logger()
|
24 |
+
self.output_dir = config["Global"]["output_dir"]
|
25 |
+
self.counter = 0
|
26 |
+
self.label_dict = {}
|
27 |
+
self.tag = tag
|
28 |
+
self.label_file_index = 0
|
29 |
+
|
30 |
+
def save_image(self, image, text_input_label):
|
31 |
+
image_home = os.path.join(self.output_dir, "images", self.tag)
|
32 |
+
if not os.path.exists(image_home):
|
33 |
+
os.makedirs(image_home)
|
34 |
+
|
35 |
+
image_path = os.path.join(image_home, "{}.png".format(self.counter))
|
36 |
+
# todo support continue synth
|
37 |
+
cv2.imwrite(image_path, image)
|
38 |
+
self.logger.info("generate image: {}".format(image_path))
|
39 |
+
|
40 |
+
image_name = os.path.join(self.tag, "{}.png".format(self.counter))
|
41 |
+
self.label_dict[image_name] = text_input_label
|
42 |
+
|
43 |
+
self.counter += 1
|
44 |
+
if not self.counter % 100:
|
45 |
+
self.save_label()
|
46 |
+
|
47 |
+
def save_label(self):
|
48 |
+
label_raw = ""
|
49 |
+
label_home = os.path.join(self.output_dir, "label")
|
50 |
+
if not os.path.exists(label_home):
|
51 |
+
os.mkdir(label_home)
|
52 |
+
for image_path in self.label_dict:
|
53 |
+
label = self.label_dict[image_path]
|
54 |
+
label_raw += "{}\t{}\n".format(image_path, label)
|
55 |
+
label_file_path = os.path.join(label_home,
|
56 |
+
"{}_label.txt".format(self.tag))
|
57 |
+
with open(label_file_path, "w") as f:
|
58 |
+
f.write(label_raw)
|
59 |
+
self.label_file_index += 1
|
60 |
+
|
61 |
+
def merge_label(self):
|
62 |
+
label_raw = ""
|
63 |
+
label_file_regex = os.path.join(self.output_dir, "label",
|
64 |
+
"*_label.txt")
|
65 |
+
label_file_list = glob.glob(label_file_regex)
|
66 |
+
for label_file_i in label_file_list:
|
67 |
+
with open(label_file_i, "r") as f:
|
68 |
+
label_raw += f.read()
|
69 |
+
label_file_path = os.path.join(self.output_dir, "label.txt")
|
70 |
+
with open(label_file_path, "w") as f:
|
71 |
+
f.write(label_raw)
|
Rotate/StyleText/examples/corpus/example.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Paddle
|
2 |
+
飞桨文字识别
|
Rotate/StyleText/examples/image_list.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
style_images/1.jpg NEATNESS
|
2 |
+
style_images/2.jpg 锁店君和宾馆
|
Rotate/StyleText/tools/__init__.py
ADDED
File without changes
|
Rotate/StyleText/tools/synth_dataset.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
|
18 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
19 |
+
sys.path.append(__dir__)
|
20 |
+
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
21 |
+
|
22 |
+
from engine.synthesisers import DatasetSynthesiser
|
23 |
+
|
24 |
+
|
25 |
+
def synth_dataset():
|
26 |
+
dataset_synthesiser = DatasetSynthesiser()
|
27 |
+
dataset_synthesiser.synth_dataset()
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
synth_dataset()
|
Rotate/StyleText/tools/synth_image.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import cv2
|
16 |
+
import sys
|
17 |
+
import glob
|
18 |
+
|
19 |
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
20 |
+
sys.path.append(__dir__)
|
21 |
+
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
22 |
+
|
23 |
+
from utils.config import ArgsParser
|
24 |
+
from engine.synthesisers import ImageSynthesiser
|
25 |
+
|
26 |
+
|
27 |
+
def synth_image():
|
28 |
+
args = ArgsParser().parse_args()
|
29 |
+
image_synthesiser = ImageSynthesiser()
|
30 |
+
style_image_path = args.style_image
|
31 |
+
img = cv2.imread(style_image_path)
|
32 |
+
text_corpus = args.text_corpus
|
33 |
+
language = args.language
|
34 |
+
|
35 |
+
synth_result = image_synthesiser.synth_image(text_corpus, img, language)
|
36 |
+
fake_fusion = synth_result["fake_fusion"]
|
37 |
+
fake_text = synth_result["fake_text"]
|
38 |
+
fake_bg = synth_result["fake_bg"]
|
39 |
+
cv2.imwrite("fake_fusion.jpg", fake_fusion)
|
40 |
+
cv2.imwrite("fake_text.jpg", fake_text)
|
41 |
+
cv2.imwrite("fake_bg.jpg", fake_bg)
|
42 |
+
|
43 |
+
|
44 |
+
def batch_synth_images():
|
45 |
+
image_synthesiser = ImageSynthesiser()
|
46 |
+
|
47 |
+
corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt"
|
48 |
+
style_data_dir = "../StyleTextRec_data/test_20201208/style_images/"
|
49 |
+
save_path = "./output_data/"
|
50 |
+
corpus_list = []
|
51 |
+
with open(corpus_file, "rb") as fin:
|
52 |
+
lines = fin.readlines()
|
53 |
+
for line in lines:
|
54 |
+
substr = line.decode("utf-8").strip("\n").split("\t")
|
55 |
+
corpus_list.append(substr)
|
56 |
+
style_img_list = glob.glob("{}/*.jpg".format(style_data_dir))
|
57 |
+
corpus_num = len(corpus_list)
|
58 |
+
style_img_num = len(style_img_list)
|
59 |
+
for cno in range(corpus_num):
|
60 |
+
for sno in range(style_img_num):
|
61 |
+
corpus, lang = corpus_list[cno]
|
62 |
+
style_img_path = style_img_list[sno]
|
63 |
+
img = cv2.imread(style_img_path)
|
64 |
+
synth_result = image_synthesiser.synth_image(corpus, img, lang)
|
65 |
+
fake_fusion = synth_result["fake_fusion"]
|
66 |
+
fake_text = synth_result["fake_text"]
|
67 |
+
fake_bg = synth_result["fake_bg"]
|
68 |
+
for tp in range(2):
|
69 |
+
if tp == 0:
|
70 |
+
prefix = "%s/c%d_s%d_" % (save_path, cno, sno)
|
71 |
+
else:
|
72 |
+
prefix = "%s/s%d_c%d_" % (save_path, sno, cno)
|
73 |
+
cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion)
|
74 |
+
cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text)
|
75 |
+
cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg)
|
76 |
+
cv2.imwrite("%s_input_style.jpg" % prefix, img)
|
77 |
+
print(cno, corpus_num, sno, style_img_num)
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == '__main__':
|
81 |
+
# batch_synth_images()
|
82 |
+
synth_image()
|
Rotate/StyleText/utils/__init__.py
ADDED
File without changes
|
Rotate/StyleText/utils/config.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import yaml
|
15 |
+
import os
|
16 |
+
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
17 |
+
|
18 |
+
|
19 |
+
def override(dl, ks, v):
|
20 |
+
"""
|
21 |
+
Recursively replace dict of list
|
22 |
+
|
23 |
+
Args:
|
24 |
+
dl(dict or list): dict or list to be replaced
|
25 |
+
ks(list): list of keys
|
26 |
+
v(str): value to be replaced
|
27 |
+
"""
|
28 |
+
|
29 |
+
def str2num(v):
|
30 |
+
try:
|
31 |
+
return eval(v)
|
32 |
+
except Exception:
|
33 |
+
return v
|
34 |
+
|
35 |
+
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
|
36 |
+
assert len(ks) > 0, ('lenght of keys should larger than 0')
|
37 |
+
if isinstance(dl, list):
|
38 |
+
k = str2num(ks[0])
|
39 |
+
if len(ks) == 1:
|
40 |
+
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
|
41 |
+
dl[k] = str2num(v)
|
42 |
+
else:
|
43 |
+
override(dl[k], ks[1:], v)
|
44 |
+
else:
|
45 |
+
if len(ks) == 1:
|
46 |
+
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
|
47 |
+
if not ks[0] in dl:
|
48 |
+
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
|
49 |
+
dl[ks[0]] = str2num(v)
|
50 |
+
else:
|
51 |
+
assert ks[0] in dl, (
|
52 |
+
'({}) doesn\'t exist in {}, a new dict field is invalid'.
|
53 |
+
format(ks[0], dl))
|
54 |
+
override(dl[ks[0]], ks[1:], v)
|
55 |
+
|
56 |
+
|
57 |
+
def override_config(config, options=None):
|
58 |
+
"""
|
59 |
+
Recursively override the config
|
60 |
+
|
61 |
+
Args:
|
62 |
+
config(dict): dict to be replaced
|
63 |
+
options(list): list of pairs(key0.key1.idx.key2=value)
|
64 |
+
such as: [
|
65 |
+
'topk=2',
|
66 |
+
'VALID.transforms.1.ResizeImage.resize_short=300'
|
67 |
+
]
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
config(dict): replaced config
|
71 |
+
"""
|
72 |
+
if options is not None:
|
73 |
+
for opt in options:
|
74 |
+
assert isinstance(opt, str), (
|
75 |
+
"option({}) should be a str".format(opt))
|
76 |
+
assert "=" in opt, (
|
77 |
+
"option({}) should contain a ="
|
78 |
+
"to distinguish between key and value".format(opt))
|
79 |
+
pair = opt.split('=')
|
80 |
+
assert len(pair) == 2, ("there can be only a = in the option")
|
81 |
+
key, value = pair
|
82 |
+
keys = key.split('.')
|
83 |
+
override(config, keys, value)
|
84 |
+
|
85 |
+
return config
|
86 |
+
|
87 |
+
|
88 |
+
class ArgsParser(ArgumentParser):
|
89 |
+
def __init__(self):
|
90 |
+
super(ArgsParser, self).__init__(
|
91 |
+
formatter_class=RawDescriptionHelpFormatter)
|
92 |
+
self.add_argument("-c", "--config", help="configuration file to use")
|
93 |
+
self.add_argument(
|
94 |
+
"-t", "--tag", default="0", help="tag for marking worker")
|
95 |
+
self.add_argument(
|
96 |
+
'-o',
|
97 |
+
'--override',
|
98 |
+
action='append',
|
99 |
+
default=[],
|
100 |
+
help='config options to be overridden')
|
101 |
+
self.add_argument(
|
102 |
+
"--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
|
103 |
+
self.add_argument(
|
104 |
+
"--text_corpus", default="PaddleOCR", help="tag for marking worker")
|
105 |
+
self.add_argument(
|
106 |
+
"--language", default="en", help="tag for marking worker")
|
107 |
+
|
108 |
+
def parse_args(self, argv=None):
|
109 |
+
args = super(ArgsParser, self).parse_args(argv)
|
110 |
+
assert args.config is not None, \
|
111 |
+
"Please specify --config=configure_file_path."
|
112 |
+
return args
|
113 |
+
|
114 |
+
|
115 |
+
def load_config(file_path):
|
116 |
+
"""
|
117 |
+
Load config from yml/yaml file.
|
118 |
+
Args:
|
119 |
+
file_path (str): Path of the config file to be loaded.
|
120 |
+
Returns: config
|
121 |
+
"""
|
122 |
+
ext = os.path.splitext(file_path)[1]
|
123 |
+
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
|
124 |
+
with open(file_path, 'rb') as f:
|
125 |
+
config = yaml.load(f, Loader=yaml.Loader)
|
126 |
+
|
127 |
+
return config
|
128 |
+
|
129 |
+
|
130 |
+
def gen_config():
|
131 |
+
base_config = {
|
132 |
+
"Global": {
|
133 |
+
"algorithm": "SRNet",
|
134 |
+
"use_gpu": True,
|
135 |
+
"start_epoch": 1,
|
136 |
+
"stage1_epoch_num": 100,
|
137 |
+
"stage2_epoch_num": 100,
|
138 |
+
"log_smooth_window": 20,
|
139 |
+
"print_batch_step": 2,
|
140 |
+
"save_model_dir": "./output/SRNet",
|
141 |
+
"use_visualdl": False,
|
142 |
+
"save_epoch_step": 10,
|
143 |
+
"vgg_pretrain": "./pretrained/VGG19_pretrained",
|
144 |
+
"vgg_load_static_pretrain": True
|
145 |
+
},
|
146 |
+
"Architecture": {
|
147 |
+
"model_type": "data_aug",
|
148 |
+
"algorithm": "SRNet",
|
149 |
+
"net_g": {
|
150 |
+
"name": "srnet_net_g",
|
151 |
+
"encode_dim": 64,
|
152 |
+
"norm": "batch",
|
153 |
+
"use_dropout": False,
|
154 |
+
"init_type": "xavier",
|
155 |
+
"init_gain": 0.02,
|
156 |
+
"use_dilation": 1
|
157 |
+
},
|
158 |
+
# input_nc, ndf, netD,
|
159 |
+
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
|
160 |
+
"bg_discriminator": {
|
161 |
+
"name": "srnet_bg_discriminator",
|
162 |
+
"input_nc": 6,
|
163 |
+
"ndf": 64,
|
164 |
+
"netD": "basic",
|
165 |
+
"norm": "none",
|
166 |
+
"init_type": "xavier",
|
167 |
+
},
|
168 |
+
"fusion_discriminator": {
|
169 |
+
"name": "srnet_fusion_discriminator",
|
170 |
+
"input_nc": 6,
|
171 |
+
"ndf": 64,
|
172 |
+
"netD": "basic",
|
173 |
+
"norm": "none",
|
174 |
+
"init_type": "xavier",
|
175 |
+
}
|
176 |
+
},
|
177 |
+
"Loss": {
|
178 |
+
"lamb": 10,
|
179 |
+
"perceptual_lamb": 1,
|
180 |
+
"muvar_lamb": 50,
|
181 |
+
"style_lamb": 500
|
182 |
+
},
|
183 |
+
"Optimizer": {
|
184 |
+
"name": "Adam",
|
185 |
+
"learning_rate": {
|
186 |
+
"name": "lambda",
|
187 |
+
"lr": 0.0002,
|
188 |
+
"lr_decay_iters": 50
|
189 |
+
},
|
190 |
+
"beta1": 0.5,
|
191 |
+
"beta2": 0.999,
|
192 |
+
},
|
193 |
+
"Train": {
|
194 |
+
"batch_size_per_card": 8,
|
195 |
+
"num_workers_per_card": 4,
|
196 |
+
"dataset": {
|
197 |
+
"delimiter": "\t",
|
198 |
+
"data_dir": "/",
|
199 |
+
"label_file": "tmp/label.txt",
|
200 |
+
"transforms": [{
|
201 |
+
"DecodeImage": {
|
202 |
+
"to_rgb": True,
|
203 |
+
"to_np": False,
|
204 |
+
"channel_first": False
|
205 |
+
}
|
206 |
+
}, {
|
207 |
+
"NormalizeImage": {
|
208 |
+
"scale": 1. / 255.,
|
209 |
+
"mean": [0.485, 0.456, 0.406],
|
210 |
+
"std": [0.229, 0.224, 0.225],
|
211 |
+
"order": None
|
212 |
+
}
|
213 |
+
}, {
|
214 |
+
"ToCHWImage": None
|
215 |
+
}]
|
216 |
+
}
|
217 |
+
}
|
218 |
+
}
|
219 |
+
with open("config.yml", "w") as f:
|
220 |
+
yaml.dump(base_config, f)
|
221 |
+
|
222 |
+
|
223 |
+
if __name__ == '__main__':
|
224 |
+
gen_config()
|
Rotate/StyleText/utils/load_params.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import paddle
|
16 |
+
|
17 |
+
__all__ = ['load_dygraph_pretrain']
|
18 |
+
|
19 |
+
|
20 |
+
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
|
21 |
+
if not os.path.exists(path + '.pdparams'):
|
22 |
+
raise ValueError("Model pretrain path {} does not "
|
23 |
+
"exists.".format(path))
|
24 |
+
param_state_dict = paddle.load(path + '.pdparams')
|
25 |
+
model.set_state_dict(param_state_dict)
|
26 |
+
logger.info("load pretrained model from {}".format(path))
|
27 |
+
return
|
Rotate/StyleText/utils/logging.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import logging
|
17 |
+
import functools
|
18 |
+
import paddle.distributed as dist
|
19 |
+
|
20 |
+
logger_initialized = {}
|
21 |
+
|
22 |
+
|
23 |
+
@functools.lru_cache()
|
24 |
+
def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
|
25 |
+
"""Initialize and get a logger by name.
|
26 |
+
If the logger has not been initialized, this method will initialize the
|
27 |
+
logger by adding one or two handlers, otherwise the initialized logger will
|
28 |
+
be directly returned. During initialization, a StreamHandler will always be
|
29 |
+
added. If `log_file` is specified a FileHandler will also be added.
|
30 |
+
Args:
|
31 |
+
name (str): Logger name.
|
32 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
33 |
+
will be added to the logger.
|
34 |
+
log_level (int): The logger level. Note that only the process of
|
35 |
+
rank 0 is affected, and other processes will set the level to
|
36 |
+
"Error" thus be silent most of the time.
|
37 |
+
Returns:
|
38 |
+
logging.Logger: The expected logger.
|
39 |
+
"""
|
40 |
+
logger = logging.getLogger(name)
|
41 |
+
if name in logger_initialized:
|
42 |
+
return logger
|
43 |
+
for logger_name in logger_initialized:
|
44 |
+
if name.startswith(logger_name):
|
45 |
+
return logger
|
46 |
+
|
47 |
+
formatter = logging.Formatter(
|
48 |
+
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
49 |
+
datefmt="%Y/%m/%d %H:%M:%S")
|
50 |
+
|
51 |
+
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
52 |
+
stream_handler.setFormatter(formatter)
|
53 |
+
logger.addHandler(stream_handler)
|
54 |
+
if log_file is not None and dist.get_rank() == 0:
|
55 |
+
log_file_folder = os.path.split(log_file)[0]
|
56 |
+
os.makedirs(log_file_folder, exist_ok=True)
|
57 |
+
file_handler = logging.FileHandler(log_file, 'a')
|
58 |
+
file_handler.setFormatter(formatter)
|
59 |
+
logger.addHandler(file_handler)
|
60 |
+
if dist.get_rank() == 0:
|
61 |
+
logger.setLevel(log_level)
|
62 |
+
else:
|
63 |
+
logger.setLevel(logging.ERROR)
|
64 |
+
logger_initialized[name] = True
|
65 |
+
return logger
|
Rotate/StyleText/utils/math_functions.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import paddle
|
15 |
+
|
16 |
+
|
17 |
+
def compute_mean_covariance(img):
|
18 |
+
batch_size = img.shape[0]
|
19 |
+
channel_num = img.shape[1]
|
20 |
+
height = img.shape[2]
|
21 |
+
width = img.shape[3]
|
22 |
+
num_pixels = height * width
|
23 |
+
|
24 |
+
# batch_size * channel_num * 1 * 1
|
25 |
+
mu = img.mean(2, keepdim=True).mean(3, keepdim=True)
|
26 |
+
|
27 |
+
# batch_size * channel_num * num_pixels
|
28 |
+
img_hat = img - mu.expand_as(img)
|
29 |
+
img_hat = img_hat.reshape([batch_size, channel_num, num_pixels])
|
30 |
+
# batch_size * num_pixels * channel_num
|
31 |
+
img_hat_transpose = img_hat.transpose([0, 2, 1])
|
32 |
+
# batch_size * channel_num * channel_num
|
33 |
+
covariance = paddle.bmm(img_hat, img_hat_transpose)
|
34 |
+
covariance = covariance / num_pixels
|
35 |
+
|
36 |
+
return mu, covariance
|
37 |
+
|
38 |
+
|
39 |
+
def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
|
40 |
+
eps = 1e-5
|
41 |
+
intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask)
|
42 |
+
union = paddle.sum(y_true_cls * training_mask) + paddle.sum(
|
43 |
+
y_pred_cls * training_mask) + eps
|
44 |
+
loss = 1. - (2 * intersection / union)
|
45 |
+
return loss
|
Rotate/StyleText/utils/sys_funcs.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import sys
|
15 |
+
import os
|
16 |
+
import errno
|
17 |
+
import paddle
|
18 |
+
|
19 |
+
|
20 |
+
def get_check_global_params(mode):
|
21 |
+
check_params = [
|
22 |
+
'use_gpu', 'max_text_length', 'image_shape', 'image_shape',
|
23 |
+
'character_type', 'loss_type'
|
24 |
+
]
|
25 |
+
if mode == "train_eval":
|
26 |
+
check_params = check_params + [
|
27 |
+
'train_batch_size_per_card', 'test_batch_size_per_card'
|
28 |
+
]
|
29 |
+
elif mode == "test":
|
30 |
+
check_params = check_params + ['test_batch_size_per_card']
|
31 |
+
return check_params
|
32 |
+
|
33 |
+
|
34 |
+
def check_gpu(use_gpu):
|
35 |
+
"""
|
36 |
+
Log error and exit when set use_gpu=true in paddlepaddle
|
37 |
+
cpu version.
|
38 |
+
"""
|
39 |
+
err = "Config use_gpu cannot be set as true while you are " \
|
40 |
+
"using paddlepaddle cpu version ! \nPlease try: \n" \
|
41 |
+
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
|
42 |
+
"\t2. Set use_gpu as false in config file to run " \
|
43 |
+
"model on CPU"
|
44 |
+
if use_gpu:
|
45 |
+
try:
|
46 |
+
if not paddle.is_compiled_with_cuda():
|
47 |
+
print(err)
|
48 |
+
sys.exit(1)
|
49 |
+
except:
|
50 |
+
print("Fail to check gpu state.")
|
51 |
+
sys.exit(1)
|
52 |
+
|
53 |
+
|
54 |
+
def _mkdir_if_not_exist(path, logger):
|
55 |
+
"""
|
56 |
+
mkdir if not exists, ignore the exception when multiprocess mkdir together
|
57 |
+
"""
|
58 |
+
if not os.path.exists(path):
|
59 |
+
try:
|
60 |
+
os.makedirs(path)
|
61 |
+
except OSError as e:
|
62 |
+
if e.errno == errno.EEXIST and os.path.isdir(path):
|
63 |
+
logger.warning(
|
64 |
+
'be happy if some process has already created {}'.format(
|
65 |
+
path))
|
66 |
+
else:
|
67 |
+
raise OSError('Failed to mkdir {}'.format(path))
|
Rotate/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import paddleocr
|
15 |
+
# from .paddleocr import *
|
16 |
+
|
17 |
+
# __version__ = paddleocr.VERSION
|
18 |
+
# __all__ = ['PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'save_structure_res','download_with_progressbar']
|
Rotate/ch_PP-OCRv4_det_infer/inference.pdiparams.info
ADDED
Binary file (23.6 kB). View file
|
|
Rotate/ch_PP-OCRv4_det_infer/inference.pdmodel
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ad68ed2768fe6c41166a5bc64680cc9f445390acb6528da449a4db2f7b90e14
|
3 |
+
size 166367
|
Rotate/ch_ppocr_mobile_v2.0_cls_infer/._inference.pdmodel
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d86f5afbfb8cd933a1d0dbbfd8ff2b93ca3eacc6c45f4590a4a2ee107047f6d2
|
3 |
+
size 176
|
Rotate/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1efda1b80e174b4fcb168a035ac96c1af4938892bd86a55f300a6027105d08c
|
3 |
+
size 539978
|
Rotate/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams.info
ADDED
Binary file (18.5 kB). View file
|
|
Rotate/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
debug: false
|
3 |
+
use_gpu: true
|
4 |
+
epoch_num: 100
|
5 |
+
log_smooth_window: 20
|
6 |
+
print_batch_step: 10
|
7 |
+
save_model_dir: ./output/rec_ppocr_v3_rotnet
|
8 |
+
save_epoch_step: 3
|
9 |
+
eval_batch_step: [0, 2000]
|
10 |
+
cal_metric_during_train: true
|
11 |
+
pretrained_model: null
|
12 |
+
checkpoints: null
|
13 |
+
save_inference_dir: null
|
14 |
+
use_visualdl: false
|
15 |
+
infer_img: doc/imgs_words/ch/word_1.jpg
|
16 |
+
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
17 |
+
max_text_length: 25
|
18 |
+
infer_mode: false
|
19 |
+
use_space_char: true
|
20 |
+
save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt
|
21 |
+
Optimizer:
|
22 |
+
name: Adam
|
23 |
+
beta1: 0.9
|
24 |
+
beta2: 0.999
|
25 |
+
lr:
|
26 |
+
name: Cosine
|
27 |
+
learning_rate: 0.001
|
28 |
+
regularizer:
|
29 |
+
name: L2
|
30 |
+
factor: 1.0e-05
|
31 |
+
Architecture:
|
32 |
+
model_type: cls
|
33 |
+
algorithm: CLS
|
34 |
+
Transform: null
|
35 |
+
Backbone:
|
36 |
+
name: MobileNetV1Enhance
|
37 |
+
scale: 0.5
|
38 |
+
last_conv_stride: [1, 2]
|
39 |
+
last_pool_type: avg
|
40 |
+
Neck:
|
41 |
+
Head:
|
42 |
+
name: ClsHead
|
43 |
+
class_dim: 4
|
44 |
+
|
45 |
+
Loss:
|
46 |
+
name: ClsLoss
|
47 |
+
main_indicator: acc
|
48 |
+
|
49 |
+
PostProcess:
|
50 |
+
name: ClsPostProcess
|
51 |
+
|
52 |
+
Metric:
|
53 |
+
name: ClsMetric
|
54 |
+
main_indicator: acc
|
55 |
+
|
56 |
+
Train:
|
57 |
+
dataset:
|
58 |
+
name: SimpleDataSet
|
59 |
+
data_dir: ./train_data
|
60 |
+
label_file_list:
|
61 |
+
- ./train_data/train_list.txt
|
62 |
+
transforms:
|
63 |
+
- DecodeImage:
|
64 |
+
img_mode: BGR
|
65 |
+
channel_first: false
|
66 |
+
- BaseDataAugmentation:
|
67 |
+
- RandAugment:
|
68 |
+
- SSLRotateResize:
|
69 |
+
image_shape: [3, 48, 320]
|
70 |
+
- KeepKeys:
|
71 |
+
keep_keys: ["image", "label"]
|
72 |
+
loader:
|
73 |
+
collate_fn: "SSLRotateCollate"
|
74 |
+
shuffle: true
|
75 |
+
batch_size_per_card: 32
|
76 |
+
drop_last: true
|
77 |
+
num_workers: 8
|
78 |
+
Eval:
|
79 |
+
dataset:
|
80 |
+
name: SimpleDataSet
|
81 |
+
data_dir: ./train_data
|
82 |
+
label_file_list:
|
83 |
+
- ./train_data/val_list.txt
|
84 |
+
transforms:
|
85 |
+
- DecodeImage:
|
86 |
+
img_mode: BGR
|
87 |
+
channel_first: false
|
88 |
+
- SSLRotateResize:
|
89 |
+
image_shape: [3, 48, 320]
|
90 |
+
- KeepKeys:
|
91 |
+
keep_keys: ["image", "label"]
|
92 |
+
loader:
|
93 |
+
collate_fn: "SSLRotateCollate"
|
94 |
+
shuffle: false
|
95 |
+
drop_last: false
|
96 |
+
batch_size_per_card: 64
|
97 |
+
num_workers: 8
|
98 |
+
profiler_options: null
|
Rotate/configs/cls/cls_mv3.yml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
use_gpu: true
|
3 |
+
epoch_num: 100
|
4 |
+
log_smooth_window: 20
|
5 |
+
print_batch_step: 10
|
6 |
+
save_model_dir: ./output/cls/mv3/
|
7 |
+
save_epoch_step: 3
|
8 |
+
# evaluation is run every 5000 iterations after the 4000th iteration
|
9 |
+
eval_batch_step: [0, 1000]
|
10 |
+
cal_metric_during_train: True
|
11 |
+
pretrained_model:
|
12 |
+
checkpoints:
|
13 |
+
save_inference_dir:
|
14 |
+
use_visualdl: False
|
15 |
+
infer_img: doc/imgs_words_en/word_10.png
|
16 |
+
label_list: ['0','180']
|
17 |
+
|
18 |
+
Architecture:
|
19 |
+
model_type: cls
|
20 |
+
algorithm: CLS
|
21 |
+
Transform:
|
22 |
+
Backbone:
|
23 |
+
name: MobileNetV3
|
24 |
+
scale: 0.35
|
25 |
+
model_name: small
|
26 |
+
Neck:
|
27 |
+
Head:
|
28 |
+
name: ClsHead
|
29 |
+
class_dim: 2
|
30 |
+
|
31 |
+
Loss:
|
32 |
+
name: ClsLoss
|
33 |
+
|
34 |
+
Optimizer:
|
35 |
+
name: Adam
|
36 |
+
beta1: 0.9
|
37 |
+
beta2: 0.999
|
38 |
+
lr:
|
39 |
+
name: Cosine
|
40 |
+
learning_rate: 0.001
|
41 |
+
regularizer:
|
42 |
+
name: 'L2'
|
43 |
+
factor: 0
|
44 |
+
|
45 |
+
PostProcess:
|
46 |
+
name: ClsPostProcess
|
47 |
+
|
48 |
+
Metric:
|
49 |
+
name: ClsMetric
|
50 |
+
main_indicator: acc
|
51 |
+
|
52 |
+
Train:
|
53 |
+
dataset:
|
54 |
+
name: SimpleDataSet
|
55 |
+
data_dir: ./train_data/cls
|
56 |
+
label_file_list:
|
57 |
+
- ./train_data/cls/train.txt
|
58 |
+
transforms:
|
59 |
+
- DecodeImage: # load image
|
60 |
+
img_mode: BGR
|
61 |
+
channel_first: False
|
62 |
+
- ClsLabelEncode: # Class handling label
|
63 |
+
- BaseDataAugmentation:
|
64 |
+
- RandAugment:
|
65 |
+
- ClsResizeImg:
|
66 |
+
image_shape: [3, 48, 192]
|
67 |
+
- KeepKeys:
|
68 |
+
keep_keys: ['image', 'label'] # dataloader will return list in this order
|
69 |
+
loader:
|
70 |
+
shuffle: True
|
71 |
+
batch_size_per_card: 512
|
72 |
+
drop_last: True
|
73 |
+
num_workers: 8
|
74 |
+
|
75 |
+
Eval:
|
76 |
+
dataset:
|
77 |
+
name: SimpleDataSet
|
78 |
+
data_dir: ./train_data/cls
|
79 |
+
label_file_list:
|
80 |
+
- ./train_data/cls/test.txt
|
81 |
+
transforms:
|
82 |
+
- DecodeImage: # load image
|
83 |
+
img_mode: BGR
|
84 |
+
channel_first: False
|
85 |
+
- ClsLabelEncode: # Class handling label
|
86 |
+
- ClsResizeImg:
|
87 |
+
image_shape: [3, 48, 192]
|
88 |
+
- KeepKeys:
|
89 |
+
keep_keys: ['image', 'label'] # dataloader will return list in this order
|
90 |
+
loader:
|
91 |
+
shuffle: False
|
92 |
+
drop_last: False
|
93 |
+
batch_size_per_card: 512
|
94 |
+
num_workers: 4
|
Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
use_gpu: true
|
3 |
+
epoch_num: 1200
|
4 |
+
log_smooth_window: 20
|
5 |
+
print_batch_step: 2
|
6 |
+
save_model_dir: ./output/ch_db_mv3/
|
7 |
+
save_epoch_step: 1200
|
8 |
+
# evaluation is run every 5000 iterations after the 4000th iteration
|
9 |
+
eval_batch_step: [3000, 2000]
|
10 |
+
cal_metric_during_train: False
|
11 |
+
pretrained_model:
|
12 |
+
checkpoints:
|
13 |
+
save_inference_dir:
|
14 |
+
use_visualdl: False
|
15 |
+
infer_img: doc/imgs_en/img_10.jpg
|
16 |
+
save_res_path: ./output/det_db/predicts_db.txt
|
17 |
+
use_amp: False
|
18 |
+
amp_level: O2
|
19 |
+
amp_dtype: bfloat16
|
20 |
+
|
21 |
+
Architecture:
|
22 |
+
name: DistillationModel
|
23 |
+
algorithm: Distillation
|
24 |
+
model_type: det
|
25 |
+
Models:
|
26 |
+
Teacher:
|
27 |
+
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
28 |
+
freeze_params: true
|
29 |
+
return_all_feats: false
|
30 |
+
model_type: det
|
31 |
+
algorithm: DB
|
32 |
+
Transform:
|
33 |
+
Backbone:
|
34 |
+
name: ResNet_vd
|
35 |
+
layers: 18
|
36 |
+
Neck:
|
37 |
+
name: DBFPN
|
38 |
+
out_channels: 256
|
39 |
+
Head:
|
40 |
+
name: DBHead
|
41 |
+
k: 50
|
42 |
+
Student:
|
43 |
+
pretrained:
|
44 |
+
freeze_params: false
|
45 |
+
return_all_feats: false
|
46 |
+
model_type: det
|
47 |
+
algorithm: DB
|
48 |
+
Backbone:
|
49 |
+
name: MobileNetV3
|
50 |
+
scale: 0.5
|
51 |
+
model_name: large
|
52 |
+
disable_se: True
|
53 |
+
Neck:
|
54 |
+
name: DBFPN
|
55 |
+
out_channels: 96
|
56 |
+
Head:
|
57 |
+
name: DBHead
|
58 |
+
k: 50
|
59 |
+
Student2:
|
60 |
+
pretrained:
|
61 |
+
freeze_params: false
|
62 |
+
return_all_feats: false
|
63 |
+
model_type: det
|
64 |
+
algorithm: DB
|
65 |
+
Transform:
|
66 |
+
Backbone:
|
67 |
+
name: MobileNetV3
|
68 |
+
scale: 0.5
|
69 |
+
model_name: large
|
70 |
+
disable_se: True
|
71 |
+
Neck:
|
72 |
+
name: DBFPN
|
73 |
+
out_channels: 96
|
74 |
+
Head:
|
75 |
+
name: DBHead
|
76 |
+
k: 50
|
77 |
+
|
78 |
+
Loss:
|
79 |
+
name: CombinedLoss
|
80 |
+
loss_config_list:
|
81 |
+
- DistillationDilaDBLoss:
|
82 |
+
weight: 1.0
|
83 |
+
model_name_pairs:
|
84 |
+
- ["Student", "Teacher"]
|
85 |
+
- ["Student2", "Teacher"]
|
86 |
+
key: maps
|
87 |
+
balance_loss: true
|
88 |
+
main_loss_type: DiceLoss
|
89 |
+
alpha: 5
|
90 |
+
beta: 10
|
91 |
+
ohem_ratio: 3
|
92 |
+
- DistillationDMLLoss:
|
93 |
+
model_name_pairs:
|
94 |
+
- ["Student", "Student2"]
|
95 |
+
maps_name: "thrink_maps"
|
96 |
+
weight: 1.0
|
97 |
+
# act: None
|
98 |
+
model_name_pairs: ["Student", "Student2"]
|
99 |
+
key: maps
|
100 |
+
- DistillationDBLoss:
|
101 |
+
weight: 1.0
|
102 |
+
model_name_list: ["Student", "Student2"]
|
103 |
+
# key: maps
|
104 |
+
# name: DBLoss
|
105 |
+
balance_loss: true
|
106 |
+
main_loss_type: DiceLoss
|
107 |
+
alpha: 5
|
108 |
+
beta: 10
|
109 |
+
ohem_ratio: 3
|
110 |
+
|
111 |
+
|
112 |
+
Optimizer:
|
113 |
+
name: Adam
|
114 |
+
beta1: 0.9
|
115 |
+
beta2: 0.999
|
116 |
+
lr:
|
117 |
+
name: Cosine
|
118 |
+
learning_rate: 0.001
|
119 |
+
warmup_epoch: 2
|
120 |
+
regularizer:
|
121 |
+
name: 'L2'
|
122 |
+
factor: 0
|
123 |
+
|
124 |
+
PostProcess:
|
125 |
+
name: DistillationDBPostProcess
|
126 |
+
model_name: ["Student", "Student2", "Teacher"]
|
127 |
+
# key: maps
|
128 |
+
thresh: 0.3
|
129 |
+
box_thresh: 0.6
|
130 |
+
max_candidates: 1000
|
131 |
+
unclip_ratio: 1.5
|
132 |
+
|
133 |
+
Metric:
|
134 |
+
name: DistillationMetric
|
135 |
+
base_metric_name: DetMetric
|
136 |
+
main_indicator: hmean
|
137 |
+
key: "Student"
|
138 |
+
|
139 |
+
Train:
|
140 |
+
dataset:
|
141 |
+
name: SimpleDataSet
|
142 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
143 |
+
label_file_list:
|
144 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
145 |
+
ratio_list: [1.0]
|
146 |
+
transforms:
|
147 |
+
- DecodeImage: # load image
|
148 |
+
img_mode: BGR
|
149 |
+
channel_first: False
|
150 |
+
- DetLabelEncode: # Class handling label
|
151 |
+
- CopyPaste:
|
152 |
+
- IaaAugment:
|
153 |
+
augmenter_args:
|
154 |
+
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
155 |
+
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
156 |
+
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
157 |
+
- EastRandomCropData:
|
158 |
+
size: [960, 960]
|
159 |
+
max_tries: 50
|
160 |
+
keep_ratio: true
|
161 |
+
- MakeBorderMap:
|
162 |
+
shrink_ratio: 0.4
|
163 |
+
thresh_min: 0.3
|
164 |
+
thresh_max: 0.7
|
165 |
+
- MakeShrinkMap:
|
166 |
+
shrink_ratio: 0.4
|
167 |
+
min_text_size: 8
|
168 |
+
- NormalizeImage:
|
169 |
+
scale: 1./255.
|
170 |
+
mean: [0.485, 0.456, 0.406]
|
171 |
+
std: [0.229, 0.224, 0.225]
|
172 |
+
order: 'hwc'
|
173 |
+
- ToCHWImage:
|
174 |
+
- KeepKeys:
|
175 |
+
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
176 |
+
loader:
|
177 |
+
shuffle: True
|
178 |
+
drop_last: False
|
179 |
+
batch_size_per_card: 8
|
180 |
+
num_workers: 4
|
181 |
+
|
182 |
+
Eval:
|
183 |
+
dataset:
|
184 |
+
name: SimpleDataSet
|
185 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
186 |
+
label_file_list:
|
187 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
188 |
+
transforms:
|
189 |
+
- DecodeImage: # load image
|
190 |
+
img_mode: BGR
|
191 |
+
channel_first: False
|
192 |
+
- DetLabelEncode: # Class handling label
|
193 |
+
- DetResizeForTest:
|
194 |
+
- NormalizeImage:
|
195 |
+
scale: 1./255.
|
196 |
+
mean: [0.485, 0.456, 0.406]
|
197 |
+
std: [0.229, 0.224, 0.225]
|
198 |
+
order: 'hwc'
|
199 |
+
- ToCHWImage:
|
200 |
+
- KeepKeys:
|
201 |
+
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
202 |
+
loader:
|
203 |
+
shuffle: False
|
204 |
+
drop_last: False
|
205 |
+
batch_size_per_card: 1 # must be 1
|
206 |
+
num_workers: 2
|
Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
use_gpu: true
|
3 |
+
epoch_num: 1200
|
4 |
+
log_smooth_window: 20
|
5 |
+
print_batch_step: 2
|
6 |
+
save_model_dir: ./output/ch_db_mv3/
|
7 |
+
save_epoch_step: 1200
|
8 |
+
# evaluation is run every 5000 iterations after the 4000th iteration
|
9 |
+
eval_batch_step: [3000, 2000]
|
10 |
+
cal_metric_during_train: False
|
11 |
+
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
12 |
+
checkpoints:
|
13 |
+
save_inference_dir:
|
14 |
+
use_visualdl: False
|
15 |
+
infer_img: doc/imgs_en/img_10.jpg
|
16 |
+
save_res_path: ./output/det_db/predicts_db.txt
|
17 |
+
|
18 |
+
Architecture:
|
19 |
+
name: DistillationModel
|
20 |
+
algorithm: Distillation
|
21 |
+
model_type: det
|
22 |
+
Models:
|
23 |
+
Student:
|
24 |
+
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
25 |
+
freeze_params: false
|
26 |
+
return_all_feats: false
|
27 |
+
model_type: det
|
28 |
+
algorithm: DB
|
29 |
+
Backbone:
|
30 |
+
name: MobileNetV3
|
31 |
+
scale: 0.5
|
32 |
+
model_name: large
|
33 |
+
disable_se: True
|
34 |
+
Neck:
|
35 |
+
name: DBFPN
|
36 |
+
out_channels: 96
|
37 |
+
Head:
|
38 |
+
name: DBHead
|
39 |
+
k: 50
|
40 |
+
Teacher:
|
41 |
+
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
42 |
+
freeze_params: true
|
43 |
+
return_all_feats: false
|
44 |
+
model_type: det
|
45 |
+
algorithm: DB
|
46 |
+
Transform:
|
47 |
+
Backbone:
|
48 |
+
name: ResNet_vd
|
49 |
+
layers: 18
|
50 |
+
Neck:
|
51 |
+
name: DBFPN
|
52 |
+
out_channels: 256
|
53 |
+
Head:
|
54 |
+
name: DBHead
|
55 |
+
k: 50
|
56 |
+
|
57 |
+
Loss:
|
58 |
+
name: CombinedLoss
|
59 |
+
loss_config_list:
|
60 |
+
- DistillationDilaDBLoss:
|
61 |
+
weight: 1.0
|
62 |
+
model_name_pairs:
|
63 |
+
- ["Student", "Teacher"]
|
64 |
+
key: maps
|
65 |
+
balance_loss: true
|
66 |
+
main_loss_type: DiceLoss
|
67 |
+
alpha: 5
|
68 |
+
beta: 10
|
69 |
+
ohem_ratio: 3
|
70 |
+
- DistillationDBLoss:
|
71 |
+
weight: 1.0
|
72 |
+
model_name_list: ["Student"]
|
73 |
+
name: DBLoss
|
74 |
+
balance_loss: true
|
75 |
+
main_loss_type: DiceLoss
|
76 |
+
alpha: 5
|
77 |
+
beta: 10
|
78 |
+
ohem_ratio: 3
|
79 |
+
|
80 |
+
Optimizer:
|
81 |
+
name: Adam
|
82 |
+
beta1: 0.9
|
83 |
+
beta2: 0.999
|
84 |
+
lr:
|
85 |
+
name: Cosine
|
86 |
+
learning_rate: 0.001
|
87 |
+
warmup_epoch: 2
|
88 |
+
regularizer:
|
89 |
+
name: 'L2'
|
90 |
+
factor: 0
|
91 |
+
|
92 |
+
PostProcess:
|
93 |
+
name: DistillationDBPostProcess
|
94 |
+
model_name: ["Student"]
|
95 |
+
key: head_out
|
96 |
+
thresh: 0.3
|
97 |
+
box_thresh: 0.6
|
98 |
+
max_candidates: 1000
|
99 |
+
unclip_ratio: 1.5
|
100 |
+
|
101 |
+
Metric:
|
102 |
+
name: DistillationMetric
|
103 |
+
base_metric_name: DetMetric
|
104 |
+
main_indicator: hmean
|
105 |
+
key: "Student"
|
106 |
+
|
107 |
+
Train:
|
108 |
+
dataset:
|
109 |
+
name: SimpleDataSet
|
110 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
111 |
+
label_file_list:
|
112 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
113 |
+
ratio_list: [1.0]
|
114 |
+
transforms:
|
115 |
+
- DecodeImage: # load image
|
116 |
+
img_mode: BGR
|
117 |
+
channel_first: False
|
118 |
+
- DetLabelEncode: # Class handling label
|
119 |
+
- CopyPaste:
|
120 |
+
- IaaAugment:
|
121 |
+
augmenter_args:
|
122 |
+
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
123 |
+
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
124 |
+
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
125 |
+
- EastRandomCropData:
|
126 |
+
size: [960, 960]
|
127 |
+
max_tries: 50
|
128 |
+
keep_ratio: true
|
129 |
+
- MakeBorderMap:
|
130 |
+
shrink_ratio: 0.4
|
131 |
+
thresh_min: 0.3
|
132 |
+
thresh_max: 0.7
|
133 |
+
- MakeShrinkMap:
|
134 |
+
shrink_ratio: 0.4
|
135 |
+
min_text_size: 8
|
136 |
+
- NormalizeImage:
|
137 |
+
scale: 1./255.
|
138 |
+
mean: [0.485, 0.456, 0.406]
|
139 |
+
std: [0.229, 0.224, 0.225]
|
140 |
+
order: 'hwc'
|
141 |
+
- ToCHWImage:
|
142 |
+
- KeepKeys:
|
143 |
+
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
144 |
+
loader:
|
145 |
+
shuffle: True
|
146 |
+
drop_last: False
|
147 |
+
batch_size_per_card: 8
|
148 |
+
num_workers: 4
|
149 |
+
|
150 |
+
Eval:
|
151 |
+
dataset:
|
152 |
+
name: SimpleDataSet
|
153 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
154 |
+
label_file_list:
|
155 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
156 |
+
transforms:
|
157 |
+
- DecodeImage: # load image
|
158 |
+
img_mode: BGR
|
159 |
+
channel_first: False
|
160 |
+
- DetLabelEncode: # Class handling label
|
161 |
+
- DetResizeForTest:
|
162 |
+
# image_shape: [736, 1280]
|
163 |
+
- NormalizeImage:
|
164 |
+
scale: 1./255.
|
165 |
+
mean: [0.485, 0.456, 0.406]
|
166 |
+
std: [0.229, 0.224, 0.225]
|
167 |
+
order: 'hwc'
|
168 |
+
- ToCHWImage:
|
169 |
+
- KeepKeys:
|
170 |
+
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
171 |
+
loader:
|
172 |
+
shuffle: False
|
173 |
+
drop_last: False
|
174 |
+
batch_size_per_card: 1 # must be 1
|
175 |
+
num_workers: 2
|
Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
use_gpu: true
|
3 |
+
epoch_num: 1200
|
4 |
+
log_smooth_window: 20
|
5 |
+
print_batch_step: 2
|
6 |
+
save_model_dir: ./output/ch_db_mv3/
|
7 |
+
save_epoch_step: 1200
|
8 |
+
# evaluation is run every 5000 iterations after the 4000th iteration
|
9 |
+
eval_batch_step: [3000, 2000]
|
10 |
+
cal_metric_during_train: False
|
11 |
+
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
12 |
+
checkpoints:
|
13 |
+
save_inference_dir:
|
14 |
+
use_visualdl: False
|
15 |
+
infer_img: doc/imgs_en/img_10.jpg
|
16 |
+
save_res_path: ./output/det_db/predicts_db.txt
|
17 |
+
|
18 |
+
Architecture:
|
19 |
+
name: DistillationModel
|
20 |
+
algorithm: Distillation
|
21 |
+
model_type: det
|
22 |
+
Models:
|
23 |
+
Student:
|
24 |
+
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
25 |
+
freeze_params: false
|
26 |
+
return_all_feats: false
|
27 |
+
model_type: det
|
28 |
+
algorithm: DB
|
29 |
+
Backbone:
|
30 |
+
name: MobileNetV3
|
31 |
+
scale: 0.5
|
32 |
+
model_name: large
|
33 |
+
disable_se: True
|
34 |
+
Neck:
|
35 |
+
name: DBFPN
|
36 |
+
out_channels: 96
|
37 |
+
Head:
|
38 |
+
name: DBHead
|
39 |
+
k: 50
|
40 |
+
Teacher:
|
41 |
+
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
42 |
+
freeze_params: false
|
43 |
+
return_all_feats: false
|
44 |
+
model_type: det
|
45 |
+
algorithm: DB
|
46 |
+
Transform:
|
47 |
+
Backbone:
|
48 |
+
name: MobileNetV3
|
49 |
+
scale: 0.5
|
50 |
+
model_name: large
|
51 |
+
disable_se: True
|
52 |
+
Neck:
|
53 |
+
name: DBFPN
|
54 |
+
out_channels: 96
|
55 |
+
Head:
|
56 |
+
name: DBHead
|
57 |
+
k: 50
|
58 |
+
|
59 |
+
|
60 |
+
Loss:
|
61 |
+
name: CombinedLoss
|
62 |
+
loss_config_list:
|
63 |
+
- DistillationDMLLoss:
|
64 |
+
model_name_pairs:
|
65 |
+
- ["Student", "Teacher"]
|
66 |
+
maps_name: "thrink_maps"
|
67 |
+
weight: 1.0
|
68 |
+
# act: None
|
69 |
+
model_name_pairs: ["Student", "Teacher"]
|
70 |
+
key: maps
|
71 |
+
- DistillationDBLoss:
|
72 |
+
weight: 1.0
|
73 |
+
model_name_list: ["Student", "Teacher"]
|
74 |
+
# key: maps
|
75 |
+
name: DBLoss
|
76 |
+
balance_loss: true
|
77 |
+
main_loss_type: DiceLoss
|
78 |
+
alpha: 5
|
79 |
+
beta: 10
|
80 |
+
ohem_ratio: 3
|
81 |
+
|
82 |
+
|
83 |
+
Optimizer:
|
84 |
+
name: Adam
|
85 |
+
beta1: 0.9
|
86 |
+
beta2: 0.999
|
87 |
+
lr:
|
88 |
+
name: Cosine
|
89 |
+
learning_rate: 0.001
|
90 |
+
warmup_epoch: 2
|
91 |
+
regularizer:
|
92 |
+
name: 'L2'
|
93 |
+
factor: 0
|
94 |
+
|
95 |
+
PostProcess:
|
96 |
+
name: DistillationDBPostProcess
|
97 |
+
model_name: ["Student", "Teacher"]
|
98 |
+
key: head_out
|
99 |
+
thresh: 0.3
|
100 |
+
box_thresh: 0.6
|
101 |
+
max_candidates: 1000
|
102 |
+
unclip_ratio: 1.5
|
103 |
+
|
104 |
+
Metric:
|
105 |
+
name: DistillationMetric
|
106 |
+
base_metric_name: DetMetric
|
107 |
+
main_indicator: hmean
|
108 |
+
key: "Student"
|
109 |
+
|
110 |
+
Train:
|
111 |
+
dataset:
|
112 |
+
name: SimpleDataSet
|
113 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
114 |
+
label_file_list:
|
115 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
116 |
+
ratio_list: [1.0]
|
117 |
+
transforms:
|
118 |
+
- DecodeImage: # load image
|
119 |
+
img_mode: BGR
|
120 |
+
channel_first: False
|
121 |
+
- DetLabelEncode: # Class handling label
|
122 |
+
- CopyPaste:
|
123 |
+
- IaaAugment:
|
124 |
+
augmenter_args:
|
125 |
+
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
126 |
+
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
127 |
+
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
128 |
+
- EastRandomCropData:
|
129 |
+
size: [960, 960]
|
130 |
+
max_tries: 50
|
131 |
+
keep_ratio: true
|
132 |
+
- MakeBorderMap:
|
133 |
+
shrink_ratio: 0.4
|
134 |
+
thresh_min: 0.3
|
135 |
+
thresh_max: 0.7
|
136 |
+
- MakeShrinkMap:
|
137 |
+
shrink_ratio: 0.4
|
138 |
+
min_text_size: 8
|
139 |
+
- NormalizeImage:
|
140 |
+
scale: 1./255.
|
141 |
+
mean: [0.485, 0.456, 0.406]
|
142 |
+
std: [0.229, 0.224, 0.225]
|
143 |
+
order: 'hwc'
|
144 |
+
- ToCHWImage:
|
145 |
+
- KeepKeys:
|
146 |
+
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
147 |
+
loader:
|
148 |
+
shuffle: True
|
149 |
+
drop_last: False
|
150 |
+
batch_size_per_card: 8
|
151 |
+
num_workers: 4
|
152 |
+
|
153 |
+
Eval:
|
154 |
+
dataset:
|
155 |
+
name: SimpleDataSet
|
156 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
157 |
+
label_file_list:
|
158 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
159 |
+
transforms:
|
160 |
+
- DecodeImage: # load image
|
161 |
+
img_mode: BGR
|
162 |
+
channel_first: False
|
163 |
+
- DetLabelEncode: # Class handling label
|
164 |
+
- DetResizeForTest:
|
165 |
+
# image_shape: [736, 1280]
|
166 |
+
- NormalizeImage:
|
167 |
+
scale: 1./255.
|
168 |
+
mean: [0.485, 0.456, 0.406]
|
169 |
+
std: [0.229, 0.224, 0.225]
|
170 |
+
order: 'hwc'
|
171 |
+
- ToCHWImage:
|
172 |
+
- KeepKeys:
|
173 |
+
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
174 |
+
loader:
|
175 |
+
shuffle: False
|
176 |
+
drop_last: False
|
177 |
+
batch_size_per_card: 1 # must be 1
|
178 |
+
num_workers: 2
|
Rotate/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
use_gpu: true
|
3 |
+
epoch_num: 1200
|
4 |
+
log_smooth_window: 20
|
5 |
+
print_batch_step: 10
|
6 |
+
save_model_dir: ./output/ch_db_mv3/
|
7 |
+
save_epoch_step: 1200
|
8 |
+
# evaluation is run every 5000 iterations after the 4000th iteration
|
9 |
+
eval_batch_step: [0, 400]
|
10 |
+
cal_metric_during_train: False
|
11 |
+
pretrained_model: ./pretrain_models/student.pdparams
|
12 |
+
checkpoints:
|
13 |
+
save_inference_dir:
|
14 |
+
use_visualdl: False
|
15 |
+
infer_img: doc/imgs_en/img_10.jpg
|
16 |
+
save_res_path: ./output/det_db/predicts_db.txt
|
17 |
+
|
18 |
+
Architecture:
|
19 |
+
model_type: det
|
20 |
+
algorithm: DB
|
21 |
+
Transform:
|
22 |
+
Backbone:
|
23 |
+
name: MobileNetV3
|
24 |
+
scale: 0.5
|
25 |
+
model_name: large
|
26 |
+
disable_se: True
|
27 |
+
Neck:
|
28 |
+
name: DBFPN
|
29 |
+
out_channels: 96
|
30 |
+
Head:
|
31 |
+
name: DBHead
|
32 |
+
k: 50
|
33 |
+
|
34 |
+
Loss:
|
35 |
+
name: DBLoss
|
36 |
+
balance_loss: true
|
37 |
+
main_loss_type: DiceLoss
|
38 |
+
alpha: 5
|
39 |
+
beta: 10
|
40 |
+
ohem_ratio: 3
|
41 |
+
|
42 |
+
Optimizer:
|
43 |
+
name: Adam
|
44 |
+
beta1: 0.9
|
45 |
+
beta2: 0.999
|
46 |
+
lr:
|
47 |
+
name: Cosine
|
48 |
+
learning_rate: 0.001
|
49 |
+
warmup_epoch: 2
|
50 |
+
regularizer:
|
51 |
+
name: 'L2'
|
52 |
+
factor: 0
|
53 |
+
|
54 |
+
PostProcess:
|
55 |
+
name: DBPostProcess
|
56 |
+
thresh: 0.3
|
57 |
+
box_thresh: 0.6
|
58 |
+
max_candidates: 1000
|
59 |
+
unclip_ratio: 1.5
|
60 |
+
|
61 |
+
Metric:
|
62 |
+
name: DetMetric
|
63 |
+
main_indicator: hmean
|
64 |
+
|
65 |
+
Train:
|
66 |
+
dataset:
|
67 |
+
name: SimpleDataSet
|
68 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
69 |
+
label_file_list:
|
70 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
71 |
+
ratio_list: [1.0]
|
72 |
+
transforms:
|
73 |
+
- DecodeImage: # load image
|
74 |
+
img_mode: BGR
|
75 |
+
channel_first: False
|
76 |
+
- DetLabelEncode: # Class handling label
|
77 |
+
- IaaAugment:
|
78 |
+
augmenter_args:
|
79 |
+
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
80 |
+
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
81 |
+
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
82 |
+
- EastRandomCropData:
|
83 |
+
size: [960, 960]
|
84 |
+
max_tries: 50
|
85 |
+
keep_ratio: true
|
86 |
+
- MakeBorderMap:
|
87 |
+
shrink_ratio: 0.4
|
88 |
+
thresh_min: 0.3
|
89 |
+
thresh_max: 0.7
|
90 |
+
- MakeShrinkMap:
|
91 |
+
shrink_ratio: 0.4
|
92 |
+
min_text_size: 8
|
93 |
+
- NormalizeImage:
|
94 |
+
scale: 1./255.
|
95 |
+
mean: [0.485, 0.456, 0.406]
|
96 |
+
std: [0.229, 0.224, 0.225]
|
97 |
+
order: 'hwc'
|
98 |
+
- ToCHWImage:
|
99 |
+
- KeepKeys:
|
100 |
+
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
101 |
+
loader:
|
102 |
+
shuffle: True
|
103 |
+
drop_last: False
|
104 |
+
batch_size_per_card: 8
|
105 |
+
num_workers: 4
|
106 |
+
|
107 |
+
Eval:
|
108 |
+
dataset:
|
109 |
+
name: SimpleDataSet
|
110 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
111 |
+
label_file_list:
|
112 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
113 |
+
transforms:
|
114 |
+
- DecodeImage: # load image
|
115 |
+
img_mode: BGR
|
116 |
+
channel_first: False
|
117 |
+
- DetLabelEncode: # Class handling label
|
118 |
+
- DetResizeForTest:
|
119 |
+
# image_shape: [736, 1280]
|
120 |
+
- NormalizeImage:
|
121 |
+
scale: 1./255.
|
122 |
+
mean: [0.485, 0.456, 0.406]
|
123 |
+
std: [0.229, 0.224, 0.225]
|
124 |
+
order: 'hwc'
|
125 |
+
- ToCHWImage:
|
126 |
+
- KeepKeys:
|
127 |
+
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
128 |
+
loader:
|
129 |
+
shuffle: False
|
130 |
+
drop_last: False
|
131 |
+
batch_size_per_card: 1 # must be 1
|
132 |
+
num_workers: 2
|
Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
debug: false
|
3 |
+
use_gpu: true
|
4 |
+
epoch_num: 500
|
5 |
+
log_smooth_window: 20
|
6 |
+
print_batch_step: 10
|
7 |
+
save_model_dir: ./output/ch_PP-OCR_v3_det/
|
8 |
+
save_epoch_step: 100
|
9 |
+
eval_batch_step:
|
10 |
+
- 0
|
11 |
+
- 400
|
12 |
+
cal_metric_during_train: false
|
13 |
+
pretrained_model: null
|
14 |
+
checkpoints: null
|
15 |
+
save_inference_dir: null
|
16 |
+
use_visualdl: false
|
17 |
+
infer_img: doc/imgs_en/img_10.jpg
|
18 |
+
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
19 |
+
distributed: true
|
20 |
+
d2s_train_image_shape: [3, -1, -1]
|
21 |
+
amp_dtype: bfloat16
|
22 |
+
|
23 |
+
Architecture:
|
24 |
+
name: DistillationModel
|
25 |
+
algorithm: Distillation
|
26 |
+
model_type: det
|
27 |
+
Models:
|
28 |
+
Student:
|
29 |
+
pretrained:
|
30 |
+
model_type: det
|
31 |
+
algorithm: DB
|
32 |
+
Transform: null
|
33 |
+
Backbone:
|
34 |
+
name: MobileNetV3
|
35 |
+
scale: 0.5
|
36 |
+
model_name: large
|
37 |
+
disable_se: true
|
38 |
+
Neck:
|
39 |
+
name: RSEFPN
|
40 |
+
out_channels: 96
|
41 |
+
shortcut: True
|
42 |
+
Head:
|
43 |
+
name: DBHead
|
44 |
+
k: 50
|
45 |
+
Student2:
|
46 |
+
pretrained:
|
47 |
+
model_type: det
|
48 |
+
algorithm: DB
|
49 |
+
Transform: null
|
50 |
+
Backbone:
|
51 |
+
name: MobileNetV3
|
52 |
+
scale: 0.5
|
53 |
+
model_name: large
|
54 |
+
disable_se: true
|
55 |
+
Neck:
|
56 |
+
name: RSEFPN
|
57 |
+
out_channels: 96
|
58 |
+
shortcut: True
|
59 |
+
Head:
|
60 |
+
name: DBHead
|
61 |
+
k: 50
|
62 |
+
Teacher:
|
63 |
+
freeze_params: true
|
64 |
+
return_all_feats: false
|
65 |
+
model_type: det
|
66 |
+
algorithm: DB
|
67 |
+
Backbone:
|
68 |
+
name: ResNet_vd
|
69 |
+
in_channels: 3
|
70 |
+
layers: 50
|
71 |
+
Neck:
|
72 |
+
name: LKPAN
|
73 |
+
out_channels: 256
|
74 |
+
Head:
|
75 |
+
name: DBHead
|
76 |
+
kernel_list: [7,2,2]
|
77 |
+
k: 50
|
78 |
+
|
79 |
+
Loss:
|
80 |
+
name: CombinedLoss
|
81 |
+
loss_config_list:
|
82 |
+
- DistillationDilaDBLoss:
|
83 |
+
weight: 1.0
|
84 |
+
model_name_pairs:
|
85 |
+
- ["Student", "Teacher"]
|
86 |
+
- ["Student2", "Teacher"]
|
87 |
+
key: maps
|
88 |
+
balance_loss: true
|
89 |
+
main_loss_type: DiceLoss
|
90 |
+
alpha: 5
|
91 |
+
beta: 10
|
92 |
+
ohem_ratio: 3
|
93 |
+
- DistillationDMLLoss:
|
94 |
+
model_name_pairs:
|
95 |
+
- ["Student", "Student2"]
|
96 |
+
maps_name: "thrink_maps"
|
97 |
+
weight: 1.0
|
98 |
+
model_name_pairs: ["Student", "Student2"]
|
99 |
+
key: maps
|
100 |
+
- DistillationDBLoss:
|
101 |
+
weight: 1.0
|
102 |
+
model_name_list: ["Student", "Student2"]
|
103 |
+
balance_loss: true
|
104 |
+
main_loss_type: DiceLoss
|
105 |
+
alpha: 5
|
106 |
+
beta: 10
|
107 |
+
ohem_ratio: 3
|
108 |
+
|
109 |
+
Optimizer:
|
110 |
+
name: Adam
|
111 |
+
beta1: 0.9
|
112 |
+
beta2: 0.999
|
113 |
+
lr:
|
114 |
+
name: Cosine
|
115 |
+
learning_rate: 0.001
|
116 |
+
warmup_epoch: 2
|
117 |
+
regularizer:
|
118 |
+
name: L2
|
119 |
+
factor: 5.0e-05
|
120 |
+
|
121 |
+
PostProcess:
|
122 |
+
name: DistillationDBPostProcess
|
123 |
+
model_name: ["Student"]
|
124 |
+
key: head_out
|
125 |
+
thresh: 0.3
|
126 |
+
box_thresh: 0.6
|
127 |
+
max_candidates: 1000
|
128 |
+
unclip_ratio: 1.5
|
129 |
+
|
130 |
+
Metric:
|
131 |
+
name: DistillationMetric
|
132 |
+
base_metric_name: DetMetric
|
133 |
+
main_indicator: hmean
|
134 |
+
key: "Student"
|
135 |
+
|
136 |
+
Train:
|
137 |
+
dataset:
|
138 |
+
name: SimpleDataSet
|
139 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
140 |
+
label_file_list:
|
141 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
142 |
+
ratio_list: [1.0]
|
143 |
+
transforms:
|
144 |
+
- DecodeImage:
|
145 |
+
img_mode: BGR
|
146 |
+
channel_first: false
|
147 |
+
- DetLabelEncode: null
|
148 |
+
- CopyPaste:
|
149 |
+
- IaaAugment:
|
150 |
+
augmenter_args:
|
151 |
+
- type: Fliplr
|
152 |
+
args:
|
153 |
+
p: 0.5
|
154 |
+
- type: Affine
|
155 |
+
args:
|
156 |
+
rotate:
|
157 |
+
- -10
|
158 |
+
- 10
|
159 |
+
- type: Resize
|
160 |
+
args:
|
161 |
+
size:
|
162 |
+
- 0.5
|
163 |
+
- 3
|
164 |
+
- EastRandomCropData:
|
165 |
+
size:
|
166 |
+
- 960
|
167 |
+
- 960
|
168 |
+
max_tries: 50
|
169 |
+
keep_ratio: true
|
170 |
+
- MakeBorderMap:
|
171 |
+
shrink_ratio: 0.4
|
172 |
+
thresh_min: 0.3
|
173 |
+
thresh_max: 0.7
|
174 |
+
- MakeShrinkMap:
|
175 |
+
shrink_ratio: 0.4
|
176 |
+
min_text_size: 8
|
177 |
+
- NormalizeImage:
|
178 |
+
scale: 1./255.
|
179 |
+
mean:
|
180 |
+
- 0.485
|
181 |
+
- 0.456
|
182 |
+
- 0.406
|
183 |
+
std:
|
184 |
+
- 0.229
|
185 |
+
- 0.224
|
186 |
+
- 0.225
|
187 |
+
order: hwc
|
188 |
+
- ToCHWImage: null
|
189 |
+
- KeepKeys:
|
190 |
+
keep_keys:
|
191 |
+
- image
|
192 |
+
- threshold_map
|
193 |
+
- threshold_mask
|
194 |
+
- shrink_map
|
195 |
+
- shrink_mask
|
196 |
+
loader:
|
197 |
+
shuffle: true
|
198 |
+
drop_last: false
|
199 |
+
batch_size_per_card: 8
|
200 |
+
num_workers: 4
|
201 |
+
|
202 |
+
Eval:
|
203 |
+
dataset:
|
204 |
+
name: SimpleDataSet
|
205 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
206 |
+
label_file_list:
|
207 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
208 |
+
transforms:
|
209 |
+
- DecodeImage: # load image
|
210 |
+
img_mode: BGR
|
211 |
+
channel_first: False
|
212 |
+
- DetLabelEncode: # Class handling label
|
213 |
+
- DetResizeForTest:
|
214 |
+
- NormalizeImage:
|
215 |
+
scale: 1./255.
|
216 |
+
mean: [0.485, 0.456, 0.406]
|
217 |
+
std: [0.229, 0.224, 0.225]
|
218 |
+
order: 'hwc'
|
219 |
+
- ToCHWImage:
|
220 |
+
- KeepKeys:
|
221 |
+
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
222 |
+
loader:
|
223 |
+
shuffle: False
|
224 |
+
drop_last: False
|
225 |
+
batch_size_per_card: 1 # must be 1
|
226 |
+
num_workers: 2
|
Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
use_gpu: true
|
3 |
+
epoch_num: 1200
|
4 |
+
log_smooth_window: 20
|
5 |
+
print_batch_step: 2
|
6 |
+
save_model_dir: ./output/ch_db_mv3/
|
7 |
+
save_epoch_step: 1200
|
8 |
+
# evaluation is run every 5000 iterations after the 4000th iteration
|
9 |
+
eval_batch_step: [3000, 2000]
|
10 |
+
cal_metric_during_train: False
|
11 |
+
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
12 |
+
checkpoints:
|
13 |
+
save_inference_dir:
|
14 |
+
use_visualdl: False
|
15 |
+
infer_img: doc/imgs_en/img_10.jpg
|
16 |
+
save_res_path: ./output/det_db/predicts_db.txt
|
17 |
+
|
18 |
+
Architecture:
|
19 |
+
name: DistillationModel
|
20 |
+
algorithm: Distillation
|
21 |
+
model_type: det
|
22 |
+
Models:
|
23 |
+
Student:
|
24 |
+
return_all_feats: false
|
25 |
+
model_type: det
|
26 |
+
algorithm: DB
|
27 |
+
Backbone:
|
28 |
+
name: ResNet_vd
|
29 |
+
in_channels: 3
|
30 |
+
layers: 50
|
31 |
+
Neck:
|
32 |
+
name: LKPAN
|
33 |
+
out_channels: 256
|
34 |
+
Head:
|
35 |
+
name: DBHead
|
36 |
+
kernel_list: [7,2,2]
|
37 |
+
k: 50
|
38 |
+
Student2:
|
39 |
+
return_all_feats: false
|
40 |
+
model_type: det
|
41 |
+
algorithm: DB
|
42 |
+
Backbone:
|
43 |
+
name: ResNet_vd
|
44 |
+
in_channels: 3
|
45 |
+
layers: 50
|
46 |
+
Neck:
|
47 |
+
name: LKPAN
|
48 |
+
out_channels: 256
|
49 |
+
Head:
|
50 |
+
name: DBHead
|
51 |
+
kernel_list: [7,2,2]
|
52 |
+
k: 50
|
53 |
+
|
54 |
+
|
55 |
+
Loss:
|
56 |
+
name: CombinedLoss
|
57 |
+
loss_config_list:
|
58 |
+
- DistillationDMLLoss:
|
59 |
+
model_name_pairs:
|
60 |
+
- ["Student", "Student2"]
|
61 |
+
maps_name: "thrink_maps"
|
62 |
+
weight: 1.0
|
63 |
+
# act: None
|
64 |
+
model_name_pairs: ["Student", "Student2"]
|
65 |
+
key: maps
|
66 |
+
- DistillationDBLoss:
|
67 |
+
weight: 1.0
|
68 |
+
model_name_list: ["Student", "Student2"]
|
69 |
+
# key: maps
|
70 |
+
name: DBLoss
|
71 |
+
balance_loss: true
|
72 |
+
main_loss_type: DiceLoss
|
73 |
+
alpha: 5
|
74 |
+
beta: 10
|
75 |
+
ohem_ratio: 3
|
76 |
+
|
77 |
+
|
78 |
+
Optimizer:
|
79 |
+
name: Adam
|
80 |
+
beta1: 0.9
|
81 |
+
beta2: 0.999
|
82 |
+
lr:
|
83 |
+
name: Cosine
|
84 |
+
learning_rate: 0.001
|
85 |
+
warmup_epoch: 2
|
86 |
+
regularizer:
|
87 |
+
name: 'L2'
|
88 |
+
factor: 0
|
89 |
+
|
90 |
+
PostProcess:
|
91 |
+
name: DistillationDBPostProcess
|
92 |
+
model_name: ["Student", "Student2"]
|
93 |
+
key: head_out
|
94 |
+
thresh: 0.3
|
95 |
+
box_thresh: 0.6
|
96 |
+
max_candidates: 1000
|
97 |
+
unclip_ratio: 1.5
|
98 |
+
|
99 |
+
Metric:
|
100 |
+
name: DistillationMetric
|
101 |
+
base_metric_name: DetMetric
|
102 |
+
main_indicator: hmean
|
103 |
+
key: "Student"
|
104 |
+
|
105 |
+
Train:
|
106 |
+
dataset:
|
107 |
+
name: SimpleDataSet
|
108 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
109 |
+
label_file_list:
|
110 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
111 |
+
ratio_list: [1.0]
|
112 |
+
transforms:
|
113 |
+
- DecodeImage: # load image
|
114 |
+
img_mode: BGR
|
115 |
+
channel_first: False
|
116 |
+
- DetLabelEncode: # Class handling label
|
117 |
+
- CopyPaste:
|
118 |
+
- IaaAugment:
|
119 |
+
augmenter_args:
|
120 |
+
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
121 |
+
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
122 |
+
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
123 |
+
- EastRandomCropData:
|
124 |
+
size: [960, 960]
|
125 |
+
max_tries: 50
|
126 |
+
keep_ratio: true
|
127 |
+
- MakeBorderMap:
|
128 |
+
shrink_ratio: 0.4
|
129 |
+
thresh_min: 0.3
|
130 |
+
thresh_max: 0.7
|
131 |
+
- MakeShrinkMap:
|
132 |
+
shrink_ratio: 0.4
|
133 |
+
min_text_size: 8
|
134 |
+
- NormalizeImage:
|
135 |
+
scale: 1./255.
|
136 |
+
mean: [0.485, 0.456, 0.406]
|
137 |
+
std: [0.229, 0.224, 0.225]
|
138 |
+
order: 'hwc'
|
139 |
+
- ToCHWImage:
|
140 |
+
- KeepKeys:
|
141 |
+
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
142 |
+
loader:
|
143 |
+
shuffle: True
|
144 |
+
drop_last: False
|
145 |
+
batch_size_per_card: 8
|
146 |
+
num_workers: 4
|
147 |
+
|
148 |
+
Eval:
|
149 |
+
dataset:
|
150 |
+
name: SimpleDataSet
|
151 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
152 |
+
label_file_list:
|
153 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
154 |
+
transforms:
|
155 |
+
- DecodeImage: # load image
|
156 |
+
img_mode: BGR
|
157 |
+
channel_first: False
|
158 |
+
- DetLabelEncode: # Class handling label
|
159 |
+
- DetResizeForTest:
|
160 |
+
# image_shape: [736, 1280]
|
161 |
+
- NormalizeImage:
|
162 |
+
scale: 1./255.
|
163 |
+
mean: [0.485, 0.456, 0.406]
|
164 |
+
std: [0.229, 0.224, 0.225]
|
165 |
+
order: 'hwc'
|
166 |
+
- ToCHWImage:
|
167 |
+
- KeepKeys:
|
168 |
+
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
169 |
+
loader:
|
170 |
+
shuffle: False
|
171 |
+
drop_last: False
|
172 |
+
batch_size_per_card: 1 # must be 1
|
173 |
+
num_workers: 2
|
Rotate/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
debug: false
|
3 |
+
use_gpu: true
|
4 |
+
epoch_num: 500
|
5 |
+
log_smooth_window: 20
|
6 |
+
print_batch_step: 10
|
7 |
+
save_model_dir: ./output/ch_PP-OCR_V3_det/
|
8 |
+
save_epoch_step: 100
|
9 |
+
eval_batch_step:
|
10 |
+
- 0
|
11 |
+
- 400
|
12 |
+
cal_metric_during_train: false
|
13 |
+
pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
|
14 |
+
checkpoints: null
|
15 |
+
save_inference_dir: null
|
16 |
+
use_visualdl: false
|
17 |
+
infer_img: doc/imgs_en/img_10.jpg
|
18 |
+
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
19 |
+
distributed: true
|
20 |
+
|
21 |
+
Architecture:
|
22 |
+
model_type: det
|
23 |
+
algorithm: DB
|
24 |
+
Transform:
|
25 |
+
Backbone:
|
26 |
+
name: MobileNetV3
|
27 |
+
scale: 0.5
|
28 |
+
model_name: large
|
29 |
+
disable_se: True
|
30 |
+
Neck:
|
31 |
+
name: RSEFPN
|
32 |
+
out_channels: 96
|
33 |
+
shortcut: True
|
34 |
+
Head:
|
35 |
+
name: DBHead
|
36 |
+
k: 50
|
37 |
+
|
38 |
+
Loss:
|
39 |
+
name: DBLoss
|
40 |
+
balance_loss: true
|
41 |
+
main_loss_type: DiceLoss
|
42 |
+
alpha: 5
|
43 |
+
beta: 10
|
44 |
+
ohem_ratio: 3
|
45 |
+
Optimizer:
|
46 |
+
name: Adam
|
47 |
+
beta1: 0.9
|
48 |
+
beta2: 0.999
|
49 |
+
lr:
|
50 |
+
name: Cosine
|
51 |
+
learning_rate: 0.001
|
52 |
+
warmup_epoch: 2
|
53 |
+
regularizer:
|
54 |
+
name: L2
|
55 |
+
factor: 5.0e-05
|
56 |
+
PostProcess:
|
57 |
+
name: DBPostProcess
|
58 |
+
thresh: 0.3
|
59 |
+
box_thresh: 0.6
|
60 |
+
max_candidates: 1000
|
61 |
+
unclip_ratio: 1.5
|
62 |
+
Metric:
|
63 |
+
name: DetMetric
|
64 |
+
main_indicator: hmean
|
65 |
+
Train:
|
66 |
+
dataset:
|
67 |
+
name: SimpleDataSet
|
68 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
69 |
+
label_file_list:
|
70 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
71 |
+
ratio_list: [1.0]
|
72 |
+
transforms:
|
73 |
+
- DecodeImage:
|
74 |
+
img_mode: BGR
|
75 |
+
channel_first: false
|
76 |
+
- DetLabelEncode: null
|
77 |
+
- IaaAugment:
|
78 |
+
augmenter_args:
|
79 |
+
- type: Fliplr
|
80 |
+
args:
|
81 |
+
p: 0.5
|
82 |
+
- type: Affine
|
83 |
+
args:
|
84 |
+
rotate:
|
85 |
+
- -10
|
86 |
+
- 10
|
87 |
+
- type: Resize
|
88 |
+
args:
|
89 |
+
size:
|
90 |
+
- 0.5
|
91 |
+
- 3
|
92 |
+
- EastRandomCropData:
|
93 |
+
size:
|
94 |
+
- 960
|
95 |
+
- 960
|
96 |
+
max_tries: 50
|
97 |
+
keep_ratio: true
|
98 |
+
- MakeBorderMap:
|
99 |
+
shrink_ratio: 0.4
|
100 |
+
thresh_min: 0.3
|
101 |
+
thresh_max: 0.7
|
102 |
+
- MakeShrinkMap:
|
103 |
+
shrink_ratio: 0.4
|
104 |
+
min_text_size: 8
|
105 |
+
- NormalizeImage:
|
106 |
+
scale: 1./255.
|
107 |
+
mean:
|
108 |
+
- 0.485
|
109 |
+
- 0.456
|
110 |
+
- 0.406
|
111 |
+
std:
|
112 |
+
- 0.229
|
113 |
+
- 0.224
|
114 |
+
- 0.225
|
115 |
+
order: hwc
|
116 |
+
- ToCHWImage: null
|
117 |
+
- KeepKeys:
|
118 |
+
keep_keys:
|
119 |
+
- image
|
120 |
+
- threshold_map
|
121 |
+
- threshold_mask
|
122 |
+
- shrink_map
|
123 |
+
- shrink_mask
|
124 |
+
loader:
|
125 |
+
shuffle: true
|
126 |
+
drop_last: false
|
127 |
+
batch_size_per_card: 8
|
128 |
+
num_workers: 4
|
129 |
+
Eval:
|
130 |
+
dataset:
|
131 |
+
name: SimpleDataSet
|
132 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
133 |
+
label_file_list:
|
134 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
135 |
+
transforms:
|
136 |
+
- DecodeImage:
|
137 |
+
img_mode: BGR
|
138 |
+
channel_first: false
|
139 |
+
- DetLabelEncode: null
|
140 |
+
- DetResizeForTest: null
|
141 |
+
- NormalizeImage:
|
142 |
+
scale: 1./255.
|
143 |
+
mean:
|
144 |
+
- 0.485
|
145 |
+
- 0.456
|
146 |
+
- 0.406
|
147 |
+
std:
|
148 |
+
- 0.229
|
149 |
+
- 0.224
|
150 |
+
- 0.225
|
151 |
+
order: hwc
|
152 |
+
- ToCHWImage: null
|
153 |
+
- KeepKeys:
|
154 |
+
keep_keys:
|
155 |
+
- image
|
156 |
+
- shape
|
157 |
+
- polys
|
158 |
+
- ignore_tags
|
159 |
+
loader:
|
160 |
+
shuffle: false
|
161 |
+
drop_last: false
|
162 |
+
batch_size_per_card: 1
|
163 |
+
num_workers: 2
|
Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_cml.yml
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
debug: false
|
3 |
+
use_gpu: true
|
4 |
+
epoch_num: 500
|
5 |
+
log_smooth_window: 20
|
6 |
+
print_batch_step: 20
|
7 |
+
save_model_dir: ./output/ch_PP-OCRv4
|
8 |
+
save_epoch_step: 50
|
9 |
+
eval_batch_step:
|
10 |
+
- 0
|
11 |
+
- 1000
|
12 |
+
cal_metric_during_train: true
|
13 |
+
checkpoints: null
|
14 |
+
pretrained_model: null
|
15 |
+
save_inference_dir: null
|
16 |
+
use_visualdl: false
|
17 |
+
infer_img: doc/imgs_en/img_10.jpg
|
18 |
+
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
19 |
+
distributed: true
|
20 |
+
Architecture:
|
21 |
+
name: DistillationModel
|
22 |
+
algorithm: Distillation
|
23 |
+
model_type: det
|
24 |
+
Models:
|
25 |
+
Student:
|
26 |
+
model_type: det
|
27 |
+
algorithm: DB
|
28 |
+
Transform: null
|
29 |
+
Backbone:
|
30 |
+
name: PPLCNetNew
|
31 |
+
scale: 0.75
|
32 |
+
pretrained: false
|
33 |
+
Neck:
|
34 |
+
name: RSEFPN
|
35 |
+
out_channels: 96
|
36 |
+
shortcut: true
|
37 |
+
Head:
|
38 |
+
name: DBHead
|
39 |
+
k: 50
|
40 |
+
Student2:
|
41 |
+
pretrained: null
|
42 |
+
model_type: det
|
43 |
+
algorithm: DB
|
44 |
+
Transform: null
|
45 |
+
Backbone:
|
46 |
+
name: PPLCNetNew
|
47 |
+
scale: 0.75
|
48 |
+
pretrained: true
|
49 |
+
Neck:
|
50 |
+
name: RSEFPN
|
51 |
+
out_channels: 96
|
52 |
+
shortcut: true
|
53 |
+
Head:
|
54 |
+
name: DBHead
|
55 |
+
k: 50
|
56 |
+
Teacher:
|
57 |
+
pretrained: https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_cml_teacher_pretrained/teacher.pdparams
|
58 |
+
freeze_params: true
|
59 |
+
return_all_feats: false
|
60 |
+
model_type: det
|
61 |
+
algorithm: DB
|
62 |
+
Backbone:
|
63 |
+
name: ResNet_vd
|
64 |
+
in_channels: 3
|
65 |
+
layers: 50
|
66 |
+
Neck:
|
67 |
+
name: LKPAN
|
68 |
+
out_channels: 256
|
69 |
+
Head:
|
70 |
+
name: DBHead
|
71 |
+
kernel_list:
|
72 |
+
- 7
|
73 |
+
- 2
|
74 |
+
- 2
|
75 |
+
k: 50
|
76 |
+
Loss:
|
77 |
+
name: CombinedLoss
|
78 |
+
loss_config_list:
|
79 |
+
- DistillationDilaDBLoss:
|
80 |
+
weight: 1.0
|
81 |
+
model_name_pairs:
|
82 |
+
- - Student
|
83 |
+
- Teacher
|
84 |
+
- - Student2
|
85 |
+
- Teacher
|
86 |
+
key: maps
|
87 |
+
balance_loss: true
|
88 |
+
main_loss_type: DiceLoss
|
89 |
+
alpha: 5
|
90 |
+
beta: 10
|
91 |
+
ohem_ratio: 3
|
92 |
+
- DistillationDMLLoss:
|
93 |
+
model_name_pairs:
|
94 |
+
- Student
|
95 |
+
- Student2
|
96 |
+
maps_name: thrink_maps
|
97 |
+
weight: 1.0
|
98 |
+
key: maps
|
99 |
+
- DistillationDBLoss:
|
100 |
+
weight: 1.0
|
101 |
+
model_name_list:
|
102 |
+
- Student
|
103 |
+
- Student2
|
104 |
+
balance_loss: true
|
105 |
+
main_loss_type: DiceLoss
|
106 |
+
alpha: 5
|
107 |
+
beta: 10
|
108 |
+
ohem_ratio: 3
|
109 |
+
Optimizer:
|
110 |
+
name: Adam
|
111 |
+
beta1: 0.9
|
112 |
+
beta2: 0.999
|
113 |
+
lr:
|
114 |
+
name: Cosine
|
115 |
+
learning_rate: 0.001
|
116 |
+
warmup_epoch: 2
|
117 |
+
regularizer:
|
118 |
+
name: L2
|
119 |
+
factor: 5.0e-05
|
120 |
+
PostProcess:
|
121 |
+
name: DistillationDBPostProcess
|
122 |
+
model_name:
|
123 |
+
- Student
|
124 |
+
key: head_out
|
125 |
+
thresh: 0.3
|
126 |
+
box_thresh: 0.6
|
127 |
+
max_candidates: 1000
|
128 |
+
unclip_ratio: 1.5
|
129 |
+
Metric:
|
130 |
+
name: DistillationMetric
|
131 |
+
base_metric_name: DetMetric
|
132 |
+
main_indicator: hmean
|
133 |
+
key: Student
|
134 |
+
Train:
|
135 |
+
dataset:
|
136 |
+
name: SimpleDataSet
|
137 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
138 |
+
label_file_list:
|
139 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
140 |
+
ratio_list: [1.0]
|
141 |
+
transforms:
|
142 |
+
- DecodeImage:
|
143 |
+
img_mode: BGR
|
144 |
+
channel_first: false
|
145 |
+
- DetLabelEncode: null
|
146 |
+
- IaaAugment:
|
147 |
+
augmenter_args:
|
148 |
+
- type: Fliplr
|
149 |
+
args:
|
150 |
+
p: 0.5
|
151 |
+
- type: Affine
|
152 |
+
args:
|
153 |
+
rotate:
|
154 |
+
- -10
|
155 |
+
- 10
|
156 |
+
- type: Resize
|
157 |
+
args:
|
158 |
+
size:
|
159 |
+
- 0.5
|
160 |
+
- 3
|
161 |
+
- EastRandomCropData:
|
162 |
+
size:
|
163 |
+
- 640
|
164 |
+
- 640
|
165 |
+
max_tries: 50
|
166 |
+
keep_ratio: true
|
167 |
+
- MakeBorderMap:
|
168 |
+
shrink_ratio: 0.4
|
169 |
+
thresh_min: 0.3
|
170 |
+
thresh_max: 0.7
|
171 |
+
total_epoch: 500
|
172 |
+
- MakeShrinkMap:
|
173 |
+
shrink_ratio: 0.4
|
174 |
+
min_text_size: 8
|
175 |
+
total_epoch: 500
|
176 |
+
- NormalizeImage:
|
177 |
+
scale: 1./255.
|
178 |
+
mean:
|
179 |
+
- 0.485
|
180 |
+
- 0.456
|
181 |
+
- 0.406
|
182 |
+
std:
|
183 |
+
- 0.229
|
184 |
+
- 0.224
|
185 |
+
- 0.225
|
186 |
+
order: hwc
|
187 |
+
- ToCHWImage: null
|
188 |
+
- KeepKeys:
|
189 |
+
keep_keys:
|
190 |
+
- image
|
191 |
+
- threshold_map
|
192 |
+
- threshold_mask
|
193 |
+
- shrink_map
|
194 |
+
- shrink_mask
|
195 |
+
loader:
|
196 |
+
shuffle: true
|
197 |
+
drop_last: false
|
198 |
+
batch_size_per_card: 16
|
199 |
+
num_workers: 8
|
200 |
+
Eval:
|
201 |
+
dataset:
|
202 |
+
name: SimpleDataSet
|
203 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
204 |
+
label_file_list:
|
205 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
206 |
+
transforms:
|
207 |
+
- DecodeImage:
|
208 |
+
img_mode: BGR
|
209 |
+
channel_first: false
|
210 |
+
- DetLabelEncode: null
|
211 |
+
- DetResizeForTest: null
|
212 |
+
- NormalizeImage:
|
213 |
+
scale: 1./255.
|
214 |
+
mean:
|
215 |
+
- 0.485
|
216 |
+
- 0.456
|
217 |
+
- 0.406
|
218 |
+
std:
|
219 |
+
- 0.229
|
220 |
+
- 0.224
|
221 |
+
- 0.225
|
222 |
+
order: hwc
|
223 |
+
- ToCHWImage: null
|
224 |
+
- KeepKeys:
|
225 |
+
keep_keys:
|
226 |
+
- image
|
227 |
+
- shape
|
228 |
+
- polys
|
229 |
+
- ignore_tags
|
230 |
+
loader:
|
231 |
+
shuffle: false
|
232 |
+
drop_last: false
|
233 |
+
batch_size_per_card: 1
|
234 |
+
num_workers: 2
|
235 |
+
profiler_options: null
|
Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
debug: false
|
3 |
+
use_gpu: true
|
4 |
+
epoch_num: &epoch_num 500
|
5 |
+
log_smooth_window: 20
|
6 |
+
print_batch_step: 100
|
7 |
+
save_model_dir: ./output/ch_PP-OCRv4
|
8 |
+
save_epoch_step: 10
|
9 |
+
eval_batch_step:
|
10 |
+
- 0
|
11 |
+
- 1500
|
12 |
+
cal_metric_during_train: false
|
13 |
+
checkpoints:
|
14 |
+
pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/PPLCNetV3_x0_75_ocr_det.pdparams
|
15 |
+
save_inference_dir: null
|
16 |
+
use_visualdl: false
|
17 |
+
infer_img: doc/imgs_en/img_10.jpg
|
18 |
+
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
19 |
+
distributed: true
|
20 |
+
|
21 |
+
Architecture:
|
22 |
+
model_type: det
|
23 |
+
algorithm: DB
|
24 |
+
Transform: null
|
25 |
+
Backbone:
|
26 |
+
name: PPLCNetV3
|
27 |
+
scale: 0.75
|
28 |
+
det: True
|
29 |
+
Neck:
|
30 |
+
name: RSEFPN
|
31 |
+
out_channels: 96
|
32 |
+
shortcut: True
|
33 |
+
Head:
|
34 |
+
name: DBHead
|
35 |
+
k: 50
|
36 |
+
|
37 |
+
Loss:
|
38 |
+
name: DBLoss
|
39 |
+
balance_loss: true
|
40 |
+
main_loss_type: DiceLoss
|
41 |
+
alpha: 5
|
42 |
+
beta: 10
|
43 |
+
ohem_ratio: 3
|
44 |
+
|
45 |
+
Optimizer:
|
46 |
+
name: Adam
|
47 |
+
beta1: 0.9
|
48 |
+
beta2: 0.999
|
49 |
+
lr:
|
50 |
+
name: Cosine
|
51 |
+
learning_rate: 0.001 #(8*8c)
|
52 |
+
warmup_epoch: 2
|
53 |
+
regularizer:
|
54 |
+
name: L2
|
55 |
+
factor: 5.0e-05
|
56 |
+
|
57 |
+
PostProcess:
|
58 |
+
name: DBPostProcess
|
59 |
+
thresh: 0.3
|
60 |
+
box_thresh: 0.6
|
61 |
+
max_candidates: 1000
|
62 |
+
unclip_ratio: 1.5
|
63 |
+
|
64 |
+
Metric:
|
65 |
+
name: DetMetric
|
66 |
+
main_indicator: hmean
|
67 |
+
|
68 |
+
Train:
|
69 |
+
dataset:
|
70 |
+
name: SimpleDataSet
|
71 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
72 |
+
label_file_list:
|
73 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
74 |
+
ratio_list: [1.0]
|
75 |
+
transforms:
|
76 |
+
- DecodeImage:
|
77 |
+
img_mode: BGR
|
78 |
+
channel_first: false
|
79 |
+
- DetLabelEncode: null
|
80 |
+
- CopyPaste: null
|
81 |
+
- IaaAugment:
|
82 |
+
augmenter_args:
|
83 |
+
- type: Fliplr
|
84 |
+
args:
|
85 |
+
p: 0.5
|
86 |
+
- type: Affine
|
87 |
+
args:
|
88 |
+
rotate:
|
89 |
+
- -10
|
90 |
+
- 10
|
91 |
+
- type: Resize
|
92 |
+
args:
|
93 |
+
size:
|
94 |
+
- 0.5
|
95 |
+
- 3
|
96 |
+
- EastRandomCropData:
|
97 |
+
size:
|
98 |
+
- 640
|
99 |
+
- 640
|
100 |
+
max_tries: 50
|
101 |
+
keep_ratio: true
|
102 |
+
- MakeBorderMap:
|
103 |
+
shrink_ratio: 0.4
|
104 |
+
thresh_min: 0.3
|
105 |
+
thresh_max: 0.7
|
106 |
+
total_epoch: *epoch_num
|
107 |
+
- MakeShrinkMap:
|
108 |
+
shrink_ratio: 0.4
|
109 |
+
min_text_size: 8
|
110 |
+
total_epoch: *epoch_num
|
111 |
+
- NormalizeImage:
|
112 |
+
scale: 1./255.
|
113 |
+
mean:
|
114 |
+
- 0.485
|
115 |
+
- 0.456
|
116 |
+
- 0.406
|
117 |
+
std:
|
118 |
+
- 0.229
|
119 |
+
- 0.224
|
120 |
+
- 0.225
|
121 |
+
order: hwc
|
122 |
+
- ToCHWImage: null
|
123 |
+
- KeepKeys:
|
124 |
+
keep_keys:
|
125 |
+
- image
|
126 |
+
- threshold_map
|
127 |
+
- threshold_mask
|
128 |
+
- shrink_map
|
129 |
+
- shrink_mask
|
130 |
+
loader:
|
131 |
+
shuffle: true
|
132 |
+
drop_last: false
|
133 |
+
batch_size_per_card: 8
|
134 |
+
num_workers: 8
|
135 |
+
|
136 |
+
Eval:
|
137 |
+
dataset:
|
138 |
+
name: SimpleDataSet
|
139 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
140 |
+
label_file_list:
|
141 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
142 |
+
transforms:
|
143 |
+
- DecodeImage:
|
144 |
+
img_mode: BGR
|
145 |
+
channel_first: false
|
146 |
+
- DetLabelEncode: null
|
147 |
+
- DetResizeForTest:
|
148 |
+
- NormalizeImage:
|
149 |
+
scale: 1./255.
|
150 |
+
mean:
|
151 |
+
- 0.485
|
152 |
+
- 0.456
|
153 |
+
- 0.406
|
154 |
+
std:
|
155 |
+
- 0.229
|
156 |
+
- 0.224
|
157 |
+
- 0.225
|
158 |
+
order: hwc
|
159 |
+
- ToCHWImage: null
|
160 |
+
- KeepKeys:
|
161 |
+
keep_keys:
|
162 |
+
- image
|
163 |
+
- shape
|
164 |
+
- polys
|
165 |
+
- ignore_tags
|
166 |
+
loader:
|
167 |
+
shuffle: false
|
168 |
+
drop_last: false
|
169 |
+
batch_size_per_card: 1
|
170 |
+
num_workers: 2
|
171 |
+
profiler_options: null
|
Rotate/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
debug: false
|
3 |
+
use_gpu: true
|
4 |
+
epoch_num: &epoch_num 500
|
5 |
+
log_smooth_window: 20
|
6 |
+
print_batch_step: 100
|
7 |
+
save_model_dir: ./output/ch_PP-OCRv4
|
8 |
+
save_epoch_step: 10
|
9 |
+
eval_batch_step:
|
10 |
+
- 0
|
11 |
+
- 1500
|
12 |
+
cal_metric_during_train: false
|
13 |
+
checkpoints:
|
14 |
+
pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/PPHGNet_small_ocr_det.pdparams
|
15 |
+
save_inference_dir: null
|
16 |
+
use_visualdl: false
|
17 |
+
infer_img: doc/imgs_en/img_10.jpg
|
18 |
+
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
19 |
+
distributed: true
|
20 |
+
|
21 |
+
Architecture:
|
22 |
+
model_type: det
|
23 |
+
algorithm: DB
|
24 |
+
Transform: null
|
25 |
+
Backbone:
|
26 |
+
name: PPHGNet_small
|
27 |
+
det: True
|
28 |
+
Neck:
|
29 |
+
name: LKPAN
|
30 |
+
out_channels: 256
|
31 |
+
intracl: true
|
32 |
+
Head:
|
33 |
+
name: PFHeadLocal
|
34 |
+
k: 50
|
35 |
+
mode: "large"
|
36 |
+
|
37 |
+
|
38 |
+
Loss:
|
39 |
+
name: DBLoss
|
40 |
+
balance_loss: true
|
41 |
+
main_loss_type: DiceLoss
|
42 |
+
alpha: 5
|
43 |
+
beta: 10
|
44 |
+
ohem_ratio: 3
|
45 |
+
|
46 |
+
Optimizer:
|
47 |
+
name: Adam
|
48 |
+
beta1: 0.9
|
49 |
+
beta2: 0.999
|
50 |
+
lr:
|
51 |
+
name: Cosine
|
52 |
+
learning_rate: 0.001 #(8*8c)
|
53 |
+
warmup_epoch: 2
|
54 |
+
regularizer:
|
55 |
+
name: L2
|
56 |
+
factor: 1e-6
|
57 |
+
|
58 |
+
PostProcess:
|
59 |
+
name: DBPostProcess
|
60 |
+
thresh: 0.3
|
61 |
+
box_thresh: 0.6
|
62 |
+
max_candidates: 1000
|
63 |
+
unclip_ratio: 1.5
|
64 |
+
|
65 |
+
Metric:
|
66 |
+
name: DetMetric
|
67 |
+
main_indicator: hmean
|
68 |
+
|
69 |
+
Train:
|
70 |
+
dataset:
|
71 |
+
name: SimpleDataSet
|
72 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
73 |
+
label_file_list:
|
74 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
75 |
+
ratio_list: [1.0]
|
76 |
+
transforms:
|
77 |
+
- DecodeImage:
|
78 |
+
img_mode: BGR
|
79 |
+
channel_first: false
|
80 |
+
- DetLabelEncode: null
|
81 |
+
- CopyPaste: null
|
82 |
+
- IaaAugment:
|
83 |
+
augmenter_args:
|
84 |
+
- type: Fliplr
|
85 |
+
args:
|
86 |
+
p: 0.5
|
87 |
+
- type: Affine
|
88 |
+
args:
|
89 |
+
rotate:
|
90 |
+
- -10
|
91 |
+
- 10
|
92 |
+
- type: Resize
|
93 |
+
args:
|
94 |
+
size:
|
95 |
+
- 0.5
|
96 |
+
- 3
|
97 |
+
- EastRandomCropData:
|
98 |
+
size:
|
99 |
+
- 640
|
100 |
+
- 640
|
101 |
+
max_tries: 50
|
102 |
+
keep_ratio: true
|
103 |
+
- MakeBorderMap:
|
104 |
+
shrink_ratio: 0.4
|
105 |
+
thresh_min: 0.3
|
106 |
+
thresh_max: 0.7
|
107 |
+
total_epoch: *epoch_num
|
108 |
+
- MakeShrinkMap:
|
109 |
+
shrink_ratio: 0.4
|
110 |
+
min_text_size: 8
|
111 |
+
total_epoch: *epoch_num
|
112 |
+
- NormalizeImage:
|
113 |
+
scale: 1./255.
|
114 |
+
mean:
|
115 |
+
- 0.485
|
116 |
+
- 0.456
|
117 |
+
- 0.406
|
118 |
+
std:
|
119 |
+
- 0.229
|
120 |
+
- 0.224
|
121 |
+
- 0.225
|
122 |
+
order: hwc
|
123 |
+
- ToCHWImage: null
|
124 |
+
- KeepKeys:
|
125 |
+
keep_keys:
|
126 |
+
- image
|
127 |
+
- threshold_map
|
128 |
+
- threshold_mask
|
129 |
+
- shrink_map
|
130 |
+
- shrink_mask
|
131 |
+
loader:
|
132 |
+
shuffle: true
|
133 |
+
drop_last: false
|
134 |
+
batch_size_per_card: 8
|
135 |
+
num_workers: 8
|
136 |
+
|
137 |
+
Eval:
|
138 |
+
dataset:
|
139 |
+
name: SimpleDataSet
|
140 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
141 |
+
label_file_list:
|
142 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
143 |
+
transforms:
|
144 |
+
- DecodeImage:
|
145 |
+
img_mode: BGR
|
146 |
+
channel_first: false
|
147 |
+
- DetLabelEncode: null
|
148 |
+
- DetResizeForTest:
|
149 |
+
- NormalizeImage:
|
150 |
+
scale: 1./255.
|
151 |
+
mean:
|
152 |
+
- 0.485
|
153 |
+
- 0.456
|
154 |
+
- 0.406
|
155 |
+
std:
|
156 |
+
- 0.229
|
157 |
+
- 0.224
|
158 |
+
- 0.225
|
159 |
+
order: hwc
|
160 |
+
- ToCHWImage: null
|
161 |
+
- KeepKeys:
|
162 |
+
keep_keys:
|
163 |
+
- image
|
164 |
+
- shape
|
165 |
+
- polys
|
166 |
+
- ignore_tags
|
167 |
+
loader:
|
168 |
+
shuffle: false
|
169 |
+
drop_last: false
|
170 |
+
batch_size_per_card: 1
|
171 |
+
num_workers: 2
|
172 |
+
profiler_options: null
|
Rotate/configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
use_gpu: true
|
3 |
+
epoch_num: 1200
|
4 |
+
log_smooth_window: 20
|
5 |
+
print_batch_step: 2
|
6 |
+
save_model_dir: ./output/ch_db_mv3/
|
7 |
+
save_epoch_step: 1200
|
8 |
+
# evaluation is run every 5000 iterations after the 4000th iteration
|
9 |
+
eval_batch_step: [3000, 2000]
|
10 |
+
cal_metric_during_train: False
|
11 |
+
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
12 |
+
checkpoints:
|
13 |
+
save_inference_dir:
|
14 |
+
use_visualdl: False
|
15 |
+
infer_img: doc/imgs_en/img_10.jpg
|
16 |
+
save_res_path: ./output/det_db/predicts_db.txt
|
17 |
+
|
18 |
+
Architecture:
|
19 |
+
model_type: det
|
20 |
+
algorithm: DB
|
21 |
+
Transform:
|
22 |
+
Backbone:
|
23 |
+
name: MobileNetV3
|
24 |
+
scale: 0.5
|
25 |
+
model_name: large
|
26 |
+
disable_se: True
|
27 |
+
Neck:
|
28 |
+
name: DBFPN
|
29 |
+
out_channels: 96
|
30 |
+
Head:
|
31 |
+
name: DBHead
|
32 |
+
k: 50
|
33 |
+
|
34 |
+
Loss:
|
35 |
+
name: DBLoss
|
36 |
+
balance_loss: true
|
37 |
+
main_loss_type: DiceLoss
|
38 |
+
alpha: 5
|
39 |
+
beta: 10
|
40 |
+
ohem_ratio: 3
|
41 |
+
|
42 |
+
Optimizer:
|
43 |
+
name: Adam
|
44 |
+
beta1: 0.9
|
45 |
+
beta2: 0.999
|
46 |
+
lr:
|
47 |
+
name: Cosine
|
48 |
+
learning_rate: 0.001
|
49 |
+
warmup_epoch: 2
|
50 |
+
regularizer:
|
51 |
+
name: 'L2'
|
52 |
+
factor: 0
|
53 |
+
|
54 |
+
PostProcess:
|
55 |
+
name: DBPostProcess
|
56 |
+
thresh: 0.3
|
57 |
+
box_thresh: 0.6
|
58 |
+
max_candidates: 1000
|
59 |
+
unclip_ratio: 1.5
|
60 |
+
|
61 |
+
Metric:
|
62 |
+
name: DetMetric
|
63 |
+
main_indicator: hmean
|
64 |
+
|
65 |
+
Train:
|
66 |
+
dataset:
|
67 |
+
name: SimpleDataSet
|
68 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
69 |
+
label_file_list:
|
70 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
71 |
+
ratio_list: [1.0]
|
72 |
+
transforms:
|
73 |
+
- DecodeImage: # load image
|
74 |
+
img_mode: BGR
|
75 |
+
channel_first: False
|
76 |
+
- DetLabelEncode: # Class handling label
|
77 |
+
- IaaAugment:
|
78 |
+
augmenter_args:
|
79 |
+
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
80 |
+
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
81 |
+
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
82 |
+
- EastRandomCropData:
|
83 |
+
size: [960, 960]
|
84 |
+
max_tries: 50
|
85 |
+
keep_ratio: true
|
86 |
+
- MakeBorderMap:
|
87 |
+
shrink_ratio: 0.4
|
88 |
+
thresh_min: 0.3
|
89 |
+
thresh_max: 0.7
|
90 |
+
- MakeShrinkMap:
|
91 |
+
shrink_ratio: 0.4
|
92 |
+
min_text_size: 8
|
93 |
+
- NormalizeImage:
|
94 |
+
scale: 1./255.
|
95 |
+
mean: [0.485, 0.456, 0.406]
|
96 |
+
std: [0.229, 0.224, 0.225]
|
97 |
+
order: 'hwc'
|
98 |
+
- ToCHWImage:
|
99 |
+
- KeepKeys:
|
100 |
+
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
101 |
+
loader:
|
102 |
+
shuffle: True
|
103 |
+
drop_last: False
|
104 |
+
batch_size_per_card: 8
|
105 |
+
num_workers: 4
|
106 |
+
|
107 |
+
Eval:
|
108 |
+
dataset:
|
109 |
+
name: SimpleDataSet
|
110 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
111 |
+
label_file_list:
|
112 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
113 |
+
transforms:
|
114 |
+
- DecodeImage: # load image
|
115 |
+
img_mode: BGR
|
116 |
+
channel_first: False
|
117 |
+
- DetLabelEncode: # Class handling label
|
118 |
+
- DetResizeForTest:
|
119 |
+
# image_shape: [736, 1280]
|
120 |
+
- NormalizeImage:
|
121 |
+
scale: 1./255.
|
122 |
+
mean: [0.485, 0.456, 0.406]
|
123 |
+
std: [0.229, 0.224, 0.225]
|
124 |
+
order: 'hwc'
|
125 |
+
- ToCHWImage:
|
126 |
+
- KeepKeys:
|
127 |
+
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
128 |
+
loader:
|
129 |
+
shuffle: False
|
130 |
+
drop_last: False
|
131 |
+
batch_size_per_card: 1 # must be 1
|
132 |
+
num_workers: 2
|
Rotate/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
use_gpu: true
|
3 |
+
epoch_num: 1200
|
4 |
+
log_smooth_window: 20
|
5 |
+
print_batch_step: 2
|
6 |
+
save_model_dir: ./output/ch_db_res18/
|
7 |
+
save_epoch_step: 1200
|
8 |
+
# evaluation is run every 5000 iterations after the 4000th iteration
|
9 |
+
eval_batch_step: [3000, 2000]
|
10 |
+
cal_metric_during_train: False
|
11 |
+
pretrained_model: ./pretrain_models/ResNet18_vd_pretrained
|
12 |
+
checkpoints:
|
13 |
+
save_inference_dir:
|
14 |
+
use_visualdl: False
|
15 |
+
infer_img: doc/imgs_en/img_10.jpg
|
16 |
+
save_res_path: ./output/det_db/predicts_db.txt
|
17 |
+
|
18 |
+
Architecture:
|
19 |
+
model_type: det
|
20 |
+
algorithm: DB
|
21 |
+
Transform:
|
22 |
+
Backbone:
|
23 |
+
name: ResNet_vd
|
24 |
+
layers: 18
|
25 |
+
disable_se: True
|
26 |
+
Neck:
|
27 |
+
name: DBFPN
|
28 |
+
out_channels: 256
|
29 |
+
Head:
|
30 |
+
name: DBHead
|
31 |
+
k: 50
|
32 |
+
|
33 |
+
Loss:
|
34 |
+
name: DBLoss
|
35 |
+
balance_loss: true
|
36 |
+
main_loss_type: DiceLoss
|
37 |
+
alpha: 5
|
38 |
+
beta: 10
|
39 |
+
ohem_ratio: 3
|
40 |
+
|
41 |
+
Optimizer:
|
42 |
+
name: Adam
|
43 |
+
beta1: 0.9
|
44 |
+
beta2: 0.999
|
45 |
+
lr:
|
46 |
+
name: Cosine
|
47 |
+
learning_rate: 0.001
|
48 |
+
warmup_epoch: 2
|
49 |
+
regularizer:
|
50 |
+
name: 'L2'
|
51 |
+
factor: 0
|
52 |
+
|
53 |
+
PostProcess:
|
54 |
+
name: DBPostProcess
|
55 |
+
thresh: 0.3
|
56 |
+
box_thresh: 0.6
|
57 |
+
max_candidates: 1000
|
58 |
+
unclip_ratio: 1.5
|
59 |
+
|
60 |
+
Metric:
|
61 |
+
name: DetMetric
|
62 |
+
main_indicator: hmean
|
63 |
+
|
64 |
+
Train:
|
65 |
+
dataset:
|
66 |
+
name: SimpleDataSet
|
67 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
68 |
+
label_file_list:
|
69 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
70 |
+
ratio_list: [1.0]
|
71 |
+
transforms:
|
72 |
+
- DecodeImage: # load image
|
73 |
+
img_mode: BGR
|
74 |
+
channel_first: False
|
75 |
+
- DetLabelEncode: # Class handling label
|
76 |
+
- IaaAugment:
|
77 |
+
augmenter_args:
|
78 |
+
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
79 |
+
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
80 |
+
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
81 |
+
- EastRandomCropData:
|
82 |
+
size: [960, 960]
|
83 |
+
max_tries: 50
|
84 |
+
keep_ratio: true
|
85 |
+
- MakeBorderMap:
|
86 |
+
shrink_ratio: 0.4
|
87 |
+
thresh_min: 0.3
|
88 |
+
thresh_max: 0.7
|
89 |
+
- MakeShrinkMap:
|
90 |
+
shrink_ratio: 0.4
|
91 |
+
min_text_size: 8
|
92 |
+
- NormalizeImage:
|
93 |
+
scale: 1./255.
|
94 |
+
mean: [0.485, 0.456, 0.406]
|
95 |
+
std: [0.229, 0.224, 0.225]
|
96 |
+
order: 'hwc'
|
97 |
+
- ToCHWImage:
|
98 |
+
- KeepKeys:
|
99 |
+
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
100 |
+
loader:
|
101 |
+
shuffle: True
|
102 |
+
drop_last: False
|
103 |
+
batch_size_per_card: 8
|
104 |
+
num_workers: 4
|
105 |
+
|
106 |
+
Eval:
|
107 |
+
dataset:
|
108 |
+
name: SimpleDataSet
|
109 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
110 |
+
label_file_list:
|
111 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
112 |
+
transforms:
|
113 |
+
- DecodeImage: # load image
|
114 |
+
img_mode: BGR
|
115 |
+
channel_first: False
|
116 |
+
- DetLabelEncode: # Class handling label
|
117 |
+
- DetResizeForTest:
|
118 |
+
# image_shape: [736, 1280]
|
119 |
+
- NormalizeImage:
|
120 |
+
scale: 1./255.
|
121 |
+
mean: [0.485, 0.456, 0.406]
|
122 |
+
std: [0.229, 0.224, 0.225]
|
123 |
+
order: 'hwc'
|
124 |
+
- ToCHWImage:
|
125 |
+
- KeepKeys:
|
126 |
+
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
127 |
+
loader:
|
128 |
+
shuffle: False
|
129 |
+
drop_last: False
|
130 |
+
batch_size_per_card: 1 # must be 1
|
131 |
+
num_workers: 2
|
Rotate/configs/det/det_mv3_db.yml
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Global:
|
2 |
+
use_gpu: true
|
3 |
+
use_xpu: false
|
4 |
+
use_mlu: false
|
5 |
+
epoch_num: 1200
|
6 |
+
log_smooth_window: 20
|
7 |
+
print_batch_step: 10
|
8 |
+
save_model_dir: ./output/db_mv3/
|
9 |
+
save_epoch_step: 1200
|
10 |
+
# evaluation is run every 2000 iterations
|
11 |
+
eval_batch_step: [0, 2000]
|
12 |
+
cal_metric_during_train: False
|
13 |
+
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
14 |
+
checkpoints:
|
15 |
+
save_inference_dir:
|
16 |
+
use_visualdl: False
|
17 |
+
infer_img: doc/imgs_en/img_10.jpg
|
18 |
+
save_res_path: ./output/det_db/predicts_db.txt
|
19 |
+
|
20 |
+
Architecture:
|
21 |
+
model_type: det
|
22 |
+
algorithm: DB
|
23 |
+
Transform:
|
24 |
+
Backbone:
|
25 |
+
name: MobileNetV3
|
26 |
+
scale: 0.5
|
27 |
+
model_name: large
|
28 |
+
Neck:
|
29 |
+
name: DBFPN
|
30 |
+
out_channels: 256
|
31 |
+
Head:
|
32 |
+
name: DBHead
|
33 |
+
k: 50
|
34 |
+
|
35 |
+
Loss:
|
36 |
+
name: DBLoss
|
37 |
+
balance_loss: true
|
38 |
+
main_loss_type: DiceLoss
|
39 |
+
alpha: 5
|
40 |
+
beta: 10
|
41 |
+
ohem_ratio: 3
|
42 |
+
|
43 |
+
Optimizer:
|
44 |
+
name: Adam
|
45 |
+
beta1: 0.9
|
46 |
+
beta2: 0.999
|
47 |
+
lr:
|
48 |
+
learning_rate: 0.001
|
49 |
+
regularizer:
|
50 |
+
name: 'L2'
|
51 |
+
factor: 0
|
52 |
+
|
53 |
+
PostProcess:
|
54 |
+
name: DBPostProcess
|
55 |
+
thresh: 0.3
|
56 |
+
box_thresh: 0.6
|
57 |
+
max_candidates: 1000
|
58 |
+
unclip_ratio: 1.5
|
59 |
+
|
60 |
+
Metric:
|
61 |
+
name: DetMetric
|
62 |
+
main_indicator: hmean
|
63 |
+
|
64 |
+
Train:
|
65 |
+
dataset:
|
66 |
+
name: SimpleDataSet
|
67 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
68 |
+
label_file_list:
|
69 |
+
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
70 |
+
ratio_list: [1.0]
|
71 |
+
transforms:
|
72 |
+
- DecodeImage: # load image
|
73 |
+
img_mode: BGR
|
74 |
+
channel_first: False
|
75 |
+
- DetLabelEncode: # Class handling label
|
76 |
+
- IaaAugment:
|
77 |
+
augmenter_args:
|
78 |
+
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
79 |
+
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
80 |
+
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
81 |
+
- EastRandomCropData:
|
82 |
+
size: [640, 640]
|
83 |
+
max_tries: 50
|
84 |
+
keep_ratio: true
|
85 |
+
- MakeBorderMap:
|
86 |
+
shrink_ratio: 0.4
|
87 |
+
thresh_min: 0.3
|
88 |
+
thresh_max: 0.7
|
89 |
+
- MakeShrinkMap:
|
90 |
+
shrink_ratio: 0.4
|
91 |
+
min_text_size: 8
|
92 |
+
- NormalizeImage:
|
93 |
+
scale: 1./255.
|
94 |
+
mean: [0.485, 0.456, 0.406]
|
95 |
+
std: [0.229, 0.224, 0.225]
|
96 |
+
order: 'hwc'
|
97 |
+
- ToCHWImage:
|
98 |
+
- KeepKeys:
|
99 |
+
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
100 |
+
loader:
|
101 |
+
shuffle: True
|
102 |
+
drop_last: False
|
103 |
+
batch_size_per_card: 16
|
104 |
+
num_workers: 8
|
105 |
+
use_shared_memory: True
|
106 |
+
|
107 |
+
Eval:
|
108 |
+
dataset:
|
109 |
+
name: SimpleDataSet
|
110 |
+
data_dir: ./train_data/icdar2015/text_localization/
|
111 |
+
label_file_list:
|
112 |
+
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
113 |
+
transforms:
|
114 |
+
- DecodeImage: # load image
|
115 |
+
img_mode: BGR
|
116 |
+
channel_first: False
|
117 |
+
- DetLabelEncode: # Class handling label
|
118 |
+
- DetResizeForTest:
|
119 |
+
image_shape: [736, 1280]
|
120 |
+
- NormalizeImage:
|
121 |
+
scale: 1./255.
|
122 |
+
mean: [0.485, 0.456, 0.406]
|
123 |
+
std: [0.229, 0.224, 0.225]
|
124 |
+
order: 'hwc'
|
125 |
+
- ToCHWImage:
|
126 |
+
- KeepKeys:
|
127 |
+
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
128 |
+
loader:
|
129 |
+
shuffle: False
|
130 |
+
drop_last: False
|
131 |
+
batch_size_per_card: 1 # must be 1
|
132 |
+
num_workers: 8
|
133 |
+
use_shared_memory: True
|