Porjaz commited on
Commit
23ab61f
1 Parent(s): 4985c84

Upload 3 files

Browse files
Files changed (3) hide show
  1. 1000_unigram.model +3 -0
  2. hyperparams_augment.yaml +266 -0
  3. train_augment.py +308 -0
1000_unigram.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9554eb7aea11a6003af9d520f6d2cfdefb32225141ed8602448530b95785d74e
3
+ size 257601
hyperparams_augment.yaml ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Seed needs to be set at top of yaml, before objects with parameters
2
+ # are instantiated
3
+ seed: 1994
4
+ __set_seed: !apply:torch.manual_seed [!ref <seed>]
5
+
6
+ skip_training: True
7
+
8
+ output_folder: !ref output_folder_seq2seq_cv_podcast_arhiv_augmentation_128_emb_5000_vocab
9
+ output_wer_folder: !ref <output_folder>/
10
+ save_folder: !ref <output_folder>/save
11
+ train_log: !ref <output_folder>/train_log.txt
12
+
13
+ lm_folder: LM/output_folder_lm
14
+
15
+ # Data files
16
+ data_folder: "../../data/combined_data/speechbrain_splits"
17
+
18
+ wav2vec2_hub: facebook/wav2vec2-large-xlsr-53
19
+ wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
20
+
21
+ # pretrained_tokenizer_path: "Tokenizer/output_folder_cv/1K_subword_unigram" # Use this for the CV model
22
+ pretrained_tokenizer_path: "Tokenizer/output_folder_cv_podcast_arhiv/5K_subword_unigram" # Use this for the CV+Podcast+Arhiv model
23
+
24
+ ####################### Training Parameters ####################################
25
+
26
+ number_of_epochs: 50
27
+ number_of_ctc_epochs: 15
28
+ # batch_size: 16
29
+ # batch_size: 6 # for cv+podcast
30
+ batch_size: 6 # for cv+podcast+arhiv
31
+ label_smoothing: 0.1
32
+ lr: 0.0001
33
+ ctc_weight: 0.5
34
+
35
+ opt_class: !name:torch.optim.Adam
36
+ lr: !ref <lr>
37
+
38
+ lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
39
+ initial_value: !ref <lr>
40
+ improvement_threshold: 0.0025
41
+ annealing_factor: 0.8
42
+ patient: 0
43
+
44
+ # Dataloader options
45
+ num_workers: 4
46
+ train_dataloader_opts:
47
+ num_workers: !ref <num_workers>
48
+ batch_size: !ref <batch_size>
49
+
50
+ valid_dataloader_opts:
51
+ num_workers: !ref <num_workers>
52
+ batch_size: !ref <batch_size>
53
+
54
+ test_dataloader_opts:
55
+ batch_size: 1
56
+
57
+ ####################### Model Parameters #######################################
58
+
59
+ dropout: 0.15
60
+ wav2vec_output_dim: 1024
61
+ emb_size: 128
62
+ dec_neurons: 1024
63
+ dec_layers: 1
64
+
65
+ output_neurons: 5000
66
+ blank_index: 0
67
+ bos_index: 0
68
+ eos_index: 0
69
+ unk_index: 0
70
+
71
+ # Decoding parameters
72
+ min_decode_ratio: 0.0
73
+ max_decode_ratio: 1.0
74
+ valid_beam_size: 10
75
+ test_beam_size: 10
76
+ using_eos_threshold: True
77
+ eos_threshold: 1.5
78
+ using_max_attn_shift: True
79
+ max_attn_shift: 300
80
+ temperature: 1.0
81
+ ctc_window_size: 200
82
+ temperature_lm: 1.25
83
+ # Scoring parameters
84
+ ctc_weight_decode: 0.0
85
+ coverage_penalty: 1.5
86
+ lm_weight: 0.0
87
+
88
+ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
89
+ limit: !ref <number_of_epochs>
90
+
91
+ # Wav2vec2 encoder
92
+ encoder_w2v2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
93
+ source: !ref <wav2vec2_hub>
94
+ output_norm: True
95
+ freeze: False
96
+ freeze_feature_extractor: True
97
+ save_path: !ref <wav2vec2_folder>
98
+ output_all_hiddens: False
99
+
100
+ embedding: !new:speechbrain.nnet.embedding.Embedding
101
+ num_embeddings: !ref <output_neurons>
102
+ embedding_dim: !ref <emb_size>
103
+
104
+ # Attention-based RNN decoder.
105
+ decoder: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
106
+ enc_dim: !ref <wav2vec_output_dim>
107
+ input_size: !ref <emb_size>
108
+ rnn_type: gru
109
+ attn_type: location
110
+ hidden_size: !ref <dec_neurons>
111
+ attn_dim: 512
112
+ num_layers: !ref <dec_layers>
113
+ scaling: 1.0
114
+ channels: 10
115
+ kernel_size: 100
116
+ re_init: True
117
+ dropout: !ref <dropout>
118
+
119
+ ctc_lin: !new:speechbrain.nnet.linear.Linear
120
+ input_size: !ref <wav2vec_output_dim>
121
+ n_neurons: !ref <output_neurons>
122
+
123
+ seq_lin: !new:speechbrain.nnet.linear.Linear
124
+ input_size: !ref <dec_neurons>
125
+ n_neurons: !ref <output_neurons>
126
+
127
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
128
+ apply_log: True
129
+
130
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
131
+ blank_index: !ref <blank_index>
132
+
133
+ nll_cost: !name:speechbrain.nnet.losses.nll_loss
134
+ label_smoothing: 0.1
135
+
136
+ # This is the RNNLM that is used according to the Huggingface repository
137
+ # NB: It has to match the pre-trained RNNLM!!
138
+ #lm_model: !new:speechbrain.lobes.models.RNNLM.RNNLM
139
+ # output_neurons: !ref <output_neurons>
140
+ # embedding_dim: !ref <emb_size>
141
+ # activation: !name:torch.nn.LeakyReLU
142
+ # dropout: 0.0
143
+ # rnn_layers: 2
144
+ # rnn_neurons: 2048
145
+ # dnn_blocks: 1
146
+ # dnn_neurons: 512
147
+ # return_hidden: True # For inference
148
+
149
+ tokenizer: !new:sentencepiece.SentencePieceProcessor
150
+ model_file: !ref <pretrained_tokenizer_path>/5000_unigram.model
151
+
152
+ modules:
153
+ encoder_w2v2: !ref <encoder_w2v2>
154
+ embedding: !ref <embedding>
155
+ decoder: !ref <decoder>
156
+ ctc_lin: !ref <ctc_lin>
157
+ seq_lin: !ref <seq_lin>
158
+ #lm_model: !ref <lm_model>
159
+
160
+ model: !new:torch.nn.ModuleList
161
+ - [!ref <encoder_w2v2>, !ref <embedding>, !ref <decoder>, !ref <ctc_lin>, !ref <seq_lin>]
162
+
163
+ ############################## Decoding & optimiser ############################
164
+ #coverage_scorer: !new:speechbrain.decoders.scorer.CoverageScorer
165
+ # vocab_size: !ref <output_neurons>
166
+ #
167
+ #rnnlm_scorer: !new:speechbrain.decoders.scorer.RNNLMScorer
168
+ # language_model: !ref <lm_model>
169
+ # temperature: !ref <temperature_lm>
170
+ #
171
+ #scorer: !new:speechbrain.decoders.scorer.ScorerBuilder
172
+ # full_scorers: [!ref <rnnlm_scorer>,
173
+ # !ref <coverage_scorer>]
174
+ # weights:
175
+ # rnnlm: !ref <lm_weight>
176
+ # coverage: !ref <coverage_penalty>
177
+
178
+
179
+ # Search
180
+ greedy_search: !new:speechbrain.decoders.S2SRNNGreedySearcher
181
+ embedding: !ref <embedding>
182
+ decoder: !ref <decoder>
183
+ linear: !ref <seq_lin>
184
+ bos_index: !ref <bos_index>
185
+ eos_index: !ref <eos_index>
186
+ min_decode_ratio: !ref <min_decode_ratio>
187
+ max_decode_ratio: !ref <max_decode_ratio>
188
+
189
+ test_search: !new:speechbrain.decoders.S2SRNNBeamSearcher
190
+ embedding: !ref <embedding>
191
+ decoder: !ref <decoder>
192
+ linear: !ref <seq_lin>
193
+ bos_index: !ref <bos_index>
194
+ eos_index: !ref <eos_index>
195
+ min_decode_ratio: !ref <min_decode_ratio>
196
+ max_decode_ratio: !ref <max_decode_ratio>
197
+ beam_size: !ref <test_beam_size>
198
+ eos_threshold: !ref <eos_threshold>
199
+ using_max_attn_shift: !ref <using_max_attn_shift>
200
+ max_attn_shift: !ref <max_attn_shift>
201
+ temperature: !ref <temperature>
202
+ #scorer: !ref <scorer>
203
+
204
+
205
+ ############################## Augmentations ###################################
206
+
207
+ # Speed perturbation
208
+ speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
209
+ orig_freq: 16000
210
+ speeds: [95, 100, 105]
211
+
212
+ # Frequency drop: randomly drops a number of frequency bands to zero.
213
+ drop_freq: !new:speechbrain.augment.time_domain.DropFreq
214
+ drop_freq_low: 0
215
+ drop_freq_high: 1
216
+ drop_freq_count_low: 1
217
+ drop_freq_count_high: 3
218
+ drop_freq_width: 0.05
219
+
220
+ # Time drop: randomly drops a number of temporal chunks.
221
+ drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
222
+ drop_length_low: 1000
223
+ drop_length_high: 2000
224
+ drop_count_low: 1
225
+ drop_count_high: 5
226
+
227
+ # Augmenter: Combines previously defined augmentations to perform data augmentation
228
+ wav_augment: !new:speechbrain.augment.augmenter.Augmenter
229
+ concat_original: False
230
+ min_augmentations: 1
231
+ max_augmentations: 3
232
+ augment_prob: 0.5
233
+ augmentations: [
234
+ !ref <speed_perturb>,
235
+ !ref <drop_freq>,
236
+ !ref <drop_chunk>]
237
+
238
+
239
+ ############################## Logging and Pretrainer ##########################
240
+
241
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
242
+ checkpoints_dir: !ref <save_folder>
243
+ recoverables:
244
+ model: !ref <model>
245
+ scheduler: !ref <lr_annealing>
246
+ counter: !ref <epoch_counter>
247
+
248
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
249
+ save_file: !ref <train_log>
250
+
251
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
252
+
253
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
254
+ split_tokens: True
255
+
256
+
257
+ # The pretrainer allows a mapping between pretrained files and instances that
258
+ # are declared in the yaml. E.g here, we will download the file lm.ckpt
259
+ # and it will be loaded into "lm" which is pointing to the <lm_model> defined
260
+ # before.
261
+ #pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
262
+ # collect_in: !ref <lm_folder>
263
+ # loadables:
264
+ # lm: !ref <lm_model>
265
+ # paths:
266
+ # lm: !ref <lm_folder>/save/CKPT+2024-07-19+14-16-05+00/model.ckpt
train_augment.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env/python3
2
+
3
+ import logging
4
+ import sys
5
+ from pathlib import Path
6
+ import os
7
+
8
+ import librosa
9
+
10
+ import torch
11
+ from torch.utils.data import DataLoader
12
+ from hyperpyyaml import load_hyperpyyaml
13
+
14
+ import speechbrain as sb
15
+ from speechbrain.utils.distributed import if_main_process, run_on_main
16
+
17
+ from jiwer import wer, cer
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ # Define training procedure
23
+ class ASR(sb.Brain):
24
+ def compute_forward(self, batch, stage):
25
+ """Forward computations from the waveform batches to the output probabilities."""
26
+ batch = batch.to(self.device)
27
+ sig, self.sig_lens = batch.sig
28
+ tokens_bos, _ = batch.tokens_bos
29
+ sig, self.sig_lens = sig.to(self.device), self.sig_lens.to(self.device)
30
+
31
+ # Add waveform augmentation if specified.
32
+ if stage == sb.Stage.TRAIN:
33
+ sig, self.sig_lens = self.hparams.wav_augment(sig, self.sig_lens)
34
+ # tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
35
+
36
+ # Forward pass
37
+ encoded_outputs = self.modules.encoder_w2v2(sig.detach())
38
+ embedded_tokens = self.modules.embedding(tokens_bos)
39
+ decoder_outputs, _ = self.modules.decoder(embedded_tokens, encoded_outputs, self.sig_lens)
40
+
41
+ # Output layer for seq2seq log-probabilities
42
+ logits = self.modules.seq_lin(decoder_outputs)
43
+ predictions = {"seq_logprobs": self.hparams.log_softmax(logits)}
44
+
45
+ if self.is_ctc_active(stage):
46
+ # Output layer for ctc log-probabilities
47
+ ctc_logits = self.modules.ctc_lin(encoded_outputs)
48
+ predictions["ctc_logprobs"] = self.hparams.log_softmax(ctc_logits)
49
+ elif stage == sb.Stage.VALID:
50
+ predictions["tokens"], _, _, _ = self.hparams.greedy_search(encoded_outputs, self.sig_lens)
51
+ elif stage == sb.Stage.TEST:
52
+ predictions["tokens"], _, _, _ = self.hparams.test_search(encoded_outputs, self.sig_lens)
53
+
54
+ return predictions
55
+
56
+
57
+ def is_ctc_active(self, stage):
58
+ """Check if CTC is currently active.
59
+
60
+ Arguments
61
+ ---------
62
+ stage : sb.Stage
63
+ Currently executing stage.
64
+ """
65
+ if stage != sb.Stage.TRAIN:
66
+ return False
67
+ current_epoch = self.hparams.epoch_counter.current
68
+ return current_epoch <= self.hparams.number_of_ctc_epochs
69
+
70
+
71
+
72
+ def compute_objectives(self, predictions, batch, stage):
73
+ """Computes the loss (CTC+NLL) given predictions and targets."""
74
+ ids = batch.id
75
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
76
+ tokens, tokens_lens = batch.tokens
77
+
78
+ # if stage == sb.Stage.TRAIN:
79
+ # (tokens, tokens_lens, tokens_eos, tokens_eos_lens) = self.hparams.wav_augment.replicate_multiple_labels(tokens, tokens_lens, tokens_eos, tokens_eos_lens)
80
+
81
+ loss = self.hparams.nll_cost(log_probabilities=predictions["seq_logprobs"], targets=tokens_eos, length=tokens_eos_lens)
82
+
83
+ if self.is_ctc_active(stage):
84
+ # Load tokens without EOS as CTC targets
85
+ loss_ctc = self.hparams.ctc_cost(predictions["ctc_logprobs"], tokens, self.sig_lens, tokens_lens)
86
+ loss *= 1 - self.hparams.ctc_weight
87
+ loss += self.hparams.ctc_weight * loss_ctc
88
+
89
+ if stage != sb.Stage.TRAIN:
90
+ # for prediction in predictions["tokens"]:
91
+ # print(self.hparams.tokenizer.decode_ids(prediction))
92
+ predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions["tokens"]]
93
+ target_words = [words.split(" ") for words in batch.transcript]
94
+ self.wer_metric.append(ids, predicted_words, target_words)
95
+ self.cer_metric.append(ids, predicted_words, target_words)
96
+
97
+ return loss
98
+
99
+ def on_stage_start(self, stage, epoch):
100
+ """Gets called at the beginning of each epoch"""
101
+ if stage != sb.Stage.TRAIN:
102
+ self.cer_metric = self.hparams.cer_computer()
103
+ self.wer_metric = self.hparams.error_rate_computer()
104
+
105
+ def on_stage_end(self, stage, stage_loss, epoch):
106
+ """Gets called at the end of a epoch."""
107
+ # Compute/store important stats
108
+ stage_stats = {"loss": stage_loss}
109
+ if stage == sb.Stage.TRAIN:
110
+ self.train_stats = stage_stats
111
+ else:
112
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
113
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
114
+
115
+ # Perform end-of-iteration things, like annealing, logging, etc.
116
+ if stage == sb.Stage.VALID:
117
+ old_lr, new_lr = self.hparams.lr_annealing(stage_stats["WER"])
118
+ sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
119
+ self.hparams.train_logger.log_stats(
120
+ stats_meta={"epoch": epoch, "lr": old_lr},
121
+ # stats_meta={"epoch": epoch},
122
+ train_stats=self.train_stats,
123
+ valid_stats=stage_stats,
124
+ )
125
+ self.checkpointer.save_and_keep_only(
126
+ meta={"WER": stage_stats["WER"]},
127
+ min_keys=["WER"],
128
+ )
129
+ elif stage == sb.Stage.TEST:
130
+ self.hparams.train_logger.log_stats(
131
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
132
+ test_stats=stage_stats,
133
+ )
134
+ if if_main_process():
135
+ with open(self.hparams.test_wer_file, "w") as w:
136
+ self.wer_metric.write_stats(w)
137
+
138
+ def run_inference(
139
+ self,
140
+ dataset, # Must be obtained from the dataio_function
141
+ min_key, # We load the model with the lowest error rate
142
+ loader_kwargs, # opts for the dataloading
143
+ ):
144
+
145
+ # If dataset isn't a Dataloader, we create it.
146
+ if not isinstance(dataset, DataLoader):
147
+ loader_kwargs["ckpt_prefix"] = None
148
+ dataset = self.make_dataloader(
149
+ dataset, sb.Stage.TEST, **loader_kwargs
150
+ )
151
+
152
+ self.checkpointer.recover_if_possible(min_key=min_key)
153
+ self.modules.eval() # We set the model to eval mode (remove dropout etc)
154
+
155
+ with torch.no_grad():
156
+ true_labels = []
157
+ pred_labels = []
158
+ #for batch in tqdm(dataset, dynamic_ncols=True):
159
+ for batch in dataset:
160
+ # Make sure that your compute_forward returns the predictions !!!
161
+ # In the case of the template, when stage = TEST, a beam search is applied
162
+ # in compute_forward().
163
+ predictions = self.compute_forward(batch, stage=sb.Stage.TEST)
164
+
165
+ pred_batch = []
166
+ predicted_words = []
167
+
168
+ predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions["tokens"]]
169
+ for sent in predicted_words:
170
+ sent = " ".join(sent)
171
+ pred_batch.append(sent)
172
+
173
+ pred_labels.append(pred_batch[0])
174
+ true_labels.append(batch.transcript[0])
175
+
176
+ # print("True: ", batch.transcript[0])
177
+ # print("Pred: ", pred_batch[0])
178
+ # with open("predictions/predictions_arhiv.txt", "a") as f:
179
+ # f.write("True: " + batch.transcript[0] + "\n")
180
+ # f.write("Pred: " + pred_batch[0] + "\n\n")
181
+ print("True: ", batch.transcript[0])
182
+ print("Pred: ", pred_batch[0])
183
+
184
+ print('WER: ', wer(true_labels, pred_labels) * 100)
185
+ print('CER: ', cer(true_labels, pred_labels) * 100)
186
+
187
+
188
+ def dataio_prepare(hparams):
189
+ """This function prepares the datasets to be used in the brain class.
190
+ It also defines the data processing pipeline through user-defined functions.
191
+ """
192
+ data_folder = hparams["data_folder"]
193
+
194
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "train_corrected.json"), replacements={"data_root": data_folder})
195
+
196
+ train_data = train_data.filtered_sorted(sort_key="duration")
197
+ hparams["train_dataloader_opts"]["shuffle"] = False
198
+
199
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "dev_corrected.json"), replacements={"data_root": data_folder})
200
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
201
+
202
+ test_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "test_arhiv.json"), replacements={"data_root": data_folder})
203
+
204
+
205
+ datasets = [train_data, valid_data, test_data]
206
+
207
+ # We get the tokenizer as we need it to encode the labels when creating
208
+ # mini-batches.
209
+ tokenizer = hparams["tokenizer"]
210
+
211
+ # 2. Define audio pipeline:
212
+ @sb.utils.data_pipeline.takes("data_path")
213
+ @sb.utils.data_pipeline.provides("sig")
214
+ def audio_pipeline(data_path):
215
+ if "cv-mk" in data_path:
216
+ filename = data_path.split("clips")[1]
217
+ data_path = "/m/triton/scratch/elec/t405-puhe/p/porjazd1/macedonian_asr/data/CV-18_MK/cv-mk/mk/clips" + filename
218
+ elif "podcast" in data_path:
219
+ filename = data_path.split("segmented_audio")[1]
220
+ data_path = "/m/triton/scratch/elec/t405-puhe/p/porjazd1/macedonian_asr/data/podcast/audio/segmented_audio" + filename
221
+ elif "arhiv" in data_path:
222
+ filename = data_path.split("segmented_audio")[1]
223
+ data_path = "/m/triton/scratch/elec/t405-puhe/p/porjazd1/macedonian_asr/data/arhiv/audio/segmented_audio" + filename
224
+
225
+ sig, sr = librosa.load(data_path, sr=16000)
226
+ # sig = sb.dataio.dataio.read_audio(wav)
227
+ return sig
228
+
229
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
230
+
231
+ # 3. Define text pipeline:
232
+ @sb.utils.data_pipeline.takes("transcript")
233
+ @sb.utils.data_pipeline.provides("transcript", "tokens_list", "tokens_bos", "tokens_eos", "tokens")
234
+ def text_pipeline(transcript):
235
+ yield transcript
236
+ tokens_list = tokenizer.encode_as_ids(transcript)
237
+ yield tokens_list
238
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
239
+ yield tokens_bos
240
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
241
+ yield tokens_eos
242
+ tokens = torch.LongTensor(tokens_list)
243
+ yield tokens
244
+
245
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
246
+
247
+ # 4. Set output:
248
+ sb.dataio.dataset.set_output_keys(datasets, ["id", "sig", "transcript", "tokens_list", "tokens_bos", "tokens_eos", "tokens"])
249
+
250
+ return (train_data, valid_data, test_data)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ # CLI:
255
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
256
+
257
+ # create ddp_group with the right communication protocol
258
+ sb.utils.distributed.ddp_init_group(run_opts)
259
+
260
+ with open(hparams_file) as fin:
261
+ hparams = load_hyperpyyaml(fin, overrides)
262
+
263
+ # Create experiment directory
264
+ sb.create_experiment_directory(
265
+ experiment_directory=hparams["output_folder"],
266
+ hyperparams_to_save=hparams_file,
267
+ overrides=overrides,
268
+ )
269
+
270
+ # here we create the datasets objects as well as tokenization and encoding
271
+ (train_data, valid_data, test_data) = dataio_prepare(hparams)
272
+
273
+ #run_on_main(hparams["pretrainer"].collect_files)
274
+ #hparams["pretrainer"].load_collected()
275
+
276
+ # Trainer initialization
277
+ asr_brain = ASR(
278
+ modules=hparams["modules"],
279
+ opt_class=hparams["opt_class"],
280
+ hparams=hparams,
281
+ run_opts=run_opts,
282
+ checkpointer=hparams["checkpointer"],
283
+ )
284
+
285
+ # We dynamically add the tokenizer to our brain class.
286
+ # NB: This tokenizer corresponds to the one used for the LM!!
287
+ asr_brain.tokenizer = hparams["tokenizer"]
288
+ train_dataloader_opts = hparams["train_dataloader_opts"]
289
+ valid_dataloader_opts = hparams["valid_dataloader_opts"]
290
+
291
+
292
+ # Training/validation loop
293
+ if hparams["skip_training"] == False:
294
+ print("Training...")
295
+ # Training
296
+ asr_brain.fit(
297
+ asr_brain.hparams.epoch_counter,
298
+ train_data,
299
+ valid_data,
300
+ train_loader_kwargs=train_dataloader_opts,
301
+ valid_loader_kwargs=valid_dataloader_opts,
302
+ )
303
+
304
+ else:
305
+ # evaluate
306
+ print("Evaluating")
307
+ asr_brain.run_inference(test_data, "WER", hparams["test_dataloader_opts"])
308
+