Cyril666 commited on
Commit
1a827c6
·
1 Parent(s): e5414eb

First model version

Browse files
configs/pretrain_language_model.yaml DELETED
@@ -1,45 +0,0 @@
1
- global:
2
- name: pretrain-language-model
3
- phase: train
4
- stage: pretrain-language
5
- workdir: workdir
6
- seed: ~
7
-
8
- dataset:
9
- train: {
10
- roots: ['data/WikiText-103.csv'],
11
- batch_size: 4096
12
- }
13
- test: {
14
- roots: ['data/WikiText-103_eval_d1.csv'],
15
- batch_size: 4096
16
- }
17
-
18
- training:
19
- epochs: 80
20
- show_iters: 50
21
- eval_iters: 6000
22
- save_iters: 3000
23
-
24
- optimizer:
25
- type: Adam
26
- true_wd: False
27
- wd: 0.0
28
- bn_wd: False
29
- clip_grad: 20
30
- lr: 0.0001
31
- args: {
32
- betas: !!python/tuple [0.9, 0.999], # for default Adam
33
- }
34
- scheduler: {
35
- periods: [70, 10],
36
- gamma: 0.1,
37
- }
38
-
39
- model:
40
- name: 'modules.model_language.BCNLanguage'
41
- language: {
42
- num_layers: 4,
43
- loss_weight: 1.,
44
- use_self_attn: False
45
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/pretrain_vision_model.yaml DELETED
@@ -1,58 +0,0 @@
1
- global:
2
- name: pretrain-vision-model
3
- phase: train
4
- stage: pretrain-vision
5
- workdir: workdir
6
- seed: ~
7
-
8
- dataset:
9
- train: {
10
- roots: ['data/training/MJ/MJ_train/',
11
- 'data/training/MJ/MJ_test/',
12
- 'data/training/MJ/MJ_valid/',
13
- 'data/training/ST'],
14
- batch_size: 384
15
- }
16
- test: {
17
- roots: ['data/evaluation/IIIT5k_3000',
18
- 'data/evaluation/SVT',
19
- 'data/evaluation/SVTP',
20
- 'data/evaluation/IC13_857',
21
- 'data/evaluation/IC15_1811',
22
- 'data/evaluation/CUTE80'],
23
- batch_size: 384
24
- }
25
- data_aug: True
26
- multiscales: False
27
- num_workers: 14
28
-
29
- training:
30
- epochs: 8
31
- show_iters: 50
32
- eval_iters: 3000
33
- save_iters: 3000
34
-
35
- optimizer:
36
- type: Adam
37
- true_wd: False
38
- wd: 0.0
39
- bn_wd: False
40
- clip_grad: 20
41
- lr: 0.0001
42
- args: {
43
- betas: !!python/tuple [0.9, 0.999], # for default Adam
44
- }
45
- scheduler: {
46
- periods: [6, 2],
47
- gamma: 0.1,
48
- }
49
-
50
- model:
51
- name: 'modules.model_vision.BaseVision'
52
- checkpoint: ~
53
- vision: {
54
- loss_weight: 1.,
55
- attention: 'position',
56
- backbone: 'transformer',
57
- backbone_ln: 3,
58
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/pretrain_vision_model_sv.yaml DELETED
@@ -1,58 +0,0 @@
1
- global:
2
- name: pretrain-vision-model-sv
3
- phase: train
4
- stage: pretrain-vision
5
- workdir: workdir
6
- seed: ~
7
-
8
- dataset:
9
- train: {
10
- roots: ['data/training/MJ/MJ_train/',
11
- 'data/training/MJ/MJ_test/',
12
- 'data/training/MJ/MJ_valid/',
13
- 'data/training/ST'],
14
- batch_size: 384
15
- }
16
- test: {
17
- roots: ['data/evaluation/IIIT5k_3000',
18
- 'data/evaluation/SVT',
19
- 'data/evaluation/SVTP',
20
- 'data/evaluation/IC13_857',
21
- 'data/evaluation/IC15_1811',
22
- 'data/evaluation/CUTE80'],
23
- batch_size: 384
24
- }
25
- data_aug: True
26
- multiscales: False
27
- num_workers: 14
28
-
29
- training:
30
- epochs: 8
31
- show_iters: 50
32
- eval_iters: 3000
33
- save_iters: 3000
34
-
35
- optimizer:
36
- type: Adam
37
- true_wd: False
38
- wd: 0.0
39
- bn_wd: False
40
- clip_grad: 20
41
- lr: 0.0001
42
- args: {
43
- betas: !!python/tuple [0.9, 0.999], # for default Adam
44
- }
45
- scheduler: {
46
- periods: [6, 2],
47
- gamma: 0.1,
48
- }
49
-
50
- model:
51
- name: 'modules.model_vision.BaseVision'
52
- checkpoint: ~
53
- vision: {
54
- loss_weight: 1.,
55
- attention: 'attention',
56
- backbone: 'transformer',
57
- backbone_ln: 2,
58
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/template.yaml DELETED
@@ -1,67 +0,0 @@
1
- global:
2
- name: exp
3
- phase: train
4
- stage: pretrain-vision
5
- workdir: /tmp/workdir
6
- seed: ~
7
-
8
- dataset:
9
- train: {
10
- roots: ['data/training/MJ/MJ_train/',
11
- 'data/training/MJ/MJ_test/',
12
- 'data/training/MJ/MJ_valid/',
13
- 'data/training/ST'],
14
- batch_size: 128
15
- }
16
- test: {
17
- roots: ['data/evaluation/IIIT5k_3000',
18
- 'data/evaluation/SVT',
19
- 'data/evaluation/SVTP',
20
- 'data/evaluation/IC13_857',
21
- 'data/evaluation/IC15_1811',
22
- 'data/evaluation/CUTE80'],
23
- batch_size: 128
24
- }
25
- charset_path: data/charset_36.txt
26
- num_workers: 4
27
- max_length: 25 # 30
28
- image_height: 32
29
- image_width: 128
30
- case_sensitive: False
31
- eval_case_sensitive: False
32
- data_aug: True
33
- multiscales: False
34
- pin_memory: True
35
- smooth_label: False
36
- smooth_factor: 0.1
37
- one_hot_y: True
38
- use_sm: False
39
-
40
- training:
41
- epochs: 6
42
- show_iters: 50
43
- eval_iters: 3000
44
- save_iters: 20000
45
- start_iters: 0
46
- stats_iters: 100000
47
-
48
- optimizer:
49
- type: Adadelta # Adadelta, Adam
50
- true_wd: False
51
- wd: 0. # 0.001
52
- bn_wd: False
53
- args: {
54
- # betas: !!python/tuple [0.9, 0.99], # betas=(0.9,0.99) for AdamW
55
- # betas: !!python/tuple [0.9, 0.999], # for default Adam
56
- }
57
- clip_grad: 20
58
- lr: [1.0, 1.0, 1.0] # lr: [0.005, 0.005, 0.005]
59
- scheduler: {
60
- periods: [3, 2, 1],
61
- gamma: 0.1,
62
- }
63
-
64
- model:
65
- name: 'modules.model_abinet.ABINetModel'
66
- checkpoint: ~
67
- strict: True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/train_abinet.yaml DELETED
@@ -1,71 +0,0 @@
1
- global:
2
- name: train-abinet
3
- phase: train
4
- stage: train-super
5
- workdir: workdir
6
- seed: ~
7
-
8
- dataset:
9
- train: {
10
- roots: ['data/training/MJ/MJ_train/',
11
- 'data/training/MJ/MJ_test/',
12
- 'data/training/MJ/MJ_valid/',
13
- 'data/training/ST'],
14
- batch_size: 384
15
- }
16
- test: {
17
- roots: ['data/evaluation/IIIT5k_3000',
18
- 'data/evaluation/SVT',
19
- 'data/evaluation/SVTP',
20
- 'data/evaluation/IC13_857',
21
- 'data/evaluation/IC15_1811',
22
- 'data/evaluation/CUTE80'],
23
- batch_size: 384
24
- }
25
- data_aug: True
26
- multiscales: False
27
- num_workers: 14
28
-
29
- training:
30
- epochs: 10
31
- show_iters: 50
32
- eval_iters: 3000
33
- save_iters: 3000
34
-
35
- optimizer:
36
- type: Adam
37
- true_wd: False
38
- wd: 0.0
39
- bn_wd: False
40
- clip_grad: 20
41
- lr: 0.0001
42
- args: {
43
- betas: !!python/tuple [0.9, 0.999], # for default Adam
44
- }
45
- scheduler: {
46
- periods: [6, 4],
47
- gamma: 0.1,
48
- }
49
-
50
- model:
51
- name: 'modules.model_abinet_iter.ABINetIterModel'
52
- iter_size: 3
53
- ensemble: ''
54
- use_vision: False
55
- vision: {
56
- checkpoint: workdir/pretrain-vision-model/best-pretrain-vision-model.pth,
57
- loss_weight: 1.,
58
- attention: 'position',
59
- backbone: 'transformer',
60
- backbone_ln: 3,
61
- }
62
- language: {
63
- checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth,
64
- num_layers: 4,
65
- loss_weight: 1.,
66
- detach: True,
67
- use_self_attn: False
68
- }
69
- alignment: {
70
- loss_weight: 1.,
71
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/train_abinet_sv.yaml DELETED
@@ -1,71 +0,0 @@
1
- global:
2
- name: train-abinet-sv
3
- phase: train
4
- stage: train-super
5
- workdir: workdir
6
- seed: ~
7
-
8
- dataset:
9
- train: {
10
- roots: ['data/training/MJ/MJ_train/',
11
- 'data/training/MJ/MJ_test/',
12
- 'data/training/MJ/MJ_valid/',
13
- 'data/training/ST'],
14
- batch_size: 384
15
- }
16
- test: {
17
- roots: ['data/evaluation/IIIT5k_3000',
18
- 'data/evaluation/SVT',
19
- 'data/evaluation/SVTP',
20
- 'data/evaluation/IC13_857',
21
- 'data/evaluation/IC15_1811',
22
- 'data/evaluation/CUTE80'],
23
- batch_size: 384
24
- }
25
- data_aug: True
26
- multiscales: False
27
- num_workers: 14
28
-
29
- training:
30
- epochs: 10
31
- show_iters: 50
32
- eval_iters: 3000
33
- save_iters: 3000
34
-
35
- optimizer:
36
- type: Adam
37
- true_wd: False
38
- wd: 0.0
39
- bn_wd: False
40
- clip_grad: 20
41
- lr: 0.0001
42
- args: {
43
- betas: !!python/tuple [0.9, 0.999], # for default Adam
44
- }
45
- scheduler: {
46
- periods: [6, 4],
47
- gamma: 0.1,
48
- }
49
-
50
- model:
51
- name: 'modules.model_abinet_iter.ABINetIterModel'
52
- iter_size: 3
53
- ensemble: ''
54
- use_vision: False
55
- vision: {
56
- checkpoint: workdir/pretrain-vision-model-sv/best-pretrain-vision-model-sv.pth,
57
- loss_weight: 1.,
58
- attention: 'attention',
59
- backbone: 'transformer',
60
- backbone_ln: 2,
61
- }
62
- language: {
63
- checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth,
64
- num_layers: 4,
65
- loss_weight: 1.,
66
- detach: True,
67
- use_self_attn: False
68
- }
69
- alignment: {
70
- loss_weight: 1.,
71
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/train_abinet_wo_iter.yaml DELETED
@@ -1,68 +0,0 @@
1
- global:
2
- name: train-abinet-wo-iter
3
- phase: train
4
- stage: train-super
5
- workdir: workdir
6
- seed: ~
7
-
8
- dataset:
9
- train: {
10
- roots: ['data/training/MJ/MJ_train/',
11
- 'data/training/MJ/MJ_test/',
12
- 'data/training/MJ/MJ_valid/',
13
- 'data/training/ST'],
14
- batch_size: 384
15
- }
16
- test: {
17
- roots: ['data/evaluation/IIIT5k_3000',
18
- 'data/evaluation/SVT',
19
- 'data/evaluation/SVTP',
20
- 'data/evaluation/IC13_857',
21
- 'data/evaluation/IC15_1811',
22
- 'data/evaluation/CUTE80'],
23
- batch_size: 384
24
- }
25
- data_aug: True
26
- multiscales: False
27
- num_workers: 14
28
-
29
- training:
30
- epochs: 10
31
- show_iters: 50
32
- eval_iters: 3000
33
- save_iters: 3000
34
-
35
- optimizer:
36
- type: Adam
37
- true_wd: False
38
- wd: 0.0
39
- bn_wd: False
40
- clip_grad: 20
41
- lr: 0.0001
42
- args: {
43
- betas: !!python/tuple [0.9, 0.999], # for default Adam
44
- }
45
- scheduler: {
46
- periods: [6, 4],
47
- gamma: 0.1,
48
- }
49
-
50
- model:
51
- name: 'modules.model_abinet.ABINetModel'
52
- vision: {
53
- checkpoint: workdir/pretrain-vision-model/best-pretrain-vision-model.pth,
54
- loss_weight: 1.,
55
- attention: 'position',
56
- backbone: 'transformer',
57
- backbone_ln: 3,
58
- }
59
- language: {
60
- checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth,
61
- num_layers: 4,
62
- loss_weight: 1.,
63
- detach: True,
64
- use_self_attn: False
65
- }
66
- alignment: {
67
- loss_weight: 1.,
68
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/charset_36.txt DELETED
@@ -1,36 +0,0 @@
1
- 0 a
2
- 1 b
3
- 2 c
4
- 3 d
5
- 4 e
6
- 5 f
7
- 6 g
8
- 7 h
9
- 8 i
10
- 9 j
11
- 10 k
12
- 11 l
13
- 12 m
14
- 13 n
15
- 14 o
16
- 15 p
17
- 16 q
18
- 17 r
19
- 18 s
20
- 19 t
21
- 20 u
22
- 21 v
23
- 22 w
24
- 23 x
25
- 24 y
26
- 25 z
27
- 26 1
28
- 27 2
29
- 28 3
30
- 29 4
31
- 30 5
32
- 31 6
33
- 32 7
34
- 33 8
35
- 34 9
36
- 35 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/charset_62.txt DELETED
@@ -1,62 +0,0 @@
1
- 0 0
2
- 1 1
3
- 2 2
4
- 3 3
5
- 4 4
6
- 5 5
7
- 6 6
8
- 7 7
9
- 8 8
10
- 9 9
11
- 10 A
12
- 11 B
13
- 12 C
14
- 13 D
15
- 14 E
16
- 15 F
17
- 16 G
18
- 17 H
19
- 18 I
20
- 19 J
21
- 20 K
22
- 21 L
23
- 22 M
24
- 23 N
25
- 24 O
26
- 25 P
27
- 26 Q
28
- 27 R
29
- 28 S
30
- 29 T
31
- 30 U
32
- 31 V
33
- 32 W
34
- 33 X
35
- 34 Y
36
- 35 Z
37
- 36 a
38
- 37 b
39
- 38 c
40
- 39 d
41
- 40 e
42
- 41 f
43
- 42 g
44
- 43 h
45
- 44 i
46
- 45 j
47
- 46 k
48
- 47 l
49
- 48 m
50
- 49 n
51
- 50 o
52
- 51 p
53
- 52 q
54
- 53 r
55
- 54 s
56
- 55 t
57
- 56 u
58
- 57 v
59
- 58 w
60
- 59 x
61
- 60 y
62
- 61 z
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docker/Dockerfile DELETED
@@ -1,25 +0,0 @@
1
- FROM anibali/pytorch:cuda-9.0
2
- MAINTAINER fangshancheng <fangsc@ustc.edu.cn>
3
- RUN sudo rm -rf /etc/apt/sources.list.d && \
4
- sudo apt update && \
5
- sudo apt install -y build-essential vim && \
6
- conda config --add channels https://mirrors.ustc.edu.cn/anaconda/pkgs/free/ && \
7
- conda config --add channels https://mirrors.ustc.edu.cn/anaconda/pkgs/main/ && \
8
- conda config --set show_channel_urls yes && \
9
- pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ && \
10
- pip install torch==1.1.0 torchvision==0.3.0 && \
11
- pip install fastai==1.0.60 && \
12
- pip install ipdb jupyter ipython lmdb editdistance tensorboardX natsort nltk && \
13
- conda uninstall -y --force pillow pil jpeg libtiff libjpeg-turbo && \
14
- pip uninstall -y pillow pil jpeg libtiff libjpeg-turbo && \
15
- conda install -yc conda-forge libjpeg-turbo && \
16
- CFLAGS="${CFLAGS} -mavx2" pip install --no-cache-dir --force-reinstall --no-binary :all: --compile pillow-simd==6.2.2.post1 && \
17
- conda install -y jpeg libtiff opencv && \
18
- sudo rm -rf /var/lib/apt/lists/* && \
19
- sudo rm -rf /tmp/* && \
20
- sudo rm -rf ~/.cache && \
21
- sudo apt clean all && \
22
- conda clean -y -a
23
- EXPOSE 8888
24
- ENV LANG C.UTF-8
25
- ENV LC_ALL C.UTF-8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/dataset-text.ipynb DELETED
@@ -1,159 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import os\n",
10
- "os.chdir('..')\n",
11
- "from dataset import *\n",
12
- "torch.set_printoptions(sci_mode=False)"
13
- ]
14
- },
15
- {
16
- "cell_type": "markdown",
17
- "metadata": {},
18
- "source": [
19
- "# Construct dataset"
20
- ]
21
- },
22
- {
23
- "cell_type": "code",
24
- "execution_count": null,
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "data = TextDataset('data/Vocabulary_train_v2.csv', is_training=False, smooth_label=True, smooth_factor=0.1)"
29
- ]
30
- },
31
- {
32
- "cell_type": "code",
33
- "execution_count": null,
34
- "metadata": {},
35
- "outputs": [],
36
- "source": [
37
- "data = DataBunch.create(train_ds=data, valid_ds=None, bs=6)"
38
- ]
39
- },
40
- {
41
- "cell_type": "code",
42
- "execution_count": null,
43
- "metadata": {},
44
- "outputs": [],
45
- "source": [
46
- "x, y = data.one_batch(); x, y"
47
- ]
48
- },
49
- {
50
- "cell_type": "code",
51
- "execution_count": null,
52
- "metadata": {},
53
- "outputs": [],
54
- "source": [
55
- "x[0].shape, x[1].shape"
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": null,
61
- "metadata": {},
62
- "outputs": [],
63
- "source": [
64
- "y[0].shape, y[1].shape"
65
- ]
66
- },
67
- {
68
- "cell_type": "code",
69
- "execution_count": null,
70
- "metadata": {},
71
- "outputs": [],
72
- "source": [
73
- "x[0].argmax(-1) - y[0].argmax(-1)"
74
- ]
75
- },
76
- {
77
- "cell_type": "code",
78
- "execution_count": null,
79
- "metadata": {},
80
- "outputs": [],
81
- "source": [
82
- "x[0].argmax(-1)"
83
- ]
84
- },
85
- {
86
- "cell_type": "code",
87
- "execution_count": null,
88
- "metadata": {},
89
- "outputs": [],
90
- "source": [
91
- "y[0].argmax(-1)"
92
- ]
93
- },
94
- {
95
- "cell_type": "code",
96
- "execution_count": null,
97
- "metadata": {},
98
- "outputs": [],
99
- "source": [
100
- "x[0][0,0]"
101
- ]
102
- },
103
- {
104
- "cell_type": "markdown",
105
- "metadata": {},
106
- "source": [
107
- "# test SpellingMutation"
108
- ]
109
- },
110
- {
111
- "cell_type": "code",
112
- "execution_count": null,
113
- "metadata": {},
114
- "outputs": [],
115
- "source": [
116
- "probs = {'pn0': 0., 'pn1': 0., 'pn2': 0., 'pt0': 1.0, 'pt1': 1.0}\n",
117
- "charset = CharsetMapper('data/charset_36.txt')\n",
118
- "sm = SpellingMutation(charset=charset, **probs)"
119
- ]
120
- },
121
- {
122
- "cell_type": "code",
123
- "execution_count": null,
124
- "metadata": {},
125
- "outputs": [],
126
- "source": [
127
- "sm('*a-aa')"
128
- ]
129
- },
130
- {
131
- "cell_type": "code",
132
- "execution_count": null,
133
- "metadata": {},
134
- "outputs": [],
135
- "source": []
136
- }
137
- ],
138
- "metadata": {
139
- "kernelspec": {
140
- "display_name": "Python 3",
141
- "language": "python",
142
- "name": "python3"
143
- },
144
- "language_info": {
145
- "codemirror_mode": {
146
- "name": "ipython",
147
- "version": 3
148
- },
149
- "file_extension": ".py",
150
- "mimetype": "text/x-python",
151
- "name": "python",
152
- "nbconvert_exporter": "python",
153
- "pygments_lexer": "ipython3",
154
- "version": "3.7.4"
155
- }
156
- },
157
- "nbformat": 4,
158
- "nbformat_minor": 2
159
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/dataset.ipynb DELETED
@@ -1,298 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import os\n",
10
- "os.chdir('..')\n",
11
- "from dataset import *"
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": null,
17
- "metadata": {
18
- "scrolled": false
19
- },
20
- "outputs": [],
21
- "source": [
22
- "import logging\n",
23
- "from torchvision.transforms import ToPILImage\n",
24
- "from torchvision.utils import make_grid\n",
25
- "from IPython.display import display\n",
26
- "from torch.utils.data import ConcatDataset\n",
27
- "charset = CharsetMapper('data/charset_36.txt')"
28
- ]
29
- },
30
- {
31
- "cell_type": "code",
32
- "execution_count": null,
33
- "metadata": {},
34
- "outputs": [],
35
- "source": [
36
- "def show_all(dl, iter_size=None):\n",
37
- " if iter_size is None: iter_size = len(dl)\n",
38
- " for i, item in enumerate(dl):\n",
39
- " if i >= iter_size:\n",
40
- " break\n",
41
- " image = item[0]\n",
42
- " label = item[1][0]\n",
43
- " length = item[1][1]\n",
44
- " print(f'iter {i}:', [charset.get_text(label[j][0: length[j]].argmax(-1), padding=False) for j in range(bs)])\n",
45
- " display(ToPILImage()(make_grid(item[0].cpu())))"
46
- ]
47
- },
48
- {
49
- "cell_type": "markdown",
50
- "metadata": {},
51
- "source": [
52
- "# Construct dataset"
53
- ]
54
- },
55
- {
56
- "cell_type": "code",
57
- "execution_count": null,
58
- "metadata": {},
59
- "outputs": [],
60
- "source": [
61
- "data1 = ImageDataset('data/training/ST', is_training=True);data1 # is_training"
62
- ]
63
- },
64
- {
65
- "cell_type": "code",
66
- "execution_count": null,
67
- "metadata": {
68
- "scrolled": true
69
- },
70
- "outputs": [],
71
- "source": [
72
- "bs=64\n",
73
- "data2 = ImageDataBunch.create(train_ds=data1, valid_ds=None, bs=bs, num_workers=1);data2"
74
- ]
75
- },
76
- {
77
- "cell_type": "code",
78
- "execution_count": null,
79
- "metadata": {},
80
- "outputs": [],
81
- "source": [
82
- "#data3 = data2.normalize(imagenet_stats);data3\n",
83
- "data3 = data2"
84
- ]
85
- },
86
- {
87
- "cell_type": "code",
88
- "execution_count": null,
89
- "metadata": {},
90
- "outputs": [],
91
- "source": [
92
- "show_all(data3.train_dl, 4)"
93
- ]
94
- },
95
- {
96
- "cell_type": "markdown",
97
- "metadata": {},
98
- "source": [
99
- "# Add dataset"
100
- ]
101
- },
102
- {
103
- "cell_type": "code",
104
- "execution_count": null,
105
- "metadata": {},
106
- "outputs": [],
107
- "source": [
108
- "kwargs = {'data_aug': False, 'is_training': False}"
109
- ]
110
- },
111
- {
112
- "cell_type": "code",
113
- "execution_count": null,
114
- "metadata": {},
115
- "outputs": [],
116
- "source": [
117
- "data1 = ImageDataset('data/evaluation/IIIT5k_3000', **kwargs);data1"
118
- ]
119
- },
120
- {
121
- "cell_type": "code",
122
- "execution_count": null,
123
- "metadata": {},
124
- "outputs": [],
125
- "source": [
126
- "data2 = ImageDataset('data/evaluation/SVT', **kwargs);data2"
127
- ]
128
- },
129
- {
130
- "cell_type": "code",
131
- "execution_count": null,
132
- "metadata": {},
133
- "outputs": [],
134
- "source": [
135
- "data3 = ConcatDataset([data1, data2])"
136
- ]
137
- },
138
- {
139
- "cell_type": "code",
140
- "execution_count": null,
141
- "metadata": {},
142
- "outputs": [],
143
- "source": [
144
- "bs=64\n",
145
- "data4 = ImageDataBunch.create(train_ds=data1, valid_ds=data3, bs=bs, num_workers=1);data4"
146
- ]
147
- },
148
- {
149
- "cell_type": "code",
150
- "execution_count": null,
151
- "metadata": {},
152
- "outputs": [],
153
- "source": [
154
- "len(data4.train_dl), len(data4.valid_dl)"
155
- ]
156
- },
157
- {
158
- "cell_type": "code",
159
- "execution_count": null,
160
- "metadata": {},
161
- "outputs": [],
162
- "source": [
163
- "show_all(data4.train_dl, 4)"
164
- ]
165
- },
166
- {
167
- "cell_type": "markdown",
168
- "metadata": {},
169
- "source": [
170
- "# TEST"
171
- ]
172
- },
173
- {
174
- "cell_type": "code",
175
- "execution_count": null,
176
- "metadata": {},
177
- "outputs": [],
178
- "source": [
179
- "len(data4.valid_dl)"
180
- ]
181
- },
182
- {
183
- "cell_type": "code",
184
- "execution_count": null,
185
- "metadata": {},
186
- "outputs": [],
187
- "source": [
188
- "import time\n",
189
- "niter = 1000\n",
190
- "start = time.time()\n",
191
- "for i, item in enumerate(progress_bar(data4.valid_dl)):\n",
192
- " if i % niter == 0 and i > 0:\n",
193
- " print(i, (time.time() - start) / niter)\n",
194
- " start = time.time()"
195
- ]
196
- },
197
- {
198
- "cell_type": "code",
199
- "execution_count": null,
200
- "metadata": {
201
- "scrolled": true
202
- },
203
- "outputs": [],
204
- "source": [
205
- "num = 20\n",
206
- "index = 6\n",
207
- "plt.figure(figsize=(20, 10))\n",
208
- "for i in range(num):\n",
209
- " plt.subplot(num // 4, 4, i+1)\n",
210
- " plt.imshow(data4.train_ds[i][0].data.numpy().transpose(1,2,0))"
211
- ]
212
- },
213
- {
214
- "cell_type": "code",
215
- "execution_count": null,
216
- "metadata": {},
217
- "outputs": [],
218
- "source": [
219
- "def show(path, image_key):\n",
220
- " with lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False).begin(write=False) as txn:\n",
221
- " imgbuf = txn.get(image_key.encode()) # image\n",
222
- " buf = six.BytesIO()\n",
223
- " buf.write(imgbuf)\n",
224
- " buf.seek(0)\n",
225
- " with warnings.catch_warnings():\n",
226
- " warnings.simplefilter(\"ignore\", UserWarning) # EXIF warning from TiffPlugin\n",
227
- " x = PIL.Image.open(buf).convert('RGB')\n",
228
- " print(x.size)\n",
229
- " plt.imshow(x)"
230
- ]
231
- },
232
- {
233
- "cell_type": "code",
234
- "execution_count": null,
235
- "metadata": {},
236
- "outputs": [],
237
- "source": [
238
- "image_key = 'image-003118258'\n",
239
- "image_key = 'image-002780217'\n",
240
- "image_key = 'image-002780218'\n",
241
- "path = 'data/CVPR2016'\n",
242
- "show(path, image_key)"
243
- ]
244
- },
245
- {
246
- "cell_type": "code",
247
- "execution_count": null,
248
- "metadata": {},
249
- "outputs": [],
250
- "source": [
251
- "image_key = 'image-004668347'\n",
252
- "image_key = 'image-006128516'\n",
253
- "path = 'data/NIPS2014'\n",
254
- "show(path, image_key)"
255
- ]
256
- },
257
- {
258
- "cell_type": "code",
259
- "execution_count": null,
260
- "metadata": {},
261
- "outputs": [],
262
- "source": [
263
- "image_key = 'image-004668347'\n",
264
- "image_key = 'image-000002420'\n",
265
- "path = 'data/IIIT5K_3000'\n",
266
- "show(path, image_key)"
267
- ]
268
- },
269
- {
270
- "cell_type": "code",
271
- "execution_count": null,
272
- "metadata": {},
273
- "outputs": [],
274
- "source": []
275
- }
276
- ],
277
- "metadata": {
278
- "kernelspec": {
279
- "display_name": "Python 3",
280
- "language": "python",
281
- "name": "python3"
282
- },
283
- "language_info": {
284
- "codemirror_mode": {
285
- "name": "ipython",
286
- "version": 3
287
- },
288
- "file_extension": ".py",
289
- "mimetype": "text/x-python",
290
- "name": "python",
291
- "nbconvert_exporter": "python",
292
- "pygments_lexer": "ipython3",
293
- "version": "3.7.4"
294
- }
295
- },
296
- "nbformat": 4,
297
- "nbformat_minor": 2
298
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/prepare_wikitext103.ipynb DELETED
@@ -1,468 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# 82841986 is_char and is_digit"
8
- ]
9
- },
10
- {
11
- "cell_type": "markdown",
12
- "metadata": {},
13
- "source": [
14
- "# 82075350 regrex non-ascii and none-digit"
15
- ]
16
- },
17
- {
18
- "cell_type": "markdown",
19
- "metadata": {},
20
- "source": [
21
- "## 86460763 left"
22
- ]
23
- },
24
- {
25
- "cell_type": "code",
26
- "execution_count": 1,
27
- "metadata": {},
28
- "outputs": [],
29
- "source": [
30
- "import os\n",
31
- "import random\n",
32
- "import re\n",
33
- "import pandas as pd"
34
- ]
35
- },
36
- {
37
- "cell_type": "code",
38
- "execution_count": 2,
39
- "metadata": {},
40
- "outputs": [],
41
- "source": [
42
- "max_length = 25\n",
43
- "min_length = 1\n",
44
- "root = '../data'\n",
45
- "charset = 'abcdefghijklmnopqrstuvwxyz'\n",
46
- "digits = '0123456789'"
47
- ]
48
- },
49
- {
50
- "cell_type": "code",
51
- "execution_count": 3,
52
- "metadata": {},
53
- "outputs": [],
54
- "source": [
55
- "def is_char(text, ratio=0.5):\n",
56
- " text = text.lower()\n",
57
- " length = max(len(text), 1)\n",
58
- " char_num = sum([t in charset for t in text])\n",
59
- " if char_num < min_length: return False\n",
60
- " if char_num / length < ratio: return False\n",
61
- " return True\n",
62
- "\n",
63
- "def is_digit(text, ratio=0.5):\n",
64
- " length = max(len(text), 1)\n",
65
- " digit_num = sum([t in digits for t in text])\n",
66
- " if digit_num / length < ratio: return False\n",
67
- " return True"
68
- ]
69
- },
70
- {
71
- "cell_type": "markdown",
72
- "metadata": {},
73
- "source": [
74
- "# generate training dataset"
75
- ]
76
- },
77
- {
78
- "cell_type": "code",
79
- "execution_count": 4,
80
- "metadata": {},
81
- "outputs": [],
82
- "source": [
83
- "with open('/tmp/wikitext-103/wiki.train.tokens', 'r') as file:\n",
84
- " lines = file.readlines()"
85
- ]
86
- },
87
- {
88
- "cell_type": "code",
89
- "execution_count": 5,
90
- "metadata": {},
91
- "outputs": [],
92
- "source": [
93
- "inp, gt = [], []\n",
94
- "for line in lines:\n",
95
- " token = line.lower().split()\n",
96
- " for text in token:\n",
97
- " text = re.sub('[^0-9a-zA-Z]+', '', text)\n",
98
- " if len(text) < min_length:\n",
99
- " # print('short-text', text)\n",
100
- " continue\n",
101
- " if len(text) > max_length:\n",
102
- " # print('long-text', text)\n",
103
- " continue\n",
104
- " inp.append(text)\n",
105
- " gt.append(text)"
106
- ]
107
- },
108
- {
109
- "cell_type": "code",
110
- "execution_count": 6,
111
- "metadata": {},
112
- "outputs": [],
113
- "source": [
114
- "train_voc = os.path.join(root, 'WikiText-103.csv')\n",
115
- "pd.DataFrame({'inp':inp, 'gt':gt}).to_csv(train_voc, index=None, sep='\\t')"
116
- ]
117
- },
118
- {
119
- "cell_type": "code",
120
- "execution_count": 7,
121
- "metadata": {},
122
- "outputs": [
123
- {
124
- "data": {
125
- "text/plain": [
126
- "86460763"
127
- ]
128
- },
129
- "execution_count": 7,
130
- "metadata": {},
131
- "output_type": "execute_result"
132
- }
133
- ],
134
- "source": [
135
- "len(inp)"
136
- ]
137
- },
138
- {
139
- "cell_type": "code",
140
- "execution_count": 8,
141
- "metadata": {},
142
- "outputs": [
143
- {
144
- "data": {
145
- "text/plain": [
146
- "['valkyria',\n",
147
- " 'chronicles',\n",
148
- " 'iii',\n",
149
- " 'senj',\n",
150
- " 'no',\n",
151
- " 'valkyria',\n",
152
- " '3',\n",
153
- " 'unk',\n",
154
- " 'chronicles',\n",
155
- " 'japanese',\n",
156
- " '3',\n",
157
- " 'lit',\n",
158
- " 'valkyria',\n",
159
- " 'of',\n",
160
- " 'the',\n",
161
- " 'battlefield',\n",
162
- " '3',\n",
163
- " 'commonly',\n",
164
- " 'referred',\n",
165
- " 'to',\n",
166
- " 'as',\n",
167
- " 'valkyria',\n",
168
- " 'chronicles',\n",
169
- " 'iii',\n",
170
- " 'outside',\n",
171
- " 'japan',\n",
172
- " 'is',\n",
173
- " 'a',\n",
174
- " 'tactical',\n",
175
- " 'role',\n",
176
- " 'playing',\n",
177
- " 'video',\n",
178
- " 'game',\n",
179
- " 'developed',\n",
180
- " 'by',\n",
181
- " 'sega',\n",
182
- " 'and',\n",
183
- " 'mediavision',\n",
184
- " 'for',\n",
185
- " 'the',\n",
186
- " 'playstation',\n",
187
- " 'portable',\n",
188
- " 'released',\n",
189
- " 'in',\n",
190
- " 'january',\n",
191
- " '2011',\n",
192
- " 'in',\n",
193
- " 'japan',\n",
194
- " 'it',\n",
195
- " 'is',\n",
196
- " 'the',\n",
197
- " 'third',\n",
198
- " 'game',\n",
199
- " 'in',\n",
200
- " 'the',\n",
201
- " 'valkyria',\n",
202
- " 'series',\n",
203
- " 'employing',\n",
204
- " 'the',\n",
205
- " 'same',\n",
206
- " 'fusion',\n",
207
- " 'of',\n",
208
- " 'tactical',\n",
209
- " 'and',\n",
210
- " 'real',\n",
211
- " 'time',\n",
212
- " 'gameplay',\n",
213
- " 'as',\n",
214
- " 'its',\n",
215
- " 'predecessors',\n",
216
- " 'the',\n",
217
- " 'story',\n",
218
- " 'runs',\n",
219
- " 'parallel',\n",
220
- " 'to',\n",
221
- " 'the',\n",
222
- " 'first',\n",
223
- " 'game',\n",
224
- " 'and',\n",
225
- " 'follows',\n",
226
- " 'the',\n",
227
- " 'nameless',\n",
228
- " 'a',\n",
229
- " 'penal',\n",
230
- " 'military',\n",
231
- " 'unit',\n",
232
- " 'serving',\n",
233
- " 'the',\n",
234
- " 'nation',\n",
235
- " 'of',\n",
236
- " 'gallia',\n",
237
- " 'during',\n",
238
- " 'the',\n",
239
- " 'second',\n",
240
- " 'europan',\n",
241
- " 'war',\n",
242
- " 'who',\n",
243
- " 'perform',\n",
244
- " 'secret',\n",
245
- " 'black']"
246
- ]
247
- },
248
- "execution_count": 8,
249
- "metadata": {},
250
- "output_type": "execute_result"
251
- }
252
- ],
253
- "source": [
254
- "inp[:100]"
255
- ]
256
- },
257
- {
258
- "cell_type": "markdown",
259
- "metadata": {},
260
- "source": [
261
- "# generate evaluation dataset"
262
- ]
263
- },
264
- {
265
- "cell_type": "code",
266
- "execution_count": 9,
267
- "metadata": {},
268
- "outputs": [],
269
- "source": [
270
- "def disturb(word, degree, p=0.3):\n",
271
- " if len(word) // 2 < degree: return word\n",
272
- " if is_digit(word): return word\n",
273
- " if random.random() < p: return word\n",
274
- " else:\n",
275
- " index = list(range(len(word)))\n",
276
- " random.shuffle(index)\n",
277
- " index = index[:degree]\n",
278
- " new_word = []\n",
279
- " for i in range(len(word)):\n",
280
- " if i not in index: \n",
281
- " new_word.append(word[i])\n",
282
- " continue\n",
283
- " if (word[i] not in charset) and (word[i] not in digits):\n",
284
- " # special token\n",
285
- " new_word.append(word[i])\n",
286
- " continue\n",
287
- " op = random.random()\n",
288
- " if op < 0.1: # add\n",
289
- " new_word.append(random.choice(charset))\n",
290
- " new_word.append(word[i])\n",
291
- " elif op < 0.2: continue # remove\n",
292
- " else: new_word.append(random.choice(charset)) # replace\n",
293
- " return ''.join(new_word)"
294
- ]
295
- },
296
- {
297
- "cell_type": "code",
298
- "execution_count": 10,
299
- "metadata": {},
300
- "outputs": [],
301
- "source": [
302
- "lines = inp\n",
303
- "degree = 1\n",
304
- "keep_num = 50000\n",
305
- "\n",
306
- "random.shuffle(lines)\n",
307
- "part_lines = lines[:keep_num]\n",
308
- "inp, gt = [], []\n",
309
- "\n",
310
- "for w in part_lines:\n",
311
- " w = w.strip().lower()\n",
312
- " new_w = disturb(w, degree)\n",
313
- " inp.append(new_w)\n",
314
- " gt.append(w)\n",
315
- " \n",
316
- "eval_voc = os.path.join(root, f'WikiText-103_eval_d{degree}.csv')\n",
317
- "pd.DataFrame({'inp':inp, 'gt':gt}).to_csv(eval_voc, index=None, sep='\\t')"
318
- ]
319
- },
320
- {
321
- "cell_type": "code",
322
- "execution_count": 11,
323
- "metadata": {},
324
- "outputs": [
325
- {
326
- "data": {
327
- "text/plain": [
328
- "[('high', 'high'),\n",
329
- " ('vctoria', 'victoria'),\n",
330
- " ('mains', 'mains'),\n",
331
- " ('bi', 'by'),\n",
332
- " ('13', '13'),\n",
333
- " ('ticnet', 'ticket'),\n",
334
- " ('basil', 'basic'),\n",
335
- " ('cut', 'cut'),\n",
336
- " ('aqarky', 'anarky'),\n",
337
- " ('the', 'the'),\n",
338
- " ('tqe', 'the'),\n",
339
- " ('oc', 'of'),\n",
340
- " ('diwpersal', 'dispersal'),\n",
341
- " ('traffic', 'traffic'),\n",
342
- " ('in', 'in'),\n",
343
- " ('the', 'the'),\n",
344
- " ('ti', 'to'),\n",
345
- " ('professionalms', 'professionals'),\n",
346
- " ('747', '747'),\n",
347
- " ('in', 'in'),\n",
348
- " ('and', 'and'),\n",
349
- " ('exezutive', 'executive'),\n",
350
- " ('n400', 'n400'),\n",
351
- " ('yusic', 'music'),\n",
352
- " ('s', 's'),\n",
353
- " ('henri', 'henry'),\n",
354
- " ('heard', 'heard'),\n",
355
- " ('thousand', 'thousand'),\n",
356
- " ('to', 'to'),\n",
357
- " ('arhy', 'army'),\n",
358
- " ('td', 'to'),\n",
359
- " ('a', 'a'),\n",
360
- " ('oall', 'hall'),\n",
361
- " ('qind', 'kind'),\n",
362
- " ('od', 'on'),\n",
363
- " ('samfria', 'samaria'),\n",
364
- " ('driveway', 'driveway'),\n",
365
- " ('which', 'which'),\n",
366
- " ('wotk', 'work'),\n",
367
- " ('ak', 'as'),\n",
368
- " ('persona', 'persona'),\n",
369
- " ('s', 's'),\n",
370
- " ('melbourne', 'melbourne'),\n",
371
- " ('apong', 'along'),\n",
372
- " ('fas', 'was'),\n",
373
- " ('thea', 'then'),\n",
374
- " ('permcy', 'percy'),\n",
375
- " ('nnd', 'and'),\n",
376
- " ('alan', 'alan'),\n",
377
- " ('13', '13'),\n",
378
- " ('matteos', 'matters'),\n",
379
- " ('against', 'against'),\n",
380
- " ('nefion', 'nexion'),\n",
381
- " ('held', 'held'),\n",
382
- " ('negative', 'negative'),\n",
383
- " ('gogd', 'good'),\n",
384
- " ('the', 'the'),\n",
385
- " ('thd', 'the'),\n",
386
- " ('groening', 'groening'),\n",
387
- " ('tqe', 'the'),\n",
388
- " ('cwould', 'would'),\n",
389
- " ('fb', 'ft'),\n",
390
- " ('uniten', 'united'),\n",
391
- " ('kone', 'one'),\n",
392
- " ('thiy', 'this'),\n",
393
- " ('lanren', 'lauren'),\n",
394
- " ('s', 's'),\n",
395
- " ('thhe', 'the'),\n",
396
- " ('is', 'is'),\n",
397
- " ('modep', 'model'),\n",
398
- " ('weird', 'weird'),\n",
399
- " ('angwer', 'answer'),\n",
400
- " ('imprisxnment', 'imprisonment'),\n",
401
- " ('marpery', 'margery'),\n",
402
- " ('eventuanly', 'eventually'),\n",
403
- " ('in', 'in'),\n",
404
- " ('donnoa', 'donna'),\n",
405
- " ('ik', 'it'),\n",
406
- " ('reached', 'reached'),\n",
407
- " ('at', 'at'),\n",
408
- " ('excxted', 'excited'),\n",
409
- " ('ws', 'was'),\n",
410
- " ('raes', 'rates'),\n",
411
- " ('the', 'the'),\n",
412
- " ('firsq', 'first'),\n",
413
- " ('concluyed', 'concluded'),\n",
414
- " ('recdorded', 'recorded'),\n",
415
- " ('fhe', 'the'),\n",
416
- " ('uegiment', 'regiment'),\n",
417
- " ('a', 'a'),\n",
418
- " ('glanes', 'planes'),\n",
419
- " ('conyrol', 'control'),\n",
420
- " ('thr', 'the'),\n",
421
- " ('arrext', 'arrest'),\n",
422
- " ('bth', 'both'),\n",
423
- " ('forward', 'forward'),\n",
424
- " ('allowdd', 'allowed'),\n",
425
- " ('revealed', 'revealed'),\n",
426
- " ('mayagement', 'management'),\n",
427
- " ('normal', 'normal')]"
428
- ]
429
- },
430
- "execution_count": 11,
431
- "metadata": {},
432
- "output_type": "execute_result"
433
- }
434
- ],
435
- "source": [
436
- "list(zip(inp, gt))[:100]"
437
- ]
438
- },
439
- {
440
- "cell_type": "code",
441
- "execution_count": null,
442
- "metadata": {},
443
- "outputs": [],
444
- "source": []
445
- }
446
- ],
447
- "metadata": {
448
- "kernelspec": {
449
- "display_name": "Python 3",
450
- "language": "python",
451
- "name": "python3"
452
- },
453
- "language_info": {
454
- "codemirror_mode": {
455
- "name": "ipython",
456
- "version": 3
457
- },
458
- "file_extension": ".py",
459
- "mimetype": "text/x-python",
460
- "name": "python",
461
- "nbconvert_exporter": "python",
462
- "pygments_lexer": "ipython3",
463
- "version": "3.7.4"
464
- }
465
- },
466
- "nbformat": 4,
467
- "nbformat_minor": 4
468
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/transforms.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
tools/create_lmdb_dataset.py DELETED
@@ -1,87 +0,0 @@
1
- """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """
2
-
3
- import fire
4
- import os
5
- import lmdb
6
- import cv2
7
-
8
- import numpy as np
9
-
10
-
11
- def checkImageIsValid(imageBin):
12
- if imageBin is None:
13
- return False
14
- imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
15
- img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
16
- imgH, imgW = img.shape[0], img.shape[1]
17
- if imgH * imgW == 0:
18
- return False
19
- return True
20
-
21
-
22
- def writeCache(env, cache):
23
- with env.begin(write=True) as txn:
24
- for k, v in cache.items():
25
- txn.put(k, v)
26
-
27
-
28
- def createDataset(inputPath, gtFile, outputPath, checkValid=True):
29
- """
30
- Create LMDB dataset for training and evaluation.
31
- ARGS:
32
- inputPath : input folder path where starts imagePath
33
- outputPath : LMDB output path
34
- gtFile : list of image path and label
35
- checkValid : if true, check the validity of every image
36
- """
37
- os.makedirs(outputPath, exist_ok=True)
38
- env = lmdb.open(outputPath, map_size=1099511627776)
39
- cache = {}
40
- cnt = 1
41
-
42
- with open(gtFile, 'r', encoding='utf-8') as data:
43
- datalist = data.readlines()
44
-
45
- nSamples = len(datalist)
46
- for i in range(nSamples):
47
- imagePath, label = datalist[i].strip('\n').split('\t')
48
- imagePath = os.path.join(inputPath, imagePath)
49
-
50
- # # only use alphanumeric data
51
- # if re.search('[^a-zA-Z0-9]', label):
52
- # continue
53
-
54
- if not os.path.exists(imagePath):
55
- print('%s does not exist' % imagePath)
56
- continue
57
- with open(imagePath, 'rb') as f:
58
- imageBin = f.read()
59
- if checkValid:
60
- try:
61
- if not checkImageIsValid(imageBin):
62
- print('%s is not a valid image' % imagePath)
63
- continue
64
- except:
65
- print('error occured', i)
66
- with open(outputPath + '/error_image_log.txt', 'a') as log:
67
- log.write('%s-th image data occured error\n' % str(i))
68
- continue
69
-
70
- imageKey = 'image-%09d'.encode() % cnt
71
- labelKey = 'label-%09d'.encode() % cnt
72
- cache[imageKey] = imageBin
73
- cache[labelKey] = label.encode()
74
-
75
- if cnt % 1000 == 0:
76
- writeCache(env, cache)
77
- cache = {}
78
- print('Written %d / %d' % (cnt, nSamples))
79
- cnt += 1
80
- nSamples = cnt-1
81
- cache['num-samples'.encode()] = str(nSamples).encode()
82
- writeCache(env, cache)
83
- print('Created dataset with %d samples' % nSamples)
84
-
85
-
86
- if __name__ == '__main__':
87
- fire.Fire(createDataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/crop_by_word_bb_syn90k.py DELETED
@@ -1,153 +0,0 @@
1
- # Crop by word bounding box
2
- # Locate script with gt.mat
3
- # $ python crop_by_word_bb.py
4
-
5
- import os
6
- import re
7
- import cv2
8
- import scipy.io as sio
9
- from itertools import chain
10
- import numpy as np
11
- import math
12
-
13
- mat_contents = sio.loadmat('gt.mat')
14
-
15
- image_names = mat_contents['imnames'][0]
16
- cropped_indx = 0
17
- start_img_indx = 0
18
- gt_file = open('gt_oabc.txt', 'a')
19
- err_file = open('err_oabc.txt', 'a')
20
-
21
- for img_indx in range(start_img_indx, len(image_names)):
22
-
23
-
24
- # Get image name
25
- image_name_new = image_names[img_indx][0]
26
- # print(image_name_new)
27
- image_name = '/home/yxwang/pytorch/dataset/SynthText/img/'+ image_name_new
28
- # print('IMAGE : {}.{}'.format(img_indx, image_name))
29
- print('evaluating {} image'.format(img_indx), end='\r')
30
- # Get text in image
31
- txt = mat_contents['txt'][0][img_indx]
32
- txt = [re.split(' \n|\n |\n| ', t.strip()) for t in txt]
33
- txt = list(chain(*txt))
34
- txt = [t for t in txt if len(t) > 0 ]
35
- # print(txt) # ['Lines:', 'I', 'lost', 'Kevin', 'will', 'line', 'and', 'and', 'the', '(and', 'the', 'out', 'you', "don't", 'pkg']
36
- # assert 1<0
37
-
38
- # Open image
39
- #img = Image.open(image_name)
40
- img = cv2.imread(image_name, cv2.IMREAD_COLOR)
41
- img_height, img_width, _ = img.shape
42
-
43
- # Validation
44
- if len(np.shape(mat_contents['wordBB'][0][img_indx])) == 2:
45
- wordBBlen = 1
46
- else:
47
- wordBBlen = mat_contents['wordBB'][0][img_indx].shape[-1]
48
-
49
- if wordBBlen == len(txt):
50
- # Crop image and save
51
- for word_indx in range(len(txt)):
52
- # print('txt--',txt)
53
- txt_temp = txt[word_indx]
54
- len_now = len(txt_temp)
55
- # txt_temp = re.sub('[^0-9a-zA-Z]+', '', txt_temp)
56
- # print('txt_temp-1-',txt_temp)
57
- txt_temp = re.sub('[^a-zA-Z]+', '', txt_temp)
58
- # print('txt_temp-2-',txt_temp)
59
- if len_now - len(txt_temp) != 0:
60
- print('txt_temp-2-', txt_temp)
61
-
62
- if len(np.shape(mat_contents['wordBB'][0][img_indx])) == 2: # only one word (2,4)
63
- wordBB = mat_contents['wordBB'][0][img_indx]
64
- else: # many words (2,4,num_words)
65
- wordBB = mat_contents['wordBB'][0][img_indx][:, :, word_indx]
66
-
67
- if np.shape(wordBB) != (2, 4):
68
- err_log = 'malformed box index: {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB)
69
- err_file.write(err_log)
70
- # print(err_log)
71
- continue
72
-
73
- pts1 = np.float32([[wordBB[0][0], wordBB[1][0]],
74
- [wordBB[0][3], wordBB[1][3]],
75
- [wordBB[0][1], wordBB[1][1]],
76
- [wordBB[0][2], wordBB[1][2]]])
77
- height = math.sqrt((wordBB[0][0] - wordBB[0][3])**2 + (wordBB[1][0] - wordBB[1][3])**2)
78
- width = math.sqrt((wordBB[0][0] - wordBB[0][1])**2 + (wordBB[1][0] - wordBB[1][1])**2)
79
-
80
- # Coord validation check
81
- if (height * width) <= 0:
82
- err_log = 'empty file : {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB)
83
- err_file.write(err_log)
84
- # print(err_log)
85
- continue
86
- elif (height * width) > (img_height * img_width):
87
- err_log = 'too big box : {}\t{}\t{}\n'.format(image_name, txt[word_indx], wordBB)
88
- err_file.write(err_log)
89
- # print(err_log)
90
- continue
91
- else:
92
- valid = True
93
- for i in range(2):
94
- for j in range(4):
95
- if wordBB[i][j] < 0 or wordBB[i][j] > img.shape[1 - i]:
96
- valid = False
97
- break
98
- if not valid:
99
- break
100
- if not valid:
101
- err_log = 'invalid coord : {}\t{}\t{}\t{}\t{}\n'.format(
102
- image_name, txt[word_indx], wordBB, (width, height), (img_width, img_height))
103
- err_file.write(err_log)
104
- # print(err_log)
105
- continue
106
-
107
- pts2 = np.float32([[0, 0],
108
- [0, height],
109
- [width, 0],
110
- [width, height]])
111
-
112
- x_min = np.int(round(min(wordBB[0][0], wordBB[0][1], wordBB[0][2], wordBB[0][3])))
113
- x_max = np.int(round(max(wordBB[0][0], wordBB[0][1], wordBB[0][2], wordBB[0][3])))
114
- y_min = np.int(round(min(wordBB[1][0], wordBB[1][1], wordBB[1][2], wordBB[1][3])))
115
- y_max = np.int(round(max(wordBB[1][0], wordBB[1][1], wordBB[1][2], wordBB[1][3])))
116
- # print(x_min, x_max, y_min, y_max)
117
- # print(img.shape)
118
- # assert 1<0
119
- if len(img.shape) == 3:
120
- img_cropped = img[ y_min:y_max:1, x_min:x_max:1, :]
121
- else:
122
- img_cropped = img[ y_min:y_max:1, x_min:x_max:1]
123
- dir_name = '/home/yxwang/pytorch/dataset/SynthText/cropped-oabc/{}'.format(image_name_new.split('/')[0])
124
- # print('dir_name--',dir_name)
125
- if not os.path.exists(dir_name):
126
- os.mkdir(dir_name)
127
- cropped_file_name = "{}/{}_{}_{}.jpg".format(dir_name, cropped_indx,
128
- image_name.split('/')[-1][:-len('.jpg')], word_indx)
129
- # print('cropped_file_name--',cropped_file_name)
130
- # print('img_cropped--',img_cropped.shape)
131
- if img_cropped.shape[0] == 0 or img_cropped.shape[1] == 0:
132
- err_log = 'word_box_mismatch : {}\t{}\t{}\n'.format(image_name, mat_contents['txt'][0][
133
- img_indx], mat_contents['wordBB'][0][img_indx])
134
- err_file.write(err_log)
135
- # print(err_log)
136
- continue
137
- # print('img_cropped--',img_cropped)
138
-
139
- # img_cropped.save(cropped_file_name)
140
- cv2.imwrite(cropped_file_name, img_cropped)
141
- cropped_indx += 1
142
- gt_file.write('%s\t%s\n' % (cropped_file_name, txt[word_indx]))
143
-
144
- # if cropped_indx>10:
145
- # assert 1<0
146
- # assert 1 < 0
147
- else:
148
- err_log = 'word_box_mismatch : {}\t{}\t{}\n'.format(image_name, mat_contents['txt'][0][
149
- img_indx], mat_contents['wordBB'][0][img_indx])
150
- err_file.write(err_log)
151
- # print(err_log)
152
- gt_file.close()
153
- err_file.close()