Porjaz commited on
Commit
2c32e80
1 Parent(s): 2b61326

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +58 -33
train.py CHANGED
@@ -31,7 +31,6 @@ class ASR(sb.Brain):
31
  # Add waveform augmentation if specified.
32
  if stage == sb.Stage.TRAIN:
33
  sig, self.sig_lens = self.hparams.wav_augment(sig, self.sig_lens)
34
- # tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
35
 
36
  # Forward pass
37
  encoded_outputs = self.modules.encoder_w2v2(sig.detach())
@@ -75,9 +74,6 @@ class ASR(sb.Brain):
75
  tokens_eos, tokens_eos_lens = batch.tokens_eos
76
  tokens, tokens_lens = batch.tokens
77
 
78
- # if stage == sb.Stage.TRAIN:
79
- # (tokens, tokens_lens, tokens_eos, tokens_eos_lens) = self.hparams.wav_augment.replicate_multiple_labels(tokens, tokens_lens, tokens_eos, tokens_eos_lens)
80
-
81
  loss = self.hparams.nll_cost(log_probabilities=predictions["seq_logprobs"], targets=tokens_eos, length=tokens_eos_lens)
82
 
83
  if self.is_ctc_active(stage):
@@ -87,8 +83,6 @@ class ASR(sb.Brain):
87
  loss += self.hparams.ctc_weight * loss_ctc
88
 
89
  if stage != sb.Stage.TRAIN:
90
- # for prediction in predictions["tokens"]:
91
- # print(self.hparams.tokenizer.decode_ids(prediction))
92
  predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions["tokens"]]
93
  target_words = [words.split(" ") for words in batch.transcript]
94
  self.wer_metric.append(ids, predicted_words, target_words)
@@ -118,7 +112,6 @@ class ASR(sb.Brain):
118
  sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
119
  self.hparams.train_logger.log_stats(
120
  stats_meta={"epoch": epoch, "lr": old_lr},
121
- # stats_meta={"epoch": epoch},
122
  train_stats=self.train_stats,
123
  valid_stats=stage_stats,
124
  )
@@ -155,7 +148,6 @@ class ASR(sb.Brain):
155
  with torch.no_grad():
156
  true_labels = []
157
  pred_labels = []
158
- #for batch in tqdm(dataset, dynamic_ncols=True):
159
  for batch in dataset:
160
  # Make sure that your compute_forward returns the predictions !!!
161
  # In the case of the template, when stage = TEST, a beam search is applied
@@ -167,39 +159,84 @@ class ASR(sb.Brain):
167
 
168
  predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions["tokens"]]
169
  for sent in predicted_words:
 
 
170
  sent = " ".join(sent)
171
  pred_batch.append(sent)
172
 
173
  pred_labels.append(pred_batch[0])
174
  true_labels.append(batch.transcript[0])
175
 
176
- # print("True: ", batch.transcript[0])
177
- # print("Pred: ", pred_batch[0])
178
- # with open("predictions/predictions_arhiv.txt", "a") as f:
179
- # f.write("True: " + batch.transcript[0] + "\n")
180
- # f.write("Pred: " + pred_batch[0] + "\n\n")
181
- print("True: ", batch.transcript[0])
182
- print("Pred: ", pred_batch[0])
183
-
184
  print('WER: ', wer(true_labels, pred_labels) * 100)
185
  print('CER: ', cer(true_labels, pred_labels) * 100)
186
 
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  def dataio_prepare(hparams):
189
  """This function prepares the datasets to be used in the brain class.
190
  It also defines the data processing pipeline through user-defined functions.
191
  """
192
  data_folder = hparams["data_folder"]
193
 
194
- train_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "train_corrected.json"), replacements={"data_root": data_folder})
195
-
196
  train_data = train_data.filtered_sorted(sort_key="duration")
197
  hparams["train_dataloader_opts"]["shuffle"] = False
198
 
