tobiasc commited on
Commit
ad16788
1 Parent(s): 216ab96

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +160 -0
  2. app.py +180 -0
  3. config.yaml +266 -0
  4. espnet/__init__.py +8 -0
  5. espnet/asr/__init__.py +1 -0
  6. espnet/asr/asr_mix_utils.py +187 -0
  7. espnet/asr/asr_utils.py +1024 -0
  8. espnet/asr/chainer_backend/__init__.py +1 -0
  9. espnet/asr/chainer_backend/asr.py +575 -0
  10. espnet/asr/pytorch_backend/__init__.py +1 -0
  11. espnet/asr/pytorch_backend/asr.py +1500 -0
  12. espnet/asr/pytorch_backend/asr_init.py +282 -0
  13. espnet/asr/pytorch_backend/asr_mix.py +654 -0
  14. espnet/asr/pytorch_backend/recog.py +152 -0
  15. espnet/bin/__init__.py +1 -0
  16. espnet/bin/asr_align.py +348 -0
  17. espnet/bin/asr_enhance.py +191 -0
  18. espnet/bin/asr_recog.py +363 -0
  19. espnet/bin/asr_train.py +644 -0
  20. espnet/bin/lm_train.py +288 -0
  21. espnet/bin/mt_train.py +480 -0
  22. espnet/bin/mt_trans.py +186 -0
  23. espnet/bin/st_train.py +550 -0
  24. espnet/bin/st_trans.py +183 -0
  25. espnet/bin/tts_decode.py +180 -0
  26. espnet/bin/tts_train.py +359 -0
  27. espnet/bin/vc_decode.py +174 -0
  28. espnet/bin/vc_train.py +368 -0
  29. espnet/lm/__init__.py +1 -0
  30. espnet/lm/chainer_backend/__init__.py +1 -0
  31. espnet/lm/chainer_backend/extlm.py +199 -0
  32. espnet/lm/chainer_backend/lm.py +484 -0
  33. espnet/lm/lm_utils.py +293 -0
  34. espnet/lm/pytorch_backend/__init__.py +1 -0
  35. espnet/lm/pytorch_backend/extlm.py +218 -0
  36. espnet/lm/pytorch_backend/lm.py +410 -0
  37. espnet/mt/__init__.py +1 -0
  38. espnet/mt/mt_utils.py +83 -0
  39. espnet/mt/pytorch_backend/__init__.py +1 -0
  40. espnet/mt/pytorch_backend/mt.py +600 -0
  41. espnet/nets/__init__.py +1 -0
  42. espnet/nets/asr_interface.py +172 -0
  43. espnet/nets/batch_beam_search.py +348 -0
  44. espnet/nets/batch_beam_search_online_sim.py +270 -0
  45. espnet/nets/beam_search.py +512 -0
  46. espnet/nets/beam_search_transducer.py +629 -0
  47. espnet/nets/chainer_backend/__init__.py +1 -0
  48. espnet/nets/chainer_backend/asr_interface.py +29 -0
  49. espnet/nets/chainer_backend/ctc.py +184 -0
  50. espnet/nets/chainer_backend/deterministic_embed_id.py +253 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from espnet2.bin.tts_inference import Text2Speech
2
+ import torch
3
+ from parallel_wavegan.utils import download_pretrained_model, load_model
4
+ from phonemizer import phonemize
5
+ from phonemizer.separator import Separator
6
+ import gradio as gr
7
+
8
+ s = Separator(word=None, phone=" ")
9
+ config_path = "config.yaml"
10
+ model_path = "model.pth"
11
+
12
+ vocoder_tag = "ljspeech_parallel_wavegan.v3"
13
+
14
+ vocoder = load_model(download_pretrained_model(vocoder_tag)).to("cpu").eval()
15
+ vocoder.remove_weight_norm()
16
+
17
+ global_styles = {
18
+ "Style 1": torch.load("style1.pt"),
19
+ "Style 2": torch.load("style2.pt"),
20
+ "Style 3": torch.load("style3.pt"),
21
+ "Style 4": torch.load("style4.pt"),
22
+ "Style 5": torch.load("style5.pt"),
23
+ "Style 6": torch.load("style6.pt"),
24
+ }
25
+
26
+
27
+ def inference(text, global_style, alpha, prev_fg_inds, input_fg_inds):
28
+ with torch.no_grad():
29
+ text2speech = Text2Speech(
30
+ config_path,
31
+ model_path,
32
+ device="cpu",
33
+ # Only for Tacotron 2
34
+ threshold=0.5,
35
+ minlenratio=0.0,
36
+ maxlenratio=10.0,
37
+ use_att_constraint=False,
38
+ backward_window=1,
39
+ forward_window=3,
40
+ # Only for FastSpeech & FastSpeech2
41
+ speed_control_alpha=alpha,
42
+ )
43
+ text2speech.spc2wav = None # Disable griffin-lim
44
+
45
+ style_emb = torch.flatten(global_styles[global_style])
46
+
47
+ phoneme_string = phonemize(
48
+ text, language="mb-us1", backend="espeak-mbrola", separator=s
49
+ )
50
+ phonemes = phoneme_string.split(" ")
51
+
52
+ max_edit_index = -1
53
+ for i in range(len(input_fg_inds) - 1, -1, -1):
54
+ if input_fg_inds[i] != "":
55
+ max_edit_index = i
56
+ break
57
+
58
+ if max_edit_index == -1:
59
+ _, c, _, _, _, _, _, output_fg_inds = text2speech(
60
+ phoneme_string, ref_embs=style_emb
61
+ )
62
+
63
+ else:
64
+ input_fg_inds_int_list = []
65
+ for i in range(max_edit_index + 1):
66
+ if input_fg_inds[i] != "":
67
+ input_fg_inds_int_list.append(int(input_fg_inds[i]))
68
+ else:
69
+ input_fg_inds_int_list.append(prev_fg_inds[i][1])
70
+ input_fg_inds = input_fg_inds_int_list
71
+
72
+ prev_fg_inds_list = [[[row[1], row[2], row[3]] for row in prev_fg_inds]]
73
+ prev_fg_inds = torch.tensor(prev_fg_inds_list, dtype=torch.int64)
74
+
75
+ fg_inds = torch.tensor(input_fg_inds_int_list).unsqueeze(0)
76
+ _, c, _, _, _, _, _, part_output_fg_inds = text2speech(
77
+ phoneme_string, ref_embs=style_emb, fg_inds=fg_inds
78
+ )
79
+
80
+ prev_fg_inds[0, max_edit_index + 1 :, :] = part_output_fg_inds[0]
81
+ output_fg_inds = prev_fg_inds
82
+
83
+ output_fg_inds_list = output_fg_inds.tolist()[0]
84
+ padded_phonemes = ["", *phonemes]
85
+ dataframe_values = [
86
+ [phoneme, *fgs]
87
+ for phoneme, fgs in zip(padded_phonemes, output_fg_inds_list)
88
+ ]
89
+ selected_inds = [
90
+ [input_fg_inds[i]] if i < len(input_fg_inds) else [""]
91
+ for i in range(len(padded_phonemes))
92
+ ]
93
+ wav = vocoder.inference(c)
94
+
95
+ return [
96
+ (22050, wav.view(-1).cpu().numpy()),
97
+ dataframe_values,
98
+ selected_inds,
99
+ ]
100
+
101
+
102
+ demo = gr.Blocks()
103
+
104
+ with demo:
105
+ gr.Markdown(
106
+ """
107
+
108
+ # ConEx Demo
109
+
110
+ This demo shows the capabilities of ConEx, a model for **Con**trollable **Ex**pressive speech synthesis.
111
+ ConEx allows you to generate speech in a certain speaking style, and gives you the ability to edit the prosody* of the generated speech at a fine level.
112
+ We proposed ConEx in our paper titled ["Interactive Multi-Level Prosody Control for Expressive Speech Synthesis"](https://jessa.github.io/assets/pdf/cornille2022icassp.pdf), published in proceedings of the IEEE International Conference on Acoustics, Speech, and Signal Processing (ICASSP) 2022.
113
+
114
+ To convert text to speech: input some text, choose the desired speaking style, set the duration factor (higher = slower speech), and press "Generate speech".
115
+
116
+ **prosody refers to speech characteristics such as intonation, stress, rhythm*
117
+ """
118
+ )
119
+
120
+ with gr.Row():
121
+ text_input = gr.Textbox(
122
+ label="Input text",
123
+ lines=4,
124
+ placeholder="E.g. I didn't say he stole the money",
125
+ )
126
+
127
+ with gr.Column():
128
+ global_style_dropdown = gr.Dropdown(
129
+ ["Style 1", "Style 2", "Style 3", "Style 4", "Style 5", "Style 6"],
130
+ value="Style 1",
131
+ label="Global speaking style",
132
+ )
133
+ alpha_slider = gr.Slider(
134
+ 0.1, 2, value=1, step=0.1, label="Alpha (duration factor)"
135
+ )
136
+
137
+ audio = gr.Audio()
138
+ with gr.Row():
139
+ button = gr.Button("Generate Speech")
140
+
141
+ gr.Markdown(
142
+ """
143
+
144
+ ### Fine-grained prosody editor
145
+ Once you've generated some speech, the following table will show the id of the prosody embedding used for each phoneme.
146
+ A prosody embedding determines the prosody of the phoneme.
147
+ The table not only shows the prosody embeddings that are used by default (the top predictions), but also two more likely prosody embeddings.
148
+
149
+ In order to change the prosody of a phoneme, write a new prosody embedding id in the "Chosen prosody embeddings" column and press "Generate speech" again.
150
+ You can use any number from 0-31, but the 2nd and 3rd predictions are more likely to give a fitting prosody.
151
+ Based on your edit, new prosody embeddings will be generated for the phonemes after the edit.
152
+ Thus, you can iteratively change the prosody by starting from the beginning of the utterance and working your through the utterance, making edits as you see fit.
153
+ The prosody embeddings before your edit will remain the same as before, and will be copied to the "Chosen prosody embeddings" column.
154
+ """
155
+ )
156
+
157
+ with gr.Row():
158
+ phoneme_preds_df = gr.Dataframe(
159
+ headers=["Phoneme", "🥇 Top pred", "🥈 2nd pred", "🥉 3rd pred"],
160
+ type="array",
161
+ col_count=(4, "static"),
162
+ )
163
+ phoneme_edits_df = gr.Dataframe(
164
+ headers=["Chosen prosody embeddings"], type="array", col_count=(1, "static")
165
+ )
166
+
167
+ button.click(
168
+ inference,
169
+ inputs=[
170
+ text_input,
171
+ global_style_dropdown,
172
+ alpha_slider,
173
+ phoneme_preds_df,
174
+ phoneme_edits_df,
175
+ ],
176
+ outputs=[audio, phoneme_preds_df, phoneme_edits_df],
177
+ )
178
+
179
+
180
+ demo.launch()
config.yaml ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config: conf/ar_prior_train.yaml
2
+ print_config: false
3
+ log_level: INFO
4
+ dry_run: false
5
+ iterator_type: sequence
6
+ output_dir: exp/tts_finetune_ar_prior
7
+ ngpu: 1
8
+ seed: 0
9
+ num_workers: 1
10
+ num_att_plot: 3
11
+ dist_backend: nccl
12
+ dist_init_method: env://
13
+ dist_world_size: null
14
+ dist_rank: null
15
+ local_rank: 0
16
+ dist_master_addr: null
17
+ dist_master_port: null
18
+ dist_launcher: null
19
+ multiprocessing_distributed: false
20
+ unused_parameters: false
21
+ sharded_ddp: false
22
+ cudnn_enabled: true
23
+ cudnn_benchmark: false
24
+ cudnn_deterministic: true
25
+ collect_stats: false
26
+ write_collected_feats: false
27
+ max_epoch: 500
28
+ patience: null
29
+ val_scheduler_criterion:
30
+ - valid
31
+ - loss
32
+ early_stopping_criterion:
33
+ - valid
34
+ - loss
35
+ - min
36
+ best_model_criterion:
37
+ - - valid
38
+ - loss
39
+ - min
40
+ - - train
41
+ - loss
42
+ - min
43
+ keep_nbest_models: 5
44
+ grad_clip: 1.0
45
+ grad_clip_type: 2.0
46
+ grad_noise: false
47
+ accum_grad: 8
48
+ no_forward_run: false
49
+ resume: true
50
+ train_dtype: float32
51
+ use_amp: false
52
+ log_interval: null
53
+ use_tensorboard: true
54
+ use_wandb: false
55
+ wandb_project: null
56
+ wandb_id: null
57
+ detect_anomaly: false
58
+ pretrain_path: null
59
+ init_param:
60
+ - /data/leuven/339/vsc33942/espnet-mirror/egs2/acapela_blizzard/tts1/exp/tts_train_raw_phn_none/valid.loss.best.pth:::tts.prosody_encoder.ar_prior
61
+ freeze_param:
62
+ - encoder.,prosody_encoder.ref_encoder.,prosody_encoder.fg_encoder.,prosody_encoder.global_encoder.,prosody_encoder.global_projection.,prosody_encoder.vq_layer.,prosody_encoder.qfg_projection,duration_predictor.,length_regulator,decoder.,feat_out,postnet
63
+ num_iters_per_epoch: 50
64
+ batch_size: 20
65
+ valid_batch_size: null
66
+ batch_bins: 3000000
67
+ valid_batch_bins: null
68
+ train_shape_file:
69
+ - exp/tts_stats_raw_phn_none/train/text_shape.phn
70
+ - exp/tts_stats_raw_phn_none/train/speech_shape
71
+ valid_shape_file:
72
+ - exp/tts_stats_raw_phn_none/valid/text_shape.phn
73
+ - exp/tts_stats_raw_phn_none/valid/speech_shape
74
+ batch_type: numel
75
+ valid_batch_type: null
76
+ fold_length:
77
+ - 150
78
+ - 204800
79
+ sort_in_batch: descending
80
+ sort_batch: descending
81
+ multiple_iterator: false
82
+ chunk_length: 500
83
+ chunk_shift_ratio: 0.5
84
+ num_cache_chunks: 1024
85
+ train_data_path_and_name_and_type:
86
+ - - dump/raw/tr_no_dev/text
87
+ - text
88
+ - text
89
+ - - data/durations/tr_no_dev/durations
90
+ - durations
91
+ - text_int
92
+ - - dump/raw/tr_no_dev/wav.scp
93
+ - speech
94
+ - sound
95
+ valid_data_path_and_name_and_type:
96
+ - - dump/raw/dev/text
97
+ - text
98
+ - text
99
+ - - data/durations/dev/durations
100
+ - durations
101
+ - text_int
102
+ - - dump/raw/dev/wav.scp
103
+ - speech
104
+ - sound
105
+ allow_variable_data_keys: false
106
+ max_cache_size: 0.0
107
+ max_cache_fd: 32
108
+ valid_max_cache_size: null
109
+ optim: adam
110
+ optim_conf:
111
+ lr: 1.0
112
+ scheduler: noamlr
113
+ scheduler_conf:
114
+ model_size: 384
115
+ warmup_steps: 4000
116
+ token_list:
117
+ - <blank>
118
+ - <unk>
119
+ - n
120
+ - '@'
121
+ - t
122
+ - _
123
+ - s
124
+ - I
125
+ - r
126
+ - d
127
+ - l
128
+ - m
129
+ - i
130
+ - '{'
131
+ - z
132
+ - D
133
+ - w
134
+ - r=
135
+ - f
136
+ - v
137
+ - E1
138
+ - b
139
+ - t_h
140
+ - h
141
+ - V
142
+ - u
143
+ - k
144
+ - I1
145
+ - '{1'
146
+ - k_h
147
+ - N
148
+ - EI1
149
+ - V1
150
+ - O1
151
+ - AI
152
+ - H
153
+ - S
154
+ - p_h
155
+ - '@U1'
156
+ - i1
157
+ - g
158
+ - AI1
159
+ - j
160
+ - O
161
+ - p
162
+ - u1
163
+ - r=1
164
+ - tS
165
+ - Or
166
+ - '4'
167
+ - A
168
+ - Or1
169
+ - E
170
+ - dZ
171
+ - T
172
+ - aU1
173
+ - U
174
+ - Er1
175
+ - '@U'
176
+ - U1
177
+ - Ar1
178
+ - Er
179
+ - aU
180
+ - EI
181
+ - ir1
182
+ - l=
183
+ - OI1
184
+ - Ar
185
+ - Ur1
186
+ - n=
187
+ - A1
188
+ - Z
189
+ - '?'
190
+ - ir
191
+ - Ur
192
+ - OI
193
+ - <sos/eos>
194
+ odim: null
195
+ model_conf: {}
196
+ use_preprocessor: true
197
+ token_type: phn
198
+ bpemodel: null
199
+ non_linguistic_symbols: null
200
+ cleaner: null
201
+ g2p: null
202
+ feats_extract: fbank
203
+ feats_extract_conf:
204
+ fs: 22050
205
+ fmin: 80
206
+ fmax: 7600
207
+ n_mels: 80
208
+ hop_length: 256
209
+ n_fft: 1024
210
+ win_length: null
211
+ normalize: global_mvn
212
+ normalize_conf:
213
+ stats_file: feats_stats.npz
214
+ tts: fastespeech
215
+ tts_conf:
216
+ adim: 128
217
+ aheads: 2
218
+ elayers: 4
219
+ eunits: 1536
220
+ dlayers: 4
221
+ dunits: 1536
222
+ positionwise_layer_type: conv1d
223
+ positionwise_conv_kernel_size: 3
224
+ duration_predictor_layers: 2
225
+ duration_predictor_chans: 128
226
+ duration_predictor_kernel_size: 3
227
+ duration_predictor_dropout_rate: 0.2
228
+ postnet_layers: 5
229
+ postnet_filts: 5
230
+ postnet_chans: 256
231
+ use_masking: true
232
+ use_scaled_pos_enc: true
233
+ encoder_normalize_before: true
234
+ decoder_normalize_before: true
235
+ reduction_factor: 1
236
+ init_type: xavier_uniform
237
+ init_enc_alpha: 1.0
238
+ init_dec_alpha: 1.0
239
+ transformer_enc_dropout_rate: 0.2
240
+ transformer_enc_positional_dropout_rate: 0.2
241
+ transformer_enc_attn_dropout_rate: 0.2
242
+ transformer_dec_dropout_rate: 0.2
243
+ transformer_dec_positional_dropout_rate: 0.2
244
+ transformer_dec_attn_dropout_rate: 0.2
245
+ ref_enc_conv_layers: 2
246
+ ref_enc_conv_kernel_size: 3
247
+ ref_enc_conv_stride: 2
248
+ ref_enc_gru_layers: 1
249
+ ref_enc_gru_units: 32
250
+ ref_emb_integration_type: add
251
+ prosody_num_embs: 32
252
+ prosody_hidden_dim: 3
253
+ prosody_emb_integration_type: add
254
+ pitch_extract: null
255
+ pitch_extract_conf: {}
256
+ pitch_normalize: null
257
+ pitch_normalize_conf: {}
258
+ energy_extract: null
259
+ energy_extract_conf: {}
260
+ energy_normalize: null
261
+ energy_normalize_conf: {}
262
+ required:
263
+ - output_dir
264
+ - token_list
265
+ version: 0.9.9
266
+ distributed: false
espnet/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Initialize espnet package."""
2
+
3
+ import os
4
+
5
+ dirname = os.path.dirname(__file__)
6
+ version_file = os.path.join(dirname, "version.txt")
7
+ with open(version_file, "r") as f:
8
+ __version__ = f.read().strip()
espnet/asr/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/asr/asr_mix_utils.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ This script is used to provide utility functions designed for multi-speaker ASR.
5
+
6
+ Copyright 2017 Johns Hopkins University (Shinji Watanabe)
7
+ Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
8
+
9
+ Most functions can be directly used as in asr_utils.py:
10
+ CompareValueTrigger, restore_snapshot, adadelta_eps_decay, chainer_load,
11
+ torch_snapshot, torch_save, torch_resume, AttributeDict, get_model_conf.
12
+
13
+ """
14
+
15
+ import copy
16
+ import logging
17
+ import os
18
+
19
+ from chainer.training import extension
20
+
21
+ import matplotlib
22
+
23
+ from espnet.asr.asr_utils import parse_hypothesis
24
+
25
+
26
+ matplotlib.use("Agg")
27
+
28
+
29
+ # * -------------------- chainer extension related -------------------- *
30
+ class PlotAttentionReport(extension.Extension):
31
+ """Plot attention reporter.
32
+
33
+ Args:
34
+ att_vis_fn (espnet.nets.*_backend.e2e_asr.calculate_all_attentions):
35
+ Function of attention visualization.
36
+ data (list[tuple(str, dict[str, dict[str, Any]])]): List json utt key items.
37
+ outdir (str): Directory to save figures.
38
+ converter (espnet.asr.*_backend.asr.CustomConverter):
39
+ CustomConverter object. Function to convert data.
40
+ device (torch.device): The destination device to send tensor.
41
+ reverse (bool): If True, input and output length are reversed.
42
+
43
+ """
44
+
45
+ def __init__(self, att_vis_fn, data, outdir, converter, device, reverse=False):
46
+ """Initialize PlotAttentionReport."""
47
+ self.att_vis_fn = att_vis_fn
48
+ self.data = copy.deepcopy(data)
49
+ self.outdir = outdir
50
+ self.converter = converter
51
+ self.device = device
52
+ self.reverse = reverse
53
+ if not os.path.exists(self.outdir):
54
+ os.makedirs(self.outdir)
55
+
56
+ def __call__(self, trainer):
57
+ """Plot and save imaged matrix of att_ws."""
58
+ att_ws_sd = self.get_attention_weights()
59
+ for ns, att_ws in enumerate(att_ws_sd):
60
+ for idx, att_w in enumerate(att_ws):
61
+ filename = "%s/%s.ep.{.updater.epoch}.output%d.png" % (
62
+ self.outdir,
63
+ self.data[idx][0],
64
+ ns + 1,
65
+ )
66
+ att_w = self.get_attention_weight(idx, att_w, ns)
67
+ self._plot_and_save_attention(att_w, filename.format(trainer))
68
+
69
+ def log_attentions(self, logger, step):
70
+ """Add image files of attention matrix to tensorboard."""
71
+ att_ws_sd = self.get_attention_weights()
72
+ for ns, att_ws in enumerate(att_ws_sd):
73
+ for idx, att_w in enumerate(att_ws):
74
+ att_w = self.get_attention_weight(idx, att_w, ns)
75
+ plot = self.draw_attention_plot(att_w)
76
+ logger.add_figure("%s" % (self.data[idx][0]), plot.gcf(), step)
77
+ plot.clf()
78
+
79
+ def get_attention_weights(self):
80
+ """Return attention weights.
81
+
82
+ Returns:
83
+ arr_ws_sd (numpy.ndarray): attention weights. It's shape would be
84
+ differ from bachend.dtype=float
85
+ * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax). 2)
86
+ other case => (B, Lmax, Tmax).
87
+ * chainer-> attention weights (B, Lmax, Tmax).
88
+
89
+ """
90
+ batch = self.converter([self.converter.transform(self.data)], self.device)
91
+ att_ws_sd = self.att_vis_fn(*batch)
92
+ return att_ws_sd
93
+
94
+ def get_attention_weight(self, idx, att_w, spkr_idx):
95
+ """Transform attention weight in regard to self.reverse."""
96
+ if self.reverse:
97
+ dec_len = int(self.data[idx][1]["input"][0]["shape"][0])
98
+ enc_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0])
99
+ else:
100
+ dec_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0])
101
+ enc_len = int(self.data[idx][1]["input"][0]["shape"][0])
102
+ if len(att_w.shape) == 3:
103
+ att_w = att_w[:, :dec_len, :enc_len]
104
+ else:
105
+ att_w = att_w[:dec_len, :enc_len]
106
+ return att_w
107
+
108
+ def draw_attention_plot(self, att_w):
109
+ """Visualize attention weights matrix.
110
+
111
+ Args:
112
+ att_w(Tensor): Attention weight matrix.
113
+
114
+ Returns:
115
+ matplotlib.pyplot: pyplot object with attention matrix image.
116
+
117
+ """
118
+ import matplotlib.pyplot as plt
119
+
120
+ if len(att_w.shape) == 3:
121
+ for h, aw in enumerate(att_w, 1):
122
+ plt.subplot(1, len(att_w), h)
123
+ plt.imshow(aw, aspect="auto")
124
+ plt.xlabel("Encoder Index")
125
+ plt.ylabel("Decoder Index")
126
+ else:
127
+ plt.imshow(att_w, aspect="auto")
128
+ plt.xlabel("Encoder Index")
129
+ plt.ylabel("Decoder Index")
130
+ plt.tight_layout()
131
+ return plt
132
+
133
+ def _plot_and_save_attention(self, att_w, filename):
134
+ plt = self.draw_attention_plot(att_w)
135
+ plt.savefig(filename)
136
+ plt.close()
137
+
138
+
139
+ def add_results_to_json(js, nbest_hyps_sd, char_list):
140
+ """Add N-best results to json.
141
+
142
+ Args:
143
+ js (dict[str, Any]): Groundtruth utterance dict.
144
+ nbest_hyps_sd (list[dict[str, Any]]):
145
+ List of hypothesis for multi_speakers (# Utts x # Spkrs).
146
+ char_list (list[str]): List of characters.
147
+
148
+ Returns:
149
+ dict[str, Any]: N-best results added utterance dict.
150
+
151
+ """
152
+ # copy old json info
153
+ new_js = dict()
154
+ new_js["utt2spk"] = js["utt2spk"]
155
+ num_spkrs = len(nbest_hyps_sd)
156
+ new_js["output"] = []
157
+
158
+ for ns in range(num_spkrs):
159
+ tmp_js = []
160
+ nbest_hyps = nbest_hyps_sd[ns]
161
+
162
+ for n, hyp in enumerate(nbest_hyps, 1):
163
+ # parse hypothesis
164
+ rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
165
+
166
+ # copy ground-truth
167
+ out_dic = dict(js["output"][ns].items())
168
+
169
+ # update name
170
+ out_dic["name"] += "[%d]" % n
171
+
172
+ # add recognition results
173
+ out_dic["rec_text"] = rec_text
174
+ out_dic["rec_token"] = rec_token
175
+ out_dic["rec_tokenid"] = rec_tokenid
176
+ out_dic["score"] = score
177
+
178
+ # add to list of N-best result dicts
179
+ tmp_js.append(out_dic)
180
+
181
+ # show 1-best result
182
+ if n == 1:
183
+ logging.info("groundtruth: %s" % out_dic["text"])
184
+ logging.info("prediction : %s" % out_dic["rec_text"])
185
+
186
+ new_js["output"].append(tmp_js)
187
+ return new_js
espnet/asr/asr_utils.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import argparse
5
+ import copy
6
+ import json
7
+ import logging
8
+ import os
9
+ import shutil
10
+ import tempfile
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ # * -------------------- training iterator related -------------------- *
17
+
18
+
19
+ class CompareValueTrigger(object):
20
+ """Trigger invoked when key value getting bigger or lower than before.
21
+
22
+ Args:
23
+ key (str) : Key of value.
24
+ compare_fn ((float, float) -> bool) : Function to compare the values.
25
+ trigger (tuple(int, str)) : Trigger that decide the comparison interval.
26
+
27
+ """
28
+
29
+ def __init__(self, key, compare_fn, trigger=(1, "epoch")):
30
+ from chainer import training
31
+
32
+ self._key = key
33
+ self._best_value = None
34
+ self._interval_trigger = training.util.get_trigger(trigger)
35
+ self._init_summary()
36
+ self._compare_fn = compare_fn
37
+
38
+ def __call__(self, trainer):
39
+ """Get value related to the key and compare with current value."""
40
+ observation = trainer.observation
41
+ summary = self._summary
42
+ key = self._key
43
+ if key in observation:
44
+ summary.add({key: observation[key]})
45
+
46
+ if not self._interval_trigger(trainer):
47
+ return False
48
+
49
+ stats = summary.compute_mean()
50
+ value = float(stats[key]) # copy to CPU
51
+ self._init_summary()
52
+
53
+ if self._best_value is None:
54
+ # initialize best value
55
+ self._best_value = value
56
+ return False
57
+ elif self._compare_fn(self._best_value, value):
58
+ return True
59
+ else:
60
+ self._best_value = value
61
+ return False
62
+
63
+ def _init_summary(self):
64
+ import chainer
65
+
66
+ self._summary = chainer.reporter.DictSummary()
67
+
68
+
69
+ try:
70
+ from chainer.training import extension
71
+ except ImportError:
72
+ PlotAttentionReport = None
73
+ else:
74
+
75
+ class PlotAttentionReport(extension.Extension):
76
+ """Plot attention reporter.
77
+
78
+ Args:
79
+ att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
80
+ Function of attention visualization.
81
+ data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
82
+ outdir (str): Directory to save figures.
83
+ converter (espnet.asr.*_backend.asr.CustomConverter):
84
+ Function to convert data.
85
+ device (int | torch.device): Device.
86
+ reverse (bool): If True, input and output length are reversed.
87
+ ikey (str): Key to access input
88
+ (for ASR/ST ikey="input", for MT ikey="output".)
89
+ iaxis (int): Dimension to access input
90
+ (for ASR/ST iaxis=0, for MT iaxis=1.)
91
+ okey (str): Key to access output
92
+ (for ASR/ST okey="input", MT okay="output".)
93
+ oaxis (int): Dimension to access output
94
+ (for ASR/ST oaxis=0, for MT oaxis=0.)
95
+ subsampling_factor (int): subsampling factor in encoder
96
+
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ att_vis_fn,
102
+ data,
103
+ outdir,
104
+ converter,
105
+ transform,
106
+ device,
107
+ reverse=False,
108
+ ikey="input",
109
+ iaxis=0,
110
+ okey="output",
111
+ oaxis=0,
112
+ subsampling_factor=1,
113
+ ):
114
+ self.att_vis_fn = att_vis_fn
115
+ self.data = copy.deepcopy(data)
116
+ self.data_dict = {k: v for k, v in copy.deepcopy(data)}
117
+ # key is utterance ID
118
+ self.outdir = outdir
119
+ self.converter = converter
120
+ self.transform = transform
121
+ self.device = device
122
+ self.reverse = reverse
123
+ self.ikey = ikey
124
+ self.iaxis = iaxis
125
+ self.okey = okey
126
+ self.oaxis = oaxis
127
+ self.factor = subsampling_factor
128
+ if not os.path.exists(self.outdir):
129
+ os.makedirs(self.outdir)
130
+
131
+ def __call__(self, trainer):
132
+ """Plot and save image file of att_ws matrix."""
133
+ att_ws, uttid_list = self.get_attention_weights()
134
+ if isinstance(att_ws, list): # multi-encoder case
135
+ num_encs = len(att_ws) - 1
136
+ # atts
137
+ for i in range(num_encs):
138
+ for idx, att_w in enumerate(att_ws[i]):
139
+ filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
140
+ self.outdir,
141
+ uttid_list[idx],
142
+ i + 1,
143
+ )
144
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
145
+ np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
146
+ self.outdir,
147
+ uttid_list[idx],
148
+ i + 1,
149
+ )
150
+ np.save(np_filename.format(trainer), att_w)
151
+ self._plot_and_save_attention(att_w, filename.format(trainer))
152
+ # han
153
+ for idx, att_w in enumerate(att_ws[num_encs]):
154
+ filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
155
+ self.outdir,
156
+ uttid_list[idx],
157
+ )
158
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
159
+ np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
160
+ self.outdir,
161
+ uttid_list[idx],
162
+ )
163
+ np.save(np_filename.format(trainer), att_w)
164
+ self._plot_and_save_attention(
165
+ att_w, filename.format(trainer), han_mode=True
166
+ )
167
+ else:
168
+ for idx, att_w in enumerate(att_ws):
169
+ filename = "%s/%s.ep.{.updater.epoch}.png" % (
170
+ self.outdir,
171
+ uttid_list[idx],
172
+ )
173
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
174
+ np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
175
+ self.outdir,
176
+ uttid_list[idx],
177
+ )
178
+ np.save(np_filename.format(trainer), att_w)
179
+ self._plot_and_save_attention(att_w, filename.format(trainer))
180
+
181
+ def log_attentions(self, logger, step):
182
+ """Add image files of att_ws matrix to the tensorboard."""
183
+ att_ws, uttid_list = self.get_attention_weights()
184
+ if isinstance(att_ws, list): # multi-encoder case
185
+ num_encs = len(att_ws) - 1
186
+ # atts
187
+ for i in range(num_encs):
188
+ for idx, att_w in enumerate(att_ws[i]):
189
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
190
+ plot = self.draw_attention_plot(att_w)
191
+ logger.add_figure(
192
+ "%s_att%d" % (uttid_list[idx], i + 1),
193
+ plot.gcf(),
194
+ step,
195
+ )
196
+ # han
197
+ for idx, att_w in enumerate(att_ws[num_encs]):
198
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
199
+ plot = self.draw_han_plot(att_w)
200
+ logger.add_figure(
201
+ "%s_han" % (uttid_list[idx]),
202
+ plot.gcf(),
203
+ step,
204
+ )
205
+ else:
206
+ for idx, att_w in enumerate(att_ws):
207
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
208
+ plot = self.draw_attention_plot(att_w)
209
+ logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
210
+
211
+ def get_attention_weights(self):
212
+ """Return attention weights.
213
+
214
+ Returns:
215
+ numpy.ndarray: attention weights. float. Its shape would be
216
+ differ from backend.
217
+ * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
218
+ other case => (B, Lmax, Tmax).
219
+ * chainer-> (B, Lmax, Tmax)
220
+
221
+ """
222
+ return_batch, uttid_list = self.transform(self.data, return_uttid=True)
223
+ batch = self.converter([return_batch], self.device)
224
+ if isinstance(batch, tuple):
225
+ att_ws = self.att_vis_fn(*batch)
226
+ else:
227
+ att_ws = self.att_vis_fn(**batch)
228
+ return att_ws, uttid_list
229
+
230
+ def trim_attention_weight(self, uttid, att_w):
231
+ """Transform attention matrix with regard to self.reverse."""
232
+ if self.reverse:
233
+ enc_key, enc_axis = self.okey, self.oaxis
234
+ dec_key, dec_axis = self.ikey, self.iaxis
235
+ else:
236
+ enc_key, enc_axis = self.ikey, self.iaxis
237
+ dec_key, dec_axis = self.okey, self.oaxis
238
+ dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
239
+ enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
240
+ if self.factor > 1:
241
+ enc_len //= self.factor
242
+ if len(att_w.shape) == 3:
243
+ att_w = att_w[:, :dec_len, :enc_len]
244
+ else:
245
+ att_w = att_w[:dec_len, :enc_len]
246
+ return att_w
247
+
248
+ def draw_attention_plot(self, att_w):
249
+ """Plot the att_w matrix.
250
+
251
+ Returns:
252
+ matplotlib.pyplot: pyplot object with attention matrix image.
253
+
254
+ """
255
+ import matplotlib
256
+
257
+ matplotlib.use("Agg")
258
+ import matplotlib.pyplot as plt
259
+
260
+ plt.clf()
261
+ att_w = att_w.astype(np.float32)
262
+ if len(att_w.shape) == 3:
263
+ for h, aw in enumerate(att_w, 1):
264
+ plt.subplot(1, len(att_w), h)
265
+ plt.imshow(aw, aspect="auto")
266
+ plt.xlabel("Encoder Index")
267
+ plt.ylabel("Decoder Index")
268
+ else:
269
+ plt.imshow(att_w, aspect="auto")
270
+ plt.xlabel("Encoder Index")
271
+ plt.ylabel("Decoder Index")
272
+ plt.tight_layout()
273
+ return plt
274
+
275
+ def draw_han_plot(self, att_w):
276
+ """Plot the att_w matrix for hierarchical attention.
277
+
278
+ Returns:
279
+ matplotlib.pyplot: pyplot object with attention matrix image.
280
+
281
+ """
282
+ import matplotlib
283
+
284
+ matplotlib.use("Agg")
285
+ import matplotlib.pyplot as plt
286
+
287
+ plt.clf()
288
+ if len(att_w.shape) == 3:
289
+ for h, aw in enumerate(att_w, 1):
290
+ legends = []
291
+ plt.subplot(1, len(att_w), h)
292
+ for i in range(aw.shape[1]):
293
+ plt.plot(aw[:, i])
294
+ legends.append("Att{}".format(i))
295
+ plt.ylim([0, 1.0])
296
+ plt.xlim([0, aw.shape[0]])
297
+ plt.grid(True)
298
+ plt.ylabel("Attention Weight")
299
+ plt.xlabel("Decoder Index")
300
+ plt.legend(legends)
301
+ else:
302
+ legends = []
303
+ for i in range(att_w.shape[1]):
304
+ plt.plot(att_w[:, i])
305
+ legends.append("Att{}".format(i))
306
+ plt.ylim([0, 1.0])
307
+ plt.xlim([0, att_w.shape[0]])
308
+ plt.grid(True)
309
+ plt.ylabel("Attention Weight")
310
+ plt.xlabel("Decoder Index")
311
+ plt.legend(legends)
312
+ plt.tight_layout()
313
+ return plt
314
+
315
+ def _plot_and_save_attention(self, att_w, filename, han_mode=False):
316
+ if han_mode:
317
+ plt = self.draw_han_plot(att_w)
318
+ else:
319
+ plt = self.draw_attention_plot(att_w)
320
+ plt.savefig(filename)
321
+ plt.close()
322
+
323
+
324
+ try:
325
+ from chainer.training import extension
326
+ except ImportError:
327
+ PlotCTCReport = None
328
+ else:
329
+
330
+ class PlotCTCReport(extension.Extension):
331
+ """Plot CTC reporter.
332
+
333
+ Args:
334
+ ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
335
+ Function of CTC visualization.
336
+ data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
337
+ outdir (str): Directory to save figures.
338
+ converter (espnet.asr.*_backend.asr.CustomConverter):
339
+ Function to convert data.
340
+ device (int | torch.device): Device.
341
+ reverse (bool): If True, input and output length are reversed.
342
+ ikey (str): Key to access input
343
+ (for ASR/ST ikey="input", for MT ikey="output".)
344
+ iaxis (int): Dimension to access input
345
+ (for ASR/ST iaxis=0, for MT iaxis=1.)
346
+ okey (str): Key to access output
347
+ (for ASR/ST okey="input", MT okay="output".)
348
+ oaxis (int): Dimension to access output
349
+ (for ASR/ST oaxis=0, for MT oaxis=0.)
350
+ subsampling_factor (int): subsampling factor in encoder
351
+
352
+ """
353
+
354
+ def __init__(
355
+ self,
356
+ ctc_vis_fn,
357
+ data,
358
+ outdir,
359
+ converter,
360
+ transform,
361
+ device,
362
+ reverse=False,
363
+ ikey="input",
364
+ iaxis=0,
365
+ okey="output",
366
+ oaxis=0,
367
+ subsampling_factor=1,
368
+ ):
369
+ self.ctc_vis_fn = ctc_vis_fn
370
+ self.data = copy.deepcopy(data)
371
+ self.data_dict = {k: v for k, v in copy.deepcopy(data)}
372
+ # key is utterance ID
373
+ self.outdir = outdir
374
+ self.converter = converter
375
+ self.transform = transform
376
+ self.device = device
377
+ self.reverse = reverse
378
+ self.ikey = ikey
379
+ self.iaxis = iaxis
380
+ self.okey = okey
381
+ self.oaxis = oaxis
382
+ self.factor = subsampling_factor
383
+ if not os.path.exists(self.outdir):
384
+ os.makedirs(self.outdir)
385
+
386
+ def __call__(self, trainer):
387
+ """Plot and save image file of ctc prob."""
388
+ ctc_probs, uttid_list = self.get_ctc_probs()
389
+ if isinstance(ctc_probs, list): # multi-encoder case
390
+ num_encs = len(ctc_probs) - 1
391
+ for i in range(num_encs):
392
+ for idx, ctc_prob in enumerate(ctc_probs[i]):
393
+ filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
394
+ self.outdir,
395
+ uttid_list[idx],
396
+ i + 1,
397
+ )
398
+ ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
399
+ np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
400
+ self.outdir,
401
+ uttid_list[idx],
402
+ i + 1,
403
+ )
404
+ np.save(np_filename.format(trainer), ctc_prob)
405
+ self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
406
+ else:
407
+ for idx, ctc_prob in enumerate(ctc_probs):
408
+ filename = "%s/%s.ep.{.updater.epoch}.png" % (
409
+ self.outdir,
410
+ uttid_list[idx],
411
+ )
412
+ ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
413
+ np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
414
+ self.outdir,
415
+ uttid_list[idx],
416
+ )
417
+ np.save(np_filename.format(trainer), ctc_prob)
418
+ self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
419
+
420
+ def log_ctc_probs(self, logger, step):
421
+ """Add image files of ctc probs to the tensorboard."""
422
+ ctc_probs, uttid_list = self.get_ctc_probs()
423
+ if isinstance(ctc_probs, list): # multi-encoder case
424
+ num_encs = len(ctc_probs) - 1
425
+ for i in range(num_encs):
426
+ for idx, ctc_prob in enumerate(ctc_probs[i]):
427
+ ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
428
+ plot = self.draw_ctc_plot(ctc_prob)
429
+ logger.add_figure(
430
+ "%s_ctc%d" % (uttid_list[idx], i + 1),
431
+ plot.gcf(),
432
+ step,
433
+ )
434
+ else:
435
+ for idx, ctc_prob in enumerate(ctc_probs):
436
+ ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
437
+ plot = self.draw_ctc_plot(ctc_prob)
438
+ logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
439
+
440
+ def get_ctc_probs(self):
441
+ """Return CTC probs.
442
+
443
+ Returns:
444
+ numpy.ndarray: CTC probs. float. Its shape would be
445
+ differ from backend. (B, Tmax, vocab).
446
+
447
+ """
448
+ return_batch, uttid_list = self.transform(self.data, return_uttid=True)
449
+ batch = self.converter([return_batch], self.device)
450
+ if isinstance(batch, tuple):
451
+ probs = self.ctc_vis_fn(*batch)
452
+ else:
453
+ probs = self.ctc_vis_fn(**batch)
454
+ return probs, uttid_list
455
+
456
+ def trim_ctc_prob(self, uttid, prob):
457
+ """Trim CTC posteriors accoding to input lengths."""
458
+ enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
459
+ if self.factor > 1:
460
+ enc_len //= self.factor
461
+ prob = prob[:enc_len]
462
+ return prob
463
+
464
+ def draw_ctc_plot(self, ctc_prob):
465
+ """Plot the ctc_prob matrix.
466
+
467
+ Returns:
468
+ matplotlib.pyplot: pyplot object with CTC prob matrix image.
469
+
470
+ """
471
+ import matplotlib
472
+
473
+ matplotlib.use("Agg")
474
+ import matplotlib.pyplot as plt
475
+
476
+ ctc_prob = ctc_prob.astype(np.float32)
477
+
478
+ plt.clf()
479
+ topk_ids = np.argsort(ctc_prob, axis=1)
480
+ n_frames, vocab = ctc_prob.shape
481
+ times_probs = np.arange(n_frames)
482
+
483
+ plt.figure(figsize=(20, 8))
484
+
485
+ # NOTE: index 0 is reserved for blank
486
+ for idx in set(topk_ids.reshape(-1).tolist()):
487
+ if idx == 0:
488
+ plt.plot(
489
+ times_probs, ctc_prob[:, 0], ":", label="<blank>", color="grey"
490
+ )
491
+ else:
492
+ plt.plot(times_probs, ctc_prob[:, idx])
493
+ plt.xlabel(u"Input [frame]", fontsize=12)
494
+ plt.ylabel("Posteriors", fontsize=12)
495
+ plt.xticks(list(range(0, int(n_frames) + 1, 10)))
496
+ plt.yticks(list(range(0, 2, 1)))
497
+ plt.tight_layout()
498
+ return plt
499
+
500
+ def _plot_and_save_ctc(self, ctc_prob, filename):
501
+ plt = self.draw_ctc_plot(ctc_prob)
502
+ plt.savefig(filename)
503
+ plt.close()
504
+
505
+
506
+ def restore_snapshot(model, snapshot, load_fn=None):
507
+ """Extension to restore snapshot.
508
+
509
+ Returns:
510
+ An extension function.
511
+
512
+ """
513
+ import chainer
514
+ from chainer import training
515
+
516
+ if load_fn is None:
517
+ load_fn = chainer.serializers.load_npz
518
+
519
+ @training.make_extension(trigger=(1, "epoch"))
520
+ def restore_snapshot(trainer):
521
+ _restore_snapshot(model, snapshot, load_fn)
522
+
523
+ return restore_snapshot
524
+
525
+
526
+ def _restore_snapshot(model, snapshot, load_fn=None):
527
+ if load_fn is None:
528
+ import chainer
529
+
530
+ load_fn = chainer.serializers.load_npz
531
+
532
+ load_fn(snapshot, model)
533
+ logging.info("restored from " + str(snapshot))
534
+
535
+
536
+ def adadelta_eps_decay(eps_decay):
537
+ """Extension to perform adadelta eps decay.
538
+
539
+ Args:
540
+ eps_decay (float): Decay rate of eps.
541
+
542
+ Returns:
543
+ An extension function.
544
+
545
+ """
546
+ from chainer import training
547
+
548
+ @training.make_extension(trigger=(1, "epoch"))
549
+ def adadelta_eps_decay(trainer):
550
+ _adadelta_eps_decay(trainer, eps_decay)
551
+
552
+ return adadelta_eps_decay
553
+
554
+
555
+ def _adadelta_eps_decay(trainer, eps_decay):
556
+ optimizer = trainer.updater.get_optimizer("main")
557
+ # for chainer
558
+ if hasattr(optimizer, "eps"):
559
+ current_eps = optimizer.eps
560
+ setattr(optimizer, "eps", current_eps * eps_decay)
561
+ logging.info("adadelta eps decayed to " + str(optimizer.eps))
562
+ # pytorch
563
+ else:
564
+ for p in optimizer.param_groups:
565
+ p["eps"] *= eps_decay
566
+ logging.info("adadelta eps decayed to " + str(p["eps"]))
567
+
568
+
569
+ def adam_lr_decay(eps_decay):
570
+ """Extension to perform adam lr decay.
571
+
572
+ Args:
573
+ eps_decay (float): Decay rate of lr.
574
+
575
+ Returns:
576
+ An extension function.
577
+
578
+ """
579
+ from chainer import training
580
+
581
+ @training.make_extension(trigger=(1, "epoch"))
582
+ def adam_lr_decay(trainer):
583
+ _adam_lr_decay(trainer, eps_decay)
584
+
585
+ return adam_lr_decay
586
+
587
+
588
+ def _adam_lr_decay(trainer, eps_decay):
589
+ optimizer = trainer.updater.get_optimizer("main")
590
+ # for chainer
591
+ if hasattr(optimizer, "lr"):
592
+ current_lr = optimizer.lr
593
+ setattr(optimizer, "lr", current_lr * eps_decay)
594
+ logging.info("adam lr decayed to " + str(optimizer.lr))
595
+ # pytorch
596
+ else:
597
+ for p in optimizer.param_groups:
598
+ p["lr"] *= eps_decay
599
+ logging.info("adam lr decayed to " + str(p["lr"]))
600
+
601
+
602
+ def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.epoch}"):
603
+ """Extension to take snapshot of the trainer for pytorch.
604
+
605
+ Returns:
606
+ An extension function.
607
+
608
+ """
609
+ from chainer.training import extension
610
+
611
+ @extension.make_extension(trigger=(1, "epoch"), priority=-100)
612
+ def torch_snapshot(trainer):
613
+ _torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun)
614
+
615
+ return torch_snapshot
616
+
617
+
618
+ def _torch_snapshot_object(trainer, target, filename, savefun):
619
+ from chainer.serializers import DictionarySerializer
620
+
621
+ # make snapshot_dict dictionary
622
+ s = DictionarySerializer()
623
+ s.save(trainer)
624
+ if hasattr(trainer.updater.model, "model"):
625
+ # (for TTS)
626
+ if hasattr(trainer.updater.model.model, "module"):
627
+ model_state_dict = trainer.updater.model.model.module.state_dict()
628
+ else:
629
+ model_state_dict = trainer.updater.model.model.state_dict()
630
+ else:
631
+ # (for ASR)
632
+ if hasattr(trainer.updater.model, "module"):
633
+ model_state_dict = trainer.updater.model.module.state_dict()
634
+ else:
635
+ model_state_dict = trainer.updater.model.state_dict()
636
+ snapshot_dict = {
637
+ "trainer": s.target,
638
+ "model": model_state_dict,
639
+ "optimizer": trainer.updater.get_optimizer("main").state_dict(),
640
+ }
641
+
642
+ # save snapshot dictionary
643
+ fn = filename.format(trainer)
644
+ prefix = "tmp" + fn
645
+ tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out)
646
+ tmppath = os.path.join(tmpdir, fn)
647
+ try:
648
+ savefun(snapshot_dict, tmppath)
649
+ shutil.move(tmppath, os.path.join(trainer.out, fn))
650
+ finally:
651
+ shutil.rmtree(tmpdir)
652
+
653
+
654
+ def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55):
655
+ """Adds noise from a standard normal distribution to the gradients.
656
+
657
+ The standard deviation (`sigma`) is controlled by the three hyper-parameters below.
658
+ `sigma` goes to zero (no noise) with more iterations.
659
+
660
+ Args:
661
+ model (torch.nn.model): Model.
662
+ iteration (int): Number of iterations.
663
+ duration (int) {100, 1000}:
664
+ Number of durations to control the interval of the `sigma` change.
665
+ eta (float) {0.01, 0.3, 1.0}: The magnitude of `sigma`.
666
+ scale_factor (float) {0.55}: The scale of `sigma`.
667
+ """
668
+ interval = (iteration // duration) + 1
669
+ sigma = eta / interval ** scale_factor
670
+ for param in model.parameters():
671
+ if param.grad is not None:
672
+ _shape = param.grad.size()
673
+ noise = sigma * torch.randn(_shape).to(param.device)
674
+ param.grad += noise
675
+
676
+
677
+ # * -------------------- general -------------------- *
678
+ def get_model_conf(model_path, conf_path=None):
679
+ """Get model config information by reading a model config file (model.json).
680
+
681
+ Args:
682
+ model_path (str): Model path.
683
+ conf_path (str): Optional model config path.
684
+
685
+ Returns:
686
+ list[int, int, dict[str, Any]]: Config information loaded from json file.
687
+
688
+ """
689
+ if conf_path is None:
690
+ model_conf = os.path.dirname(model_path) + "/model.json"
691
+ else:
692
+ model_conf = conf_path
693
+ with open(model_conf, "rb") as f:
694
+ logging.info("reading a config file from " + model_conf)
695
+ confs = json.load(f)
696
+ if isinstance(confs, dict):
697
+ # for lm
698
+ args = confs
699
+ return argparse.Namespace(**args)
700
+ else:
701
+ # for asr, tts, mt
702
+ idim, odim, args = confs
703
+ return idim, odim, argparse.Namespace(**args)
704
+
705
+
706
+ def chainer_load(path, model):
707
+ """Load chainer model parameters.
708
+
709
+ Args:
710
+ path (str): Model path or snapshot file path to be loaded.
711
+ model (chainer.Chain): Chainer model.
712
+
713
+ """
714
+ import chainer
715
+
716
+ if "snapshot" in os.path.basename(path):
717
+ chainer.serializers.load_npz(path, model, path="updater/model:main/")
718
+ else:
719
+ chainer.serializers.load_npz(path, model)
720
+
721
+
722
+ def torch_save(path, model):
723
+ """Save torch model states.
724
+
725
+ Args:
726
+ path (str): Model path to be saved.
727
+ model (torch.nn.Module): Torch model.
728
+
729
+ """
730
+ if hasattr(model, "module"):
731
+ torch.save(model.module.state_dict(), path)
732
+ else:
733
+ torch.save(model.state_dict(), path)
734
+
735
+
736
+ def snapshot_object(target, filename):
737
+ """Returns a trainer extension to take snapshots of a given object.
738
+
739
+ Args:
740
+ target (model): Object to serialize.
741
+ filename (str): Name of the file into which the object is serialized.It can
742
+ be a format string, where the trainer object is passed to
743
+ the :meth: `str.format` method. For example,
744
+ ``'snapshot_{.updater.iteration}'`` is converted to
745
+ ``'snapshot_10000'`` at the 10,000th iteration.
746
+
747
+ Returns:
748
+ An extension function.
749
+
750
+ """
751
+ from chainer.training import extension
752
+
753
+ @extension.make_extension(trigger=(1, "epoch"), priority=-100)
754
+ def snapshot_object(trainer):
755
+ torch_save(os.path.join(trainer.out, filename.format(trainer)), target)
756
+
757
+ return snapshot_object
758
+
759
+
760
+ def torch_load(path, model):
761
+ """Load torch model states.
762
+
763
+ Args:
764
+ path (str): Model path or snapshot file path to be loaded.
765
+ model (torch.nn.Module): Torch model.
766
+
767
+ """
768
+ if "snapshot" in os.path.basename(path):
769
+ model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)[
770
+ "model"
771
+ ]
772
+ else:
773
+ model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)
774
+
775
+ if hasattr(model, "module"):
776
+ model.module.load_state_dict(model_state_dict)
777
+ else:
778
+ model.load_state_dict(model_state_dict)
779
+
780
+ del model_state_dict
781
+
782
+
783
+ def torch_resume(snapshot_path, trainer):
784
+ """Resume from snapshot for pytorch.
785
+
786
+ Args:
787
+ snapshot_path (str): Snapshot file path.
788
+ trainer (chainer.training.Trainer): Chainer's trainer instance.
789
+
790
+ """
791
+ from chainer.serializers import NpzDeserializer
792
+
793
+ # load snapshot
794
+ snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage)
795
+
796
+ # restore trainer states
797
+ d = NpzDeserializer(snapshot_dict["trainer"])
798
+ d.load(trainer)
799
+
800
+ # restore model states
801
+ if hasattr(trainer.updater.model, "model"):
802
+ # (for TTS model)
803
+ if hasattr(trainer.updater.model.model, "module"):
804
+ trainer.updater.model.model.module.load_state_dict(snapshot_dict["model"])
805
+ else:
806
+ trainer.updater.model.model.load_state_dict(snapshot_dict["model"])
807
+ else:
808
+ # (for ASR model)
809
+ if hasattr(trainer.updater.model, "module"):
810
+ trainer.updater.model.module.load_state_dict(snapshot_dict["model"])
811
+ else:
812
+ trainer.updater.model.load_state_dict(snapshot_dict["model"])
813
+
814
+ # retore optimizer states
815
+ trainer.updater.get_optimizer("main").load_state_dict(snapshot_dict["optimizer"])
816
+
817
+ # delete opened snapshot
818
+ del snapshot_dict
819
+
820
+
821
+ # * ------------------ recognition related ------------------ *
822
+ def parse_hypothesis(hyp, char_list):
823
+ """Parse hypothesis.
824
+
825
+ Args:
826
+ hyp (list[dict[str, Any]]): Recognition hypothesis.
827
+ char_list (list[str]): List of characters.
828
+
829
+ Returns:
830
+ tuple(str, str, str, float)
831
+
832
+ """
833
+ # remove sos and get results
834
+ tokenid_as_list = list(map(int, hyp["yseq"][1:]))
835
+ token_as_list = [char_list[idx] for idx in tokenid_as_list]
836
+ score = float(hyp["score"])
837
+
838
+ # convert to string
839
+ tokenid = " ".join([str(idx) for idx in tokenid_as_list])
840
+ token = " ".join(token_as_list)
841
+ text = "".join(token_as_list).replace("<space>", " ")
842
+
843
+ return text, token, tokenid, score
844
+
845
+
846
+ def add_results_to_json(js, nbest_hyps, char_list):
847
+ """Add N-best results to json.
848
+
849
+ Args:
850
+ js (dict[str, Any]): Groundtruth utterance dict.
851
+ nbest_hyps_sd (list[dict[str, Any]]):
852
+ List of hypothesis for multi_speakers: nutts x nspkrs.
853
+ char_list (list[str]): List of characters.
854
+
855
+ Returns:
856
+ dict[str, Any]: N-best results added utterance dict.
857
+
858
+ """
859
+ # copy old json info
860
+ new_js = dict()
861
+ new_js["utt2spk"] = js["utt2spk"]
862
+ new_js["output"] = []
863
+
864
+ for n, hyp in enumerate(nbest_hyps, 1):
865
+ # parse hypothesis
866
+ rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
867
+
868
+ # copy ground-truth
869
+ if len(js["output"]) > 0:
870
+ out_dic = dict(js["output"][0].items())
871
+ else:
872
+ # for no reference case (e.g., speech translation)
873
+ out_dic = {"name": ""}
874
+
875
+ # update name
876
+ out_dic["name"] += "[%d]" % n
877
+
878
+ # add recognition results
879
+ out_dic["rec_text"] = rec_text
880
+ out_dic["rec_token"] = rec_token
881
+ out_dic["rec_tokenid"] = rec_tokenid
882
+ out_dic["score"] = score
883
+
884
+ # add to list of N-best result dicts
885
+ new_js["output"].append(out_dic)
886
+
887
+ # show 1-best result
888
+ if n == 1:
889
+ if "text" in out_dic.keys():
890
+ logging.info("groundtruth: %s" % out_dic["text"])
891
+ logging.info("prediction : %s" % out_dic["rec_text"])
892
+
893
+ return new_js
894
+
895
+
896
+ def plot_spectrogram(
897
+ plt,
898
+ spec,
899
+ mode="db",
900
+ fs=None,
901
+ frame_shift=None,
902
+ bottom=True,
903
+ left=True,
904
+ right=True,
905
+ top=False,
906
+ labelbottom=True,
907
+ labelleft=True,
908
+ labelright=True,
909
+ labeltop=False,
910
+ cmap="inferno",
911
+ ):
912
+ """Plot spectrogram using matplotlib.
913
+
914
+ Args:
915
+ plt (matplotlib.pyplot): pyplot object.
916
+ spec (numpy.ndarray): Input stft (Freq, Time)
917
+ mode (str): db or linear.
918
+ fs (int): Sample frequency. To convert y-axis to kHz unit.
919
+ frame_shift (int): The frame shift of stft. To convert x-axis to second unit.
920
+ bottom (bool):Whether to draw the respective ticks.
921
+ left (bool):
922
+ right (bool):
923
+ top (bool):
924
+ labelbottom (bool):Whether to draw the respective tick labels.
925
+ labelleft (bool):
926
+ labelright (bool):
927
+ labeltop (bool):
928
+ cmap (str): Colormap defined in matplotlib.
929
+
930
+ """
931
+ spec = np.abs(spec)
932
+ if mode == "db":
933
+ x = 20 * np.log10(spec + np.finfo(spec.dtype).eps)
934
+ elif mode == "linear":
935
+ x = spec
936
+ else:
937
+ raise ValueError(mode)
938
+
939
+ if fs is not None:
940
+ ytop = fs / 2000
941
+ ylabel = "kHz"
942
+ else:
943
+ ytop = x.shape[0]
944
+ ylabel = "bin"
945
+
946
+ if frame_shift is not None and fs is not None:
947
+ xtop = x.shape[1] * frame_shift / fs
948
+ xlabel = "s"
949
+ else:
950
+ xtop = x.shape[1]
951
+ xlabel = "frame"
952
+
953
+ extent = (0, xtop, 0, ytop)
954
+ plt.imshow(x[::-1], cmap=cmap, extent=extent)
955
+
956
+ if labelbottom:
957
+ plt.xlabel("time [{}]".format(xlabel))
958
+ if labelleft:
959
+ plt.ylabel("freq [{}]".format(ylabel))
960
+ plt.colorbar().set_label("{}".format(mode))
961
+
962
+ plt.tick_params(
963
+ bottom=bottom,
964
+ left=left,
965
+ right=right,
966
+ top=top,
967
+ labelbottom=labelbottom,
968
+ labelleft=labelleft,
969
+ labelright=labelright,
970
+ labeltop=labeltop,
971
+ )
972
+ plt.axis("auto")
973
+
974
+
975
+ # * ------------------ recognition related ------------------ *
976
+ def format_mulenc_args(args):
977
+ """Format args for multi-encoder setup.
978
+
979
+ It deals with following situations: (when args.num_encs=2):
980
+ 1. args.elayers = None -> args.elayers = [4, 4];
981
+ 2. args.elayers = 4 -> args.elayers = [4, 4];
982
+ 3. args.elayers = [4, 4, 4] -> args.elayers = [4, 4].
983
+
984
+ """
985
+ # default values when None is assigned.
986
+ default_dict = {
987
+ "etype": "blstmp",
988
+ "elayers": 4,
989
+ "eunits": 300,
990
+ "subsample": "1",
991
+ "dropout_rate": 0.0,
992
+ "atype": "dot",
993
+ "adim": 320,
994
+ "awin": 5,
995
+ "aheads": 4,
996
+ "aconv_chans": -1,
997
+ "aconv_filts": 100,
998
+ }
999
+ for k in default_dict.keys():
1000
+ if isinstance(vars(args)[k], list):
1001
+ if len(vars(args)[k]) != args.num_encs:
1002
+ logging.warning(
1003
+ "Length mismatch {}: Convert {} to {}.".format(
1004
+ k, vars(args)[k], vars(args)[k][: args.num_encs]
1005
+ )
1006
+ )
1007
+ vars(args)[k] = vars(args)[k][: args.num_encs]
1008
+ else:
1009
+ if not vars(args)[k]:
1010
+ # assign default value if it is None
1011
+ vars(args)[k] = default_dict[k]
1012
+ logging.warning(
1013
+ "{} is not specified, use default value {}.".format(
1014
+ k, default_dict[k]
1015
+ )
1016
+ )
1017
+ # duplicate
1018
+ logging.warning(
1019
+ "Type mismatch {}: Convert {} to {}.".format(
1020
+ k, vars(args)[k], [vars(args)[k] for _ in range(args.num_encs)]
1021
+ )
1022
+ )
1023
+ vars(args)[k] = [vars(args)[k] for _ in range(args.num_encs)]
1024
+ return args
espnet/asr/chainer_backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/asr/chainer_backend/asr.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """Training/decoding definition for the speech recognition task."""
5
+
6
+ import json
7
+ import logging
8
+ import os
9
+ import six
10
+
11
+ # chainer related
12
+ import chainer
13
+
14
+ from chainer import training
15
+
16
+ from chainer.datasets import TransformDataset
17
+ from chainer.training import extensions
18
+
19
+ # espnet related
20
+ from espnet.asr.asr_utils import adadelta_eps_decay
21
+ from espnet.asr.asr_utils import add_results_to_json
22
+ from espnet.asr.asr_utils import chainer_load
23
+ from espnet.asr.asr_utils import CompareValueTrigger
24
+ from espnet.asr.asr_utils import get_model_conf
25
+ from espnet.asr.asr_utils import restore_snapshot
26
+ from espnet.nets.asr_interface import ASRInterface
27
+ from espnet.utils.deterministic_utils import set_deterministic_chainer
28
+ from espnet.utils.dynamic_import import dynamic_import
29
+ from espnet.utils.io_utils import LoadInputsAndTargets
30
+ from espnet.utils.training.batchfy import make_batchset
31
+ from espnet.utils.training.evaluator import BaseEvaluator
32
+ from espnet.utils.training.iterators import ShufflingEnabler
33
+ from espnet.utils.training.iterators import ToggleableShufflingMultiprocessIterator
34
+ from espnet.utils.training.iterators import ToggleableShufflingSerialIterator
35
+ from espnet.utils.training.train_utils import check_early_stop
36
+ from espnet.utils.training.train_utils import set_early_stop
37
+
38
+ # rnnlm
39
+ import espnet.lm.chainer_backend.extlm as extlm_chainer
40
+ import espnet.lm.chainer_backend.lm as lm_chainer
41
+
42
+ # numpy related
43
+ import matplotlib
44
+
45
+ from espnet.utils.training.tensorboard_logger import TensorboardLogger
46
+ from tensorboardX import SummaryWriter
47
+
48
+ matplotlib.use("Agg")
49
+
50
+
51
+ def train(args):
52
+ """Train with the given args.
53
+
54
+ Args:
55
+ args (namespace): The program arguments.
56
+
57
+ """
58
+ # display chainer version
59
+ logging.info("chainer version = " + chainer.__version__)
60
+
61
+ set_deterministic_chainer(args)
62
+
63
+ # check cuda and cudnn availability
64
+ if not chainer.cuda.available:
65
+ logging.warning("cuda is not available")
66
+ if not chainer.cuda.cudnn_enabled:
67
+ logging.warning("cudnn is not available")
68
+
69
+ # get input and output dimension info
70
+ with open(args.valid_json, "rb") as f:
71
+ valid_json = json.load(f)["utts"]
72
+ utts = list(valid_json.keys())
73
+ idim = int(valid_json[utts[0]]["input"][0]["shape"][1])
74
+ odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
75
+ logging.info("#input dims : " + str(idim))
76
+ logging.info("#output dims: " + str(odim))
77
+
78
+ # specify attention, CTC, hybrid mode
79
+ if args.mtlalpha == 1.0:
80
+ mtl_mode = "ctc"
81
+ logging.info("Pure CTC mode")
82
+ elif args.mtlalpha == 0.0:
83
+ mtl_mode = "att"
84
+ logging.info("Pure attention mode")
85
+ else:
86
+ mtl_mode = "mtl"
87
+ logging.info("Multitask learning mode")
88
+
89
+ # specify model architecture
90
+ logging.info("import model module: " + args.model_module)
91
+ model_class = dynamic_import(args.model_module)
92
+ model = model_class(idim, odim, args, flag_return=False)
93
+ assert isinstance(model, ASRInterface)
94
+ total_subsampling_factor = model.get_total_subsampling_factor()
95
+
96
+ # write model config
97
+ if not os.path.exists(args.outdir):
98
+ os.makedirs(args.outdir)
99
+ model_conf = args.outdir + "/model.json"
100
+ with open(model_conf, "wb") as f:
101
+ logging.info("writing a model config file to " + model_conf)
102
+ f.write(
103
+ json.dumps(
104
+ (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
105
+ ).encode("utf_8")
106
+ )
107
+ for key in sorted(vars(args).keys()):
108
+ logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
109
+
110
+ # Set gpu
111
+ ngpu = args.ngpu
112
+ if ngpu == 1:
113
+ gpu_id = 0
114
+ # Make a specified GPU current
115
+ chainer.cuda.get_device_from_id(gpu_id).use()
116
+ model.to_gpu() # Copy the model to the GPU
117
+ logging.info("single gpu calculation.")
118
+ elif ngpu > 1:
119
+ gpu_id = 0
120
+ devices = {"main": gpu_id}
121
+ for gid in six.moves.xrange(1, ngpu):
122
+ devices["sub_%d" % gid] = gid
123
+ logging.info("multi gpu calculation (#gpus = %d)." % ngpu)
124
+ logging.warning(
125
+ "batch size is automatically increased (%d -> %d)"
126
+ % (args.batch_size, args.batch_size * args.ngpu)
127
+ )
128
+ else:
129
+ gpu_id = -1
130
+ logging.info("cpu calculation")
131
+
132
+ # Setup an optimizer
133
+ if args.opt == "adadelta":
134
+ optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
135
+ elif args.opt == "adam":
136
+ optimizer = chainer.optimizers.Adam()
137
+ elif args.opt == "noam":
138
+ optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9)
139
+ else:
140
+ raise NotImplementedError("args.opt={}".format(args.opt))
141
+
142
+ optimizer.setup(model)
143
+ optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))
144
+
145
+ # Setup a converter
146
+ converter = model.custom_converter(subsampling_factor=model.subsample[0])
147
+
148
+ # read json data
149
+ with open(args.train_json, "rb") as f:
150
+ train_json = json.load(f)["utts"]
151
+ with open(args.valid_json, "rb") as f:
152
+ valid_json = json.load(f)["utts"]
153
+
154
+ # set up training iterator and updater
155
+ load_tr = LoadInputsAndTargets(
156
+ mode="asr",
157
+ load_output=True,
158
+ preprocess_conf=args.preprocess_conf,
159
+ preprocess_args={"train": True}, # Switch the mode of preprocessing
160
+ )
161
+ load_cv = LoadInputsAndTargets(
162
+ mode="asr",
163
+ load_output=True,
164
+ preprocess_conf=args.preprocess_conf,
165
+ preprocess_args={"train": False}, # Switch the mode of preprocessing
166
+ )
167
+
168
+ use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
169
+ accum_grad = args.accum_grad
170
+ if ngpu <= 1:
171
+ # make minibatch list (variable length)
172
+ train = make_batchset(
173
+ train_json,
174
+ args.batch_size,
175
+ args.maxlen_in,
176
+ args.maxlen_out,
177
+ args.minibatches,
178
+ min_batch_size=args.ngpu if args.ngpu > 1 else 1,
179
+ shortest_first=use_sortagrad,
180
+ count=args.batch_count,
181
+ batch_bins=args.batch_bins,
182
+ batch_frames_in=args.batch_frames_in,
183
+ batch_frames_out=args.batch_frames_out,
184
+ batch_frames_inout=args.batch_frames_inout,
185
+ iaxis=0,
186
+ oaxis=0,
187
+ )
188
+ # hack to make batchsize argument as 1
189
+ # actual batchsize is included in a list
190
+ if args.n_iter_processes > 0:
191
+ train_iters = [
192
+ ToggleableShufflingMultiprocessIterator(
193
+ TransformDataset(train, load_tr),
194
+ batch_size=1,
195
+ n_processes=args.n_iter_processes,
196
+ n_prefetch=8,
197
+ maxtasksperchild=20,
198
+ shuffle=not use_sortagrad,
199
+ )
200
+ ]
201
+ else:
202
+ train_iters = [
203
+ ToggleableShufflingSerialIterator(
204
+ TransformDataset(train, load_tr),
205
+ batch_size=1,
206
+ shuffle=not use_sortagrad,
207
+ )
208
+ ]
209
+
210
+ # set up updater
211
+ updater = model.custom_updater(
212
+ train_iters[0],
213
+ optimizer,
214
+ converter=converter,
215
+ device=gpu_id,
216
+ accum_grad=accum_grad,
217
+ )
218
+ else:
219
+ if args.batch_count not in ("auto", "seq") and args.batch_size == 0:
220
+ raise NotImplementedError(
221
+ "--batch-count 'bin' and 'frame' are not implemented "
222
+ "in chainer multi gpu"
223
+ )
224
+ # set up minibatches
225
+ train_subsets = []
226
+ for gid in six.moves.xrange(ngpu):
227
+ # make subset
228
+ train_json_subset = {
229
+ k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid
230
+ }
231
+ # make minibatch list (variable length)
232
+ train_subsets += [
233
+ make_batchset(
234
+ train_json_subset,
235
+ args.batch_size,
236
+ args.maxlen_in,
237
+ args.maxlen_out,
238
+ args.minibatches,
239
+ )
240
+ ]
241
+
242
+ # each subset must have same length for MultiprocessParallelUpdater
243
+ maxlen = max([len(train_subset) for train_subset in train_subsets])
244
+ for train_subset in train_subsets:
245
+ if maxlen != len(train_subset):
246
+ for i in six.moves.xrange(maxlen - len(train_subset)):
247
+ train_subset += [train_subset[i]]
248
+
249
+ # hack to make batchsize argument as 1
250
+ # actual batchsize is included in a list
251
+ if args.n_iter_processes > 0:
252
+ train_iters = [
253
+ ToggleableShufflingMultiprocessIterator(
254
+ TransformDataset(train_subsets[gid], load_tr),
255
+ batch_size=1,
256
+ n_processes=args.n_iter_processes,
257
+ n_prefetch=8,
258
+ maxtasksperchild=20,
259
+ shuffle=not use_sortagrad,
260
+ )
261
+ for gid in six.moves.xrange(ngpu)
262
+ ]
263
+ else:
264
+ train_iters = [
265
+ ToggleableShufflingSerialIterator(
266
+ TransformDataset(train_subsets[gid], load_tr),
267
+ batch_size=1,
268
+ shuffle=not use_sortagrad,
269
+ )
270
+ for gid in six.moves.xrange(ngpu)
271
+ ]
272
+
273
+ # set up updater
274
+ updater = model.custom_parallel_updater(
275
+ train_iters, optimizer, converter=converter, devices=devices
276
+ )
277
+
278
+ # Set up a trainer
279
+ trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
280
+
281
+ if use_sortagrad:
282
+ trainer.extend(
283
+ ShufflingEnabler(train_iters),
284
+ trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
285
+ )
286
+ if args.opt == "noam":
287
+ from espnet.nets.chainer_backend.transformer.training import VaswaniRule
288
+
289
+ trainer.extend(
290
+ VaswaniRule(
291
+ "alpha",
292
+ d=args.adim,
293
+ warmup_steps=args.transformer_warmup_steps,
294
+ scale=args.transformer_lr,
295
+ ),
296
+ trigger=(1, "iteration"),
297
+ )
298
+ # Resume from a snapshot
299
+ if args.resume:
300
+ chainer.serializers.load_npz(args.resume, trainer)
301
+
302
+ # set up validation iterator
303
+ valid = make_batchset(
304
+ valid_json,
305
+ args.batch_size,
306
+ args.maxlen_in,
307
+ args.maxlen_out,
308
+ args.minibatches,
309
+ min_batch_size=args.ngpu if args.ngpu > 1 else 1,
310
+ count=args.batch_count,
311
+ batch_bins=args.batch_bins,
312
+ batch_frames_in=args.batch_frames_in,
313
+ batch_frames_out=args.batch_frames_out,
314
+ batch_frames_inout=args.batch_frames_inout,
315
+ iaxis=0,
316
+ oaxis=0,
317
+ )
318
+
319
+ if args.n_iter_processes > 0:
320
+ valid_iter = chainer.iterators.MultiprocessIterator(
321
+ TransformDataset(valid, load_cv),
322
+ batch_size=1,
323
+ repeat=False,
324
+ shuffle=False,
325
+ n_processes=args.n_iter_processes,
326
+ n_prefetch=8,
327
+ maxtasksperchild=20,
328
+ )
329
+ else:
330
+ valid_iter = chainer.iterators.SerialIterator(
331
+ TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False
332
+ )
333
+
334
+ # Evaluate the model with the test dataset for each epoch
335
+ trainer.extend(BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id))
336
+
337
+ # Save attention weight each epoch
338
+ if args.num_save_attention > 0 and args.mtlalpha != 1.0:
339
+ data = sorted(
340
+ list(valid_json.items())[: args.num_save_attention],
341
+ key=lambda x: int(x[1]["input"][0]["shape"][1]),
342
+ reverse=True,
343
+ )
344
+ if hasattr(model, "module"):
345
+ att_vis_fn = model.module.calculate_all_attentions
346
+ plot_class = model.module.attention_plot_class
347
+ else:
348
+ att_vis_fn = model.calculate_all_attentions
349
+ plot_class = model.attention_plot_class
350
+ logging.info("Using custom PlotAttentionReport")
351
+ att_reporter = plot_class(
352
+ att_vis_fn,
353
+ data,
354
+ args.outdir + "/att_ws",
355
+ converter=converter,
356
+ transform=load_cv,
357
+ device=gpu_id,
358
+ subsampling_factor=total_subsampling_factor,
359
+ )
360
+ trainer.extend(att_reporter, trigger=(1, "epoch"))
361
+ else:
362
+ att_reporter = None
363
+
364
+ # Take a snapshot for each specified epoch
365
+ trainer.extend(
366
+ extensions.snapshot(filename="snapshot.ep.{.updater.epoch}"),
367
+ trigger=(1, "epoch"),
368
+ )
369
+
370
+ # Make a plot for training and validation values
371
+ trainer.extend(
372
+ extensions.PlotReport(
373
+ [
374
+ "main/loss",
375
+ "validation/main/loss",
376
+ "main/loss_ctc",
377
+ "validation/main/loss_ctc",
378
+ "main/loss_att",
379
+ "validation/main/loss_att",
380
+ ],
381
+ "epoch",
382
+ file_name="loss.png",
383
+ )
384
+ )
385
+ trainer.extend(
386
+ extensions.PlotReport(
387
+ ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
388
+ )
389
+ )
390
+
391
+ # Save best models
392
+ trainer.extend(
393
+ extensions.snapshot_object(model, "model.loss.best"),
394
+ trigger=training.triggers.MinValueTrigger("validation/main/loss"),
395
+ )
396
+ if mtl_mode != "ctc":
397
+ trainer.extend(
398
+ extensions.snapshot_object(model, "model.acc.best"),
399
+ trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
400
+ )
401
+
402
+ # epsilon decay in the optimizer
403
+ if args.opt == "adadelta":
404
+ if args.criterion == "acc" and mtl_mode != "ctc":
405
+ trainer.extend(
406
+ restore_snapshot(model, args.outdir + "/model.acc.best"),
407
+ trigger=CompareValueTrigger(
408
+ "validation/main/acc",
409
+ lambda best_value, current_value: best_value > current_value,
410
+ ),
411
+ )
412
+ trainer.extend(
413
+ adadelta_eps_decay(args.eps_decay),
414
+ trigger=CompareValueTrigger(
415
+ "validation/main/acc",
416
+ lambda best_value, current_value: best_value > current_value,
417
+ ),
418
+ )
419
+ elif args.criterion == "loss":
420
+ trainer.extend(
421
+ restore_snapshot(model, args.outdir + "/model.loss.best"),
422
+ trigger=CompareValueTrigger(
423
+ "validation/main/loss",
424
+ lambda best_value, current_value: best_value < current_value,
425
+ ),
426
+ )
427
+ trainer.extend(
428
+ adadelta_eps_decay(args.eps_decay),
429
+ trigger=CompareValueTrigger(
430
+ "validation/main/loss",
431
+ lambda best_value, current_value: best_value < current_value,
432
+ ),
433
+ )
434
+
435
+ # Write a log of evaluation statistics for each epoch
436
+ trainer.extend(
437
+ extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
438
+ )
439
+ report_keys = [
440
+ "epoch",
441
+ "iteration",
442
+ "main/loss",
443
+ "main/loss_ctc",
444
+ "main/loss_att",
445
+ "validation/main/loss",
446
+ "validation/main/loss_ctc",
447
+ "validation/main/loss_att",
448
+ "main/acc",
449
+ "validation/main/acc",
450
+ "elapsed_time",
451
+ ]
452
+ if args.opt == "adadelta":
453
+ trainer.extend(
454
+ extensions.observe_value(
455
+ "eps", lambda trainer: trainer.updater.get_optimizer("main").eps
456
+ ),
457
+ trigger=(args.report_interval_iters, "iteration"),
458
+ )
459
+ report_keys.append("eps")
460
+ trainer.extend(
461
+ extensions.PrintReport(report_keys),
462
+ trigger=(args.report_interval_iters, "iteration"),
463
+ )
464
+
465
+ trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
466
+
467
+ set_early_stop(trainer, args)
468
+ if args.tensorboard_dir is not None and args.tensorboard_dir != "":
469
+ writer = SummaryWriter(args.tensorboard_dir)
470
+ trainer.extend(
471
+ TensorboardLogger(writer, att_reporter),
472
+ trigger=(args.report_interval_iters, "iteration"),
473
+ )
474
+
475
+ # Run the training
476
+ trainer.run()
477
+ check_early_stop(trainer, args.epochs)
478
+
479
+
480
+ def recog(args):
481
+ """Decode with the given args.
482
+
483
+ Args:
484
+ args (namespace): The program arguments.
485
+
486
+ """
487
+ # display chainer version
488
+ logging.info("chainer version = " + chainer.__version__)
489
+
490
+ set_deterministic_chainer(args)
491
+
492
+ # read training config
493
+ idim, odim, train_args = get_model_conf(args.model, args.model_conf)
494
+
495
+ for key in sorted(vars(args).keys()):
496
+ logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
497
+
498
+ # specify model architecture
499
+ logging.info("reading model parameters from " + args.model)
500
+ # To be compatible with v.0.3.0 models
501
+ if hasattr(train_args, "model_module"):
502
+ model_module = train_args.model_module
503
+ else:
504
+ model_module = "espnet.nets.chainer_backend.e2e_asr:E2E"
505
+ model_class = dynamic_import(model_module)
506
+ model = model_class(idim, odim, train_args)
507
+ assert isinstance(model, ASRInterface)
508
+ chainer_load(args.model, model)
509
+
510
+ # read rnnlm
511
+ if args.rnnlm:
512
+ rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
513
+ rnnlm = lm_chainer.ClassifierWithState(
514
+ lm_chainer.RNNLM(
515
+ len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit
516
+ )
517
+ )
518
+ chainer_load(args.rnnlm, rnnlm)
519
+ else:
520
+ rnnlm = None
521
+
522
+ if args.word_rnnlm:
523
+ rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
524
+ word_dict = rnnlm_args.char_list_dict
525
+ char_dict = {x: i for i, x in enumerate(train_args.char_list)}
526
+ word_rnnlm = lm_chainer.ClassifierWithState(
527
+ lm_chainer.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)
528
+ )
529
+ chainer_load(args.word_rnnlm, word_rnnlm)
530
+
531
+ if rnnlm is not None:
532
+ rnnlm = lm_chainer.ClassifierWithState(
533
+ extlm_chainer.MultiLevelLM(
534
+ word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
535
+ )
536
+ )
537
+ else:
538
+ rnnlm = lm_chainer.ClassifierWithState(
539
+ extlm_chainer.LookAheadWordLM(
540
+ word_rnnlm.predictor, word_dict, char_dict
541
+ )
542
+ )
543
+
544
+ # read json data
545
+ with open(args.recog_json, "rb") as f:
546
+ js = json.load(f)["utts"]
547
+
548
+ load_inputs_and_targets = LoadInputsAndTargets(
549
+ mode="asr",
550
+ load_output=False,
551
+ sort_in_input_length=False,
552
+ preprocess_conf=train_args.preprocess_conf
553
+ if args.preprocess_conf is None
554
+ else args.preprocess_conf,
555
+ preprocess_args={"train": False}, # Switch the mode of preprocessing
556
+ )
557
+
558
+ # decode each utterance
559
+ new_js = {}
560
+ with chainer.no_backprop_mode():
561
+ for idx, name in enumerate(js.keys(), 1):
562
+ logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
563
+ batch = [(name, js[name])]
564
+ feat = load_inputs_and_targets(batch)[0][0]
565
+ nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm)
566
+ new_js[name] = add_results_to_json(
567
+ js[name], nbest_hyps, train_args.char_list
568
+ )
569
+
570
+ with open(args.result_label, "wb") as f:
571
+ f.write(
572
+ json.dumps(
573
+ {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
574
+ ).encode("utf_8")
575
+ )
espnet/asr/pytorch_backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/asr/pytorch_backend/asr.py ADDED
@@ -0,0 +1,1500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """Training/decoding definition for the speech recognition task."""
5
+
6
+ import copy
7
+ import json
8
+ import logging
9
+ import math
10
+ import os
11
+ import sys
12
+
13
+ from chainer import reporter as reporter_module
14
+ from chainer import training
15
+ from chainer.training import extensions
16
+ from chainer.training.updater import StandardUpdater
17
+ import numpy as np
18
+ from tensorboardX import SummaryWriter
19
+ import torch
20
+ from torch.nn.parallel import data_parallel
21
+
22
+ from espnet.asr.asr_utils import adadelta_eps_decay
23
+ from espnet.asr.asr_utils import add_results_to_json
24
+ from espnet.asr.asr_utils import CompareValueTrigger
25
+ from espnet.asr.asr_utils import format_mulenc_args
26
+ from espnet.asr.asr_utils import get_model_conf
27
+ from espnet.asr.asr_utils import plot_spectrogram
28
+ from espnet.asr.asr_utils import restore_snapshot
29
+ from espnet.asr.asr_utils import snapshot_object
30
+ from espnet.asr.asr_utils import torch_load
31
+ from espnet.asr.asr_utils import torch_resume
32
+ from espnet.asr.asr_utils import torch_snapshot
33
+ from espnet.asr.pytorch_backend.asr_init import freeze_modules
34
+ from espnet.asr.pytorch_backend.asr_init import load_trained_model
35
+ from espnet.asr.pytorch_backend.asr_init import load_trained_modules
36
+ import espnet.lm.pytorch_backend.extlm as extlm_pytorch
37
+ from espnet.nets.asr_interface import ASRInterface
38
+ from espnet.nets.beam_search_transducer import BeamSearchTransducer
39
+ from espnet.nets.pytorch_backend.e2e_asr import pad_list
40
+ import espnet.nets.pytorch_backend.lm.default as lm_pytorch
41
+ from espnet.nets.pytorch_backend.streaming.segment import SegmentStreamingE2E
42
+ from espnet.nets.pytorch_backend.streaming.window import WindowStreamingE2E
43
+ from espnet.transform.spectrogram import IStft
44
+ from espnet.transform.transformation import Transformation
45
+ from espnet.utils.cli_writers import file_writer_helper
46
+ from espnet.utils.dataset import ChainerDataLoader
47
+ from espnet.utils.dataset import TransformDataset
48
+ from espnet.utils.deterministic_utils import set_deterministic_pytorch
49
+ from espnet.utils.dynamic_import import dynamic_import
50
+ from espnet.utils.io_utils import LoadInputsAndTargets
51
+ from espnet.utils.training.batchfy import make_batchset
52
+ from espnet.utils.training.evaluator import BaseEvaluator
53
+ from espnet.utils.training.iterators import ShufflingEnabler
54
+ from espnet.utils.training.tensorboard_logger import TensorboardLogger
55
+ from espnet.utils.training.train_utils import check_early_stop
56
+ from espnet.utils.training.train_utils import set_early_stop
57
+
58
+ import matplotlib
59
+
60
+ matplotlib.use("Agg")
61
+
62
+ if sys.version_info[0] == 2:
63
+ from itertools import izip_longest as zip_longest
64
+ else:
65
+ from itertools import zip_longest as zip_longest
66
+
67
+
68
+ def _recursive_to(xs, device):
69
+ if torch.is_tensor(xs):
70
+ return xs.to(device)
71
+ if isinstance(xs, tuple):
72
+ return tuple(_recursive_to(x, device) for x in xs)
73
+ return xs
74
+
75
+
76
+ class CustomEvaluator(BaseEvaluator):
77
+ """Custom Evaluator for Pytorch.
78
+
79
+ Args:
80
+ model (torch.nn.Module): The model to evaluate.
81
+ iterator (chainer.dataset.Iterator) : The train iterator.
82
+
83
+ target (link | dict[str, link]) :Link object or a dictionary of
84
+ links to evaluate. If this is just a link object, the link is
85
+ registered by the name ``'main'``.
86
+
87
+ device (torch.device): The device used.
88
+ ngpu (int): The number of GPUs.
89
+
90
+ """
91
+
92
+ def __init__(self, model, iterator, target, device, ngpu=None):
93
+ super(CustomEvaluator, self).__init__(iterator, target)
94
+ self.model = model
95
+ self.device = device
96
+ if ngpu is not None:
97
+ self.ngpu = ngpu
98
+ elif device.type == "cpu":
99
+ self.ngpu = 0
100
+ else:
101
+ self.ngpu = 1
102
+
103
+ # The core part of the update routine can be customized by overriding
104
+ def evaluate(self):
105
+ """Main evaluate routine for CustomEvaluator."""
106
+ iterator = self._iterators["main"]
107
+
108
+ if self.eval_hook:
109
+ self.eval_hook(self)
110
+
111
+ if hasattr(iterator, "reset"):
112
+ iterator.reset()
113
+ it = iterator
114
+ else:
115
+ it = copy.copy(iterator)
116
+
117
+ summary = reporter_module.DictSummary()
118
+
119
+ self.model.eval()
120
+ with torch.no_grad():
121
+ for batch in it:
122
+ x = _recursive_to(batch, self.device)
123
+ observation = {}
124
+ with reporter_module.report_scope(observation):
125
+ # read scp files
126
+ # x: original json with loaded features
127
+ # will be converted to chainer variable later
128
+ if self.ngpu == 0:
129
+ self.model(*x)
130
+ else:
131
+ # apex does not support torch.nn.DataParallel
132
+ data_parallel(self.model, x, range(self.ngpu))
133
+
134
+ summary.add(observation)
135
+ self.model.train()
136
+
137
+ return summary.compute_mean()
138
+
139
+
140
+ class CustomUpdater(StandardUpdater):
141
+ """Custom Updater for Pytorch.
142
+
143
+ Args:
144
+ model (torch.nn.Module): The model to update.
145
+ grad_clip_threshold (float): The gradient clipping value to use.
146
+ train_iter (chainer.dataset.Iterator): The training iterator.
147
+ optimizer (torch.optim.optimizer): The training optimizer.
148
+
149
+ device (torch.device): The device to use.
150
+ ngpu (int): The number of gpus to use.
151
+ use_apex (bool): The flag to use Apex in backprop.
152
+
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ model,
158
+ grad_clip_threshold,
159
+ train_iter,
160
+ optimizer,
161
+ device,
162
+ ngpu,
163
+ grad_noise=False,
164
+ accum_grad=1,
165
+ use_apex=False,
166
+ ):
167
+ super(CustomUpdater, self).__init__(train_iter, optimizer)
168
+ self.model = model
169
+ self.grad_clip_threshold = grad_clip_threshold
170
+ self.device = device
171
+ self.ngpu = ngpu
172
+ self.accum_grad = accum_grad
173
+ self.forward_count = 0
174
+ self.grad_noise = grad_noise
175
+ self.iteration = 0
176
+ self.use_apex = use_apex
177
+
178
+ # The core part of the update routine can be customized by overriding.
179
+ def update_core(self):
180
+ """Main update routine of the CustomUpdater."""
181
+ # When we pass one iterator and optimizer to StandardUpdater.__init__,
182
+ # they are automatically named 'main'.
183
+ train_iter = self.get_iterator("main")
184
+ optimizer = self.get_optimizer("main")
185
+ epoch = train_iter.epoch
186
+
187
+ # Get the next batch (a list of json files)
188
+ batch = train_iter.next()
189
+ # self.iteration += 1 # Increase may result in early report,
190
+ # which is done in other place automatically.
191
+ x = _recursive_to(batch, self.device)
192
+ is_new_epoch = train_iter.epoch != epoch
193
+ # When the last minibatch in the current epoch is given,
194
+ # gradient accumulation is turned off in order to evaluate the model
195
+ # on the validation set in every epoch.
196
+ # see details in https://github.com/espnet/espnet/pull/1388
197
+
198
+ # Compute the loss at this time step and accumulate it
199
+ if self.ngpu == 0:
200
+ loss = self.model(*x).mean() / self.accum_grad
201
+ else:
202
+ # apex does not support torch.nn.DataParallel
203
+ loss = (
204
+ data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad
205
+ )
206
+ if self.use_apex:
207
+ from apex import amp
208
+
209
+ # NOTE: for a compatibility with noam optimizer
210
+ opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
211
+ with amp.scale_loss(loss, opt) as scaled_loss:
212
+ scaled_loss.backward()
213
+ else:
214
+ loss.backward()
215
+ # gradient noise injection
216
+ if self.grad_noise:
217
+ from espnet.asr.asr_utils import add_gradient_noise
218
+
219
+ add_gradient_noise(
220
+ self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55
221
+ )
222
+
223
+ # update parameters
224
+ self.forward_count += 1
225
+ if not is_new_epoch and self.forward_count != self.accum_grad:
226
+ return
227
+ self.forward_count = 0
228
+ # compute the gradient norm to check if it is normal or not
229
+ grad_norm = torch.nn.utils.clip_grad_norm_(
230
+ self.model.parameters(), self.grad_clip_threshold
231
+ )
232
+ logging.info("grad norm={}".format(grad_norm))
233
+ if math.isnan(grad_norm):
234
+ logging.warning("grad norm is nan. Do not update model.")
235
+ else:
236
+ optimizer.step()
237
+ optimizer.zero_grad()
238
+
239
+ def update(self):
240
+ self.update_core()
241
+ # #iterations with accum_grad > 1
242
+ # Ref.: https://github.com/espnet/espnet/issues/777
243
+ if self.forward_count == 0:
244
+ self.iteration += 1
245
+
246
+
247
+ class CustomConverter(object):
248
+ """Custom batch converter for Pytorch.
249
+
250
+ Args:
251
+ subsampling_factor (int): The subsampling factor.
252
+ dtype (torch.dtype): Data type to convert.
253
+
254
+ """
255
+
256
+ def __init__(self, subsampling_factor=1, dtype=torch.float32):
257
+ """Construct a CustomConverter object."""
258
+ self.subsampling_factor = subsampling_factor
259
+ self.ignore_id = -1
260
+ self.dtype = dtype
261
+
262
+ def __call__(self, batch, device=torch.device("cpu")):
263
+ """Transform a batch and send it to a device.
264
+
265
+ Args:
266
+ batch (list): The batch to transform.
267
+ device (torch.device): The device to send to.
268
+
269
+ Returns:
270
+ tuple(torch.Tensor, torch.Tensor, torch.Tensor)
271
+
272
+ """
273
+ # batch should be located in list
274
+ assert len(batch) == 1
275
+ xs, ys = batch[0]
276
+
277
+ # perform subsampling
278
+ if self.subsampling_factor > 1:
279
+ xs = [x[:: self.subsampling_factor, :] for x in xs]
280
+
281
+ # get batch of lengths of input sequences
282
+ ilens = np.array([x.shape[0] for x in xs])
283
+
284
+ # perform padding and convert to tensor
285
+ # currently only support real number
286
+ if xs[0].dtype.kind == "c":
287
+ xs_pad_real = pad_list(
288
+ [torch.from_numpy(x.real).float() for x in xs], 0
289
+ ).to(device, dtype=self.dtype)
290
+ xs_pad_imag = pad_list(
291
+ [torch.from_numpy(x.imag).float() for x in xs], 0
292
+ ).to(device, dtype=self.dtype)
293
+ # Note(kamo):
294
+ # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
295
+ # Don't create ComplexTensor and give it E2E here
296
+ # because torch.nn.DataParellel can't handle it.
297
+ xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
298
+ else:
299
+ xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(
300
+ device, dtype=self.dtype
301
+ )
302
+
303
+ ilens = torch.from_numpy(ilens).to(device)
304
+ # NOTE: this is for multi-output (e.g., speech translation)
305
+ ys_pad = pad_list(
306
+ [
307
+ torch.from_numpy(
308
+ np.array(y[0][:]) if isinstance(y, tuple) else y
309
+ ).long()
310
+ for y in ys
311
+ ],
312
+ self.ignore_id,
313
+ ).to(device)
314
+
315
+ return xs_pad, ilens, ys_pad
316
+
317
+
318
+ class CustomConverterMulEnc(object):
319
+ """Custom batch converter for Pytorch in multi-encoder case.
320
+
321
+ Args:
322
+ subsampling_factors (list): List of subsampling factors for each encoder.
323
+ dtype (torch.dtype): Data type to convert.
324
+
325
+ """
326
+
327
+ def __init__(self, subsamping_factors=[1, 1], dtype=torch.float32):
328
+ """Initialize the converter."""
329
+ self.subsamping_factors = subsamping_factors
330
+ self.ignore_id = -1
331
+ self.dtype = dtype
332
+ self.num_encs = len(subsamping_factors)
333
+
334
+ def __call__(self, batch, device=torch.device("cpu")):
335
+ """Transform a batch and send it to a device.
336
+
337
+ Args:
338
+ batch (list): The batch to transform.
339
+ device (torch.device): The device to send to.
340
+
341
+ Returns:
342
+ tuple( list(torch.Tensor), list(torch.Tensor), torch.Tensor)
343
+
344
+ """
345
+ # batch should be located in list
346
+ assert len(batch) == 1
347
+ xs_list = batch[0][: self.num_encs]
348
+ ys = batch[0][-1]
349
+
350
+ # perform subsampling
351
+ if np.sum(self.subsamping_factors) > self.num_encs:
352
+ xs_list = [
353
+ [x[:: self.subsampling_factors[i], :] for x in xs_list[i]]
354
+ for i in range(self.num_encs)
355
+ ]
356
+
357
+ # get batch of lengths of input sequences
358
+ ilens_list = [
359
+ np.array([x.shape[0] for x in xs_list[i]]) for i in range(self.num_encs)
360
+ ]
361
+
362
+ # perform padding and convert to tensor
363
+ # currently only support real number
364
+ xs_list_pad = [
365
+ pad_list([torch.from_numpy(x).float() for x in xs_list[i]], 0).to(
366
+ device, dtype=self.dtype
367
+ )
368
+ for i in range(self.num_encs)
369
+ ]
370
+
371
+ ilens_list = [
372
+ torch.from_numpy(ilens_list[i]).to(device) for i in range(self.num_encs)
373
+ ]
374
+ # NOTE: this is for multi-task learning (e.g., speech translation)
375
+ ys_pad = pad_list(
376
+ [
377
+ torch.from_numpy(np.array(y[0]) if isinstance(y, tuple) else y).long()
378
+ for y in ys
379
+ ],
380
+ self.ignore_id,
381
+ ).to(device)
382
+
383
+ return xs_list_pad, ilens_list, ys_pad
384
+
385
+
386
+ def train(args):
387
+ """Train with the given args.
388
+
389
+ Args:
390
+ args (namespace): The program arguments.
391
+
392
+ """
393
+ set_deterministic_pytorch(args)
394
+ if args.num_encs > 1:
395
+ args = format_mulenc_args(args)
396
+
397
+ # check cuda availability
398
+ if not torch.cuda.is_available():
399
+ logging.warning("cuda is not available")
400
+
401
+ # get input and output dimension info
402
+ with open(args.valid_json, "rb") as f:
403
+ valid_json = json.load(f)["utts"]
404
+ utts = list(valid_json.keys())
405
+ idim_list = [
406
+ int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs)
407
+ ]
408
+ odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
409
+ for i in range(args.num_encs):
410
+ logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i]))
411
+ logging.info("#output dims: " + str(odim))
412
+
413
+ # specify attention, CTC, hybrid mode
414
+ if "transducer" in args.model_module:
415
+ if (
416
+ getattr(args, "etype", False) == "custom"
417
+ or getattr(args, "dtype", False) == "custom"
418
+ ):
419
+ mtl_mode = "custom_transducer"
420
+ else:
421
+ mtl_mode = "transducer"
422
+ logging.info("Pure transducer mode")
423
+ elif args.mtlalpha == 1.0:
424
+ mtl_mode = "ctc"
425
+ logging.info("Pure CTC mode")
426
+ elif args.mtlalpha == 0.0:
427
+ mtl_mode = "att"
428
+ logging.info("Pure attention mode")
429
+ else:
430
+ mtl_mode = "mtl"
431
+ logging.info("Multitask learning mode")
432
+
433
+ if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1:
434
+ model = load_trained_modules(idim_list[0], odim, args)
435
+ else:
436
+ model_class = dynamic_import(args.model_module)
437
+ model = model_class(
438
+ idim_list[0] if args.num_encs == 1 else idim_list, odim, args
439
+ )
440
+ assert isinstance(model, ASRInterface)
441
+ total_subsampling_factor = model.get_total_subsampling_factor()
442
+
443
+ logging.info(
444
+ " Total parameter of the model = "
445
+ + str(sum(p.numel() for p in model.parameters()))
446
+ )
447
+
448
+ if args.rnnlm is not None:
449
+ rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
450
+ rnnlm = lm_pytorch.ClassifierWithState(
451
+ lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)
452
+ )
453
+ torch_load(args.rnnlm, rnnlm)
454
+ model.rnnlm = rnnlm
455
+
456
+ # write model config
457
+ if not os.path.exists(args.outdir):
458
+ os.makedirs(args.outdir)
459
+ model_conf = args.outdir + "/model.json"
460
+ with open(model_conf, "wb") as f:
461
+ logging.info("writing a model config file to " + model_conf)
462
+ f.write(
463
+ json.dumps(
464
+ (idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)),
465
+ indent=4,
466
+ ensure_ascii=False,
467
+ sort_keys=True,
468
+ ).encode("utf_8")
469
+ )
470
+ for key in sorted(vars(args).keys()):
471
+ logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
472
+
473
+ reporter = model.reporter
474
+
475
+ # check the use of multi-gpu
476
+ if args.ngpu > 1:
477
+ if args.batch_size != 0:
478
+ logging.warning(
479
+ "batch size is automatically increased (%d -> %d)"
480
+ % (args.batch_size, args.batch_size * args.ngpu)
481
+ )
482
+ args.batch_size *= args.ngpu
483
+ if args.num_encs > 1:
484
+ # TODO(ruizhili): implement data parallel for multi-encoder setup.
485
+ raise NotImplementedError(
486
+ "Data parallel is not supported for multi-encoder setup."
487
+ )
488
+
489
+ # set torch device
490
+ device = torch.device("cuda" if args.ngpu > 0 else "cpu")
491
+ if args.train_dtype in ("float16", "float32", "float64"):
492
+ dtype = getattr(torch, args.train_dtype)
493
+ else:
494
+ dtype = torch.float32
495
+ model = model.to(device=device, dtype=dtype)
496
+
497
+ if args.freeze_mods:
498
+ model, model_params = freeze_modules(model, args.freeze_mods)
499
+ else:
500
+ model_params = model.parameters()
501
+
502
+ logging.warning(
503
+ "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
504
+ sum(p.numel() for p in model.parameters()),
505
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
506
+ sum(p.numel() for p in model.parameters() if p.requires_grad)
507
+ * 100.0
508
+ / sum(p.numel() for p in model.parameters()),
509
+ )
510
+ )
511
+
512
+ # Setup an optimizer
513
+ if args.opt == "adadelta":
514
+ optimizer = torch.optim.Adadelta(
515
+ model_params, rho=0.95, eps=args.eps, weight_decay=args.weight_decay
516
+ )
517
+ elif args.opt == "adam":
518
+ optimizer = torch.optim.Adam(model_params, weight_decay=args.weight_decay)
519
+ elif args.opt == "noam":
520
+ from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
521
+
522
+ # For transformer-transducer, adim declaration is within the block definition.
523
+ # Thus, we need retrieve the most dominant value (d_hidden) for Noam scheduler.
524
+ if hasattr(args, "enc_block_arch") or hasattr(args, "dec_block_arch"):
525
+ adim = model.most_dom_dim
526
+ else:
527
+ adim = args.adim
528
+
529
+ optimizer = get_std_opt(
530
+ model_params, adim, args.transformer_warmup_steps, args.transformer_lr
531
+ )
532
+ else:
533
+ raise NotImplementedError("unknown optimizer: " + args.opt)
534
+
535
+ # setup apex.amp
536
+ if args.train_dtype in ("O0", "O1", "O2", "O3"):
537
+ try:
538
+ from apex import amp
539
+ except ImportError as e:
540
+ logging.error(
541
+ f"You need to install apex for --train-dtype {args.train_dtype}. "
542
+ "See https://github.com/NVIDIA/apex#linux"
543
+ )
544
+ raise e
545
+ if args.opt == "noam":
546
+ model, optimizer.optimizer = amp.initialize(
547
+ model, optimizer.optimizer, opt_level=args.train_dtype
548
+ )
549
+ else:
550
+ model, optimizer = amp.initialize(
551
+ model, optimizer, opt_level=args.train_dtype
552
+ )
553
+ use_apex = True
554
+
555
+ from espnet.nets.pytorch_backend.ctc import CTC
556
+
557
+ amp.register_float_function(CTC, "loss_fn")
558
+ amp.init()
559
+ logging.warning("register ctc as float function")
560
+ else:
561
+ use_apex = False
562
+
563
+ # FIXME: TOO DIRTY HACK
564
+ setattr(optimizer, "target", reporter)
565
+ setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
566
+
567
+ # Setup a converter
568
+ if args.num_encs == 1:
569
+ converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype)
570
+ else:
571
+ converter = CustomConverterMulEnc(
572
+ [i[0] for i in model.subsample_list], dtype=dtype
573
+ )
574
+
575
+ # read json data
576
+ with open(args.train_json, "rb") as f:
577
+ train_json = json.load(f)["utts"]
578
+ with open(args.valid_json, "rb") as f:
579
+ valid_json = json.load(f)["utts"]
580
+
581
+ use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
582
+ # make minibatch list (variable length)
583
+ train = make_batchset(
584
+ train_json,
585
+ args.batch_size,
586
+ args.maxlen_in,
587
+ args.maxlen_out,
588
+ args.minibatches,
589
+ min_batch_size=args.ngpu if args.ngpu > 1 else 1,
590
+ shortest_first=use_sortagrad,
591
+ count=args.batch_count,
592
+ batch_bins=args.batch_bins,
593
+ batch_frames_in=args.batch_frames_in,
594
+ batch_frames_out=args.batch_frames_out,
595
+ batch_frames_inout=args.batch_frames_inout,
596
+ iaxis=0,
597
+ oaxis=0,
598
+ )
599
+ valid = make_batchset(
600
+ valid_json,
601
+ args.batch_size,
602
+ args.maxlen_in,
603
+ args.maxlen_out,
604
+ args.minibatches,
605
+ min_batch_size=args.ngpu if args.ngpu > 1 else 1,
606
+ count=args.batch_count,
607
+ batch_bins=args.batch_bins,
608
+ batch_frames_in=args.batch_frames_in,
609
+ batch_frames_out=args.batch_frames_out,
610
+ batch_frames_inout=args.batch_frames_inout,
611
+ iaxis=0,
612
+ oaxis=0,
613
+ )
614
+
615
+ load_tr = LoadInputsAndTargets(
616
+ mode="asr",
617
+ load_output=True,
618
+ preprocess_conf=args.preprocess_conf,
619
+ preprocess_args={"train": True}, # Switch the mode of preprocessing
620
+ )
621
+ load_cv = LoadInputsAndTargets(
622
+ mode="asr",
623
+ load_output=True,
624
+ preprocess_conf=args.preprocess_conf,
625
+ preprocess_args={"train": False}, # Switch the mode of preprocessing
626
+ )
627
+ # hack to make batchsize argument as 1
628
+ # actual bathsize is included in a list
629
+ # default collate function converts numpy array to pytorch tensor
630
+ # we used an empty collate function instead which returns list
631
+ train_iter = ChainerDataLoader(
632
+ dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
633
+ batch_size=1,
634
+ num_workers=args.n_iter_processes,
635
+ shuffle=not use_sortagrad,
636
+ collate_fn=lambda x: x[0],
637
+ )
638
+ valid_iter = ChainerDataLoader(
639
+ dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
640
+ batch_size=1,
641
+ shuffle=False,
642
+ collate_fn=lambda x: x[0],
643
+ num_workers=args.n_iter_processes,
644
+ )
645
+
646
+ # Set up a trainer
647
+ updater = CustomUpdater(
648
+ model,
649
+ args.grad_clip,
650
+ {"main": train_iter},
651
+ optimizer,
652
+ device,
653
+ args.ngpu,
654
+ args.grad_noise,
655
+ args.accum_grad,
656
+ use_apex=use_apex,
657
+ )
658
+ trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
659
+
660
+ if use_sortagrad:
661
+ trainer.extend(
662
+ ShufflingEnabler([train_iter]),
663
+ trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
664
+ )
665
+
666
+ # Resume from a snapshot
667
+ if args.resume:
668
+ logging.info("resumed from %s" % args.resume)
669
+ torch_resume(args.resume, trainer)
670
+
671
+ # Evaluate the model with the test dataset for each epoch
672
+ if args.save_interval_iters > 0:
673
+ trainer.extend(
674
+ CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu),
675
+ trigger=(args.save_interval_iters, "iteration"),
676
+ )
677
+ else:
678
+ trainer.extend(
679
+ CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu)
680
+ )
681
+
682
+ # Save attention weight each epoch
683
+ is_attn_plot = (
684
+ "transformer" in args.model_module
685
+ or "conformer" in args.model_module
686
+ or mtl_mode in ["att", "mtl", "custom_transducer"]
687
+ )
688
+
689
+ if args.num_save_attention > 0 and is_attn_plot:
690
+ data = sorted(
691
+ list(valid_json.items())[: args.num_save_attention],
692
+ key=lambda x: int(x[1]["input"][0]["shape"][1]),
693
+ reverse=True,
694
+ )
695
+ if hasattr(model, "module"):
696
+ att_vis_fn = model.module.calculate_all_attentions
697
+ plot_class = model.module.attention_plot_class
698
+ else:
699
+ att_vis_fn = model.calculate_all_attentions
700
+ plot_class = model.attention_plot_class
701
+ att_reporter = plot_class(
702
+ att_vis_fn,
703
+ data,
704
+ args.outdir + "/att_ws",
705
+ converter=converter,
706
+ transform=load_cv,
707
+ device=device,
708
+ subsampling_factor=total_subsampling_factor,
709
+ )
710
+ trainer.extend(att_reporter, trigger=(1, "epoch"))
711
+ else:
712
+ att_reporter = None
713
+
714
+ # Save CTC prob at each epoch
715
+ if mtl_mode in ["ctc", "mtl"] and args.num_save_ctc > 0:
716
+ # NOTE: sort it by output lengths
717
+ data = sorted(
718
+ list(valid_json.items())[: args.num_save_ctc],
719
+ key=lambda x: int(x[1]["output"][0]["shape"][0]),
720
+ reverse=True,
721
+ )
722
+ if hasattr(model, "module"):
723
+ ctc_vis_fn = model.module.calculate_all_ctc_probs
724
+ plot_class = model.module.ctc_plot_class
725
+ else:
726
+ ctc_vis_fn = model.calculate_all_ctc_probs
727
+ plot_class = model.ctc_plot_class
728
+ ctc_reporter = plot_class(
729
+ ctc_vis_fn,
730
+ data,
731
+ args.outdir + "/ctc_prob",
732
+ converter=converter,
733
+ transform=load_cv,
734
+ device=device,
735
+ subsampling_factor=total_subsampling_factor,
736
+ )
737
+ trainer.extend(ctc_reporter, trigger=(1, "epoch"))
738
+ else:
739
+ ctc_reporter = None
740
+
741
+ # Make a plot for training and validation values
742
+ if args.num_encs > 1:
743
+ report_keys_loss_ctc = [
744
+ "main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)
745
+ ] + ["validation/main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)]
746
+ report_keys_cer_ctc = [
747
+ "main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)
748
+ ] + ["validation/main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)]
749
+
750
+ if hasattr(model, "is_rnnt"):
751
+ trainer.extend(
752
+ extensions.PlotReport(
753
+ [
754
+ "main/loss",
755
+ "validation/main/loss",
756
+ "main/loss_trans",
757
+ "validation/main/loss_trans",
758
+ "main/loss_ctc",
759
+ "validation/main/loss_ctc",
760
+ "main/loss_lm",
761
+ "validation/main/loss_lm",
762
+ "main/loss_aux_trans",
763
+ "validation/main/loss_aux_trans",
764
+ "main/loss_aux_symm_kl",
765
+ "validation/main/loss_aux_symm_kl",
766
+ ],
767
+ "epoch",
768
+ file_name="loss.png",
769
+ )
770
+ )
771
+ else:
772
+ trainer.extend(
773
+ extensions.PlotReport(
774
+ [
775
+ "main/loss",
776
+ "validation/main/loss",
777
+ "main/loss_ctc",
778
+ "validation/main/loss_ctc",
779
+ "main/loss_att",
780
+ "validation/main/loss_att",
781
+ ]
782
+ + ([] if args.num_encs == 1 else report_keys_loss_ctc),
783
+ "epoch",
784
+ file_name="loss.png",
785
+ )
786
+ )
787
+
788
+ trainer.extend(
789
+ extensions.PlotReport(
790
+ ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
791
+ )
792
+ )
793
+ trainer.extend(
794
+ extensions.PlotReport(
795
+ ["main/cer_ctc", "validation/main/cer_ctc"]
796
+ + ([] if args.num_encs == 1 else report_keys_loss_ctc),
797
+ "epoch",
798
+ file_name="cer.png",
799
+ )
800
+ )
801
+
802
+ # Save best models
803
+ trainer.extend(
804
+ snapshot_object(model, "model.loss.best"),
805
+ trigger=training.triggers.MinValueTrigger("validation/main/loss"),
806
+ )
807
+ if mtl_mode not in ["ctc", "transducer", "custom_transducer"]:
808
+ trainer.extend(
809
+ snapshot_object(model, "model.acc.best"),
810
+ trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
811
+ )
812
+
813
+ # save snapshot which contains model and optimizer states
814
+ if args.save_interval_iters > 0:
815
+ trainer.extend(
816
+ torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
817
+ trigger=(args.save_interval_iters, "iteration"),
818
+ )
819
+
820
+ # save snapshot at every epoch - for model averaging
821
+ trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
822
+
823
+ # epsilon decay in the optimizer
824
+ if args.opt == "adadelta":
825
+ if args.criterion == "acc" and mtl_mode != "ctc":
826
+ trainer.extend(
827
+ restore_snapshot(
828
+ model, args.outdir + "/model.acc.best", load_fn=torch_load
829
+ ),
830
+ trigger=CompareValueTrigger(
831
+ "validation/main/acc",
832
+ lambda best_value, current_value: best_value > current_value,
833
+ ),
834
+ )
835
+ trainer.extend(
836
+ adadelta_eps_decay(args.eps_decay),
837
+ trigger=CompareValueTrigger(
838
+ "validation/main/acc",
839
+ lambda best_value, current_value: best_value > current_value,
840
+ ),
841
+ )
842
+ elif args.criterion == "loss":
843
+ trainer.extend(
844
+ restore_snapshot(
845
+ model, args.outdir + "/model.loss.best", load_fn=torch_load
846
+ ),
847
+ trigger=CompareValueTrigger(
848
+ "validation/main/loss",
849
+ lambda best_value, current_value: best_value < current_value,
850
+ ),
851
+ )
852
+ trainer.extend(
853
+ adadelta_eps_decay(args.eps_decay),
854
+ trigger=CompareValueTrigger(
855
+ "validation/main/loss",
856
+ lambda best_value, current_value: best_value < current_value,
857
+ ),
858
+ )
859
+ # NOTE: In some cases, it may take more than one epoch for the model's loss
860
+ # to escape from a local minimum.
861
+ # Thus, restore_snapshot extension is not used here.
862
+ # see details in https://github.com/espnet/espnet/pull/2171
863
+ elif args.criterion == "loss_eps_decay_only":
864
+ trainer.extend(
865
+ adadelta_eps_decay(args.eps_decay),
866
+ trigger=CompareValueTrigger(
867
+ "validation/main/loss",
868
+ lambda best_value, current_value: best_value < current_value,
869
+ ),
870
+ )
871
+
872
+ # Write a log of evaluation statistics for each epoch
873
+ trainer.extend(
874
+ extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
875
+ )
876
+
877
+ if hasattr(model, "is_rnnt"):
878
+ report_keys = [
879
+ "epoch",
880
+ "iteration",
881
+ "main/loss",
882
+ "main/loss_trans",
883
+ "main/loss_ctc",
884
+ "main/loss_lm",
885
+ "main/loss_aux_trans",
886
+ "main/loss_aux_symm_kl",
887
+ "validation/main/loss",
888
+ "validation/main/loss_trans",
889
+ "validation/main/loss_ctc",
890
+ "validation/main/loss_lm",
891
+ "validation/main/loss_aux_trans",
892
+ "validation/main/loss_aux_symm_kl",
893
+ "elapsed_time",
894
+ ]
895
+ else:
896
+ report_keys = [
897
+ "epoch",
898
+ "iteration",
899
+ "main/loss",
900
+ "main/loss_ctc",
901
+ "main/loss_att",
902
+ "validation/main/loss",
903
+ "validation/main/loss_ctc",
904
+ "validation/main/loss_att",
905
+ "main/acc",
906
+ "validation/main/acc",
907
+ "main/cer_ctc",
908
+ "validation/main/cer_ctc",
909
+ "elapsed_time",
910
+ ] + ([] if args.num_encs == 1 else report_keys_cer_ctc + report_keys_loss_ctc)
911
+
912
+ if args.opt == "adadelta":
913
+ trainer.extend(
914
+ extensions.observe_value(
915
+ "eps",
916
+ lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
917
+ "eps"
918
+ ],
919
+ ),
920
+ trigger=(args.report_interval_iters, "iteration"),
921
+ )
922
+ report_keys.append("eps")
923
+ if args.report_cer:
924
+ report_keys.append("validation/main/cer")
925
+ if args.report_wer:
926
+ report_keys.append("validation/main/wer")
927
+ trainer.extend(
928
+ extensions.PrintReport(report_keys),
929
+ trigger=(args.report_interval_iters, "iteration"),
930
+ )
931
+
932
+ trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
933
+ set_early_stop(trainer, args)
934
+
935
+ if args.tensorboard_dir is not None and args.tensorboard_dir != "":
936
+ trainer.extend(
937
+ TensorboardLogger(
938
+ SummaryWriter(args.tensorboard_dir),
939
+ att_reporter=att_reporter,
940
+ ctc_reporter=ctc_reporter,
941
+ ),
942
+ trigger=(args.report_interval_iters, "iteration"),
943
+ )
944
+ # Run the training
945
+ trainer.run()
946
+ check_early_stop(trainer, args.epochs)
947
+
948
+
949
+ def recog(args):
950
+ """Decode with the given args.
951
+
952
+ Args:
953
+ args (namespace): The program arguments.
954
+
955
+ """
956
+ set_deterministic_pytorch(args)
957
+ model, train_args = load_trained_model(args.model, training=False)
958
+ assert isinstance(model, ASRInterface)
959
+ model.recog_args = args
960
+
961
+ if args.streaming_mode and "transformer" in train_args.model_module:
962
+ raise NotImplementedError("streaming mode for transformer is not implemented")
963
+ logging.info(
964
+ " Total parameter of the model = "
965
+ + str(sum(p.numel() for p in model.parameters()))
966
+ )
967
+
968
+ # read rnnlm
969
+ if args.rnnlm:
970
+ rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
971
+ if getattr(rnnlm_args, "model_module", "default") != "default":
972
+ raise ValueError(
973
+ "use '--api v2' option to decode with non-default language model"
974
+ )
975
+ rnnlm = lm_pytorch.ClassifierWithState(
976
+ lm_pytorch.RNNLM(
977
+ len(train_args.char_list),
978
+ rnnlm_args.layer,
979
+ rnnlm_args.unit,
980
+ getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
981
+ )
982
+ )
983
+ torch_load(args.rnnlm, rnnlm)
984
+ rnnlm.eval()
985
+ else:
986
+ rnnlm = None
987
+
988
+ if args.word_rnnlm:
989
+ rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
990
+ word_dict = rnnlm_args.char_list_dict
991
+ char_dict = {x: i for i, x in enumerate(train_args.char_list)}
992
+ word_rnnlm = lm_pytorch.ClassifierWithState(
993
+ lm_pytorch.RNNLM(
994
+ len(word_dict),
995
+ rnnlm_args.layer,
996
+ rnnlm_args.unit,
997
+ getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
998
+ )
999
+ )
1000
+ torch_load(args.word_rnnlm, word_rnnlm)
1001
+ word_rnnlm.eval()
1002
+
1003
+ if rnnlm is not None:
1004
+ rnnlm = lm_pytorch.ClassifierWithState(
1005
+ extlm_pytorch.MultiLevelLM(
1006
+ word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
1007
+ )
1008
+ )
1009
+ else:
1010
+ rnnlm = lm_pytorch.ClassifierWithState(
1011
+ extlm_pytorch.LookAheadWordLM(
1012
+ word_rnnlm.predictor, word_dict, char_dict
1013
+ )
1014
+ )
1015
+
1016
+ # gpu
1017
+ if args.ngpu == 1:
1018
+ gpu_id = list(range(args.ngpu))
1019
+ logging.info("gpu id: " + str(gpu_id))
1020
+ model.cuda()
1021
+ if rnnlm:
1022
+ rnnlm.cuda()
1023
+
1024
+ # read json data
1025
+ with open(args.recog_json, "rb") as f:
1026
+ js = json.load(f)["utts"]
1027
+ new_js = {}
1028
+
1029
+ load_inputs_and_targets = LoadInputsAndTargets(
1030
+ mode="asr",
1031
+ load_output=False,
1032
+ sort_in_input_length=False,
1033
+ preprocess_conf=train_args.preprocess_conf
1034
+ if args.preprocess_conf is None
1035
+ else args.preprocess_conf,
1036
+ preprocess_args={"train": False},
1037
+ )
1038
+
1039
+ # load transducer beam search
1040
+ if hasattr(model, "is_rnnt"):
1041
+ if hasattr(model, "dec"):
1042
+ trans_decoder = model.dec
1043
+ else:
1044
+ trans_decoder = model.decoder
1045
+ joint_network = model.joint_network
1046
+
1047
+ beam_search_transducer = BeamSearchTransducer(
1048
+ decoder=trans_decoder,
1049
+ joint_network=joint_network,
1050
+ beam_size=args.beam_size,
1051
+ nbest=args.nbest,
1052
+ lm=rnnlm,
1053
+ lm_weight=args.lm_weight,
1054
+ search_type=args.search_type,
1055
+ max_sym_exp=args.max_sym_exp,
1056
+ u_max=args.u_max,
1057
+ nstep=args.nstep,
1058
+ prefix_alpha=args.prefix_alpha,
1059
+ score_norm=args.score_norm,
1060
+ )
1061
+
1062
+ if args.batchsize == 0:
1063
+ with torch.no_grad():
1064
+ for idx, name in enumerate(js.keys(), 1):
1065
+ logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
1066
+ batch = [(name, js[name])]
1067
+ feat = load_inputs_and_targets(batch)
1068
+ feat = (
1069
+ feat[0][0]
1070
+ if args.num_encs == 1
1071
+ else [feat[idx][0] for idx in range(model.num_encs)]
1072
+ )
1073
+ if args.streaming_mode == "window" and args.num_encs == 1:
1074
+ logging.info(
1075
+ "Using streaming recognizer with window size %d frames",
1076
+ args.streaming_window,
1077
+ )
1078
+ se2e = WindowStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
1079
+ for i in range(0, feat.shape[0], args.streaming_window):
1080
+ logging.info(
1081
+ "Feeding frames %d - %d", i, i + args.streaming_window
1082
+ )
1083
+ se2e.accept_input(feat[i : i + args.streaming_window])
1084
+ logging.info("Running offline attention decoder")
1085
+ se2e.decode_with_attention_offline()
1086
+ logging.info("Offline attention decoder finished")
1087
+ nbest_hyps = se2e.retrieve_recognition()
1088
+ elif args.streaming_mode == "segment" and args.num_encs == 1:
1089
+ logging.info(
1090
+ "Using streaming recognizer with threshold value %d",
1091
+ args.streaming_min_blank_dur,
1092
+ )
1093
+ nbest_hyps = []
1094
+ for n in range(args.nbest):
1095
+ nbest_hyps.append({"yseq": [], "score": 0.0})
1096
+ se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
1097
+ r = np.prod(model.subsample)
1098
+ for i in range(0, feat.shape[0], r):
1099
+ hyps = se2e.accept_input(feat[i : i + r])
1100
+ if hyps is not None:
1101
+ text = "".join(
1102
+ [
1103
+ train_args.char_list[int(x)]
1104
+ for x in hyps[0]["yseq"][1:-1]
1105
+ if int(x) != -1
1106
+ ]
1107
+ )
1108
+ text = text.replace(
1109
+ "\u2581", " "
1110
+ ).strip() # for SentencePiece
1111
+ text = text.replace(model.space, " ")
1112
+ text = text.replace(model.blank, "")
1113
+ logging.info(text)
1114
+ for n in range(args.nbest):
1115
+ nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"])
1116
+ nbest_hyps[n]["score"] += hyps[n]["score"]
1117
+ elif hasattr(model, "is_rnnt"):
1118
+ nbest_hyps = model.recognize(feat, beam_search_transducer)
1119
+ else:
1120
+ nbest_hyps = model.recognize(
1121
+ feat, args, train_args.char_list, rnnlm
1122
+ )
1123
+ new_js[name] = add_results_to_json(
1124
+ js[name], nbest_hyps, train_args.char_list
1125
+ )
1126
+
1127
+ else:
1128
+
1129
+ def grouper(n, iterable, fillvalue=None):
1130
+ kargs = [iter(iterable)] * n
1131
+ return zip_longest(*kargs, fillvalue=fillvalue)
1132
+
1133
+ # sort data if batchsize > 1
1134
+ keys = list(js.keys())
1135
+ if args.batchsize > 1:
1136
+ feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
1137
+ sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
1138
+ keys = [keys[i] for i in sorted_index]
1139
+
1140
+ with torch.no_grad():
1141
+ for names in grouper(args.batchsize, keys, None):
1142
+ names = [name for name in names if name]
1143
+ batch = [(name, js[name]) for name in names]
1144
+ feats = (
1145
+ load_inputs_and_targets(batch)[0]
1146
+ if args.num_encs == 1
1147
+ else load_inputs_and_targets(batch)
1148
+ )
1149
+ if args.streaming_mode == "window" and args.num_encs == 1:
1150
+ raise NotImplementedError
1151
+ elif args.streaming_mode == "segment" and args.num_encs == 1:
1152
+ if args.batchsize > 1:
1153
+ raise NotImplementedError
1154
+ feat = feats[0]
1155
+ nbest_hyps = []
1156
+ for n in range(args.nbest):
1157
+ nbest_hyps.append({"yseq": [], "score": 0.0})
1158
+ se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
1159
+ r = np.prod(model.subsample)
1160
+ for i in range(0, feat.shape[0], r):
1161
+ hyps = se2e.accept_input(feat[i : i + r])
1162
+ if hyps is not None:
1163
+ text = "".join(
1164
+ [
1165
+ train_args.char_list[int(x)]
1166
+ for x in hyps[0]["yseq"][1:-1]
1167
+ if int(x) != -1
1168
+ ]
1169
+ )
1170
+ text = text.replace(
1171
+ "\u2581", " "
1172
+ ).strip() # for SentencePiece
1173
+ text = text.replace(model.space, " ")
1174
+ text = text.replace(model.blank, "")
1175
+ logging.info(text)
1176
+ for n in range(args.nbest):
1177
+ nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"])
1178
+ nbest_hyps[n]["score"] += hyps[n]["score"]
1179
+ nbest_hyps = [nbest_hyps]
1180
+ else:
1181
+ nbest_hyps = model.recognize_batch(
1182
+ feats, args, train_args.char_list, rnnlm=rnnlm
1183
+ )
1184
+
1185
+ for i, nbest_hyp in enumerate(nbest_hyps):
1186
+ name = names[i]
1187
+ new_js[name] = add_results_to_json(
1188
+ js[name], nbest_hyp, train_args.char_list
1189
+ )
1190
+
1191
+ with open(args.result_label, "wb") as f:
1192
+ f.write(
1193
+ json.dumps(
1194
+ {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
1195
+ ).encode("utf_8")
1196
+ )
1197
+
1198
+
1199
+ def enhance(args):
1200
+ """Dumping enhanced speech and mask.
1201
+
1202
+ Args:
1203
+ args (namespace): The program arguments.
1204
+ """
1205
+ set_deterministic_pytorch(args)
1206
+ # read training config
1207
+ idim, odim, train_args = get_model_conf(args.model, args.model_conf)
1208
+
1209
+ # TODO(ruizhili): implement enhance for multi-encoder model
1210
+ assert args.num_encs == 1, "number of encoder should be 1 ({} is given)".format(
1211
+ args.num_encs
1212
+ )
1213
+
1214
+ # load trained model parameters
1215
+ logging.info("reading model parameters from " + args.model)
1216
+ model_class = dynamic_import(train_args.model_module)
1217
+ model = model_class(idim, odim, train_args)
1218
+ assert isinstance(model, ASRInterface)
1219
+ torch_load(args.model, model)
1220
+ model.recog_args = args
1221
+
1222
+ # gpu
1223
+ if args.ngpu == 1:
1224
+ gpu_id = list(range(args.ngpu))
1225
+ logging.info("gpu id: " + str(gpu_id))
1226
+ model.cuda()
1227
+
1228
+ # read json data
1229
+ with open(args.recog_json, "rb") as f:
1230
+ js = json.load(f)["utts"]
1231
+
1232
+ load_inputs_and_targets = LoadInputsAndTargets(
1233
+ mode="asr",
1234
+ load_output=False,
1235
+ sort_in_input_length=False,
1236
+ preprocess_conf=None, # Apply pre_process in outer func
1237
+ )
1238
+ if args.batchsize == 0:
1239
+ args.batchsize = 1
1240
+
1241
+ # Creates writers for outputs from the network
1242
+ if args.enh_wspecifier is not None:
1243
+ enh_writer = file_writer_helper(args.enh_wspecifier, filetype=args.enh_filetype)
1244
+ else:
1245
+ enh_writer = None
1246
+
1247
+ # Creates a Transformation instance
1248
+ preprocess_conf = (
1249
+ train_args.preprocess_conf
1250
+ if args.preprocess_conf is None
1251
+ else args.preprocess_conf
1252
+ )
1253
+ if preprocess_conf is not None:
1254
+ logging.info(f"Use preprocessing: {preprocess_conf}")
1255
+ transform = Transformation(preprocess_conf)
1256
+ else:
1257
+ transform = None
1258
+
1259
+ # Creates a IStft instance
1260
+ istft = None
1261
+ frame_shift = args.istft_n_shift # Used for plot the spectrogram
1262
+ if args.apply_istft:
1263
+ if preprocess_conf is not None:
1264
+ # Read the conffile and find stft setting
1265
+ with open(preprocess_conf) as f:
1266
+ # Json format: e.g.
1267
+ # {"process": [{"type": "stft",
1268
+ # "win_length": 400,
1269
+ # "n_fft": 512, "n_shift": 160,
1270
+ # "window": "han"},
1271
+ # {"type": "foo", ...}, ...]}
1272
+ conf = json.load(f)
1273
+ assert "process" in conf, conf
1274
+ # Find stft setting
1275
+ for p in conf["process"]:
1276
+ if p["type"] == "stft":
1277
+ istft = IStft(
1278
+ win_length=p["win_length"],
1279
+ n_shift=p["n_shift"],
1280
+ window=p.get("window", "hann"),
1281
+ )
1282
+ logging.info(
1283
+ "stft is found in {}. "
1284
+ "Setting istft config from it\n{}".format(
1285
+ preprocess_conf, istft
1286
+ )
1287
+ )
1288
+ frame_shift = p["n_shift"]
1289
+ break
1290
+ if istft is None:
1291
+ # Set from command line arguments
1292
+ istft = IStft(
1293
+ win_length=args.istft_win_length,
1294
+ n_shift=args.istft_n_shift,
1295
+ window=args.istft_window,
1296
+ )
1297
+ logging.info(
1298
+ "Setting istft config from the command line args\n{}".format(istft)
1299
+ )
1300
+
1301
+ # sort data
1302
+ keys = list(js.keys())
1303
+ feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
1304
+ sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
1305
+ keys = [keys[i] for i in sorted_index]
1306
+
1307
+ def grouper(n, iterable, fillvalue=None):
1308
+ kargs = [iter(iterable)] * n
1309
+ return zip_longest(*kargs, fillvalue=fillvalue)
1310
+
1311
+ num_images = 0
1312
+ if not os.path.exists(args.image_dir):
1313
+ os.makedirs(args.image_dir)
1314
+
1315
+ for names in grouper(args.batchsize, keys, None):
1316
+ batch = [(name, js[name]) for name in names]
1317
+
1318
+ # May be in time region: (Batch, [Time, Channel])
1319
+ org_feats = load_inputs_and_targets(batch)[0]
1320
+ if transform is not None:
1321
+ # May be in time-freq region: : (Batch, [Time, Channel, Freq])
1322
+ feats = transform(org_feats, train=False)
1323
+ else:
1324
+ feats = org_feats
1325
+
1326
+ with torch.no_grad():
1327
+ enhanced, mask, ilens = model.enhance(feats)
1328
+
1329
+ for idx, name in enumerate(names):
1330
+ # Assuming mask, feats : [Batch, Time, Channel. Freq]
1331
+ # enhanced : [Batch, Time, Freq]
1332
+ enh = enhanced[idx][: ilens[idx]]
1333
+ mas = mask[idx][: ilens[idx]]
1334
+ feat = feats[idx]
1335
+
1336
+ # Plot spectrogram
1337
+ if args.image_dir is not None and num_images < args.num_images:
1338
+ import matplotlib.pyplot as plt
1339
+
1340
+ num_images += 1
1341
+ ref_ch = 0
1342
+
1343
+ plt.figure(figsize=(20, 10))
1344
+ plt.subplot(4, 1, 1)
1345
+ plt.title("Mask [ref={}ch]".format(ref_ch))
1346
+ plot_spectrogram(
1347
+ plt,
1348
+ mas[:, ref_ch].T,
1349
+ fs=args.fs,
1350
+ mode="linear",
1351
+ frame_shift=frame_shift,
1352
+ bottom=False,
1353
+ labelbottom=False,
1354
+ )
1355
+
1356
+ plt.subplot(4, 1, 2)
1357
+ plt.title("Noisy speech [ref={}ch]".format(ref_ch))
1358
+ plot_spectrogram(
1359
+ plt,
1360
+ feat[:, ref_ch].T,
1361
+ fs=args.fs,
1362
+ mode="db",
1363
+ frame_shift=frame_shift,
1364
+ bottom=False,
1365
+ labelbottom=False,
1366
+ )
1367
+
1368
+ plt.subplot(4, 1, 3)
1369
+ plt.title("Masked speech [ref={}ch]".format(ref_ch))
1370
+ plot_spectrogram(
1371
+ plt,
1372
+ (feat[:, ref_ch] * mas[:, ref_ch]).T,
1373
+ frame_shift=frame_shift,
1374
+ fs=args.fs,
1375
+ mode="db",
1376
+ bottom=False,
1377
+ labelbottom=False,
1378
+ )
1379
+
1380
+ plt.subplot(4, 1, 4)
1381
+ plt.title("Enhanced speech")
1382
+ plot_spectrogram(
1383
+ plt, enh.T, fs=args.fs, mode="db", frame_shift=frame_shift
1384
+ )
1385
+
1386
+ plt.savefig(os.path.join(args.image_dir, name + ".png"))
1387
+ plt.clf()
1388
+
1389
+ # Write enhanced wave files
1390
+ if enh_writer is not None:
1391
+ if istft is not None:
1392
+ enh = istft(enh)
1393
+ else:
1394
+ enh = enh
1395
+
1396
+ if args.keep_length:
1397
+ if len(org_feats[idx]) < len(enh):
1398
+ # Truncate the frames added by stft padding
1399
+ enh = enh[: len(org_feats[idx])]
1400
+ elif len(org_feats) > len(enh):
1401
+ padwidth = [(0, (len(org_feats[idx]) - len(enh)))] + [
1402
+ (0, 0)
1403
+ ] * (enh.ndim - 1)
1404
+ enh = np.pad(enh, padwidth, mode="constant")
1405
+
1406
+ if args.enh_filetype in ("sound", "sound.hdf5"):
1407
+ enh_writer[name] = (args.fs, enh)
1408
+ else:
1409
+ # Hint: To dump stft_signal, mask or etc,
1410
+ # enh_filetype='hdf5' might be convenient.
1411
+ enh_writer[name] = enh
1412
+
1413
+ if num_images >= args.num_images and enh_writer is None:
1414
+ logging.info("Breaking the process.")
1415
+ break
1416
+
1417
+
1418
+ def ctc_align(args):
1419
+ """CTC forced alignments with the given args.
1420
+
1421
+ Args:
1422
+ args (namespace): The program arguments.
1423
+ """
1424
+
1425
+ def add_alignment_to_json(js, alignment, char_list):
1426
+ """Add N-best results to json.
1427
+
1428
+ Args:
1429
+ js (dict[str, Any]): Groundtruth utterance dict.
1430
+ alignment (list[int]): List of alignment.
1431
+ char_list (list[str]): List of characters.
1432
+
1433
+ Returns:
1434
+ dict[str, Any]: N-best results added utterance dict.
1435
+
1436
+ """
1437
+ # copy old json info
1438
+ new_js = dict()
1439
+ new_js["ctc_alignment"] = []
1440
+
1441
+ alignment_tokens = []
1442
+ for idx, a in enumerate(alignment):
1443
+ alignment_tokens.append(char_list[a])
1444
+ alignment_tokens = " ".join(alignment_tokens)
1445
+
1446
+ new_js["ctc_alignment"] = alignment_tokens
1447
+
1448
+ return new_js
1449
+
1450
+ set_deterministic_pytorch(args)
1451
+ model, train_args = load_trained_model(args.model)
1452
+ assert isinstance(model, ASRInterface)
1453
+ model.eval()
1454
+
1455
+ load_inputs_and_targets = LoadInputsAndTargets(
1456
+ mode="asr",
1457
+ load_output=True,
1458
+ sort_in_input_length=False,
1459
+ preprocess_conf=train_args.preprocess_conf
1460
+ if args.preprocess_conf is None
1461
+ else args.preprocess_conf,
1462
+ preprocess_args={"train": False},
1463
+ )
1464
+
1465
+ if args.ngpu > 1:
1466
+ raise NotImplementedError("only single GPU decoding is supported")
1467
+ if args.ngpu == 1:
1468
+ device = "cuda"
1469
+ else:
1470
+ device = "cpu"
1471
+ dtype = getattr(torch, args.dtype)
1472
+ logging.info(f"Decoding device={device}, dtype={dtype}")
1473
+ model.to(device=device, dtype=dtype).eval()
1474
+
1475
+ # read json data
1476
+ with open(args.align_json, "rb") as f:
1477
+ js = json.load(f)["utts"]
1478
+ new_js = {}
1479
+ if args.batchsize == 0:
1480
+ with torch.no_grad():
1481
+ for idx, name in enumerate(js.keys(), 1):
1482
+ logging.info("(%d/%d) aligning " + name, idx, len(js.keys()))
1483
+ batch = [(name, js[name])]
1484
+ feat, label = load_inputs_and_targets(batch)
1485
+ feat = feat[0]
1486
+ label = label[0]
1487
+ enc = model.encode(torch.as_tensor(feat).to(device)).unsqueeze(0)
1488
+ alignment = model.ctc.forced_align(enc, label)
1489
+ new_js[name] = add_alignment_to_json(
1490
+ js[name], alignment, train_args.char_list
1491
+ )
1492
+ else:
1493
+ raise NotImplementedError("Align_batch is not implemented.")
1494
+
1495
+ with open(args.result_label, "wb") as f:
1496
+ f.write(
1497
+ json.dumps(
1498
+ {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
1499
+ ).encode("utf_8")
1500
+ )
espnet/asr/pytorch_backend/asr_init.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Finetuning methods."""
2
+
3
+ import logging
4
+ import os
5
+ import torch
6
+
7
+ from collections import OrderedDict
8
+
9
+ from espnet.asr.asr_utils import get_model_conf
10
+ from espnet.asr.asr_utils import torch_load
11
+ from espnet.nets.asr_interface import ASRInterface
12
+ from espnet.nets.mt_interface import MTInterface
13
+ from espnet.nets.pytorch_backend.transducer.utils import custom_torch_load
14
+ from espnet.nets.tts_interface import TTSInterface
15
+ from espnet.utils.dynamic_import import dynamic_import
16
+
17
+
18
+ def freeze_modules(model, modules):
19
+ """Freeze model parameters according to modules list.
20
+
21
+ Args:
22
+ model (torch.nn.Module): main model to update
23
+ modules (list): specified module list for freezing
24
+
25
+ Return:
26
+ model (torch.nn.Module): updated model
27
+ model_params (filter): filtered model parameters
28
+
29
+ """
30
+ for mod, param in model.named_parameters():
31
+ if any(mod.startswith(m) for m in modules):
32
+ logging.info(f"freezing {mod}, it will not be updated.")
33
+ param.requires_grad = False
34
+
35
+ model_params = filter(lambda x: x.requires_grad, model.parameters())
36
+
37
+ return model, model_params
38
+
39
+
40
+ def transfer_verification(model_state_dict, partial_state_dict, modules):
41
+ """Verify tuples (key, shape) for input model modules match specified modules.
42
+
43
+ Args:
44
+ model_state_dict (OrderedDict): the initial model state_dict
45
+ partial_state_dict (OrderedDict): the trained model state_dict
46
+ modules (list): specified module list for transfer
47
+
48
+ Return:
49
+ (boolean): allow transfer
50
+
51
+ """
52
+ modules_model = []
53
+ partial_modules = []
54
+
55
+ for key_p, value_p in partial_state_dict.items():
56
+ if any(key_p.startswith(m) for m in modules):
57
+ partial_modules += [(key_p, value_p.shape)]
58
+
59
+ for key_m, value_m in model_state_dict.items():
60
+ if any(key_m.startswith(m) for m in modules):
61
+ modules_model += [(key_m, value_m.shape)]
62
+
63
+ len_match = len(modules_model) == len(partial_modules)
64
+
65
+ module_match = sorted(modules_model, key=lambda x: (x[0], x[1])) == sorted(
66
+ partial_modules, key=lambda x: (x[0], x[1])
67
+ )
68
+
69
+ return len_match and module_match
70
+
71
+
72
+ def get_partial_state_dict(model_state_dict, modules):
73
+ """Create state_dict with specified modules matching input model modules.
74
+
75
+ Note that get_partial_lm_state_dict is used if a LM specified.
76
+
77
+ Args:
78
+ model_state_dict (OrderedDict): trained model state_dict
79
+ modules (list): specified module list for transfer
80
+
81
+ Return:
82
+ new_state_dict (OrderedDict): the updated state_dict
83
+
84
+ """
85
+ new_state_dict = OrderedDict()
86
+
87
+ for key, value in model_state_dict.items():
88
+ if any(key.startswith(m) for m in modules):
89
+ new_state_dict[key] = value
90
+
91
+ return new_state_dict
92
+
93
+
94
+ def get_lm_state_dict(lm_state_dict):
95
+ """Create compatible ASR decoder state dict from LM state dict.
96
+
97
+ Args:
98
+ lm_state_dict (OrderedDict): pre-trained LM state_dict
99
+
100
+ Return:
101
+ new_state_dict (OrderedDict): LM state_dict with updated keys
102
+
103
+ """
104
+ new_state_dict = OrderedDict()
105
+
106
+ for key, value in list(lm_state_dict.items()):
107
+ if key == "predictor.embed.weight":
108
+ new_state_dict["dec.embed.weight"] = value
109
+ elif key.startswith("predictor.rnn."):
110
+ _split = key.split(".")
111
+
112
+ new_key = "dec.decoder." + _split[2] + "." + _split[3] + "_l0"
113
+ new_state_dict[new_key] = value
114
+
115
+ return new_state_dict
116
+
117
+
118
+ def filter_modules(model_state_dict, modules):
119
+ """Filter non-matched modules in module_state_dict.
120
+
121
+ Args:
122
+ model_state_dict (OrderedDict): trained model state_dict
123
+ modules (list): specified module list for transfer
124
+
125
+ Return:
126
+ new_mods (list): the update module list
127
+
128
+ """
129
+ new_mods = []
130
+ incorrect_mods = []
131
+
132
+ mods_model = list(model_state_dict.keys())
133
+ for mod in modules:
134
+ if any(key.startswith(mod) for key in mods_model):
135
+ new_mods += [mod]
136
+ else:
137
+ incorrect_mods += [mod]
138
+
139
+ if incorrect_mods:
140
+ logging.warning(
141
+ "module(s) %s don't match or (partially match) "
142
+ "available modules in model.",
143
+ incorrect_mods,
144
+ )
145
+ logging.warning("for information, the existing modules in model are:")
146
+ logging.warning("%s", mods_model)
147
+
148
+ return new_mods
149
+
150
+
151
+ def load_trained_model(model_path, training=True):
152
+ """Load the trained model for recognition.
153
+
154
+ Args:
155
+ model_path (str): Path to model.***.best
156
+
157
+ """
158
+ idim, odim, train_args = get_model_conf(
159
+ model_path, os.path.join(os.path.dirname(model_path), "model.json")
160
+ )
161
+
162
+ logging.warning("reading model parameters from " + model_path)
163
+
164
+ if hasattr(train_args, "model_module"):
165
+ model_module = train_args.model_module
166
+ else:
167
+ model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
168
+ # CTC Loss is not needed, default to builtin to prevent import errors
169
+ if hasattr(train_args, "ctc_type"):
170
+ train_args.ctc_type = "builtin"
171
+
172
+ model_class = dynamic_import(model_module)
173
+
174
+ if "transducer" in model_module:
175
+ model = model_class(idim, odim, train_args, training=training)
176
+ custom_torch_load(model_path, model, training=training)
177
+ else:
178
+ model = model_class(idim, odim, train_args)
179
+ torch_load(model_path, model)
180
+
181
+ return model, train_args
182
+
183
+
184
+ def get_trained_model_state_dict(model_path):
185
+ """Extract the trained model state dict for pre-initialization.
186
+
187
+ Args:
188
+ model_path (str): Path to model.***.best
189
+
190
+ Return:
191
+ model.state_dict() (OrderedDict): the loaded model state_dict
192
+ (bool): Boolean defining whether the model is an LM
193
+
194
+ """
195
+ conf_path = os.path.join(os.path.dirname(model_path), "model.json")
196
+ if "rnnlm" in model_path:
197
+ logging.warning("reading model parameters from %s", model_path)
198
+
199
+ return get_lm_state_dict(torch.load(model_path))
200
+
201
+ idim, odim, args = get_model_conf(model_path, conf_path)
202
+
203
+ logging.warning("reading model parameters from " + model_path)
204
+
205
+ if hasattr(args, "model_module"):
206
+ model_module = args.model_module
207
+ else:
208
+ model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
209
+
210
+ model_class = dynamic_import(model_module)
211
+ model = model_class(idim, odim, args)
212
+ torch_load(model_path, model)
213
+ assert (
214
+ isinstance(model, MTInterface)
215
+ or isinstance(model, ASRInterface)
216
+ or isinstance(model, TTSInterface)
217
+ )
218
+
219
+ return model.state_dict()
220
+
221
+
222
+ def load_trained_modules(idim, odim, args, interface=ASRInterface):
223
+ """Load model encoder or/and decoder modules with ESPNET pre-trained model(s).
224
+
225
+ Args:
226
+ idim (int): initial input dimension.
227
+ odim (int): initial output dimension.
228
+ args (Namespace): The initial model arguments.
229
+ interface (Interface): ASRInterface or STInterface or TTSInterface.
230
+
231
+ Return:
232
+ model (torch.nn.Module): The model with pretrained modules.
233
+
234
+ """
235
+
236
+ def print_new_keys(state_dict, modules, model_path):
237
+ logging.warning("loading %s from model: %s", modules, model_path)
238
+
239
+ for k in state_dict.keys():
240
+ logging.warning("override %s" % k)
241
+
242
+ enc_model_path = args.enc_init
243
+ dec_model_path = args.dec_init
244
+ enc_modules = args.enc_init_mods
245
+ dec_modules = args.dec_init_mods
246
+
247
+ model_class = dynamic_import(args.model_module)
248
+ main_model = model_class(idim, odim, args)
249
+ assert isinstance(main_model, interface)
250
+
251
+ main_state_dict = main_model.state_dict()
252
+
253
+ logging.warning("model(s) found for pre-initialization")
254
+ for model_path, modules in [
255
+ (enc_model_path, enc_modules),
256
+ (dec_model_path, dec_modules),
257
+ ]:
258
+ if model_path is not None:
259
+ if os.path.isfile(model_path):
260
+ model_state_dict = get_trained_model_state_dict(model_path)
261
+
262
+ modules = filter_modules(model_state_dict, modules)
263
+
264
+ partial_state_dict = get_partial_state_dict(model_state_dict, modules)
265
+
266
+ if partial_state_dict:
267
+ if transfer_verification(
268
+ main_state_dict, partial_state_dict, modules
269
+ ):
270
+ print_new_keys(partial_state_dict, modules, model_path)
271
+ main_state_dict.update(partial_state_dict)
272
+ else:
273
+ logging.warning(
274
+ f"modules {modules} in model {model_path} "
275
+ f"don't match your training config",
276
+ )
277
+ else:
278
+ logging.warning("model was not found : %s", model_path)
279
+
280
+ main_model.load_state_dict(main_state_dict)
281
+
282
+ return main_model
espnet/asr/pytorch_backend/asr_mix.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ This script is used for multi-speaker speech recognition.
5
+
6
+ Copyright 2017 Johns Hopkins University (Shinji Watanabe)
7
+ Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
8
+ """
9
+ import json
10
+ import logging
11
+ import os
12
+
13
+ # chainer related
14
+ from chainer import training
15
+ from chainer.training import extensions
16
+ from itertools import zip_longest as zip_longest
17
+ import numpy as np
18
+ from tensorboardX import SummaryWriter
19
+ import torch
20
+
21
+ from espnet.asr.asr_mix_utils import add_results_to_json
22
+ from espnet.asr.asr_utils import adadelta_eps_decay
23
+
24
+ from espnet.asr.asr_utils import CompareValueTrigger
25
+ from espnet.asr.asr_utils import get_model_conf
26
+ from espnet.asr.asr_utils import restore_snapshot
27
+ from espnet.asr.asr_utils import snapshot_object
28
+ from espnet.asr.asr_utils import torch_load
29
+ from espnet.asr.asr_utils import torch_resume
30
+ from espnet.asr.asr_utils import torch_snapshot
31
+ from espnet.asr.pytorch_backend.asr import CustomEvaluator
32
+ from espnet.asr.pytorch_backend.asr import CustomUpdater
33
+ from espnet.asr.pytorch_backend.asr import load_trained_model
34
+ import espnet.lm.pytorch_backend.extlm as extlm_pytorch
35
+ from espnet.nets.asr_interface import ASRInterface
36
+ from espnet.nets.pytorch_backend.e2e_asr_mix import pad_list
37
+ import espnet.nets.pytorch_backend.lm.default as lm_pytorch
38
+ from espnet.utils.dataset import ChainerDataLoader
39
+ from espnet.utils.dataset import TransformDataset
40
+ from espnet.utils.deterministic_utils import set_deterministic_pytorch
41
+ from espnet.utils.dynamic_import import dynamic_import
42
+ from espnet.utils.io_utils import LoadInputsAndTargets
43
+ from espnet.utils.training.batchfy import make_batchset
44
+ from espnet.utils.training.iterators import ShufflingEnabler
45
+ from espnet.utils.training.tensorboard_logger import TensorboardLogger
46
+ from espnet.utils.training.train_utils import check_early_stop
47
+ from espnet.utils.training.train_utils import set_early_stop
48
+
49
+ import matplotlib
50
+
51
+ matplotlib.use("Agg")
52
+
53
+
54
+ class CustomConverter(object):
55
+ """Custom batch converter for Pytorch.
56
+
57
+ Args:
58
+ subsampling_factor (int): The subsampling factor.
59
+ dtype (torch.dtype): Data type to convert.
60
+
61
+ """
62
+
63
+ def __init__(self, subsampling_factor=1, dtype=torch.float32, num_spkrs=2):
64
+ """Initialize the converter."""
65
+ self.subsampling_factor = subsampling_factor
66
+ self.ignore_id = -1
67
+ self.dtype = dtype
68
+ self.num_spkrs = num_spkrs
69
+
70
+ def __call__(self, batch, device=torch.device("cpu")):
71
+ """Transform a batch and send it to a device.
72
+
73
+ Args:
74
+ batch (list(tuple(str, dict[str, dict[str, Any]]))): The batch to transform.
75
+ device (torch.device): The device to send to.
76
+
77
+ Returns:
78
+ tuple(torch.Tensor, torch.Tensor, torch.Tensor): Transformed batch.
79
+
80
+ """
81
+ # batch should be located in list
82
+ assert len(batch) == 1
83
+ xs, ys = batch[0][0], batch[0][-self.num_spkrs :]
84
+
85
+ # perform subsampling
86
+ if self.subsampling_factor > 1:
87
+ xs = [x[:: self.subsampling_factor, :] for x in xs]
88
+
89
+ # get batch of lengths of input sequences
90
+ ilens = np.array([x.shape[0] for x in xs])
91
+
92
+ # perform padding and convert to tensor
93
+ # currently only support real number
94
+ if xs[0].dtype.kind == "c":
95
+ xs_pad_real = pad_list(
96
+ [torch.from_numpy(x.real).float() for x in xs], 0
97
+ ).to(device, dtype=self.dtype)
98
+ xs_pad_imag = pad_list(
99
+ [torch.from_numpy(x.imag).float() for x in xs], 0
100
+ ).to(device, dtype=self.dtype)
101
+ # Note(kamo):
102
+ # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
103
+ # Don't create ComplexTensor and give it to E2E here
104
+ # because torch.nn.DataParallel can't handle it.
105
+ xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
106
+ else:
107
+ xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(
108
+ device, dtype=self.dtype
109
+ )
110
+
111
+ ilens = torch.from_numpy(ilens).to(device)
112
+ if not isinstance(ys[0], np.ndarray):
113
+ ys_pad = []
114
+ for i in range(len(ys)): # speakers
115
+ ys_pad += [torch.from_numpy(y).long() for y in ys[i]]
116
+ ys_pad = pad_list(ys_pad, self.ignore_id)
117
+ ys_pad = (
118
+ ys_pad.view(self.num_spkrs, -1, ys_pad.size(1))
119
+ .transpose(0, 1)
120
+ .to(device)
121
+ ) # (B, num_spkrs, Tmax)
122
+ else:
123
+ ys_pad = pad_list(
124
+ [torch.from_numpy(y).long() for y in ys], self.ignore_id
125
+ ).to(device)
126
+
127
+ return xs_pad, ilens, ys_pad
128
+
129
+
130
+ def train(args):
131
+ """Train with the given args.
132
+
133
+ Args:
134
+ args (namespace): The program arguments.
135
+
136
+ """
137
+ set_deterministic_pytorch(args)
138
+
139
+ # check cuda availability
140
+ if not torch.cuda.is_available():
141
+ logging.warning("cuda is not available")
142
+
143
+ # get input and output dimension info
144
+ with open(args.valid_json, "rb") as f:
145
+ valid_json = json.load(f)["utts"]
146
+ utts = list(valid_json.keys())
147
+ idim = int(valid_json[utts[0]]["input"][0]["shape"][-1])
148
+ odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
149
+ logging.info("#input dims : " + str(idim))
150
+ logging.info("#output dims: " + str(odim))
151
+
152
+ # specify attention, CTC, hybrid mode
153
+ if args.mtlalpha == 1.0:
154
+ mtl_mode = "ctc"
155
+ logging.info("Pure CTC mode")
156
+ elif args.mtlalpha == 0.0:
157
+ mtl_mode = "att"
158
+ logging.info("Pure attention mode")
159
+ else:
160
+ mtl_mode = "mtl"
161
+ logging.info("Multitask learning mode")
162
+
163
+ # specify model architecture
164
+ model_class = dynamic_import(args.model_module)
165
+ model = model_class(idim, odim, args)
166
+ assert isinstance(model, ASRInterface)
167
+ subsampling_factor = model.subsample[0]
168
+
169
+ if args.rnnlm is not None:
170
+ rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
171
+ rnnlm = lm_pytorch.ClassifierWithState(
172
+ lm_pytorch.RNNLM(
173
+ len(args.char_list),
174
+ rnnlm_args.layer,
175
+ rnnlm_args.unit,
176
+ getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
177
+ )
178
+ )
179
+ torch.load(args.rnnlm, rnnlm)
180
+ model.rnnlm = rnnlm
181
+
182
+ # write model config
183
+ if not os.path.exists(args.outdir):
184
+ os.makedirs(args.outdir)
185
+ model_conf = args.outdir + "/model.json"
186
+ with open(model_conf, "wb") as f:
187
+ logging.info("writing a model config file to " + model_conf)
188
+ f.write(
189
+ json.dumps(
190
+ (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
191
+ ).encode("utf_8")
192
+ )
193
+ for key in sorted(vars(args).keys()):
194
+ logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
195
+
196
+ reporter = model.reporter
197
+
198
+ # check the use of multi-gpu
199
+ if args.ngpu > 1:
200
+ if args.batch_size != 0:
201
+ logging.warning(
202
+ "batch size is automatically increased (%d -> %d)"
203
+ % (args.batch_size, args.batch_size * args.ngpu)
204
+ )
205
+ args.batch_size *= args.ngpu
206
+
207
+ # set torch device
208
+ device = torch.device("cuda" if args.ngpu > 0 else "cpu")
209
+ if args.train_dtype in ("float16", "float32", "float64"):
210
+ dtype = getattr(torch, args.train_dtype)
211
+ else:
212
+ dtype = torch.float32
213
+ model = model.to(device=device, dtype=dtype)
214
+
215
+ logging.warning(
216
+ "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
217
+ sum(p.numel() for p in model.parameters()),
218
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
219
+ sum(p.numel() for p in model.parameters() if p.requires_grad)
220
+ * 100.0
221
+ / sum(p.numel() for p in model.parameters()),
222
+ )
223
+ )
224
+
225
+ # Setup an optimizer
226
+ if args.opt == "adadelta":
227
+ optimizer = torch.optim.Adadelta(
228
+ model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay
229
+ )
230
+ elif args.opt == "adam":
231
+ optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay)
232
+ elif args.opt == "noam":
233
+ from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
234
+
235
+ optimizer = get_std_opt(
236
+ model.parameters(),
237
+ args.adim,
238
+ args.transformer_warmup_steps,
239
+ args.transformer_lr,
240
+ )
241
+ else:
242
+ raise NotImplementedError("unknown optimizer: " + args.opt)
243
+
244
+ # setup apex.amp
245
+ if args.train_dtype in ("O0", "O1", "O2", "O3"):
246
+ try:
247
+ from apex import amp
248
+ except ImportError as e:
249
+ logging.error(
250
+ f"You need to install apex for --train-dtype {args.train_dtype}. "
251
+ "See https://github.com/NVIDIA/apex#linux"
252
+ )
253
+ raise e
254
+ if args.opt == "noam":
255
+ model, optimizer.optimizer = amp.initialize(
256
+ model, optimizer.optimizer, opt_level=args.train_dtype
257
+ )
258
+ else:
259
+ model, optimizer = amp.initialize(
260
+ model, optimizer, opt_level=args.train_dtype
261
+ )
262
+ use_apex = True
263
+ else:
264
+ use_apex = False
265
+
266
+ # FIXME: TOO DIRTY HACK
267
+ setattr(optimizer, "target", reporter)
268
+ setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
269
+
270
+ # Setup a converter
271
+ converter = CustomConverter(
272
+ subsampling_factor=subsampling_factor, dtype=dtype, num_spkrs=args.num_spkrs
273
+ )
274
+
275
+ # read json data
276
+ with open(args.train_json, "rb") as f:
277
+ train_json = json.load(f)["utts"]
278
+ with open(args.valid_json, "rb") as f:
279
+ valid_json = json.load(f)["utts"]
280
+
281
+ use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
282
+ # make minibatch list (variable length)
283
+ train = make_batchset(
284
+ train_json,
285
+ args.batch_size,
286
+ args.maxlen_in,
287
+ args.maxlen_out,
288
+ args.minibatches,
289
+ min_batch_size=args.ngpu if args.ngpu > 1 else 1,
290
+ shortest_first=use_sortagrad,
291
+ count=args.batch_count,
292
+ batch_bins=args.batch_bins,
293
+ batch_frames_in=args.batch_frames_in,
294
+ batch_frames_out=args.batch_frames_out,
295
+ batch_frames_inout=args.batch_frames_inout,
296
+ iaxis=0,
297
+ oaxis=-1,
298
+ )
299
+ valid = make_batchset(
300
+ valid_json,
301
+ args.batch_size,
302
+ args.maxlen_in,
303
+ args.maxlen_out,
304
+ args.minibatches,
305
+ min_batch_size=args.ngpu if args.ngpu > 1 else 1,
306
+ count=args.batch_count,
307
+ batch_bins=args.batch_bins,
308
+ batch_frames_in=args.batch_frames_in,
309
+ batch_frames_out=args.batch_frames_out,
310
+ batch_frames_inout=args.batch_frames_inout,
311
+ iaxis=0,
312
+ oaxis=-1,
313
+ )
314
+
315
+ load_tr = LoadInputsAndTargets(
316
+ mode="asr",
317
+ load_output=True,
318
+ preprocess_conf=args.preprocess_conf,
319
+ preprocess_args={"train": True}, # Switch the mode of preprocessing
320
+ )
321
+ load_cv = LoadInputsAndTargets(
322
+ mode="asr",
323
+ load_output=True,
324
+ preprocess_conf=args.preprocess_conf,
325
+ preprocess_args={"train": False}, # Switch the mode of preprocessing
326
+ )
327
+ # hack to make batchsize argument as 1
328
+ # actual bathsize is included in a list
329
+ # default collate function converts numpy array to pytorch tensor
330
+ # we used an empty collate function instead which returns list
331
+ train_iter = {
332
+ "main": ChainerDataLoader(
333
+ dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
334
+ batch_size=1,
335
+ num_workers=args.n_iter_processes,
336
+ shuffle=True,
337
+ collate_fn=lambda x: x[0],
338
+ )
339
+ }
340
+ valid_iter = {
341
+ "main": ChainerDataLoader(
342
+ dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
343
+ batch_size=1,
344
+ shuffle=False,
345
+ collate_fn=lambda x: x[0],
346
+ num_workers=args.n_iter_processes,
347
+ )
348
+ }
349
+
350
+ # Set up a trainer
351
+ updater = CustomUpdater(
352
+ model,
353
+ args.grad_clip,
354
+ train_iter,
355
+ optimizer,
356
+ device,
357
+ args.ngpu,
358
+ args.grad_noise,
359
+ args.accum_grad,
360
+ use_apex=use_apex,
361
+ )
362
+ trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
363
+
364
+ if use_sortagrad:
365
+ trainer.extend(
366
+ ShufflingEnabler([train_iter]),
367
+ trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
368
+ )
369
+
370
+ # Resume from a snapshot
371
+ if args.resume:
372
+ logging.info("resumed from %s" % args.resume)
373
+ torch_resume(args.resume, trainer)
374
+
375
+ # Evaluate the model with the test dataset for each epoch
376
+ trainer.extend(CustomEvaluator(model, valid_iter, reporter, device, args.ngpu))
377
+
378
+ # Save attention weight each epoch
379
+ if args.num_save_attention > 0 and args.mtlalpha != 1.0:
380
+ data = sorted(
381
+ list(valid_json.items())[: args.num_save_attention],
382
+ key=lambda x: int(x[1]["input"][0]["shape"][1]),
383
+ reverse=True,
384
+ )
385
+ if hasattr(model, "module"):
386
+ att_vis_fn = model.module.calculate_all_attentions
387
+ plot_class = model.module.attention_plot_class
388
+ else:
389
+ att_vis_fn = model.calculate_all_attentions
390
+ plot_class = model.attention_plot_class
391
+ att_reporter = plot_class(
392
+ att_vis_fn,
393
+ data,
394
+ args.outdir + "/att_ws",
395
+ converter=converter,
396
+ transform=load_cv,
397
+ device=device,
398
+ )
399
+ trainer.extend(att_reporter, trigger=(1, "epoch"))
400
+ else:
401
+ att_reporter = None
402
+
403
+ # Make a plot for training and validation values
404
+ trainer.extend(
405
+ extensions.PlotReport(
406
+ [
407
+ "main/loss",
408
+ "validation/main/loss",
409
+ "main/loss_ctc",
410
+ "validation/main/loss_ctc",
411
+ "main/loss_att",
412
+ "validation/main/loss_att",
413
+ ],
414
+ "epoch",
415
+ file_name="loss.png",
416
+ )
417
+ )
418
+ trainer.extend(
419
+ extensions.PlotReport(
420
+ ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
421
+ )
422
+ )
423
+ trainer.extend(
424
+ extensions.PlotReport(
425
+ ["main/cer_ctc", "validation/main/cer_ctc"], "epoch", file_name="cer.png"
426
+ )
427
+ )
428
+
429
+ # Save best models
430
+ trainer.extend(
431
+ snapshot_object(model, "model.loss.best"),
432
+ trigger=training.triggers.MinValueTrigger("validation/main/loss"),
433
+ )
434
+ if mtl_mode != "ctc":
435
+ trainer.extend(
436
+ snapshot_object(model, "model.acc.best"),
437
+ trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
438
+ )
439
+
440
+ # save snapshot which contains model and optimizer states
441
+ trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
442
+
443
+ # epsilon decay in the optimizer
444
+ if args.opt == "adadelta":
445
+ if args.criterion == "acc" and mtl_mode != "ctc":
446
+ trainer.extend(
447
+ restore_snapshot(
448
+ model, args.outdir + "/model.acc.best", load_fn=torch_load
449
+ ),
450
+ trigger=CompareValueTrigger(
451
+ "validation/main/acc",
452
+ lambda best_value, current_value: best_value > current_value,
453
+ ),
454
+ )
455
+ trainer.extend(
456
+ adadelta_eps_decay(args.eps_decay),
457
+ trigger=CompareValueTrigger(
458
+ "validation/main/acc",
459
+ lambda best_value, current_value: best_value > current_value,
460
+ ),
461
+ )
462
+ elif args.criterion == "loss":
463
+ trainer.extend(
464
+ restore_snapshot(
465
+ model, args.outdir + "/model.loss.best", load_fn=torch_load
466
+ ),
467
+ trigger=CompareValueTrigger(
468
+ "validation/main/loss",
469
+ lambda best_value, current_value: best_value < current_value,
470
+ ),
471
+ )
472
+ trainer.extend(
473
+ adadelta_eps_decay(args.eps_decay),
474
+ trigger=CompareValueTrigger(
475
+ "validation/main/loss",
476
+ lambda best_value, current_value: best_value < current_value,
477
+ ),
478
+ )
479
+
480
+ # Write a log of evaluation statistics for each epoch
481
+ trainer.extend(
482
+ extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
483
+ )
484
+ report_keys = [
485
+ "epoch",
486
+ "iteration",
487
+ "main/loss",
488
+ "main/loss_ctc",
489
+ "main/loss_att",
490
+ "validation/main/loss",
491
+ "validation/main/loss_ctc",
492
+ "validation/main/loss_att",
493
+ "main/acc",
494
+ "validation/main/acc",
495
+ "main/cer_ctc",
496
+ "validation/main/cer_ctc",
497
+ "elapsed_time",
498
+ ]
499
+ if args.opt == "adadelta":
500
+ trainer.extend(
501
+ extensions.observe_value(
502
+ "eps",
503
+ lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
504
+ "eps"
505
+ ],
506
+ ),
507
+ trigger=(args.report_interval_iters, "iteration"),
508
+ )
509
+ report_keys.append("eps")
510
+ if args.report_cer:
511
+ report_keys.append("validation/main/cer")
512
+ if args.report_wer:
513
+ report_keys.append("validation/main/wer")
514
+ trainer.extend(
515
+ extensions.PrintReport(report_keys),
516
+ trigger=(args.report_interval_iters, "iteration"),
517
+ )
518
+
519
+ trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
520
+ set_early_stop(trainer, args)
521
+
522
+ if args.tensorboard_dir is not None and args.tensorboard_dir != "":
523
+ trainer.extend(
524
+ TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter),
525
+ trigger=(args.report_interval_iters, "iteration"),
526
+ )
527
+ # Run the training
528
+ trainer.run()
529
+ check_early_stop(trainer, args.epochs)
530
+
531
+
532
+ def recog(args):
533
+ """Decode with the given args.
534
+
535
+ Args:
536
+ args (namespace): The program arguments.
537
+
538
+ """
539
+ set_deterministic_pytorch(args)
540
+ model, train_args = load_trained_model(args.model)
541
+ assert isinstance(model, ASRInterface)
542
+ model.recog_args = args
543
+
544
+ # read rnnlm
545
+ if args.rnnlm:
546
+ rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
547
+ if getattr(rnnlm_args, "model_module", "default") != "default":
548
+ raise ValueError(
549
+ "use '--api v2' option to decode with non-default language model"
550
+ )
551
+ rnnlm = lm_pytorch.ClassifierWithState(
552
+ lm_pytorch.RNNLM(
553
+ len(train_args.char_list),
554
+ rnnlm_args.layer,
555
+ rnnlm_args.unit,
556
+ getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
557
+ )
558
+ )
559
+ torch_load(args.rnnlm, rnnlm)
560
+ rnnlm.eval()
561
+ else:
562
+ rnnlm = None
563
+
564
+ if args.word_rnnlm:
565
+ rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
566
+ word_dict = rnnlm_args.char_list_dict
567
+ char_dict = {x: i for i, x in enumerate(train_args.char_list)}
568
+ word_rnnlm = lm_pytorch.ClassifierWithState(
569
+ lm_pytorch.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)
570
+ )
571
+ torch_load(args.word_rnnlm, word_rnnlm)
572
+ word_rnnlm.eval()
573
+
574
+ if rnnlm is not None:
575
+ rnnlm = lm_pytorch.ClassifierWithState(
576
+ extlm_pytorch.MultiLevelLM(
577
+ word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
578
+ )
579
+ )
580
+ else:
581
+ rnnlm = lm_pytorch.ClassifierWithState(
582
+ extlm_pytorch.LookAheadWordLM(
583
+ word_rnnlm.predictor, word_dict, char_dict
584
+ )
585
+ )
586
+
587
+ # gpu
588
+ if args.ngpu == 1:
589
+ gpu_id = list(range(args.ngpu))
590
+ logging.info("gpu id: " + str(gpu_id))
591
+ model.cuda()
592
+ if rnnlm:
593
+ rnnlm.cuda()
594
+
595
+ # read json data
596
+ with open(args.recog_json, "rb") as f:
597
+ js = json.load(f)["utts"]
598
+ new_js = {}
599
+
600
+ load_inputs_and_targets = LoadInputsAndTargets(
601
+ mode="asr",
602
+ load_output=False,
603
+ sort_in_input_length=False,
604
+ preprocess_conf=train_args.preprocess_conf
605
+ if args.preprocess_conf is None
606
+ else args.preprocess_conf,
607
+ preprocess_args={"train": False},
608
+ )
609
+
610
+ if args.batchsize == 0:
611
+ with torch.no_grad():
612
+ for idx, name in enumerate(js.keys(), 1):
613
+ logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
614
+ batch = [(name, js[name])]
615
+ feat = load_inputs_and_targets(batch)[0][0]
616
+ nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm)
617
+ new_js[name] = add_results_to_json(
618
+ js[name], nbest_hyps, train_args.char_list
619
+ )
620
+
621
+ else:
622
+
623
+ def grouper(n, iterable, fillvalue=None):
624
+ kargs = [iter(iterable)] * n
625
+ return zip_longest(*kargs, fillvalue=fillvalue)
626
+
627
+ # sort data if batchsize > 1
628
+ keys = list(js.keys())
629
+ if args.batchsize > 1:
630
+ feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
631
+ sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
632
+ keys = [keys[i] for i in sorted_index]
633
+
634
+ with torch.no_grad():
635
+ for names in grouper(args.batchsize, keys, None):
636
+ names = [name for name in names if name]
637
+ batch = [(name, js[name]) for name in names]
638
+ feats = load_inputs_and_targets(batch)[0]
639
+ nbest_hyps = model.recognize_batch(
640
+ feats, args, train_args.char_list, rnnlm=rnnlm
641
+ )
642
+
643
+ for i, name in enumerate(names):
644
+ nbest_hyp = [hyp[i] for hyp in nbest_hyps]
645
+ new_js[name] = add_results_to_json(
646
+ js[name], nbest_hyp, train_args.char_list
647
+ )
648
+
649
+ with open(args.result_label, "wb") as f:
650
+ f.write(
651
+ json.dumps(
652
+ {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
653
+ ).encode("utf_8")
654
+ )
espnet/asr/pytorch_backend/recog.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`."""
2
+
3
+ import json
4
+ import logging
5
+
6
+ import torch
7
+
8
+ from espnet.asr.asr_utils import add_results_to_json
9
+ from espnet.asr.asr_utils import get_model_conf
10
+ from espnet.asr.asr_utils import torch_load
11
+ from espnet.asr.pytorch_backend.asr import load_trained_model
12
+ from espnet.nets.asr_interface import ASRInterface
13
+ from espnet.nets.batch_beam_search import BatchBeamSearch
14
+ from espnet.nets.beam_search import BeamSearch
15
+ from espnet.nets.lm_interface import dynamic_import_lm
16
+ from espnet.nets.scorer_interface import BatchScorerInterface
17
+ from espnet.nets.scorers.length_bonus import LengthBonus
18
+ from espnet.utils.deterministic_utils import set_deterministic_pytorch
19
+ from espnet.utils.io_utils import LoadInputsAndTargets
20
+
21
+
22
+ def recog_v2(args):
23
+ """Decode with custom models that implements ScorerInterface.
24
+
25
+ Notes:
26
+ The previous backend espnet.asr.pytorch_backend.asr.recog
27
+ only supports E2E and RNNLM
28
+
29
+ Args:
30
+ args (namespace): The program arguments.
31
+ See py:func:`espnet.bin.asr_recog.get_parser` for details
32
+
33
+ """
34
+ logging.warning("experimental API for custom LMs is selected by --api v2")
35
+ if args.batchsize > 1:
36
+ raise NotImplementedError("multi-utt batch decoding is not implemented")
37
+ if args.streaming_mode is not None:
38
+ raise NotImplementedError("streaming mode is not implemented")
39
+ if args.word_rnnlm:
40
+ raise NotImplementedError("word LM is not implemented")
41
+
42
+ set_deterministic_pytorch(args)
43
+ model, train_args = load_trained_model(args.model)
44
+ assert isinstance(model, ASRInterface)
45
+ model.eval()
46
+
47
+ load_inputs_and_targets = LoadInputsAndTargets(
48
+ mode="asr",
49
+ load_output=False,
50
+ sort_in_input_length=False,
51
+ preprocess_conf=train_args.preprocess_conf
52
+ if args.preprocess_conf is None
53
+ else args.preprocess_conf,
54
+ preprocess_args={"train": False},
55
+ )
56
+
57
+ if args.rnnlm:
58
+ lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
59
+ # NOTE: for a compatibility with less than 0.5.0 version models
60
+ lm_model_module = getattr(lm_args, "model_module", "default")
61
+ lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
62
+ lm = lm_class(len(train_args.char_list), lm_args)
63
+ torch_load(args.rnnlm, lm)
64
+ lm.eval()
65
+ else:
66
+ lm = None
67
+
68
+ if args.ngram_model:
69
+ from espnet.nets.scorers.ngram import NgramFullScorer
70
+ from espnet.nets.scorers.ngram import NgramPartScorer
71
+
72
+ if args.ngram_scorer == "full":
73
+ ngram = NgramFullScorer(args.ngram_model, train_args.char_list)
74
+ else:
75
+ ngram = NgramPartScorer(args.ngram_model, train_args.char_list)
76
+ else:
77
+ ngram = None
78
+
79
+ scorers = model.scorers()
80
+ scorers["lm"] = lm
81
+ scorers["ngram"] = ngram
82
+ scorers["length_bonus"] = LengthBonus(len(train_args.char_list))
83
+ weights = dict(
84
+ decoder=1.0 - args.ctc_weight,
85
+ ctc=args.ctc_weight,
86
+ lm=args.lm_weight,
87
+ ngram=args.ngram_weight,
88
+ length_bonus=args.penalty,
89
+ )
90
+ beam_search = BeamSearch(
91
+ beam_size=args.beam_size,
92
+ vocab_size=len(train_args.char_list),
93
+ weights=weights,
94
+ scorers=scorers,
95
+ sos=model.sos,
96
+ eos=model.eos,
97
+ token_list=train_args.char_list,
98
+ pre_beam_score_key=None if args.ctc_weight == 1.0 else "full",
99
+ )
100
+ # TODO(karita): make all scorers batchfied
101
+ if args.batchsize == 1:
102
+ non_batch = [
103
+ k
104
+ for k, v in beam_search.full_scorers.items()
105
+ if not isinstance(v, BatchScorerInterface)
106
+ ]
107
+ if len(non_batch) == 0:
108
+ beam_search.__class__ = BatchBeamSearch
109
+ logging.info("BatchBeamSearch implementation is selected.")
110
+ else:
111
+ logging.warning(
112
+ f"As non-batch scorers {non_batch} are found, "
113
+ f"fall back to non-batch implementation."
114
+ )
115
+
116
+ if args.ngpu > 1:
117
+ raise NotImplementedError("only single GPU decoding is supported")
118
+ if args.ngpu == 1:
119
+ device = "cuda"
120
+ else:
121
+ device = "cpu"
122
+ dtype = getattr(torch, args.dtype)
123
+ logging.info(f"Decoding device={device}, dtype={dtype}")
124
+ model.to(device=device, dtype=dtype).eval()
125
+ beam_search.to(device=device, dtype=dtype).eval()
126
+
127
+ # read json data
128
+ with open(args.recog_json, "rb") as f:
129
+ js = json.load(f)["utts"]
130
+ new_js = {}
131
+ with torch.no_grad():
132
+ for idx, name in enumerate(js.keys(), 1):
133
+ logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
134
+ batch = [(name, js[name])]
135
+ feat = load_inputs_and_targets(batch)[0][0]
136
+ enc = model.encode(torch.as_tensor(feat).to(device=device, dtype=dtype))
137
+ nbest_hyps = beam_search(
138
+ x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio
139
+ )
140
+ nbest_hyps = [
141
+ h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)]
142
+ ]
143
+ new_js[name] = add_results_to_json(
144
+ js[name], nbest_hyps, train_args.char_list
145
+ )
146
+
147
+ with open(args.result_label, "wb") as f:
148
+ f.write(
149
+ json.dumps(
150
+ {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
151
+ ).encode("utf_8")
152
+ )
espnet/bin/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/bin/asr_align.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+
4
+ # Copyright 2020 Johns Hopkins University (Xuankai Chang)
5
+ # 2020, Technische Universität München; Dominik Winkelbauer, Ludwig Kürzinger
6
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
7
+
8
+ """
9
+ This program performs CTC segmentation to align utterances within audio files.
10
+
11
+ Inputs:
12
+ `--data-json`:
13
+ A json containing list of utterances and audio files
14
+ `--model`:
15
+ An already trained ASR model
16
+
17
+ Output:
18
+ `--output`:
19
+ A plain `segments` file with utterance positions in the audio files.
20
+
21
+ Selected parameters:
22
+ `--min-window-size`:
23
+ Minimum window size considered for a single utterance. The current default value
24
+ should be OK in most cases. Larger values might give better results; too large
25
+ values cause IndexErrors.
26
+ `--subsampling-factor`:
27
+ If the encoder sub-samples its input, the number of frames at the CTC layer is
28
+ reduced by this factor.
29
+ `--frame-duration`:
30
+ This is the non-overlapping duration of a single frame in milliseconds (the
31
+ inverse of frames per millisecond).
32
+ `--set-blank`:
33
+ In the rare case that the blank token has not the index 0 in the character
34
+ dictionary, this parameter sets the index of the blank token.
35
+ `--gratis-blank`:
36
+ Sets the transition cost for blank tokens to zero. Useful if there are longer
37
+ unrelated segments between segments.
38
+ `--replace-spaces-with-blanks`:
39
+ Spaces are replaced with blanks. Helps to model pauses between words. May
40
+ increase length of ground truth. May lead to misaligned segments when combined
41
+ with the option `--gratis-blank`.
42
+ """
43
+
44
+ import configargparse
45
+ import logging
46
+ import os
47
+ import sys
48
+
49
+ # imports for inference
50
+ from espnet.asr.pytorch_backend.asr_init import load_trained_model
51
+ from espnet.nets.asr_interface import ASRInterface
52
+ from espnet.utils.io_utils import LoadInputsAndTargets
53
+ import json
54
+ import torch
55
+
56
+ # imports for CTC segmentation
57
+ from ctc_segmentation import ctc_segmentation
58
+ from ctc_segmentation import CtcSegmentationParameters
59
+ from ctc_segmentation import determine_utterance_segments
60
+ from ctc_segmentation import prepare_text
61
+
62
+
63
+ # NOTE: you need this func to generate our sphinx doc
64
+ def get_parser():
65
+ """Get default arguments."""
66
+ parser = configargparse.ArgumentParser(
67
+ description="Align text to audio using CTC segmentation."
68
+ "using a pre-trained speech recognition model.",
69
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
70
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
71
+ )
72
+ # general configuration
73
+ parser.add("--config", is_config_file=True, help="Decoding config file path.")
74
+ parser.add_argument(
75
+ "--ngpu", type=int, default=0, help="Number of GPUs (max. 1 is supported)"
76
+ )
77
+ parser.add_argument(
78
+ "--dtype",
79
+ choices=("float16", "float32", "float64"),
80
+ default="float32",
81
+ help="Float precision (only available in --api v2)",
82
+ )
83
+ parser.add_argument(
84
+ "--backend",
85
+ type=str,
86
+ default="pytorch",
87
+ choices=["pytorch"],
88
+ help="Backend library",
89
+ )
90
+ parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
91
+ parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
92
+ parser.add_argument(
93
+ "--preprocess-conf",
94
+ type=str,
95
+ default=None,
96
+ help="The configuration file for the pre-processing",
97
+ )
98
+ # task related
99
+ parser.add_argument(
100
+ "--data-json", type=str, help="Json of recognition data for audio and text"
101
+ )
102
+ parser.add_argument("--utt-text", type=str, help="Text separated into utterances")
103
+ # model (parameter) related
104
+ parser.add_argument(
105
+ "--model", type=str, required=True, help="Model file parameters to read"
106
+ )
107
+ parser.add_argument(
108
+ "--model-conf", type=str, default=None, help="Model config file"
109
+ )
110
+ parser.add_argument(
111
+ "--num-encs", default=1, type=int, help="Number of encoders in the model."
112
+ )
113
+ # ctc-segmentation related
114
+ parser.add_argument(
115
+ "--subsampling-factor",
116
+ type=int,
117
+ default=None,
118
+ help="Subsampling factor."
119
+ " If the encoder sub-samples its input, the number of frames at the CTC layer"
120
+ " is reduced by this factor. For example, a BLSTMP with subsampling 1_2_2_1_1"
121
+ " has a subsampling factor of 4.",
122
+ )
123
+ parser.add_argument(
124
+ "--frame-duration",
125
+ type=int,
126
+ default=None,
127
+ help="Non-overlapping duration of a single frame in milliseconds.",
128
+ )
129
+ parser.add_argument(
130
+ "--min-window-size",
131
+ type=int,
132
+ default=None,
133
+ help="Minimum window size considered for utterance.",
134
+ )
135
+ parser.add_argument(
136
+ "--max-window-size",
137
+ type=int,
138
+ default=None,
139
+ help="Maximum window size considered for utterance.",
140
+ )
141
+ parser.add_argument(
142
+ "--use-dict-blank",
143
+ type=int,
144
+ default=None,
145
+ help="DEPRECATED.",
146
+ )
147
+ parser.add_argument(
148
+ "--set-blank",
149
+ type=int,
150
+ default=None,
151
+ help="Index of model dictionary for blank token (default: 0).",
152
+ )
153
+ parser.add_argument(
154
+ "--gratis-blank",
155
+ type=int,
156
+ default=None,
157
+ help="Set the transition cost of the blank token to zero. Audio sections"
158
+ " labeled with blank tokens can then be skipped without penalty. Useful"
159
+ " if there are unrelated audio segments between utterances.",
160
+ )
161
+ parser.add_argument(
162
+ "--replace-spaces-with-blanks",
163
+ type=int,
164
+ default=None,
165
+ help="Fill blanks in between words to better model pauses between words."
166
+ " Segments can be misaligned if this option is combined with --gratis-blank."
167
+ " May increase length of ground truth.",
168
+ )
169
+ parser.add_argument(
170
+ "--scoring-length",
171
+ type=int,
172
+ default=None,
173
+ help="Changes partitioning length L for calculation of the confidence score.",
174
+ )
175
+ parser.add_argument(
176
+ "--output",
177
+ type=configargparse.FileType("w"),
178
+ required=True,
179
+ help="Output segments file",
180
+ )
181
+ return parser
182
+
183
+
184
+ def main(args):
185
+ """Run the main decoding function."""
186
+ parser = get_parser()
187
+ args, extra = parser.parse_known_args(args)
188
+ # logging info
189
+ if args.verbose == 1:
190
+ logging.basicConfig(
191
+ level=logging.INFO,
192
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
193
+ )
194
+ elif args.verbose == 2:
195
+ logging.basicConfig(
196
+ level=logging.DEBUG,
197
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
198
+ )
199
+ else:
200
+ logging.basicConfig(
201
+ level=logging.WARN,
202
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
203
+ )
204
+ logging.warning("Skip DEBUG/INFO messages")
205
+ if args.ngpu == 0 and args.dtype == "float16":
206
+ raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.")
207
+ # check CUDA_VISIBLE_DEVICES
208
+ device = "cpu"
209
+ if args.ngpu == 1:
210
+ device = "cuda"
211
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
212
+ if cvd is None:
213
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
214
+ elif args.ngpu > 1:
215
+ logging.error("Decoding only supports ngpu=1.")
216
+ sys.exit(1)
217
+ # display PYTHONPATH
218
+ logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
219
+ # recog
220
+ logging.info("backend = " + args.backend)
221
+ if args.backend == "pytorch":
222
+ ctc_align(args, device)
223
+ else:
224
+ raise ValueError("Only pytorch is supported.")
225
+ sys.exit(0)
226
+
227
+
228
+ def ctc_align(args, device):
229
+ """ESPnet-specific interface for CTC segmentation.
230
+
231
+ Parses configuration, infers the CTC posterior probabilities,
232
+ and then aligns start and end of utterances using CTC segmentation.
233
+ Results are written to the output file given in the args.
234
+
235
+ :param args: given configuration
236
+ :param device: for inference; one of ['cuda', 'cpu']
237
+ :return: 0 on success
238
+ """
239
+ model, train_args = load_trained_model(args.model)
240
+ assert isinstance(model, ASRInterface)
241
+ load_inputs_and_targets = LoadInputsAndTargets(
242
+ mode="asr",
243
+ load_output=True,
244
+ sort_in_input_length=False,
245
+ preprocess_conf=train_args.preprocess_conf
246
+ if args.preprocess_conf is None
247
+ else args.preprocess_conf,
248
+ preprocess_args={"train": False},
249
+ )
250
+ logging.info(f"Decoding device={device}")
251
+ # Warn for nets with high memory consumption on long audio files
252
+ if hasattr(model, "enc"):
253
+ encoder_module = model.enc.__class__.__module__
254
+ elif hasattr(model, "encoder"):
255
+ encoder_module = model.encoder.__class__.__module__
256
+ else:
257
+ encoder_module = "Unknown"
258
+ logging.info(f"Encoder module: {encoder_module}")
259
+ logging.info(f"CTC module: {model.ctc.__class__.__module__}")
260
+ if "rnn" not in encoder_module:
261
+ logging.warning("No BLSTM model detected; memory consumption may be high.")
262
+ model.to(device=device).eval()
263
+ # read audio and text json data
264
+ with open(args.data_json, "rb") as f:
265
+ js = json.load(f)["utts"]
266
+ with open(args.utt_text, "r", encoding="utf-8") as f:
267
+ lines = f.readlines()
268
+ i = 0
269
+ text = {}
270
+ segment_names = {}
271
+ for name in js.keys():
272
+ text_per_audio = []
273
+ segment_names_per_audio = []
274
+ while i < len(lines) and lines[i].startswith(name):
275
+ text_per_audio.append(lines[i][lines[i].find(" ") + 1 :])
276
+ segment_names_per_audio.append(lines[i][: lines[i].find(" ")])
277
+ i += 1
278
+ text[name] = text_per_audio
279
+ segment_names[name] = segment_names_per_audio
280
+ # apply configuration
281
+ config = CtcSegmentationParameters()
282
+ if args.subsampling_factor is not None:
283
+ config.subsampling_factor = args.subsampling_factor
284
+ if args.frame_duration is not None:
285
+ config.frame_duration_ms = args.frame_duration
286
+ if args.min_window_size is not None:
287
+ config.min_window_size = args.min_window_size
288
+ if args.max_window_size is not None:
289
+ config.max_window_size = args.max_window_size
290
+ config.char_list = train_args.char_list
291
+ if args.use_dict_blank is not None:
292
+ logging.warning(
293
+ "The option --use-dict-blank is deprecated. If needed,"
294
+ " use --set-blank instead."
295
+ )
296
+ if args.set_blank is not None:
297
+ config.blank = args.set_blank
298
+ if args.replace_spaces_with_blanks is not None:
299
+ if args.replace_spaces_with_blanks:
300
+ config.replace_spaces_with_blanks = True
301
+ else:
302
+ config.replace_spaces_with_blanks = False
303
+ if args.gratis_blank:
304
+ config.blank_transition_cost_zero = True
305
+ if config.blank_transition_cost_zero and args.replace_spaces_with_blanks:
306
+ logging.error(
307
+ "Blanks are inserted between words, and also the transition cost of blank"
308
+ " is zero. This configuration may lead to misalignments!"
309
+ )
310
+ if args.scoring_length is not None:
311
+ config.score_min_mean_over_L = args.scoring_length
312
+ logging.info(
313
+ f"Frame timings: {config.frame_duration_ms}ms * {config.subsampling_factor}"
314
+ )
315
+ # Iterate over audio files to decode and align
316
+ for idx, name in enumerate(js.keys(), 1):
317
+ logging.info("(%d/%d) Aligning " + name, idx, len(js.keys()))
318
+ batch = [(name, js[name])]
319
+ feat, label = load_inputs_and_targets(batch)
320
+ feat = feat[0]
321
+ with torch.no_grad():
322
+ # Encode input frames
323
+ enc_output = model.encode(torch.as_tensor(feat).to(device)).unsqueeze(0)
324
+ # Apply ctc layer to obtain log character probabilities
325
+ lpz = model.ctc.log_softmax(enc_output)[0].cpu().numpy()
326
+ # Prepare the text for aligning
327
+ ground_truth_mat, utt_begin_indices = prepare_text(config, text[name])
328
+ # Align using CTC segmentation
329
+ timings, char_probs, state_list = ctc_segmentation(
330
+ config, lpz, ground_truth_mat
331
+ )
332
+ logging.debug(f"state_list = {state_list}")
333
+ # Obtain list of utterances with time intervals and confidence score
334
+ segments = determine_utterance_segments(
335
+ config, utt_begin_indices, char_probs, timings, text[name]
336
+ )
337
+ # Write to "segments" file
338
+ for i, boundary in enumerate(segments):
339
+ utt_segment = (
340
+ f"{segment_names[name][i]} {name} {boundary[0]:.2f}"
341
+ f" {boundary[1]:.2f} {boundary[2]:.9f}\n"
342
+ )
343
+ args.output.write(utt_segment)
344
+ return 0
345
+
346
+
347
+ if __name__ == "__main__":
348
+ main(sys.argv[1:])
espnet/bin/asr_enhance.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import configargparse
3
+ from distutils.util import strtobool
4
+ import logging
5
+ import os
6
+ import random
7
+ import sys
8
+
9
+ import numpy as np
10
+
11
+ from espnet.asr.pytorch_backend.asr import enhance
12
+
13
+
14
+ # NOTE: you need this func to generate our sphinx doc
15
+ def get_parser():
16
+ parser = configargparse.ArgumentParser(
17
+ description="Enhance noisy speech for speech recognition",
18
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
19
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
20
+ )
21
+ # general configuration
22
+ parser.add("--config", is_config_file=True, help="config file path")
23
+ parser.add(
24
+ "--config2",
25
+ is_config_file=True,
26
+ help="second config file path that overwrites the settings in `--config`.",
27
+ )
28
+ parser.add(
29
+ "--config3",
30
+ is_config_file=True,
31
+ help="third config file path that overwrites the settings "
32
+ "in `--config` and `--config2`.",
33
+ )
34
+
35
+ parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs")
36
+ parser.add_argument(
37
+ "--backend",
38
+ default="chainer",
39
+ type=str,
40
+ choices=["chainer", "pytorch"],
41
+ help="Backend library",
42
+ )
43
+ parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
44
+ parser.add_argument("--seed", default=1, type=int, help="Random seed")
45
+ parser.add_argument("--verbose", "-V", default=1, type=int, help="Verbose option")
46
+ parser.add_argument(
47
+ "--batchsize",
48
+ default=1,
49
+ type=int,
50
+ help="Batch size for beam search (0: means no batch processing)",
51
+ )
52
+ parser.add_argument(
53
+ "--preprocess-conf",
54
+ type=str,
55
+ default=None,
56
+ help="The configuration file for the pre-processing",
57
+ )
58
+ # task related
59
+ parser.add_argument(
60
+ "--recog-json", type=str, help="Filename of recognition data (json)"
61
+ )
62
+ # model (parameter) related
63
+ parser.add_argument(
64
+ "--model", type=str, required=True, help="Model file parameters to read"
65
+ )
66
+ parser.add_argument(
67
+ "--model-conf", type=str, default=None, help="Model config file"
68
+ )
69
+
70
+ # Outputs configuration
71
+ parser.add_argument(
72
+ "--enh-wspecifier",
73
+ type=str,
74
+ default=None,
75
+ help="Specify the output way for enhanced speech."
76
+ "e.g. ark,scp:outdir,wav.scp",
77
+ )
78
+ parser.add_argument(
79
+ "--enh-filetype",
80
+ type=str,
81
+ default="sound",
82
+ choices=["mat", "hdf5", "sound.hdf5", "sound"],
83
+ help="Specify the file format for enhanced speech. "
84
+ '"mat" is the matrix format in kaldi',
85
+ )
86
+ parser.add_argument("--fs", type=int, default=16000, help="The sample frequency")
87
+ parser.add_argument(
88
+ "--keep-length",
89
+ type=strtobool,
90
+ default=True,
91
+ help="Adjust the output length to match " "with the input for enhanced speech",
92
+ )
93
+ parser.add_argument(
94
+ "--image-dir", type=str, default=None, help="The directory saving the images."
95
+ )
96
+ parser.add_argument(
97
+ "--num-images",
98
+ type=int,
99
+ default=20,
100
+ help="The number of images files to be saved. "
101
+ "If negative, all samples are to be saved.",
102
+ )
103
+
104
+ # IStft
105
+ parser.add_argument(
106
+ "--apply-istft",
107
+ type=strtobool,
108
+ default=True,
109
+ help="Apply istft to the output from the network",
110
+ )
111
+ parser.add_argument(
112
+ "--istft-win-length",
113
+ type=int,
114
+ default=512,
115
+ help="The window length for istft. "
116
+ "This option is ignored "
117
+ "if stft is found in the preprocess-conf",
118
+ )
119
+ parser.add_argument(
120
+ "--istft-n-shift",
121
+ type=str,
122
+ default=256,
123
+ help="The window type for istft. "
124
+ "This option is ignored "
125
+ "if stft is found in the preprocess-conf",
126
+ )
127
+ parser.add_argument(
128
+ "--istft-window",
129
+ type=str,
130
+ default="hann",
131
+ help="The window type for istft. "
132
+ "This option is ignored "
133
+ "if stft is found in the preprocess-conf",
134
+ )
135
+ return parser
136
+
137
+
138
+ def main(args):
139
+ parser = get_parser()
140
+ args = parser.parse_args(args)
141
+
142
+ # logging info
143
+ if args.verbose == 1:
144
+ logging.basicConfig(
145
+ level=logging.INFO,
146
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
147
+ )
148
+ elif args.verbose == 2:
149
+ logging.basicConfig(
150
+ level=logging.DEBUG,
151
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
152
+ )
153
+ else:
154
+ logging.basicConfig(
155
+ level=logging.WARN,
156
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
157
+ )
158
+ logging.warning("Skip DEBUG/INFO messages")
159
+
160
+ # check CUDA_VISIBLE_DEVICES
161
+ if args.ngpu > 0:
162
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
163
+ if cvd is None:
164
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
165
+ elif args.ngpu != len(cvd.split(",")):
166
+ logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
167
+ sys.exit(1)
168
+
169
+ # TODO(kamo): support of multiple GPUs
170
+ if args.ngpu > 1:
171
+ logging.error("The program only supports ngpu=1.")
172
+ sys.exit(1)
173
+
174
+ # display PYTHONPATH
175
+ logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
176
+
177
+ # seed setting
178
+ random.seed(args.seed)
179
+ np.random.seed(args.seed)
180
+ logging.info("set random seed = %d" % args.seed)
181
+
182
+ # recog
183
+ logging.info("backend = " + args.backend)
184
+ if args.backend == "pytorch":
185
+ enhance(args)
186
+ else:
187
+ raise ValueError("Only pytorch is supported.")
188
+
189
+
190
+ if __name__ == "__main__":
191
+ main(sys.argv[1:])
espnet/bin/asr_recog.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+
4
+ # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """End-to-end speech recognition model decoding script."""
8
+
9
+ import configargparse
10
+ import logging
11
+ import os
12
+ import random
13
+ import sys
14
+
15
+ import numpy as np
16
+
17
+ from espnet.utils.cli_utils import strtobool
18
+
19
+ # NOTE: you need this func to generate our sphinx doc
20
+
21
+
22
+ def get_parser():
23
+ """Get default arguments."""
24
+ parser = configargparse.ArgumentParser(
25
+ description="Transcribe text from speech using "
26
+ "a speech recognition model on one CPU or GPU",
27
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
28
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
29
+ )
30
+ # general configuration
31
+ parser.add("--config", is_config_file=True, help="Config file path")
32
+ parser.add(
33
+ "--config2",
34
+ is_config_file=True,
35
+ help="Second config file path that overwrites the settings in `--config`",
36
+ )
37
+ parser.add(
38
+ "--config3",
39
+ is_config_file=True,
40
+ help="Third config file path that overwrites the settings "
41
+ "in `--config` and `--config2`",
42
+ )
43
+
44
+ parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
45
+ parser.add_argument(
46
+ "--dtype",
47
+ choices=("float16", "float32", "float64"),
48
+ default="float32",
49
+ help="Float precision (only available in --api v2)",
50
+ )
51
+ parser.add_argument(
52
+ "--backend",
53
+ type=str,
54
+ default="chainer",
55
+ choices=["chainer", "pytorch"],
56
+ help="Backend library",
57
+ )
58
+ parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
59
+ parser.add_argument("--seed", type=int, default=1, help="Random seed")
60
+ parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
61
+ parser.add_argument(
62
+ "--batchsize",
63
+ type=int,
64
+ default=1,
65
+ help="Batch size for beam search (0: means no batch processing)",
66
+ )
67
+ parser.add_argument(
68
+ "--preprocess-conf",
69
+ type=str,
70
+ default=None,
71
+ help="The configuration file for the pre-processing",
72
+ )
73
+ parser.add_argument(
74
+ "--api",
75
+ default="v1",
76
+ choices=["v1", "v2"],
77
+ help="Beam search APIs "
78
+ "v1: Default API. It only supports the ASRInterface.recognize method "
79
+ "and DefaultRNNLM. "
80
+ "v2: Experimental API. It supports any models that implements ScorerInterface.",
81
+ )
82
+ # task related
83
+ parser.add_argument(
84
+ "--recog-json", type=str, help="Filename of recognition data (json)"
85
+ )
86
+ parser.add_argument(
87
+ "--result-label",
88
+ type=str,
89
+ required=True,
90
+ help="Filename of result label data (json)",
91
+ )
92
+ # model (parameter) related
93
+ parser.add_argument(
94
+ "--model", type=str, required=True, help="Model file parameters to read"
95
+ )
96
+ parser.add_argument(
97
+ "--model-conf", type=str, default=None, help="Model config file"
98
+ )
99
+ parser.add_argument(
100
+ "--num-spkrs",
101
+ type=int,
102
+ default=1,
103
+ choices=[1, 2],
104
+ help="Number of speakers in the speech",
105
+ )
106
+ parser.add_argument(
107
+ "--num-encs", default=1, type=int, help="Number of encoders in the model."
108
+ )
109
+ # search related
110
+ parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
111
+ parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
112
+ parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty")
113
+ parser.add_argument(
114
+ "--maxlenratio",
115
+ type=float,
116
+ default=0.0,
117
+ help="""Input length ratio to obtain max output length.
118
+ If maxlenratio=0.0 (default), it uses a end-detect function
119
+ to automatically find maximum hypothesis lengths""",
120
+ )
121
+ parser.add_argument(
122
+ "--minlenratio",
123
+ type=float,
124
+ default=0.0,
125
+ help="Input length ratio to obtain min output length",
126
+ )
127
+ parser.add_argument(
128
+ "--ctc-weight", type=float, default=0.0, help="CTC weight in joint decoding"
129
+ )
130
+ parser.add_argument(
131
+ "--weights-ctc-dec",
132
+ type=float,
133
+ action="append",
134
+ help="ctc weight assigned to each encoder during decoding."
135
+ "[in multi-encoder mode only]",
136
+ )
137
+ parser.add_argument(
138
+ "--ctc-window-margin",
139
+ type=int,
140
+ default=0,
141
+ help="""Use CTC window with margin parameter to accelerate
142
+ CTC/attention decoding especially on GPU. Smaller magin
143
+ makes decoding faster, but may increase search errors.
144
+ If margin=0 (default), this function is disabled""",
145
+ )
146
+ # transducer related
147
+ parser.add_argument(
148
+ "--search-type",
149
+ type=str,
150
+ default="default",
151
+ choices=["default", "nsc", "tsd", "alsd"],
152
+ help="""Type of beam search implementation to use during inference.
153
+ Can be either: default beam search, n-step constrained beam search ("nsc"),
154
+ time-synchronous decoding ("tsd") or alignment-length synchronous decoding
155
+ ("alsd").
156
+ Additional associated parameters: "nstep" + "prefix-alpha" (for nsc),
157
+ "max-sym-exp" (for tsd) and "u-max" (for alsd)""",
158
+ )
159
+ parser.add_argument(
160
+ "--nstep",
161
+ type=int,
162
+ default=1,
163
+ help="Number of expansion steps allowed in NSC beam search.",
164
+ )
165
+ parser.add_argument(
166
+ "--prefix-alpha",
167
+ type=int,
168
+ default=2,
169
+ help="Length prefix difference allowed in NSC beam search.",
170
+ )
171
+ parser.add_argument(
172
+ "--max-sym-exp",
173
+ type=int,
174
+ default=2,
175
+ help="Number of symbol expansions allowed in TSD decoding.",
176
+ )
177
+ parser.add_argument(
178
+ "--u-max",
179
+ type=int,
180
+ default=400,
181
+ help="Length prefix difference allowed in ALSD beam search.",
182
+ )
183
+ parser.add_argument(
184
+ "--score-norm",
185
+ type=strtobool,
186
+ nargs="?",
187
+ default=True,
188
+ help="Normalize transducer scores by length",
189
+ )
190
+ # rnnlm related
191
+ parser.add_argument(
192
+ "--rnnlm", type=str, default=None, help="RNNLM model file to read"
193
+ )
194
+ parser.add_argument(
195
+ "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read"
196
+ )
197
+ parser.add_argument(
198
+ "--word-rnnlm", type=str, default=None, help="Word RNNLM model file to read"
199
+ )
200
+ parser.add_argument(
201
+ "--word-rnnlm-conf",
202
+ type=str,
203
+ default=None,
204
+ help="Word RNNLM model config file to read",
205
+ )
206
+ parser.add_argument("--word-dict", type=str, default=None, help="Word list to read")
207
+ parser.add_argument("--lm-weight", type=float, default=0.1, help="RNNLM weight")
208
+ # ngram related
209
+ parser.add_argument(
210
+ "--ngram-model", type=str, default=None, help="ngram model file to read"
211
+ )
212
+ parser.add_argument("--ngram-weight", type=float, default=0.1, help="ngram weight")
213
+ parser.add_argument(
214
+ "--ngram-scorer",
215
+ type=str,
216
+ default="part",
217
+ choices=("full", "part"),
218
+ help="""if the ngram is set as a part scorer, similar with CTC scorer,
219
+ ngram scorer only scores topK hypethesis.
220
+ if the ngram is set as full scorer, ngram scorer scores all hypthesis
221
+ the decoding speed of part scorer is musch faster than full one""",
222
+ )
223
+ # streaming related
224
+ parser.add_argument(
225
+ "--streaming-mode",
226
+ type=str,
227
+ default=None,
228
+ choices=["window", "segment"],
229
+ help="""Use streaming recognizer for inference.
230
+ `--batchsize` must be set to 0 to enable this mode""",
231
+ )
232
+ parser.add_argument("--streaming-window", type=int, default=10, help="Window size")
233
+ parser.add_argument(
234
+ "--streaming-min-blank-dur",
235
+ type=int,
236
+ default=10,
237
+ help="Minimum blank duration threshold",
238
+ )
239
+ parser.add_argument(
240
+ "--streaming-onset-margin", type=int, default=1, help="Onset margin"
241
+ )
242
+ parser.add_argument(
243
+ "--streaming-offset-margin", type=int, default=1, help="Offset margin"
244
+ )
245
+ # non-autoregressive related
246
+ # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail.
247
+ parser.add_argument(
248
+ "--maskctc-n-iterations",
249
+ type=int,
250
+ default=10,
251
+ help="Number of decoding iterations."
252
+ "For Mask CTC, set 0 to predict 1 mask/iter.",
253
+ )
254
+ parser.add_argument(
255
+ "--maskctc-probability-threshold",
256
+ type=float,
257
+ default=0.999,
258
+ help="Threshold probability for CTC output",
259
+ )
260
+
261
+ return parser
262
+
263
+
264
+ def main(args):
265
+ """Run the main decoding function."""
266
+ parser = get_parser()
267
+ args = parser.parse_args(args)
268
+
269
+ if args.ngpu == 0 and args.dtype == "float16":
270
+ raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.")
271
+
272
+ # logging info
273
+ if args.verbose == 1:
274
+ logging.basicConfig(
275
+ level=logging.INFO,
276
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
277
+ )
278
+ elif args.verbose == 2:
279
+ logging.basicConfig(
280
+ level=logging.DEBUG,
281
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
282
+ )
283
+ else:
284
+ logging.basicConfig(
285
+ level=logging.WARN,
286
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
287
+ )
288
+ logging.warning("Skip DEBUG/INFO messages")
289
+
290
+ # check CUDA_VISIBLE_DEVICES
291
+ if args.ngpu > 0:
292
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
293
+ if cvd is None:
294
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
295
+ elif args.ngpu != len(cvd.split(",")):
296
+ logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
297
+ sys.exit(1)
298
+
299
+ # TODO(mn5k): support of multiple GPUs
300
+ if args.ngpu > 1:
301
+ logging.error("The program only supports ngpu=1.")
302
+ sys.exit(1)
303
+
304
+ # display PYTHONPATH
305
+ logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
306
+
307
+ # seed setting
308
+ random.seed(args.seed)
309
+ np.random.seed(args.seed)
310
+ logging.info("set random seed = %d" % args.seed)
311
+
312
+ # validate rnn options
313
+ if args.rnnlm is not None and args.word_rnnlm is not None:
314
+ logging.error(
315
+ "It seems that both --rnnlm and --word-rnnlm are specified. "
316
+ "Please use either option."
317
+ )
318
+ sys.exit(1)
319
+
320
+ # recog
321
+ logging.info("backend = " + args.backend)
322
+ if args.num_spkrs == 1:
323
+ if args.backend == "chainer":
324
+ from espnet.asr.chainer_backend.asr import recog
325
+
326
+ recog(args)
327
+ elif args.backend == "pytorch":
328
+ if args.num_encs == 1:
329
+ # Experimental API that supports custom LMs
330
+ if args.api == "v2":
331
+ from espnet.asr.pytorch_backend.recog import recog_v2
332
+
333
+ recog_v2(args)
334
+ else:
335
+ from espnet.asr.pytorch_backend.asr import recog
336
+
337
+ if args.dtype != "float32":
338
+ raise NotImplementedError(
339
+ f"`--dtype {args.dtype}` is only available with `--api v2`"
340
+ )
341
+ recog(args)
342
+ else:
343
+ if args.api == "v2":
344
+ raise NotImplementedError(
345
+ f"--num-encs {args.num_encs} > 1 is not supported in --api v2"
346
+ )
347
+ else:
348
+ from espnet.asr.pytorch_backend.asr import recog
349
+
350
+ recog(args)
351
+ else:
352
+ raise ValueError("Only chainer and pytorch are supported.")
353
+ elif args.num_spkrs == 2:
354
+ if args.backend == "pytorch":
355
+ from espnet.asr.pytorch_backend.asr_mix import recog
356
+
357
+ recog(args)
358
+ else:
359
+ raise ValueError("Only pytorch is supported.")
360
+
361
+
362
+ if __name__ == "__main__":
363
+ main(sys.argv[1:])
espnet/bin/asr_train.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+
4
+ # Copyright 2017 Tomoki Hayashi (Nagoya University)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Automatic speech recognition model training script."""
8
+
9
+ import logging
10
+ import os
11
+ import random
12
+ import subprocess
13
+ import sys
14
+
15
+ from distutils.version import LooseVersion
16
+
17
+ import configargparse
18
+ import numpy as np
19
+ import torch
20
+
21
+ from espnet import __version__
22
+ from espnet.utils.cli_utils import strtobool
23
+ from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES
24
+
25
+ is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2")
26
+
27
+
28
+ # NOTE: you need this func to generate our sphinx doc
29
+ def get_parser(parser=None, required=True):
30
+ """Get default arguments."""
31
+ if parser is None:
32
+ parser = configargparse.ArgumentParser(
33
+ description="Train an automatic speech recognition (ASR) model on one CPU, "
34
+ "one or multiple GPUs",
35
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
36
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
37
+ )
38
+ # general configuration
39
+ parser.add("--config", is_config_file=True, help="config file path")
40
+ parser.add(
41
+ "--config2",
42
+ is_config_file=True,
43
+ help="second config file path that overwrites the settings in `--config`.",
44
+ )
45
+ parser.add(
46
+ "--config3",
47
+ is_config_file=True,
48
+ help="third config file path that overwrites the settings in "
49
+ "`--config` and `--config2`.",
50
+ )
51
+
52
+ parser.add_argument(
53
+ "--ngpu",
54
+ default=None,
55
+ type=int,
56
+ help="Number of GPUs. If not given, use all visible devices",
57
+ )
58
+ parser.add_argument(
59
+ "--train-dtype",
60
+ default="float32",
61
+ choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"],
62
+ help="Data type for training (only pytorch backend). "
63
+ "O0,O1,.. flags require apex. "
64
+ "See https://nvidia.github.io/apex/amp.html#opt-levels",
65
+ )
66
+ parser.add_argument(
67
+ "--backend",
68
+ default="chainer",
69
+ type=str,
70
+ choices=["chainer", "pytorch"],
71
+ help="Backend library",
72
+ )
73
+ parser.add_argument(
74
+ "--outdir", type=str, required=required, help="Output directory"
75
+ )
76
+ parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
77
+ parser.add_argument("--dict", required=required, help="Dictionary")
78
+ parser.add_argument("--seed", default=1, type=int, help="Random seed")
79
+ parser.add_argument("--debugdir", type=str, help="Output directory for debugging")
80
+ parser.add_argument(
81
+ "--resume",
82
+ "-r",
83
+ default="",
84
+ nargs="?",
85
+ help="Resume the training from snapshot",
86
+ )
87
+ parser.add_argument(
88
+ "--minibatches",
89
+ "-N",
90
+ type=int,
91
+ default="-1",
92
+ help="Process only N minibatches (for debug)",
93
+ )
94
+ parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
95
+ parser.add_argument(
96
+ "--tensorboard-dir",
97
+ default=None,
98
+ type=str,
99
+ nargs="?",
100
+ help="Tensorboard log dir path",
101
+ )
102
+ parser.add_argument(
103
+ "--report-interval-iters",
104
+ default=100,
105
+ type=int,
106
+ help="Report interval iterations",
107
+ )
108
+ parser.add_argument(
109
+ "--save-interval-iters",
110
+ default=0,
111
+ type=int,
112
+ help="Save snapshot interval iterations",
113
+ )
114
+ # task related
115
+ parser.add_argument(
116
+ "--train-json",
117
+ type=str,
118
+ default=None,
119
+ help="Filename of train label data (json)",
120
+ )
121
+ parser.add_argument(
122
+ "--valid-json",
123
+ type=str,
124
+ default=None,
125
+ help="Filename of validation label data (json)",
126
+ )
127
+ # network architecture
128
+ parser.add_argument(
129
+ "--model-module",
130
+ type=str,
131
+ default=None,
132
+ help="model defined module (default: espnet.nets.xxx_backend.e2e_asr:E2E)",
133
+ )
134
+ # encoder
135
+ parser.add_argument(
136
+ "--num-encs", default=1, type=int, help="Number of encoders in the model."
137
+ )
138
+ # loss related
139
+ parser.add_argument(
140
+ "--ctc_type",
141
+ default="warpctc",
142
+ type=str,
143
+ choices=["builtin", "warpctc", "gtnctc", "cudnnctc"],
144
+ help="Type of CTC implementation to calculate loss.",
145
+ )
146
+ parser.add_argument(
147
+ "--mtlalpha",
148
+ default=0.5,
149
+ type=float,
150
+ help="Multitask learning coefficient, "
151
+ "alpha: alpha*ctc_loss + (1-alpha)*att_loss ",
152
+ )
153
+ parser.add_argument(
154
+ "--lsm-weight", default=0.0, type=float, help="Label smoothing weight"
155
+ )
156
+ # recognition options to compute CER/WER
157
+ parser.add_argument(
158
+ "--report-cer",
159
+ default=False,
160
+ action="store_true",
161
+ help="Compute CER on development set",
162
+ )
163
+ parser.add_argument(
164
+ "--report-wer",
165
+ default=False,
166
+ action="store_true",
167
+ help="Compute WER on development set",
168
+ )
169
+ parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
170
+ parser.add_argument("--beam-size", type=int, default=4, help="Beam size")
171
+ parser.add_argument("--penalty", default=0.0, type=float, help="Incertion penalty")
172
+ parser.add_argument(
173
+ "--maxlenratio",
174
+ default=0.0,
175
+ type=float,
176
+ help="""Input length ratio to obtain max output length.
177
+ If maxlenratio=0.0 (default), it uses a end-detect function
178
+ to automatically find maximum hypothesis lengths""",
179
+ )
180
+ parser.add_argument(
181
+ "--minlenratio",
182
+ default=0.0,
183
+ type=float,
184
+ help="Input length ratio to obtain min output length",
185
+ )
186
+ parser.add_argument(
187
+ "--ctc-weight", default=0.3, type=float, help="CTC weight in joint decoding"
188
+ )
189
+ parser.add_argument(
190
+ "--rnnlm", type=str, default=None, help="RNNLM model file to read"
191
+ )
192
+ parser.add_argument(
193
+ "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read"
194
+ )
195
+ parser.add_argument("--lm-weight", default=0.1, type=float, help="RNNLM weight.")
196
+ parser.add_argument("--sym-space", default="<space>", type=str, help="Space symbol")
197
+ parser.add_argument("--sym-blank", default="<blank>", type=str, help="Blank symbol")
198
+ # minibatch related
199
+ parser.add_argument(
200
+ "--sortagrad",
201
+ default=0,
202
+ type=int,
203
+ nargs="?",
204
+ help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
205
+ )
206
+ parser.add_argument(
207
+ "--batch-count",
208
+ default="auto",
209
+ choices=BATCH_COUNT_CHOICES,
210
+ help="How to count batch_size. "
211
+ "The default (auto) will find how to count by args.",
212
+ )
213
+ parser.add_argument(
214
+ "--batch-size",
215
+ "--batch-seqs",
216
+ "-b",
217
+ default=0,
218
+ type=int,
219
+ help="Maximum seqs in a minibatch (0 to disable)",
220
+ )
221
+ parser.add_argument(
222
+ "--batch-bins",
223
+ default=0,
224
+ type=int,
225
+ help="Maximum bins in a minibatch (0 to disable)",
226
+ )
227
+ parser.add_argument(
228
+ "--batch-frames-in",
229
+ default=0,
230
+ type=int,
231
+ help="Maximum input frames in a minibatch (0 to disable)",
232
+ )
233
+ parser.add_argument(
234
+ "--batch-frames-out",
235
+ default=0,
236
+ type=int,
237
+ help="Maximum output frames in a minibatch (0 to disable)",
238
+ )
239
+ parser.add_argument(
240
+ "--batch-frames-inout",
241
+ default=0,
242
+ type=int,
243
+ help="Maximum input+output frames in a minibatch (0 to disable)",
244
+ )
245
+ parser.add_argument(
246
+ "--maxlen-in",
247
+ "--batch-seq-maxlen-in",
248
+ default=800,
249
+ type=int,
250
+ metavar="ML",
251
+ help="When --batch-count=seq, "
252
+ "batch size is reduced if the input sequence length > ML.",
253
+ )
254
+ parser.add_argument(
255
+ "--maxlen-out",
256
+ "--batch-seq-maxlen-out",
257
+ default=150,
258
+ type=int,
259
+ metavar="ML",
260
+ help="When --batch-count=seq, "
261
+ "batch size is reduced if the output sequence length > ML",
262
+ )
263
+ parser.add_argument(
264
+ "--n-iter-processes",
265
+ default=0,
266
+ type=int,
267
+ help="Number of processes of iterator",
268
+ )
269
+ parser.add_argument(
270
+ "--preprocess-conf",
271
+ type=str,
272
+ default=None,
273
+ nargs="?",
274
+ help="The configuration file for the pre-processing",
275
+ )
276
+ # optimization related
277
+ parser.add_argument(
278
+ "--opt",
279
+ default="adadelta",
280
+ type=str,
281
+ choices=["adadelta", "adam", "noam"],
282
+ help="Optimizer",
283
+ )
284
+ parser.add_argument(
285
+ "--accum-grad", default=1, type=int, help="Number of gradient accumuration"
286
+ )
287
+ parser.add_argument(
288
+ "--eps", default=1e-8, type=float, help="Epsilon constant for optimizer"
289
+ )
290
+ parser.add_argument(
291
+ "--eps-decay", default=0.01, type=float, help="Decaying ratio of epsilon"
292
+ )
293
+ parser.add_argument(
294
+ "--weight-decay", default=0.0, type=float, help="Weight decay ratio"
295
+ )
296
+ parser.add_argument(
297
+ "--criterion",
298
+ default="acc",
299
+ type=str,
300
+ choices=["loss", "loss_eps_decay_only", "acc"],
301
+ help="Criterion to perform epsilon decay",
302
+ )
303
+ parser.add_argument(
304
+ "--threshold", default=1e-4, type=float, help="Threshold to stop iteration"
305
+ )
306
+ parser.add_argument(
307
+ "--epochs", "-e", default=30, type=int, help="Maximum number of epochs"
308
+ )
309
+ parser.add_argument(
310
+ "--early-stop-criterion",
311
+ default="validation/main/acc",
312
+ type=str,
313
+ nargs="?",
314
+ help="Value to monitor to trigger an early stopping of the training",
315
+ )
316
+ parser.add_argument(
317
+ "--patience",
318
+ default=3,
319
+ type=int,
320
+ nargs="?",
321
+ help="Number of epochs to wait without improvement "
322
+ "before stopping the training",
323
+ )
324
+ parser.add_argument(
325
+ "--grad-clip", default=5, type=float, help="Gradient norm threshold to clip"
326
+ )
327
+ parser.add_argument(
328
+ "--num-save-attention",
329
+ default=3,
330
+ type=int,
331
+ help="Number of samples of attention to be saved",
332
+ )
333
+ parser.add_argument(
334
+ "--num-save-ctc",
335
+ default=3,
336
+ type=int,
337
+ help="Number of samples of CTC probability to be saved",
338
+ )
339
+ parser.add_argument(
340
+ "--grad-noise",
341
+ type=strtobool,
342
+ default=False,
343
+ help="The flag to switch to use noise injection to gradients during training",
344
+ )
345
+ # asr_mix related
346
+ parser.add_argument(
347
+ "--num-spkrs",
348
+ default=1,
349
+ type=int,
350
+ choices=[1, 2],
351
+ help="Number of speakers in the speech.",
352
+ )
353
+ # decoder related
354
+ parser.add_argument(
355
+ "--context-residual",
356
+ default=False,
357
+ type=strtobool,
358
+ nargs="?",
359
+ help="The flag to switch to use context vector residual in the decoder network",
360
+ )
361
+ # finetuning related
362
+ parser.add_argument(
363
+ "--enc-init",
364
+ default=None,
365
+ type=str,
366
+ help="Pre-trained ASR model to initialize encoder.",
367
+ )
368
+ parser.add_argument(
369
+ "--enc-init-mods",
370
+ default="enc.enc.",
371
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
372
+ help="List of encoder modules to initialize, separated by a comma.",
373
+ )
374
+ parser.add_argument(
375
+ "--dec-init",
376
+ default=None,
377
+ type=str,
378
+ help="Pre-trained ASR, MT or LM model to initialize decoder.",
379
+ )
380
+ parser.add_argument(
381
+ "--dec-init-mods",
382
+ default="att.,dec.",
383
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
384
+ help="List of decoder modules to initialize, separated by a comma.",
385
+ )
386
+ parser.add_argument(
387
+ "--freeze-mods",
388
+ default=None,
389
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
390
+ help="List of modules to freeze, separated by a comma.",
391
+ )
392
+ # front end related
393
+ parser.add_argument(
394
+ "--use-frontend",
395
+ type=strtobool,
396
+ default=False,
397
+ help="The flag to switch to use frontend system.",
398
+ )
399
+
400
+ # WPE related
401
+ parser.add_argument(
402
+ "--use-wpe",
403
+ type=strtobool,
404
+ default=False,
405
+ help="Apply Weighted Prediction Error",
406
+ )
407
+ parser.add_argument(
408
+ "--wtype",
409
+ default="blstmp",
410
+ type=str,
411
+ choices=[
412
+ "lstm",
413
+ "blstm",
414
+ "lstmp",
415
+ "blstmp",
416
+ "vgglstmp",
417
+ "vggblstmp",
418
+ "vgglstm",
419
+ "vggblstm",
420
+ "gru",
421
+ "bgru",
422
+ "grup",
423
+ "bgrup",
424
+ "vgggrup",
425
+ "vggbgrup",
426
+ "vgggru",
427
+ "vggbgru",
428
+ ],
429
+ help="Type of encoder network architecture "
430
+ "of the mask estimator for WPE. "
431
+ "",
432
+ )
433
+ parser.add_argument("--wlayers", type=int, default=2, help="")
434
+ parser.add_argument("--wunits", type=int, default=300, help="")
435
+ parser.add_argument("--wprojs", type=int, default=300, help="")
436
+ parser.add_argument("--wdropout-rate", type=float, default=0.0, help="")
437
+ parser.add_argument("--wpe-taps", type=int, default=5, help="")
438
+ parser.add_argument("--wpe-delay", type=int, default=3, help="")
439
+ parser.add_argument(
440
+ "--use-dnn-mask-for-wpe",
441
+ type=strtobool,
442
+ default=False,
443
+ help="Use DNN to estimate the power spectrogram. "
444
+ "This option is experimental.",
445
+ )
446
+ # Beamformer related
447
+ parser.add_argument("--use-beamformer", type=strtobool, default=True, help="")
448
+ parser.add_argument(
449
+ "--btype",
450
+ default="blstmp",
451
+ type=str,
452
+ choices=[
453
+ "lstm",
454
+ "blstm",
455
+ "lstmp",
456
+ "blstmp",
457
+ "vgglstmp",
458
+ "vggblstmp",
459
+ "vgglstm",
460
+ "vggblstm",
461
+ "gru",
462
+ "bgru",
463
+ "grup",
464
+ "bgrup",
465
+ "vgggrup",
466
+ "vggbgrup",
467
+ "vgggru",
468
+ "vggbgru",
469
+ ],
470
+ help="Type of encoder network architecture "
471
+ "of the mask estimator for Beamformer.",
472
+ )
473
+ parser.add_argument("--blayers", type=int, default=2, help="")
474
+ parser.add_argument("--bunits", type=int, default=300, help="")
475
+ parser.add_argument("--bprojs", type=int, default=300, help="")
476
+ parser.add_argument("--badim", type=int, default=320, help="")
477
+ parser.add_argument(
478
+ "--bnmask",
479
+ type=int,
480
+ default=2,
481
+ help="Number of beamforming masks, " "default is 2 for [speech, noise].",
482
+ )
483
+ parser.add_argument(
484
+ "--ref-channel",
485
+ type=int,
486
+ default=-1,
487
+ help="The reference channel used for beamformer. "
488
+ "By default, the channel is estimated by DNN.",
489
+ )
490
+ parser.add_argument("--bdropout-rate", type=float, default=0.0, help="")
491
+ # Feature transform: Normalization
492
+ parser.add_argument(
493
+ "--stats-file",
494
+ type=str,
495
+ default=None,
496
+ help="The stats file for the feature normalization",
497
+ )
498
+ parser.add_argument(
499
+ "--apply-uttmvn",
500
+ type=strtobool,
501
+ default=True,
502
+ help="Apply utterance level mean " "variance normalization.",
503
+ )
504
+ parser.add_argument("--uttmvn-norm-means", type=strtobool, default=True, help="")
505
+ parser.add_argument("--uttmvn-norm-vars", type=strtobool, default=False, help="")
506
+ # Feature transform: Fbank
507
+ parser.add_argument(
508
+ "--fbank-fs",
509
+ type=int,
510
+ default=16000,
511
+ help="The sample frequency used for " "the mel-fbank creation.",
512
+ )
513
+ parser.add_argument(
514
+ "--n-mels", type=int, default=80, help="The number of mel-frequency bins."
515
+ )
516
+ parser.add_argument("--fbank-fmin", type=float, default=0.0, help="")
517
+ parser.add_argument("--fbank-fmax", type=float, default=None, help="")
518
+ return parser
519
+
520
+
521
+ def main(cmd_args):
522
+ """Run the main training function."""
523
+ parser = get_parser()
524
+ args, _ = parser.parse_known_args(cmd_args)
525
+ if args.backend == "chainer" and args.train_dtype != "float32":
526
+ raise NotImplementedError(
527
+ f"chainer backend does not support --train-dtype {args.train_dtype}."
528
+ "Use --dtype float32."
529
+ )
530
+ if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"):
531
+ raise ValueError(
532
+ f"--train-dtype {args.train_dtype} does not support the CPU backend."
533
+ )
534
+
535
+ from espnet.utils.dynamic_import import dynamic_import
536
+
537
+ if args.model_module is None:
538
+ if args.num_spkrs == 1:
539
+ model_module = "espnet.nets." + args.backend + "_backend.e2e_asr:E2E"
540
+ else:
541
+ model_module = "espnet.nets." + args.backend + "_backend.e2e_asr_mix:E2E"
542
+ else:
543
+ model_module = args.model_module
544
+ model_class = dynamic_import(model_module)
545
+ model_class.add_arguments(parser)
546
+
547
+ args = parser.parse_args(cmd_args)
548
+ args.model_module = model_module
549
+ if "chainer_backend" in args.model_module:
550
+ args.backend = "chainer"
551
+ if "pytorch_backend" in args.model_module:
552
+ args.backend = "pytorch"
553
+
554
+ # add version info in args
555
+ args.version = __version__
556
+
557
+ # logging info
558
+ if args.verbose > 0:
559
+ logging.basicConfig(
560
+ level=logging.INFO,
561
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
562
+ )
563
+ else:
564
+ logging.basicConfig(
565
+ level=logging.WARN,
566
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
567
+ )
568
+ logging.warning("Skip DEBUG/INFO messages")
569
+
570
+ # If --ngpu is not given,
571
+ # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
572
+ # 2. if nvidia-smi exists, use all devices
573
+ # 3. else ngpu=0
574
+ if args.ngpu is None:
575
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
576
+ if cvd is not None:
577
+ ngpu = len(cvd.split(","))
578
+ else:
579
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
580
+ try:
581
+ p = subprocess.run(
582
+ ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
583
+ )
584
+ except (subprocess.CalledProcessError, FileNotFoundError):
585
+ ngpu = 0
586
+ else:
587
+ ngpu = len(p.stderr.decode().split("\n")) - 1
588
+ else:
589
+ if is_torch_1_2_plus and args.ngpu != 1:
590
+ logging.debug(
591
+ "There are some bugs with multi-GPU processing in PyTorch 1.2+"
592
+ + " (see https://github.com/pytorch/pytorch/issues/21108)"
593
+ )
594
+ ngpu = args.ngpu
595
+ logging.info(f"ngpu: {ngpu}")
596
+
597
+ # display PYTHONPATH
598
+ logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
599
+
600
+ # set random seed
601
+ logging.info("random seed = %d" % args.seed)
602
+ random.seed(args.seed)
603
+ np.random.seed(args.seed)
604
+
605
+ # load dictionary for debug log
606
+ if args.dict is not None:
607
+ with open(args.dict, "rb") as f:
608
+ dictionary = f.readlines()
609
+ char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary]
610
+ char_list.insert(0, "<blank>")
611
+ char_list.append("<eos>")
612
+ # for non-autoregressive maskctc model
613
+ if "maskctc" in args.model_module:
614
+ char_list.append("<mask>")
615
+ args.char_list = char_list
616
+ else:
617
+ args.char_list = None
618
+
619
+ # train
620
+ logging.info("backend = " + args.backend)
621
+
622
+ if args.num_spkrs == 1:
623
+ if args.backend == "chainer":
624
+ from espnet.asr.chainer_backend.asr import train
625
+
626
+ train(args)
627
+ elif args.backend == "pytorch":
628
+ from espnet.asr.pytorch_backend.asr import train
629
+
630
+ train(args)
631
+ else:
632
+ raise ValueError("Only chainer and pytorch are supported.")
633
+ else:
634
+ # FIXME(kamo): Support --model-module
635
+ if args.backend == "pytorch":
636
+ from espnet.asr.pytorch_backend.asr_mix import train
637
+
638
+ train(args)
639
+ else:
640
+ raise ValueError("Only pytorch is supported.")
641
+
642
+
643
+ if __name__ == "__main__":
644
+ main(sys.argv[1:])
espnet/bin/lm_train.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ # This code is ported from the following implementation written in Torch.
7
+ # https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
8
+
9
+ """Language model training script."""
10
+
11
+ import logging
12
+ import os
13
+ import random
14
+ import subprocess
15
+ import sys
16
+
17
+ import configargparse
18
+ import numpy as np
19
+
20
+ from espnet import __version__
21
+ from espnet.nets.lm_interface import dynamic_import_lm
22
+ from espnet.optimizer.factory import dynamic_import_optimizer
23
+ from espnet.scheduler.scheduler import dynamic_import_scheduler
24
+
25
+
26
+ # NOTE: you need this func to generate our sphinx doc
27
+ def get_parser(parser=None, required=True):
28
+ """Get parser."""
29
+ if parser is None:
30
+ parser = configargparse.ArgumentParser(
31
+ description="Train a new language model on one CPU or one GPU",
32
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
33
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
34
+ )
35
+ # general configuration
36
+ parser.add("--config", is_config_file=True, help="config file path")
37
+ parser.add(
38
+ "--config2",
39
+ is_config_file=True,
40
+ help="second config file path that overwrites the settings in `--config`.",
41
+ )
42
+ parser.add(
43
+ "--config3",
44
+ is_config_file=True,
45
+ help="third config file path that overwrites the settings "
46
+ "in `--config` and `--config2`.",
47
+ )
48
+
49
+ parser.add_argument(
50
+ "--ngpu",
51
+ default=None,
52
+ type=int,
53
+ help="Number of GPUs. If not given, use all visible devices",
54
+ )
55
+ parser.add_argument(
56
+ "--train-dtype",
57
+ default="float32",
58
+ choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"],
59
+ help="Data type for training (only pytorch backend). "
60
+ "O0,O1,.. flags require apex. "
61
+ "See https://nvidia.github.io/apex/amp.html#opt-levels",
62
+ )
63
+ parser.add_argument(
64
+ "--backend",
65
+ default="chainer",
66
+ type=str,
67
+ choices=["chainer", "pytorch"],
68
+ help="Backend library",
69
+ )
70
+ parser.add_argument(
71
+ "--outdir", type=str, required=required, help="Output directory"
72
+ )
73
+ parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
74
+ parser.add_argument("--dict", type=str, required=required, help="Dictionary")
75
+ parser.add_argument("--seed", default=1, type=int, help="Random seed")
76
+ parser.add_argument(
77
+ "--resume",
78
+ "-r",
79
+ default="",
80
+ nargs="?",
81
+ help="Resume the training from snapshot",
82
+ )
83
+ parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
84
+ parser.add_argument(
85
+ "--tensorboard-dir",
86
+ default=None,
87
+ type=str,
88
+ nargs="?",
89
+ help="Tensorboard log dir path",
90
+ )
91
+ parser.add_argument(
92
+ "--report-interval-iters",
93
+ default=100,
94
+ type=int,
95
+ help="Report interval iterations",
96
+ )
97
+ # task related
98
+ parser.add_argument(
99
+ "--train-label",
100
+ type=str,
101
+ required=required,
102
+ help="Filename of train label data",
103
+ )
104
+ parser.add_argument(
105
+ "--valid-label",
106
+ type=str,
107
+ required=required,
108
+ help="Filename of validation label data",
109
+ )
110
+ parser.add_argument("--test-label", type=str, help="Filename of test label data")
111
+ parser.add_argument(
112
+ "--dump-hdf5-path",
113
+ type=str,
114
+ default=None,
115
+ help="Path to dump a preprocessed dataset as hdf5",
116
+ )
117
+ # training configuration
118
+ parser.add_argument("--opt", default="sgd", type=str, help="Optimizer")
119
+ parser.add_argument(
120
+ "--sortagrad",
121
+ default=0,
122
+ type=int,
123
+ nargs="?",
124
+ help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
125
+ )
126
+ parser.add_argument(
127
+ "--batchsize",
128
+ "-b",
129
+ type=int,
130
+ default=300,
131
+ help="Number of examples in each mini-batch",
132
+ )
133
+ parser.add_argument(
134
+ "--accum-grad", type=int, default=1, help="Number of gradient accumueration"
135
+ )
136
+ parser.add_argument(
137
+ "--epoch",
138
+ "-e",
139
+ type=int,
140
+ default=20,
141
+ help="Number of sweeps over the dataset to train",
142
+ )
143
+ parser.add_argument(
144
+ "--early-stop-criterion",
145
+ default="validation/main/loss",
146
+ type=str,
147
+ nargs="?",
148
+ help="Value to monitor to trigger an early stopping of the training",
149
+ )
150
+ parser.add_argument(
151
+ "--patience",
152
+ default=3,
153
+ type=int,
154
+ nargs="?",
155
+ help="Number of epochs "
156
+ "to wait without improvement before stopping the training",
157
+ )
158
+ parser.add_argument(
159
+ "--schedulers",
160
+ default=None,
161
+ action="append",
162
+ type=lambda kv: kv.split("="),
163
+ help="optimizer schedulers, you can configure params like:"
164
+ " <optimizer-param>-<scheduler-name>-<schduler-param>"
165
+ ' e.g., "--schedulers lr=noam --lr-noam-warmup 1000".',
166
+ )
167
+ parser.add_argument(
168
+ "--gradclip",
169
+ "-c",
170
+ type=float,
171
+ default=5,
172
+ help="Gradient norm threshold to clip",
173
+ )
174
+ parser.add_argument(
175
+ "--maxlen",
176
+ type=int,
177
+ default=40,
178
+ help="Batch size is reduced if the input sequence > ML",
179
+ )
180
+ parser.add_argument(
181
+ "--model-module",
182
+ type=str,
183
+ default="default",
184
+ help="model defined module "
185
+ "(default: espnet.nets.xxx_backend.lm.default:DefaultRNNLM)",
186
+ )
187
+ return parser
188
+
189
+
190
+ def main(cmd_args):
191
+ """Train LM."""
192
+ parser = get_parser()
193
+ args, _ = parser.parse_known_args(cmd_args)
194
+ if args.backend == "chainer" and args.train_dtype != "float32":
195
+ raise NotImplementedError(
196
+ f"chainer backend does not support --train-dtype {args.train_dtype}."
197
+ "Use --dtype float32."
198
+ )
199
+ if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"):
200
+ raise ValueError(
201
+ f"--train-dtype {args.train_dtype} does not support the CPU backend."
202
+ )
203
+
204
+ # parse arguments dynamically
205
+ model_class = dynamic_import_lm(args.model_module, args.backend)
206
+ model_class.add_arguments(parser)
207
+ if args.schedulers is not None:
208
+ for k, v in args.schedulers:
209
+ scheduler_class = dynamic_import_scheduler(v)
210
+ scheduler_class.add_arguments(k, parser)
211
+
212
+ opt_class = dynamic_import_optimizer(args.opt, args.backend)
213
+ opt_class.add_arguments(parser)
214
+
215
+ args = parser.parse_args(cmd_args)
216
+
217
+ # add version info in args
218
+ args.version = __version__
219
+
220
+ # logging info
221
+ if args.verbose > 0:
222
+ logging.basicConfig(
223
+ level=logging.INFO,
224
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
225
+ )
226
+ else:
227
+ logging.basicConfig(
228
+ level=logging.WARN,
229
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
230
+ )
231
+ logging.warning("Skip DEBUG/INFO messages")
232
+
233
+ # If --ngpu is not given,
234
+ # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
235
+ # 2. if nvidia-smi exists, use all devices
236
+ # 3. else ngpu=0
237
+ if args.ngpu is None:
238
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
239
+ if cvd is not None:
240
+ ngpu = len(cvd.split(","))
241
+ else:
242
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
243
+ try:
244
+ p = subprocess.run(
245
+ ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
246
+ )
247
+ except (subprocess.CalledProcessError, FileNotFoundError):
248
+ ngpu = 0
249
+ else:
250
+ ngpu = len(p.stderr.decode().split("\n")) - 1
251
+ args.ngpu = ngpu
252
+ else:
253
+ ngpu = args.ngpu
254
+ logging.info(f"ngpu: {ngpu}")
255
+
256
+ # display PYTHONPATH
257
+ logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
258
+
259
+ # seed setting
260
+ nseed = args.seed
261
+ random.seed(nseed)
262
+ np.random.seed(nseed)
263
+
264
+ # load dictionary
265
+ with open(args.dict, "rb") as f:
266
+ dictionary = f.readlines()
267
+ char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary]
268
+ char_list.insert(0, "<blank>")
269
+ char_list.append("<eos>")
270
+ args.char_list_dict = {x: i for i, x in enumerate(char_list)}
271
+ args.n_vocab = len(char_list)
272
+
273
+ # train
274
+ logging.info("backend = " + args.backend)
275
+ if args.backend == "chainer":
276
+ from espnet.lm.chainer_backend.lm import train
277
+
278
+ train(args)
279
+ elif args.backend == "pytorch":
280
+ from espnet.lm.pytorch_backend.lm import train
281
+
282
+ train(args)
283
+ else:
284
+ raise ValueError("Only chainer and pytorch are supported.")
285
+
286
+
287
+ if __name__ == "__main__":
288
+ main(sys.argv[1:])
espnet/bin/mt_train.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+
4
+ # Copyright 2019 Kyoto University (Hirofumi Inaguma)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Neural machine translation model training script."""
8
+
9
+ import logging
10
+ import os
11
+ import random
12
+ import subprocess
13
+ import sys
14
+
15
+ from distutils.version import LooseVersion
16
+
17
+ import configargparse
18
+ import numpy as np
19
+ import torch
20
+
21
+ from espnet import __version__
22
+ from espnet.utils.cli_utils import strtobool
23
+ from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES
24
+
25
+ is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2")
26
+
27
+
28
+ # NOTE: you need this func to generate our sphinx doc
29
+ def get_parser(parser=None, required=True):
30
+ """Get default arguments."""
31
+ if parser is None:
32
+ parser = configargparse.ArgumentParser(
33
+ description="Train a neural machine translation (NMT) model on one CPU, "
34
+ "one or multiple GPUs",
35
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
36
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
37
+ )
38
+ # general configuration
39
+ parser.add("--config", is_config_file=True, help="config file path")
40
+ parser.add(
41
+ "--config2",
42
+ is_config_file=True,
43
+ help="second config file path that overwrites the settings in `--config`.",
44
+ )
45
+ parser.add(
46
+ "--config3",
47
+ is_config_file=True,
48
+ help="third config file path that overwrites the settings "
49
+ "in `--config` and `--config2`.",
50
+ )
51
+
52
+ parser.add_argument(
53
+ "--ngpu",
54
+ default=None,
55
+ type=int,
56
+ help="Number of GPUs. If not given, use all visible devices",
57
+ )
58
+ parser.add_argument(
59
+ "--train-dtype",
60
+ default="float32",
61
+ choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"],
62
+ help="Data type for training (only pytorch backend). "
63
+ "O0,O1,.. flags require apex. "
64
+ "See https://nvidia.github.io/apex/amp.html#opt-levels",
65
+ )
66
+ parser.add_argument(
67
+ "--backend",
68
+ default="chainer",
69
+ type=str,
70
+ choices=["chainer", "pytorch"],
71
+ help="Backend library",
72
+ )
73
+ parser.add_argument(
74
+ "--outdir", type=str, required=required, help="Output directory"
75
+ )
76
+ parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
77
+ parser.add_argument(
78
+ "--dict", required=required, help="Dictionary for source/target languages"
79
+ )
80
+ parser.add_argument("--seed", default=1, type=int, help="Random seed")
81
+ parser.add_argument("--debugdir", type=str, help="Output directory for debugging")
82
+ parser.add_argument(
83
+ "--resume",
84
+ "-r",
85
+ default="",
86
+ nargs="?",
87
+ help="Resume the training from snapshot",
88
+ )
89
+ parser.add_argument(
90
+ "--minibatches",
91
+ "-N",
92
+ type=int,
93
+ default="-1",
94
+ help="Process only N minibatches (for debug)",
95
+ )
96
+ parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
97
+ parser.add_argument(
98
+ "--tensorboard-dir",
99
+ default=None,
100
+ type=str,
101
+ nargs="?",
102
+ help="Tensorboard log dir path",
103
+ )
104
+ parser.add_argument(
105
+ "--report-interval-iters",
106
+ default=100,
107
+ type=int,
108
+ help="Report interval iterations",
109
+ )
110
+ parser.add_argument(
111
+ "--save-interval-iters",
112
+ default=0,
113
+ type=int,
114
+ help="Save snapshot interval iterations",
115
+ )
116
+ # task related
117
+ parser.add_argument(
118
+ "--train-json",
119
+ type=str,
120
+ default=None,
121
+ help="Filename of train label data (json)",
122
+ )
123
+ parser.add_argument(
124
+ "--valid-json",
125
+ type=str,
126
+ default=None,
127
+ help="Filename of validation label data (json)",
128
+ )
129
+ # network architecture
130
+ parser.add_argument(
131
+ "--model-module",
132
+ type=str,
133
+ default=None,
134
+ help="model defined module (default: espnet.nets.xxx_backend.e2e_mt:E2E)",
135
+ )
136
+ # loss related
137
+ parser.add_argument(
138
+ "--lsm-weight", default=0.0, type=float, help="Label smoothing weight"
139
+ )
140
+ # translations options to compute BLEU
141
+ parser.add_argument(
142
+ "--report-bleu",
143
+ default=True,
144
+ action="store_true",
145
+ help="Compute BLEU on development set",
146
+ )
147
+ parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
148
+ parser.add_argument("--beam-size", type=int, default=4, help="Beam size")
149
+ parser.add_argument("--penalty", default=0.0, type=float, help="Incertion penalty")
150
+ parser.add_argument(
151
+ "--maxlenratio",
152
+ default=0.0,
153
+ type=float,
154
+ help="""Input length ratio to obtain max output length.
155
+ If maxlenratio=0.0 (default), it uses a end-detect function
156
+ to automatically find maximum hypothesis lengths""",
157
+ )
158
+ parser.add_argument(
159
+ "--minlenratio",
160
+ default=0.0,
161
+ type=float,
162
+ help="Input length ratio to obtain min output length",
163
+ )
164
+ parser.add_argument(
165
+ "--rnnlm", type=str, default=None, help="RNNLM model file to read"
166
+ )
167
+ parser.add_argument(
168
+ "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read"
169
+ )
170
+ parser.add_argument("--lm-weight", default=0.0, type=float, help="RNNLM weight.")
171
+ parser.add_argument("--sym-space", default="<space>", type=str, help="Space symbol")
172
+ parser.add_argument("--sym-blank", default="<blank>", type=str, help="Blank symbol")
173
+ # minibatch related
174
+ parser.add_argument(
175
+ "--sortagrad",
176
+ default=0,
177
+ type=int,
178
+ nargs="?",
179
+ help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
180
+ )
181
+ parser.add_argument(
182
+ "--batch-count",
183
+ default="auto",
184
+ choices=BATCH_COUNT_CHOICES,
185
+ help="How to count batch_size. "
186
+ "The default (auto) will find how to count by args.",
187
+ )
188
+ parser.add_argument(
189
+ "--batch-size",
190
+ "--batch-seqs",
191
+ "-b",
192
+ default=0,
193
+ type=int,
194
+ help="Maximum seqs in a minibatch (0 to disable)",
195
+ )
196
+ parser.add_argument(
197
+ "--batch-bins",
198
+ default=0,
199
+ type=int,
200
+ help="Maximum bins in a minibatch (0 to disable)",
201
+ )
202
+ parser.add_argument(
203
+ "--batch-frames-in",
204
+ default=0,
205
+ type=int,
206
+ help="Maximum input frames in a minibatch (0 to disable)",
207
+ )
208
+ parser.add_argument(
209
+ "--batch-frames-out",
210
+ default=0,
211
+ type=int,
212
+ help="Maximum output frames in a minibatch (0 to disable)",
213
+ )
214
+ parser.add_argument(
215
+ "--batch-frames-inout",
216
+ default=0,
217
+ type=int,
218
+ help="Maximum input+output frames in a minibatch (0 to disable)",
219
+ )
220
+ parser.add_argument(
221
+ "--maxlen-in",
222
+ "--batch-seq-maxlen-in",
223
+ default=100,
224
+ type=int,
225
+ metavar="ML",
226
+ help="When --batch-count=seq, "
227
+ "batch size is reduced if the input sequence length > ML.",
228
+ )
229
+ parser.add_argument(
230
+ "--maxlen-out",
231
+ "--batch-seq-maxlen-out",
232
+ default=100,
233
+ type=int,
234
+ metavar="ML",
235
+ help="When --batch-count=seq, "
236
+ "batch size is reduced if the output sequence length > ML",
237
+ )
238
+ parser.add_argument(
239
+ "--n-iter-processes",
240
+ default=0,
241
+ type=int,
242
+ help="Number of processes of iterator",
243
+ )
244
+ # optimization related
245
+ parser.add_argument(
246
+ "--opt",
247
+ default="adadelta",
248
+ type=str,
249
+ choices=["adadelta", "adam", "noam"],
250
+ help="Optimizer",
251
+ )
252
+ parser.add_argument(
253
+ "--accum-grad", default=1, type=int, help="Number of gradient accumuration"
254
+ )
255
+ parser.add_argument(
256
+ "--eps", default=1e-8, type=float, help="Epsilon constant for optimizer"
257
+ )
258
+ parser.add_argument(
259
+ "--eps-decay", default=0.01, type=float, help="Decaying ratio of epsilon"
260
+ )
261
+ parser.add_argument(
262
+ "--lr", default=1e-3, type=float, help="Learning rate for optimizer"
263
+ )
264
+ parser.add_argument(
265
+ "--lr-decay", default=1.0, type=float, help="Decaying ratio of learning rate"
266
+ )
267
+ parser.add_argument(
268
+ "--weight-decay", default=0.0, type=float, help="Weight decay ratio"
269
+ )
270
+ parser.add_argument(
271
+ "--criterion",
272
+ default="acc",
273
+ type=str,
274
+ choices=["loss", "acc"],
275
+ help="Criterion to perform epsilon decay",
276
+ )
277
+ parser.add_argument(
278
+ "--threshold", default=1e-4, type=float, help="Threshold to stop iteration"
279
+ )
280
+ parser.add_argument(
281
+ "--epochs", "-e", default=30, type=int, help="Maximum number of epochs"
282
+ )
283
+ parser.add_argument(
284
+ "--early-stop-criterion",
285
+ default="validation/main/acc",
286
+ type=str,
287
+ nargs="?",
288
+ help="Value to monitor to trigger an early stopping of the training",
289
+ )
290
+ parser.add_argument(
291
+ "--patience",
292
+ default=3,
293
+ type=int,
294
+ nargs="?",
295
+ help="Number of epochs to wait "
296
+ "without improvement before stopping the training",
297
+ )
298
+ parser.add_argument(
299
+ "--grad-clip", default=5, type=float, help="Gradient norm threshold to clip"
300
+ )
301
+ parser.add_argument(
302
+ "--num-save-attention",
303
+ default=3,
304
+ type=int,
305
+ help="Number of samples of attention to be saved",
306
+ )
307
+ # decoder related
308
+ parser.add_argument(
309
+ "--context-residual",
310
+ default=False,
311
+ type=strtobool,
312
+ nargs="?",
313
+ help="The flag to switch to use context vector residual in the decoder network",
314
+ )
315
+ parser.add_argument(
316
+ "--tie-src-tgt-embedding",
317
+ default=False,
318
+ type=strtobool,
319
+ nargs="?",
320
+ help="Tie parameters of source embedding and target embedding.",
321
+ )
322
+ parser.add_argument(
323
+ "--tie-classifier",
324
+ default=False,
325
+ type=strtobool,
326
+ nargs="?",
327
+ help="Tie parameters of target embedding and output projection layer.",
328
+ )
329
+ # finetuning related
330
+ parser.add_argument(
331
+ "--enc-init",
332
+ default=None,
333
+ type=str,
334
+ nargs="?",
335
+ help="Pre-trained ASR model to initialize encoder.",
336
+ )
337
+ parser.add_argument(
338
+ "--enc-init-mods",
339
+ default="enc.enc.",
340
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
341
+ help="List of encoder modules to initialize, separated by a comma.",
342
+ )
343
+ parser.add_argument(
344
+ "--dec-init",
345
+ default=None,
346
+ type=str,
347
+ nargs="?",
348
+ help="Pre-trained ASR, MT or LM model to initialize decoder.",
349
+ )
350
+ parser.add_argument(
351
+ "--dec-init-mods",
352
+ default="att., dec.",
353
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
354
+ help="List of decoder modules to initialize, separated by a comma.",
355
+ )
356
+ # multilingual related
357
+ parser.add_argument(
358
+ "--multilingual",
359
+ default=False,
360
+ type=strtobool,
361
+ help="Prepend target language ID to the source sentence. "
362
+ "Both source/target language IDs must be prepend in the pre-processing stage.",
363
+ )
364
+ parser.add_argument(
365
+ "--replace-sos",
366
+ default=False,
367
+ type=strtobool,
368
+ help="Replace <sos> in the decoder with a target language ID "
369
+ "(the first token in the target sequence)",
370
+ )
371
+
372
+ return parser
373
+
374
+
375
+ def main(cmd_args):
376
+ """Run the main training function."""
377
+ parser = get_parser()
378
+ args, _ = parser.parse_known_args(cmd_args)
379
+ if args.backend == "chainer" and args.train_dtype != "float32":
380
+ raise NotImplementedError(
381
+ f"chainer backend does not support --train-dtype {args.train_dtype}."
382
+ "Use --dtype float32."
383
+ )
384
+ if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"):
385
+ raise ValueError(
386
+ f"--train-dtype {args.train_dtype} does not support the CPU backend."
387
+ )
388
+
389
+ from espnet.utils.dynamic_import import dynamic_import
390
+
391
+ if args.model_module is None:
392
+ model_module = "espnet.nets." + args.backend + "_backend.e2e_mt:E2E"
393
+ else:
394
+ model_module = args.model_module
395
+ model_class = dynamic_import(model_module)
396
+ model_class.add_arguments(parser)
397
+
398
+ args = parser.parse_args(cmd_args)
399
+ args.model_module = model_module
400
+ if "chainer_backend" in args.model_module:
401
+ args.backend = "chainer"
402
+ if "pytorch_backend" in args.model_module:
403
+ args.backend = "pytorch"
404
+
405
+ # add version info in args
406
+ args.version = __version__
407
+
408
+ # logging info
409
+ if args.verbose > 0:
410
+ logging.basicConfig(
411
+ level=logging.INFO,
412
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
413
+ )
414
+ else:
415
+ logging.basicConfig(
416
+ level=logging.WARN,
417
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
418
+ )
419
+ logging.warning("Skip DEBUG/INFO messages")
420
+
421
+ # If --ngpu is not given,
422
+ # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
423
+ # 2. if nvidia-smi exists, use all devices
424
+ # 3. else ngpu=0
425
+ if args.ngpu is None:
426
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
427
+ if cvd is not None:
428
+ ngpu = len(cvd.split(","))
429
+ else:
430
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
431
+ try:
432
+ p = subprocess.run(
433
+ ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
434
+ )
435
+ except (subprocess.CalledProcessError, FileNotFoundError):
436
+ ngpu = 0
437
+ else:
438
+ ngpu = len(p.stderr.decode().split("\n")) - 1
439
+ args.ngpu = ngpu
440
+ else:
441
+ if is_torch_1_2_plus and args.ngpu != 1:
442
+ logging.debug(
443
+ "There are some bugs with multi-GPU processing in PyTorch 1.2+"
444
+ + " (see https://github.com/pytorch/pytorch/issues/21108)"
445
+ )
446
+ ngpu = args.ngpu
447
+ logging.info(f"ngpu: {ngpu}")
448
+
449
+ # display PYTHONPATH
450
+ logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
451
+
452
+ # set random seed
453
+ logging.info("random seed = %d" % args.seed)
454
+ random.seed(args.seed)
455
+ np.random.seed(args.seed)
456
+
457
+ # load dictionary for debug log
458
+ if args.dict is not None:
459
+ with open(args.dict, "rb") as f:
460
+ dictionary = f.readlines()
461
+ char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary]
462
+ char_list.insert(0, "<blank>")
463
+ char_list.append("<eos>")
464
+ args.char_list = char_list
465
+ else:
466
+ args.char_list = None
467
+
468
+ # train
469
+ logging.info("backend = " + args.backend)
470
+
471
+ if args.backend == "pytorch":
472
+ from espnet.mt.pytorch_backend.mt import train
473
+
474
+ train(args)
475
+ else:
476
+ raise ValueError("Only pytorch are supported.")
477
+
478
+
479
+ if __name__ == "__main__":
480
+ main(sys.argv[1:])
espnet/bin/mt_trans.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+
4
+ # Copyright 2019 Kyoto University (Hirofumi Inaguma)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Neural machine translation model decoding script."""
8
+
9
+ import configargparse
10
+ import logging
11
+ import os
12
+ import random
13
+ import sys
14
+
15
+ import numpy as np
16
+
17
+
18
+ # NOTE: you need this func to generate our sphinx doc
19
+ def get_parser():
20
+ """Get default arguments."""
21
+ parser = configargparse.ArgumentParser(
22
+ description="Translate text from speech "
23
+ "using a speech translation model on one CPU or GPU",
24
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
25
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
26
+ )
27
+ # general configuration
28
+ parser.add("--config", is_config_file=True, help="Config file path")
29
+ parser.add(
30
+ "--config2",
31
+ is_config_file=True,
32
+ help="Second config file path that overwrites the settings in `--config`",
33
+ )
34
+ parser.add(
35
+ "--config3",
36
+ is_config_file=True,
37
+ help="Third config file path "
38
+ "that overwrites the settings in `--config` and `--config2`",
39
+ )
40
+
41
+ parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
42
+ parser.add_argument(
43
+ "--dtype",
44
+ choices=("float16", "float32", "float64"),
45
+ default="float32",
46
+ help="Float precision (only available in --api v2)",
47
+ )
48
+ parser.add_argument(
49
+ "--backend",
50
+ type=str,
51
+ default="chainer",
52
+ choices=["chainer", "pytorch"],
53
+ help="Backend library",
54
+ )
55
+ parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
56
+ parser.add_argument("--seed", type=int, default=1, help="Random seed")
57
+ parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
58
+ parser.add_argument(
59
+ "--batchsize",
60
+ type=int,
61
+ default=1,
62
+ help="Batch size for beam search (0: means no batch processing)",
63
+ )
64
+ parser.add_argument(
65
+ "--preprocess-conf",
66
+ type=str,
67
+ default=None,
68
+ help="The configuration file for the pre-processing",
69
+ )
70
+ parser.add_argument(
71
+ "--api",
72
+ default="v1",
73
+ choices=["v1", "v2"],
74
+ help="Beam search APIs "
75
+ "v1: Default API. It only supports "
76
+ "the ASRInterface.recognize method and DefaultRNNLM. "
77
+ "v2: Experimental API. "
78
+ "It supports any models that implements ScorerInterface.",
79
+ )
80
+ # task related
81
+ parser.add_argument(
82
+ "--trans-json", type=str, help="Filename of translation data (json)"
83
+ )
84
+ parser.add_argument(
85
+ "--result-label",
86
+ type=str,
87
+ required=True,
88
+ help="Filename of result label data (json)",
89
+ )
90
+ # model (parameter) related
91
+ parser.add_argument(
92
+ "--model", type=str, required=True, help="Model file parameters to read"
93
+ )
94
+ parser.add_argument(
95
+ "--model-conf", type=str, default=None, help="Model config file"
96
+ )
97
+ # search related
98
+ parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
99
+ parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
100
+ parser.add_argument("--penalty", type=float, default=0.1, help="Incertion penalty")
101
+ parser.add_argument(
102
+ "--maxlenratio",
103
+ type=float,
104
+ default=3.0,
105
+ help="""Input length ratio to obtain max output length.
106
+ If maxlenratio=0.0 (default), it uses a end-detect function
107
+ to automatically find maximum hypothesis lengths""",
108
+ )
109
+ parser.add_argument(
110
+ "--minlenratio",
111
+ type=float,
112
+ default=0.0,
113
+ help="Input length ratio to obtain min output length",
114
+ )
115
+ # multilingual related
116
+ parser.add_argument(
117
+ "--tgt-lang",
118
+ default=False,
119
+ type=str,
120
+ help="target language ID (e.g., <en>, <de>, and <fr> etc.)",
121
+ )
122
+ return parser
123
+
124
+
125
+ def main(args):
126
+ """Run the main decoding function."""
127
+ parser = get_parser()
128
+ args = parser.parse_args(args)
129
+
130
+ # logging info
131
+ if args.verbose == 1:
132
+ logging.basicConfig(
133
+ level=logging.INFO,
134
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
135
+ )
136
+ elif args.verbose == 2:
137
+ logging.basicConfig(
138
+ level=logging.DEBUG,
139
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
140
+ )
141
+ else:
142
+ logging.basicConfig(
143
+ level=logging.WARN,
144
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
145
+ )
146
+ logging.warning("Skip DEBUG/INFO messages")
147
+
148
+ # check CUDA_VISIBLE_DEVICES
149
+ if args.ngpu > 0:
150
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
151
+ if cvd is None:
152
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
153
+ elif args.ngpu != len(cvd.split(",")):
154
+ logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
155
+ sys.exit(1)
156
+
157
+ # TODO(mn5k): support of multiple GPUs
158
+ if args.ngpu > 1:
159
+ logging.error("The program only supports ngpu=1.")
160
+ sys.exit(1)
161
+
162
+ # display PYTHONPATH
163
+ logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
164
+
165
+ # seed setting
166
+ random.seed(args.seed)
167
+ np.random.seed(args.seed)
168
+ logging.info("set random seed = %d" % args.seed)
169
+
170
+ # trans
171
+ logging.info("backend = " + args.backend)
172
+ if args.backend == "pytorch":
173
+ # Experimental API that supports custom LMs
174
+ from espnet.mt.pytorch_backend.mt import trans
175
+
176
+ if args.dtype != "float32":
177
+ raise NotImplementedError(
178
+ f"`--dtype {args.dtype}` is only available with `--api v2`"
179
+ )
180
+ trans(args)
181
+ else:
182
+ raise ValueError("Only pytorch are supported.")
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main(sys.argv[1:])
espnet/bin/st_train.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+
4
+ # Copyright 2019 Kyoto University (Hirofumi Inaguma)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """End-to-end speech translation model training script."""
8
+
9
+ from distutils.version import LooseVersion
10
+ import logging
11
+ import os
12
+ import random
13
+ import subprocess
14
+ import sys
15
+
16
+ import configargparse
17
+ import numpy as np
18
+ import torch
19
+
20
+ from espnet import __version__
21
+ from espnet.utils.cli_utils import strtobool
22
+ from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES
23
+
24
+ is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2")
25
+
26
+
27
+ # NOTE: you need this func to generate our sphinx doc
28
+ def get_parser(parser=None, required=True):
29
+ """Get default arguments."""
30
+ if parser is None:
31
+ parser = configargparse.ArgumentParser(
32
+ description="Train a speech translation (ST) model on one CPU, "
33
+ "one or multiple GPUs",
34
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
35
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
36
+ )
37
+ # general configuration
38
+ parser.add("--config", is_config_file=True, help="config file path")
39
+ parser.add(
40
+ "--config2",
41
+ is_config_file=True,
42
+ help="second config file path that overwrites the settings in `--config`.",
43
+ )
44
+ parser.add(
45
+ "--config3",
46
+ is_config_file=True,
47
+ help="third config file path that overwrites the settings "
48
+ "in `--config` and `--config2`.",
49
+ )
50
+
51
+ parser.add_argument(
52
+ "--ngpu",
53
+ default=None,
54
+ type=int,
55
+ help="Number of GPUs. If not given, use all visible devices",
56
+ )
57
+ parser.add_argument(
58
+ "--train-dtype",
59
+ default="float32",
60
+ choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"],
61
+ help="Data type for training (only pytorch backend). "
62
+ "O0,O1,.. flags require apex. "
63
+ "See https://nvidia.github.io/apex/amp.html#opt-levels",
64
+ )
65
+ parser.add_argument(
66
+ "--backend",
67
+ default="chainer",
68
+ type=str,
69
+ choices=["chainer", "pytorch"],
70
+ help="Backend library",
71
+ )
72
+ parser.add_argument(
73
+ "--outdir", type=str, required=required, help="Output directory"
74
+ )
75
+ parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
76
+ parser.add_argument("--dict", required=required, help="Dictionary")
77
+ parser.add_argument("--seed", default=1, type=int, help="Random seed")
78
+ parser.add_argument("--debugdir", type=str, help="Output directory for debugging")
79
+ parser.add_argument(
80
+ "--resume",
81
+ "-r",
82
+ default="",
83
+ nargs="?",
84
+ help="Resume the training from snapshot",
85
+ )
86
+ parser.add_argument(
87
+ "--minibatches",
88
+ "-N",
89
+ type=int,
90
+ default="-1",
91
+ help="Process only N minibatches (for debug)",
92
+ )
93
+ parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
94
+ parser.add_argument(
95
+ "--tensorboard-dir",
96
+ default=None,
97
+ type=str,
98
+ nargs="?",
99
+ help="Tensorboard log dir path",
100
+ )
101
+ parser.add_argument(
102
+ "--report-interval-iters",
103
+ default=100,
104
+ type=int,
105
+ help="Report interval iterations",
106
+ )
107
+ parser.add_argument(
108
+ "--save-interval-iters",
109
+ default=0,
110
+ type=int,
111
+ help="Save snapshot interval iterations",
112
+ )
113
+ # task related
114
+ parser.add_argument(
115
+ "--train-json",
116
+ type=str,
117
+ default=None,
118
+ help="Filename of train label data (json)",
119
+ )
120
+ parser.add_argument(
121
+ "--valid-json",
122
+ type=str,
123
+ default=None,
124
+ help="Filename of validation label data (json)",
125
+ )
126
+ # network architecture
127
+ parser.add_argument(
128
+ "--model-module",
129
+ type=str,
130
+ default=None,
131
+ help="model defined module (default: espnet.nets.xxx_backend.e2e_st:E2E)",
132
+ )
133
+ # loss related
134
+ parser.add_argument(
135
+ "--ctc_type",
136
+ default="warpctc",
137
+ type=str,
138
+ choices=["builtin", "warpctc", "gtnctc", "cudnnctc"],
139
+ help="Type of CTC implementation to calculate loss.",
140
+ )
141
+ parser.add_argument(
142
+ "--mtlalpha",
143
+ default=0.0,
144
+ type=float,
145
+ help="Multitask learning coefficient, alpha: \
146
+ alpha*ctc_loss + (1-alpha)*att_loss",
147
+ )
148
+ parser.add_argument(
149
+ "--asr-weight",
150
+ default=0.0,
151
+ type=float,
152
+ help="Multitask learning coefficient for ASR task, weight: "
153
+ " asr_weight*(alpha*ctc_loss + (1-alpha)*att_loss)"
154
+ " + (1-asr_weight-mt_weight)*st_loss",
155
+ )
156
+ parser.add_argument(
157
+ "--mt-weight",
158
+ default=0.0,
159
+ type=float,
160
+ help="Multitask learning coefficient for MT task, weight: \
161
+ mt_weight*mt_loss + (1-mt_weight-asr_weight)*st_loss",
162
+ )
163
+ parser.add_argument(
164
+ "--lsm-weight", default=0.0, type=float, help="Label smoothing weight"
165
+ )
166
+ # recognition options to compute CER/WER
167
+ parser.add_argument(
168
+ "--report-cer",
169
+ default=False,
170
+ action="store_true",
171
+ help="Compute CER on development set",
172
+ )
173
+ parser.add_argument(
174
+ "--report-wer",
175
+ default=False,
176
+ action="store_true",
177
+ help="Compute WER on development set",
178
+ )
179
+ # translations options to compute BLEU
180
+ parser.add_argument(
181
+ "--report-bleu",
182
+ default=True,
183
+ action="store_true",
184
+ help="Compute BLEU on development set",
185
+ )
186
+ parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
187
+ parser.add_argument("--beam-size", type=int, default=4, help="Beam size")
188
+ parser.add_argument("--penalty", default=0.0, type=float, help="Incertion penalty")
189
+ parser.add_argument(
190
+ "--maxlenratio",
191
+ default=0.0,
192
+ type=float,
193
+ help="""Input length ratio to obtain max output length.
194
+ If maxlenratio=0.0 (default), it uses a end-detect function
195
+ to automatically find maximum hypothesis lengths""",
196
+ )
197
+ parser.add_argument(
198
+ "--minlenratio",
199
+ default=0.0,
200
+ type=float,
201
+ help="Input length ratio to obtain min output length",
202
+ )
203
+ parser.add_argument(
204
+ "--rnnlm", type=str, default=None, help="RNNLM model file to read"
205
+ )
206
+ parser.add_argument(
207
+ "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read"
208
+ )
209
+ parser.add_argument("--lm-weight", default=0.0, type=float, help="RNNLM weight.")
210
+ parser.add_argument("--sym-space", default="<space>", type=str, help="Space symbol")
211
+ parser.add_argument("--sym-blank", default="<blank>", type=str, help="Blank symbol")
212
+ # minibatch related
213
+ parser.add_argument(
214
+ "--sortagrad",
215
+ default=0,
216
+ type=int,
217
+ nargs="?",
218
+ help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
219
+ )
220
+ parser.add_argument(
221
+ "--batch-count",
222
+ default="auto",
223
+ choices=BATCH_COUNT_CHOICES,
224
+ help="How to count batch_size. "
225
+ "The default (auto) will find how to count by args.",
226
+ )
227
+ parser.add_argument(
228
+ "--batch-size",
229
+ "--batch-seqs",
230
+ "-b",
231
+ default=0,
232
+ type=int,
233
+ help="Maximum seqs in a minibatch (0 to disable)",
234
+ )
235
+ parser.add_argument(
236
+ "--batch-bins",
237
+ default=0,
238
+ type=int,
239
+ help="Maximum bins in a minibatch (0 to disable)",
240
+ )
241
+ parser.add_argument(
242
+ "--batch-frames-in",
243
+ default=0,
244
+ type=int,
245
+ help="Maximum input frames in a minibatch (0 to disable)",
246
+ )
247
+ parser.add_argument(
248
+ "--batch-frames-out",
249
+ default=0,
250
+ type=int,
251
+ help="Maximum output frames in a minibatch (0 to disable)",
252
+ )
253
+ parser.add_argument(
254
+ "--batch-frames-inout",
255
+ default=0,
256
+ type=int,
257
+ help="Maximum input+output frames in a minibatch (0 to disable)",
258
+ )
259
+ parser.add_argument(
260
+ "--maxlen-in",
261
+ "--batch-seq-maxlen-in",
262
+ default=800,
263
+ type=int,
264
+ metavar="ML",
265
+ help="When --batch-count=seq, batch size is reduced "
266
+ "if the input sequence length > ML.",
267
+ )
268
+ parser.add_argument(
269
+ "--maxlen-out",
270
+ "--batch-seq-maxlen-out",
271
+ default=150,
272
+ type=int,
273
+ metavar="ML",
274
+ help="When --batch-count=seq, "
275
+ "batch size is reduced if the output sequence length > ML",
276
+ )
277
+ parser.add_argument(
278
+ "--n-iter-processes",
279
+ default=0,
280
+ type=int,
281
+ help="Number of processes of iterator",
282
+ )
283
+ parser.add_argument(
284
+ "--preprocess-conf",
285
+ type=str,
286
+ default=None,
287
+ nargs="?",
288
+ help="The configuration file for the pre-processing",
289
+ )
290
+ # optimization related
291
+ parser.add_argument(
292
+ "--opt",
293
+ default="adadelta",
294
+ type=str,
295
+ choices=["adadelta", "adam", "noam"],
296
+ help="Optimizer",
297
+ )
298
+ parser.add_argument(
299
+ "--accum-grad", default=1, type=int, help="Number of gradient accumuration"
300
+ )
301
+ parser.add_argument(
302
+ "--eps", default=1e-8, type=float, help="Epsilon constant for optimizer"
303
+ )
304
+ parser.add_argument(
305
+ "--eps-decay", default=0.01, type=float, help="Decaying ratio of epsilon"
306
+ )
307
+ parser.add_argument(
308
+ "--lr", default=1e-3, type=float, help="Learning rate for optimizer"
309
+ )
310
+ parser.add_argument(
311
+ "--lr-decay", default=1.0, type=float, help="Decaying ratio of learning rate"
312
+ )
313
+ parser.add_argument(
314
+ "--weight-decay", default=0.0, type=float, help="Weight decay ratio"
315
+ )
316
+ parser.add_argument(
317
+ "--criterion",
318
+ default="acc",
319
+ type=str,
320
+ choices=["loss", "acc"],
321
+ help="Criterion to perform epsilon decay",
322
+ )
323
+ parser.add_argument(
324
+ "--threshold", default=1e-4, type=float, help="Threshold to stop iteration"
325
+ )
326
+ parser.add_argument(
327
+ "--epochs", "-e", default=30, type=int, help="Maximum number of epochs"
328
+ )
329
+ parser.add_argument(
330
+ "--early-stop-criterion",
331
+ default="validation/main/acc",
332
+ type=str,
333
+ nargs="?",
334
+ help="Value to monitor to trigger an early stopping of the training",
335
+ )
336
+ parser.add_argument(
337
+ "--patience",
338
+ default=3,
339
+ type=int,
340
+ nargs="?",
341
+ help="Number of epochs to wait "
342
+ "without improvement before stopping the training",
343
+ )
344
+ parser.add_argument(
345
+ "--grad-clip", default=5, type=float, help="Gradient norm threshold to clip"
346
+ )
347
+ parser.add_argument(
348
+ "--num-save-attention",
349
+ default=3,
350
+ type=int,
351
+ help="Number of samples of attention to be saved",
352
+ )
353
+ parser.add_argument(
354
+ "--num-save-ctc",
355
+ default=3,
356
+ type=int,
357
+ help="Number of samples of CTC probability to be saved",
358
+ )
359
+ parser.add_argument(
360
+ "--grad-noise",
361
+ type=strtobool,
362
+ default=False,
363
+ help="The flag to switch to use noise injection to gradients during training",
364
+ )
365
+ # speech translation related
366
+ parser.add_argument(
367
+ "--context-residual",
368
+ default=False,
369
+ type=strtobool,
370
+ nargs="?",
371
+ help="The flag to switch to use context vector residual in the decoder network",
372
+ )
373
+ # finetuning related
374
+ parser.add_argument(
375
+ "--enc-init",
376
+ default=None,
377
+ type=str,
378
+ nargs="?",
379
+ help="Pre-trained ASR model to initialize encoder.",
380
+ )
381
+ parser.add_argument(
382
+ "--enc-init-mods",
383
+ default="enc.enc.",
384
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
385
+ help="List of encoder modules to initialize, separated by a comma.",
386
+ )
387
+ parser.add_argument(
388
+ "--dec-init",
389
+ default=None,
390
+ type=str,
391
+ nargs="?",
392
+ help="Pre-trained ASR, MT or LM model to initialize decoder.",
393
+ )
394
+ parser.add_argument(
395
+ "--dec-init-mods",
396
+ default="att., dec.",
397
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
398
+ help="List of decoder modules to initialize, separated by a comma.",
399
+ )
400
+ # multilingual related
401
+ parser.add_argument(
402
+ "--multilingual",
403
+ default=False,
404
+ type=strtobool,
405
+ help="Prepend target language ID to the source sentence. "
406
+ " Both source/target language IDs must be prepend in the pre-processing stage.",
407
+ )
408
+ parser.add_argument(
409
+ "--replace-sos",
410
+ default=False,
411
+ type=strtobool,
412
+ help="Replace <sos> in the decoder with a target language ID \
413
+ (the first token in the target sequence)",
414
+ )
415
+ # Feature transform: Normalization
416
+ parser.add_argument(
417
+ "--stats-file",
418
+ type=str,
419
+ default=None,
420
+ help="The stats file for the feature normalization",
421
+ )
422
+ parser.add_argument(
423
+ "--apply-uttmvn",
424
+ type=strtobool,
425
+ default=True,
426
+ help="Apply utterance level mean " "variance normalization.",
427
+ )
428
+ parser.add_argument("--uttmvn-norm-means", type=strtobool, default=True, help="")
429
+ parser.add_argument("--uttmvn-norm-vars", type=strtobool, default=False, help="")
430
+ # Feature transform: Fbank
431
+ parser.add_argument(
432
+ "--fbank-fs",
433
+ type=int,
434
+ default=16000,
435
+ help="The sample frequency used for " "the mel-fbank creation.",
436
+ )
437
+ parser.add_argument(
438
+ "--n-mels", type=int, default=80, help="The number of mel-frequency bins."
439
+ )
440
+ parser.add_argument("--fbank-fmin", type=float, default=0.0, help="")
441
+ parser.add_argument("--fbank-fmax", type=float, default=None, help="")
442
+ return parser
443
+
444
+
445
+ def main(cmd_args):
446
+ """Run the main training function."""
447
+ parser = get_parser()
448
+ args, _ = parser.parse_known_args(cmd_args)
449
+ if args.backend == "chainer" and args.train_dtype != "float32":
450
+ raise NotImplementedError(
451
+ f"chainer backend does not support --train-dtype {args.train_dtype}."
452
+ "Use --dtype float32."
453
+ )
454
+ if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"):
455
+ raise ValueError(
456
+ f"--train-dtype {args.train_dtype} does not support the CPU backend."
457
+ )
458
+
459
+ from espnet.utils.dynamic_import import dynamic_import
460
+
461
+ if args.model_module is None:
462
+ model_module = "espnet.nets." + args.backend + "_backend.e2e_st:E2E"
463
+ else:
464
+ model_module = args.model_module
465
+ model_class = dynamic_import(model_module)
466
+ model_class.add_arguments(parser)
467
+
468
+ args = parser.parse_args(cmd_args)
469
+ args.model_module = model_module
470
+ if "chainer_backend" in args.model_module:
471
+ args.backend = "chainer"
472
+ if "pytorch_backend" in args.model_module:
473
+ args.backend = "pytorch"
474
+
475
+ # add version info in args
476
+ args.version = __version__
477
+
478
+ # logging info
479
+ if args.verbose > 0:
480
+ logging.basicConfig(
481
+ level=logging.INFO,
482
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
483
+ )
484
+ else:
485
+ logging.basicConfig(
486
+ level=logging.WARN,
487
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
488
+ )
489
+ logging.warning("Skip DEBUG/INFO messages")
490
+
491
+ # If --ngpu is not given,
492
+ # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
493
+ # 2. if nvidia-smi exists, use all devices
494
+ # 3. else ngpu=0
495
+ if args.ngpu is None:
496
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
497
+ if cvd is not None:
498
+ ngpu = len(cvd.split(","))
499
+ else:
500
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
501
+ try:
502
+ p = subprocess.run(
503
+ ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
504
+ )
505
+ except (subprocess.CalledProcessError, FileNotFoundError):
506
+ ngpu = 0
507
+ else:
508
+ ngpu = len(p.stderr.decode().split("\n")) - 1
509
+ args.ngpu = ngpu
510
+ else:
511
+ if is_torch_1_2_plus and args.ngpu != 1:
512
+ logging.debug(
513
+ "There are some bugs with multi-GPU processing in PyTorch 1.2+"
514
+ + " (see https://github.com/pytorch/pytorch/issues/21108)"
515
+ )
516
+ ngpu = args.ngpu
517
+ logging.info(f"ngpu: {ngpu}")
518
+
519
+ # display PYTHONPATH
520
+ logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
521
+
522
+ # set random seed
523
+ logging.info("random seed = %d" % args.seed)
524
+ random.seed(args.seed)
525
+ np.random.seed(args.seed)
526
+
527
+ # load dictionary for debug log
528
+ if args.dict is not None:
529
+ with open(args.dict, "rb") as f:
530
+ dictionary = f.readlines()
531
+ char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary]
532
+ char_list.insert(0, "<blank>")
533
+ char_list.append("<eos>")
534
+ args.char_list = char_list
535
+ else:
536
+ args.char_list = None
537
+
538
+ # train
539
+ logging.info("backend = " + args.backend)
540
+
541
+ if args.backend == "pytorch":
542
+ from espnet.st.pytorch_backend.st import train
543
+
544
+ train(args)
545
+ else:
546
+ raise ValueError("Only pytorch are supported.")
547
+
548
+
549
+ if __name__ == "__main__":
550
+ main(sys.argv[1:])
espnet/bin/st_trans.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+
4
+ # Copyright 2019 Kyoto University (Hirofumi Inaguma)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """End-to-end speech translation model decoding script."""
8
+
9
+ import logging
10
+ import os
11
+ import random
12
+ import sys
13
+
14
+ import configargparse
15
+ import numpy as np
16
+
17
+
18
+ # NOTE: you need this func to generate our sphinx doc
19
+ def get_parser():
20
+ """Get default arguments."""
21
+ parser = configargparse.ArgumentParser(
22
+ description="Translate text from speech using a speech translation "
23
+ "model on one CPU or GPU",
24
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
25
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
26
+ )
27
+ # general configuration
28
+ parser.add("--config", is_config_file=True, help="Config file path")
29
+ parser.add(
30
+ "--config2",
31
+ is_config_file=True,
32
+ help="Second config file path that overwrites the settings in `--config`",
33
+ )
34
+ parser.add(
35
+ "--config3",
36
+ is_config_file=True,
37
+ help="Third config file path that overwrites "
38
+ "the settings in `--config` and `--config2`",
39
+ )
40
+
41
+ parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
42
+ parser.add_argument(
43
+ "--dtype",
44
+ choices=("float16", "float32", "float64"),
45
+ default="float32",
46
+ help="Float precision (only available in --api v2)",
47
+ )
48
+ parser.add_argument(
49
+ "--backend",
50
+ type=str,
51
+ default="chainer",
52
+ choices=["chainer", "pytorch"],
53
+ help="Backend library",
54
+ )
55
+ parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
56
+ parser.add_argument("--seed", type=int, default=1, help="Random seed")
57
+ parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
58
+ parser.add_argument(
59
+ "--batchsize",
60
+ type=int,
61
+ default=1,
62
+ help="Batch size for beam search (0: means no batch processing)",
63
+ )
64
+ parser.add_argument(
65
+ "--preprocess-conf",
66
+ type=str,
67
+ default=None,
68
+ help="The configuration file for the pre-processing",
69
+ )
70
+ parser.add_argument(
71
+ "--api",
72
+ default="v1",
73
+ choices=["v1", "v2"],
74
+ help="Beam search APIs "
75
+ "v1: Default API. "
76
+ "It only supports the ASRInterface.recognize method and DefaultRNNLM. "
77
+ "v2: Experimental API. "
78
+ "It supports any models that implements ScorerInterface.",
79
+ )
80
+ # task related
81
+ parser.add_argument(
82
+ "--trans-json", type=str, help="Filename of translation data (json)"
83
+ )
84
+ parser.add_argument(
85
+ "--result-label",
86
+ type=str,
87
+ required=True,
88
+ help="Filename of result label data (json)",
89
+ )
90
+ # model (parameter) related
91
+ parser.add_argument(
92
+ "--model", type=str, required=True, help="Model file parameters to read"
93
+ )
94
+ # search related
95
+ parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
96
+ parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
97
+ parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty")
98
+ parser.add_argument(
99
+ "--maxlenratio",
100
+ type=float,
101
+ default=0.0,
102
+ help="""Input length ratio to obtain max output length.
103
+ If maxlenratio=0.0 (default), it uses a end-detect function
104
+ to automatically find maximum hypothesis lengths""",
105
+ )
106
+ parser.add_argument(
107
+ "--minlenratio",
108
+ type=float,
109
+ default=0.0,
110
+ help="Input length ratio to obtain min output length",
111
+ )
112
+ # multilingual related
113
+ parser.add_argument(
114
+ "--tgt-lang",
115
+ default=False,
116
+ type=str,
117
+ help="target language ID (e.g., <en>, <de>, and <fr> etc.)",
118
+ )
119
+ return parser
120
+
121
+
122
+ def main(args):
123
+ """Run the main decoding function."""
124
+ parser = get_parser()
125
+ args = parser.parse_args(args)
126
+
127
+ # logging info
128
+ if args.verbose == 1:
129
+ logging.basicConfig(
130
+ level=logging.INFO,
131
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
132
+ )
133
+ elif args.verbose == 2:
134
+ logging.basicConfig(
135
+ level=logging.DEBUG,
136
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
137
+ )
138
+ else:
139
+ logging.basicConfig(
140
+ level=logging.WARN,
141
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
142
+ )
143
+ logging.warning("Skip DEBUG/INFO messages")
144
+
145
+ # check CUDA_VISIBLE_DEVICES
146
+ if args.ngpu > 0:
147
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
148
+ if cvd is None:
149
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
150
+ elif args.ngpu != len(cvd.split(",")):
151
+ logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
152
+ sys.exit(1)
153
+
154
+ # TODO(mn5k): support of multiple GPUs
155
+ if args.ngpu > 1:
156
+ logging.error("The program only supports ngpu=1.")
157
+ sys.exit(1)
158
+
159
+ # display PYTHONPATH
160
+ logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
161
+
162
+ # seed setting
163
+ random.seed(args.seed)
164
+ np.random.seed(args.seed)
165
+ logging.info("set random seed = %d" % args.seed)
166
+
167
+ # trans
168
+ logging.info("backend = " + args.backend)
169
+ if args.backend == "pytorch":
170
+ # Experimental API that supports custom LMs
171
+ from espnet.st.pytorch_backend.st import trans
172
+
173
+ if args.dtype != "float32":
174
+ raise NotImplementedError(
175
+ f"`--dtype {args.dtype}` is only available with `--api v2`"
176
+ )
177
+ trans(args)
178
+ else:
179
+ raise ValueError("Only pytorch are supported.")
180
+
181
+
182
+ if __name__ == "__main__":
183
+ main(sys.argv[1:])
espnet/bin/tts_decode.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2018 Nagoya University (Tomoki Hayashi)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ """TTS decoding script."""
7
+
8
+ import configargparse
9
+ import logging
10
+ import os
11
+ import platform
12
+ import subprocess
13
+ import sys
14
+
15
+ from espnet.utils.cli_utils import strtobool
16
+
17
+
18
+ # NOTE: you need this func to generate our sphinx doc
19
+ def get_parser():
20
+ """Get parser of decoding arguments."""
21
+ parser = configargparse.ArgumentParser(
22
+ description="Synthesize speech from text using a TTS model on one CPU",
23
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
24
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
25
+ )
26
+ # general configuration
27
+ parser.add("--config", is_config_file=True, help="config file path")
28
+ parser.add(
29
+ "--config2",
30
+ is_config_file=True,
31
+ help="second config file path that overwrites the settings in `--config`.",
32
+ )
33
+ parser.add(
34
+ "--config3",
35
+ is_config_file=True,
36
+ help="third config file path that overwrites "
37
+ "the settings in `--config` and `--config2`.",
38
+ )
39
+
40
+ parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs")
41
+ parser.add_argument(
42
+ "--backend",
43
+ default="pytorch",
44
+ type=str,
45
+ choices=["chainer", "pytorch"],
46
+ help="Backend library",
47
+ )
48
+ parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
49
+ parser.add_argument("--seed", default=1, type=int, help="Random seed")
50
+ parser.add_argument("--out", type=str, required=True, help="Output filename")
51
+ parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
52
+ parser.add_argument(
53
+ "--preprocess-conf",
54
+ type=str,
55
+ default=None,
56
+ help="The configuration file for the pre-processing",
57
+ )
58
+ # task related
59
+ parser.add_argument(
60
+ "--json", type=str, required=True, help="Filename of train label data (json)"
61
+ )
62
+ parser.add_argument(
63
+ "--model", type=str, required=True, help="Model file parameters to read"
64
+ )
65
+ parser.add_argument(
66
+ "--model-conf", type=str, default=None, help="Model config file"
67
+ )
68
+ # decoding related
69
+ parser.add_argument(
70
+ "--maxlenratio", type=float, default=5, help="Maximum length ratio in decoding"
71
+ )
72
+ parser.add_argument(
73
+ "--minlenratio", type=float, default=0, help="Minimum length ratio in decoding"
74
+ )
75
+ parser.add_argument(
76
+ "--threshold", type=float, default=0.5, help="Threshold value in decoding"
77
+ )
78
+ parser.add_argument(
79
+ "--use-att-constraint",
80
+ type=strtobool,
81
+ default=False,
82
+ help="Whether to use the attention constraint",
83
+ )
84
+ parser.add_argument(
85
+ "--backward-window",
86
+ type=int,
87
+ default=1,
88
+ help="Backward window size in the attention constraint",
89
+ )
90
+ parser.add_argument(
91
+ "--forward-window",
92
+ type=int,
93
+ default=3,
94
+ help="Forward window size in the attention constraint",
95
+ )
96
+ parser.add_argument(
97
+ "--fastspeech-alpha",
98
+ type=float,
99
+ default=1.0,
100
+ help="Alpha to change the speed for FastSpeech",
101
+ )
102
+ # save related
103
+ parser.add_argument(
104
+ "--save-durations",
105
+ default=False,
106
+ type=strtobool,
107
+ help="Whether to save durations converted from attentions",
108
+ )
109
+ parser.add_argument(
110
+ "--save-focus-rates",
111
+ default=False,
112
+ type=strtobool,
113
+ help="Whether to save focus rates of attentions",
114
+ )
115
+ return parser
116
+
117
+
118
+ def main(args):
119
+ """Run deocding."""
120
+ parser = get_parser()
121
+ args = parser.parse_args(args)
122
+
123
+ # logging info
124
+ if args.verbose > 0:
125
+ logging.basicConfig(
126
+ level=logging.INFO,
127
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
128
+ )
129
+ else:
130
+ logging.basicConfig(
131
+ level=logging.WARN,
132
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
133
+ )
134
+ logging.warning("Skip DEBUG/INFO messages")
135
+
136
+ # check CUDA_VISIBLE_DEVICES
137
+ if args.ngpu > 0:
138
+ # python 2 case
139
+ if platform.python_version_tuple()[0] == "2":
140
+ if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]):
141
+ cvd = subprocess.check_output(
142
+ ["/usr/local/bin/free-gpu", "-n", str(args.ngpu)]
143
+ ).strip()
144
+ logging.info("CLSP: use gpu" + cvd)
145
+ os.environ["CUDA_VISIBLE_DEVICES"] = cvd
146
+ # python 3 case
147
+ else:
148
+ if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]).decode():
149
+ cvd = (
150
+ subprocess.check_output(
151
+ ["/usr/local/bin/free-gpu", "-n", str(args.ngpu)]
152
+ )
153
+ .decode()
154
+ .strip()
155
+ )
156
+ logging.info("CLSP: use gpu" + cvd)
157
+ os.environ["CUDA_VISIBLE_DEVICES"] = cvd
158
+
159
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
160
+ if cvd is None:
161
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
162
+ elif args.ngpu != len(cvd.split(",")):
163
+ logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
164
+ sys.exit(1)
165
+
166
+ # display PYTHONPATH
167
+ logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
168
+
169
+ # extract
170
+ logging.info("backend = " + args.backend)
171
+ if args.backend == "pytorch":
172
+ from espnet.tts.pytorch_backend.tts import decode
173
+
174
+ decode(args)
175
+ else:
176
+ raise NotImplementedError("Only pytorch is supported.")
177
+
178
+
179
+ if __name__ == "__main__":
180
+ main(sys.argv[1:])
espnet/bin/tts_train.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2018 Nagoya University (Tomoki Hayashi)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ """Text-to-speech model training script."""
7
+
8
+ import logging
9
+ import os
10
+ import random
11
+ import subprocess
12
+ import sys
13
+
14
+ import configargparse
15
+ import numpy as np
16
+
17
+ from espnet import __version__
18
+ from espnet.nets.tts_interface import TTSInterface
19
+ from espnet.utils.cli_utils import strtobool
20
+ from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES
21
+
22
+
23
+ # NOTE: you need this func to generate our sphinx doc
24
+ def get_parser():
25
+ """Get parser of training arguments."""
26
+ parser = configargparse.ArgumentParser(
27
+ description="Train a new text-to-speech (TTS) model on one CPU, "
28
+ "one or multiple GPUs",
29
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
30
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
31
+ )
32
+
33
+ # general configuration
34
+ parser.add("--config", is_config_file=True, help="config file path")
35
+ parser.add(
36
+ "--config2",
37
+ is_config_file=True,
38
+ help="second config file path that overwrites the settings in `--config`.",
39
+ )
40
+ parser.add(
41
+ "--config3",
42
+ is_config_file=True,
43
+ help="third config file path that overwrites "
44
+ "the settings in `--config` and `--config2`.",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--ngpu",
49
+ default=None,
50
+ type=int,
51
+ help="Number of GPUs. If not given, use all visible devices",
52
+ )
53
+ parser.add_argument(
54
+ "--backend",
55
+ default="pytorch",
56
+ type=str,
57
+ choices=["chainer", "pytorch"],
58
+ help="Backend library",
59
+ )
60
+ parser.add_argument("--outdir", type=str, required=True, help="Output directory")
61
+ parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
62
+ parser.add_argument("--seed", default=1, type=int, help="Random seed")
63
+ parser.add_argument(
64
+ "--resume",
65
+ "-r",
66
+ default="",
67
+ type=str,
68
+ nargs="?",
69
+ help="Resume the training from snapshot",
70
+ )
71
+ parser.add_argument(
72
+ "--minibatches",
73
+ "-N",
74
+ type=int,
75
+ default="-1",
76
+ help="Process only N minibatches (for debug)",
77
+ )
78
+ parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
79
+ parser.add_argument(
80
+ "--tensorboard-dir",
81
+ default=None,
82
+ type=str,
83
+ nargs="?",
84
+ help="Tensorboard log directory path",
85
+ )
86
+ parser.add_argument(
87
+ "--eval-interval-epochs", default=1, type=int, help="Evaluation interval epochs"
88
+ )
89
+ parser.add_argument(
90
+ "--save-interval-epochs", default=1, type=int, help="Save interval epochs"
91
+ )
92
+ parser.add_argument(
93
+ "--report-interval-iters",
94
+ default=100,
95
+ type=int,
96
+ help="Report interval iterations",
97
+ )
98
+ # task related
99
+ parser.add_argument(
100
+ "--train-json", type=str, required=True, help="Filename of training json"
101
+ )
102
+ parser.add_argument(
103
+ "--valid-json", type=str, required=True, help="Filename of validation json"
104
+ )
105
+ # network architecture
106
+ parser.add_argument(
107
+ "--model-module",
108
+ type=str,
109
+ default="espnet.nets.pytorch_backend.e2e_tts_tacotron2:Tacotron2",
110
+ help="model defined module",
111
+ )
112
+ # minibatch related
113
+ parser.add_argument(
114
+ "--sortagrad",
115
+ default=0,
116
+ type=int,
117
+ nargs="?",
118
+ help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
119
+ )
120
+ parser.add_argument(
121
+ "--batch-sort-key",
122
+ default="shuffle",
123
+ type=str,
124
+ choices=["shuffle", "output", "input"],
125
+ nargs="?",
126
+ help='Batch sorting key. "shuffle" only work with --batch-count "seq".',
127
+ )
128
+ parser.add_argument(
129
+ "--batch-count",
130
+ default="auto",
131
+ choices=BATCH_COUNT_CHOICES,
132
+ help="How to count batch_size. "
133
+ "The default (auto) will find how to count by args.",
134
+ )
135
+ parser.add_argument(
136
+ "--batch-size",
137
+ "--batch-seqs",
138
+ "-b",
139
+ default=0,
140
+ type=int,
141
+ help="Maximum seqs in a minibatch (0 to disable)",
142
+ )
143
+ parser.add_argument(
144
+ "--batch-bins",
145
+ default=0,
146
+ type=int,
147
+ help="Maximum bins in a minibatch (0 to disable)",
148
+ )
149
+ parser.add_argument(
150
+ "--batch-frames-in",
151
+ default=0,
152
+ type=int,
153
+ help="Maximum input frames in a minibatch (0 to disable)",
154
+ )
155
+ parser.add_argument(
156
+ "--batch-frames-out",
157
+ default=0,
158
+ type=int,
159
+ help="Maximum output frames in a minibatch (0 to disable)",
160
+ )
161
+ parser.add_argument(
162
+ "--batch-frames-inout",
163
+ default=0,
164
+ type=int,
165
+ help="Maximum input+output frames in a minibatch (0 to disable)",
166
+ )
167
+ parser.add_argument(
168
+ "--maxlen-in",
169
+ "--batch-seq-maxlen-in",
170
+ default=100,
171
+ type=int,
172
+ metavar="ML",
173
+ help="When --batch-count=seq, "
174
+ "batch size is reduced if the input sequence length > ML.",
175
+ )
176
+ parser.add_argument(
177
+ "--maxlen-out",
178
+ "--batch-seq-maxlen-out",
179
+ default=200,
180
+ type=int,
181
+ metavar="ML",
182
+ help="When --batch-count=seq, "
183
+ "batch size is reduced if the output sequence length > ML",
184
+ )
185
+ parser.add_argument(
186
+ "--num-iter-processes",
187
+ default=0,
188
+ type=int,
189
+ help="Number of processes of iterator",
190
+ )
191
+ parser.add_argument(
192
+ "--preprocess-conf",
193
+ type=str,
194
+ default=None,
195
+ help="The configuration file for the pre-processing",
196
+ )
197
+ parser.add_argument(
198
+ "--use-speaker-embedding",
199
+ default=False,
200
+ type=strtobool,
201
+ help="Whether to use speaker embedding",
202
+ )
203
+ parser.add_argument(
204
+ "--use-second-target",
205
+ default=False,
206
+ type=strtobool,
207
+ help="Whether to use second target",
208
+ )
209
+ # optimization related
210
+ parser.add_argument(
211
+ "--opt", default="adam", type=str, choices=["adam", "noam"], help="Optimizer"
212
+ )
213
+ parser.add_argument(
214
+ "--accum-grad", default=1, type=int, help="Number of gradient accumuration"
215
+ )
216
+ parser.add_argument(
217
+ "--lr", default=1e-3, type=float, help="Learning rate for optimizer"
218
+ )
219
+ parser.add_argument("--eps", default=1e-6, type=float, help="Epsilon for optimizer")
220
+ parser.add_argument(
221
+ "--weight-decay",
222
+ default=1e-6,
223
+ type=float,
224
+ help="Weight decay coefficient for optimizer",
225
+ )
226
+ parser.add_argument(
227
+ "--epochs", "-e", default=30, type=int, help="Number of maximum epochs"
228
+ )
229
+ parser.add_argument(
230
+ "--early-stop-criterion",
231
+ default="validation/main/loss",
232
+ type=str,
233
+ nargs="?",
234
+ help="Value to monitor to trigger an early stopping of the training",
235
+ )
236
+ parser.add_argument(
237
+ "--patience",
238
+ default=3,
239
+ type=int,
240
+ nargs="?",
241
+ help="Number of epochs to wait "
242
+ "without improvement before stopping the training",
243
+ )
244
+ parser.add_argument(
245
+ "--grad-clip", default=1, type=float, help="Gradient norm threshold to clip"
246
+ )
247
+ parser.add_argument(
248
+ "--num-save-attention",
249
+ default=5,
250
+ type=int,
251
+ help="Number of samples of attention to be saved",
252
+ )
253
+ parser.add_argument(
254
+ "--keep-all-data-on-mem",
255
+ default=False,
256
+ type=strtobool,
257
+ help="Whether to keep all data on memory",
258
+ )
259
+ # finetuning related
260
+ parser.add_argument(
261
+ "--enc-init",
262
+ default=None,
263
+ type=str,
264
+ help="Pre-trained TTS model path to initialize encoder.",
265
+ )
266
+ parser.add_argument(
267
+ "--enc-init-mods",
268
+ default="enc.",
269
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
270
+ help="List of encoder modules to initialize, separated by a comma.",
271
+ )
272
+ parser.add_argument(
273
+ "--dec-init",
274
+ default=None,
275
+ type=str,
276
+ help="Pre-trained TTS model path to initialize decoder.",
277
+ )
278
+ parser.add_argument(
279
+ "--dec-init-mods",
280
+ default="dec.",
281
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
282
+ help="List of decoder modules to initialize, separated by a comma.",
283
+ )
284
+ parser.add_argument(
285
+ "--freeze-mods",
286
+ default=None,
287
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
288
+ help="List of modules to freeze (not to train), separated by a comma.",
289
+ )
290
+
291
+ return parser
292
+
293
+
294
+ def main(cmd_args):
295
+ """Run training."""
296
+ parser = get_parser()
297
+ args, _ = parser.parse_known_args(cmd_args)
298
+
299
+ from espnet.utils.dynamic_import import dynamic_import
300
+
301
+ model_class = dynamic_import(args.model_module)
302
+ assert issubclass(model_class, TTSInterface)
303
+ model_class.add_arguments(parser)
304
+ args = parser.parse_args(cmd_args)
305
+
306
+ # add version info in args
307
+ args.version = __version__
308
+
309
+ # logging info
310
+ if args.verbose > 0:
311
+ logging.basicConfig(
312
+ level=logging.INFO,
313
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
314
+ )
315
+ else:
316
+ logging.basicConfig(
317
+ level=logging.WARN,
318
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
319
+ )
320
+ logging.warning("Skip DEBUG/INFO messages")
321
+
322
+ # If --ngpu is not given,
323
+ # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
324
+ # 2. if nvidia-smi exists, use all devices
325
+ # 3. else ngpu=0
326
+ if args.ngpu is None:
327
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
328
+ if cvd is not None:
329
+ ngpu = len(cvd.split(","))
330
+ else:
331
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
332
+ try:
333
+ p = subprocess.run(
334
+ ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
335
+ )
336
+ except (subprocess.CalledProcessError, FileNotFoundError):
337
+ ngpu = 0
338
+ else:
339
+ ngpu = len(p.stderr.decode().split("\n")) - 1
340
+ args.ngpu = ngpu
341
+ else:
342
+ ngpu = args.ngpu
343
+ logging.info(f"ngpu: {ngpu}")
344
+
345
+ # set random seed
346
+ logging.info("random seed = %d" % args.seed)
347
+ random.seed(args.seed)
348
+ np.random.seed(args.seed)
349
+
350
+ if args.backend == "pytorch":
351
+ from espnet.tts.pytorch_backend.tts import train
352
+
353
+ train(args)
354
+ else:
355
+ raise NotImplementedError("Only pytorch is supported.")
356
+
357
+
358
+ if __name__ == "__main__":
359
+ main(sys.argv[1:])
espnet/bin/vc_decode.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020 Nagoya University (Wen-Chin Huang)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ """VC decoding script."""
7
+
8
+ import configargparse
9
+ import logging
10
+ import os
11
+ import platform
12
+ import subprocess
13
+ import sys
14
+
15
+ from espnet.utils.cli_utils import strtobool
16
+
17
+
18
+ # NOTE: you need this func to generate our sphinx doc
19
+ def get_parser():
20
+ """Get parser of decoding arguments."""
21
+ parser = configargparse.ArgumentParser(
22
+ description="Converting speech using a VC model on one CPU",
23
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
24
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
25
+ )
26
+ # general configuration
27
+ parser.add("--config", is_config_file=True, help="config file path")
28
+ parser.add(
29
+ "--config2",
30
+ is_config_file=True,
31
+ help="second config file path that overwrites the settings in `--config`.",
32
+ )
33
+ parser.add(
34
+ "--config3",
35
+ is_config_file=True,
36
+ help="third config file path that overwrites the settings "
37
+ "in `--config` and `--config2`.",
38
+ )
39
+
40
+ parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs")
41
+ parser.add_argument(
42
+ "--backend",
43
+ default="pytorch",
44
+ type=str,
45
+ choices=["chainer", "pytorch"],
46
+ help="Backend library",
47
+ )
48
+ parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
49
+ parser.add_argument("--seed", default=1, type=int, help="Random seed")
50
+ parser.add_argument("--out", type=str, required=True, help="Output filename")
51
+ parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
52
+ parser.add_argument(
53
+ "--preprocess-conf",
54
+ type=str,
55
+ default=None,
56
+ help="The configuration file for the pre-processing",
57
+ )
58
+ # task related
59
+ parser.add_argument(
60
+ "--json", type=str, required=True, help="Filename of train label data (json)"
61
+ )
62
+ parser.add_argument(
63
+ "--model", type=str, required=True, help="Model file parameters to read"
64
+ )
65
+ parser.add_argument(
66
+ "--model-conf", type=str, default=None, help="Model config file"
67
+ )
68
+ # decoding related
69
+ parser.add_argument(
70
+ "--maxlenratio", type=float, default=5, help="Maximum length ratio in decoding"
71
+ )
72
+ parser.add_argument(
73
+ "--minlenratio", type=float, default=0, help="Minimum length ratio in decoding"
74
+ )
75
+ parser.add_argument(
76
+ "--threshold", type=float, default=0.5, help="Threshold value in decoding"
77
+ )
78
+ parser.add_argument(
79
+ "--use-att-constraint",
80
+ type=strtobool,
81
+ default=False,
82
+ help="Whether to use the attention constraint",
83
+ )
84
+ parser.add_argument(
85
+ "--backward-window",
86
+ type=int,
87
+ default=1,
88
+ help="Backward window size in the attention constraint",
89
+ )
90
+ parser.add_argument(
91
+ "--forward-window",
92
+ type=int,
93
+ default=3,
94
+ help="Forward window size in the attention constraint",
95
+ )
96
+ # save related
97
+ parser.add_argument(
98
+ "--save-durations",
99
+ default=False,
100
+ type=strtobool,
101
+ help="Whether to save durations converted from attentions",
102
+ )
103
+ parser.add_argument(
104
+ "--save-focus-rates",
105
+ default=False,
106
+ type=strtobool,
107
+ help="Whether to save focus rates of attentions",
108
+ )
109
+ return parser
110
+
111
+
112
+ def main(args):
113
+ """Run deocding."""
114
+ parser = get_parser()
115
+ args = parser.parse_args(args)
116
+
117
+ # logging info
118
+ if args.verbose > 0:
119
+ logging.basicConfig(
120
+ level=logging.INFO,
121
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
122
+ )
123
+ else:
124
+ logging.basicConfig(
125
+ level=logging.WARN,
126
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
127
+ )
128
+ logging.warning("Skip DEBUG/INFO messages")
129
+
130
+ # check CUDA_VISIBLE_DEVICES
131
+ if args.ngpu > 0:
132
+ # python 2 case
133
+ if platform.python_version_tuple()[0] == "2":
134
+ if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]):
135
+ cvd = subprocess.check_output(
136
+ ["/usr/local/bin/free-gpu", "-n", str(args.ngpu)]
137
+ ).strip()
138
+ logging.info("CLSP: use gpu" + cvd)
139
+ os.environ["CUDA_VISIBLE_DEVICES"] = cvd
140
+ # python 3 case
141
+ else:
142
+ if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]).decode():
143
+ cvd = (
144
+ subprocess.check_output(
145
+ ["/usr/local/bin/free-gpu", "-n", str(args.ngpu)]
146
+ )
147
+ .decode()
148
+ .strip()
149
+ )
150
+ logging.info("CLSP: use gpu" + cvd)
151
+ os.environ["CUDA_VISIBLE_DEVICES"] = cvd
152
+
153
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
154
+ if cvd is None:
155
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
156
+ elif args.ngpu != len(cvd.split(",")):
157
+ logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
158
+ sys.exit(1)
159
+
160
+ # display PYTHONPATH
161
+ logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
162
+
163
+ # extract
164
+ logging.info("backend = " + args.backend)
165
+ if args.backend == "pytorch":
166
+ from espnet.vc.pytorch_backend.vc import decode
167
+
168
+ decode(args)
169
+ else:
170
+ raise NotImplementedError("Only pytorch is supported.")
171
+
172
+
173
+ if __name__ == "__main__":
174
+ main(sys.argv[1:])
espnet/bin/vc_train.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020 Nagoya University (Wen-Chin Huang)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ """Voice conversion model training script."""
7
+
8
+ import logging
9
+ import os
10
+ import random
11
+ import subprocess
12
+ import sys
13
+
14
+ import configargparse
15
+ import numpy as np
16
+
17
+ from espnet import __version__
18
+ from espnet.nets.tts_interface import TTSInterface
19
+ from espnet.utils.cli_utils import strtobool
20
+ from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES
21
+
22
+
23
+ # NOTE: you need this func to generate our sphinx doc
24
+ def get_parser():
25
+ """Get parser of training arguments."""
26
+ parser = configargparse.ArgumentParser(
27
+ description="Train a new voice conversion (VC) model on one CPU, "
28
+ "one or multiple GPUs",
29
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
30
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
31
+ )
32
+
33
+ # general configuration
34
+ parser.add("--config", is_config_file=True, help="config file path")
35
+ parser.add(
36
+ "--config2",
37
+ is_config_file=True,
38
+ help="second config file path that overwrites the settings in `--config`.",
39
+ )
40
+ parser.add(
41
+ "--config3",
42
+ is_config_file=True,
43
+ help="third config file path that overwrites the settings "
44
+ "in `--config` and `--config2`.",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--ngpu",
49
+ default=None,
50
+ type=int,
51
+ help="Number of GPUs. If not given, use all visible devices",
52
+ )
53
+ parser.add_argument(
54
+ "--backend",
55
+ default="pytorch",
56
+ type=str,
57
+ choices=["chainer", "pytorch"],
58
+ help="Backend library",
59
+ )
60
+ parser.add_argument("--outdir", type=str, required=True, help="Output directory")
61
+ parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
62
+ parser.add_argument("--seed", default=1, type=int, help="Random seed")
63
+ parser.add_argument(
64
+ "--resume",
65
+ "-r",
66
+ default="",
67
+ type=str,
68
+ nargs="?",
69
+ help="Resume the training from snapshot",
70
+ )
71
+ parser.add_argument(
72
+ "--minibatches",
73
+ "-N",
74
+ type=int,
75
+ default="-1",
76
+ help="Process only N minibatches (for debug)",
77
+ )
78
+ parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
79
+ parser.add_argument(
80
+ "--tensorboard-dir",
81
+ default=None,
82
+ type=str,
83
+ nargs="?",
84
+ help="Tensorboard log directory path",
85
+ )
86
+ parser.add_argument(
87
+ "--eval-interval-epochs",
88
+ default=100,
89
+ type=int,
90
+ help="Evaluation interval epochs",
91
+ )
92
+ parser.add_argument(
93
+ "--save-interval-epochs", default=1, type=int, help="Save interval epochs"
94
+ )
95
+ parser.add_argument(
96
+ "--report-interval-iters",
97
+ default=10,
98
+ type=int,
99
+ help="Report interval iterations",
100
+ )
101
+ # task related
102
+ parser.add_argument("--srcspk", type=str, help="Source speaker")
103
+ parser.add_argument("--trgspk", type=str, help="Target speaker")
104
+ parser.add_argument(
105
+ "--train-json", type=str, required=True, help="Filename of training json"
106
+ )
107
+ parser.add_argument(
108
+ "--valid-json", type=str, required=True, help="Filename of validation json"
109
+ )
110
+
111
+ # network architecture
112
+ parser.add_argument(
113
+ "--model-module",
114
+ type=str,
115
+ default="espnet.nets.pytorch_backend.e2e_tts_tacotron2:Tacotron2",
116
+ help="model defined module",
117
+ )
118
+ # minibatch related
119
+ parser.add_argument(
120
+ "--sortagrad",
121
+ default=0,
122
+ type=int,
123
+ nargs="?",
124
+ help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
125
+ )
126
+ parser.add_argument(
127
+ "--batch-sort-key",
128
+ default="shuffle",
129
+ type=str,
130
+ choices=["shuffle", "output", "input"],
131
+ nargs="?",
132
+ help='Batch sorting key. "shuffle" only work with --batch-count "seq".',
133
+ )
134
+ parser.add_argument(
135
+ "--batch-count",
136
+ default="auto",
137
+ choices=BATCH_COUNT_CHOICES,
138
+ help="How to count batch_size. "
139
+ "The default (auto) will find how to count by args.",
140
+ )
141
+ parser.add_argument(
142
+ "--batch-size",
143
+ "--batch-seqs",
144
+ "-b",
145
+ default=0,
146
+ type=int,
147
+ help="Maximum seqs in a minibatch (0 to disable)",
148
+ )
149
+ parser.add_argument(
150
+ "--batch-bins",
151
+ default=0,
152
+ type=int,
153
+ help="Maximum bins in a minibatch (0 to disable)",
154
+ )
155
+ parser.add_argument(
156
+ "--batch-frames-in",
157
+ default=0,
158
+ type=int,
159
+ help="Maximum input frames in a minibatch (0 to disable)",
160
+ )
161
+ parser.add_argument(
162
+ "--batch-frames-out",
163
+ default=0,
164
+ type=int,
165
+ help="Maximum output frames in a minibatch (0 to disable)",
166
+ )
167
+ parser.add_argument(
168
+ "--batch-frames-inout",
169
+ default=0,
170
+ type=int,
171
+ help="Maximum input+output frames in a minibatch (0 to disable)",
172
+ )
173
+ parser.add_argument(
174
+ "--maxlen-in",
175
+ "--batch-seq-maxlen-in",
176
+ default=100,
177
+ type=int,
178
+ metavar="ML",
179
+ help="When --batch-count=seq, "
180
+ "batch size is reduced if the input sequence length > ML.",
181
+ )
182
+ parser.add_argument(
183
+ "--maxlen-out",
184
+ "--batch-seq-maxlen-out",
185
+ default=200,
186
+ type=int,
187
+ metavar="ML",
188
+ help="When --batch-count=seq, "
189
+ "batch size is reduced if the output sequence length > ML",
190
+ )
191
+ parser.add_argument(
192
+ "--num-iter-processes",
193
+ default=0,
194
+ type=int,
195
+ help="Number of processes of iterator",
196
+ )
197
+ parser.add_argument(
198
+ "--preprocess-conf",
199
+ type=str,
200
+ default=None,
201
+ help="The configuration file for the pre-processing",
202
+ )
203
+ parser.add_argument(
204
+ "--use-speaker-embedding",
205
+ default=False,
206
+ type=strtobool,
207
+ help="Whether to use speaker embedding",
208
+ )
209
+ parser.add_argument(
210
+ "--use-second-target",
211
+ default=False,
212
+ type=strtobool,
213
+ help="Whether to use second target",
214
+ )
215
+ # optimization related
216
+ parser.add_argument(
217
+ "--opt",
218
+ default="adam",
219
+ type=str,
220
+ choices=["adam", "noam", "lamb"],
221
+ help="Optimizer",
222
+ )
223
+ parser.add_argument(
224
+ "--accum-grad", default=1, type=int, help="Number of gradient accumuration"
225
+ )
226
+ parser.add_argument(
227
+ "--lr", default=1e-3, type=float, help="Learning rate for optimizer"
228
+ )
229
+ parser.add_argument("--eps", default=1e-6, type=float, help="Epsilon for optimizer")
230
+ parser.add_argument(
231
+ "--weight-decay",
232
+ default=1e-6,
233
+ type=float,
234
+ help="Weight decay coefficient for optimizer",
235
+ )
236
+ parser.add_argument(
237
+ "--epochs", "-e", default=30, type=int, help="Number of maximum epochs"
238
+ )
239
+ parser.add_argument(
240
+ "--early-stop-criterion",
241
+ default="validation/main/loss",
242
+ type=str,
243
+ nargs="?",
244
+ help="Value to monitor to trigger an early stopping of the training",
245
+ )
246
+ parser.add_argument(
247
+ "--patience",
248
+ default=3,
249
+ type=int,
250
+ nargs="?",
251
+ help="Number of epochs to wait without improvement "
252
+ "before stopping the training",
253
+ )
254
+ parser.add_argument(
255
+ "--grad-clip", default=1, type=float, help="Gradient norm threshold to clip"
256
+ )
257
+ parser.add_argument(
258
+ "--num-save-attention",
259
+ default=5,
260
+ type=int,
261
+ help="Number of samples of attention to be saved",
262
+ )
263
+ parser.add_argument(
264
+ "--keep-all-data-on-mem",
265
+ default=False,
266
+ type=strtobool,
267
+ help="Whether to keep all data on memory",
268
+ )
269
+
270
+ parser.add_argument(
271
+ "--enc-init",
272
+ default=None,
273
+ type=str,
274
+ help="Pre-trained model path to initialize encoder.",
275
+ )
276
+ parser.add_argument(
277
+ "--enc-init-mods",
278
+ default="enc.",
279
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
280
+ help="List of encoder modules to initialize, separated by a comma.",
281
+ )
282
+ parser.add_argument(
283
+ "--dec-init",
284
+ default=None,
285
+ type=str,
286
+ help="Pre-trained model path to initialize decoder.",
287
+ )
288
+ parser.add_argument(
289
+ "--dec-init-mods",
290
+ default="dec.",
291
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
292
+ help="List of decoder modules to initialize, separated by a comma.",
293
+ )
294
+ parser.add_argument(
295
+ "--freeze-mods",
296
+ default=None,
297
+ type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
298
+ help="List of modules to freeze (not to train), separated by a comma.",
299
+ )
300
+
301
+ return parser
302
+
303
+
304
+ def main(cmd_args):
305
+ """Run training."""
306
+ parser = get_parser()
307
+ args, _ = parser.parse_known_args(cmd_args)
308
+
309
+ from espnet.utils.dynamic_import import dynamic_import
310
+
311
+ model_class = dynamic_import(args.model_module)
312
+ assert issubclass(model_class, TTSInterface)
313
+ model_class.add_arguments(parser)
314
+ args = parser.parse_args(cmd_args)
315
+
316
+ # add version info in args
317
+ args.version = __version__
318
+
319
+ # logging info
320
+ if args.verbose > 0:
321
+ logging.basicConfig(
322
+ level=logging.INFO,
323
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
324
+ )
325
+ else:
326
+ logging.basicConfig(
327
+ level=logging.WARN,
328
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
329
+ )
330
+ logging.warning("Skip DEBUG/INFO messages")
331
+
332
+ # If --ngpu is not given,
333
+ # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
334
+ # 2. if nvidia-smi exists, use all devices
335
+ # 3. else ngpu=0
336
+ if args.ngpu is None:
337
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
338
+ if cvd is not None:
339
+ ngpu = len(cvd.split(","))
340
+ else:
341
+ logging.warning("CUDA_VISIBLE_DEVICES is not set.")
342
+ try:
343
+ p = subprocess.run(
344
+ ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
345
+ )
346
+ except (subprocess.CalledProcessError, FileNotFoundError):
347
+ ngpu = 0
348
+ else:
349
+ ngpu = len(p.stderr.decode().split("\n")) - 1
350
+ else:
351
+ ngpu = args.ngpu
352
+ logging.info(f"ngpu: {ngpu}")
353
+
354
+ # set random seed
355
+ logging.info("random seed = %d" % args.seed)
356
+ random.seed(args.seed)
357
+ np.random.seed(args.seed)
358
+
359
+ if args.backend == "pytorch":
360
+ from espnet.vc.pytorch_backend.vc import train
361
+
362
+ train(args)
363
+ else:
364
+ raise NotImplementedError("Only pytorch is supported.")
365
+
366
+
367
+ if __name__ == "__main__":
368
+ main(sys.argv[1:])
espnet/lm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/lm/chainer_backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/lm/chainer_backend/extlm.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+
7
+ import math
8
+
9
+ import chainer
10
+ import chainer.functions as F
11
+ from espnet.lm.lm_utils import make_lexical_tree
12
+
13
+
14
+ # Definition of a multi-level (subword/word) language model
15
+ class MultiLevelLM(chainer.Chain):
16
+ logzero = -10000000000.0
17
+ zero = 1.0e-10
18
+
19
+ def __init__(
20
+ self,
21
+ wordlm,
22
+ subwordlm,
23
+ word_dict,
24
+ subword_dict,
25
+ subwordlm_weight=0.8,
26
+ oov_penalty=1.0,
27
+ open_vocab=True,
28
+ ):
29
+ super(MultiLevelLM, self).__init__()
30
+ self.wordlm = wordlm
31
+ self.subwordlm = subwordlm
32
+ self.word_eos = word_dict["<eos>"]
33
+ self.word_unk = word_dict["<unk>"]
34
+ self.xp_word_eos = self.xp.full(1, self.word_eos, "i")
35
+ self.xp_word_unk = self.xp.full(1, self.word_unk, "i")
36
+ self.space = subword_dict["<space>"]
37
+ self.eos = subword_dict["<eos>"]
38
+ self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
39
+ self.log_oov_penalty = math.log(oov_penalty)
40
+ self.open_vocab = open_vocab
41
+ self.subword_dict_size = len(subword_dict)
42
+ self.subwordlm_weight = subwordlm_weight
43
+ self.normalized = True
44
+
45
+ def __call__(self, state, x):
46
+ # update state with input label x
47
+ if state is None: # make initial states and log-prob vectors
48
+ wlm_state, z_wlm = self.wordlm(None, self.xp_word_eos)
49
+ wlm_logprobs = F.log_softmax(z_wlm).data
50
+ clm_state, z_clm = self.subwordlm(None, x)
51
+ log_y = F.log_softmax(z_clm).data * self.subwordlm_weight
52
+ new_node = self.lexroot
53
+ clm_logprob = 0.0
54
+ xi = self.space
55
+ else:
56
+ clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
57
+ xi = int(x)
58
+ if xi == self.space: # inter-word transition
59
+ if node is not None and node[1] >= 0: # check if the node is word end
60
+ w = self.xp.full(1, node[1], "i")
61
+ else: # this node is not a word end, which means <unk>
62
+ w = self.xp_word_unk
63
+ # update wordlm state and log-prob vector
64
+ wlm_state, z_wlm = self.wordlm(wlm_state, w)
65
+ wlm_logprobs = F.log_softmax(z_wlm).data
66
+ new_node = self.lexroot # move to the tree root
67
+ clm_logprob = 0.0
68
+ elif node is not None and xi in node[0]: # intra-word transition
69
+ new_node = node[0][xi]
70
+ clm_logprob += log_y[0, xi]
71
+ elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
72
+ new_node = None
73
+ clm_logprob += log_y[0, xi]
74
+ else: # if open_vocab flag is disabled, return 0 probabilities
75
+ log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f")
76
+ return (clm_state, wlm_state, None, log_y, 0.0), log_y
77
+
78
+ clm_state, z_clm = self.subwordlm(clm_state, x)
79
+ log_y = F.log_softmax(z_clm).data * self.subwordlm_weight
80
+
81
+ # apply word-level probabilies for <space> and <eos> labels
82
+ if xi != self.space:
83
+ if new_node is not None and new_node[1] >= 0: # if new node is word end
84
+ wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob
85
+ else:
86
+ wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty
87
+ log_y[:, self.space] = wlm_logprob
88
+ log_y[:, self.eos] = wlm_logprob
89
+ else:
90
+ log_y[:, self.space] = self.logzero
91
+ log_y[:, self.eos] = self.logzero
92
+
93
+ return (clm_state, wlm_state, wlm_logprobs, new_node, log_y, clm_logprob), log_y
94
+
95
+ def final(self, state):
96
+ clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
97
+ if node is not None and node[1] >= 0: # check if the node is word end
98
+ w = self.xp.full(1, node[1], "i")
99
+ else: # this node is not a word end, which means <unk>
100
+ w = self.xp_word_unk
101
+ wlm_state, z_wlm = self.wordlm(wlm_state, w)
102
+ return F.log_softmax(z_wlm).data[:, self.word_eos]
103
+
104
+
105
+ # Definition of a look-ahead word language model
106
+ class LookAheadWordLM(chainer.Chain):
107
+ logzero = -10000000000.0
108
+ zero = 1.0e-10
109
+
110
+ def __init__(
111
+ self, wordlm, word_dict, subword_dict, oov_penalty=0.0001, open_vocab=True
112
+ ):
113
+ super(LookAheadWordLM, self).__init__()
114
+ self.wordlm = wordlm
115
+ self.word_eos = word_dict["<eos>"]
116
+ self.word_unk = word_dict["<unk>"]
117
+ self.xp_word_eos = self.xp.full(1, self.word_eos, "i")
118
+ self.xp_word_unk = self.xp.full(1, self.word_unk, "i")
119
+ self.space = subword_dict["<space>"]
120
+ self.eos = subword_dict["<eos>"]
121
+ self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
122
+ self.oov_penalty = oov_penalty
123
+ self.open_vocab = open_vocab
124
+ self.subword_dict_size = len(subword_dict)
125
+ self.normalized = True
126
+
127
+ def __call__(self, state, x):
128
+ # update state with input label x
129
+ if state is None: # make initial states and cumlative probability vector
130
+ wlm_state, z_wlm = self.wordlm(None, self.xp_word_eos)
131
+ cumsum_probs = self.xp.cumsum(F.softmax(z_wlm).data, axis=1)
132
+ new_node = self.lexroot
133
+ xi = self.space
134
+ else:
135
+ wlm_state, cumsum_probs, node = state
136
+ xi = int(x)
137
+ if xi == self.space: # inter-word transition
138
+ if node is not None and node[1] >= 0: # check if the node is word end
139
+ w = self.xp.full(1, node[1], "i")
140
+ else: # this node is not a word end, which means <unk>
141
+ w = self.xp_word_unk
142
+ # update wordlm state and cumlative probability vector
143
+ wlm_state, z_wlm = self.wordlm(wlm_state, w)
144
+ cumsum_probs = self.xp.cumsum(F.softmax(z_wlm).data, axis=1)
145
+ new_node = self.lexroot # move to the tree root
146
+ elif node is not None and xi in node[0]: # intra-word transition
147
+ new_node = node[0][xi]
148
+ elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
149
+ new_node = None
150
+ else: # if open_vocab flag is disabled, return 0 probabilities
151
+ log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f")
152
+ return (wlm_state, None, None), log_y
153
+
154
+ if new_node is not None:
155
+ succ, wid, wids = new_node
156
+ # compute parent node probability
157
+ sum_prob = (
158
+ (cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]])
159
+ if wids is not None
160
+ else 1.0
161
+ )
162
+ if sum_prob < self.zero:
163
+ log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f")
164
+ return (wlm_state, cumsum_probs, new_node), log_y
165
+ # set <unk> probability as a default value
166
+ unk_prob = (
167
+ cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1]
168
+ )
169
+ y = self.xp.full(
170
+ (1, self.subword_dict_size), unk_prob * self.oov_penalty, "f"
171
+ )
172
+ # compute transition probabilities to child nodes
173
+ for cid, nd in succ.items():
174
+ y[:, cid] = (
175
+ cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]]
176
+ ) / sum_prob
177
+ # apply word-level probabilies for <space> and <eos> labels
178
+ if wid >= 0:
179
+ wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob
180
+ y[:, self.space] = wlm_prob
181
+ y[:, self.eos] = wlm_prob
182
+ elif xi == self.space:
183
+ y[:, self.space] = self.zero
184
+ y[:, self.eos] = self.zero
185
+ log_y = self.xp.log(
186
+ self.xp.clip(y, self.zero, None)
187
+ ) # clip to avoid log(0)
188
+ else: # if no path in the tree, transition probability is one
189
+ log_y = self.xp.zeros((1, self.subword_dict_size), "f")
190
+ return (wlm_state, cumsum_probs, new_node), log_y
191
+
192
+ def final(self, state):
193
+ wlm_state, cumsum_probs, node = state
194
+ if node is not None and node[1] >= 0: # check if the node is word end
195
+ w = self.xp.full(1, node[1], "i")
196
+ else: # this node is not a word end, which means <unk>
197
+ w = self.xp_word_unk
198
+ wlm_state, z_wlm = self.wordlm(wlm_state, w)
199
+ return F.log_softmax(z_wlm).data[:, self.word_eos]
espnet/lm/chainer_backend/lm.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ # This code is ported from the following implementation written in Torch.
7
+ # https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
8
+
9
+
10
+ import copy
11
+ import json
12
+ import logging
13
+ import numpy as np
14
+ import six
15
+
16
+ import chainer
17
+ from chainer.dataset import convert
18
+ import chainer.functions as F
19
+ import chainer.links as L
20
+
21
+ # for classifier link
22
+ from chainer.functions.loss import softmax_cross_entropy
23
+ from chainer import link
24
+ from chainer import reporter
25
+ from chainer import training
26
+ from chainer.training import extensions
27
+
28
+ from espnet.lm.lm_utils import compute_perplexity
29
+ from espnet.lm.lm_utils import count_tokens
30
+ from espnet.lm.lm_utils import MakeSymlinkToBestModel
31
+ from espnet.lm.lm_utils import ParallelSentenceIterator
32
+ from espnet.lm.lm_utils import read_tokens
33
+
34
+ import espnet.nets.chainer_backend.deterministic_embed_id as DL
35
+ from espnet.nets.lm_interface import LMInterface
36
+ from espnet.optimizer.factory import dynamic_import_optimizer
37
+ from espnet.scheduler.chainer import ChainerScheduler
38
+ from espnet.scheduler.scheduler import dynamic_import_scheduler
39
+
40
+ from espnet.utils.training.tensorboard_logger import TensorboardLogger
41
+ from tensorboardX import SummaryWriter
42
+
43
+ from espnet.utils.deterministic_utils import set_deterministic_chainer
44
+ from espnet.utils.training.evaluator import BaseEvaluator
45
+ from espnet.utils.training.iterators import ShufflingEnabler
46
+ from espnet.utils.training.train_utils import check_early_stop
47
+ from espnet.utils.training.train_utils import set_early_stop
48
+
49
+
50
+ # TODO(karita): reimplement RNNLM with new interface
51
+ class DefaultRNNLM(LMInterface, link.Chain):
52
+ """Default RNNLM wrapper to compute reduce framewise loss values.
53
+
54
+ Args:
55
+ n_vocab (int): The size of the vocabulary
56
+ args (argparse.Namespace): configurations. see `add_arguments`
57
+ """
58
+
59
+ @staticmethod
60
+ def add_arguments(parser):
61
+ parser.add_argument(
62
+ "--type",
63
+ type=str,
64
+ default="lstm",
65
+ nargs="?",
66
+ choices=["lstm", "gru"],
67
+ help="Which type of RNN to use",
68
+ )
69
+ parser.add_argument(
70
+ "--layer", "-l", type=int, default=2, help="Number of hidden layers"
71
+ )
72
+ parser.add_argument(
73
+ "--unit", "-u", type=int, default=650, help="Number of hidden units"
74
+ )
75
+ return parser
76
+
77
+
78
+ class ClassifierWithState(link.Chain):
79
+ """A wrapper for a chainer RNNLM
80
+
81
+ :param link.Chain predictor : The RNNLM
82
+ :param function lossfun: The loss function to use
83
+ :param int/str label_key:
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ predictor,
89
+ lossfun=softmax_cross_entropy.softmax_cross_entropy,
90
+ label_key=-1,
91
+ ):
92
+ if not (isinstance(label_key, (int, str))):
93
+ raise TypeError("label_key must be int or str, but is %s" % type(label_key))
94
+
95
+ super(ClassifierWithState, self).__init__()
96
+ self.lossfun = lossfun
97
+ self.y = None
98
+ self.loss = None
99
+ self.label_key = label_key
100
+
101
+ with self.init_scope():
102
+ self.predictor = predictor
103
+
104
+ def __call__(self, state, *args, **kwargs):
105
+ """Computes the loss value for an input and label pair.
106
+
107
+ It also computes accuracy and stores it to the attribute.
108
+ When ``label_key`` is ``int``, the corresponding element in ``args``
109
+ is treated as ground truth labels. And when it is ``str``, the
110
+ element in ``kwargs`` is used.
111
+ The all elements of ``args`` and ``kwargs`` except the groundtruth
112
+ labels are features.
113
+ It feeds features to the predictor and compare the result
114
+ with ground truth labels.
115
+
116
+ :param state : The LM state
117
+ :param list[chainer.Variable] args : Input minibatch
118
+ :param dict[chainer.Variable] kwargs : Input minibatch
119
+ :return loss value
120
+ :rtype chainer.Variable
121
+ """
122
+
123
+ if isinstance(self.label_key, int):
124
+ if not (-len(args) <= self.label_key < len(args)):
125
+ msg = "Label key %d is out of bounds" % self.label_key
126
+ raise ValueError(msg)
127
+ t = args[self.label_key]
128
+ if self.label_key == -1:
129
+ args = args[:-1]
130
+ else:
131
+ args = args[: self.label_key] + args[self.label_key + 1 :]
132
+ elif isinstance(self.label_key, str):
133
+ if self.label_key not in kwargs:
134
+ msg = 'Label key "%s" is not found' % self.label_key
135
+ raise ValueError(msg)
136
+ t = kwargs[self.label_key]
137
+ del kwargs[self.label_key]
138
+
139
+ self.y = None
140
+ self.loss = None
141
+ state, self.y = self.predictor(state, *args, **kwargs)
142
+ self.loss = self.lossfun(self.y, t)
143
+ return state, self.loss
144
+
145
+ def predict(self, state, x):
146
+ """Predict log probabilities for given state and input x using the predictor
147
+
148
+ :param state : the state
149
+ :param x : the input
150
+ :return a tuple (state, log prob vector)
151
+ :rtype cupy/numpy array
152
+ """
153
+ if hasattr(self.predictor, "normalized") and self.predictor.normalized:
154
+ return self.predictor(state, x)
155
+ else:
156
+ state, z = self.predictor(state, x)
157
+ return state, F.log_softmax(z).data
158
+
159
+ def final(self, state):
160
+ """Predict final log probabilities for given state using the predictor
161
+
162
+ :param state : the state
163
+ :return log probability vector
164
+ :rtype cupy/numpy array
165
+
166
+ """
167
+ if hasattr(self.predictor, "final"):
168
+ return self.predictor.final(state)
169
+ else:
170
+ return 0.0
171
+
172
+
173
+ # Definition of a recurrent net for language modeling
174
+ class RNNLM(chainer.Chain):
175
+ """A chainer RNNLM
176
+
177
+ :param int n_vocab: The size of the vocabulary
178
+ :param int n_layers: The number of layers to create
179
+ :param int n_units: The number of units per layer
180
+ :param str type: The RNN type
181
+ """
182
+
183
+ def __init__(self, n_vocab, n_layers, n_units, typ="lstm"):
184
+ super(RNNLM, self).__init__()
185
+ with self.init_scope():
186
+ self.embed = DL.EmbedID(n_vocab, n_units)
187
+ self.rnn = (
188
+ chainer.ChainList(
189
+ *[L.StatelessLSTM(n_units, n_units) for _ in range(n_layers)]
190
+ )
191
+ if typ == "lstm"
192
+ else chainer.ChainList(
193
+ *[L.StatelessGRU(n_units, n_units) for _ in range(n_layers)]
194
+ )
195
+ )
196
+ self.lo = L.Linear(n_units, n_vocab)
197
+
198
+ for param in self.params():
199
+ param.data[...] = np.random.uniform(-0.1, 0.1, param.data.shape)
200
+ self.n_layers = n_layers
201
+ self.n_units = n_units
202
+ self.typ = typ
203
+
204
+ def __call__(self, state, x):
205
+ if state is None:
206
+ if self.typ == "lstm":
207
+ state = {"c": [None] * self.n_layers, "h": [None] * self.n_layers}
208
+ else:
209
+ state = {"h": [None] * self.n_layers}
210
+
211
+ h = [None] * self.n_layers
212
+ emb = self.embed(x)
213
+ if self.typ == "lstm":
214
+ c = [None] * self.n_layers
215
+ c[0], h[0] = self.rnn[0](state["c"][0], state["h"][0], F.dropout(emb))
216
+ for n in six.moves.range(1, self.n_layers):
217
+ c[n], h[n] = self.rnn[n](
218
+ state["c"][n], state["h"][n], F.dropout(h[n - 1])
219
+ )
220
+ state = {"c": c, "h": h}
221
+ else:
222
+ if state["h"][0] is None:
223
+ xp = self.xp
224
+ with chainer.backends.cuda.get_device_from_id(self._device_id):
225
+ state["h"][0] = chainer.Variable(
226
+ xp.zeros((emb.shape[0], self.n_units), dtype=emb.dtype)
227
+ )
228
+ h[0] = self.rnn[0](state["h"][0], F.dropout(emb))
229
+ for n in six.moves.range(1, self.n_layers):
230
+ if state["h"][n] is None:
231
+ xp = self.xp
232
+ with chainer.backends.cuda.get_device_from_id(self._device_id):
233
+ state["h"][n] = chainer.Variable(
234
+ xp.zeros(
235
+ (h[n - 1].shape[0], self.n_units), dtype=h[n - 1].dtype
236
+ )
237
+ )
238
+ h[n] = self.rnn[n](state["h"][n], F.dropout(h[n - 1]))
239
+ state = {"h": h}
240
+ y = self.lo(F.dropout(h[-1]))
241
+ return state, y
242
+
243
+
244
+ class BPTTUpdater(training.updaters.StandardUpdater):
245
+ """An updater for a chainer LM
246
+
247
+ :param chainer.dataset.Iterator train_iter : The train iterator
248
+ :param optimizer:
249
+ :param schedulers:
250
+ :param int device : The device id
251
+ :param int accum_grad :
252
+ """
253
+
254
+ def __init__(self, train_iter, optimizer, schedulers, device, accum_grad):
255
+ super(BPTTUpdater, self).__init__(train_iter, optimizer, device=device)
256
+ self.scheduler = ChainerScheduler(schedulers, optimizer)
257
+ self.accum_grad = accum_grad
258
+
259
+ # The core part of the update routine can be customized by overriding.
260
+ def update_core(self):
261
+ # When we pass one iterator and optimizer to StandardUpdater.__init__,
262
+ # they are automatically named 'main'.
263
+ train_iter = self.get_iterator("main")
264
+ optimizer = self.get_optimizer("main")
265
+
266
+ count = 0
267
+ sum_loss = 0
268
+ optimizer.target.cleargrads() # Clear the parameter gradients
269
+ for _ in range(self.accum_grad):
270
+ # Progress the dataset iterator for sentences at each iteration.
271
+ batch = train_iter.__next__()
272
+ x, t = convert.concat_examples(batch, device=self.device, padding=(0, -1))
273
+ # Concatenate the token IDs to matrices and send them to the device
274
+ # self.converter does this job
275
+ # (it is chainer.dataset.concat_examples by default)
276
+ xp = chainer.backends.cuda.get_array_module(x)
277
+ loss = 0
278
+ state = None
279
+ batch_size, sequence_length = x.shape
280
+ for i in six.moves.range(sequence_length):
281
+ # Compute the loss at this time step and accumulate it
282
+ state, loss_batch = optimizer.target(
283
+ state, chainer.Variable(x[:, i]), chainer.Variable(t[:, i])
284
+ )
285
+ non_zeros = xp.count_nonzero(x[:, i])
286
+ loss += loss_batch * non_zeros
287
+ count += int(non_zeros)
288
+ # backward
289
+ loss /= batch_size * self.accum_grad # normalized by batch size
290
+ sum_loss += float(loss.data)
291
+ loss.backward() # Backprop
292
+ loss.unchain_backward() # Truncate the graph
293
+
294
+ reporter.report({"loss": sum_loss}, optimizer.target)
295
+ reporter.report({"count": count}, optimizer.target)
296
+ # update
297
+ optimizer.update() # Update the parameters
298
+ self.scheduler.step(self.iteration)
299
+
300
+
301
+ class LMEvaluator(BaseEvaluator):
302
+ """A custom evaluator for a chainer LM
303
+
304
+ :param chainer.dataset.Iterator val_iter : The validation iterator
305
+ :param eval_model : The model to evaluate
306
+ :param int device : The device id to use
307
+ """
308
+
309
+ def __init__(self, val_iter, eval_model, device):
310
+ super(LMEvaluator, self).__init__(val_iter, eval_model, device=device)
311
+
312
+ def evaluate(self):
313
+ val_iter = self.get_iterator("main")
314
+ target = self.get_target("main")
315
+ loss = 0
316
+ count = 0
317
+ for batch in copy.copy(val_iter):
318
+ x, t = convert.concat_examples(batch, device=self.device, padding=(0, -1))
319
+ xp = chainer.backends.cuda.get_array_module(x)
320
+ state = None
321
+ for i in six.moves.range(len(x[0])):
322
+ state, loss_batch = target(state, x[:, i], t[:, i])
323
+ non_zeros = xp.count_nonzero(x[:, i])
324
+ loss += loss_batch.data * non_zeros
325
+ count += int(non_zeros)
326
+ # report validation loss
327
+ observation = {}
328
+ with reporter.report_scope(observation):
329
+ reporter.report({"loss": float(loss / count)}, target)
330
+ return observation
331
+
332
+
333
+ def train(args):
334
+ """Train with the given args
335
+
336
+ :param Namespace args: The program arguments
337
+ """
338
+ # TODO(karita): support this
339
+ if args.model_module != "default":
340
+ raise NotImplementedError("chainer backend does not support --model-module")
341
+
342
+ # display chainer version
343
+ logging.info("chainer version = " + chainer.__version__)
344
+
345
+ set_deterministic_chainer(args)
346
+
347
+ # check cuda and cudnn availability
348
+ if not chainer.cuda.available:
349
+ logging.warning("cuda is not available")
350
+ if not chainer.cuda.cudnn_enabled:
351
+ logging.warning("cudnn is not available")
352
+
353
+ # get special label ids
354
+ unk = args.char_list_dict["<unk>"]
355
+ eos = args.char_list_dict["<eos>"]
356
+ # read tokens as a sequence of sentences
357
+ train = read_tokens(args.train_label, args.char_list_dict)
358
+ val = read_tokens(args.valid_label, args.char_list_dict)
359
+ # count tokens
360
+ n_train_tokens, n_train_oovs = count_tokens(train, unk)
361
+ n_val_tokens, n_val_oovs = count_tokens(val, unk)
362
+ logging.info("#vocab = " + str(args.n_vocab))
363
+ logging.info("#sentences in the training data = " + str(len(train)))
364
+ logging.info("#tokens in the training data = " + str(n_train_tokens))
365
+ logging.info(
366
+ "oov rate in the training data = %.2f %%"
367
+ % (n_train_oovs / n_train_tokens * 100)
368
+ )
369
+ logging.info("#sentences in the validation data = " + str(len(val)))
370
+ logging.info("#tokens in the validation data = " + str(n_val_tokens))
371
+ logging.info(
372
+ "oov rate in the validation data = %.2f %%" % (n_val_oovs / n_val_tokens * 100)
373
+ )
374
+
375
+ use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
376
+
377
+ # Create the dataset iterators
378
+ train_iter = ParallelSentenceIterator(
379
+ train,
380
+ args.batchsize,
381
+ max_length=args.maxlen,
382
+ sos=eos,
383
+ eos=eos,
384
+ shuffle=not use_sortagrad,
385
+ )
386
+ val_iter = ParallelSentenceIterator(
387
+ val, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
388
+ )
389
+ epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad)
390
+ logging.info("#iterations per epoch = %d" % epoch_iters)
391
+ logging.info("#total iterations = " + str(args.epoch * epoch_iters))
392
+ # Prepare an RNNLM model
393
+ rnn = RNNLM(args.n_vocab, args.layer, args.unit, args.type)
394
+ model = ClassifierWithState(rnn)
395
+ if args.ngpu > 1:
396
+ logging.warning("currently, multi-gpu is not supported. use single gpu.")
397
+ if args.ngpu > 0:
398
+ # Make the specified GPU current
399
+ gpu_id = 0
400
+ chainer.cuda.get_device_from_id(gpu_id).use()
401
+ model.to_gpu()
402
+ else:
403
+ gpu_id = -1
404
+
405
+ # Save model conf to json
406
+ model_conf = args.outdir + "/model.json"
407
+ with open(model_conf, "wb") as f:
408
+ logging.info("writing a model config file to " + model_conf)
409
+ f.write(
410
+ json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode(
411
+ "utf_8"
412
+ )
413
+ )
414
+
415
+ # Set up an optimizer
416
+ opt_class = dynamic_import_optimizer(args.opt, args.backend)
417
+ optimizer = opt_class.from_args(model, args)
418
+ if args.schedulers is None:
419
+ schedulers = []
420
+ else:
421
+ schedulers = [dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers]
422
+
423
+ optimizer.setup(model)
424
+ optimizer.add_hook(chainer.optimizer.GradientClipping(args.gradclip))
425
+
426
+ updater = BPTTUpdater(train_iter, optimizer, schedulers, gpu_id, args.accum_grad)
427
+ trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.outdir)
428
+ trainer.extend(LMEvaluator(val_iter, model, device=gpu_id))
429
+ trainer.extend(
430
+ extensions.LogReport(
431
+ postprocess=compute_perplexity,
432
+ trigger=(args.report_interval_iters, "iteration"),
433
+ )
434
+ )
435
+ trainer.extend(
436
+ extensions.PrintReport(
437
+ ["epoch", "iteration", "perplexity", "val_perplexity", "elapsed_time"]
438
+ ),
439
+ trigger=(args.report_interval_iters, "iteration"),
440
+ )
441
+ trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
442
+ trainer.extend(extensions.snapshot(filename="snapshot.ep.{.updater.epoch}"))
443
+ trainer.extend(extensions.snapshot_object(model, "rnnlm.model.{.updater.epoch}"))
444
+ # MEMO(Hori): wants to use MinValueTrigger, but it seems to fail in resuming
445
+ trainer.extend(MakeSymlinkToBestModel("validation/main/loss", "rnnlm.model"))
446
+
447
+ if use_sortagrad:
448
+ trainer.extend(
449
+ ShufflingEnabler([train_iter]),
450
+ trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, "epoch"),
451
+ )
452
+
453
+ if args.resume:
454
+ logging.info("resumed from %s" % args.resume)
455
+ chainer.serializers.load_npz(args.resume, trainer)
456
+
457
+ set_early_stop(trainer, args, is_lm=True)
458
+ if args.tensorboard_dir is not None and args.tensorboard_dir != "":
459
+ writer = SummaryWriter(args.tensorboard_dir)
460
+ trainer.extend(
461
+ TensorboardLogger(writer), trigger=(args.report_interval_iters, "iteration")
462
+ )
463
+
464
+ trainer.run()
465
+ check_early_stop(trainer, args.epoch)
466
+
467
+ # compute perplexity for test set
468
+ if args.test_label:
469
+ logging.info("test the best model")
470
+ chainer.serializers.load_npz(args.outdir + "/rnnlm.model.best", model)
471
+ test = read_tokens(args.test_label, args.char_list_dict)
472
+ n_test_tokens, n_test_oovs = count_tokens(test, unk)
473
+ logging.info("#sentences in the test data = " + str(len(test)))
474
+ logging.info("#tokens in the test data = " + str(n_test_tokens))
475
+ logging.info(
476
+ "oov rate in the test data = %.2f %%" % (n_test_oovs / n_test_tokens * 100)
477
+ )
478
+ test_iter = ParallelSentenceIterator(
479
+ test, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
480
+ )
481
+ evaluator = LMEvaluator(test_iter, model, device=gpu_id)
482
+ with chainer.using_config("train", False):
483
+ result = evaluator()
484
+ logging.info("test perplexity: " + str(np.exp(float(result["main/loss"]))))
espnet/lm/lm_utils.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ # This code is ported from the following implementation written in Torch.
7
+ # https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
8
+
9
+ import chainer
10
+ import h5py
11
+ import logging
12
+ import numpy as np
13
+ import os
14
+ import random
15
+ import six
16
+ from tqdm import tqdm
17
+
18
+ from chainer.training import extension
19
+
20
+
21
+ def load_dataset(path, label_dict, outdir=None):
22
+ """Load and save HDF5 that contains a dataset and stats for LM
23
+
24
+ Args:
25
+ path (str): The path of an input text dataset file
26
+ label_dict (dict[str, int]):
27
+ dictionary that maps token label string to its ID number
28
+ outdir (str): The path of an output dir
29
+
30
+ Returns:
31
+ tuple[list[np.ndarray], int, int]: Tuple of
32
+ token IDs in np.int32 converted by `read_tokens`
33
+ the number of tokens by `count_tokens`,
34
+ and the number of OOVs by `count_tokens`
35
+ """
36
+ if outdir is not None:
37
+ os.makedirs(outdir, exist_ok=True)
38
+ filename = outdir + "/" + os.path.basename(path) + ".h5"
39
+ if os.path.exists(filename):
40
+ logging.info(f"loading binary dataset: {filename}")
41
+ f = h5py.File(filename, "r")
42
+ return f["data"][:], f["n_tokens"][()], f["n_oovs"][()]
43
+ else:
44
+ logging.info("skip dump/load HDF5 because the output dir is not specified")
45
+ logging.info(f"reading text dataset: {path}")
46
+ ret = read_tokens(path, label_dict)
47
+ n_tokens, n_oovs = count_tokens(ret, label_dict["<unk>"])
48
+ if outdir is not None:
49
+ logging.info(f"saving binary dataset: {filename}")
50
+ with h5py.File(filename, "w") as f:
51
+ # http://docs.h5py.org/en/stable/special.html#arbitrary-vlen-data
52
+ data = f.create_dataset(
53
+ "data", (len(ret),), dtype=h5py.special_dtype(vlen=np.int32)
54
+ )
55
+ data[:] = ret
56
+ f["n_tokens"] = n_tokens
57
+ f["n_oovs"] = n_oovs
58
+ return ret, n_tokens, n_oovs
59
+
60
+
61
+ def read_tokens(filename, label_dict):
62
+ """Read tokens as a sequence of sentences
63
+
64
+ :param str filename : The name of the input file
65
+ :param dict label_dict : dictionary that maps token label string to its ID number
66
+ :return list of ID sequences
67
+ :rtype list
68
+ """
69
+
70
+ data = []
71
+ unk = label_dict["<unk>"]
72
+ for ln in tqdm(open(filename, "r", encoding="utf-8")):
73
+ data.append(
74
+ np.array(
75
+ [label_dict.get(label, unk) for label in ln.split()], dtype=np.int32
76
+ )
77
+ )
78
+ return data
79
+
80
+
81
+ def count_tokens(data, unk_id=None):
82
+ """Count tokens and oovs in token ID sequences.
83
+
84
+ Args:
85
+ data (list[np.ndarray]): list of token ID sequences
86
+ unk_id (int): ID of unknown token
87
+
88
+ Returns:
89
+ tuple: tuple of number of token occurrences and number of oov tokens
90
+
91
+ """
92
+
93
+ n_tokens = 0
94
+ n_oovs = 0
95
+ for sentence in data:
96
+ n_tokens += len(sentence)
97
+ if unk_id is not None:
98
+ n_oovs += np.count_nonzero(sentence == unk_id)
99
+ return n_tokens, n_oovs
100
+
101
+
102
+ def compute_perplexity(result):
103
+ """Computes and add the perplexity to the LogReport
104
+
105
+ :param dict result: The current observations
106
+ """
107
+ # Routine to rewrite the result dictionary of LogReport to add perplexity values
108
+ result["perplexity"] = np.exp(result["main/loss"] / result["main/count"])
109
+ if "validation/main/loss" in result:
110
+ result["val_perplexity"] = np.exp(result["validation/main/loss"])
111
+
112
+
113
+ class ParallelSentenceIterator(chainer.dataset.Iterator):
114
+ """Dataset iterator to create a batch of sentences.
115
+
116
+ This iterator returns a pair of sentences, where one token is shifted
117
+ between the sentences like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
118
+ Sentence batches are made in order of longer sentences, and then
119
+ randomly shuffled.
120
+ """
121
+
122
+ def __init__(
123
+ self, dataset, batch_size, max_length=0, sos=0, eos=0, repeat=True, shuffle=True
124
+ ):
125
+ self.dataset = dataset
126
+ self.batch_size = batch_size # batch size
127
+ # Number of completed sweeps over the dataset. In this case, it is
128
+ # incremented if every word is visited at least once after the last
129
+ # increment.
130
+ self.epoch = 0
131
+ # True if the epoch is incremented at the last iteration.
132
+ self.is_new_epoch = False
133
+ self.repeat = repeat
134
+ length = len(dataset)
135
+ self.batch_indices = []
136
+ # make mini-batches
137
+ if batch_size > 1:
138
+ indices = sorted(range(len(dataset)), key=lambda i: -len(dataset[i]))
139
+ bs = 0
140
+ while bs < length:
141
+ be = min(bs + batch_size, length)
142
+ # batch size is automatically reduced if the sentence length
143
+ # is larger than max_length
144
+ if max_length > 0:
145
+ sent_length = len(dataset[indices[bs]])
146
+ be = min(
147
+ be, bs + max(batch_size // (sent_length // max_length + 1), 1)
148
+ )
149
+ self.batch_indices.append(np.array(indices[bs:be]))
150
+ bs = be
151
+ if shuffle:
152
+ # shuffle batches
153
+ random.shuffle(self.batch_indices)
154
+ else:
155
+ self.batch_indices = [np.array([i]) for i in six.moves.range(length)]
156
+
157
+ # NOTE: this is not a count of parameter updates. It is just a count of
158
+ # calls of ``__next__``.
159
+ self.iteration = 0
160
+ self.sos = sos
161
+ self.eos = eos
162
+ # use -1 instead of None internally
163
+ self._previous_epoch_detail = -1.0
164
+
165
+ def __next__(self):
166
+ # This iterator returns a list representing a mini-batch. Each item
167
+ # indicates a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
168
+ # represented by token IDs.
169
+ n_batches = len(self.batch_indices)
170
+ if not self.repeat and self.iteration >= n_batches:
171
+ # If not self.repeat, this iterator stops at the end of the first
172
+ # epoch (i.e., when all words are visited once).
173
+ raise StopIteration
174
+
175
+ batch = []
176
+ for idx in self.batch_indices[self.iteration % n_batches]:
177
+ batch.append(
178
+ (
179
+ np.append([self.sos], self.dataset[idx]),
180
+ np.append(self.dataset[idx], [self.eos]),
181
+ )
182
+ )
183
+
184
+ self._previous_epoch_detail = self.epoch_detail
185
+ self.iteration += 1
186
+
187
+ epoch = self.iteration // n_batches
188
+ self.is_new_epoch = self.epoch < epoch
189
+ if self.is_new_epoch:
190
+ self.epoch = epoch
191
+
192
+ return batch
193
+
194
+ def start_shuffle(self):
195
+ random.shuffle(self.batch_indices)
196
+
197
+ @property
198
+ def epoch_detail(self):
199
+ # Floating point version of epoch.
200
+ return self.iteration / len(self.batch_indices)
201
+
202
+ @property
203
+ def previous_epoch_detail(self):
204
+ if self._previous_epoch_detail < 0:
205
+ return None
206
+ return self._previous_epoch_detail
207
+
208
+ def serialize(self, serializer):
209
+ # It is important to serialize the state to be recovered on resume.
210
+ self.iteration = serializer("iteration", self.iteration)
211
+ self.epoch = serializer("epoch", self.epoch)
212
+ try:
213
+ self._previous_epoch_detail = serializer(
214
+ "previous_epoch_detail", self._previous_epoch_detail
215
+ )
216
+ except KeyError:
217
+ # guess previous_epoch_detail for older version
218
+ self._previous_epoch_detail = self.epoch + (
219
+ self.current_position - 1
220
+ ) / len(self.batch_indices)
221
+ if self.epoch_detail > 0:
222
+ self._previous_epoch_detail = max(self._previous_epoch_detail, 0.0)
223
+ else:
224
+ self._previous_epoch_detail = -1.0
225
+
226
+
227
+ class MakeSymlinkToBestModel(extension.Extension):
228
+ """Extension that makes a symbolic link to the best model
229
+
230
+ :param str key: Key of value
231
+ :param str prefix: Prefix of model files and link target
232
+ :param str suffix: Suffix of link target
233
+ """
234
+
235
+ def __init__(self, key, prefix="model", suffix="best"):
236
+ super(MakeSymlinkToBestModel, self).__init__()
237
+ self.best_model = -1
238
+ self.min_loss = 0.0
239
+ self.key = key
240
+ self.prefix = prefix
241
+ self.suffix = suffix
242
+
243
+ def __call__(self, trainer):
244
+ observation = trainer.observation
245
+ if self.key in observation:
246
+ loss = observation[self.key]
247
+ if self.best_model == -1 or loss < self.min_loss:
248
+ self.min_loss = loss
249
+ self.best_model = trainer.updater.epoch
250
+ src = "%s.%d" % (self.prefix, self.best_model)
251
+ dest = os.path.join(trainer.out, "%s.%s" % (self.prefix, self.suffix))
252
+ if os.path.lexists(dest):
253
+ os.remove(dest)
254
+ os.symlink(src, dest)
255
+ logging.info("best model is " + src)
256
+
257
+ def serialize(self, serializer):
258
+ if isinstance(serializer, chainer.serializer.Serializer):
259
+ serializer("_best_model", self.best_model)
260
+ serializer("_min_loss", self.min_loss)
261
+ serializer("_key", self.key)
262
+ serializer("_prefix", self.prefix)
263
+ serializer("_suffix", self.suffix)
264
+ else:
265
+ self.best_model = serializer("_best_model", -1)
266
+ self.min_loss = serializer("_min_loss", 0.0)
267
+ self.key = serializer("_key", "")
268
+ self.prefix = serializer("_prefix", "model")
269
+ self.suffix = serializer("_suffix", "best")
270
+
271
+
272
+ # TODO(Hori): currently it only works with character-word level LM.
273
+ # need to consider any types of subwords-to-word mapping.
274
+ def make_lexical_tree(word_dict, subword_dict, word_unk):
275
+ """Make a lexical tree to compute word-level probabilities"""
276
+ # node [dict(subword_id -> node), word_id, word_set[start-1, end]]
277
+ root = [{}, -1, None]
278
+ for w, wid in word_dict.items():
279
+ if wid > 0 and wid != word_unk: # skip <blank> and <unk>
280
+ if True in [c not in subword_dict for c in w]: # skip unknown subword
281
+ continue
282
+ succ = root[0] # get successors from root node
283
+ for i, c in enumerate(w):
284
+ cid = subword_dict[c]
285
+ if cid not in succ: # if next node does not exist, make a new node
286
+ succ[cid] = [{}, -1, (wid - 1, wid)]
287
+ else:
288
+ prev = succ[cid][2]
289
+ succ[cid][2] = (min(prev[0], wid - 1), max(prev[1], wid))
290
+ if i == len(w) - 1: # if word end, set word id
291
+ succ[cid][1] = wid
292
+ succ = succ[cid][0] # move to the child successors
293
+ return root
espnet/lm/pytorch_backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/lm/pytorch_backend/extlm.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from espnet.lm.lm_utils import make_lexical_tree
14
+ from espnet.nets.pytorch_backend.nets_utils import to_device
15
+
16
+
17
+ # Definition of a multi-level (subword/word) language model
18
+ class MultiLevelLM(nn.Module):
19
+ logzero = -10000000000.0
20
+ zero = 1.0e-10
21
+
22
+ def __init__(
23
+ self,
24
+ wordlm,
25
+ subwordlm,
26
+ word_dict,
27
+ subword_dict,
28
+ subwordlm_weight=0.8,
29
+ oov_penalty=1.0,
30
+ open_vocab=True,
31
+ ):
32
+ super(MultiLevelLM, self).__init__()
33
+ self.wordlm = wordlm
34
+ self.subwordlm = subwordlm
35
+ self.word_eos = word_dict["<eos>"]
36
+ self.word_unk = word_dict["<unk>"]
37
+ self.var_word_eos = torch.LongTensor([self.word_eos])
38
+ self.var_word_unk = torch.LongTensor([self.word_unk])
39
+ self.space = subword_dict["<space>"]
40
+ self.eos = subword_dict["<eos>"]
41
+ self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
42
+ self.log_oov_penalty = math.log(oov_penalty)
43
+ self.open_vocab = open_vocab
44
+ self.subword_dict_size = len(subword_dict)
45
+ self.subwordlm_weight = subwordlm_weight
46
+ self.normalized = True
47
+
48
+ def forward(self, state, x):
49
+ # update state with input label x
50
+ if state is None: # make initial states and log-prob vectors
51
+ self.var_word_eos = to_device(x, self.var_word_eos)
52
+ self.var_word_unk = to_device(x, self.var_word_eos)
53
+ wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
54
+ wlm_logprobs = F.log_softmax(z_wlm, dim=1)
55
+ clm_state, z_clm = self.subwordlm(None, x)
56
+ log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight
57
+ new_node = self.lexroot
58
+ clm_logprob = 0.0
59
+ xi = self.space
60
+ else:
61
+ clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
62
+ xi = int(x)
63
+ if xi == self.space: # inter-word transition
64
+ if node is not None and node[1] >= 0: # check if the node is word end
65
+ w = to_device(x, torch.LongTensor([node[1]]))
66
+ else: # this node is not a word end, which means <unk>
67
+ w = self.var_word_unk
68
+ # update wordlm state and log-prob vector
69
+ wlm_state, z_wlm = self.wordlm(wlm_state, w)
70
+ wlm_logprobs = F.log_softmax(z_wlm, dim=1)
71
+ new_node = self.lexroot # move to the tree root
72
+ clm_logprob = 0.0
73
+ elif node is not None and xi in node[0]: # intra-word transition
74
+ new_node = node[0][xi]
75
+ clm_logprob += log_y[0, xi]
76
+ elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
77
+ new_node = None
78
+ clm_logprob += log_y[0, xi]
79
+ else: # if open_vocab flag is disabled, return 0 probabilities
80
+ log_y = to_device(
81
+ x, torch.full((1, self.subword_dict_size), self.logzero)
82
+ )
83
+ return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.0), log_y
84
+
85
+ clm_state, z_clm = self.subwordlm(clm_state, x)
86
+ log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight
87
+
88
+ # apply word-level probabilies for <space> and <eos> labels
89
+ if xi != self.space:
90
+ if new_node is not None and new_node[1] >= 0: # if new node is word end
91
+ wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob
92
+ else:
93
+ wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty
94
+ log_y[:, self.space] = wlm_logprob
95
+ log_y[:, self.eos] = wlm_logprob
96
+ else:
97
+ log_y[:, self.space] = self.logzero
98
+ log_y[:, self.eos] = self.logzero
99
+
100
+ return (
101
+ (clm_state, wlm_state, wlm_logprobs, new_node, log_y, float(clm_logprob)),
102
+ log_y,
103
+ )
104
+
105
+ def final(self, state):
106
+ clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
107
+ if node is not None and node[1] >= 0: # check if the node is word end
108
+ w = to_device(wlm_logprobs, torch.LongTensor([node[1]]))
109
+ else: # this node is not a word end, which means <unk>
110
+ w = self.var_word_unk
111
+ wlm_state, z_wlm = self.wordlm(wlm_state, w)
112
+ return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])
113
+
114
+
115
+ # Definition of a look-ahead word language model
116
+ class LookAheadWordLM(nn.Module):
117
+ logzero = -10000000000.0
118
+ zero = 1.0e-10
119
+
120
+ def __init__(
121
+ self, wordlm, word_dict, subword_dict, oov_penalty=0.0001, open_vocab=True
122
+ ):
123
+ super(LookAheadWordLM, self).__init__()
124
+ self.wordlm = wordlm
125
+ self.word_eos = word_dict["<eos>"]
126
+ self.word_unk = word_dict["<unk>"]
127
+ self.var_word_eos = torch.LongTensor([self.word_eos])
128
+ self.var_word_unk = torch.LongTensor([self.word_unk])
129
+ self.space = subword_dict["<space>"]
130
+ self.eos = subword_dict["<eos>"]
131
+ self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
132
+ self.oov_penalty = oov_penalty
133
+ self.open_vocab = open_vocab
134
+ self.subword_dict_size = len(subword_dict)
135
+ self.zero_tensor = torch.FloatTensor([self.zero])
136
+ self.normalized = True
137
+
138
+ def forward(self, state, x):
139
+ # update state with input label x
140
+ if state is None: # make initial states and cumlative probability vector
141
+ self.var_word_eos = to_device(x, self.var_word_eos)
142
+ self.var_word_unk = to_device(x, self.var_word_eos)
143
+ self.zero_tensor = to_device(x, self.zero_tensor)
144
+ wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
145
+ cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
146
+ new_node = self.lexroot
147
+ xi = self.space
148
+ else:
149
+ wlm_state, cumsum_probs, node = state
150
+ xi = int(x)
151
+ if xi == self.space: # inter-word transition
152
+ if node is not None and node[1] >= 0: # check if the node is word end
153
+ w = to_device(x, torch.LongTensor([node[1]]))
154
+ else: # this node is not a word end, which means <unk>
155
+ w = self.var_word_unk
156
+ # update wordlm state and cumlative probability vector
157
+ wlm_state, z_wlm = self.wordlm(wlm_state, w)
158
+ cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
159
+ new_node = self.lexroot # move to the tree root
160
+ elif node is not None and xi in node[0]: # intra-word transition
161
+ new_node = node[0][xi]
162
+ elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
163
+ new_node = None
164
+ else: # if open_vocab flag is disabled, return 0 probabilities
165
+ log_y = to_device(
166
+ x, torch.full((1, self.subword_dict_size), self.logzero)
167
+ )
168
+ return (wlm_state, None, None), log_y
169
+
170
+ if new_node is not None:
171
+ succ, wid, wids = new_node
172
+ # compute parent node probability
173
+ sum_prob = (
174
+ (cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]])
175
+ if wids is not None
176
+ else 1.0
177
+ )
178
+ if sum_prob < self.zero:
179
+ log_y = to_device(
180
+ x, torch.full((1, self.subword_dict_size), self.logzero)
181
+ )
182
+ return (wlm_state, cumsum_probs, new_node), log_y
183
+ # set <unk> probability as a default value
184
+ unk_prob = (
185
+ cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1]
186
+ )
187
+ y = to_device(
188
+ x,
189
+ torch.full(
190
+ (1, self.subword_dict_size), float(unk_prob) * self.oov_penalty
191
+ ),
192
+ )
193
+ # compute transition probabilities to child nodes
194
+ for cid, nd in succ.items():
195
+ y[:, cid] = (
196
+ cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]]
197
+ ) / sum_prob
198
+ # apply word-level probabilies for <space> and <eos> labels
199
+ if wid >= 0:
200
+ wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob
201
+ y[:, self.space] = wlm_prob
202
+ y[:, self.eos] = wlm_prob
203
+ elif xi == self.space:
204
+ y[:, self.space] = self.zero
205
+ y[:, self.eos] = self.zero
206
+ log_y = torch.log(torch.max(y, self.zero_tensor)) # clip to avoid log(0)
207
+ else: # if no path in the tree, transition probability is one
208
+ log_y = to_device(x, torch.zeros(1, self.subword_dict_size))
209
+ return (wlm_state, cumsum_probs, new_node), log_y
210
+
211
+ def final(self, state):
212
+ wlm_state, cumsum_probs, node = state
213
+ if node is not None and node[1] >= 0: # check if the node is word end
214
+ w = to_device(cumsum_probs, torch.LongTensor([node[1]]))
215
+ else: # this node is not a word end, which means <unk>
216
+ w = self.var_word_unk
217
+ wlm_state, z_wlm = self.wordlm(wlm_state, w)
218
+ return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])
espnet/lm/pytorch_backend/lm.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
3
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+ # This code is ported from the following implementation written in Torch.
5
+ # https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
6
+
7
+ """LM training in pytorch."""
8
+
9
+ import copy
10
+ import json
11
+ import logging
12
+ import numpy as np
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.parallel import data_parallel
17
+
18
+ from chainer import Chain
19
+ from chainer.dataset import convert
20
+ from chainer import reporter
21
+ from chainer import training
22
+ from chainer.training import extensions
23
+
24
+ from espnet.lm.lm_utils import count_tokens
25
+ from espnet.lm.lm_utils import load_dataset
26
+ from espnet.lm.lm_utils import MakeSymlinkToBestModel
27
+ from espnet.lm.lm_utils import ParallelSentenceIterator
28
+ from espnet.lm.lm_utils import read_tokens
29
+ from espnet.nets.lm_interface import dynamic_import_lm
30
+ from espnet.nets.lm_interface import LMInterface
31
+ from espnet.optimizer.factory import dynamic_import_optimizer
32
+ from espnet.scheduler.pytorch import PyTorchScheduler
33
+ from espnet.scheduler.scheduler import dynamic_import_scheduler
34
+
35
+ from espnet.asr.asr_utils import snapshot_object
36
+ from espnet.asr.asr_utils import torch_load
37
+ from espnet.asr.asr_utils import torch_resume
38
+ from espnet.asr.asr_utils import torch_snapshot
39
+
40
+ from espnet.utils.training.tensorboard_logger import TensorboardLogger
41
+ from tensorboardX import SummaryWriter
42
+
43
+ from espnet.utils.deterministic_utils import set_deterministic_pytorch
44
+ from espnet.utils.training.evaluator import BaseEvaluator
45
+ from espnet.utils.training.iterators import ShufflingEnabler
46
+ from espnet.utils.training.train_utils import check_early_stop
47
+ from espnet.utils.training.train_utils import set_early_stop
48
+
49
+
50
+ def compute_perplexity(result):
51
+ """Compute and add the perplexity to the LogReport.
52
+
53
+ :param dict result: The current observations
54
+ """
55
+ # Routine to rewrite the result dictionary of LogReport to add perplexity values
56
+ result["perplexity"] = np.exp(result["main/nll"] / result["main/count"])
57
+ if "validation/main/nll" in result:
58
+ result["val_perplexity"] = np.exp(
59
+ result["validation/main/nll"] / result["validation/main/count"]
60
+ )
61
+
62
+
63
+ class Reporter(Chain):
64
+ """Dummy module to use chainer's trainer."""
65
+
66
+ def report(self, loss):
67
+ """Report nothing."""
68
+ pass
69
+
70
+
71
+ def concat_examples(batch, device=None, padding=None):
72
+ """Concat examples in minibatch.
73
+
74
+ :param np.ndarray batch: The batch to concatenate
75
+ :param int device: The device to send to
76
+ :param Tuple[int,int] padding: The padding to use
77
+ :return: (inputs, targets)
78
+ :rtype (torch.Tensor, torch.Tensor)
79
+ """
80
+ x, t = convert.concat_examples(batch, padding=padding)
81
+ x = torch.from_numpy(x)
82
+ t = torch.from_numpy(t)
83
+ if device is not None and device >= 0:
84
+ x = x.cuda(device)
85
+ t = t.cuda(device)
86
+ return x, t
87
+
88
+
89
+ class BPTTUpdater(training.StandardUpdater):
90
+ """An updater for a pytorch LM."""
91
+
92
+ def __init__(
93
+ self,
94
+ train_iter,
95
+ model,
96
+ optimizer,
97
+ schedulers,
98
+ device,
99
+ gradclip=None,
100
+ use_apex=False,
101
+ accum_grad=1,
102
+ ):
103
+ """Initialize class.
104
+
105
+ Args:
106
+ train_iter (chainer.dataset.Iterator): The train iterator
107
+ model (LMInterface) : The model to update
108
+ optimizer (torch.optim.Optimizer): The optimizer for training
109
+ schedulers (espnet.scheduler.scheduler.SchedulerInterface):
110
+ The schedulers of `optimizer`
111
+ device (int): The device id
112
+ gradclip (float): The gradient clipping value to use
113
+ use_apex (bool): The flag to use Apex in backprop.
114
+ accum_grad (int): The number of gradient accumulation.
115
+
116
+ """
117
+ super(BPTTUpdater, self).__init__(train_iter, optimizer)
118
+ self.model = model
119
+ self.device = device
120
+ self.gradclip = gradclip
121
+ self.use_apex = use_apex
122
+ self.scheduler = PyTorchScheduler(schedulers, optimizer)
123
+ self.accum_grad = accum_grad
124
+
125
+ # The core part of the update routine can be customized by overriding.
126
+ def update_core(self):
127
+ """Update the model."""
128
+ # When we pass one iterator and optimizer to StandardUpdater.__init__,
129
+ # they are automatically named 'main'.
130
+ train_iter = self.get_iterator("main")
131
+ optimizer = self.get_optimizer("main")
132
+ # Progress the dataset iterator for sentences at each iteration.
133
+ self.model.zero_grad() # Clear the parameter gradients
134
+ accum = {"loss": 0.0, "nll": 0.0, "count": 0}
135
+ for _ in range(self.accum_grad):
136
+ batch = train_iter.__next__()
137
+ # Concatenate the token IDs to matrices and send them to the device
138
+ # self.converter does this job
139
+ # (it is chainer.dataset.concat_examples by default)
140
+ x, t = concat_examples(batch, device=self.device[0], padding=(0, -100))
141
+ if self.device[0] == -1:
142
+ loss, nll, count = self.model(x, t)
143
+ else:
144
+ # apex does not support torch.nn.DataParallel
145
+ loss, nll, count = data_parallel(self.model, (x, t), self.device)
146
+
147
+ # backward
148
+ loss = loss.mean() / self.accum_grad
149
+ if self.use_apex:
150
+ from apex import amp
151
+
152
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
153
+ scaled_loss.backward()
154
+ else:
155
+ loss.backward() # Backprop
156
+ # accumulate stats
157
+ accum["loss"] += float(loss)
158
+ accum["nll"] += float(nll.sum())
159
+ accum["count"] += int(count.sum())
160
+
161
+ for k, v in accum.items():
162
+ reporter.report({k: v}, optimizer.target)
163
+ if self.gradclip is not None:
164
+ nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip)
165
+ optimizer.step() # Update the parameters
166
+ self.scheduler.step(n_iter=self.iteration)
167
+
168
+
169
+ class LMEvaluator(BaseEvaluator):
170
+ """A custom evaluator for a pytorch LM."""
171
+
172
+ def __init__(self, val_iter, eval_model, reporter, device):
173
+ """Initialize class.
174
+
175
+ :param chainer.dataset.Iterator val_iter : The validation iterator
176
+ :param LMInterface eval_model : The model to evaluate
177
+ :param chainer.Reporter reporter : The observations reporter
178
+ :param int device : The device id to use
179
+
180
+ """
181
+ super(LMEvaluator, self).__init__(val_iter, reporter, device=-1)
182
+ self.model = eval_model
183
+ self.device = device
184
+
185
+ def evaluate(self):
186
+ """Evaluate the model."""
187
+ val_iter = self.get_iterator("main")
188
+ loss = 0
189
+ nll = 0
190
+ count = 0
191
+ self.model.eval()
192
+ with torch.no_grad():
193
+ for batch in copy.copy(val_iter):
194
+ x, t = concat_examples(batch, device=self.device[0], padding=(0, -100))
195
+ if self.device[0] == -1:
196
+ l, n, c = self.model(x, t)
197
+ else:
198
+ # apex does not support torch.nn.DataParallel
199
+ l, n, c = data_parallel(self.model, (x, t), self.device)
200
+ loss += float(l.sum())
201
+ nll += float(n.sum())
202
+ count += int(c.sum())
203
+ self.model.train()
204
+ # report validation loss
205
+ observation = {}
206
+ with reporter.report_scope(observation):
207
+ reporter.report({"loss": loss}, self.model.reporter)
208
+ reporter.report({"nll": nll}, self.model.reporter)
209
+ reporter.report({"count": count}, self.model.reporter)
210
+ return observation
211
+
212
+
213
+ def train(args):
214
+ """Train with the given args.
215
+
216
+ :param Namespace args: The program arguments
217
+ :param type model_class: LMInterface class for training
218
+ """
219
+ model_class = dynamic_import_lm(args.model_module, args.backend)
220
+ assert issubclass(model_class, LMInterface), "model should implement LMInterface"
221
+ # display torch version
222
+ logging.info("torch version = " + torch.__version__)
223
+
224
+ set_deterministic_pytorch(args)
225
+
226
+ # check cuda and cudnn availability
227
+ if not torch.cuda.is_available():
228
+ logging.warning("cuda is not available")
229
+
230
+ # get special label ids
231
+ unk = args.char_list_dict["<unk>"]
232
+ eos = args.char_list_dict["<eos>"]
233
+ # read tokens as a sequence of sentences
234
+ val, n_val_tokens, n_val_oovs = load_dataset(
235
+ args.valid_label, args.char_list_dict, args.dump_hdf5_path
236
+ )
237
+ train, n_train_tokens, n_train_oovs = load_dataset(
238
+ args.train_label, args.char_list_dict, args.dump_hdf5_path
239
+ )
240
+ logging.info("#vocab = " + str(args.n_vocab))
241
+ logging.info("#sentences in the training data = " + str(len(train)))
242
+ logging.info("#tokens in the training data = " + str(n_train_tokens))
243
+ logging.info(
244
+ "oov rate in the training data = %.2f %%"
245
+ % (n_train_oovs / n_train_tokens * 100)
246
+ )
247
+ logging.info("#sentences in the validation data = " + str(len(val)))
248
+ logging.info("#tokens in the validation data = " + str(n_val_tokens))
249
+ logging.info(
250
+ "oov rate in the validation data = %.2f %%" % (n_val_oovs / n_val_tokens * 100)
251
+ )
252
+
253
+ use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
254
+ # Create the dataset iterators
255
+ batch_size = args.batchsize * max(args.ngpu, 1)
256
+ if batch_size * args.accum_grad > args.batchsize:
257
+ logging.info(
258
+ f"batch size is automatically increased "
259
+ f"({args.batchsize} -> {batch_size * args.accum_grad})"
260
+ )
261
+ train_iter = ParallelSentenceIterator(
262
+ train,
263
+ batch_size,
264
+ max_length=args.maxlen,
265
+ sos=eos,
266
+ eos=eos,
267
+ shuffle=not use_sortagrad,
268
+ )
269
+ val_iter = ParallelSentenceIterator(
270
+ val, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
271
+ )
272
+ epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad)
273
+ logging.info("#iterations per epoch = %d" % epoch_iters)
274
+ logging.info("#total iterations = " + str(args.epoch * epoch_iters))
275
+ # Prepare an RNNLM model
276
+ if args.train_dtype in ("float16", "float32", "float64"):
277
+ dtype = getattr(torch, args.train_dtype)
278
+ else:
279
+ dtype = torch.float32
280
+ model = model_class(args.n_vocab, args).to(dtype=dtype)
281
+ if args.ngpu > 0:
282
+ model.to("cuda")
283
+ gpu_id = list(range(args.ngpu))
284
+ else:
285
+ gpu_id = [-1]
286
+
287
+ # Save model conf to json
288
+ model_conf = args.outdir + "/model.json"
289
+ with open(model_conf, "wb") as f:
290
+ logging.info("writing a model config file to " + model_conf)
291
+ f.write(
292
+ json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode(
293
+ "utf_8"
294
+ )
295
+ )
296
+
297
+ logging.warning(
298
+ "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
299
+ sum(p.numel() for p in model.parameters()),
300
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
301
+ sum(p.numel() for p in model.parameters() if p.requires_grad)
302
+ * 100.0
303
+ / sum(p.numel() for p in model.parameters()),
304
+ )
305
+ )
306
+
307
+ # Set up an optimizer
308
+ opt_class = dynamic_import_optimizer(args.opt, args.backend)
309
+ optimizer = opt_class.from_args(model.parameters(), args)
310
+ if args.schedulers is None:
311
+ schedulers = []
312
+ else:
313
+ schedulers = [dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers]
314
+
315
+ # setup apex.amp
316
+ if args.train_dtype in ("O0", "O1", "O2", "O3"):
317
+ try:
318
+ from apex import amp
319
+ except ImportError as e:
320
+ logging.error(
321
+ f"You need to install apex for --train-dtype {args.train_dtype}. "
322
+ "See https://github.com/NVIDIA/apex#linux"
323
+ )
324
+ raise e
325
+ model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype)
326
+ use_apex = True
327
+ else:
328
+ use_apex = False
329
+
330
+ # FIXME: TOO DIRTY HACK
331
+ reporter = Reporter()
332
+ setattr(model, "reporter", reporter)
333
+ setattr(optimizer, "target", reporter)
334
+ setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
335
+
336
+ updater = BPTTUpdater(
337
+ train_iter,
338
+ model,
339
+ optimizer,
340
+ schedulers,
341
+ gpu_id,
342
+ gradclip=args.gradclip,
343
+ use_apex=use_apex,
344
+ accum_grad=args.accum_grad,
345
+ )
346
+ trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.outdir)
347
+ trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id))
348
+ trainer.extend(
349
+ extensions.LogReport(
350
+ postprocess=compute_perplexity,
351
+ trigger=(args.report_interval_iters, "iteration"),
352
+ )
353
+ )
354
+ trainer.extend(
355
+ extensions.PrintReport(
356
+ [
357
+ "epoch",
358
+ "iteration",
359
+ "main/loss",
360
+ "perplexity",
361
+ "val_perplexity",
362
+ "elapsed_time",
363
+ ]
364
+ ),
365
+ trigger=(args.report_interval_iters, "iteration"),
366
+ )
367
+ trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
368
+ # Save best models
369
+ trainer.extend(torch_snapshot(filename="snapshot.ep.{.updater.epoch}"))
370
+ trainer.extend(snapshot_object(model, "rnnlm.model.{.updater.epoch}"))
371
+ # T.Hori: MinValueTrigger should be used, but it fails when resuming
372
+ trainer.extend(MakeSymlinkToBestModel("validation/main/loss", "rnnlm.model"))
373
+
374
+ if use_sortagrad:
375
+ trainer.extend(
376
+ ShufflingEnabler([train_iter]),
377
+ trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, "epoch"),
378
+ )
379
+ if args.resume:
380
+ logging.info("resumed from %s" % args.resume)
381
+ torch_resume(args.resume, trainer)
382
+
383
+ set_early_stop(trainer, args, is_lm=True)
384
+ if args.tensorboard_dir is not None and args.tensorboard_dir != "":
385
+ writer = SummaryWriter(args.tensorboard_dir)
386
+ trainer.extend(
387
+ TensorboardLogger(writer), trigger=(args.report_interval_iters, "iteration")
388
+ )
389
+
390
+ trainer.run()
391
+ check_early_stop(trainer, args.epoch)
392
+
393
+ # compute perplexity for test set
394
+ if args.test_label:
395
+ logging.info("test the best model")
396
+ torch_load(args.outdir + "/rnnlm.model.best", model)
397
+ test = read_tokens(args.test_label, args.char_list_dict)
398
+ n_test_tokens, n_test_oovs = count_tokens(test, unk)
399
+ logging.info("#sentences in the test data = " + str(len(test)))
400
+ logging.info("#tokens in the test data = " + str(n_test_tokens))
401
+ logging.info(
402
+ "oov rate in the test data = %.2f %%" % (n_test_oovs / n_test_tokens * 100)
403
+ )
404
+ test_iter = ParallelSentenceIterator(
405
+ test, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
406
+ )
407
+ evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id)
408
+ result = evaluator()
409
+ compute_perplexity(result)
410
+ logging.info(f"test perplexity: {result['perplexity']}")
espnet/mt/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/mt/mt_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+
4
+ # Copyright 2019 Kyoto University (Hirofumi Inaguma)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Utility funcitons for the text translation task."""
8
+
9
+ import logging
10
+
11
+
12
+ # * ------------------ recognition related ------------------ *
13
+ def parse_hypothesis(hyp, char_list):
14
+ """Parse hypothesis.
15
+
16
+ :param list hyp: recognition hypothesis
17
+ :param list char_list: list of characters
18
+ :return: recognition text string
19
+ :return: recognition token string
20
+ :return: recognition tokenid string
21
+ """
22
+ # remove sos and get results
23
+ tokenid_as_list = list(map(int, hyp["yseq"][1:]))
24
+ token_as_list = [char_list[idx] for idx in tokenid_as_list]
25
+ score = float(hyp["score"])
26
+
27
+ # convert to string
28
+ tokenid = " ".join([str(idx) for idx in tokenid_as_list])
29
+ token = " ".join(token_as_list)
30
+ text = "".join(token_as_list).replace("<space>", " ")
31
+
32
+ return text, token, tokenid, score
33
+
34
+
35
+ def add_results_to_json(js, nbest_hyps, char_list):
36
+ """Add N-best results to json.
37
+
38
+ :param dict js: groundtruth utterance dict
39
+ :param list nbest_hyps: list of hypothesis
40
+ :param list char_list: list of characters
41
+ :return: N-best results added utterance dict
42
+ """
43
+ # copy old json info
44
+ new_js = dict()
45
+ if "utt2spk" in js.keys():
46
+ new_js["utt2spk"] = js["utt2spk"]
47
+ new_js["output"] = []
48
+
49
+ for n, hyp in enumerate(nbest_hyps, 1):
50
+ # parse hypothesis
51
+ rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
52
+
53
+ # copy ground-truth
54
+ if len(js["output"]) > 0:
55
+ out_dic = dict(js["output"][0].items())
56
+ else:
57
+ out_dic = {"name": ""}
58
+
59
+ # update name
60
+ out_dic["name"] += "[%d]" % n
61
+
62
+ # add recognition results
63
+ out_dic["rec_text"] = rec_text
64
+ out_dic["rec_token"] = rec_token
65
+ out_dic["rec_tokenid"] = rec_tokenid
66
+ out_dic["score"] = score
67
+
68
+ # add source reference
69
+ out_dic["text_src"] = js["output"][1]["text"]
70
+ out_dic["token_src"] = js["output"][1]["token"]
71
+ out_dic["tokenid_src"] = js["output"][1]["tokenid"]
72
+
73
+ # add to list of N-best result dicts
74
+ new_js["output"].append(out_dic)
75
+
76
+ # show 1-best result
77
+ if n == 1:
78
+ if "text" in out_dic.keys():
79
+ logging.info("groundtruth: %s" % out_dic["text"])
80
+ logging.info("prediction : %s" % out_dic["rec_text"])
81
+ logging.info("source : %s" % out_dic["token_src"])
82
+
83
+ return new_js
espnet/mt/pytorch_backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/mt/pytorch_backend/mt.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+
4
+ # Copyright 2019 Kyoto University (Hirofumi Inaguma)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Training/decoding definition for the text translation task."""
8
+
9
+ import json
10
+ import logging
11
+ import os
12
+ import sys
13
+
14
+ from chainer import training
15
+ from chainer.training import extensions
16
+ import numpy as np
17
+ from tensorboardX import SummaryWriter
18
+ import torch
19
+
20
+ from espnet.asr.asr_utils import adadelta_eps_decay
21
+ from espnet.asr.asr_utils import adam_lr_decay
22
+ from espnet.asr.asr_utils import add_results_to_json
23
+ from espnet.asr.asr_utils import CompareValueTrigger
24
+ from espnet.asr.asr_utils import restore_snapshot
25
+ from espnet.asr.asr_utils import snapshot_object
26
+ from espnet.asr.asr_utils import torch_load
27
+ from espnet.asr.asr_utils import torch_resume
28
+ from espnet.asr.asr_utils import torch_snapshot
29
+ from espnet.nets.mt_interface import MTInterface
30
+ from espnet.nets.pytorch_backend.e2e_asr import pad_list
31
+ from espnet.utils.dataset import ChainerDataLoader
32
+ from espnet.utils.dataset import TransformDataset
33
+ from espnet.utils.deterministic_utils import set_deterministic_pytorch
34
+ from espnet.utils.dynamic_import import dynamic_import
35
+ from espnet.utils.io_utils import LoadInputsAndTargets
36
+ from espnet.utils.training.batchfy import make_batchset
37
+ from espnet.utils.training.iterators import ShufflingEnabler
38
+ from espnet.utils.training.tensorboard_logger import TensorboardLogger
39
+ from espnet.utils.training.train_utils import check_early_stop
40
+ from espnet.utils.training.train_utils import set_early_stop
41
+
42
+ from espnet.asr.pytorch_backend.asr import CustomEvaluator
43
+ from espnet.asr.pytorch_backend.asr import CustomUpdater
44
+ from espnet.asr.pytorch_backend.asr import load_trained_model
45
+
46
+ import matplotlib
47
+
48
+ matplotlib.use("Agg")
49
+
50
+ if sys.version_info[0] == 2:
51
+ from itertools import izip_longest as zip_longest
52
+ else:
53
+ from itertools import zip_longest as zip_longest
54
+
55
+
56
+ class CustomConverter(object):
57
+ """Custom batch converter for Pytorch."""
58
+
59
+ def __init__(self):
60
+ """Construct a CustomConverter object."""
61
+ self.ignore_id = -1
62
+ self.pad = 0
63
+ # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class
64
+ # in ASR. However,
65
+ # blank labels are not used in NMT. To keep the vocabulary size,
66
+ # we use index:0 for padding instead of adding one more class.
67
+
68
+ def __call__(self, batch, device=torch.device("cpu")):
69
+ """Transform a batch and send it to a device.
70
+
71
+ Args:
72
+ batch (list): The batch to transform.
73
+ device (torch.device): The device to send to.
74
+
75
+ Returns:
76
+ tuple(torch.Tensor, torch.Tensor, torch.Tensor)
77
+
78
+ """
79
+ # batch should be located in list
80
+ assert len(batch) == 1
81
+ xs, ys = batch[0]
82
+
83
+ # get batch of lengths of input sequences
84
+ ilens = np.array([x.shape[0] for x in xs])
85
+
86
+ # perform padding and convert to tensor
87
+ xs_pad = pad_list([torch.from_numpy(x).long() for x in xs], self.pad).to(device)
88
+ ilens = torch.from_numpy(ilens).to(device)
89
+ ys_pad = pad_list([torch.from_numpy(y).long() for y in ys], self.ignore_id).to(
90
+ device
91
+ )
92
+
93
+ return xs_pad, ilens, ys_pad
94
+
95
+
96
+ def train(args):
97
+ """Train with the given args.
98
+
99
+ Args:
100
+ args (namespace): The program arguments.
101
+
102
+ """
103
+ set_deterministic_pytorch(args)
104
+
105
+ # check cuda availability
106
+ if not torch.cuda.is_available():
107
+ logging.warning("cuda is not available")
108
+
109
+ # get input and output dimension info
110
+ with open(args.valid_json, "rb") as f:
111
+ valid_json = json.load(f)["utts"]
112
+ utts = list(valid_json.keys())
113
+ idim = int(valid_json[utts[0]]["output"][1]["shape"][1])
114
+ odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
115
+ logging.info("#input dims : " + str(idim))
116
+ logging.info("#output dims: " + str(odim))
117
+
118
+ # specify model architecture
119
+ model_class = dynamic_import(args.model_module)
120
+ model = model_class(idim, odim, args)
121
+ assert isinstance(model, MTInterface)
122
+
123
+ # write model config
124
+ if not os.path.exists(args.outdir):
125
+ os.makedirs(args.outdir)
126
+ model_conf = args.outdir + "/model.json"
127
+ with open(model_conf, "wb") as f:
128
+ logging.info("writing a model config file to " + model_conf)
129
+ f.write(
130
+ json.dumps(
131
+ (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
132
+ ).encode("utf_8")
133
+ )
134
+ for key in sorted(vars(args).keys()):
135
+ logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
136
+
137
+ reporter = model.reporter
138
+
139
+ # check the use of multi-gpu
140
+ if args.ngpu > 1:
141
+ if args.batch_size != 0:
142
+ logging.warning(
143
+ "batch size is automatically increased (%d -> %d)"
144
+ % (args.batch_size, args.batch_size * args.ngpu)
145
+ )
146
+ args.batch_size *= args.ngpu
147
+
148
+ # set torch device
149
+ device = torch.device("cuda" if args.ngpu > 0 else "cpu")
150
+ if args.train_dtype in ("float16", "float32", "float64"):
151
+ dtype = getattr(torch, args.train_dtype)
152
+ else:
153
+ dtype = torch.float32
154
+ model = model.to(device=device, dtype=dtype)
155
+
156
+ logging.warning(
157
+ "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
158
+ sum(p.numel() for p in model.parameters()),
159
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
160
+ sum(p.numel() for p in model.parameters() if p.requires_grad)
161
+ * 100.0
162
+ / sum(p.numel() for p in model.parameters()),
163
+ )
164
+ )
165
+
166
+ # Setup an optimizer
167
+ if args.opt == "adadelta":
168
+ optimizer = torch.optim.Adadelta(
169
+ model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay
170
+ )
171
+ elif args.opt == "adam":
172
+ optimizer = torch.optim.Adam(
173
+ model.parameters(), lr=args.lr, weight_decay=args.weight_decay
174
+ )
175
+ elif args.opt == "noam":
176
+ from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
177
+
178
+ optimizer = get_std_opt(
179
+ model.parameters(),
180
+ args.adim,
181
+ args.transformer_warmup_steps,
182
+ args.transformer_lr,
183
+ )
184
+ else:
185
+ raise NotImplementedError("unknown optimizer: " + args.opt)
186
+
187
+ # setup apex.amp
188
+ if args.train_dtype in ("O0", "O1", "O2", "O3"):
189
+ try:
190
+ from apex import amp
191
+ except ImportError as e:
192
+ logging.error(
193
+ f"You need to install apex for --train-dtype {args.train_dtype}. "
194
+ "See https://github.com/NVIDIA/apex#linux"
195
+ )
196
+ raise e
197
+ if args.opt == "noam":
198
+ model, optimizer.optimizer = amp.initialize(
199
+ model, optimizer.optimizer, opt_level=args.train_dtype
200
+ )
201
+ else:
202
+ model, optimizer = amp.initialize(
203
+ model, optimizer, opt_level=args.train_dtype
204
+ )
205
+ use_apex = True
206
+ else:
207
+ use_apex = False
208
+
209
+ # FIXME: TOO DIRTY HACK
210
+ setattr(optimizer, "target", reporter)
211
+ setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
212
+
213
+ # Setup a converter
214
+ converter = CustomConverter()
215
+
216
+ # read json data
217
+ with open(args.train_json, "rb") as f:
218
+ train_json = json.load(f)["utts"]
219
+ with open(args.valid_json, "rb") as f:
220
+ valid_json = json.load(f)["utts"]
221
+
222
+ use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
223
+ # make minibatch list (variable length)
224
+ train = make_batchset(
225
+ train_json,
226
+ args.batch_size,
227
+ args.maxlen_in,
228
+ args.maxlen_out,
229
+ args.minibatches,
230
+ min_batch_size=args.ngpu if args.ngpu > 1 else 1,
231
+ shortest_first=use_sortagrad,
232
+ count=args.batch_count,
233
+ batch_bins=args.batch_bins,
234
+ batch_frames_in=args.batch_frames_in,
235
+ batch_frames_out=args.batch_frames_out,
236
+ batch_frames_inout=args.batch_frames_inout,
237
+ mt=True,
238
+ iaxis=1,
239
+ oaxis=0,
240
+ )
241
+ valid = make_batchset(
242
+ valid_json,
243
+ args.batch_size,
244
+ args.maxlen_in,
245
+ args.maxlen_out,
246
+ args.minibatches,
247
+ min_batch_size=args.ngpu if args.ngpu > 1 else 1,
248
+ count=args.batch_count,
249
+ batch_bins=args.batch_bins,
250
+ batch_frames_in=args.batch_frames_in,
251
+ batch_frames_out=args.batch_frames_out,
252
+ batch_frames_inout=args.batch_frames_inout,
253
+ mt=True,
254
+ iaxis=1,
255
+ oaxis=0,
256
+ )
257
+
258
+ load_tr = LoadInputsAndTargets(mode="mt", load_output=True)
259
+ load_cv = LoadInputsAndTargets(mode="mt", load_output=True)
260
+ # hack to make batchsize argument as 1
261
+ # actual bathsize is included in a list
262
+ # default collate function converts numpy array to pytorch tensor
263
+ # we used an empty collate function instead which returns list
264
+ train_iter = ChainerDataLoader(
265
+ dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
266
+ batch_size=1,
267
+ num_workers=args.n_iter_processes,
268
+ shuffle=not use_sortagrad,
269
+ collate_fn=lambda x: x[0],
270
+ )
271
+ valid_iter = ChainerDataLoader(
272
+ dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
273
+ batch_size=1,
274
+ shuffle=False,
275
+ collate_fn=lambda x: x[0],
276
+ num_workers=args.n_iter_processes,
277
+ )
278
+
279
+ # Set up a trainer
280
+ updater = CustomUpdater(
281
+ model,
282
+ args.grad_clip,
283
+ {"main": train_iter},
284
+ optimizer,
285
+ device,
286
+ args.ngpu,
287
+ False,
288
+ args.accum_grad,
289
+ use_apex=use_apex,
290
+ )
291
+ trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
292
+
293
+ if use_sortagrad:
294
+ trainer.extend(
295
+ ShufflingEnabler([train_iter]),
296
+ trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
297
+ )
298
+
299
+ # Resume from a snapshot
300
+ if args.resume:
301
+ logging.info("resumed from %s" % args.resume)
302
+ torch_resume(args.resume, trainer)
303
+
304
+ # Evaluate the model with the test dataset for each epoch
305
+ if args.save_interval_iters > 0:
306
+ trainer.extend(
307
+ CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu),
308
+ trigger=(args.save_interval_iters, "iteration"),
309
+ )
310
+ else:
311
+ trainer.extend(
312
+ CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu)
313
+ )
314
+
315
+ # Save attention weight each epoch
316
+ if args.num_save_attention > 0:
317
+ # NOTE: sort it by output lengths
318
+ data = sorted(
319
+ list(valid_json.items())[: args.num_save_attention],
320
+ key=lambda x: int(x[1]["output"][0]["shape"][0]),
321
+ reverse=True,
322
+ )
323
+ if hasattr(model, "module"):
324
+ att_vis_fn = model.module.calculate_all_attentions
325
+ plot_class = model.module.attention_plot_class
326
+ else:
327
+ att_vis_fn = model.calculate_all_attentions
328
+ plot_class = model.attention_plot_class
329
+ att_reporter = plot_class(
330
+ att_vis_fn,
331
+ data,
332
+ args.outdir + "/att_ws",
333
+ converter=converter,
334
+ transform=load_cv,
335
+ device=device,
336
+ ikey="output",
337
+ iaxis=1,
338
+ )
339
+ trainer.extend(att_reporter, trigger=(1, "epoch"))
340
+ else:
341
+ att_reporter = None
342
+
343
+ # Make a plot for training and validation values
344
+ trainer.extend(
345
+ extensions.PlotReport(
346
+ ["main/loss", "validation/main/loss"], "epoch", file_name="loss.png"
347
+ )
348
+ )
349
+ trainer.extend(
350
+ extensions.PlotReport(
351
+ ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
352
+ )
353
+ )
354
+ trainer.extend(
355
+ extensions.PlotReport(
356
+ ["main/ppl", "validation/main/ppl"], "epoch", file_name="ppl.png"
357
+ )
358
+ )
359
+ trainer.extend(
360
+ extensions.PlotReport(
361
+ ["main/bleu", "validation/main/bleu"], "epoch", file_name="bleu.png"
362
+ )
363
+ )
364
+
365
+ # Save best models
366
+ trainer.extend(
367
+ snapshot_object(model, "model.loss.best"),
368
+ trigger=training.triggers.MinValueTrigger("validation/main/loss"),
369
+ )
370
+ trainer.extend(
371
+ snapshot_object(model, "model.acc.best"),
372
+ trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
373
+ )
374
+
375
+ # save snapshot which contains model and optimizer states
376
+ if args.save_interval_iters > 0:
377
+ trainer.extend(
378
+ torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
379
+ trigger=(args.save_interval_iters, "iteration"),
380
+ )
381
+ else:
382
+ trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
383
+
384
+ # epsilon decay in the optimizer
385
+ if args.opt == "adadelta":
386
+ if args.criterion == "acc":
387
+ trainer.extend(
388
+ restore_snapshot(
389
+ model, args.outdir + "/model.acc.best", load_fn=torch_load
390
+ ),
391
+ trigger=CompareValueTrigger(
392
+ "validation/main/acc",
393
+ lambda best_value, current_value: best_value > current_value,
394
+ ),
395
+ )
396
+ trainer.extend(
397
+ adadelta_eps_decay(args.eps_decay),
398
+ trigger=CompareValueTrigger(
399
+ "validation/main/acc",
400
+ lambda best_value, current_value: best_value > current_value,
401
+ ),
402
+ )
403
+ elif args.criterion == "loss":
404
+ trainer.extend(
405
+ restore_snapshot(
406
+ model, args.outdir + "/model.loss.best", load_fn=torch_load
407
+ ),
408
+ trigger=CompareValueTrigger(
409
+ "validation/main/loss",
410
+ lambda best_value, current_value: best_value < current_value,
411
+ ),
412
+ )
413
+ trainer.extend(
414
+ adadelta_eps_decay(args.eps_decay),
415
+ trigger=CompareValueTrigger(
416
+ "validation/main/loss",
417
+ lambda best_value, current_value: best_value < current_value,
418
+ ),
419
+ )
420
+ elif args.opt == "adam":
421
+ if args.criterion == "acc":
422
+ trainer.extend(
423
+ restore_snapshot(
424
+ model, args.outdir + "/model.acc.best", load_fn=torch_load
425
+ ),
426
+ trigger=CompareValueTrigger(
427
+ "validation/main/acc",
428
+ lambda best_value, current_value: best_value > current_value,
429
+ ),
430
+ )
431
+ trainer.extend(
432
+ adam_lr_decay(args.lr_decay),
433
+ trigger=CompareValueTrigger(
434
+ "validation/main/acc",
435
+ lambda best_value, current_value: best_value > current_value,
436
+ ),
437
+ )
438
+ elif args.criterion == "loss":
439
+ trainer.extend(
440
+ restore_snapshot(
441
+ model, args.outdir + "/model.loss.best", load_fn=torch_load
442
+ ),
443
+ trigger=CompareValueTrigger(
444
+ "validation/main/loss",
445
+ lambda best_value, current_value: best_value < current_value,
446
+ ),
447
+ )
448
+ trainer.extend(
449
+ adam_lr_decay(args.lr_decay),
450
+ trigger=CompareValueTrigger(
451
+ "validation/main/loss",
452
+ lambda best_value, current_value: best_value < current_value,
453
+ ),
454
+ )
455
+
456
+ # Write a log of evaluation statistics for each epoch
457
+ trainer.extend(
458
+ extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
459
+ )
460
+ report_keys = [
461
+ "epoch",
462
+ "iteration",
463
+ "main/loss",
464
+ "validation/main/loss",
465
+ "main/acc",
466
+ "validation/main/acc",
467
+ "main/ppl",
468
+ "validation/main/ppl",
469
+ "elapsed_time",
470
+ ]
471
+ if args.opt == "adadelta":
472
+ trainer.extend(
473
+ extensions.observe_value(
474
+ "eps",
475
+ lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
476
+ "eps"
477
+ ],
478
+ ),
479
+ trigger=(args.report_interval_iters, "iteration"),
480
+ )
481
+ report_keys.append("eps")
482
+ elif args.opt in ["adam", "noam"]:
483
+ trainer.extend(
484
+ extensions.observe_value(
485
+ "lr",
486
+ lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
487
+ "lr"
488
+ ],
489
+ ),
490
+ trigger=(args.report_interval_iters, "iteration"),
491
+ )
492
+ report_keys.append("lr")
493
+ if args.report_bleu:
494
+ report_keys.append("main/bleu")
495
+ report_keys.append("validation/main/bleu")
496
+ trainer.extend(
497
+ extensions.PrintReport(report_keys),
498
+ trigger=(args.report_interval_iters, "iteration"),
499
+ )
500
+
501
+ trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
502
+ set_early_stop(trainer, args)
503
+
504
+ if args.tensorboard_dir is not None and args.tensorboard_dir != "":
505
+ trainer.extend(
506
+ TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter),
507
+ trigger=(args.report_interval_iters, "iteration"),
508
+ )
509
+ # Run the training
510
+ trainer.run()
511
+ check_early_stop(trainer, args.epochs)
512
+
513
+
514
+ def trans(args):
515
+ """Decode with the given args.
516
+
517
+ Args:
518
+ args (namespace): The program arguments.
519
+
520
+ """
521
+ set_deterministic_pytorch(args)
522
+ model, train_args = load_trained_model(args.model)
523
+ assert isinstance(model, MTInterface)
524
+ model.trans_args = args
525
+
526
+ # gpu
527
+ if args.ngpu == 1:
528
+ gpu_id = list(range(args.ngpu))
529
+ logging.info("gpu id: " + str(gpu_id))
530
+ model.cuda()
531
+
532
+ # read json data
533
+ with open(args.trans_json, "rb") as f:
534
+ js = json.load(f)["utts"]
535
+ new_js = {}
536
+
537
+ # remove enmpy utterances
538
+ if train_args.multilingual:
539
+ js = {
540
+ k: v
541
+ for k, v in js.items()
542
+ if v["output"][0]["shape"][0] > 1 and v["output"][1]["shape"][0] > 1
543
+ }
544
+ else:
545
+ js = {
546
+ k: v
547
+ for k, v in js.items()
548
+ if v["output"][0]["shape"][0] > 0 and v["output"][1]["shape"][0] > 0
549
+ }
550
+
551
+ if args.batchsize == 0:
552
+ with torch.no_grad():
553
+ for idx, name in enumerate(js.keys(), 1):
554
+ logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
555
+ feat = [js[name]["output"][1]["tokenid"].split()]
556
+ nbest_hyps = model.translate(feat, args, train_args.char_list)
557
+ new_js[name] = add_results_to_json(
558
+ js[name], nbest_hyps, train_args.char_list
559
+ )
560
+
561
+ else:
562
+
563
+ def grouper(n, iterable, fillvalue=None):
564
+ kargs = [iter(iterable)] * n
565
+ return zip_longest(*kargs, fillvalue=fillvalue)
566
+
567
+ # sort data
568
+ keys = list(js.keys())
569
+ feat_lens = [js[key]["output"][1]["shape"][0] for key in keys]
570
+ sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
571
+ keys = [keys[i] for i in sorted_index]
572
+
573
+ with torch.no_grad():
574
+ for names in grouper(args.batchsize, keys, None):
575
+ names = [name for name in names if name]
576
+ feats = [
577
+ np.fromiter(
578
+ map(int, js[name]["output"][1]["tokenid"].split()),
579
+ dtype=np.int64,
580
+ )
581
+ for name in names
582
+ ]
583
+ nbest_hyps = model.translate_batch(
584
+ feats,
585
+ args,
586
+ train_args.char_list,
587
+ )
588
+
589
+ for i, nbest_hyp in enumerate(nbest_hyps):
590
+ name = names[i]
591
+ new_js[name] = add_results_to_json(
592
+ js[name], nbest_hyp, train_args.char_list
593
+ )
594
+
595
+ with open(args.result_label, "wb") as f:
596
+ f.write(
597
+ json.dumps(
598
+ {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
599
+ ).encode("utf_8")
600
+ )
espnet/nets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/nets/asr_interface.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ASR Interface module."""
2
+ import argparse
3
+
4
+ from espnet.bin.asr_train import get_parser
5
+ from espnet.utils.dynamic_import import dynamic_import
6
+ from espnet.utils.fill_missing_args import fill_missing_args
7
+
8
+
9
+ class ASRInterface:
10
+ """ASR Interface for ESPnet model implementation."""
11
+
12
+ @staticmethod
13
+ def add_arguments(parser):
14
+ """Add arguments to parser."""
15
+ return parser
16
+
17
+ @classmethod
18
+ def build(cls, idim: int, odim: int, **kwargs):
19
+ """Initialize this class with python-level args.
20
+
21
+ Args:
22
+ idim (int): The number of an input feature dim.
23
+ odim (int): The number of output vocab.
24
+
25
+ Returns:
26
+ ASRinterface: A new instance of ASRInterface.
27
+
28
+ """
29
+
30
+ def wrap(parser):
31
+ return get_parser(parser, required=False)
32
+
33
+ args = argparse.Namespace(**kwargs)
34
+ args = fill_missing_args(args, wrap)
35
+ args = fill_missing_args(args, cls.add_arguments)
36
+ return cls(idim, odim, args)
37
+
38
+ def forward(self, xs, ilens, ys):
39
+ """Compute loss for training.
40
+
41
+ :param xs:
42
+ For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim)
43
+ For chainer, list of source sequences chainer.Variable
44
+ :param ilens: batch of lengths of source sequences (B)
45
+ For pytorch, torch.Tensor
46
+ For chainer, list of int
47
+ :param ys:
48
+ For pytorch, batch of padded source sequences torch.Tensor (B, Lmax)
49
+ For chainer, list of source sequences chainer.Variable
50
+ :return: loss value
51
+ :rtype: torch.Tensor for pytorch, chainer.Variable for chainer
52
+ """
53
+ raise NotImplementedError("forward method is not implemented")
54
+
55
+ def recognize(self, x, recog_args, char_list=None, rnnlm=None):
56
+ """Recognize x for evaluation.
57
+
58
+ :param ndarray x: input acouctic feature (B, T, D) or (T, D)
59
+ :param namespace recog_args: argment namespace contraining options
60
+ :param list char_list: list of characters
61
+ :param torch.nn.Module rnnlm: language model module
62
+ :return: N-best decoding results
63
+ :rtype: list
64
+ """
65
+ raise NotImplementedError("recognize method is not implemented")
66
+
67
+ def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None):
68
+ """Beam search implementation for batch.
69
+
70
+ :param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
71
+ :param namespace recog_args: argument namespace containing options
72
+ :param list char_list: list of characters
73
+ :param torch.nn.Module rnnlm: language model module
74
+ :return: N-best decoding results
75
+ :rtype: list
76
+ """
77
+ raise NotImplementedError("Batch decoding is not supported yet.")
78
+
79
+ def calculate_all_attentions(self, xs, ilens, ys):
80
+ """Caluculate attention.
81
+
82
+ :param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
83
+ :param ndarray ilens: batch of lengths of input sequences (B)
84
+ :param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
85
+ :return: attention weights (B, Lmax, Tmax)
86
+ :rtype: float ndarray
87
+ """
88
+ raise NotImplementedError("calculate_all_attentions method is not implemented")
89
+
90
+ def calculate_all_ctc_probs(self, xs, ilens, ys):
91
+ """Caluculate CTC probability.
92
+
93
+ :param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
94
+ :param ndarray ilens: batch of lengths of input sequences (B)
95
+ :param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
96
+ :return: CTC probabilities (B, Tmax, vocab)
97
+ :rtype: float ndarray
98
+ """
99
+ raise NotImplementedError("calculate_all_ctc_probs method is not implemented")
100
+
101
+ @property
102
+ def attention_plot_class(self):
103
+ """Get attention plot class."""
104
+ from espnet.asr.asr_utils import PlotAttentionReport
105
+
106
+ return PlotAttentionReport
107
+
108
+ @property
109
+ def ctc_plot_class(self):
110
+ """Get CTC plot class."""
111
+ from espnet.asr.asr_utils import PlotCTCReport
112
+
113
+ return PlotCTCReport
114
+
115
+ def get_total_subsampling_factor(self):
116
+ """Get total subsampling factor."""
117
+ raise NotImplementedError(
118
+ "get_total_subsampling_factor method is not implemented"
119
+ )
120
+
121
+ def encode(self, feat):
122
+ """Encode feature in `beam_search` (optional).
123
+
124
+ Args:
125
+ x (numpy.ndarray): input feature (T, D)
126
+ Returns:
127
+ torch.Tensor for pytorch, chainer.Variable for chainer:
128
+ encoded feature (T, D)
129
+
130
+ """
131
+ raise NotImplementedError("encode method is not implemented")
132
+
133
+ def scorers(self):
134
+ """Get scorers for `beam_search` (optional).
135
+
136
+ Returns:
137
+ dict[str, ScorerInterface]: dict of `ScorerInterface` objects
138
+
139
+ """
140
+ raise NotImplementedError("decoders method is not implemented")
141
+
142
+
143
+ predefined_asr = {
144
+ "pytorch": {
145
+ "rnn": "espnet.nets.pytorch_backend.e2e_asr:E2E",
146
+ "transducer": "espnet.nets.pytorch_backend.e2e_asr_transducer:E2E",
147
+ "transformer": "espnet.nets.pytorch_backend.e2e_asr_transformer:E2E",
148
+ "conformer": "espnet.nets.pytorch_backend.e2e_asr_conformer:E2E",
149
+ },
150
+ "chainer": {
151
+ "rnn": "espnet.nets.chainer_backend.e2e_asr:E2E",
152
+ "transformer": "espnet.nets.chainer_backend.e2e_asr_transformer:E2E",
153
+ },
154
+ }
155
+
156
+
157
+ def dynamic_import_asr(module, backend):
158
+ """Import ASR models dynamically.
159
+
160
+ Args:
161
+ module (str): module_name:class_name or alias in `predefined_asr`
162
+ backend (str): NN backend. e.g., pytorch, chainer
163
+
164
+ Returns:
165
+ type: ASR class
166
+
167
+ """
168
+ model_class = dynamic_import(module, predefined_asr.get(backend, dict()))
169
+ assert issubclass(
170
+ model_class, ASRInterface
171
+ ), f"{module} does not implement ASRInterface"
172
+ return model_class
espnet/nets/batch_beam_search.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Parallel beam search module."""
2
+
3
+ import logging
4
+ from typing import Any
5
+ from typing import Dict
6
+ from typing import List
7
+ from typing import NamedTuple
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ from torch.nn.utils.rnn import pad_sequence
12
+
13
+ from espnet.nets.beam_search import BeamSearch
14
+ from espnet.nets.beam_search import Hypothesis
15
+
16
+
17
+ class BatchHypothesis(NamedTuple):
18
+ """Batchfied/Vectorized hypothesis data type."""
19
+
20
+ yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
21
+ score: torch.Tensor = torch.tensor([]) # (batch,)
22
+ length: torch.Tensor = torch.tensor([]) # (batch,)
23
+ scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
24
+ states: Dict[str, Dict] = dict()
25
+
26
+ def __len__(self) -> int:
27
+ """Return a batch size."""
28
+ return len(self.length)
29
+
30
+
31
+ class BatchBeamSearch(BeamSearch):
32
+ """Batch beam search implementation."""
33
+
34
+ def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
35
+ """Convert list to batch."""
36
+ if len(hyps) == 0:
37
+ return BatchHypothesis()
38
+ return BatchHypothesis(
39
+ yseq=pad_sequence(
40
+ [h.yseq for h in hyps], batch_first=True, padding_value=self.eos
41
+ ),
42
+ length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64),
43
+ score=torch.tensor([h.score for h in hyps]),
44
+ scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers},
45
+ states={k: [h.states[k] for h in hyps] for k in self.scorers},
46
+ )
47
+
48
+ def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
49
+ return BatchHypothesis(
50
+ yseq=hyps.yseq[ids],
51
+ score=hyps.score[ids],
52
+ length=hyps.length[ids],
53
+ scores={k: v[ids] for k, v in hyps.scores.items()},
54
+ states={
55
+ k: [self.scorers[k].select_state(v, i) for i in ids]
56
+ for k, v in hyps.states.items()
57
+ },
58
+ )
59
+
60
+ def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
61
+ return Hypothesis(
62
+ yseq=hyps.yseq[i, : hyps.length[i]],
63
+ score=hyps.score[i],
64
+ scores={k: v[i] for k, v in hyps.scores.items()},
65
+ states={
66
+ k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()
67
+ },
68
+ )
69
+
70
+ def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
71
+ """Revert batch to list."""
72
+ return [
73
+ Hypothesis(
74
+ yseq=batch_hyps.yseq[i][: batch_hyps.length[i]],
75
+ score=batch_hyps.score[i],
76
+ scores={k: batch_hyps.scores[k][i] for k in self.scorers},
77
+ states={
78
+ k: v.select_state(batch_hyps.states[k], i)
79
+ for k, v in self.scorers.items()
80
+ },
81
+ )
82
+ for i in range(len(batch_hyps.length))
83
+ ]
84
+
85
+ def batch_beam(
86
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
87
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
88
+ """Batch-compute topk full token ids and partial token ids.
89
+
90
+ Args:
91
+ weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
92
+ Its shape is `(n_beam, self.vocab_size)`.
93
+ ids (torch.Tensor): The partial token ids to compute topk.
94
+ Its shape is `(n_beam, self.pre_beam_size)`.
95
+
96
+ Returns:
97
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ The topk full (prev_hyp, new_token) ids
99
+ and partial (prev_hyp, new_token) ids.
100
+ Their shapes are all `(self.beam_size,)`
101
+
102
+ """
103
+ top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
104
+ # Because of the flatten above, `top_ids` is organized as:
105
+ # [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
106
+ # where V is `self.n_vocab` and K is `self.beam_size`
107
+ prev_hyp_ids = top_ids // self.n_vocab
108
+ new_token_ids = top_ids % self.n_vocab
109
+ return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
110
+
111
+ def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
112
+ """Get an initial hypothesis data.
113
+
114
+ Args:
115
+ x (torch.Tensor): The encoder output feature
116
+
117
+ Returns:
118
+ Hypothesis: The initial hypothesis.
119
+
120
+ """
121
+ init_states = dict()
122
+ init_scores = dict()
123
+ for k, d in self.scorers.items():
124
+ init_states[k] = d.batch_init_state(x)
125
+ init_scores[k] = 0.0
126
+ return self.batchfy(
127
+ [
128
+ Hypothesis(
129
+ score=0.0,
130
+ scores=init_scores,
131
+ states=init_states,
132
+ yseq=torch.tensor([self.sos], device=x.device),
133
+ )
134
+ ]
135
+ )
136
+
137
+ def score_full(
138
+ self, hyp: BatchHypothesis, x: torch.Tensor
139
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
140
+ """Score new hypothesis by `self.full_scorers`.
141
+
142
+ Args:
143
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
144
+ x (torch.Tensor): Corresponding input feature
145
+
146
+ Returns:
147
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
148
+ score dict of `hyp` that has string keys of `self.full_scorers`
149
+ and tensor score values of shape: `(self.n_vocab,)`,
150
+ and state dict that has string keys
151
+ and state values of `self.full_scorers`
152
+
153
+ """
154
+ scores = dict()
155
+ states = dict()
156
+ for k, d in self.full_scorers.items():
157
+ scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
158
+ return scores, states
159
+
160
+ def score_partial(
161
+ self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor
162
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
163
+ """Score new hypothesis by `self.full_scorers`.
164
+
165
+ Args:
166
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
167
+ ids (torch.Tensor): 2D tensor of new partial tokens to score
168
+ x (torch.Tensor): Corresponding input feature
169
+
170
+ Returns:
171
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
172
+ score dict of `hyp` that has string keys of `self.full_scorers`
173
+ and tensor score values of shape: `(self.n_vocab,)`,
174
+ and state dict that has string keys
175
+ and state values of `self.full_scorers`
176
+
177
+ """
178
+ scores = dict()
179
+ states = dict()
180
+ for k, d in self.part_scorers.items():
181
+ scores[k], states[k] = d.batch_score_partial(
182
+ hyp.yseq, ids, hyp.states[k], x
183
+ )
184
+ return scores, states
185
+
186
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
187
+ """Merge states for new hypothesis.
188
+
189
+ Args:
190
+ states: states of `self.full_scorers`
191
+ part_states: states of `self.part_scorers`
192
+ part_idx (int): The new token id for `part_scores`
193
+
194
+ Returns:
195
+ Dict[str, torch.Tensor]: The new score dict.
196
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
197
+ Its values are states of the scorers.
198
+
199
+ """
200
+ new_states = dict()
201
+ for k, v in states.items():
202
+ new_states[k] = v
203
+ for k, v in part_states.items():
204
+ new_states[k] = v
205
+ return new_states
206
+
207
+ def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
208
+ """Search new tokens for running hypotheses and encoded speech x.
209
+
210
+ Args:
211
+ running_hyps (BatchHypothesis): Running hypotheses on beam
212
+ x (torch.Tensor): Encoded speech feature (T, D)
213
+
214
+ Returns:
215
+ BatchHypothesis: Best sorted hypotheses
216
+
217
+ """
218
+ n_batch = len(running_hyps)
219
+ part_ids = None # no pre-beam
220
+ # batch scoring
221
+ weighted_scores = torch.zeros(
222
+ n_batch, self.n_vocab, dtype=x.dtype, device=x.device
223
+ )
224
+ scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
225
+ for k in self.full_scorers:
226
+ weighted_scores += self.weights[k] * scores[k]
227
+ # partial scoring
228
+ if self.do_pre_beam:
229
+ pre_beam_scores = (
230
+ weighted_scores
231
+ if self.pre_beam_score_key == "full"
232
+ else scores[self.pre_beam_score_key]
233
+ )
234
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1]
235
+ # NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
236
+ # full-size score matrices, which has non-zero scores for part_ids and zeros
237
+ # for others.
238
+ part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
239
+ for k in self.part_scorers:
240
+ weighted_scores += self.weights[k] * part_scores[k]
241
+ # add previous hyp scores
242
+ weighted_scores += running_hyps.score.to(
243
+ dtype=x.dtype, device=x.device
244
+ ).unsqueeze(1)
245
+
246
+ # TODO(karita): do not use list. use batch instead
247
+ # see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
248
+ # update hyps
249
+ best_hyps = []
250
+ prev_hyps = self.unbatchfy(running_hyps)
251
+ for (
252
+ full_prev_hyp_id,
253
+ full_new_token_id,
254
+ part_prev_hyp_id,
255
+ part_new_token_id,
256
+ ) in zip(*self.batch_beam(weighted_scores, part_ids)):
257
+ prev_hyp = prev_hyps[full_prev_hyp_id]
258
+ best_hyps.append(
259
+ Hypothesis(
260
+ score=weighted_scores[full_prev_hyp_id, full_new_token_id],
261
+ yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
262
+ scores=self.merge_scores(
263
+ prev_hyp.scores,
264
+ {k: v[full_prev_hyp_id] for k, v in scores.items()},
265
+ full_new_token_id,
266
+ {k: v[part_prev_hyp_id] for k, v in part_scores.items()},
267
+ part_new_token_id,
268
+ ),
269
+ states=self.merge_states(
270
+ {
271
+ k: self.full_scorers[k].select_state(v, full_prev_hyp_id)
272
+ for k, v in states.items()
273
+ },
274
+ {
275
+ k: self.part_scorers[k].select_state(
276
+ v, part_prev_hyp_id, part_new_token_id
277
+ )
278
+ for k, v in part_states.items()
279
+ },
280
+ part_new_token_id,
281
+ ),
282
+ )
283
+ )
284
+ return self.batchfy(best_hyps)
285
+
286
+ def post_process(
287
+ self,
288
+ i: int,
289
+ maxlen: int,
290
+ maxlenratio: float,
291
+ running_hyps: BatchHypothesis,
292
+ ended_hyps: List[Hypothesis],
293
+ ) -> BatchHypothesis:
294
+ """Perform post-processing of beam search iterations.
295
+
296
+ Args:
297
+ i (int): The length of hypothesis tokens.
298
+ maxlen (int): The maximum length of tokens in beam search.
299
+ maxlenratio (int): The maximum length ratio in beam search.
300
+ running_hyps (BatchHypothesis): The running hypotheses in beam search.
301
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
302
+
303
+ Returns:
304
+ BatchHypothesis: The new running hypotheses.
305
+
306
+ """
307
+ n_batch = running_hyps.yseq.shape[0]
308
+ logging.debug(f"the number of running hypothes: {n_batch}")
309
+ if self.token_list is not None:
310
+ logging.debug(
311
+ "best hypo: "
312
+ + "".join(
313
+ [
314
+ self.token_list[x]
315
+ for x in running_hyps.yseq[0, 1 : running_hyps.length[0]]
316
+ ]
317
+ )
318
+ )
319
+ # add eos in the final loop to avoid that there are no ended hyps
320
+ if i == maxlen - 1:
321
+ logging.info("adding <eos> in the last position in the loop")
322
+ yseq_eos = torch.cat(
323
+ (
324
+ running_hyps.yseq,
325
+ torch.full(
326
+ (n_batch, 1),
327
+ self.eos,
328
+ device=running_hyps.yseq.device,
329
+ dtype=torch.int64,
330
+ ),
331
+ ),
332
+ 1,
333
+ )
334
+ running_hyps.yseq.resize_as_(yseq_eos)
335
+ running_hyps.yseq[:] = yseq_eos
336
+ running_hyps.length[:] = yseq_eos.shape[1]
337
+
338
+ # add ended hypotheses to a final list, and removed them from current hypotheses
339
+ # (this will be a probmlem, number of hyps < beam)
340
+ is_eos = (
341
+ running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1]
342
+ == self.eos
343
+ )
344
+ for b in torch.nonzero(is_eos).view(-1):
345
+ hyp = self._select(running_hyps, b)
346
+ ended_hyps.append(hyp)
347
+ remained_ids = torch.nonzero(is_eos == 0).view(-1)
348
+ return self._batch_select(running_hyps, remained_ids)
espnet/nets/batch_beam_search_online_sim.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Parallel beam search module for online simulation."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import List
6
+
7
+ import yaml
8
+
9
+ import torch
10
+
11
+ from espnet.nets.batch_beam_search import BatchBeamSearch
12
+ from espnet.nets.beam_search import Hypothesis
13
+ from espnet.nets.e2e_asr_common import end_detect
14
+
15
+
16
+ class BatchBeamSearchOnlineSim(BatchBeamSearch):
17
+ """Online beam search implementation.
18
+
19
+ This simulates streaming decoding.
20
+ It requires encoded features of entire utterance and
21
+ extracts block by block from it as it shoud be done
22
+ in streaming processing.
23
+ This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
24
+ WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
25
+ (https://arxiv.org/abs/2006.14941).
26
+ """
27
+
28
+ def set_streaming_config(self, asr_config: str):
29
+ """Set config file for streaming decoding.
30
+
31
+ Args:
32
+ asr_config (str): The config file for asr training
33
+
34
+ """
35
+ train_config_file = Path(asr_config)
36
+ self.block_size = None
37
+ self.hop_size = None
38
+ self.look_ahead = None
39
+ config = None
40
+ with train_config_file.open("r", encoding="utf-8") as f:
41
+ args = yaml.safe_load(f)
42
+ if "encoder_conf" in args.keys():
43
+ if "block_size" in args["encoder_conf"].keys():
44
+ self.block_size = args["encoder_conf"]["block_size"]
45
+ if "hop_size" in args["encoder_conf"].keys():
46
+ self.hop_size = args["encoder_conf"]["hop_size"]
47
+ if "look_ahead" in args["encoder_conf"].keys():
48
+ self.look_ahead = args["encoder_conf"]["look_ahead"]
49
+ elif "config" in args.keys():
50
+ config = args["config"]
51
+ if config is None:
52
+ logging.info(
53
+ "Cannot find config file for streaming decoding: "
54
+ + "apply batch beam search instead."
55
+ )
56
+ return
57
+ if (
58
+ self.block_size is None or self.hop_size is None or self.look_ahead is None
59
+ ) and config is not None:
60
+ config_file = Path(config)
61
+ with config_file.open("r", encoding="utf-8") as f:
62
+ args = yaml.safe_load(f)
63
+ if "encoder_conf" in args.keys():
64
+ enc_args = args["encoder_conf"]
65
+ if enc_args and "block_size" in enc_args:
66
+ self.block_size = enc_args["block_size"]
67
+ if enc_args and "hop_size" in enc_args:
68
+ self.hop_size = enc_args["hop_size"]
69
+ if enc_args and "look_ahead" in enc_args:
70
+ self.look_ahead = enc_args["look_ahead"]
71
+
72
+ def set_block_size(self, block_size: int):
73
+ """Set block size for streaming decoding.
74
+
75
+ Args:
76
+ block_size (int): The block size of encoder
77
+ """
78
+ self.block_size = block_size
79
+
80
+ def set_hop_size(self, hop_size: int):
81
+ """Set hop size for streaming decoding.
82
+
83
+ Args:
84
+ hop_size (int): The hop size of encoder
85
+ """
86
+ self.hop_size = hop_size
87
+
88
+ def set_look_ahead(self, look_ahead: int):
89
+ """Set look ahead size for streaming decoding.
90
+
91
+ Args:
92
+ look_ahead (int): The look ahead size of encoder
93
+ """
94
+ self.look_ahead = look_ahead
95
+
96
+ def forward(
97
+ self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
98
+ ) -> List[Hypothesis]:
99
+ """Perform beam search.
100
+
101
+ Args:
102
+ x (torch.Tensor): Encoded speech feature (T, D)
103
+ maxlenratio (float): Input length ratio to obtain max output length.
104
+ If maxlenratio=0.0 (default), it uses a end-detect function
105
+ to automatically find maximum hypothesis lengths
106
+ minlenratio (float): Input length ratio to obtain min output length.
107
+
108
+ Returns:
109
+ list[Hypothesis]: N-best decoding results
110
+
111
+ """
112
+ self.conservative = True # always true
113
+
114
+ if self.block_size and self.hop_size and self.look_ahead:
115
+ cur_end_frame = int(self.block_size - self.look_ahead)
116
+ else:
117
+ cur_end_frame = x.shape[0]
118
+ process_idx = 0
119
+ if cur_end_frame < x.shape[0]:
120
+ h = x.narrow(0, 0, cur_end_frame)
121
+ else:
122
+ h = x
123
+
124
+ # set length bounds
125
+ if maxlenratio == 0:
126
+ maxlen = x.shape[0]
127
+ else:
128
+ maxlen = max(1, int(maxlenratio * x.size(0)))
129
+ minlen = int(minlenratio * x.size(0))
130
+ logging.info("decoder input length: " + str(x.shape[0]))
131
+ logging.info("max output length: " + str(maxlen))
132
+ logging.info("min output length: " + str(minlen))
133
+
134
+ # main loop of prefix search
135
+ running_hyps = self.init_hyp(h)
136
+ prev_hyps = []
137
+ ended_hyps = []
138
+ prev_repeat = False
139
+
140
+ continue_decode = True
141
+
142
+ while continue_decode:
143
+ move_to_next_block = False
144
+ if cur_end_frame < x.shape[0]:
145
+ h = x.narrow(0, 0, cur_end_frame)
146
+ else:
147
+ h = x
148
+
149
+ # extend states for ctc
150
+ self.extend(h, running_hyps)
151
+
152
+ while process_idx < maxlen:
153
+ logging.debug("position " + str(process_idx))
154
+ best = self.search(running_hyps, h)
155
+
156
+ if process_idx == maxlen - 1:
157
+ # end decoding
158
+ running_hyps = self.post_process(
159
+ process_idx, maxlen, maxlenratio, best, ended_hyps
160
+ )
161
+ n_batch = best.yseq.shape[0]
162
+ local_ended_hyps = []
163
+ is_local_eos = (
164
+ best.yseq[torch.arange(n_batch), best.length - 1] == self.eos
165
+ )
166
+ for i in range(is_local_eos.shape[0]):
167
+ if is_local_eos[i]:
168
+ hyp = self._select(best, i)
169
+ local_ended_hyps.append(hyp)
170
+ # NOTE(tsunoo): check repetitions here
171
+ # This is a implicit implementation of
172
+ # Eq (11) in https://arxiv.org/abs/2006.14941
173
+ # A flag prev_repeat is used instead of using set
174
+ elif (
175
+ not prev_repeat
176
+ and best.yseq[i, -1] in best.yseq[i, :-1]
177
+ and cur_end_frame < x.shape[0]
178
+ ):
179
+ move_to_next_block = True
180
+ prev_repeat = True
181
+ if maxlenratio == 0.0 and end_detect(
182
+ [lh.asdict() for lh in local_ended_hyps], process_idx
183
+ ):
184
+ logging.info(f"end detected at {process_idx}")
185
+ continue_decode = False
186
+ break
187
+ if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]:
188
+ move_to_next_block = True
189
+
190
+ if move_to_next_block:
191
+ if (
192
+ self.hop_size
193
+ and cur_end_frame + int(self.hop_size) + int(self.look_ahead)
194
+ < x.shape[0]
195
+ ):
196
+ cur_end_frame += int(self.hop_size)
197
+ else:
198
+ cur_end_frame = x.shape[0]
199
+ logging.debug("Going to next block: %d", cur_end_frame)
200
+ if process_idx > 1 and len(prev_hyps) > 0 and self.conservative:
201
+ running_hyps = prev_hyps
202
+ process_idx -= 1
203
+ prev_hyps = []
204
+ break
205
+
206
+ prev_repeat = False
207
+ prev_hyps = running_hyps
208
+ running_hyps = self.post_process(
209
+ process_idx, maxlen, maxlenratio, best, ended_hyps
210
+ )
211
+
212
+ if cur_end_frame >= x.shape[0]:
213
+ for hyp in local_ended_hyps:
214
+ ended_hyps.append(hyp)
215
+
216
+ if len(running_hyps) == 0:
217
+ logging.info("no hypothesis. Finish decoding.")
218
+ continue_decode = False
219
+ break
220
+ else:
221
+ logging.debug(f"remained hypotheses: {len(running_hyps)}")
222
+ # increment number
223
+ process_idx += 1
224
+
225
+ nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
226
+ # check the number of hypotheses reaching to eos
227
+ if len(nbest_hyps) == 0:
228
+ logging.warning(
229
+ "there is no N-best results, perform recognition "
230
+ "again with smaller minlenratio."
231
+ )
232
+ return (
233
+ []
234
+ if minlenratio < 0.1
235
+ else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
236
+ )
237
+
238
+ # report the best result
239
+ best = nbest_hyps[0]
240
+ for k, v in best.scores.items():
241
+ logging.info(
242
+ f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
243
+ )
244
+ logging.info(f"total log probability: {best.score:.2f}")
245
+ logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
246
+ logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
247
+ if self.token_list is not None:
248
+ logging.info(
249
+ "best hypo: "
250
+ + "".join([self.token_list[x] for x in best.yseq[1:-1]])
251
+ + "\n"
252
+ )
253
+ return nbest_hyps
254
+
255
+ def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
256
+ """Extend probabilities and states with more encoded chunks.
257
+
258
+ Args:
259
+ x (torch.Tensor): The extended encoder output feature
260
+ hyps (Hypothesis): Current list of hypothesis
261
+
262
+ Returns:
263
+ Hypothesis: The exxtended hypothesis
264
+
265
+ """
266
+ for k, d in self.scorers.items():
267
+ if hasattr(d, "extend_prob"):
268
+ d.extend_prob(x)
269
+ if hasattr(d, "extend_state"):
270
+ hyps.states[k] = d.extend_state(hyps.states[k])
espnet/nets/beam_search.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Beam search module."""
2
+
3
+ from itertools import chain
4
+ import logging
5
+ from typing import Any
6
+ from typing import Dict
7
+ from typing import List
8
+ from typing import NamedTuple
9
+ from typing import Tuple
10
+ from typing import Union
11
+
12
+ import torch
13
+
14
+ from espnet.nets.e2e_asr_common import end_detect
15
+ from espnet.nets.scorer_interface import PartialScorerInterface
16
+ from espnet.nets.scorer_interface import ScorerInterface
17
+
18
+
19
+ class Hypothesis(NamedTuple):
20
+ """Hypothesis data type."""
21
+
22
+ yseq: torch.Tensor
23
+ score: Union[float, torch.Tensor] = 0
24
+ scores: Dict[str, Union[float, torch.Tensor]] = dict()
25
+ states: Dict[str, Any] = dict()
26
+
27
+ def asdict(self) -> dict:
28
+ """Convert data to JSON-friendly dict."""
29
+ return self._replace(
30
+ yseq=self.yseq.tolist(),
31
+ score=float(self.score),
32
+ scores={k: float(v) for k, v in self.scores.items()},
33
+ )._asdict()
34
+
35
+
36
+ class BeamSearch(torch.nn.Module):
37
+ """Beam search implementation."""
38
+
39
+ def __init__(
40
+ self,
41
+ scorers: Dict[str, ScorerInterface],
42
+ weights: Dict[str, float],
43
+ beam_size: int,
44
+ vocab_size: int,
45
+ sos: int,
46
+ eos: int,
47
+ token_list: List[str] = None,
48
+ pre_beam_ratio: float = 1.5,
49
+ pre_beam_score_key: str = None,
50
+ ):
51
+ """Initialize beam search.
52
+
53
+ Args:
54
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
55
+ e.g., Decoder, CTCPrefixScorer, LM
56
+ The scorer will be ignored if it is `None`
57
+ weights (dict[str, float]): Dict of weights for each scorers
58
+ The scorer will be ignored if its weight is 0
59
+ beam_size (int): The number of hypotheses kept during search
60
+ vocab_size (int): The number of vocabulary
61
+ sos (int): Start of sequence id
62
+ eos (int): End of sequence id
63
+ token_list (list[str]): List of tokens for debug log
64
+ pre_beam_score_key (str): key of scores to perform pre-beam search
65
+ pre_beam_ratio (float): beam size in the pre-beam search
66
+ will be `int(pre_beam_ratio * beam_size)`
67
+
68
+ """
69
+ super().__init__()
70
+ # set scorers
71
+ self.weights = weights
72
+ self.scorers = dict()
73
+ self.full_scorers = dict()
74
+ self.part_scorers = dict()
75
+ # this module dict is required for recursive cast
76
+ # `self.to(device, dtype)` in `recog.py`
77
+ self.nn_dict = torch.nn.ModuleDict()
78
+ for k, v in scorers.items():
79
+ w = weights.get(k, 0)
80
+ if w == 0 or v is None:
81
+ continue
82
+ assert isinstance(
83
+ v, ScorerInterface
84
+ ), f"{k} ({type(v)}) does not implement ScorerInterface"
85
+ self.scorers[k] = v
86
+ if isinstance(v, PartialScorerInterface):
87
+ self.part_scorers[k] = v
88
+ else:
89
+ self.full_scorers[k] = v
90
+ if isinstance(v, torch.nn.Module):
91
+ self.nn_dict[k] = v
92
+
93
+ # set configurations
94
+ self.sos = sos
95
+ self.eos = eos
96
+ self.token_list = token_list
97
+ self.pre_beam_size = int(pre_beam_ratio * beam_size)
98
+ self.beam_size = beam_size
99
+ self.n_vocab = vocab_size
100
+ if (
101
+ pre_beam_score_key is not None
102
+ and pre_beam_score_key != "full"
103
+ and pre_beam_score_key not in self.full_scorers
104
+ ):
105
+ raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
106
+ self.pre_beam_score_key = pre_beam_score_key
107
+ self.do_pre_beam = (
108
+ self.pre_beam_score_key is not None
109
+ and self.pre_beam_size < self.n_vocab
110
+ and len(self.part_scorers) > 0
111
+ )
112
+
113
+ def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
114
+ """Get an initial hypothesis data.
115
+
116
+ Args:
117
+ x (torch.Tensor): The encoder output feature
118
+
119
+ Returns:
120
+ Hypothesis: The initial hypothesis.
121
+
122
+ """
123
+ init_states = dict()
124
+ init_scores = dict()
125
+ for k, d in self.scorers.items():
126
+ init_states[k] = d.init_state(x)
127
+ init_scores[k] = 0.0
128
+ return [
129
+ Hypothesis(
130
+ score=0.0,
131
+ scores=init_scores,
132
+ states=init_states,
133
+ yseq=torch.tensor([self.sos], device=x.device),
134
+ )
135
+ ]
136
+
137
+ @staticmethod
138
+ def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
139
+ """Append new token to prefix tokens.
140
+
141
+ Args:
142
+ xs (torch.Tensor): The prefix token
143
+ x (int): The new token to append
144
+
145
+ Returns:
146
+ torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
147
+
148
+ """
149
+ x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
150
+ return torch.cat((xs, x))
151
+
152
+ def score_full(
153
+ self, hyp: Hypothesis, x: torch.Tensor
154
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
155
+ """Score new hypothesis by `self.full_scorers`.
156
+
157
+ Args:
158
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
159
+ x (torch.Tensor): Corresponding input feature
160
+
161
+ Returns:
162
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
163
+ score dict of `hyp` that has string keys of `self.full_scorers`
164
+ and tensor score values of shape: `(self.n_vocab,)`,
165
+ and state dict that has string keys
166
+ and state values of `self.full_scorers`
167
+
168
+ """
169
+ scores = dict()
170
+ states = dict()
171
+ for k, d in self.full_scorers.items():
172
+ scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
173
+ return scores, states
174
+
175
+ def score_partial(
176
+ self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
177
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
178
+ """Score new hypothesis by `self.part_scorers`.
179
+
180
+ Args:
181
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
182
+ ids (torch.Tensor): 1D tensor of new partial tokens to score
183
+ x (torch.Tensor): Corresponding input feature
184
+
185
+ Returns:
186
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
187
+ score dict of `hyp` that has string keys of `self.part_scorers`
188
+ and tensor score values of shape: `(len(ids),)`,
189
+ and state dict that has string keys
190
+ and state values of `self.part_scorers`
191
+
192
+ """
193
+ scores = dict()
194
+ states = dict()
195
+ for k, d in self.part_scorers.items():
196
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
197
+ return scores, states
198
+
199
+ def beam(
200
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
201
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
202
+ """Compute topk full token ids and partial token ids.
203
+
204
+ Args:
205
+ weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
206
+ Its shape is `(self.n_vocab,)`.
207
+ ids (torch.Tensor): The partial token ids to compute topk
208
+
209
+ Returns:
210
+ Tuple[torch.Tensor, torch.Tensor]:
211
+ The topk full token ids and partial token ids.
212
+ Their shapes are `(self.beam_size,)`
213
+
214
+ """
215
+ # no pre beam performed
216
+ if weighted_scores.size(0) == ids.size(0):
217
+ top_ids = weighted_scores.topk(self.beam_size)[1]
218
+ return top_ids, top_ids
219
+
220
+ # mask pruned in pre-beam not to select in topk
221
+ tmp = weighted_scores[ids]
222
+ weighted_scores[:] = -float("inf")
223
+ weighted_scores[ids] = tmp
224
+ top_ids = weighted_scores.topk(self.beam_size)[1]
225
+ local_ids = weighted_scores[ids].topk(self.beam_size)[1]
226
+ return top_ids, local_ids
227
+
228
+ @staticmethod
229
+ def merge_scores(
230
+ prev_scores: Dict[str, float],
231
+ next_full_scores: Dict[str, torch.Tensor],
232
+ full_idx: int,
233
+ next_part_scores: Dict[str, torch.Tensor],
234
+ part_idx: int,
235
+ ) -> Dict[str, torch.Tensor]:
236
+ """Merge scores for new hypothesis.
237
+
238
+ Args:
239
+ prev_scores (Dict[str, float]):
240
+ The previous hypothesis scores by `self.scorers`
241
+ next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
242
+ full_idx (int): The next token id for `next_full_scores`
243
+ next_part_scores (Dict[str, torch.Tensor]):
244
+ scores of partial tokens by `self.part_scorers`
245
+ part_idx (int): The new token id for `next_part_scores`
246
+
247
+ Returns:
248
+ Dict[str, torch.Tensor]: The new score dict.
249
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
250
+ Its values are scalar tensors by the scorers.
251
+
252
+ """
253
+ new_scores = dict()
254
+ for k, v in next_full_scores.items():
255
+ new_scores[k] = prev_scores[k] + v[full_idx]
256
+ for k, v in next_part_scores.items():
257
+ new_scores[k] = prev_scores[k] + v[part_idx]
258
+ return new_scores
259
+
260
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
261
+ """Merge states for new hypothesis.
262
+
263
+ Args:
264
+ states: states of `self.full_scorers`
265
+ part_states: states of `self.part_scorers`
266
+ part_idx (int): The new token id for `part_scores`
267
+
268
+ Returns:
269
+ Dict[str, torch.Tensor]: The new score dict.
270
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
271
+ Its values are states of the scorers.
272
+
273
+ """
274
+ new_states = dict()
275
+ for k, v in states.items():
276
+ new_states[k] = v
277
+ for k, d in self.part_scorers.items():
278
+ new_states[k] = d.select_state(part_states[k], part_idx)
279
+ return new_states
280
+
281
+ def search(
282
+ self, running_hyps: List[Hypothesis], x: torch.Tensor
283
+ ) -> List[Hypothesis]:
284
+ """Search new tokens for running hypotheses and encoded speech x.
285
+
286
+ Args:
287
+ running_hyps (List[Hypothesis]): Running hypotheses on beam
288
+ x (torch.Tensor): Encoded speech feature (T, D)
289
+
290
+ Returns:
291
+ List[Hypotheses]: Best sorted hypotheses
292
+
293
+ """
294
+ best_hyps = []
295
+ part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
296
+ for hyp in running_hyps:
297
+ # scoring
298
+ weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
299
+ scores, states = self.score_full(hyp, x)
300
+ for k in self.full_scorers:
301
+ weighted_scores += self.weights[k] * scores[k]
302
+ # partial scoring
303
+ if self.do_pre_beam:
304
+ pre_beam_scores = (
305
+ weighted_scores
306
+ if self.pre_beam_score_key == "full"
307
+ else scores[self.pre_beam_score_key]
308
+ )
309
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
310
+ part_scores, part_states = self.score_partial(hyp, part_ids, x)
311
+ for k in self.part_scorers:
312
+ weighted_scores[part_ids] += self.weights[k] * part_scores[k]
313
+ # add previous hyp score
314
+ weighted_scores += hyp.score
315
+
316
+ # update hyps
317
+ for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
318
+ # will be (2 x beam at most)
319
+ best_hyps.append(
320
+ Hypothesis(
321
+ score=weighted_scores[j],
322
+ yseq=self.append_token(hyp.yseq, j),
323
+ scores=self.merge_scores(
324
+ hyp.scores, scores, j, part_scores, part_j
325
+ ),
326
+ states=self.merge_states(states, part_states, part_j),
327
+ )
328
+ )
329
+
330
+ # sort and prune 2 x beam -> beam
331
+ best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
332
+ : min(len(best_hyps), self.beam_size)
333
+ ]
334
+ return best_hyps
335
+
336
+ def forward(
337
+ self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
338
+ ) -> List[Hypothesis]:
339
+ """Perform beam search.
340
+
341
+ Args:
342
+ x (torch.Tensor): Encoded speech feature (T, D)
343
+ maxlenratio (float): Input length ratio to obtain max output length.
344
+ If maxlenratio=0.0 (default), it uses a end-detect function
345
+ to automatically find maximum hypothesis lengths
346
+ minlenratio (float): Input length ratio to obtain min output length.
347
+
348
+ Returns:
349
+ list[Hypothesis]: N-best decoding results
350
+
351
+ """
352
+ # set length bounds
353
+ if maxlenratio == 0:
354
+ maxlen = x.shape[0]
355
+ else:
356
+ maxlen = max(1, int(maxlenratio * x.size(0)))
357
+ minlen = int(minlenratio * x.size(0))
358
+ logging.info("decoder input length: " + str(x.shape[0]))
359
+ logging.info("max output length: " + str(maxlen))
360
+ logging.info("min output length: " + str(minlen))
361
+
362
+ # main loop of prefix search
363
+ running_hyps = self.init_hyp(x)
364
+ ended_hyps = []
365
+ for i in range(maxlen):
366
+ logging.debug("position " + str(i))
367
+ best = self.search(running_hyps, x)
368
+ # post process of one iteration
369
+ running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
370
+ # end detection
371
+ if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
372
+ logging.info(f"end detected at {i}")
373
+ break
374
+ if len(running_hyps) == 0:
375
+ logging.info("no hypothesis. Finish decoding.")
376
+ break
377
+ else:
378
+ logging.debug(f"remained hypotheses: {len(running_hyps)}")
379
+
380
+ nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
381
+ # check the number of hypotheses reaching to eos
382
+ if len(nbest_hyps) == 0:
383
+ logging.warning(
384
+ "there is no N-best results, perform recognition "
385
+ "again with smaller minlenratio."
386
+ )
387
+ return (
388
+ []
389
+ if minlenratio < 0.1
390
+ else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
391
+ )
392
+
393
+ # report the best result
394
+ best = nbest_hyps[0]
395
+ for k, v in best.scores.items():
396
+ logging.info(
397
+ f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
398
+ )
399
+ logging.info(f"total log probability: {best.score:.2f}")
400
+ logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
401
+ logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
402
+ if self.token_list is not None:
403
+ logging.info(
404
+ "best hypo: "
405
+ + "".join([self.token_list[x] for x in best.yseq[1:-1]])
406
+ + "\n"
407
+ )
408
+ return nbest_hyps
409
+
410
+ def post_process(
411
+ self,
412
+ i: int,
413
+ maxlen: int,
414
+ maxlenratio: float,
415
+ running_hyps: List[Hypothesis],
416
+ ended_hyps: List[Hypothesis],
417
+ ) -> List[Hypothesis]:
418
+ """Perform post-processing of beam search iterations.
419
+
420
+ Args:
421
+ i (int): The length of hypothesis tokens.
422
+ maxlen (int): The maximum length of tokens in beam search.
423
+ maxlenratio (int): The maximum length ratio in beam search.
424
+ running_hyps (List[Hypothesis]): The running hypotheses in beam search.
425
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
426
+
427
+ Returns:
428
+ List[Hypothesis]: The new running hypotheses.
429
+
430
+ """
431
+ logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
432
+ if self.token_list is not None:
433
+ logging.debug(
434
+ "best hypo: "
435
+ + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
436
+ )
437
+ # add eos in the final loop to avoid that there are no ended hyps
438
+ if i == maxlen - 1:
439
+ logging.info("adding <eos> in the last position in the loop")
440
+ running_hyps = [
441
+ h._replace(yseq=self.append_token(h.yseq, self.eos))
442
+ for h in running_hyps
443
+ ]
444
+
445
+ # add ended hypotheses to a final list, and removed them from current hypotheses
446
+ # (this will be a problem, number of hyps < beam)
447
+ remained_hyps = []
448
+ for hyp in running_hyps:
449
+ if hyp.yseq[-1] == self.eos:
450
+ # e.g., Word LM needs to add final <eos> score
451
+ for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
452
+ s = d.final_score(hyp.states[k])
453
+ hyp.scores[k] += s
454
+ hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
455
+ ended_hyps.append(hyp)
456
+ else:
457
+ remained_hyps.append(hyp)
458
+ return remained_hyps
459
+
460
+
461
+ def beam_search(
462
+ x: torch.Tensor,
463
+ sos: int,
464
+ eos: int,
465
+ beam_size: int,
466
+ vocab_size: int,
467
+ scorers: Dict[str, ScorerInterface],
468
+ weights: Dict[str, float],
469
+ token_list: List[str] = None,
470
+ maxlenratio: float = 0.0,
471
+ minlenratio: float = 0.0,
472
+ pre_beam_ratio: float = 1.5,
473
+ pre_beam_score_key: str = "full",
474
+ ) -> list:
475
+ """Perform beam search with scorers.
476
+
477
+ Args:
478
+ x (torch.Tensor): Encoded speech feature (T, D)
479
+ sos (int): Start of sequence id
480
+ eos (int): End of sequence id
481
+ beam_size (int): The number of hypotheses kept during search
482
+ vocab_size (int): The number of vocabulary
483
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
484
+ e.g., Decoder, CTCPrefixScorer, LM
485
+ The scorer will be ignored if it is `None`
486
+ weights (dict[str, float]): Dict of weights for each scorers
487
+ The scorer will be ignored if its weight is 0
488
+ token_list (list[str]): List of tokens for debug log
489
+ maxlenratio (float): Input length ratio to obtain max output length.
490
+ If maxlenratio=0.0 (default), it uses a end-detect function
491
+ to automatically find maximum hypothesis lengths
492
+ minlenratio (float): Input length ratio to obtain min output length.
493
+ pre_beam_score_key (str): key of scores to perform pre-beam search
494
+ pre_beam_ratio (float): beam size in the pre-beam search
495
+ will be `int(pre_beam_ratio * beam_size)`
496
+
497
+ Returns:
498
+ list: N-best decoding results
499
+
500
+ """
501
+ ret = BeamSearch(
502
+ scorers,
503
+ weights,
504
+ beam_size=beam_size,
505
+ vocab_size=vocab_size,
506
+ pre_beam_ratio=pre_beam_ratio,
507
+ pre_beam_score_key=pre_beam_score_key,
508
+ sos=sos,
509
+ eos=eos,
510
+ token_list=token_list,
511
+ ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
512
+ return [h.asdict() for h in ret]
espnet/nets/beam_search_transducer.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Search algorithms for transducer models."""
2
+
3
+ from typing import List
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from espnet.nets.pytorch_backend.transducer.utils import create_lm_batch_state
10
+ from espnet.nets.pytorch_backend.transducer.utils import init_lm_state
11
+ from espnet.nets.pytorch_backend.transducer.utils import is_prefix
12
+ from espnet.nets.pytorch_backend.transducer.utils import recombine_hyps
13
+ from espnet.nets.pytorch_backend.transducer.utils import select_lm_state
14
+ from espnet.nets.pytorch_backend.transducer.utils import substract
15
+ from espnet.nets.transducer_decoder_interface import Hypothesis
16
+ from espnet.nets.transducer_decoder_interface import NSCHypothesis
17
+ from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface
18
+
19
+
20
+ class BeamSearchTransducer:
21
+ """Beam search implementation for transducer."""
22
+
23
+ def __init__(
24
+ self,
25
+ decoder: Union[TransducerDecoderInterface, torch.nn.Module],
26
+ joint_network: torch.nn.Module,
27
+ beam_size: int,
28
+ lm: torch.nn.Module = None,
29
+ lm_weight: float = 0.1,
30
+ search_type: str = "default",
31
+ max_sym_exp: int = 2,
32
+ u_max: int = 50,
33
+ nstep: int = 1,
34
+ prefix_alpha: int = 1,
35
+ score_norm: bool = True,
36
+ nbest: int = 1,
37
+ ):
38
+ """Initialize transducer beam search.
39
+
40
+ Args:
41
+ decoder: Decoder class to use
42
+ joint_network: Joint Network class
43
+ beam_size: Number of hypotheses kept during search
44
+ lm: LM class to use
45
+ lm_weight: lm weight for soft fusion
46
+ search_type: type of algorithm to use for search
47
+ max_sym_exp: number of maximum symbol expansions at each time step ("tsd")
48
+ u_max: maximum output sequence length ("alsd")
49
+ nstep: number of maximum expansion steps at each time step ("nsc")
50
+ prefix_alpha: maximum prefix length in prefix search ("nsc")
51
+ score_norm: normalize final scores by length ("default")
52
+ nbest: number of returned final hypothesis
53
+ """
54
+ self.decoder = decoder
55
+ self.joint_network = joint_network
56
+
57
+ self.beam_size = beam_size
58
+ self.hidden_size = decoder.dunits
59
+ self.vocab_size = decoder.odim
60
+ self.blank = decoder.blank
61
+
62
+ if self.beam_size <= 1:
63
+ self.search_algorithm = self.greedy_search
64
+ elif search_type == "default":
65
+ self.search_algorithm = self.default_beam_search
66
+ elif search_type == "tsd":
67
+ self.search_algorithm = self.time_sync_decoding
68
+ elif search_type == "alsd":
69
+ self.search_algorithm = self.align_length_sync_decoding
70
+ elif search_type == "nsc":
71
+ self.search_algorithm = self.nsc_beam_search
72
+ else:
73
+ raise NotImplementedError
74
+
75
+ self.lm = lm
76
+ self.lm_weight = lm_weight
77
+
78
+ if lm is not None:
79
+ self.use_lm = True
80
+ self.is_wordlm = True if hasattr(lm.predictor, "wordlm") else False
81
+ self.lm_predictor = lm.predictor.wordlm if self.is_wordlm else lm.predictor
82
+ self.lm_layers = len(self.lm_predictor.rnn)
83
+ else:
84
+ self.use_lm = False
85
+
86
+ self.max_sym_exp = max_sym_exp
87
+ self.u_max = u_max
88
+ self.nstep = nstep
89
+ self.prefix_alpha = prefix_alpha
90
+ self.score_norm = score_norm
91
+
92
+ self.nbest = nbest
93
+
94
+ def __call__(self, h: torch.Tensor) -> Union[List[Hypothesis], List[NSCHypothesis]]:
95
+ """Perform beam search.
96
+
97
+ Args:
98
+ h: Encoded speech features (T_max, D_enc)
99
+
100
+ Returns:
101
+ nbest_hyps: N-best decoding results
102
+
103
+ """
104
+ self.decoder.set_device(h.device)
105
+
106
+ if not hasattr(self.decoder, "decoders"):
107
+ self.decoder.set_data_type(h.dtype)
108
+
109
+ nbest_hyps = self.search_algorithm(h)
110
+
111
+ return nbest_hyps
112
+
113
+ def sort_nbest(
114
+ self, hyps: Union[List[Hypothesis], List[NSCHypothesis]]
115
+ ) -> Union[List[Hypothesis], List[NSCHypothesis]]:
116
+ """Sort hypotheses by score or score given sequence length.
117
+
118
+ Args:
119
+ hyps: list of hypotheses
120
+
121
+ Return:
122
+ hyps: sorted list of hypotheses
123
+
124
+ """
125
+ if self.score_norm:
126
+ hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
127
+ else:
128
+ hyps.sort(key=lambda x: x.score, reverse=True)
129
+
130
+ return hyps[: self.nbest]
131
+
132
+ def greedy_search(self, h: torch.Tensor) -> List[Hypothesis]:
133
+ """Greedy search implementation for transformer-transducer.
134
+
135
+ Args:
136
+ h: Encoded speech features (T_max, D_enc)
137
+
138
+ Returns:
139
+ hyp: 1-best decoding results
140
+
141
+ """
142
+ dec_state = self.decoder.init_state(1)
143
+
144
+ hyp = Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state)
145
+ cache = {}
146
+
147
+ y, state, _ = self.decoder.score(hyp, cache)
148
+
149
+ for i, hi in enumerate(h):
150
+ ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1)
151
+ logp, pred = torch.max(ytu, dim=-1)
152
+
153
+ if pred != self.blank:
154
+ hyp.yseq.append(int(pred))
155
+ hyp.score += float(logp)
156
+
157
+ hyp.dec_state = state
158
+
159
+ y, state, _ = self.decoder.score(hyp, cache)
160
+
161
+ return [hyp]
162
+
163
+ def default_beam_search(self, h: torch.Tensor) -> List[Hypothesis]:
164
+ """Beam search implementation.
165
+
166
+ Args:
167
+ x: Encoded speech features (T_max, D_enc)
168
+
169
+ Returns:
170
+ nbest_hyps: N-best decoding results
171
+
172
+ """
173
+ beam = min(self.beam_size, self.vocab_size)
174
+ beam_k = min(beam, (self.vocab_size - 1))
175
+
176
+ dec_state = self.decoder.init_state(1)
177
+
178
+ kept_hyps = [Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state)]
179
+ cache = {}
180
+
181
+ for hi in h:
182
+ hyps = kept_hyps
183
+ kept_hyps = []
184
+
185
+ while True:
186
+ max_hyp = max(hyps, key=lambda x: x.score)
187
+ hyps.remove(max_hyp)
188
+
189
+ y, state, lm_tokens = self.decoder.score(max_hyp, cache)
190
+
191
+ ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1)
192
+ top_k = ytu[1:].topk(beam_k, dim=-1)
193
+
194
+ kept_hyps.append(
195
+ Hypothesis(
196
+ score=(max_hyp.score + float(ytu[0:1])),
197
+ yseq=max_hyp.yseq[:],
198
+ dec_state=max_hyp.dec_state,
199
+ lm_state=max_hyp.lm_state,
200
+ )
201
+ )
202
+
203
+ if self.use_lm:
204
+ lm_state, lm_scores = self.lm.predict(max_hyp.lm_state, lm_tokens)
205
+ else:
206
+ lm_state = max_hyp.lm_state
207
+
208
+ for logp, k in zip(*top_k):
209
+ score = max_hyp.score + float(logp)
210
+
211
+ if self.use_lm:
212
+ score += self.lm_weight * lm_scores[0][k + 1]
213
+
214
+ hyps.append(
215
+ Hypothesis(
216
+ score=score,
217
+ yseq=max_hyp.yseq[:] + [int(k + 1)],
218
+ dec_state=state,
219
+ lm_state=lm_state,
220
+ )
221
+ )
222
+
223
+ hyps_max = float(max(hyps, key=lambda x: x.score).score)
224
+ kept_most_prob = sorted(
225
+ [hyp for hyp in kept_hyps if hyp.score > hyps_max],
226
+ key=lambda x: x.score,
227
+ )
228
+ if len(kept_most_prob) >= beam:
229
+ kept_hyps = kept_most_prob
230
+ break
231
+
232
+ return self.sort_nbest(kept_hyps)
233
+
234
+ def time_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]:
235
+ """Time synchronous beam search implementation.
236
+
237
+ Based on https://ieeexplore.ieee.org/document/9053040
238
+
239
+ Args:
240
+ h: Encoded speech features (T_max, D_enc)
241
+
242
+ Returns:
243
+ nbest_hyps: N-best decoding results
244
+
245
+ """
246
+ beam = min(self.beam_size, self.vocab_size)
247
+
248
+ beam_state = self.decoder.init_state(beam)
249
+
250
+ B = [
251
+ Hypothesis(
252
+ yseq=[self.blank],
253
+ score=0.0,
254
+ dec_state=self.decoder.select_state(beam_state, 0),
255
+ )
256
+ ]
257
+ cache = {}
258
+
259
+ if self.use_lm and not self.is_wordlm:
260
+ B[0].lm_state = init_lm_state(self.lm_predictor)
261
+
262
+ for hi in h:
263
+ A = []
264
+ C = B
265
+
266
+ h_enc = hi.unsqueeze(0)
267
+
268
+ for v in range(self.max_sym_exp):
269
+ D = []
270
+
271
+ beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
272
+ C,
273
+ beam_state,
274
+ cache,
275
+ self.use_lm,
276
+ )
277
+
278
+ beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
279
+ beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)
280
+
281
+ seq_A = [h.yseq for h in A]
282
+
283
+ for i, hyp in enumerate(C):
284
+ if hyp.yseq not in seq_A:
285
+ A.append(
286
+ Hypothesis(
287
+ score=(hyp.score + float(beam_logp[i, 0])),
288
+ yseq=hyp.yseq[:],
289
+ dec_state=hyp.dec_state,
290
+ lm_state=hyp.lm_state,
291
+ )
292
+ )
293
+ else:
294
+ dict_pos = seq_A.index(hyp.yseq)
295
+
296
+ A[dict_pos].score = np.logaddexp(
297
+ A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
298
+ )
299
+
300
+ if v < (self.max_sym_exp - 1):
301
+ if self.use_lm:
302
+ beam_lm_states = create_lm_batch_state(
303
+ [c.lm_state for c in C], self.lm_layers, self.is_wordlm
304
+ )
305
+
306
+ beam_lm_states, beam_lm_scores = self.lm.buff_predict(
307
+ beam_lm_states, beam_lm_tokens, len(C)
308
+ )
309
+
310
+ for i, hyp in enumerate(C):
311
+ for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
312
+ new_hyp = Hypothesis(
313
+ score=(hyp.score + float(logp)),
314
+ yseq=(hyp.yseq + [int(k)]),
315
+ dec_state=self.decoder.select_state(beam_state, i),
316
+ lm_state=hyp.lm_state,
317
+ )
318
+
319
+ if self.use_lm:
320
+ new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
321
+
322
+ new_hyp.lm_state = select_lm_state(
323
+ beam_lm_states, i, self.lm_layers, self.is_wordlm
324
+ )
325
+
326
+ D.append(new_hyp)
327
+
328
+ C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]
329
+
330
+ B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
331
+
332
+ return self.sort_nbest(B)
333
+
334
+ def align_length_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]:
335
+ """Alignment-length synchronous beam search implementation.
336
+
337
+ Based on https://ieeexplore.ieee.org/document/9053040
338
+
339
+ Args:
340
+ h: Encoded speech features (T_max, D_enc)
341
+
342
+ Returns:
343
+ nbest_hyps: N-best decoding results
344
+
345
+ """
346
+ beam = min(self.beam_size, self.vocab_size)
347
+
348
+ h_length = int(h.size(0))
349
+ u_max = min(self.u_max, (h_length - 1))
350
+
351
+ beam_state = self.decoder.init_state(beam)
352
+
353
+ B = [
354
+ Hypothesis(
355
+ yseq=[self.blank],
356
+ score=0.0,
357
+ dec_state=self.decoder.select_state(beam_state, 0),
358
+ )
359
+ ]
360
+ final = []
361
+ cache = {}
362
+
363
+ if self.use_lm and not self.is_wordlm:
364
+ B[0].lm_state = init_lm_state(self.lm_predictor)
365
+
366
+ for i in range(h_length + u_max):
367
+ A = []
368
+
369
+ B_ = []
370
+ h_states = []
371
+ for hyp in B:
372
+ u = len(hyp.yseq) - 1
373
+ t = i - u + 1
374
+
375
+ if t > (h_length - 1):
376
+ continue
377
+
378
+ B_.append(hyp)
379
+ h_states.append((t, h[t]))
380
+
381
+ if B_:
382
+ beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
383
+ B_,
384
+ beam_state,
385
+ cache,
386
+ self.use_lm,
387
+ )
388
+
389
+ h_enc = torch.stack([h[1] for h in h_states])
390
+
391
+ beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
392
+ beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)
393
+
394
+ if self.use_lm:
395
+ beam_lm_states = create_lm_batch_state(
396
+ [b.lm_state for b in B_], self.lm_layers, self.is_wordlm
397
+ )
398
+
399
+ beam_lm_states, beam_lm_scores = self.lm.buff_predict(
400
+ beam_lm_states, beam_lm_tokens, len(B_)
401
+ )
402
+
403
+ for i, hyp in enumerate(B_):
404
+ new_hyp = Hypothesis(
405
+ score=(hyp.score + float(beam_logp[i, 0])),
406
+ yseq=hyp.yseq[:],
407
+ dec_state=hyp.dec_state,
408
+ lm_state=hyp.lm_state,
409
+ )
410
+
411
+ A.append(new_hyp)
412
+
413
+ if h_states[i][0] == (h_length - 1):
414
+ final.append(new_hyp)
415
+
416
+ for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
417
+ new_hyp = Hypothesis(
418
+ score=(hyp.score + float(logp)),
419
+ yseq=(hyp.yseq[:] + [int(k)]),
420
+ dec_state=self.decoder.select_state(beam_state, i),
421
+ lm_state=hyp.lm_state,
422
+ )
423
+
424
+ if self.use_lm:
425
+ new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
426
+
427
+ new_hyp.lm_state = select_lm_state(
428
+ beam_lm_states, i, self.lm_layers, self.is_wordlm
429
+ )
430
+
431
+ A.append(new_hyp)
432
+
433
+ B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
434
+ B = recombine_hyps(B)
435
+
436
+ if final:
437
+ return self.sort_nbest(final)
438
+ else:
439
+ return B
440
+
441
+ def nsc_beam_search(self, h: torch.Tensor) -> List[NSCHypothesis]:
442
+ """N-step constrained beam search implementation.
443
+
444
+ Based and modified from https://arxiv.org/pdf/2002.03577.pdf.
445
+ Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet
446
+ until further modifications.
447
+
448
+ Note: the algorithm is not in his "complete" form but works almost as
449
+ intended.
450
+
451
+ Args:
452
+ h: Encoded speech features (T_max, D_enc)
453
+
454
+ Returns:
455
+ nbest_hyps: N-best decoding results
456
+
457
+ """
458
+ beam = min(self.beam_size, self.vocab_size)
459
+ beam_k = min(beam, (self.vocab_size - 1))
460
+
461
+ beam_state = self.decoder.init_state(beam)
462
+
463
+ init_tokens = [
464
+ NSCHypothesis(
465
+ yseq=[self.blank],
466
+ score=0.0,
467
+ dec_state=self.decoder.select_state(beam_state, 0),
468
+ )
469
+ ]
470
+
471
+ cache = {}
472
+
473
+ beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
474
+ init_tokens,
475
+ beam_state,
476
+ cache,
477
+ self.use_lm,
478
+ )
479
+
480
+ state = self.decoder.select_state(beam_state, 0)
481
+
482
+ if self.use_lm:
483
+ beam_lm_states, beam_lm_scores = self.lm.buff_predict(
484
+ None, beam_lm_tokens, 1
485
+ )
486
+ lm_state = select_lm_state(
487
+ beam_lm_states, 0, self.lm_layers, self.is_wordlm
488
+ )
489
+ lm_scores = beam_lm_scores[0]
490
+ else:
491
+ lm_state = None
492
+ lm_scores = None
493
+
494
+ kept_hyps = [
495
+ NSCHypothesis(
496
+ yseq=[self.blank],
497
+ score=0.0,
498
+ dec_state=state,
499
+ y=[beam_y[0]],
500
+ lm_state=lm_state,
501
+ lm_scores=lm_scores,
502
+ )
503
+ ]
504
+
505
+ for hi in h:
506
+ hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True)
507
+ kept_hyps = []
508
+
509
+ h_enc = hi.unsqueeze(0)
510
+
511
+ for j, hyp_j in enumerate(hyps[:-1]):
512
+ for hyp_i in hyps[(j + 1) :]:
513
+ curr_id = len(hyp_j.yseq)
514
+ next_id = len(hyp_i.yseq)
515
+
516
+ if (
517
+ is_prefix(hyp_j.yseq, hyp_i.yseq)
518
+ and (curr_id - next_id) <= self.prefix_alpha
519
+ ):
520
+ ytu = torch.log_softmax(
521
+ self.joint_network(hi, hyp_i.y[-1]), dim=-1
522
+ )
523
+
524
+ curr_score = hyp_i.score + float(ytu[hyp_j.yseq[next_id]])
525
+
526
+ for k in range(next_id, (curr_id - 1)):
527
+ ytu = torch.log_softmax(
528
+ self.joint_network(hi, hyp_j.y[k]), dim=-1
529
+ )
530
+
531
+ curr_score += float(ytu[hyp_j.yseq[k + 1]])
532
+
533
+ hyp_j.score = np.logaddexp(hyp_j.score, curr_score)
534
+
535
+ S = []
536
+ V = []
537
+ for n in range(self.nstep):
538
+ beam_y = torch.stack([hyp.y[-1] for hyp in hyps])
539
+
540
+ beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
541
+ beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1)
542
+
543
+ for i, hyp in enumerate(hyps):
544
+ S.append(
545
+ NSCHypothesis(
546
+ yseq=hyp.yseq[:],
547
+ score=hyp.score + float(beam_logp[i, 0:1]),
548
+ y=hyp.y[:],
549
+ dec_state=hyp.dec_state,
550
+ lm_state=hyp.lm_state,
551
+ lm_scores=hyp.lm_scores,
552
+ )
553
+ )
554
+
555
+ for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
556
+ score = hyp.score + float(logp)
557
+
558
+ if self.use_lm:
559
+ score += self.lm_weight * float(hyp.lm_scores[k])
560
+
561
+ V.append(
562
+ NSCHypothesis(
563
+ yseq=hyp.yseq[:] + [int(k)],
564
+ score=score,
565
+ y=hyp.y[:],
566
+ dec_state=hyp.dec_state,
567
+ lm_state=hyp.lm_state,
568
+ lm_scores=hyp.lm_scores,
569
+ )
570
+ )
571
+
572
+ V.sort(key=lambda x: x.score, reverse=True)
573
+ V = substract(V, hyps)[:beam]
574
+
575
+ beam_state = self.decoder.create_batch_states(
576
+ beam_state,
577
+ [v.dec_state for v in V],
578
+ [v.yseq for v in V],
579
+ )
580
+ beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
581
+ V,
582
+ beam_state,
583
+ cache,
584
+ self.use_lm,
585
+ )
586
+
587
+ if self.use_lm:
588
+ beam_lm_states = create_lm_batch_state(
589
+ [v.lm_state for v in V], self.lm_layers, self.is_wordlm
590
+ )
591
+ beam_lm_states, beam_lm_scores = self.lm.buff_predict(
592
+ beam_lm_states, beam_lm_tokens, len(V)
593
+ )
594
+
595
+ if n < (self.nstep - 1):
596
+ for i, v in enumerate(V):
597
+ v.y.append(beam_y[i])
598
+
599
+ v.dec_state = self.decoder.select_state(beam_state, i)
600
+
601
+ if self.use_lm:
602
+ v.lm_state = select_lm_state(
603
+ beam_lm_states, i, self.lm_layers, self.is_wordlm
604
+ )
605
+ v.lm_scores = beam_lm_scores[i]
606
+
607
+ hyps = V[:]
608
+ else:
609
+ beam_logp = torch.log_softmax(
610
+ self.joint_network(h_enc, beam_y), dim=-1
611
+ )
612
+
613
+ for i, v in enumerate(V):
614
+ if self.nstep != 1:
615
+ v.score += float(beam_logp[i, 0])
616
+
617
+ v.y.append(beam_y[i])
618
+
619
+ v.dec_state = self.decoder.select_state(beam_state, i)
620
+
621
+ if self.use_lm:
622
+ v.lm_state = select_lm_state(
623
+ beam_lm_states, i, self.lm_layers, self.is_wordlm
624
+ )
625
+ v.lm_scores = beam_lm_scores[i]
626
+
627
+ kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam]
628
+
629
+ return self.sort_nbest(kept_hyps)
espnet/nets/chainer_backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/nets/chainer_backend/asr_interface.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ASR Interface module."""
2
+ import chainer
3
+
4
+ from espnet.nets.asr_interface import ASRInterface
5
+
6
+
7
+ class ChainerASRInterface(ASRInterface, chainer.Chain):
8
+ """ASR Interface for ESPnet model implementation."""
9
+
10
+ @staticmethod
11
+ def custom_converter(*args, **kw):
12
+ """Get customconverter of the model (Chainer only)."""
13
+ raise NotImplementedError("custom converter method is not implemented")
14
+
15
+ @staticmethod
16
+ def custom_updater(*args, **kw):
17
+ """Get custom_updater of the model (Chainer only)."""
18
+ raise NotImplementedError("custom updater method is not implemented")
19
+
20
+ @staticmethod
21
+ def custom_parallel_updater(*args, **kw):
22
+ """Get custom_parallel_updater of the model (Chainer only)."""
23
+ raise NotImplementedError("custom parallel updater method is not implemented")
24
+
25
+ def get_total_subsampling_factor(self):
26
+ """Get total subsampling factor."""
27
+ raise NotImplementedError(
28
+ "get_total_subsampling_factor method is not implemented"
29
+ )
espnet/nets/chainer_backend/ctc.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import chainer
4
+ from chainer import cuda
5
+ import chainer.functions as F
6
+ import chainer.links as L
7
+ import numpy as np
8
+
9
+
10
+ class CTC(chainer.Chain):
11
+ """Chainer implementation of ctc layer.
12
+
13
+ Args:
14
+ odim (int): The output dimension.
15
+ eprojs (int | None): Dimension of input vectors from encoder.
16
+ dropout_rate (float): Dropout rate.
17
+
18
+ """
19
+
20
+ def __init__(self, odim, eprojs, dropout_rate):
21
+ super(CTC, self).__init__()
22
+ self.dropout_rate = dropout_rate
23
+ self.loss = None
24
+
25
+ with self.init_scope():
26
+ self.ctc_lo = L.Linear(eprojs, odim)
27
+
28
+ def __call__(self, hs, ys):
29
+ """CTC forward.
30
+
31
+ Args:
32
+ hs (list of chainer.Variable | N-dimension array):
33
+ Input variable from encoder.
34
+ ys (list of chainer.Variable | N-dimension array):
35
+ Input variable of decoder.
36
+
37
+ Returns:
38
+ chainer.Variable: A variable holding a scalar value of the CTC loss.
39
+
40
+ """
41
+ self.loss = None
42
+ ilens = [x.shape[0] for x in hs]
43
+ olens = [x.shape[0] for x in ys]
44
+
45
+ # zero padding for hs
46
+ y_hat = self.ctc_lo(
47
+ F.dropout(F.pad_sequence(hs), ratio=self.dropout_rate), n_batch_axes=2
48
+ )
49
+ y_hat = F.separate(y_hat, axis=1) # ilen list of batch x hdim
50
+
51
+ # zero padding for ys
52
+ y_true = F.pad_sequence(ys, padding=-1) # batch x olen
53
+
54
+ # get length info
55
+ input_length = chainer.Variable(self.xp.array(ilens, dtype=np.int32))
56
+ label_length = chainer.Variable(self.xp.array(olens, dtype=np.int32))
57
+ logging.info(
58
+ self.__class__.__name__ + " input lengths: " + str(input_length.data)
59
+ )
60
+ logging.info(
61
+ self.__class__.__name__ + " output lengths: " + str(label_length.data)
62
+ )
63
+
64
+ # get ctc loss
65
+ self.loss = F.connectionist_temporal_classification(
66
+ y_hat, y_true, 0, input_length, label_length
67
+ )
68
+ logging.info("ctc loss:" + str(self.loss.data))
69
+
70
+ return self.loss
71
+
72
+ def log_softmax(self, hs):
73
+ """Log_softmax of frame activations.
74
+
75
+ Args:
76
+ hs (list of chainer.Variable | N-dimension array):
77
+ Input variable from encoder.
78
+
79
+ Returns:
80
+ chainer.Variable: A n-dimension float array.
81
+
82
+ """
83
+ y_hat = self.ctc_lo(F.pad_sequence(hs), n_batch_axes=2)
84
+ return F.log_softmax(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape)
85
+
86
+
87
+ class WarpCTC(chainer.Chain):
88
+ """Chainer implementation of warp-ctc layer.
89
+
90
+ Args:
91
+ odim (int): The output dimension.
92
+ eproj (int | None): Dimension of input vector from encoder.
93
+ dropout_rate (float): Dropout rate.
94
+
95
+ """
96
+
97
+ def __init__(self, odim, eprojs, dropout_rate):
98
+ super(WarpCTC, self).__init__()
99
+ self.dropout_rate = dropout_rate
100
+ self.loss = None
101
+
102
+ with self.init_scope():
103
+ self.ctc_lo = L.Linear(eprojs, odim)
104
+
105
+ def __call__(self, hs, ys):
106
+ """Core function of the Warp-CTC layer.
107
+
108
+ Args:
109
+ hs (iterable of chainer.Variable | N-dimention array):
110
+ Input variable from encoder.
111
+ ys (iterable of chainer.Variable | N-dimension array):
112
+ Input variable of decoder.
113
+
114
+ Returns:
115
+ chainer.Variable: A variable holding a scalar value of the CTC loss.
116
+
117
+ """
118
+ self.loss = None
119
+ ilens = [x.shape[0] for x in hs]
120
+ olens = [x.shape[0] for x in ys]
121
+
122
+ # zero padding for hs
123
+ y_hat = self.ctc_lo(
124
+ F.dropout(F.pad_sequence(hs), ratio=self.dropout_rate), n_batch_axes=2
125
+ )
126
+ y_hat = y_hat.transpose(1, 0, 2) # batch x frames x hdim
127
+
128
+ # get length info
129
+ logging.info(self.__class__.__name__ + " input lengths: " + str(ilens))
130
+ logging.info(self.__class__.__name__ + " output lengths: " + str(olens))
131
+
132
+ # get ctc loss
133
+ from chainer_ctc.warpctc import ctc as warp_ctc
134
+
135
+ self.loss = warp_ctc(y_hat, ilens, [cuda.to_cpu(y.data) for y in ys])[0]
136
+ logging.info("ctc loss:" + str(self.loss.data))
137
+
138
+ return self.loss
139
+
140
+ def log_softmax(self, hs):
141
+ """Log_softmax of frame activations.
142
+
143
+ Args:
144
+ hs (list of chainer.Variable | N-dimension array):
145
+ Input variable from encoder.
146
+
147
+ Returns:
148
+ chainer.Variable: A n-dimension float array.
149
+
150
+ """
151
+ y_hat = self.ctc_lo(F.pad_sequence(hs), n_batch_axes=2)
152
+ return F.log_softmax(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape)
153
+
154
+ def argmax(self, hs_pad):
155
+ """argmax of frame activations
156
+
157
+ :param chainer variable hs_pad: 3d tensor (B, Tmax, eprojs)
158
+ :return: argmax applied 2d tensor (B, Tmax)
159
+ :rtype: chainer.Variable
160
+ """
161
+ return F.argmax(self.ctc_lo(F.pad_sequence(hs_pad), n_batch_axes=2), axis=-1)
162
+
163
+
164
+ def ctc_for(args, odim):
165
+ """Return the CTC layer corresponding to the args.
166
+
167
+ Args:
168
+ args (Namespace): The program arguments.
169
+ odim (int): The output dimension.
170
+
171
+ Returns:
172
+ The CTC module.
173
+
174
+ """
175
+ ctc_type = args.ctc_type
176
+ if ctc_type == "builtin":
177
+ logging.info("Using chainer CTC implementation")
178
+ ctc = CTC(odim, args.eprojs, args.dropout_rate)
179
+ elif ctc_type == "warpctc":
180
+ logging.info("Using warpctc CTC implementation")
181
+ ctc = WarpCTC(odim, args.eprojs, args.dropout_rate)
182
+ else:
183
+ raise ValueError('ctc_type must be "builtin" or "warpctc": {}'.format(ctc_type))
184
+ return ctc
espnet/nets/chainer_backend/deterministic_embed_id.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import six
3
+
4
+ import chainer
5
+ from chainer import cuda
6
+ from chainer import function_node
7
+ from chainer.initializers import normal
8
+
9
+ # from chainer.functions.connection import embed_id
10
+ from chainer import link
11
+ from chainer.utils import type_check
12
+ from chainer import variable
13
+
14
+ """Deterministic EmbedID link and function
15
+
16
+ copied from chainer/links/connection/embed_id.py
17
+ and chainer/functions/connection/embed_id.py,
18
+ and modified not to use atomicAdd operation
19
+ """
20
+
21
+
22
+ class EmbedIDFunction(function_node.FunctionNode):
23
+ def __init__(self, ignore_label=None):
24
+ self.ignore_label = ignore_label
25
+ self._w_shape = None
26
+
27
+ def check_type_forward(self, in_types):
28
+ type_check.expect(in_types.size() == 2)
29
+ x_type, w_type = in_types
30
+ type_check.expect(
31
+ x_type.dtype.kind == "i",
32
+ x_type.ndim >= 1,
33
+ )
34
+ type_check.expect(w_type.dtype == numpy.float32, w_type.ndim == 2)
35
+
36
+ def forward(self, inputs):
37
+ self.retain_inputs((0,))
38
+ x, W = inputs
39
+ self._w_shape = W.shape
40
+
41
+ if not type_check.same_types(*inputs):
42
+ raise ValueError(
43
+ "numpy and cupy must not be used together\n"
44
+ "type(W): {0}, type(x): {1}".format(type(W), type(x))
45
+ )
46
+
47
+ xp = cuda.get_array_module(*inputs)
48
+ if chainer.is_debug():
49
+ valid_x = xp.logical_and(0 <= x, x < len(W))
50
+ if self.ignore_label is not None:
51
+ valid_x = xp.logical_or(valid_x, x == self.ignore_label)
52
+ if not valid_x.all():
53
+ raise ValueError(
54
+ "Each not ignored `x` value need to satisfy" "`0 <= x < len(W)`"
55
+ )
56
+
57
+ if self.ignore_label is not None:
58
+ mask = x == self.ignore_label
59
+ return (xp.where(mask[..., None], 0, W[xp.where(mask, 0, x)]),)
60
+
61
+ return (W[x],)
62
+
63
+ def backward(self, indexes, grad_outputs):
64
+ inputs = self.get_retained_inputs()
65
+ gW = EmbedIDGrad(self._w_shape, self.ignore_label).apply(inputs + grad_outputs)[
66
+ 0
67
+ ]
68
+ return None, gW
69
+
70
+
71
+ class EmbedIDGrad(function_node.FunctionNode):
72
+ def __init__(self, w_shape, ignore_label=None):
73
+ self.w_shape = w_shape
74
+ self.ignore_label = ignore_label
75
+ self._gy_shape = None
76
+
77
+ def forward(self, inputs):
78
+ self.retain_inputs((0,))
79
+ xp = cuda.get_array_module(*inputs)
80
+ x, gy = inputs
81
+ self._gy_shape = gy.shape
82
+ gW = xp.zeros(self.w_shape, dtype=gy.dtype)
83
+
84
+ if xp is numpy:
85
+ # It is equivalent to `numpy.add.at(gW, x, gy)` but ufunc.at is
86
+ # too slow.
87
+ for ix, igy in six.moves.zip(x.ravel(), gy.reshape(x.size, -1)):
88
+ if ix == self.ignore_label:
89
+ continue
90
+ gW[ix] += igy
91
+ else:
92
+ """
93
+ # original code based on cuda elementwise method
94
+ if self.ignore_label is None:
95
+ cuda.elementwise(
96
+ 'T gy, S x, S n_out', 'raw T gW',
97
+ 'ptrdiff_t w_ind[] = {x, i % n_out};'
98
+ 'atomicAdd(&gW[w_ind], gy)',
99
+ 'embed_id_bwd')(
100
+ gy, xp.expand_dims(x, -1), gW.shape[1], gW)
101
+ else:
102
+ cuda.elementwise(
103
+ 'T gy, S x, S n_out, S ignore', 'raw T gW',
104
+ '''
105
+ if (x != ignore) {
106
+ ptrdiff_t w_ind[] = {x, i % n_out};
107
+ atomicAdd(&gW[w_ind], gy);
108
+ }
109
+ ''',
110
+ 'embed_id_bwd_ignore_label')(
111
+ gy, xp.expand_dims(x, -1), gW.shape[1],
112
+ self.ignore_label, gW)
113
+ """
114
+ # EmbedID gradient alternative without atomicAdd, which simply
115
+ # creates a one-hot vector and applies dot product
116
+ xi = xp.zeros((x.size, len(gW)), dtype=numpy.float32)
117
+ idx = xp.arange(x.size, dtype=numpy.int32) * len(gW) + x.ravel()
118
+ xi.ravel()[idx] = 1.0
119
+ if self.ignore_label is not None:
120
+ xi[:, self.ignore_label] = 0.0
121
+ gW = xi.T.dot(gy.reshape(x.size, -1)).astype(gW.dtype, copy=False)
122
+
123
+ return (gW,)
124
+
125
+ def backward(self, indexes, grads):
126
+ xp = cuda.get_array_module(*grads)
127
+ x = self.get_retained_inputs()[0].data
128
+ ggW = grads[0]
129
+
130
+ if self.ignore_label is not None:
131
+ mask = x == self.ignore_label
132
+ # To prevent index out of bounds, we need to check if ignore_label
133
+ # is inside of W.
134
+ if not (0 <= self.ignore_label < self.w_shape[1]):
135
+ x = xp.where(mask, 0, x)
136
+
137
+ ggy = ggW[x]
138
+
139
+ if self.ignore_label is not None:
140
+ mask, zero, _ = xp.broadcast_arrays(
141
+ mask[..., None], xp.zeros((), "f"), ggy.data
142
+ )
143
+ ggy = chainer.functions.where(mask, zero, ggy)
144
+ return None, ggy
145
+
146
+
147
+ def embed_id(x, W, ignore_label=None):
148
+ r"""Efficient linear function for one-hot input.
149
+
150
+ This function implements so called *word embeddings*. It takes two
151
+ arguments: a set of IDs (words) ``x`` in :math:`B` dimensional integer
152
+ vector, and a set of all ID (word) embeddings ``W`` in :math:`V \\times d`
153
+ float32 matrix. It outputs :math:`B \\times d` matrix whose ``i``-th
154
+ column is the ``x[i]``-th column of ``W``.
155
+ This function is only differentiable on the input ``W``.
156
+
157
+ Args:
158
+ x (chainer.Variable | np.ndarray): Batch vectors of IDs. Each
159
+ element must be signed integer.
160
+ W (chainer.Variable | np.ndarray): Distributed representation
161
+ of each ID (a.k.a. word embeddings).
162
+ ignore_label (int): If ignore_label is an int value, i-th column
163
+ of return value is filled with 0.
164
+
165
+ Returns:
166
+ chainer.Variable: Embedded variable.
167
+
168
+
169
+ .. rubric:: :class:`~chainer.links.EmbedID`
170
+
171
+ Examples:
172
+
173
+ >>> x = np.array([2, 1]).astype('i')
174
+ >>> x
175
+ array([2, 1], dtype=int32)
176
+ >>> W = np.array([[0, 0, 0],
177
+ ... [1, 1, 1],
178
+ ... [2, 2, 2]]).astype('f')
179
+ >>> W
180
+ array([[ 0., 0., 0.],
181
+ [ 1., 1., 1.],
182
+ [ 2., 2., 2.]], dtype=float32)
183
+ >>> F.embed_id(x, W).data
184
+ array([[ 2., 2., 2.],
185
+ [ 1., 1., 1.]], dtype=float32)
186
+ >>> F.embed_id(x, W, ignore_label=1).data
187
+ array([[ 2., 2., 2.],
188
+ [ 0., 0., 0.]], dtype=float32)
189
+
190
+ """
191
+ return EmbedIDFunction(ignore_label=ignore_label).apply((x, W))[0]
192
+
193
+
194
+ class EmbedID(link.Link):
195
+ """Efficient linear layer for one-hot input.
196
+
197
+ This is a link that wraps the :func:`~chainer.functions.embed_id` function.
198
+ This link holds the ID (word) embedding matrix ``W`` as a parameter.
199
+
200
+ Args:
201
+ in_size (int): Number of different identifiers (a.k.a. vocabulary size).
202
+ out_size (int): Output dimension.
203
+ initialW (Initializer): Initializer to initialize the weight.
204
+ ignore_label (int): If `ignore_label` is an int value, i-th column of
205
+ return value is filled with 0.
206
+
207
+ .. rubric:: :func:`~chainer.functions.embed_id`
208
+
209
+ Attributes:
210
+ W (~chainer.Variable): Embedding parameter matrix.
211
+
212
+ Examples:
213
+
214
+ >>> W = np.array([[0, 0, 0],
215
+ ... [1, 1, 1],
216
+ ... [2, 2, 2]]).astype('f')
217
+ >>> W
218
+ array([[ 0., 0., 0.],
219
+ [ 1., 1., 1.],
220
+ [ 2., 2., 2.]], dtype=float32)
221
+ >>> l = L.EmbedID(W.shape[0], W.shape[1], initialW=W)
222
+ >>> x = np.array([2, 1]).astype('i')
223
+ >>> x
224
+ array([2, 1], dtype=int32)
225
+ >>> y = l(x)
226
+ >>> y.data
227
+ array([[ 2., 2., 2.],
228
+ [ 1., 1., 1.]], dtype=float32)
229
+
230
+ """
231
+
232
+ ignore_label = None
233
+
234
+ def __init__(self, in_size, out_size, initialW=None, ignore_label=None):
235
+ super(EmbedID, self).__init__()
236
+ self.ignore_label = ignore_label
237
+
238
+ with self.init_scope():
239
+ if initialW is None:
240
+ initialW = normal.Normal(1.0)
241
+ self.W = variable.Parameter(initialW, (in_size, out_size))
242
+
243
+ def __call__(self, x):
244
+ """Extracts the word embedding of given IDs.
245
+
246
+ Args:
247
+ x (chainer.Variable): Batch vectors of IDs.
248
+
249
+ Returns:
250
+ chainer.Variable: Batch of corresponding embeddings.
251
+
252
+ """
253
+ return embed_id(x, self.W, ignore_label=self.ignore_label)