199
- valid_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "dev_corrected.json"), replacements={"data_root": data_folder})
200
  valid_data = valid_data.filtered_sorted(sort_key="duration")
201
 
202
- test_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "test_arhiv.json"), replacements={"data_root": data_folder})
203
 
204
 
205
  datasets = [train_data, valid_data, test_data]
@@ -212,18 +249,8 @@ def dataio_prepare(hparams):
212
  @sb.utils.data_pipeline.takes("data_path")
213
  @sb.utils.data_pipeline.provides("sig")
214
  def audio_pipeline(data_path):
215
- if "cv-mk" in data_path:
216
- filename = data_path.split("clips")[1]
217
- data_path = "/m/triton/scratch/elec/t405-puhe/p/porjazd1/macedonian_asr/data/CV-18_MK/cv-mk/mk/clips" + filename
218
- elif "podcast" in data_path:
219
- filename = data_path.split("segmented_audio")[1]
220
- data_path = "/m/triton/scratch/elec/t405-puhe/p/porjazd1/macedonian_asr/data/podcast/audio/segmented_audio" + filename
221
- elif "arhiv" in data_path:
222
- filename = data_path.split("segmented_audio")[1]
223
- data_path = "/m/triton/scratch/elec/t405-puhe/p/porjazd1/macedonian_asr/data/arhiv/audio/segmented_audio" + filename
224
-
225
  sig, sr = librosa.load(data_path, sr=16000)
226
- # sig = sb.dataio.dataio.read_audio(wav)
227
  return sig
228
 
229
  sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
@@ -270,8 +297,6 @@ if __name__ == "__main__":
270
  # here we create the datasets objects as well as tokenization and encoding
271
  (train_data, valid_data, test_data) = dataio_prepare(hparams)
272
 
273
- #run_on_main(hparams["pretrainer"].collect_files)
274
- #hparams["pretrainer"].load_collected()
275
 
276
  # Trainer initialization
277
  asr_brain = ASR(
 
31
  # Add waveform augmentation if specified.
32
  if stage == sb.Stage.TRAIN:
33
  sig, self.sig_lens = self.hparams.wav_augment(sig, self.sig_lens)
 
34
 
35
  # Forward pass
36
  encoded_outputs = self.modules.encoder_w2v2(sig.detach())
 
74
  tokens_eos, tokens_eos_lens = batch.tokens_eos
75
  tokens, tokens_lens = batch.tokens
76
 
 
 
 
77
  loss = self.hparams.nll_cost(log_probabilities=predictions["seq_logprobs"], targets=tokens_eos, length=tokens_eos_lens)
78
 
79
  if self.is_ctc_active(stage):
 
83
  loss += self.hparams.ctc_weight * loss_ctc
84
 
85
  if stage != sb.Stage.TRAIN:
 
 
86
  predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions["tokens"]]
87
  target_words = [words.split(" ") for words in batch.transcript]
88
  self.wer_metric.append(ids, predicted_words, target_words)
 
112
  sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
113
  self.hparams.train_logger.log_stats(
114
  stats_meta={"epoch": epoch, "lr": old_lr},
 
115
  train_stats=self.train_stats,
116
  valid_stats=stage_stats,
117
  )
 
148
  with torch.no_grad():
149
  true_labels = []
150
  pred_labels = []
 
151
  for batch in dataset:
152
  # Make sure that your compute_forward returns the predictions !!!
153
  # In the case of the template, when stage = TEST, a beam search is applied
 
159
 
160
  predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions["tokens"]]
161
  for sent in predicted_words:
162
+ # sent = " ".join(sent)
163
+ sent = filter_repetitions(sent, 3)
164
  sent = " ".join(sent)
165
  pred_batch.append(sent)
166
 
167
  pred_labels.append(pred_batch[0])
168
  true_labels.append(batch.transcript[0])
169
 
 
 
 
 
 
 
 
 
170
  print('WER: ', wer(true_labels, pred_labels) * 100)
171
  print('CER: ', cer(true_labels, pred_labels) * 100)
172
 
173
 
174
+ def filter_repetitions(seq, max_repetition_length):
175
+ seq = list(seq)
176
+ output = []
177
+ max_n = len(seq) // 2
178
+ for n in range(max_n, 0, -1):
179
+ max_repetitions = max(max_repetition_length // n, 1)
180
+ # Don't need to iterate over impossible n values:
181
+ # len(seq) can change a lot during iteration
182
+ if (len(seq) <= n*2) or (len(seq) <= max_repetition_length):
183
+ continue
184
+ iterator = enumerate(seq)
185
+ # Fill first buffers:
186
+ buffers = [[next(iterator)[1]] for _ in range(n)]
187
+ for seq_index, token in iterator:
188
+ current_buffer = seq_index % n
189
+ if token != buffers[current_buffer][-1]:
190
+ # No repeat, we can flush some tokens
191
+ buf_len = sum(map(len, buffers))
192
+ flush_start = (current_buffer-buf_len) % n
193
+ # Keep n-1 tokens, but possibly mark some for removal
194
+ for flush_index in range(buf_len - buf_len%n):
195
+ if (buf_len - flush_index) > n-1:
196
+ to_flush = buffers[(flush_index + flush_start) % n].pop(0)
197
+ else:
198
+ to_flush = None
199
+ # Here, repetitions get removed:
200
+ if (flush_index // n < max_repetitions) and to_flush is not None:
201
+ output.append(to_flush)
202
+ elif (flush_index // n >= max_repetitions) and to_flush is None:
203
+ output.append(to_flush)
204
+ buffers[current_buffer].append(token)
205
+ # At the end, final flush
206
+ current_buffer += 1
207
+ buf_len = sum(map(len, buffers))
208
+ flush_start = (current_buffer-buf_len) % n
209
+ for flush_index in range(buf_len):
210
+ to_flush = buffers[(flush_index + flush_start) % n].pop(0)
211
+ # Here, repetitions just get removed:
212
+ if flush_index // n < max_repetitions:
213
+ output.append(to_flush)
214
+ seq = []
215
+ to_delete = 0
216
+ for token in output:
217
+ if token is None:
218
+ to_delete += 1
219
+ elif to_delete > 0:
220
+ to_delete -= 1
221
+ else:
222
+ seq.append(token)
223
+ output = []
224
+ return seq
225
+
226
  def dataio_prepare(hparams):
227
  """This function prepares the datasets to be used in the brain class.
228
  It also defines the data processing pipeline through user-defined functions.
229
  """
230
  data_folder = hparams["data_folder"]
231
 
232
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "train.json"), replacements={"data_root": data_folder})
 
233
  train_data = train_data.filtered_sorted(sort_key="duration")
234
  hparams["train_dataloader_opts"]["shuffle"] = False
235
 
236
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "dev.json"), replacements={"data_root": data_folder})
237
  valid_data = valid_data.filtered_sorted(sort_key="duration")
238
 
239
+ test_data = sb.dataio.dataset.DynamicItemDataset.from_json(json_path=os.path.join(hparams["data_folder"], "test.json"), replacements={"data_root": data_folder})
240
 
241
 
242
  datasets = [train_data, valid_data, test_data]
 
249
  @sb.utils.data_pipeline.takes("data_path")
250
  @sb.utils.data_pipeline.provides("sig")
251
  def audio_pipeline(data_path):
 
 
 
 
 
 
 
 
 
 
252
  sig, sr = librosa.load(data_path, sr=16000)
253
+ # sig = sb.dataio.dataio.read_audio(wav) # alternatively use the SpeechBrain data loading function
254
  return sig
255
 
256
  sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
 
297
  # here we create the datasets objects as well as tokenization and encoding
298
  (train_data, valid_data, test_data) = dataio_prepare(hparams)
299
 
 
 
300
 
301
  # Trainer initialization
302
  asr_brain = ASR(