anonymoussubmitter222 commited on
Commit
ed25c49
1 Parent(s): 8a51838

working version, needs better readm

Browse files
Files changed (43) hide show
  1. .ipynb_checkpoints/transcribe-checkpoint.ipynb +913 -0
  2. EnglishCV/train_en_with_wav2vec.yaml +3 -3
  3. EnglishCV/train_with_wav2vec.py +2 -2
  4. TunisianASR/results/14epoch_tunisian/1234/env.log +347 -0
  5. TunisianASR/{semi_wavlm_large_tunisian_ctc → results/14epoch_tunisian}/1234/hyperparams.yaml +14 -14
  6. TunisianASR/results/14epoch_tunisian/1234/log.txt +359 -0
  7. TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/CKPT.yaml +2 -2
  8. TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/brain.ckpt +1 -1
  9. TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/counter.ckpt +1 -1
  10. TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/dataloader-TRAIN.ckpt +1 -1
  11. TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/model.ckpt +1 -1
  12. TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/modelopt.ckpt +2 -2
  13. TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/scheduler_model.ckpt +1 -1
  14. TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/scheduler_wav2vec.ckpt +1 -1
  15. TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/wav2vec2.ckpt +1 -1
  16. TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/wav2vec_opt.ckpt +1 -1
  17. TunisianASR/{semi_wavlm_large_tunisian_ctc → results/14epoch_tunisian}/1234/save/label_encoder.txt +0 -0
  18. TunisianASR/results/14epoch_tunisian/1234/train_with_wav2vec.py +399 -0
  19. TunisianASR/results/14epoch_tunisian/<seed>/copy_of_wavlm_tun.py +761 -0
  20. TunisianASR/results/14epoch_tunisian/<seed>/ctc_lin.py +756 -0
  21. TunisianASR/results/14epoch_tunisian/<seed>/env.log +97 -0
  22. TunisianASR/results/14epoch_tunisian/<seed>/hyperparams.yaml +200 -0
  23. TunisianASR/results/14epoch_tunisian/<seed>/log.txt +0 -0
  24. TunisianASR/{train_semi.yaml → semi_trained.yaml} +23 -6
  25. __pycache__/cv_train.cpython-38.pyc +0 -0
  26. app.py +768 -0
  27. TunisianASR/outdomain.arpa → arpas/pluslanguages_everything.arpa +2 -2
  28. asr-wav2vec2-commonvoice-fr/hyperparams.yaml +1 -1
  29. cs.yaml +1 -1
  30. cv_train.py +388 -0
  31. pretrained_models/asr-wav2vec2-commonvoice-fr/asr.ckpt +1 -0
  32. pretrained_models/asr-wav2vec2-commonvoice-fr/custom.py +1 -0
  33. pretrained_models/asr-wav2vec2-commonvoice-fr/hyperparams.yaml +1 -0
  34. pretrained_models/asr-wav2vec2-commonvoice-fr/tokenizer.ckpt +1 -0
  35. pretrained_models/asr-wav2vec2-commonvoice-fr/wav2vec2.ckpt +1 -0
  36. requirements.txt +17 -0
  37. results/non_semi_final_stac/env.log +347 -0
  38. results/non_semi_final_stac/hyperparams.yaml +3 -4
  39. results/non_semi_final_stac/log.txt +0 -0
  40. transcribe.ipynb +915 -0
  41. wav2vec2-FR-7K-large +1 -0
  42. wav2vec2-large-lv60 +1 -0
  43. wavlm-large +1 -0
.ipynb_checkpoints/transcribe-checkpoint.ipynb ADDED
@@ -0,0 +1,913 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 22,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "speechbrain.utils.distributed - distributed_launch flag is disabled, this experiment will be executed without DDP.\n",
13
+ "speechbrain.lobes.models.huggingface_wav2vec - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 feature extractor is frozen.\n",
14
+ "speechbrain.core - Beginning experiment!\n",
15
+ "speechbrain.core - Experiment folder: TunisianASR/semi_wavlm_large_tunisian_ctc/1234\n",
16
+ "speechbrain.pretrained.fetching - Fetch hyperparams.yaml: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/hyperparams.yaml.\n",
17
+ "speechbrain.pretrained.fetching - Fetch custom.py: Linking to local file in /home/salah/Code_Switched_Tunisian_Speech_Recognition/asr-wav2vec2-commonvoice-fr/custom.py.\n",
18
+ "speechbrain.lobes.models.huggingface_wav2vec - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 is frozen.\n",
19
+ "speechbrain.pretrained.fetching - Fetch wav2vec2.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/wav2vec2.ckpt.\n",
20
+ "speechbrain.pretrained.fetching - Fetch asr.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/asr.ckpt.\n",
21
+ "speechbrain.pretrained.fetching - Fetch tokenizer.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/tokenizer.ckpt.\n",
22
+ "speechbrain.utils.parameter_transfer - Loading pretrained files for: wav2vec2, asr, tokenizer\n"
23
+ ]
24
+ },
25
+ {
26
+ "name": "stderr",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "Some weights of the model checkpoint at wav2vec2-large-lv60/ were not used when initializing Wav2Vec2Model: ['project_hid.bias', 'project_q.bias', 'project_hid.weight', 'quantizer.codevectors', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_q.weight']\n",
30
+ "- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
31
+ "- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
32
+ ]
33
+ },
34
+ {
35
+ "name": "stdout",
36
+ "output_type": "stream",
37
+ "text": [
38
+ "speechbrain.lobes.models.huggingface_wav2vec - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 feature extractor is frozen.\n",
39
+ "speechbrain.core - Info: auto_mix_prec arg from hparam file is used\n",
40
+ "speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used\n",
41
+ "speechbrain.core - 314.4M trainable parameters in ASRCV\n",
42
+ "speechbrain.utils.checkpoints - Loading a checkpoint from EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00\n",
43
+ "speechbrain.core - Info: auto_mix_prec arg from hparam file is used\n",
44
+ "speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used\n",
45
+ "speechbrain.core - 314.4M trainable parameters in ASR\n",
46
+ "speechbrain.utils.checkpoints - Loading a checkpoint from TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00\n"
47
+ ]
48
+ }
49
+ ],
50
+ "source": [
51
+ "import os\n",
52
+ "import sys\n",
53
+ "import torch\n",
54
+ "import logging\n",
55
+ "import speechbrain as sb\n",
56
+ "from speechbrain.utils.distributed import run_on_main\n",
57
+ "from hyperpyyaml import load_hyperpyyaml\n",
58
+ "from pathlib import Path\n",
59
+ "import torchaudio.transforms as T\n",
60
+ "from cv_train import ASRCV\n",
61
+ "import torchaudio\n",
62
+ "import numpy as np\n",
63
+ "import kenlm\n",
64
+ "from pyctcdecode import build_ctcdecoder\n",
65
+ "import re\n",
66
+ "from torch.nn.utils.rnn import pad_sequence\n",
67
+ "import torch.optim as optim\n",
68
+ "import torch.nn as nn\n",
69
+ "\n",
70
+ "\n",
71
+ "# Commented out IPython magic to ensure Python compatibility.\n",
72
+ "hparams_file, run_opts, overrides = sb.parse_arguments([\"TunisianASR/train_semi.yaml\"])\n",
73
+ "\n",
74
+ "# If distributed_launch=True then\n",
75
+ "# create ddp_group with the right communication protocol\n",
76
+ "sb.utils.distributed.ddp_init_group(run_opts)\n",
77
+ "\n",
78
+ "with open(hparams_file) as fin:\n",
79
+ " hparams = load_hyperpyyaml(fin, overrides)\n",
80
+ "\n",
81
+ "# Create experiment directory\n",
82
+ "sb.create_experiment_directory(\n",
83
+ " experiment_directory=hparams[\"output_folder\"],\n",
84
+ " hyperparams_to_save=hparams_file,\n",
85
+ " overrides=overrides,\n",
86
+ ")\n",
87
+ "# Dataset prep (parsing Librispeech)\n",
88
+ "\n",
89
+ "def dataio_prepare(hparams):\n",
90
+ " \"\"\"This function prepares the datasets to be used in the brain class.\n",
91
+ " It also defines the data processing pipeline through user-defined functions.\"\"\"\n",
92
+ "\n",
93
+ " # 1. Define datasets\n",
94
+ " data_folder = hparams[\"data_folder\"]\n",
95
+ "\n",
96
+ " train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(\n",
97
+ " csv_path=hparams[\"train_csv\"], replacements={\"data_root\": data_folder},\n",
98
+ " )\n",
99
+ "\n",
100
+ " if hparams[\"sorting\"] == \"ascending\":\n",
101
+ " # we sort training data to speed up training and get better results.\n",
102
+ " train_data = train_data.filtered_sorted(\n",
103
+ " sort_key=\"duration\",\n",
104
+ " key_max_value={\"duration\": hparams[\"avoid_if_longer_than\"]},\n",
105
+ " )\n",
106
+ " # when sorting do not shuffle in dataloader ! otherwise is pointless\n",
107
+ " hparams[\"dataloader_options\"][\"shuffle\"] = False\n",
108
+ "\n",
109
+ " elif hparams[\"sorting\"] == \"descending\":\n",
110
+ " train_data = train_data.filtered_sorted(\n",
111
+ " sort_key=\"duration\",\n",
112
+ " reverse=True,\n",
113
+ " key_max_value={\"duration\": hparams[\"avoid_if_longer_than\"]},\n",
114
+ " )\n",
115
+ " # when sorting do not shuffle in dataloader ! otherwise is pointless\n",
116
+ " hparams[\"dataloader_options\"][\"shuffle\"] = False\n",
117
+ "\n",
118
+ " elif hparams[\"sorting\"] == \"random\":\n",
119
+ " pass\n",
120
+ "\n",
121
+ " else:\n",
122
+ " raise NotImplementedError(\n",
123
+ " \"sorting must be random, ascending or descending\"\n",
124
+ " )\n",
125
+ "\n",
126
+ " valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(\n",
127
+ " csv_path=hparams[\"valid_csv\"], replacements={\"data_root\": data_folder},\n",
128
+ " )\n",
129
+ " # We also sort the validation data so it is faster to validate\n",
130
+ " valid_data = valid_data.filtered_sorted(sort_key=\"duration\")\n",
131
+ " test_datasets = {}\n",
132
+ " for csv_file in hparams[\"test_csv\"]:\n",
133
+ " name = Path(csv_file).stem\n",
134
+ " test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(\n",
135
+ " csv_path=csv_file, replacements={\"data_root\": data_folder}\n",
136
+ " )\n",
137
+ " test_datasets[name] = test_datasets[name].filtered_sorted(\n",
138
+ " sort_key=\"duration\"\n",
139
+ " )\n",
140
+ "\n",
141
+ " datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]\n",
142
+ "\n",
143
+ "\n",
144
+ " # 2. Define audio pipeline:\n",
145
+ " @sb.utils.data_pipeline.takes(\"wav\")\n",
146
+ " @sb.utils.data_pipeline.provides(\"sig\")\n",
147
+ " def audio_pipeline(wav):\n",
148
+ " info = torchaudio.info(wav)\n",
149
+ " sig = sb.dataio.dataio.read_audio(wav)\n",
150
+ " if len(sig.shape)>1 :\n",
151
+ " sig = torch.mean(sig, dim=1)\n",
152
+ " resampled = torchaudio.transforms.Resample(\n",
153
+ " info.sample_rate, hparams[\"sample_rate\"],\n",
154
+ " )(sig)\n",
155
+ " return resampled\n",
156
+ "\n",
157
+ " sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)\n",
158
+ " label_encoder = sb.dataio.encoder.CTCTextEncoder()\n",
159
+ "\n",
160
+ " # 3. Define text pipeline:\n",
161
+ " @sb.utils.data_pipeline.takes(\"wrd\")\n",
162
+ " @sb.utils.data_pipeline.provides(\n",
163
+ " \"wrd\", \"char_list\", \"tokens_list\", \"tokens\"\n",
164
+ " )\n",
165
+ " def text_pipeline(wrd):\n",
166
+ " yield wrd\n",
167
+ " char_list = list(wrd)\n",
168
+ " yield char_list\n",
169
+ " tokens_list = label_encoder.encode_sequence(char_list)\n",
170
+ " yield tokens_list\n",
171
+ " tokens = torch.LongTensor(tokens_list)\n",
172
+ " yield tokens\n",
173
+ "\n",
174
+ " sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)\n",
175
+ " lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n",
176
+ " special_labels = {\n",
177
+ " \"blank_label\": hparams[\"blank_index\"],\n",
178
+ " \"unk_label\": hparams[\"unk_index\"]\n",
179
+ " }\n",
180
+ " label_encoder.load_or_create(\n",
181
+ " path=lab_enc_file,\n",
182
+ " from_didatasets=[train_data],\n",
183
+ " output_key=\"char_list\",\n",
184
+ " special_labels=special_labels,\n",
185
+ " sequence_input=True,\n",
186
+ " )\n",
187
+ "\n",
188
+ " # 4. Set output:\n",
189
+ " sb.dataio.dataset.set_output_keys(\n",
190
+ " datasets, [\"id\", \"sig\", \"wrd\", \"char_list\", \"tokens\"],\n",
191
+ " )\n",
192
+ " return train_data, valid_data,test_datasets, label_encoder\n",
193
+ "\n",
194
+ "class ASR(sb.core.Brain):\n",
195
+ " def compute_forward(self, batch, stage):\n",
196
+ " \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n",
197
+ "\n",
198
+ " batch = batch.to(self.device)\n",
199
+ " wavs, wav_lens = batch.sig\n",
200
+ " wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n",
201
+ "\n",
202
+ " if stage == sb.Stage.TRAIN:\n",
203
+ " if hasattr(self.hparams, \"augmentation\"):\n",
204
+ " wavs = self.hparams.augmentation(wavs, wav_lens)\n",
205
+ "\n",
206
+ " # Forward pass\n",
207
+ " feats = self.modules.wav2vec2(wavs, wav_lens)\n",
208
+ " x = self.modules.enc(feats)\n",
209
+ " logits = self.modules.ctc_lin(x)\n",
210
+ " p_ctc = self.hparams.log_softmax(logits)\n",
211
+ "\n",
212
+ " return p_ctc, wav_lens\n",
213
+ "\n",
214
+ " def custom_encode(self,wavs,wav_lens) :\n",
215
+ " wavs = wavs.to(\"cpu\")\n",
216
+ " if(wav_lens is not None): wav_lens.to(self.device)\n",
217
+ "\n",
218
+ " feats = self.modules.wav2vec2(wavs, wav_lens)\n",
219
+ " x = self.modules.enc(feats)\n",
220
+ " logits = self.modules.ctc_lin(x)\n",
221
+ " p_ctc = self.hparams.log_softmax(logits)\n",
222
+ "\n",
223
+ " return feats,p_ctc\n",
224
+ "\n",
225
+ "\n",
226
+ "\n",
227
+ " def compute_objectives(self, predictions, batch, stage):\n",
228
+ " \"\"\"Computes the loss (CTC) given predictions and targets.\"\"\"\n",
229
+ "\n",
230
+ " p_ctc, wav_lens = predictions\n",
231
+ "\n",
232
+ " ids = batch.id\n",
233
+ " tokens, tokens_lens = batch.tokens\n",
234
+ "\n",
235
+ " loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n",
236
+ "\n",
237
+ " if stage != sb.Stage.TRAIN:\n",
238
+ " predicted_tokens = sb.decoders.ctc_greedy_decode(\n",
239
+ " p_ctc, wav_lens, blank_id=self.hparams.blank_index\n",
240
+ " )\n",
241
+ " # Decode token terms to words\n",
242
+ " if self.hparams.use_language_modelling:\n",
243
+ " predicted_words = []\n",
244
+ " for logs in p_ctc:\n",
245
+ " text = decoder.decode(logs.detach().cpu().numpy())\n",
246
+ " predicted_words.append(text.split(\" \"))\n",
247
+ " else:\n",
248
+ " predicted_words = [\n",
249
+ " \"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \")\n",
250
+ " for utt_seq in predicted_tokens\n",
251
+ " ]\n",
252
+ " # Convert indices to words\n",
253
+ " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n",
254
+ "\n",
255
+ " self.wer_metric.append(ids, predicted_words, target_words)\n",
256
+ " self.cer_metric.append(ids, predicted_words, target_words)\n",
257
+ "\n",
258
+ " return loss\n",
259
+ "\n",
260
+ " def fit_batch(self, batch):\n",
261
+ " \"\"\"Train the parameters given a single batch in input\"\"\"\n",
262
+ " should_step = self.step % self.grad_accumulation_factor == 0\n",
263
+ " # Managing automatic mixed precision\n",
264
+ " # TOFIX: CTC fine-tuning currently is unstable\n",
265
+ " # This is certainly due to CTC being done in fp16 instead of fp32\n",
266
+ " if self.auto_mix_prec:\n",
267
+ " with torch.cuda.amp.autocast():\n",
268
+ " with self.no_sync():\n",
269
+ " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n",
270
+ " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n",
271
+ " with self.no_sync(not should_step):\n",
272
+ " self.scaler.scale(\n",
273
+ " loss / self.grad_accumulation_factor\n",
274
+ " ).backward()\n",
275
+ " if should_step:\n",
276
+ "\n",
277
+ " if not self.hparams.wav2vec2.freeze:\n",
278
+ " self.scaler.unscale_(self.wav2vec_optimizer)\n",
279
+ " self.scaler.unscale_(self.model_optimizer)\n",
280
+ " if self.check_gradients(loss):\n",
281
+ " if not self.hparams.wav2vec2.freeze:\n",
282
+ " if self.optimizer_step >= self.hparams.warmup_steps:\n",
283
+ " self.scaler.step(self.wav2vec_optimizer)\n",
284
+ " self.scaler.step(self.model_optimizer)\n",
285
+ " self.scaler.update()\n",
286
+ " self.zero_grad()\n",
287
+ " self.optimizer_step += 1\n",
288
+ " else:\n",
289
+ " # This is mandatory because HF models have a weird behavior with DDP\n",
290
+ " # on the forward pass\n",
291
+ " with self.no_sync():\n",
292
+ " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n",
293
+ "\n",
294
+ " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n",
295
+ "\n",
296
+ " with self.no_sync(not should_step):\n",
297
+ " (loss / self.grad_accumulation_factor).backward()\n",
298
+ " if should_step:\n",
299
+ " if self.check_gradients(loss):\n",
300
+ " if not self.hparams.wav2vec2.freeze:\n",
301
+ " if self.optimizer_step >= self.hparams.warmup_steps:\n",
302
+ " self.wav2vec_optimizer.step()\n",
303
+ " self.model_optimizer.step()\n",
304
+ " self.zero_grad()\n",
305
+ " self.optimizer_step += 1\n",
306
+ "\n",
307
+ " self.on_fit_batch_end(batch, outputs, loss, should_step)\n",
308
+ " return loss.detach().cpu()\n",
309
+ "\n",
310
+ " def evaluate_batch(self, batch, stage):\n",
311
+ " \"\"\"Computations needed for validation/test batches\"\"\"\n",
312
+ " predictions = self.compute_forward(batch, stage=stage)\n",
313
+ " with torch.no_grad():\n",
314
+ " loss = self.compute_objectives(predictions, batch, stage=stage)\n",
315
+ " return loss.detach()\n",
316
+ "\n",
317
+ " def on_stage_start(self, stage, epoch):\n",
318
+ " \"\"\"Gets called at the beginning of each epoch\"\"\"\n",
319
+ " if stage != sb.Stage.TRAIN:\n",
320
+ " self.cer_metric = self.hparams.cer_computer()\n",
321
+ " self.wer_metric = self.hparams.error_rate_computer()\n",
322
+ "\n",
323
+ " def on_stage_end(self, stage, stage_loss, epoch):\n",
324
+ " \"\"\"Gets called at the end of an epoch.\"\"\"\n",
325
+ " # Compute/store important stats\n",
326
+ " stage_stats = {\"loss\": stage_loss}\n",
327
+ " if stage == sb.Stage.TRAIN:\n",
328
+ " self.train_stats = stage_stats\n",
329
+ " else:\n",
330
+ " stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n",
331
+ " stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n",
332
+ "\n",
333
+ " # Perform end-of-iteration things, like annealing, logging, etc.\n",
334
+ " if stage == sb.Stage.VALID:\n",
335
+ " old_lr_model, new_lr_model = self.hparams.lr_annealing_model(\n",
336
+ " stage_stats[\"loss\"]\n",
337
+ " )\n",
338
+ " old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(\n",
339
+ " stage_stats[\"loss\"]\n",
340
+ " )\n",
341
+ " sb.nnet.schedulers.update_learning_rate(\n",
342
+ " self.model_optimizer, new_lr_model\n",
343
+ " )\n",
344
+ " if not self.hparams.wav2vec2.freeze:\n",
345
+ " sb.nnet.schedulers.update_learning_rate(\n",
346
+ " self.wav2vec_optimizer, new_lr_wav2vec\n",
347
+ " )\n",
348
+ " self.hparams.train_logger.log_stats(\n",
349
+ " stats_meta={\n",
350
+ " \"epoch\": epoch,\n",
351
+ " \"lr_model\": old_lr_model,\n",
352
+ " \"lr_wav2vec\": old_lr_wav2vec,\n",
353
+ " },\n",
354
+ " train_stats=self.train_stats,\n",
355
+ " valid_stats=stage_stats,\n",
356
+ " )\n",
357
+ " self.checkpointer.save_and_keep_only(\n",
358
+ " meta={\"WER\": stage_stats[\"WER\"]}, min_keys=[\"WER\"],\n",
359
+ " )\n",
360
+ " elif stage == sb.Stage.TEST:\n",
361
+ " self.hparams.train_logger.log_stats(\n",
362
+ " stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},\n",
363
+ " test_stats=stage_stats,\n",
364
+ " )\n",
365
+ " with open(self.hparams.wer_file, \"w\") as w:\n",
366
+ " self.wer_metric.write_stats(w)\n",
367
+ "\n",
368
+ " def init_optimizers(self):\n",
369
+ " \"Initializes the wav2vec2 optimizer and model optimizer\"\n",
370
+ "\n",
371
+ " # If the wav2vec encoder is unfrozen, we create the optimizer\n",
372
+ " if not self.hparams.wav2vec2.freeze:\n",
373
+ " self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(\n",
374
+ " self.modules.wav2vec2.parameters()\n",
375
+ " )\n",
376
+ " if self.checkpointer is not None:\n",
377
+ " self.checkpointer.add_recoverable(\n",
378
+ " \"wav2vec_opt\", self.wav2vec_optimizer\n",
379
+ " )\n",
380
+ "\n",
381
+ " self.model_optimizer = self.hparams.model_opt_class(\n",
382
+ " self.hparams.model.parameters()\n",
383
+ " )\n",
384
+ "\n",
385
+ " if self.checkpointer is not None:\n",
386
+ " self.checkpointer.add_recoverable(\"modelopt\", self.model_optimizer)\n",
387
+ "\n",
388
+ " def zero_grad(self, set_to_none=False):\n",
389
+ " if not self.hparams.wav2vec2.freeze:\n",
390
+ " self.wav2vec_optimizer.zero_grad(set_to_none)\n",
391
+ " self.model_optimizer.zero_grad(set_to_none)\n",
392
+ "\n",
393
+ "\n",
394
+ "from speechbrain.pretrained import EncoderASR,EncoderDecoderASR\n",
395
+ "french_asr_model = EncoderASR.from_hparams(source=\"asr-wav2vec2-commonvoice-fr\", savedir=\"pretrained_models/asr-wav2vec2-commonvoice-fr\").cuda()\n",
396
+ "french_asr_model.to(\"cpu\")\n",
397
+ "cvhparams_file, cvrun_opts, cvoverrides = sb.parse_arguments([\"EnglishCV/train_en_with_wav2vec.yaml\"])\n",
398
+ "with open(cvhparams_file) as cvfin:\n",
399
+ " cvhparams = load_hyperpyyaml(cvfin, cvoverrides)\n",
400
+ "english_asr_model = ASRCV(\n",
401
+ " modules=cvhparams[\"modules\"],\n",
402
+ " hparams=cvhparams,\n",
403
+ " run_opts=cvrun_opts,\n",
404
+ " checkpointer=cvhparams[\"checkpointer\"],\n",
405
+ " )\n",
406
+ "english_asr_model.modules.to(\"cpu\")\n",
407
+ "english_asr_model.checkpointer.recover_if_possible()\n",
408
+ "asr_brain = ASR(\n",
409
+ " modules=hparams[\"modules\"],\n",
410
+ " hparams=hparams,\n",
411
+ " run_opts=run_opts,\n",
412
+ " checkpointer=hparams[\"checkpointer\"],\n",
413
+ ")\n",
414
+ "asr_brain.modules.to(\"cpu\")\n",
415
+ "asr_brain.checkpointer.recover_if_possible()\n",
416
+ "asr_brain.modules.eval()\n",
417
+ "english_asr_model.modules.eval()\n",
418
+ "french_asr_model.mods.eval()\n",
419
+ "asr_brain.modules.to(\"cpu\")\n",
420
+ "\n",
421
+ "# Commented out IPython magic to ensure Python compatibility.\n",
422
+ "# %ls\n",
423
+ "\n",
424
+ "#UTILS FUNCTIOJNS\n",
425
+ "def get_size_dimensions(arr):\n",
426
+ " size_dimensions = []\n",
427
+ " while isinstance(arr, list):\n",
428
+ " size_dimensions.append(len(arr))\n",
429
+ " arr = arr[0]\n",
430
+ " return size_dimensions\n",
431
+ "\n",
432
+ "def scale_array(batch,n):\n",
433
+ " scaled_batch = []\n",
434
+ "\n",
435
+ " for array in batch:\n",
436
+ " if(n < len(array)): raise ValueError(\"Cannot scale Array down\")\n",
437
+ "\n",
438
+ " repeat = round(n/len(array))+1\n",
439
+ " scaled_length_array= []\n",
440
+ "\n",
441
+ " for i in array:\n",
442
+ " for j in range(repeat) :\n",
443
+ " if(len(scaled_length_array) == n): break\n",
444
+ " scaled_length_array.append(i)\n",
445
+ "\n",
446
+ " scaled_batch.append(scaled_length_array)\n",
447
+ "\n",
448
+ " return torch.tensor(scaled_batch)\n",
449
+ "\n",
450
+ "\n",
451
+ "def load_paths(wavs_path):\n",
452
+ " waveforms = []\n",
453
+ " for path in wavs_path :\n",
454
+ " waveform, _ = torchaudio.load(path)\n",
455
+ " waveforms.append(waveform.squeeze(0))\n",
456
+ " # normalize array length to the bigger arrays by pading with 0's\n",
457
+ " padded_arrays = pad_sequence(waveforms, batch_first=True)\n",
458
+ " return torch.tensor(padded_arrays)\n",
459
+ "\n",
460
+ "\n",
461
+ "\n",
462
+ "device = 'cuda'\n",
463
+ "verbose = 0\n",
464
+ "#FLOW LEVEL FUNCTIONS\n",
465
+ "def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):\n",
466
+ "\n",
467
+ "\n",
468
+ " post1 = post1.to(device)\n",
469
+ " post2 = post2.to(device)\n",
470
+ " post3 = post3.to(device)\n",
471
+ " embeddings1 = embeddings1.to(device)\n",
472
+ " embeddings2 = embeddings2.to(device)\n",
473
+ " embeddings3 = embeddings3.to(device)\n",
474
+ "\n",
475
+ " posteriograms_merged = torch.cat((post1,post2,post3),dim=2)\n",
476
+ " embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)\n",
477
+ "\n",
478
+ " if(verbose !=0):\n",
479
+ " print('MERGED POST ',posteriograms_merged.shape)\n",
480
+ " print('MERGED emb ',embeddings_merged.shape)\n",
481
+ "\n",
482
+ " return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)\n",
483
+ "\n",
484
+ "def decode(model,wavs,wav_lens):\n",
485
+ "\n",
486
+ " with torch.no_grad():\n",
487
+ " wav_lens = wav_lens.to(model.device)\n",
488
+ " encoder_out = model.encode_batch(wavs, wav_lens)\n",
489
+ " predictions = model.decoding_function(encoder_out, wav_lens)\n",
490
+ " return predictions\n",
491
+ "\n",
492
+ "def middle_layer(batch, lens):\n",
493
+ "\n",
494
+ " tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)\n",
495
+ "\n",
496
+ " fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)\n",
497
+ " fr_posteriogram =french_asr_model.encode_batch(batch,lens)\n",
498
+ " en_embeddings = english_asr_model.modules.wav2vec2(batch, lens)\n",
499
+ " x = english_asr_model.modules.enc(en_embeddings)\n",
500
+ " en_posteriogram = english_asr_model.modules.ctc_lin(x)\n",
501
+ " #scores, en_posteriogram = english_asr_model.mods.decoder(en_embeddings ,lens)\n",
502
+ " if(verbose !=0):\n",
503
+ " print('[EMBEDDINGS] FR:',fr_embeddings.shape, \"EN:\",en_embeddings.shape, \"TN:\", tn_embeddings.shape)\n",
504
+ " print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, \"EN:\",en_posteriogram.shape,\"TN:\",tn_posteriogram.shape)\n",
505
+ "\n",
506
+ "\n",
507
+ " bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)\n",
508
+ " return bilangual_sample\n",
509
+ "\n",
510
+ "class Mixer(sb.core.Brain):\n",
511
+ "\n",
512
+ " def compute_forward(self, batch, stage):\n",
513
+ " \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n",
514
+ " wavs, wav_lens = batch.sig\n",
515
+ " wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n",
516
+ "\n",
517
+ " if stage == sb.Stage.TRAIN:\n",
518
+ " if hasattr(self.hparams, \"augmentation\"):\n",
519
+ " wavs = self.hparams.augmentation(wavs, wav_lens)\n",
520
+ "\n",
521
+ " multi_langual_feats = middle_layer(wavs, wav_lens)\n",
522
+ " multi_langual_feats= multi_langual_feats.to(device)\n",
523
+ " feats, _ = self.modules.enc(multi_langual_feats)\n",
524
+ " logits = self.modules.ctc_lin(feats)\n",
525
+ " p_ctc = self.hparams.log_softmax(logits)\n",
526
+ " \n",
527
+ " if stage!= sb.Stage.TRAIN:\n",
528
+ " p_tokens = sb.decoders.ctc_greedy_decode(\n",
529
+ " p_ctc, wav_lens, blank_id=self.hparams.blank_index\n",
530
+ " )\n",
531
+ " else : \n",
532
+ " p_tokens = None\n",
533
+ " return p_ctc, wav_lens, p_tokens\n",
534
+ " \n",
535
+ " \n",
536
+ " def treat_wav(self,sig):\n",
537
+ " multi_langual_feats = middle_layer(sig.to(\"cpu\"), torch.tensor([1]).to(\"cpu\"))\n",
538
+ " multi_langual_feats= multi_langual_feats.to(device)\n",
539
+ " feats, _ = self.modules.enc(multi_langual_feats)\n",
540
+ " logits = self.modules.ctc_lin(feats)\n",
541
+ " p_ctc = self.hparams.log_softmax(logits)\n",
542
+ " predicted_words =[]\n",
543
+ " for logs in p_ctc:\n",
544
+ " text = decoder.decode(logs.detach().cpu().numpy())\n",
545
+ " predicted_words.append(text.split(\" \"))\n",
546
+ " return \" \".join(predicted_words[0])\n",
547
+ " \n",
548
+ "\n",
549
+ " def compute_objectives(self, predictions, batch, stage):\n",
550
+ " \"\"\"Computes the loss (CTC) given predictions and targets.\"\"\"\n",
551
+ "\n",
552
+ " p_ctc, wav_lens , predicted_tokens= predictions\n",
553
+ "\n",
554
+ " ids = batch.id\n",
555
+ " tokens, tokens_lens = batch.tokens\n",
556
+ "\n",
557
+ " loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n",
558
+ "\n",
559
+ "\n",
560
+ " if stage == sb.Stage.VALID:\n",
561
+ " predicted_words = [\n",
562
+ " \"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \")\n",
563
+ " for utt_seq in predicted_tokens\n",
564
+ " ]\n",
565
+ " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n",
566
+ " self.wer_metric.append(ids, predicted_words, target_words)\n",
567
+ " self.cer_metric.append(ids, predicted_words, target_words)\n",
568
+ " if stage ==sb.Stage.TEST : \n",
569
+ " if self.hparams.language_modelling:\n",
570
+ " predicted_words = []\n",
571
+ " for logs in p_ctc:\n",
572
+ " text = decoder.decode(logs.detach().cpu().numpy())\n",
573
+ " predicted_words.append(text.split(\" \"))\n",
574
+ " else : \n",
575
+ " predicted_words = [\n",
576
+ " \"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \")\n",
577
+ " for utt_seq in predicted_tokens\n",
578
+ " ]\n",
579
+ "\n",
580
+ " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n",
581
+ " self.wer_metric.append(ids, predicted_words, target_words)\n",
582
+ " self.cer_metric.append(ids, predicted_words, target_words)\n",
583
+ "\n",
584
+ " return loss\n",
585
+ "\n",
586
+ " def fit_batch(self, batch):\n",
587
+ " \"\"\"Train the parameters given a single batch in input\"\"\"\n",
588
+ " should_step = self.step % self.grad_accumulation_factor == 0\n",
589
+ " # Managing automatic mixed precision\n",
590
+ " # TOFIX: CTC fine-tuning currently is unstable\n",
591
+ " # This is certainly due to CTC being done in fp16 instead of fp32\n",
592
+ " if self.auto_mix_prec:\n",
593
+ " with torch.cuda.amp.autocast():\n",
594
+ " with self.no_sync():\n",
595
+ " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n",
596
+ " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n",
597
+ " with self.no_sync(not should_step):\n",
598
+ " self.scaler.scale(\n",
599
+ " loss / self.grad_accumulation_factor\n",
600
+ " ).backward()\n",
601
+ " if should_step:\n",
602
+ "\n",
603
+ "\n",
604
+ " self.scaler.unscale_(self.model_optimizer)\n",
605
+ " if self.check_gradients(loss):\n",
606
+ " self.scaler.step(self.model_optimizer)\n",
607
+ " self.scaler.update()\n",
608
+ " self.zero_grad()\n",
609
+ " self.optimizer_step += 1\n",
610
+ " else:\n",
611
+ " # This is mandatory because HF models have a weird behavior with DDP\n",
612
+ " # on the forward pass\n",
613
+ " with self.no_sync():\n",
614
+ " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n",
615
+ "\n",
616
+ " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n",
617
+ "\n",
618
+ " with self.no_sync(not should_step):\n",
619
+ " (loss / self.grad_accumulation_factor).backward()\n",
620
+ " if should_step:\n",
621
+ " if self.check_gradients(loss):\n",
622
+ " self.model_optimizer.step()\n",
623
+ " self.zero_grad()\n",
624
+ " self.optimizer_step += 1\n",
625
+ "\n",
626
+ " self.on_fit_batch_end(batch, outputs, loss, should_step)\n",
627
+ " return loss.detach().cpu()\n",
628
+ "\n",
629
+ " def evaluate_batch(self, batch, stage):\n",
630
+ " \"\"\"Computations needed for validation/test batches\"\"\"\n",
631
+ " predictions = self.compute_forward(batch, stage=stage)\n",
632
+ " with torch.no_grad():\n",
633
+ " loss = self.compute_objectives(predictions, batch, stage=stage)\n",
634
+ " return loss.detach()\n",
635
+ "\n",
636
+ " def on_stage_start(self, stage, epoch):\n",
637
+ " \"\"\"Gets called at the beginning of each epoch\"\"\"\n",
638
+ " if stage != sb.Stage.TRAIN:\n",
639
+ " self.cer_metric = self.hparams.cer_computer()\n",
640
+ " self.wer_metric = self.hparams.error_rate_computer()\n",
641
+ "\n",
642
+ " def on_stage_end(self, stage, stage_loss, epoch):\n",
643
+ " \"\"\"Gets called at the end of an epoch.\"\"\"\n",
644
+ " # Compute/store important stats\n",
645
+ " stage_stats = {\"loss\": stage_loss}\n",
646
+ " if stage == sb.Stage.TRAIN:\n",
647
+ " self.train_stats = stage_stats\n",
648
+ " else:\n",
649
+ " stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n",
650
+ " stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n",
651
+ "\n",
652
+ " # Perform end-of-iteration things, like annealing, logging, etc.\n",
653
+ " if stage == sb.Stage.VALID:\n",
654
+ " old_lr_model, new_lr_model = self.hparams.lr_annealing_model(\n",
655
+ " stage_stats[\"loss\"]\n",
656
+ " )\n",
657
+ " sb.nnet.schedulers.update_learning_rate(\n",
658
+ " self.model_optimizer, new_lr_model\n",
659
+ " )\n",
660
+ " self.hparams.train_logger.log_stats(\n",
661
+ " stats_meta={\n",
662
+ " \"epoch\": epoch,\n",
663
+ " \"lr_model\": old_lr_model,\n",
664
+ " },\n",
665
+ " train_stats=self.train_stats,\n",
666
+ " valid_stats=stage_stats,\n",
667
+ " )\n",
668
+ " self.checkpointer.save_and_keep_only(\n",
669
+ " meta={\"WER\": stage_stats[\"WER\"]}, min_keys=[\"WER\"],\n",
670
+ " )\n",
671
+ " elif stage == sb.Stage.TEST:\n",
672
+ " self.hparams.train_logger.log_stats(\n",
673
+ " stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},\n",
674
+ " test_stats=stage_stats,\n",
675
+ " )\n",
676
+ " with open(self.hparams.wer_file, \"w\") as w:\n",
677
+ " self.wer_metric.write_stats(w)\n",
678
+ "\n",
679
+ " def init_optimizers(self):\n",
680
+ "\n",
681
+ " self.model_optimizer = self.hparams.model_opt_class(\n",
682
+ " self.hparams.model.parameters()\n",
683
+ " )\n",
684
+ "\n",
685
+ " if self.checkpointer is not None:\n",
686
+ " self.checkpointer.add_recoverable(\"modelopt\", self.model_optimizer)\n",
687
+ "\n",
688
+ " def zero_grad(self, set_to_none=False):\n",
689
+ "\n",
690
+ " self.model_optimizer.zero_grad(set_to_none)\n",
691
+ "\n",
692
+ "\n"
693
+ ]
694
+ },
695
+ {
696
+ "cell_type": "code",
697
+ "execution_count": null,
698
+ "metadata": {},
699
+ "outputs": [
700
+ {
701
+ "name": "stdout",
702
+ "output_type": "stream",
703
+ "text": [
704
+ "speechbrain.utils.distributed - distributed_launch flag is disabled, this experiment will be executed without DDP.\n",
705
+ "speechbrain.core - Beginning experiment!\n",
706
+ "speechbrain.core - Experiment folder: results/non_semi_final_stac\n",
707
+ "speechbrain.dataio.encoder - Load called, but CTCTextEncoder is not empty. Loaded data will overwrite everything. This is normal if there is e.g. an unk label defined at init.\n",
708
+ "pyctcdecode.decoder - Using arpa instead of binary LM file, decoder instantiation might be slow.\n",
709
+ "pyctcdecode.alphabet - Alphabet determined to be of regular style.\n",
710
+ "pyctcdecode.alphabet - Unigrams and labels don't seem to agree.\n",
711
+ "speechbrain.core - Info: auto_mix_prec arg from hparam file is used\n",
712
+ "speechbrain.core - 60.1M trainable parameters in Mixer\n",
713
+ "pyctcdecode.decoder - Using arpa instead of binary LM file, decoder instantiation might be slow.\n",
714
+ "pyctcdecode.alphabet - Alphabet determined to be of regular style.\n",
715
+ "pyctcdecode.alphabet - Unigrams and labels don't seem to agree.\n",
716
+ "speechbrain.utils.checkpoints - Loading a checkpoint from TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00\n"
717
+ ]
718
+ },
719
+ {
720
+ "name": "stderr",
721
+ "output_type": "stream",
722
+ "text": [
723
+ "<ipython-input-26-84a6e2d9fce8>:119: GradioDeprecationWarning: `optional` parameter is deprecated, and it has no effect\n",
724
+ " inputs=[gr.Audio(source=\"microphone\", type='filepath', label = \"record\", optional = True),\n",
725
+ "<ipython-input-26-84a6e2d9fce8>:120: GradioDeprecationWarning: `optional` parameter is deprecated, and it has no effect\n",
726
+ " gr.Audio(source=\"upload\", type='filepath', label=\"filein\", optional=True)]\n"
727
+ ]
728
+ },
729
+ {
730
+ "name": "stdout",
731
+ "output_type": "stream",
732
+ "text": [
733
+ "Running on local URL: http://127.0.0.1:7860\n",
734
+ "\n",
735
+ "To create a public link, set `share=True` in `launch()`.\n"
736
+ ]
737
+ },
738
+ {
739
+ "data": {
740
+ "text/html": [
741
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
742
+ ],
743
+ "text/plain": [
744
+ "<IPython.core.display.HTML object>"
745
+ ]
746
+ },
747
+ "metadata": {},
748
+ "output_type": "display_data"
749
+ },
750
+ {
751
+ "name": "stderr",
752
+ "output_type": "stream",
753
+ "text": [
754
+ "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:188: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n",
755
+ " warnings.warn(warning.format(data.dtype))\n"
756
+ ]
757
+ },
758
+ {
759
+ "name": "stdout",
760
+ "output_type": "stream",
761
+ "text": [
762
+ "tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0075, -0.0042, -0.0031]])\n",
763
+ "tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.0518e-05,\n",
764
+ " -3.0518e-05, 0.0000e+00]])\n"
765
+ ]
766
+ }
767
+ ],
768
+ "source": [
769
+ "hparams_file, run_opts, overrides = sb.parse_arguments([\"cs.yaml\"])\n",
770
+ "\n",
771
+ "# If distributed_launch=True then\n",
772
+ "# create ddp_group with the right communication protocol\n",
773
+ "sb.utils.distributed.ddp_init_group(run_opts)\n",
774
+ "\n",
775
+ "with open(hparams_file) as fin:\n",
776
+ " hparams = load_hyperpyyaml(fin, overrides)\n",
777
+ "\n",
778
+ "# Create experiment directory\n",
779
+ "sb.create_experiment_directory(\n",
780
+ " experiment_directory=hparams[\"output_folder\"],\n",
781
+ " hyperparams_to_save=hparams_file,\n",
782
+ " overrides=overrides,\n",
783
+ ")\n",
784
+ "def read_labels_file(labels_file):\n",
785
+ " with open(labels_file, \"r\",encoding=\"utf-8\") as lf:\n",
786
+ " lines = lf.read().splitlines()\n",
787
+ " division = \"===\"\n",
788
+ " numbers = {}\n",
789
+ " for line in lines :\n",
790
+ " if division in line :\n",
791
+ " break\n",
792
+ " string, number = line.split(\"=>\")\n",
793
+ " number = int(number)\n",
794
+ " string = string[1:-2]\n",
795
+ " numbers[number] = string\n",
796
+ " return [numbers[x] for x in range(len(numbers))]\n",
797
+ "\n",
798
+ "label_encoder = sb.dataio.encoder.CTCTextEncoder()\n",
799
+ "\n",
800
+ "lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n",
801
+ "special_labels = {\n",
802
+ " \"blank_label\": hparams[\"blank_index\"],\n",
803
+ " \"unk_label\": hparams[\"unk_index\"]\n",
804
+ "}\n",
805
+ "label_encoder.load_or_create(\n",
806
+ " path=lab_enc_file,\n",
807
+ " from_didatasets=[[]],\n",
808
+ " output_key=\"char_list\",\n",
809
+ " special_labels=special_labels,\n",
810
+ " sequence_input=True,\n",
811
+ ")\n",
812
+ "\n",
813
+ "\n",
814
+ "labels = read_labels_file(os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\"))\n",
815
+ "labels = [\"\"] + labels[1:-1] + [\"1\"] \n",
816
+ "if hparams[\"language_modelling\"]:\n",
817
+ " decoder = build_ctcdecoder(\n",
818
+ " labels,\n",
819
+ " kenlm_model_path=hparams[\"ngram_lm_path\"], # either .arpa or .bin file\n",
820
+ " alpha=0.5, # tuned on a val set\n",
821
+ " beta=1, # tuned on a val set\n",
822
+ " )\n",
823
+ "\n",
824
+ "\n",
825
+ "\n",
826
+ "\n",
827
+ "mixer = Mixer(\n",
828
+ " modules=hparams[\"modules\"],\n",
829
+ " hparams=hparams,\n",
830
+ " run_opts=run_opts,\n",
831
+ " checkpointer=hparams[\"checkpointer\"],\n",
832
+ ")\n",
833
+ "mixer.tokenizer = label_encoder\n",
834
+ "\n",
835
+ "\n",
836
+ "label_encoder = sb.dataio.encoder.CTCTextEncoder()\n",
837
+ "\n",
838
+ "\n",
839
+ "# We dynamicaly add the tokenizer to our brain class.\n",
840
+ "# NB: This tokenizer corresponds to the one used for the LM!!\n",
841
+ "\n",
842
+ "decoder = build_ctcdecoder(\n",
843
+ " labels,\n",
844
+ " kenlm_model_path= \"arpas/everything.arpa\", # either .arpa or .bin file\n",
845
+ " alpha=0.5, # tuned on a val set\n",
846
+ " beta=1, # tuned on a val set\n",
847
+ ")\n",
848
+ "\n",
849
+ "run_opts[\"device\"]=\"cpu\"\n",
850
+ "\n",
851
+ "\n",
852
+ "device = \"cpu\"\n",
853
+ "mixer.device= \"cpu\"\n",
854
+ "mixer.modules.to(\"cpu\")\n",
855
+ "\n",
856
+ "from enum import Enum, auto\n",
857
+ "class Stage(Enum):\n",
858
+ " TRAIN = auto()\n",
859
+ " VALID = auto()\n",
860
+ " TEST = auto()\n",
861
+ "\n",
862
+ "asr_brain.on_evaluate_start()\n",
863
+ "asr_brain.modules.eval()\n",
864
+ "\n",
865
+ "\n",
866
+ "import gradio as gr\n",
867
+ "\n",
868
+ "def treat_wav_file(file_mic,file_upload ,asr=mixer, device=\"cpu\") :\n",
869
+ " if (file_mic is not None) and (file_upload is not None):\n",
870
+ " warn_output = \"WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\\n\"\n",
871
+ " wav = file_mic\n",
872
+ " elif (file_mic is None) and (file_upload is None):\n",
873
+ " return \"ERROR: You have to either use the microphone or upload an audio file\"\n",
874
+ " elif file_mic is not None:\n",
875
+ " wav = file_mic\n",
876
+ " else:\n",
877
+ " wav = file_upload\n",
878
+ " sig, sr = torchaudio.load(wav)\n",
879
+ " tensor_wav = sig.to(device)\n",
880
+ " resampled = torchaudio.functional.resample( tensor_wav, sr, 16000)\n",
881
+ " sentence = asr.treat_wav(resampled)\n",
882
+ " return sentence\n",
883
+ "\n",
884
+ "gr.Interface(\n",
885
+ " fn=treat_wav_file, \n",
886
+ " inputs=[gr.Audio(source=\"microphone\", type='filepath', label = \"record\", optional = True), \n",
887
+ " gr.Audio(source=\"upload\", type='filepath', label=\"filein\", optional=True)]\n",
888
+ " ,outputs=\"text\").launch(share= False, debug = True)\n"
889
+ ]
890
+ }
891
+ ],
892
+ "metadata": {
893
+ "kernelspec": {
894
+ "display_name": "Python 3",
895
+ "language": "python",
896
+ "name": "python3"
897
+ },
898
+ "language_info": {
899
+ "codemirror_mode": {
900
+ "name": "ipython",
901
+ "version": 3
902
+ },
903
+ "file_extension": ".py",
904
+ "mimetype": "text/x-python",
905
+ "name": "python",
906
+ "nbconvert_exporter": "python",
907
+ "pygments_lexer": "ipython3",
908
+ "version": "3.8.5"
909
+ }
910
+ },
911
+ "nbformat": 4,
912
+ "nbformat_minor": 5
913
+ }
EnglishCV/train_en_with_wav2vec.yaml CHANGED
@@ -7,13 +7,13 @@
7
  # Seed needs to be set at top of yaml, before objects with parameters are made
8
  seed: 1234
9
  __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
10
- output_folder: !ref results/wav2vec2_ctc_en/<seed>
11
  wer_file: !ref <output_folder>/wer.txt
12
  save_folder: !ref <output_folder>/save
13
  train_log: !ref <output_folder>/train_log.txt
14
 
15
  # URL for the biggest Fairseq english wav2vec2 model.
16
- wav2vec2_hub: facebook/wav2vec2-large-lv60
17
  wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
18
 
19
  # Data files
@@ -109,7 +109,7 @@ enc: !new:speechbrain.nnet.containers.Sequential
109
  activation3: !new:torch.nn.LeakyReLU
110
 
111
  wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
112
- source: /gpfsscratch/rech/nou/uzn19yk/wav2vec2-large-lv60/
113
  output_norm: True
114
  freeze: !ref <freeze_wav2vec>
115
  freeze_feature_extractor: !ref <freeze_feature_extractor>
 
7
  # Seed needs to be set at top of yaml, before objects with parameters are made
8
  seed: 1234
9
  __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
10
+ output_folder: !ref EnglishCV/results/wav2vec2_ctc_en/<seed>
11
  wer_file: !ref <output_folder>/wer.txt
12
  save_folder: !ref <output_folder>/save
13
  train_log: !ref <output_folder>/train_log.txt
14
 
15
  # URL for the biggest Fairseq english wav2vec2 model.
16
+ wav2vec2_hub: wav2vec2-large-lv60/
17
  wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
18
 
19
  # Data files
 
109
  activation3: !new:torch.nn.LeakyReLU
110
 
111
  wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
112
+ source: wav2vec2-large-lv60/
113
  output_norm: True
114
  freeze: !ref <freeze_wav2vec>
115
  freeze_feature_extractor: !ref <freeze_feature_extractor>
EnglishCV/train_with_wav2vec.py CHANGED
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
37
 
38
 
39
  # Define training procedure
40
- class ASR(sb.core.Brain):
41
  def compute_forward(self, batch, stage):
42
  """Forward computations from the waveform batches to the output probabilities."""
43
 
@@ -360,7 +360,7 @@ if __name__ == "__main__":
360
  train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer)
361
 
362
  # Trainer initialization
363
- asr_brain = ASR(
364
  modules=hparams["modules"],
365
  hparams=hparams,
366
  run_opts=run_opts,
 
37
 
38
 
39
  # Define training procedure
40
+ class ASRCV(sb.core.Brain):
41
  def compute_forward(self, batch, stage):
42
  """Forward computations from the waveform batches to the output probabilities."""
43
 
 
360
  train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer)
361
 
362
  # Trainer initialization
363
+ asr_brain = ASRCV(
364
  modules=hparams["modules"],
365
  hparams=hparams,
366
  run_opts=run_opts,
TunisianASR/results/14epoch_tunisian/1234/env.log ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SpeechBrain system description
2
+ ==============================
3
+ Python version:
4
+ 3.8.5 (default, Sep 4 2020, 07:30:14)
5
+ [GCC 7.3.0]
6
+ ==============================
7
+ Installed Python packages:
8
+ absl-py==1.2.0
9
+ aiohttp==3.8.1
10
+ aiosignal==1.2.0
11
+ alabaster==0.7.12
12
+ anaconda-client==1.7.2
13
+ anaconda-navigator==1.10.0
14
+ anaconda-project==0.8.3
15
+ antlr4-python3-runtime==4.9.3
16
+ appdirs==1.4.4
17
+ argh==0.26.2
18
+ argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1596828493937/work
19
+ asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work
20
+ astroid @ file:///tmp/build/80754af9/astroid_1592495912941/work
21
+ astropy==4.0.2
22
+ async-generator==1.10
23
+ async-timeout==4.0.2
24
+ atomicwrites==1.4.0
25
+ attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work
26
+ audioread==2.1.9
27
+ autopep8 @ file:///tmp/build/80754af9/autopep8_1596578164842/work
28
+ Babel @ file:///tmp/build/80754af9/babel_1605108370292/work
29
+ backcall==0.2.0
30
+ backports.functools-lru-cache==1.6.1
31
+ backports.shutil-get-terminal-size==1.0.0
32
+ backports.tempfile==1.0
33
+ backports.weakref==1.0.post1
34
+ beautifulsoup4 @ file:///tmp/build/80754af9/beautifulsoup4_1601924105527/work
35
+ bitarray @ file:///tmp/build/80754af9/bitarray_1605065113847/work
36
+ bkcharts==0.2
37
+ black==22.12.0
38
+ bleach @ file:///tmp/build/80754af9/bleach_1600439572647/work
39
+ bokeh @ file:///tmp/build/80754af9/bokeh_1603297833684/work
40
+ boto==2.49.0
41
+ boto3==1.28.43
42
+ botocore==1.31.43
43
+ Bottleneck==1.3.2
44
+ bpemb==0.3.4
45
+ brotlipy==0.7.0
46
+ cachetools==5.2.0
47
+ certifi==2020.6.20
48
+ cffi @ file:///tmp/build/80754af9/cffi_1600699146221/work
49
+ chardet==3.0.4
50
+ charset-normalizer==2.0.12
51
+ click==8.1.3
52
+ cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work
53
+ clyent==1.2.2
54
+ colorama @ file:///tmp/build/80754af9/colorama_1603211150991/work
55
+ coloredlogs==15.0.1
56
+ conda==4.9.2
57
+ conda-build==3.20.5
58
+ conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1603018141399/work
59
+ conda-verify==3.4.2
60
+ conllu==4.5.3
61
+ contextlib2==0.6.0.post1
62
+ cryptography @ file:///tmp/build/80754af9/cryptography_1601046815590/work
63
+ cycler==0.10.0
64
+ Cython @ file:///tmp/build/80754af9/cython_1594831566883/work
65
+ cytoolz==0.11.0
66
+ dask @ file:///tmp/build/80754af9/dask-core_1602083700509/work
67
+ datasets==1.18.3
68
+ decorator==4.4.2
69
+ defusedxml==0.6.0
70
+ Deprecated==1.2.14
71
+ diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work
72
+ dill==0.3.4
73
+ distributed @ file:///tmp/build/80754af9/distributed_1605066520644/work
74
+ docutils==0.16
75
+ easyocr==1.2.1
76
+ einops==0.3.0
77
+ entrypoints==0.3
78
+ et-xmlfile==1.0.1
79
+ farasapy==0.0.14
80
+ fastcache==1.1.0
81
+ ffmpeg-python==0.2.0
82
+ filelock==3.0.12
83
+ flair==0.12.2
84
+ flake8 @ file:///tmp/build/80754af9/flake8_1601911421857/work
85
+ Flask==1.1.2
86
+ flatbuffers==22.9.24
87
+ frozenlist==1.3.0
88
+ fsspec==2022.3.0
89
+ ftfy==6.1.1
90
+ future==0.18.2
91
+ gdown==4.4.0
92
+ gensim==4.1.2
93
+ gevent @ file:///tmp/build/80754af9/gevent_1601397537062/work
94
+ glob2==0.7
95
+ gmpy2==2.0.8
96
+ google-auth==2.12.0
97
+ google-auth-oauthlib==0.4.6
98
+ greenlet @ file:///tmp/build/80754af9/greenlet_1600874013538/work
99
+ grpcio==1.49.1
100
+ h5py==2.10.0
101
+ HeapDict==1.0.1
102
+ html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work
103
+ huggingface-hub==0.16.4
104
+ humanfriendly==10.0
105
+ hyperopt==0.2.7
106
+ idna @ file:///tmp/build/80754af9/idna_1593446292537/work
107
+ imageio @ file:///tmp/build/80754af9/imageio_1594161405741/work
108
+ imagesize==1.2.0
109
+ imhist==0.0.4
110
+ importlib-metadata==5.0.0
111
+ imWatermark==0.0.2
112
+ iniconfig @ file:///tmp/build/80754af9/iniconfig_1602780191262/work
113
+ install==1.3.5
114
+ intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work
115
+ invisible-watermark==0.1.5
116
+ ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl
117
+ ipython @ file:///tmp/build/80754af9/ipython_1604101197014/work
118
+ ipython-genutils==0.2.0
119
+ ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1601490159889/work
120
+ isort @ file:///tmp/build/80754af9/isort_1602603989581/work
121
+ itsdangerous==1.1.0
122
+ Janome==0.5.0
123
+ jdcal==1.4.1
124
+ jedi @ file:///tmp/build/80754af9/jedi_1592841866100/work
125
+ jeepney @ file:///tmp/build/80754af9/jeepney_1605069705079/work
126
+ Jinja2==2.11.2
127
+ jiwer==2.3.0
128
+ jmespath==1.0.1
129
+ joblib @ file:///tmp/build/80754af9/joblib_1601912903842/work
130
+ json5==0.9.5
131
+ jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
132
+ jupyter==1.0.0
133
+ jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1601311786391/work
134
+ jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1598884538475/work
135
+ jupyter-core==4.6.3
136
+ jupyterlab==2.2.6
137
+ jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
138
+ jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1594164409481/work
139
+ keyring @ file:///tmp/build/80754af9/keyring_1601490835422/work
140
+ kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1604014535162/work
141
+ langdetect==1.0.9
142
+ lazy-object-proxy==1.4.3
143
+ libarchive-c==2.9
144
+ librosa==0.9.1
145
+ llvmlite==0.34.0
146
+ locket==0.2.0
147
+ lxml @ file:///tmp/build/80754af9/lxml_1603216285000/work
148
+ Markdown==3.4.1
149
+ MarkupSafe==1.1.1
150
+ matplotlib @ file:///tmp/build/80754af9/matplotlib-base_1603378225747/work
151
+ mccabe==0.6.1
152
+ mido==1.2.10
153
+ mistune==0.8.4
154
+ mkl-fft==1.2.0
155
+ mkl-random==1.1.1
156
+ mkl-service==2.3.0
157
+ mock==4.0.2
158
+ more-itertools @ file:///tmp/build/80754af9/more-itertools_1605111547926/work
159
+ mpld3==0.3
160
+ mpmath==1.1.0
161
+ msgpack==1.0.0
162
+ multidict==6.0.2
163
+ multipledispatch==0.6.0
164
+ multiprocess==0.70.12.2
165
+ mypy-extensions==0.4.3
166
+ navigator-updater==0.2.1
167
+ nbclient @ file:///tmp/build/80754af9/nbclient_1602783176460/work
168
+ nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914830498/work
169
+ nbformat @ file:///tmp/build/80754af9/nbformat_1602783287752/work
170
+ nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1605115881283/work
171
+ networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work
172
+ nltk @ file:///tmp/build/80754af9/nltk_1592496090529/work
173
+ nose==1.3.7
174
+ notebook @ file:///tmp/build/80754af9/notebook_1601501575118/work
175
+ numba @ file:///tmp/build/80754af9/numba_1600100669015/work
176
+ numexpr==2.7.1
177
+ numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1603570489231/work
178
+ numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work
179
+ oauthlib==3.2.1
180
+ olefile==0.46
181
+ omegaconf==2.2.3
182
+ onnx==1.12.0
183
+ onnxruntime==1.12.1
184
+ opencv-python==4.4.0.46
185
+ openpyxl @ file:///tmp/build/80754af9/openpyxl_1598113097404/work
186
+ packaging==20.9
187
+ pandas @ file:///tmp/build/80754af9/pandas_1602088120436/work
188
+ pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work
189
+ parso==0.7.0
190
+ partd==1.1.0
191
+ path @ file:///tmp/build/80754af9/path_1598376507494/work
192
+ pathlib2==2.3.5
193
+ pathspec==0.10.3
194
+ pathtools==0.1.2
195
+ patsy==0.5.1
196
+ pep8==1.7.1
197
+ pexpect==4.8.0
198
+ pickleshare==0.7.5
199
+ Pillow @ file:///tmp/build/80754af9/pillow_1603822255246/work
200
+ pkginfo==1.6.1
201
+ platformdirs==2.6.0
202
+ pluggy==0.13.1
203
+ ply==3.11
204
+ pooch==1.6.0
205
+ pptree==3.1
206
+ pretty-midi==0.2.9
207
+ prometheus-client==0.8.0
208
+ prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work
209
+ protobuf==3.19.6
210
+ psutil @ file:///tmp/build/80754af9/psutil_1598370257551/work
211
+ ptyprocess==0.6.0
212
+ py @ file:///tmp/build/80754af9/py_1593446248552/work
213
+ py-espeak-ng==0.1.8
214
+ py4j==0.10.9.7
215
+ PyArabic==0.6.15
216
+ pyarrow==7.0.0
217
+ pyasn1==0.4.8
218
+ pyasn1-modules==0.2.8
219
+ pycodestyle==2.6.0
220
+ pycosat==0.6.3
221
+ pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
222
+ pycurl==7.43.0.6
223
+ pyDeprecate==0.3.1
224
+ pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1598885001695/work
225
+ pyflakes==2.2.0
226
+ Pygments @ file:///tmp/build/80754af9/pygments_1604103097372/work
227
+ pylint @ file:///tmp/build/80754af9/pylint_1598623985952/work
228
+ pyodbc===4.0.0-unsupported
229
+ pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1594392929924/work
230
+ pyparsing==2.4.7
231
+ pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work
232
+ PySocks==1.7.1
233
+ pytest==0.0.0
234
+ python-bidi==0.4.2
235
+ python-crfsuite==0.9.7
236
+ python-dateutil==2.8.1
237
+ python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work
238
+ python-language-server @ file:///tmp/build/80754af9/python-language-server_1600454544709/work
239
+ python-Levenshtein==0.12.2
240
+ pytorch-lightning==1.4.2
241
+ pytorch-revgrad==0.2.0
242
+ pytz==2020.1
243
+ PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work
244
+ pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work
245
+ PyYAML==5.3.1
246
+ pyzmq==19.0.2
247
+ QDarkStyle==2.8.1
248
+ QtAwesome @ file:///tmp/build/80754af9/qtawesome_1602272867890/work
249
+ qtconsole @ file:///tmp/build/80754af9/qtconsole_1600870028330/work
250
+ QtPy==1.9.0
251
+ regex @ file:///tmp/build/80754af9/regex_1602786672676/work
252
+ requests @ file:///tmp/build/80754af9/requests_1592841827918/work
253
+ requests-oauthlib==1.3.1
254
+ resampy==0.2.2
255
+ rope @ file:///tmp/build/80754af9/rope_1602264064449/work
256
+ rsa==4.9
257
+ Rtree==0.9.4
258
+ ruamel-yaml==0.15.87
259
+ s3transfer==0.6.2
260
+ sacremoses==0.0.49
261
+ safetensors==0.3.3
262
+ scikit-image==0.17.2
263
+ scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1598376899566/work
264
+ scipy @ file:///tmp/build/80754af9/scipy_1597686649129/work
265
+ seaborn @ file:///tmp/build/80754af9/seaborn_1600553570093/work
266
+ SecretStorage==3.1.2
267
+ segtok==1.5.11
268
+ Send2Trash==1.5.0
269
+ sentencepiece==0.1.97
270
+ simplegeneric==0.8.1
271
+ singledispatch @ file:///tmp/build/80754af9/singledispatch_1602523705405/work
272
+ sip==4.19.13
273
+ six @ file:///tmp/build/80754af9/six_1605205327372/work
274
+ smart-open==5.2.1
275
+ snowballstemmer==2.0.0
276
+ sortedcollections==1.2.1
277
+ sortedcontainers==2.2.2
278
+ SoundFile==0.10.3.post1
279
+ soupsieve==2.0.1
280
+ sphfile==1.0.3
281
+ Sphinx @ file:///tmp/build/80754af9/sphinx_1597428793432/work
282
+ sphinxcontrib-applehelp==1.0.2
283
+ sphinxcontrib-devhelp==1.0.2
284
+ sphinxcontrib-htmlhelp==1.0.3
285
+ sphinxcontrib-jsmath==1.0.1
286
+ sphinxcontrib-qthelp==1.0.3
287
+ sphinxcontrib-serializinghtml==1.1.4
288
+ sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work
289
+ spyder @ file:///tmp/build/80754af9/spyder_1599056981321/work
290
+ spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1599056754858/work
291
+ SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1603397987316/work
292
+ sqlitedict==2.1.0
293
+ statsmodels @ file:///tmp/build/80754af9/statsmodels_1602280205159/work
294
+ sympy @ file:///tmp/build/80754af9/sympy_1605119542615/work
295
+ tables==3.6.1
296
+ tabulate==0.9.0
297
+ tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work
298
+ tensorboard==2.10.1
299
+ tensorboard-data-server==0.6.1
300
+ tensorboard-plugin-wit==1.8.1
301
+ terminado==0.9.1
302
+ testpath==0.4.4
303
+ threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl
304
+ tifffile==2020.10.1
305
+ tkseem==0.0.3
306
+ tokenizers==0.13.3
307
+ toml @ file:///tmp/build/80754af9/toml_1592853716807/work
308
+ tomli==2.0.1
309
+ toolz @ file:///tmp/build/80754af9/toolz_1601054250827/work
310
+ torch==1.11.0
311
+ torchaudio==0.11.0
312
+ torchmetrics==0.6.0
313
+ torchvision==0.8.2
314
+ tornado==6.0.4
315
+ tqdm==4.64.0
316
+ traitlets @ file:///tmp/build/80754af9/traitlets_1602787416690/work
317
+ transformer-smaller-training-vocab==0.3.1
318
+ transformers==4.33.1
319
+ typing-extensions==4.4.0
320
+ ujson @ file:///tmp/build/80754af9/ujson_1602523317881/work
321
+ unicodecsv==0.14.1
322
+ urllib3 @ file:///tmp/build/80754af9/urllib3_1603305693037/work
323
+ watchdog @ file:///tmp/build/80754af9/watchdog_1593447344699/work
324
+ wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
325
+ webencodings==0.5.1
326
+ Werkzeug==1.0.1
327
+ widgetsnbextension==3.5.1
328
+ Wikipedia-API==0.6.0
329
+ wrapt==1.11.2
330
+ wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1594753850195/work
331
+ xlrd==1.2.0
332
+ XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1602692860603/work
333
+ xlwt==1.3.0
334
+ xmltodict==0.12.0
335
+ xxhash==3.0.0
336
+ yapf @ file:///tmp/build/80754af9/yapf_1593528177422/work
337
+ yarl==1.7.2
338
+ zict==2.0.0
339
+ zipp @ file:///tmp/build/80754af9/zipp_1604001098328/work
340
+ zope.event==4.5.0
341
+ zope.interface @ file:///tmp/build/80754af9/zope.interface_1602002420968/work
342
+ ==============================
343
+ Git revision:
344
+ 8a51838
345
+ ==============================
346
+ CUDA version:
347
+ 11.7
TunisianASR/{semi_wavlm_large_tunisian_ctc → results/14epoch_tunisian}/1234/hyperparams.yaml RENAMED
@@ -1,5 +1,5 @@
1
- # Generated 2023-09-08 from:
2
- # /gpfsdsstore/projects/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/hparams/train_semi.yaml
3
  # yamllint disable
4
  # ################################
5
  # Model: wav2vec2 + DNN + CTC
@@ -10,13 +10,13 @@
10
  # Seed needs to be set at top of yaml, before objects with parameters are made
11
  seed: 1234
12
  __set_seed: !!python/object/apply:torch.manual_seed [1234]
13
- output_folder: results/semi_wavlm_large_tunisian_ctc/1234
14
- wer_file: results/semi_wavlm_large_tunisian_ctc/1234/wer.txt
15
- save_folder: results/semi_wavlm_large_tunisian_ctc/1234/save
16
- train_log: results/semi_wavlm_large_tunisian_ctc/1234/train_log.txt
17
 
18
  # URL for the biggest LeBenchmark wav2vec french.
19
- wav2vec2_folder: results/semi_wavlm_large_tunisian_ctc/1234/save/wav2vec2_checkpoint
20
 
21
  # Data files
22
  data_folder: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
@@ -25,7 +25,7 @@ dev_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/dev.tsv # Standard Com
25
  test_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/test.tsv # Standard CommonVoice .tsv files
26
  accented_letters: true
27
  language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
28
- train_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/train_enhanced.csv
29
  valid_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/dev.csv
30
  test_csv:
31
  - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/full_annotation_test.csv
@@ -44,7 +44,7 @@ avoid_if_shorter_than: 1.2
44
 
45
 
46
  # Training parameters
47
- number_of_epochs: 12
48
  lr: 1.0
49
  lr_wav2vec: 0.0001
50
  sorting: ascending
@@ -91,7 +91,7 @@ unk_index: 1
91
  #
92
  epoch_counter: &id007 !new:speechbrain.utils.epoch_loop.EpochCounter
93
 
94
- limit: 12
95
 
96
  augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
97
  sample_rate: 16000
@@ -120,11 +120,11 @@ enc: &id002 !new:speechbrain.nnet.containers.Sequential
120
  activation3: !new:torch.nn.LeakyReLU
121
 
122
  wav2vec2: &id001 !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
123
- source: /gpfsstore/rech/nou/uzn19yk/wavlm/
124
  output_norm: false
125
  freeze: false
126
  freeze_feature_extractor: true
127
- save_path: results/semi_wavlm_large_tunisian_ctc/1234/save/wav2vec2_checkpoint
128
 
129
  #####
130
  # Uncomment this block if you prefer to use a Fairseq pretrained model instead
@@ -178,7 +178,7 @@ lr_annealing_wav2vec: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
178
  patient: 0
179
 
180
  checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
181
- checkpoints_dir: results/semi_wavlm_large_tunisian_ctc/1234/save
182
  recoverables:
183
  wav2vec2: *id001
184
  model: *id004
@@ -186,7 +186,7 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
186
  scheduler_wav2vec: *id006
187
  counter: *id007
188
  train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
189
- save_file: results/semi_wavlm_large_tunisian_ctc/1234/train_log.txt
190
 
191
  error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
192
 
 
1
+ # Generated 2023-09-20 from:
2
+ # /home/salah/Code_Switched_Tunisian_Speech_Recognition/TunisianASR/semi_trained.yaml
3
  # yamllint disable
4
  # ################################
5
  # Model: wav2vec2 + DNN + CTC
 
10
  # Seed needs to be set at top of yaml, before objects with parameters are made
11
  seed: 1234
12
  __set_seed: !!python/object/apply:torch.manual_seed [1234]
13
+ output_folder: TunisianASR/results/14epoch_tunisian/1234/
14
+ wer_file: TunisianASR/results/14epoch_tunisian/1234//wer.txt
15
+ save_folder: TunisianASR/results/14epoch_tunisian/1234//save
16
+ train_log: TunisianASR/results/14epoch_tunisian/1234//train_log.txt
17
 
18
  # URL for the biggest LeBenchmark wav2vec french.
19
+ wav2vec2_folder: TunisianASR/results/14epoch_tunisian/1234//save/wav2vec2_checkpoint
20
 
21
  # Data files
22
  data_folder: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
 
25
  test_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/test.tsv # Standard CommonVoice .tsv files
26
  accented_letters: true
27
  language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
28
+ train_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/train.csv
29
  valid_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/dev.csv
30
  test_csv:
31
  - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/full_annotation_test.csv
 
44
 
45
 
46
  # Training parameters
47
+ number_of_epochs: 14
48
  lr: 1.0
49
  lr_wav2vec: 0.0001
50
  sorting: ascending
 
91
  #
92
  epoch_counter: &id007 !new:speechbrain.utils.epoch_loop.EpochCounter
93
 
94
+ limit: 14
95
 
96
  augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
97
  sample_rate: 16000
 
120
  activation3: !new:torch.nn.LeakyReLU
121
 
122
  wav2vec2: &id001 !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
123
+ source: wavlm-large/
124
  output_norm: false
125
  freeze: false
126
  freeze_feature_extractor: true
127
+ save_path: TunisianASR/results/14epoch_tunisian/1234//save/wav2vec2_checkpoint
128
 
129
  #####
130
  # Uncomment this block if you prefer to use a Fairseq pretrained model instead
 
178
  patient: 0
179
 
180
  checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
181
+ checkpoints_dir: TunisianASR/results/14epoch_tunisian/1234//save
182
  recoverables:
183
  wav2vec2: *id001
184
  model: *id004
 
186
  scheduler_wav2vec: *id006
187
  counter: *id007
188
  train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
189
+ save_file: TunisianASR/results/14epoch_tunisian/1234//train_log.txt
190
 
191
  error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
192
 
TunisianASR/results/14epoch_tunisian/1234/log.txt ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2023-09-20 16:23:38,106 - speechbrain.core - INFO - Beginning experiment!
2
+ 2023-09-20 16:23:38,106 - speechbrain.core - INFO - Experiment folder: TunisianASR/results/14epoch_tunisian/1234/
3
+ 2023-09-20 16:23:39,287 - speechbrain.utils.superpowers - DEBUG - absl-py==1.2.0
4
+ aiohttp==3.8.1
5
+ aiosignal==1.2.0
6
+ alabaster==0.7.12
7
+ anaconda-client==1.7.2
8
+ anaconda-navigator==1.10.0
9
+ anaconda-project==0.8.3
10
+ antlr4-python3-runtime==4.9.3
11
+ appdirs==1.4.4
12
+ argh==0.26.2
13
+ argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1596828493937/work
14
+ asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work
15
+ astroid @ file:///tmp/build/80754af9/astroid_1592495912941/work
16
+ astropy==4.0.2
17
+ async-generator==1.10
18
+ async-timeout==4.0.2
19
+ atomicwrites==1.4.0
20
+ attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work
21
+ audioread==2.1.9
22
+ autopep8 @ file:///tmp/build/80754af9/autopep8_1596578164842/work
23
+ Babel @ file:///tmp/build/80754af9/babel_1605108370292/work
24
+ backcall==0.2.0
25
+ backports.functools-lru-cache==1.6.1
26
+ backports.shutil-get-terminal-size==1.0.0
27
+ backports.tempfile==1.0
28
+ backports.weakref==1.0.post1
29
+ beautifulsoup4 @ file:///tmp/build/80754af9/beautifulsoup4_1601924105527/work
30
+ bitarray @ file:///tmp/build/80754af9/bitarray_1605065113847/work
31
+ bkcharts==0.2
32
+ black==22.12.0
33
+ bleach @ file:///tmp/build/80754af9/bleach_1600439572647/work
34
+ bokeh @ file:///tmp/build/80754af9/bokeh_1603297833684/work
35
+ boto==2.49.0
36
+ boto3==1.28.43
37
+ botocore==1.31.43
38
+ Bottleneck==1.3.2
39
+ bpemb==0.3.4
40
+ brotlipy==0.7.0
41
+ cachetools==5.2.0
42
+ certifi==2020.6.20
43
+ cffi @ file:///tmp/build/80754af9/cffi_1600699146221/work
44
+ chardet==3.0.4
45
+ charset-normalizer==2.0.12
46
+ click==8.1.3
47
+ cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work
48
+ clyent==1.2.2
49
+ colorama @ file:///tmp/build/80754af9/colorama_1603211150991/work
50
+ coloredlogs==15.0.1
51
+ conda==4.9.2
52
+ conda-build==3.20.5
53
+ conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1603018141399/work
54
+ conda-verify==3.4.2
55
+ conllu==4.5.3
56
+ contextlib2==0.6.0.post1
57
+ cryptography @ file:///tmp/build/80754af9/cryptography_1601046815590/work
58
+ cycler==0.10.0
59
+ Cython @ file:///tmp/build/80754af9/cython_1594831566883/work
60
+ cytoolz==0.11.0
61
+ dask @ file:///tmp/build/80754af9/dask-core_1602083700509/work
62
+ datasets==1.18.3
63
+ decorator==4.4.2
64
+ defusedxml==0.6.0
65
+ Deprecated==1.2.14
66
+ diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work
67
+ dill==0.3.4
68
+ distributed @ file:///tmp/build/80754af9/distributed_1605066520644/work
69
+ docutils==0.16
70
+ easyocr==1.2.1
71
+ einops==0.3.0
72
+ entrypoints==0.3
73
+ et-xmlfile==1.0.1
74
+ farasapy==0.0.14
75
+ fastcache==1.1.0
76
+ ffmpeg-python==0.2.0
77
+ filelock==3.0.12
78
+ flair==0.12.2
79
+ flake8 @ file:///tmp/build/80754af9/flake8_1601911421857/work
80
+ Flask==1.1.2
81
+ flatbuffers==22.9.24
82
+ frozenlist==1.3.0
83
+ fsspec==2022.3.0
84
+ ftfy==6.1.1
85
+ future==0.18.2
86
+ gdown==4.4.0
87
+ gensim==4.1.2
88
+ gevent @ file:///tmp/build/80754af9/gevent_1601397537062/work
89
+ glob2==0.7
90
+ gmpy2==2.0.8
91
+ google-auth==2.12.0
92
+ google-auth-oauthlib==0.4.6
93
+ greenlet @ file:///tmp/build/80754af9/greenlet_1600874013538/work
94
+ grpcio==1.49.1
95
+ h5py==2.10.0
96
+ HeapDict==1.0.1
97
+ html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work
98
+ huggingface-hub==0.16.4
99
+ humanfriendly==10.0
100
+ hyperopt==0.2.7
101
+ idna @ file:///tmp/build/80754af9/idna_1593446292537/work
102
+ imageio @ file:///tmp/build/80754af9/imageio_1594161405741/work
103
+ imagesize==1.2.0
104
+ imhist==0.0.4
105
+ importlib-metadata==5.0.0
106
+ imWatermark==0.0.2
107
+ iniconfig @ file:///tmp/build/80754af9/iniconfig_1602780191262/work
108
+ install==1.3.5
109
+ intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work
110
+ invisible-watermark==0.1.5
111
+ ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl
112
+ ipython @ file:///tmp/build/80754af9/ipython_1604101197014/work
113
+ ipython-genutils==0.2.0
114
+ ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1601490159889/work
115
+ isort @ file:///tmp/build/80754af9/isort_1602603989581/work
116
+ itsdangerous==1.1.0
117
+ Janome==0.5.0
118
+ jdcal==1.4.1
119
+ jedi @ file:///tmp/build/80754af9/jedi_1592841866100/work
120
+ jeepney @ file:///tmp/build/80754af9/jeepney_1605069705079/work
121
+ Jinja2==2.11.2
122
+ jiwer==2.3.0
123
+ jmespath==1.0.1
124
+ joblib @ file:///tmp/build/80754af9/joblib_1601912903842/work
125
+ json5==0.9.5
126
+ jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
127
+ jupyter==1.0.0
128
+ jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1601311786391/work
129
+ jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1598884538475/work
130
+ jupyter-core==4.6.3
131
+ jupyterlab==2.2.6
132
+ jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
133
+ jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1594164409481/work
134
+ keyring @ file:///tmp/build/80754af9/keyring_1601490835422/work
135
+ kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1604014535162/work
136
+ langdetect==1.0.9
137
+ lazy-object-proxy==1.4.3
138
+ libarchive-c==2.9
139
+ librosa==0.9.1
140
+ llvmlite==0.34.0
141
+ locket==0.2.0
142
+ lxml @ file:///tmp/build/80754af9/lxml_1603216285000/work
143
+ Markdown==3.4.1
144
+ MarkupSafe==1.1.1
145
+ matplotlib @ file:///tmp/build/80754af9/matplotlib-base_1603378225747/work
146
+ mccabe==0.6.1
147
+ mido==1.2.10
148
+ mistune==0.8.4
149
+ mkl-fft==1.2.0
150
+ mkl-random==1.1.1
151
+ mkl-service==2.3.0
152
+ mock==4.0.2
153
+ more-itertools @ file:///tmp/build/80754af9/more-itertools_1605111547926/work
154
+ mpld3==0.3
155
+ mpmath==1.1.0
156
+ msgpack==1.0.0
157
+ multidict==6.0.2
158
+ multipledispatch==0.6.0
159
+ multiprocess==0.70.12.2
160
+ mypy-extensions==0.4.3
161
+ navigator-updater==0.2.1
162
+ nbclient @ file:///tmp/build/80754af9/nbclient_1602783176460/work
163
+ nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914830498/work
164
+ nbformat @ file:///tmp/build/80754af9/nbformat_1602783287752/work
165
+ nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1605115881283/work
166
+ networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work
167
+ nltk @ file:///tmp/build/80754af9/nltk_1592496090529/work
168
+ nose==1.3.7
169
+ notebook @ file:///tmp/build/80754af9/notebook_1601501575118/work
170
+ numba @ file:///tmp/build/80754af9/numba_1600100669015/work
171
+ numexpr==2.7.1
172
+ numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1603570489231/work
173
+ numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work
174
+ oauthlib==3.2.1
175
+ olefile==0.46
176
+ omegaconf==2.2.3
177
+ onnx==1.12.0
178
+ onnxruntime==1.12.1
179
+ opencv-python==4.4.0.46
180
+ openpyxl @ file:///tmp/build/80754af9/openpyxl_1598113097404/work
181
+ packaging==20.9
182
+ pandas @ file:///tmp/build/80754af9/pandas_1602088120436/work
183
+ pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work
184
+ parso==0.7.0
185
+ partd==1.1.0
186
+ path @ file:///tmp/build/80754af9/path_1598376507494/work
187
+ pathlib2==2.3.5
188
+ pathspec==0.10.3
189
+ pathtools==0.1.2
190
+ patsy==0.5.1
191
+ pep8==1.7.1
192
+ pexpect==4.8.0
193
+ pickleshare==0.7.5
194
+ Pillow @ file:///tmp/build/80754af9/pillow_1603822255246/work
195
+ pkginfo==1.6.1
196
+ platformdirs==2.6.0
197
+ pluggy==0.13.1
198
+ ply==3.11
199
+ pooch==1.6.0
200
+ pptree==3.1
201
+ pretty-midi==0.2.9
202
+ prometheus-client==0.8.0
203
+ prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work
204
+ protobuf==3.19.6
205
+ psutil @ file:///tmp/build/80754af9/psutil_1598370257551/work
206
+ ptyprocess==0.6.0
207
+ py @ file:///tmp/build/80754af9/py_1593446248552/work
208
+ py-espeak-ng==0.1.8
209
+ py4j==0.10.9.7
210
+ PyArabic==0.6.15
211
+ pyarrow==7.0.0
212
+ pyasn1==0.4.8
213
+ pyasn1-modules==0.2.8
214
+ pycodestyle==2.6.0
215
+ pycosat==0.6.3
216
+ pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
217
+ pycurl==7.43.0.6
218
+ pyDeprecate==0.3.1
219
+ pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1598885001695/work
220
+ pyflakes==2.2.0
221
+ Pygments @ file:///tmp/build/80754af9/pygments_1604103097372/work
222
+ pylint @ file:///tmp/build/80754af9/pylint_1598623985952/work
223
+ pyodbc===4.0.0-unsupported
224
+ pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1594392929924/work
225
+ pyparsing==2.4.7
226
+ pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work
227
+ PySocks==1.7.1
228
+ pytest==0.0.0
229
+ python-bidi==0.4.2
230
+ python-crfsuite==0.9.7
231
+ python-dateutil==2.8.1
232
+ python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work
233
+ python-language-server @ file:///tmp/build/80754af9/python-language-server_1600454544709/work
234
+ python-Levenshtein==0.12.2
235
+ pytorch-lightning==1.4.2
236
+ pytorch-revgrad==0.2.0
237
+ pytz==2020.1
238
+ PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work
239
+ pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work
240
+ PyYAML==5.3.1
241
+ pyzmq==19.0.2
242
+ QDarkStyle==2.8.1
243
+ QtAwesome @ file:///tmp/build/80754af9/qtawesome_1602272867890/work
244
+ qtconsole @ file:///tmp/build/80754af9/qtconsole_1600870028330/work
245
+ QtPy==1.9.0
246
+ regex @ file:///tmp/build/80754af9/regex_1602786672676/work
247
+ requests @ file:///tmp/build/80754af9/requests_1592841827918/work
248
+ requests-oauthlib==1.3.1
249
+ resampy==0.2.2
250
+ rope @ file:///tmp/build/80754af9/rope_1602264064449/work
251
+ rsa==4.9
252
+ Rtree==0.9.4
253
+ ruamel-yaml==0.15.87
254
+ s3transfer==0.6.2
255
+ sacremoses==0.0.49
256
+ safetensors==0.3.3
257
+ scikit-image==0.17.2
258
+ scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1598376899566/work
259
+ scipy @ file:///tmp/build/80754af9/scipy_1597686649129/work
260
+ seaborn @ file:///tmp/build/80754af9/seaborn_1600553570093/work
261
+ SecretStorage==3.1.2
262
+ segtok==1.5.11
263
+ Send2Trash==1.5.0
264
+ sentencepiece==0.1.97
265
+ simplegeneric==0.8.1
266
+ singledispatch @ file:///tmp/build/80754af9/singledispatch_1602523705405/work
267
+ sip==4.19.13
268
+ six @ file:///tmp/build/80754af9/six_1605205327372/work
269
+ smart-open==5.2.1
270
+ snowballstemmer==2.0.0
271
+ sortedcollections==1.2.1
272
+ sortedcontainers==2.2.2
273
+ SoundFile==0.10.3.post1
274
+ soupsieve==2.0.1
275
+ sphfile==1.0.3
276
+ Sphinx @ file:///tmp/build/80754af9/sphinx_1597428793432/work
277
+ sphinxcontrib-applehelp==1.0.2
278
+ sphinxcontrib-devhelp==1.0.2
279
+ sphinxcontrib-htmlhelp==1.0.3
280
+ sphinxcontrib-jsmath==1.0.1
281
+ sphinxcontrib-qthelp==1.0.3
282
+ sphinxcontrib-serializinghtml==1.1.4
283
+ sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work
284
+ spyder @ file:///tmp/build/80754af9/spyder_1599056981321/work
285
+ spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1599056754858/work
286
+ SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1603397987316/work
287
+ sqlitedict==2.1.0
288
+ statsmodels @ file:///tmp/build/80754af9/statsmodels_1602280205159/work
289
+ sympy @ file:///tmp/build/80754af9/sympy_1605119542615/work
290
+ tables==3.6.1
291
+ tabulate==0.9.0
292
+ tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work
293
+ tensorboard==2.10.1
294
+ tensorboard-data-server==0.6.1
295
+ tensorboard-plugin-wit==1.8.1
296
+ terminado==0.9.1
297
+ testpath==0.4.4
298
+ threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl
299
+ tifffile==2020.10.1
300
+ tkseem==0.0.3
301
+ tokenizers==0.13.3
302
+ toml @ file:///tmp/build/80754af9/toml_1592853716807/work
303
+ tomli==2.0.1
304
+ toolz @ file:///tmp/build/80754af9/toolz_1601054250827/work
305
+ torch==1.11.0
306
+ torchaudio==0.11.0
307
+ torchmetrics==0.6.0
308
+ torchvision==0.8.2
309
+ tornado==6.0.4
310
+ tqdm==4.64.0
311
+ traitlets @ file:///tmp/build/80754af9/traitlets_1602787416690/work
312
+ transformer-smaller-training-vocab==0.3.1
313
+ transformers==4.33.1
314
+ typing-extensions==4.4.0
315
+ ujson @ file:///tmp/build/80754af9/ujson_1602523317881/work
316
+ unicodecsv==0.14.1
317
+ urllib3 @ file:///tmp/build/80754af9/urllib3_1603305693037/work
318
+ watchdog @ file:///tmp/build/80754af9/watchdog_1593447344699/work
319
+ wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
320
+ webencodings==0.5.1
321
+ Werkzeug==1.0.1
322
+ widgetsnbextension==3.5.1
323
+ Wikipedia-API==0.6.0
324
+ wrapt==1.11.2
325
+ wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1594753850195/work
326
+ xlrd==1.2.0
327
+ XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1602692860603/work
328
+ xlwt==1.3.0
329
+ xmltodict==0.12.0
330
+ xxhash==3.0.0
331
+ yapf @ file:///tmp/build/80754af9/yapf_1593528177422/work
332
+ yarl==1.7.2
333
+ zict==2.0.0
334
+ zipp @ file:///tmp/build/80754af9/zipp_1604001098328/work
335
+ zope.event==4.5.0
336
+ zope.interface @ file:///tmp/build/80754af9/zope.interface_1602002420968/work
337
+
338
+
339
+ 2023-09-20 16:23:39,866 - speechbrain.utils.superpowers - DEBUG - 8a51838
340
+
341
+
342
+ 2023-09-20 16:23:39,869 - speechbrain.pretrained.fetching - INFO - Fetch hyperparams.yaml: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/hyperparams.yaml.
343
+ 2023-09-20 16:23:39,871 - speechbrain.pretrained.fetching - INFO - Fetch custom.py: Linking to local file in /home/salah/Code_Switched_Tunisian_Speech_Recognition/asr-wav2vec2-commonvoice-fr/custom.py.
344
+ 2023-09-20 16:23:47,958 - speechbrain.lobes.models.huggingface_wav2vec - WARNING - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 is frozen.
345
+ 2023-09-20 16:23:48,018 - speechbrain.utils.parameter_transfer - DEBUG - Collecting files (or symlinks) for pretraining in pretrained_models/asr-wav2vec2-commonvoice-fr.
346
+ 2023-09-20 16:23:48,023 - speechbrain.pretrained.fetching - INFO - Fetch wav2vec2.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/wav2vec2.ckpt.
347
+ 2023-09-20 16:23:48,025 - speechbrain.pretrained.fetching - INFO - Fetch asr.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/asr.ckpt.
348
+ 2023-09-20 16:23:48,028 - speechbrain.pretrained.fetching - INFO - Fetch tokenizer.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/tokenizer.ckpt.
349
+ 2023-09-20 16:23:48,029 - speechbrain.utils.parameter_transfer - INFO - Loading pretrained files for: wav2vec2, asr, tokenizer
350
+ 2023-09-20 16:23:56,361 - speechbrain.lobes.models.huggingface_wav2vec - WARNING - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 feature extractor is frozen.
351
+ 2023-09-20 16:23:56,366 - speechbrain.core - INFO - Info: auto_mix_prec arg from hparam file is used
352
+ 2023-09-20 16:23:56,366 - speechbrain.core - INFO - Info: ckpt_interval_minutes arg from hparam file is used
353
+ 2023-09-20 16:23:56,529 - speechbrain.core - INFO - 314.4M trainable parameters in ASRCV
354
+ 2023-09-20 16:23:57,316 - speechbrain.utils.checkpoints - INFO - Loading a checkpoint from EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00
355
+ 2023-09-20 16:23:59,928 - speechbrain.core - INFO - Info: auto_mix_prec arg from hparam file is used
356
+ 2023-09-20 16:23:59,940 - speechbrain.core - INFO - Info: ckpt_interval_minutes arg from hparam file is used
357
+ 2023-09-20 16:24:00,139 - speechbrain.core - INFO - 314.4M trainable parameters in ASR
358
+ 2023-09-20 16:24:00,967 - speechbrain.utils.checkpoints - INFO - Loading a checkpoint from TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00
359
+ 2023-09-20 16:24:49,007 - speechbrain.utils.distributed - INFO - distributed_launch flag is disabled, this experiment will be executed without DDP.
TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/CKPT.yaml RENAMED
@@ -1,4 +1,4 @@
1
  # yamllint disable
2
- WER: 27.83210816487267
3
  end-of-epoch: true
4
- unixtime: 1693868963.5220973
 
1
  # yamllint disable
2
+ WER: 26.88369650826989
3
  end-of-epoch: true
4
+ unixtime: 1691019518.289327
TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/brain.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3947a24e8dff5a14299b9cf2fe66ffb4d738cb88717de7f0cf7e8547a76e9776
3
  size 51
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c991c52635ebf5f1d342ff11f149ab3000260e4b08bf1b4356e5134002a60feb
3
  size 51
TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/counter.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6b51d431df5d7f141cbececcf79edf3dd861c3b4069f0b11661a3eefacbba918
3
  size 2
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8527a891e224136950ff32ca212b45bc93f69fbb801c3b1ebedac52775f99e61
3
  size 2
TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/dataloader-TRAIN.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b363886c229e536bd3c84e0c3e89312d70e00422578e076a62df1b45c9390793
3
  size 5
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:646ecafa8b16fbb513bf9ddc56ba5e34c8818c0c8a7858871698ef9d15ddea68
3
  size 5
TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/model.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bc1dbeca1e1f1340b08d8ebea6e492f474708dddbbe8cabbcdde5ee9660704f2
3
  size 12814446
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cb0cabac5780ffeb4b9850d30e8cd10f748896ddb04aa963d29512463f9b65c
3
  size 12814446
TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/modelopt.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3af1791eb9a5bfbfc087d2c10b94634df24cad3ac503ce9ba280a3ecc4737781
3
- size 25575663
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20f6aecbdbc179aeac4a305431e9f4d17a3436a4aa8426d20f066d6c99c7b449
3
+ size 25575599
TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/scheduler_model.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c275ab9245b440d1586f72058d9edaac1a2fb3e7a52712aa9a9ad022b99a1c0d
3
  size 639
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a25325a19660b044c1edc3405e8a298702a93f2f569b0548f8066bb50e8e3c8
3
  size 639
TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/scheduler_wav2vec.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a88187f7882dc3e10c108f1b7abfbd819285b34bded4e88e91c4ff699c1bb5d2
3
  size 643
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b238edab7c400d8ad289eb44c8912bc0d1d2144f2ac59e48b9ab736dc4ef5f79
3
  size 643
TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/wav2vec2.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:788267bd25ef37623715fa21a975090e5e316fff05971375cd3f62e5160f0743
3
  size 1262005979
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9f5cd05dd7941f51ce7f19acd406b3eb562de4bc4c6ed818f709563a8308e8d
3
  size 1262005979
TunisianASR/{semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00 → results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00}/wav2vec_opt.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:efa967fdd8067be7d88c18cd197980c9c91f344a3dff2b2518b8381c49f28b1e
3
  size 2490361859
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d191b5b044b5dd0414bafb8c762083c4d344e3c81de807ef6a092dba4a383dd
3
  size 2490361859
TunisianASR/{semi_wavlm_large_tunisian_ctc → results/14epoch_tunisian}/1234/save/label_encoder.txt RENAMED
File without changes
TunisianASR/results/14epoch_tunisian/1234/train_with_wav2vec.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import speechbrain as sb
6
+ from pathlib import Path
7
+ import os
8
+ import torchaudio
9
+ from hyperpyyaml import load_hyperpyyaml
10
+ from speechbrain.tokenizers.SentencePiece import SentencePiece
11
+ from speechbrain.utils.data_utils import undo_padding
12
+ from speechbrain.utils.distributed import run_on_main
13
+
14
+ """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
15
+ The system employs a wav2vec2 encoder and a CTC decoder.
16
+ Decoding is performed with greedy decoding (will be extended to beam search).
17
+
18
+ To run this recipe, do the following:
19
+ > python train_with_wav2vec2.py hparams/train_with_wav2vec2.yaml
20
+
21
+ With the default hyperparameters, the system employs a pretrained wav2vec2 encoder.
22
+ The wav2vec2 model is pretrained following the model given in the hprams file.
23
+ It may be dependent on the language.
24
+
25
+ The neural network is trained with CTC on sub-word units estimated with
26
+ Byte Pairwise Encoding (BPE).
27
+
28
+ The experiment file is flexible enough to support a large variety of
29
+ different systems. By properly changing the parameter files, you can try
30
+ different encoders, decoders, tokens (e.g, characters instead of BPE),
31
+ training languages (all CommonVoice languages), and many
32
+ other possible variations.
33
+
34
+ Authors
35
+ * Titouan Parcollet 2021
36
+ """
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ # Define training procedure
42
+ class ASR(sb.core.Brain):
43
+ def compute_forward(self, batch, stage):
44
+ """Forward computations from the waveform batches to the output probabilities."""
45
+
46
+ batch = batch.to(self.device)
47
+ wavs, wav_lens = batch.sig
48
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
49
+ if stage == sb.Stage.TRAIN:
50
+ if hasattr(self.hparams, "augmentation"):
51
+ wavs = self.hparams.augmentation(wavs, wav_lens)
52
+
53
+ # Forward pass
54
+ feats = self.modules.wav2vec2(wavs, wav_lens)
55
+ x = self.modules.enc(feats)
56
+ logits = self.modules.ctc_lin(x)
57
+ p_ctc = self.hparams.log_softmax(logits)
58
+
59
+ return p_ctc, wav_lens
60
+
61
+ def compute_objectives(self, predictions, batch, stage):
62
+ """Computes the loss (CTC) given predictions and targets."""
63
+
64
+ p_ctc, wav_lens = predictions
65
+
66
+ ids = batch.id
67
+ tokens, tokens_lens = batch.tokens
68
+
69
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
70
+
71
+ if stage != sb.Stage.TRAIN:
72
+ predicted_tokens = sb.decoders.ctc_greedy_decode(
73
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
74
+ )
75
+ # Decode token terms to words
76
+ if self.hparams.use_language_modelling:
77
+ predicted_words = []
78
+ for logs in p_ctc:
79
+ text = decoder.decode(logs.detach().cpu().numpy())
80
+ predicted_words.append(text.split(" "))
81
+ else:
82
+ predicted_words = [
83
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
84
+ for utt_seq in predicted_tokens
85
+ ]
86
+ # Convert indices to words
87
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
88
+
89
+ self.wer_metric.append(ids, predicted_words, target_words)
90
+ self.cer_metric.append(ids, predicted_words, target_words)
91
+
92
+ return loss
93
+
94
+ def fit_batch(self, batch):
95
+ """Train the parameters given a single batch in input"""
96
+ should_step = self.step % self.grad_accumulation_factor == 0
97
+ # Managing automatic mixed precision
98
+ # TOFIX: CTC fine-tuning currently is unstable
99
+ # This is certainly due to CTC being done in fp16 instead of fp32
100
+ if self.auto_mix_prec:
101
+ with torch.cuda.amp.autocast():
102
+ with self.no_sync():
103
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
104
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
105
+ with self.no_sync(not should_step):
106
+ self.scaler.scale(
107
+ loss / self.grad_accumulation_factor
108
+ ).backward()
109
+ if should_step:
110
+
111
+ if not self.hparams.wav2vec2.freeze:
112
+ self.scaler.unscale_(self.wav2vec_optimizer)
113
+ self.scaler.unscale_(self.model_optimizer)
114
+ if self.check_gradients(loss):
115
+ if not self.hparams.wav2vec2.freeze:
116
+ if self.optimizer_step >= self.hparams.warmup_steps:
117
+ self.scaler.step(self.wav2vec_optimizer)
118
+ self.scaler.step(self.model_optimizer)
119
+ self.scaler.update()
120
+ self.zero_grad()
121
+ self.optimizer_step += 1
122
+ else:
123
+ # This is mandatory because HF models have a weird behavior with DDP
124
+ # on the forward pass
125
+ with self.no_sync():
126
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
127
+
128
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
129
+
130
+ with self.no_sync(not should_step):
131
+ (loss / self.grad_accumulation_factor).backward()
132
+ if should_step:
133
+ if self.check_gradients(loss):
134
+ if not self.hparams.wav2vec2.freeze:
135
+ if self.optimizer_step >= self.hparams.warmup_steps:
136
+ self.wav2vec_optimizer.step()
137
+ self.model_optimizer.step()
138
+ self.zero_grad()
139
+ self.optimizer_step += 1
140
+
141
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
142
+ return loss.detach().cpu()
143
+
144
+ def evaluate_batch(self, batch, stage):
145
+ """Computations needed for validation/test batches"""
146
+ predictions = self.compute_forward(batch, stage=stage)
147
+ with torch.no_grad():
148
+ loss = self.compute_objectives(predictions, batch, stage=stage)
149
+ return loss.detach()
150
+
151
+ def on_stage_start(self, stage, epoch):
152
+ """Gets called at the beginning of each epoch"""
153
+ if stage != sb.Stage.TRAIN:
154
+ self.cer_metric = self.hparams.cer_computer()
155
+ self.wer_metric = self.hparams.error_rate_computer()
156
+
157
+ def on_stage_end(self, stage, stage_loss, epoch):
158
+ """Gets called at the end of an epoch."""
159
+ # Compute/store important stats
160
+ stage_stats = {"loss": stage_loss}
161
+ if stage == sb.Stage.TRAIN:
162
+ self.train_stats = stage_stats
163
+ else:
164
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
165
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
166
+
167
+ # Perform end-of-iteration things, like annealing, logging, etc.
168
+ if stage == sb.Stage.VALID:
169
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
170
+ stage_stats["loss"]
171
+ )
172
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
173
+ stage_stats["loss"]
174
+ )
175
+ sb.nnet.schedulers.update_learning_rate(
176
+ self.model_optimizer, new_lr_model
177
+ )
178
+ if not self.hparams.wav2vec2.freeze:
179
+ sb.nnet.schedulers.update_learning_rate(
180
+ self.wav2vec_optimizer, new_lr_wav2vec
181
+ )
182
+ self.hparams.train_logger.log_stats(
183
+ stats_meta={
184
+ "epoch": epoch,
185
+ "lr_model": old_lr_model,
186
+ "lr_wav2vec": old_lr_wav2vec,
187
+ },
188
+ train_stats=self.train_stats,
189
+ valid_stats=stage_stats,
190
+ )
191
+ self.checkpointer.save_and_keep_only(
192
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
193
+ )
194
+ elif stage == sb.Stage.TEST:
195
+ self.hparams.train_logger.log_stats(
196
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
197
+ test_stats=stage_stats,
198
+ )
199
+ with open(self.hparams.wer_file, "w") as w:
200
+ self.wer_metric.write_stats(w)
201
+
202
+ def init_optimizers(self):
203
+ "Initializes the wav2vec2 optimizer and model optimizer"
204
+
205
+ # If the wav2vec encoder is unfrozen, we create the optimizer
206
+ if not self.hparams.wav2vec2.freeze:
207
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
208
+ self.modules.wav2vec2.parameters()
209
+ )
210
+ if self.checkpointer is not None:
211
+ self.checkpointer.add_recoverable(
212
+ "wav2vec_opt", self.wav2vec_optimizer
213
+ )
214
+
215
+ self.model_optimizer = self.hparams.model_opt_class(
216
+ self.hparams.model.parameters()
217
+ )
218
+
219
+ if self.checkpointer is not None:
220
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
221
+
222
+ def zero_grad(self, set_to_none=False):
223
+ if not self.hparams.wav2vec2.freeze:
224
+ self.wav2vec_optimizer.zero_grad(set_to_none)
225
+ self.model_optimizer.zero_grad(set_to_none)
226
+
227
+
228
+ # Define custom data procedure
229
+ def dataio_prepare(hparams):
230
+ """This function prepares the datasets to be used in the brain class.
231
+ It also defines the data processing pipeline through user-defined functions."""
232
+
233
+ # 1. Define datasets
234
+ data_folder = hparams["data_folder"]
235
+
236
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
237
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
238
+ )
239
+
240
+ if hparams["sorting"] == "ascending":
241
+ # we sort training data to speed up training and get better results.
242
+ train_data = train_data.filtered_sorted(
243
+ sort_key="duration",
244
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
245
+ )
246
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
247
+ hparams["dataloader_options"]["shuffle"] = False
248
+
249
+ elif hparams["sorting"] == "descending":
250
+ train_data = train_data.filtered_sorted(
251
+ sort_key="duration",
252
+ reverse=True,
253
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
254
+ )
255
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
256
+ hparams["dataloader_options"]["shuffle"] = False
257
+
258
+ elif hparams["sorting"] == "random":
259
+ pass
260
+
261
+ else:
262
+ raise NotImplementedError(
263
+ "sorting must be random, ascending or descending"
264
+ )
265
+
266
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
267
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
268
+ )
269
+ # We also sort the validation data so it is faster to validate
270
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
271
+ test_datasets = {}
272
+ for csv_file in hparams["test_csv"]:
273
+ name = Path(csv_file).stem
274
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
275
+ csv_path=csv_file, replacements={"data_root": data_folder}
276
+ )
277
+ test_datasets[name] = test_datasets[name].filtered_sorted(
278
+ sort_key="duration"
279
+ )
280
+
281
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
282
+
283
+
284
+ # 2. Define audio pipeline:
285
+ @sb.utils.data_pipeline.takes("wav")
286
+ @sb.utils.data_pipeline.provides("sig")
287
+ def audio_pipeline(wav):
288
+ info = torchaudio.info(wav)
289
+ sig = sb.dataio.dataio.read_audio(wav)
290
+ resampled = torchaudio.transforms.Resample(
291
+ info.sample_rate, hparams["sample_rate"],
292
+ )(sig)
293
+ return resampled
294
+
295
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
296
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
297
+
298
+ # 3. Define text pipeline:
299
+ @sb.utils.data_pipeline.takes("wrd")
300
+ @sb.utils.data_pipeline.provides(
301
+ "wrd", "char_list", "tokens_list", "tokens"
302
+ )
303
+ def text_pipeline(wrd):
304
+ yield wrd
305
+ char_list = list(wrd)
306
+ yield char_list
307
+ tokens_list = label_encoder.encode_sequence(char_list)
308
+ yield tokens_list
309
+ tokens = torch.LongTensor(tokens_list)
310
+ yield tokens
311
+
312
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
313
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
314
+ special_labels = {
315
+ "blank_label": hparams["blank_index"],
316
+ "unk_label": hparams["unk_index"]
317
+ }
318
+ label_encoder.load_or_create(
319
+ path=lab_enc_file,
320
+ from_didatasets=[train_data],
321
+ output_key="char_list",
322
+ special_labels=special_labels,
323
+ sequence_input=True,
324
+ )
325
+
326
+ # 4. Set output:
327
+ sb.dataio.dataset.set_output_keys(
328
+ datasets, ["id", "sig", "wrd", "char_list", "tokens"],
329
+ )
330
+ return train_data, valid_data,test_datasets, label_encoder
331
+
332
+
333
+ if __name__ == "__main__":
334
+
335
+ # Load hyperparameters file with command-line overrides
336
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
337
+ with open(hparams_file) as fin:
338
+ hparams = load_hyperpyyaml(fin, overrides)
339
+
340
+ # If --distributed_launch then
341
+ # create ddp_group with the right communication protocol
342
+ sb.utils.distributed.ddp_init_group(run_opts)
343
+
344
+
345
+ # Create experiment directory
346
+ sb.create_experiment_directory(
347
+ experiment_directory=hparams["output_folder"],
348
+ hyperparams_to_save=hparams_file,
349
+ overrides=overrides,
350
+ )
351
+
352
+ # Due to DDP, we do the preparation ONLY on the main python process
353
+ # Defining tokenizer and loading it
354
+ # Create the datasets objects as well as tokenization and encoding :-D
355
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(hparams)
356
+ if hparams["use_language_modelling"]:
357
+ print("using langauge_modeeling")
358
+ from pyctcdecode import build_ctcdecoder
359
+ ind2lab = label_encoder.ind2lab
360
+ print(ind2lab)
361
+ labels = [ind2lab[x] for x in range(len(ind2lab))]
362
+ labels = [""] + labels[1:-1] + ["1"]
363
+ # Replace the <blank> token with a blank character, needed for PyCTCdecode
364
+ print(labels)
365
+ decoder = build_ctcdecoder(
366
+ labels,
367
+ kenlm_model_path=hparams["ngram_lm_path"], # .arpa or .bin
368
+ alpha=0.5, # Default by KenLM
369
+ beta=1.0, # Default by KenLM
370
+ )
371
+ # Trainer initialization
372
+ asr_brain = ASR(
373
+ modules=hparams["modules"],
374
+ hparams=hparams,
375
+ run_opts=run_opts,
376
+ checkpointer=hparams["checkpointer"],
377
+ )
378
+
379
+ # Adding objects to trainer.
380
+ asr_brain.tokenizer = label_encoder
381
+
382
+ # Training
383
+ asr_brain.fit(
384
+ asr_brain.hparams.epoch_counter,
385
+ train_data,
386
+ valid_data,
387
+ train_loader_kwargs=hparams["dataloader_options"],
388
+ valid_loader_kwargs=hparams["test_dataloader_options"],
389
+ )
390
+
391
+ # Test
392
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
393
+ asr_brain.hparams.wer_file = os.path.join(
394
+ hparams["output_folder"], "wer_{}.txt".format(k)
395
+ )
396
+ asr_brain.evaluate(
397
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
398
+ )
399
+
TunisianASR/results/14epoch_tunisian/<seed>/copy_of_wavlm_tun.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import sys
5
+ import torch
6
+ import logging
7
+ import speechbrain as sb
8
+ from speechbrain.utils.distributed import run_on_main
9
+ from hyperpyyaml import load_hyperpyyaml
10
+ from pathlib import Path
11
+ import torchaudio.transforms as T
12
+ import torchaudio
13
+ import numpy as np
14
+ import kenlm
15
+ from pyctcdecode import build_ctcdecoder
16
+ import re
17
+
18
+ # Commented out IPython magic to ensure Python compatibility.
19
+ # %cd /content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm
20
+
21
+ hparams_file, run_opts, overrides = sb.parse_arguments(["semi_supervised_test_tunisian.yaml"])
22
+
23
+ # If distributed_launch=True then
24
+ # create ddp_group with the right communication protocol
25
+ sb.utils.distributed.ddp_init_group(run_opts)
26
+
27
+ with open(hparams_file) as fin:
28
+ hparams = load_hyperpyyaml(fin, overrides)
29
+
30
+ # Create experiment directory
31
+ sb.create_experiment_directory(
32
+ experiment_directory=hparams["output_folder"],
33
+ hyperparams_to_save=hparams_file,
34
+ overrides=overrides,
35
+ )
36
+ """
37
+ def read_labels_file(labels_file):
38
+ with open(labels_file, "r",encoding="utf-8") as lf:
39
+ lines = lf.read().splitlines()
40
+ division = "==="
41
+ numbers = {}
42
+ for line in lines :
43
+ if division in line :
44
+ break
45
+ string, number = line.split("=>")
46
+ number = int(number)
47
+ string = string[1:-2]
48
+ numbers[number] = string
49
+ return [numbers[x] for x in range(len(numbers))]
50
+
51
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
52
+ labels = [""] + labels[1:-1] + ["1"]
53
+
54
+ # Dataset prep (parsing Librispeech)
55
+ """
56
+
57
+ def dataio_prepare(hparams):
58
+ """This function prepares the datasets to be used in the brain class.
59
+ It also defines the data processing pipeline through user-defined functions."""
60
+
61
+ # 1. Define datasets
62
+ data_folder = hparams["data_folder"]
63
+
64
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
65
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
66
+ )
67
+
68
+ if hparams["sorting"] == "ascending":
69
+ # we sort training data to speed up training and get better results.
70
+ train_data = train_data.filtered_sorted(
71
+ sort_key="duration",
72
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
73
+ )
74
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
75
+ hparams["dataloader_options"]["shuffle"] = False
76
+
77
+ elif hparams["sorting"] == "descending":
78
+ train_data = train_data.filtered_sorted(
79
+ sort_key="duration",
80
+ reverse=True,
81
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
82
+ )
83
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
84
+ hparams["dataloader_options"]["shuffle"] = False
85
+
86
+ elif hparams["sorting"] == "random":
87
+ pass
88
+
89
+ else:
90
+ raise NotImplementedError(
91
+ "sorting must be random, ascending or descending"
92
+ )
93
+
94
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
95
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
96
+ )
97
+ # We also sort the validation data so it is faster to validate
98
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
99
+ test_datasets = {}
100
+ for csv_file in hparams["test_csv"]:
101
+ name = Path(csv_file).stem
102
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
103
+ csv_path=csv_file, replacements={"data_root": data_folder}
104
+ )
105
+ test_datasets[name] = test_datasets[name].filtered_sorted(
106
+ sort_key="duration"
107
+ )
108
+
109
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
110
+
111
+
112
+ # 2. Define audio pipeline:
113
+ @sb.utils.data_pipeline.takes("wav")
114
+ @sb.utils.data_pipeline.provides("sig")
115
+ def audio_pipeline(wav):
116
+ info = torchaudio.info(wav)
117
+ sig = sb.dataio.dataio.read_audio(wav)
118
+ if len(sig.shape)>1 :
119
+ sig = torch.mean(sig, dim=1)
120
+ resampled = torchaudio.transforms.Resample(
121
+ info.sample_rate, hparams["sample_rate"],
122
+ )(sig)
123
+ return resampled
124
+
125
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
126
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
127
+
128
+ # 3. Define text pipeline:
129
+ @sb.utils.data_pipeline.takes("wrd")
130
+ @sb.utils.data_pipeline.provides(
131
+ "wrd", "char_list", "tokens_list", "tokens"
132
+ )
133
+ def text_pipeline(wrd):
134
+ yield wrd
135
+ char_list = list(wrd)
136
+ yield char_list
137
+ tokens_list = label_encoder.encode_sequence(char_list)
138
+ yield tokens_list
139
+ tokens = torch.LongTensor(tokens_list)
140
+ yield tokens
141
+
142
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
143
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
144
+ special_labels = {
145
+ "blank_label": hparams["blank_index"],
146
+ "unk_label": hparams["unk_index"]
147
+ }
148
+ label_encoder.load_or_create(
149
+ path=lab_enc_file,
150
+ from_didatasets=[train_data],
151
+ output_key="char_list",
152
+ special_labels=special_labels,
153
+ sequence_input=True,
154
+ )
155
+
156
+ # 4. Set output:
157
+ sb.dataio.dataset.set_output_keys(
158
+ datasets, ["id", "sig", "wrd", "char_list", "tokens"],
159
+ )
160
+ return train_data, valid_data,test_datasets, label_encoder
161
+
162
+ class ASR(sb.core.Brain):
163
+ def compute_forward(self, batch, stage):
164
+ """Forward computations from the waveform batches to the output probabilities."""
165
+
166
+ batch = batch.to(self.device)
167
+ wavs, wav_lens = batch.sig
168
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
169
+
170
+ if stage == sb.Stage.TRAIN:
171
+ if hasattr(self.hparams, "augmentation"):
172
+ wavs = self.hparams.augmentation(wavs, wav_lens)
173
+
174
+ # Forward pass
175
+ feats = self.modules.wav2vec2(wavs, wav_lens)
176
+ x = self.modules.enc(feats)
177
+ logits = self.modules.ctc_lin(x)
178
+ p_ctc = self.hparams.log_softmax(logits)
179
+
180
+ return p_ctc, wav_lens
181
+
182
+ def custom_encode(self,wavs,wav_lens) :
183
+ wavs = wavs.to(self.device)
184
+ if(wav_lens is not None): wav_lens.to(self.device)
185
+
186
+ feats = self.modules.wav2vec2(wavs, wav_lens)
187
+ x = self.modules.enc(feats)
188
+ logits = self.modules.ctc_lin(x)
189
+ p_ctc = self.hparams.log_softmax(logits)
190
+
191
+ return feats,p_ctc
192
+
193
+
194
+
195
+ def compute_objectives(self, predictions, batch, stage):
196
+ """Computes the loss (CTC) given predictions and targets."""
197
+
198
+ p_ctc, wav_lens = predictions
199
+
200
+ ids = batch.id
201
+ tokens, tokens_lens = batch.tokens
202
+
203
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
204
+
205
+ if stage != sb.Stage.TRAIN:
206
+ predicted_tokens = sb.decoders.ctc_greedy_decode(
207
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
208
+ )
209
+ # Decode token terms to words
210
+ if self.hparams.use_language_modelling:
211
+ predicted_words = []
212
+ for logs in p_ctc:
213
+ text = decoder.decode(logs.detach().cpu().numpy())
214
+ predicted_words.append(text.split(" "))
215
+ else:
216
+ predicted_words = [
217
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
218
+ for utt_seq in predicted_tokens
219
+ ]
220
+ # Convert indices to words
221
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
222
+
223
+ self.wer_metric.append(ids, predicted_words, target_words)
224
+ self.cer_metric.append(ids, predicted_words, target_words)
225
+
226
+ return loss
227
+
228
+ def fit_batch(self, batch):
229
+ """Train the parameters given a single batch in input"""
230
+ should_step = self.step % self.grad_accumulation_factor == 0
231
+ # Managing automatic mixed precision
232
+ # TOFIX: CTC fine-tuning currently is unstable
233
+ # This is certainly due to CTC being done in fp16 instead of fp32
234
+ if self.auto_mix_prec:
235
+ with torch.cuda.amp.autocast():
236
+ with self.no_sync():
237
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
238
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
239
+ with self.no_sync(not should_step):
240
+ self.scaler.scale(
241
+ loss / self.grad_accumulation_factor
242
+ ).backward()
243
+ if should_step:
244
+
245
+ if not self.hparams.wav2vec2.freeze:
246
+ self.scaler.unscale_(self.wav2vec_optimizer)
247
+ self.scaler.unscale_(self.model_optimizer)
248
+ if self.check_gradients(loss):
249
+ if not self.hparams.wav2vec2.freeze:
250
+ if self.optimizer_step >= self.hparams.warmup_steps:
251
+ self.scaler.step(self.wav2vec_optimizer)
252
+ self.scaler.step(self.model_optimizer)
253
+ self.scaler.update()
254
+ self.zero_grad()
255
+ self.optimizer_step += 1
256
+ else:
257
+ # This is mandatory because HF models have a weird behavior with DDP
258
+ # on the forward pass
259
+ with self.no_sync():
260
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
261
+
262
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
263
+
264
+ with self.no_sync(not should_step):
265
+ (loss / self.grad_accumulation_factor).backward()
266
+ if should_step:
267
+ if self.check_gradients(loss):
268
+ if not self.hparams.wav2vec2.freeze:
269
+ if self.optimizer_step >= self.hparams.warmup_steps:
270
+ self.wav2vec_optimizer.step()
271
+ self.model_optimizer.step()
272
+ self.zero_grad()
273
+ self.optimizer_step += 1
274
+
275
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
276
+ return loss.detach().cpu()
277
+
278
+ def evaluate_batch(self, batch, stage):
279
+ """Computations needed for validation/test batches"""
280
+ predictions = self.compute_forward(batch, stage=stage)
281
+ with torch.no_grad():
282
+ loss = self.compute_objectives(predictions, batch, stage=stage)
283
+ return loss.detach()
284
+
285
+ def on_stage_start(self, stage, epoch):
286
+ """Gets called at the beginning of each epoch"""
287
+ if stage != sb.Stage.TRAIN:
288
+ self.cer_metric = self.hparams.cer_computer()
289
+ self.wer_metric = self.hparams.error_rate_computer()
290
+
291
+ def on_stage_end(self, stage, stage_loss, epoch):
292
+ """Gets called at the end of an epoch."""
293
+ # Compute/store important stats
294
+ stage_stats = {"loss": stage_loss}
295
+ if stage == sb.Stage.TRAIN:
296
+ self.train_stats = stage_stats
297
+ else:
298
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
299
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
300
+
301
+ # Perform end-of-iteration things, like annealing, logging, etc.
302
+ if stage == sb.Stage.VALID:
303
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
304
+ stage_stats["loss"]
305
+ )
306
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
307
+ stage_stats["loss"]
308
+ )
309
+ sb.nnet.schedulers.update_learning_rate(
310
+ self.model_optimizer, new_lr_model
311
+ )
312
+ if not self.hparams.wav2vec2.freeze:
313
+ sb.nnet.schedulers.update_learning_rate(
314
+ self.wav2vec_optimizer, new_lr_wav2vec
315
+ )
316
+ self.hparams.train_logger.log_stats(
317
+ stats_meta={
318
+ "epoch": epoch,
319
+ "lr_model": old_lr_model,
320
+ "lr_wav2vec": old_lr_wav2vec,
321
+ },
322
+ train_stats=self.train_stats,
323
+ valid_stats=stage_stats,
324
+ )
325
+ self.checkpointer.save_and_keep_only(
326
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
327
+ )
328
+ elif stage == sb.Stage.TEST:
329
+ self.hparams.train_logger.log_stats(
330
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
331
+ test_stats=stage_stats,
332
+ )
333
+ with open(self.hparams.wer_file, "w") as w:
334
+ self.wer_metric.write_stats(w)
335
+
336
+ def init_optimizers(self):
337
+ "Initializes the wav2vec2 optimizer and model optimizer"
338
+
339
+ # If the wav2vec encoder is unfrozen, we create the optimizer
340
+ if not self.hparams.wav2vec2.freeze:
341
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
342
+ self.modules.wav2vec2.parameters()
343
+ )
344
+ if self.checkpointer is not None:
345
+ self.checkpointer.add_recoverable(
346
+ "wav2vec_opt", self.wav2vec_optimizer
347
+ )
348
+
349
+ self.model_optimizer = self.hparams.model_opt_class(
350
+ self.hparams.model.parameters()
351
+ )
352
+
353
+ if self.checkpointer is not None:
354
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
355
+
356
+ def zero_grad(self, set_to_none=False):
357
+ if not self.hparams.wav2vec2.freeze:
358
+ self.wav2vec_optimizer.zero_grad(set_to_none)
359
+ self.model_optimizer.zero_grad(set_to_none)
360
+
361
+
362
+ """
363
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
364
+
365
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
366
+ hparams
367
+ )
368
+
369
+
370
+ # We dynamicaly add the tokenizer to our brain class.
371
+ # NB: This tokenizer corresponds to the one used for the LM!!
372
+ decoder = build_ctcdecoder(
373
+ labels,
374
+ kenlm_model_path="/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/lm_data/arpas/indomain.arpa", # either .arpa or .bin file
375
+ alpha=0.5, # tuned on a val set
376
+ beta=1, # tuned on a val set
377
+ )
378
+ """
379
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
380
+ french_asr_model = EncoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-fr", savedir="pretrained_models/asr-wav2vec2-commonvoice-fr").cuda()
381
+ french_asr_model.mods.eval()
382
+ #french_asr_model = "r"
383
+
384
+ english_asr_model = EncoderDecoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-en", savedir="pretrained_models/asr-wav2vec2-commonvoice-en/").cuda()
385
+ english_asr_model.mods.eval()
386
+
387
+ asr_brain = ASR(
388
+ modules=hparams["modules"],
389
+ hparams=hparams,
390
+ run_opts=run_opts,
391
+ checkpointer=hparams["checkpointer"],
392
+ )
393
+ asr_brain.checkpointer.recover_if_possible()
394
+ asr_brain.modules.eval()
395
+ """
396
+ asr_brain.tokenizer = label_encoder
397
+
398
+ # Testing
399
+ real = True
400
+ if real :
401
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
402
+ asr_brain.hparams.wer_file = os.path.join(
403
+ hparams["output_folder"], "wer_{}.txt".format(k)
404
+ )
405
+ asr_brain.evaluate(
406
+ test_datasets[k], test_loader_kwargs=hparams["dataloader_options"]
407
+ )
408
+ """
409
+
410
+ """
411
+ from torch.nn.utils.rnn import pad_sequence
412
+ def load_paths(wavs_path):
413
+ waveforms = []
414
+ for path in wavs_path :
415
+ waveform, _ = torchaudio.load(path)
416
+ waveforms.append(waveform.squeeze(0))
417
+ # normalize array length to the bigger arrays by pading with 0's
418
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
419
+ return torch.tensor(padded_arrays)
420
+
421
+ waveform = load_paths(["/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav","/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav"])
422
+ embeddings, posteriogram = asr_brain.custom_encode(waveform,None)
423
+ print(embeddings.shape)
424
+ print(posteriogram.shape)
425
+ """
426
+
427
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
428
+ import torchaudio
429
+ import speechbrain as sb
430
+ import torch
431
+ from torch.nn.utils.rnn import pad_sequence
432
+ import torch
433
+ import speechbrain as sb
434
+ import numpy as np
435
+ import torch.optim as optim
436
+ import torch.nn as nn
437
+
438
+ # Commented out IPython magic to ensure Python compatibility.
439
+ # %ls
440
+
441
+ #UTILS FUNCTIOJNS
442
+ def get_size_dimensions(arr):
443
+ size_dimensions = []
444
+ while isinstance(arr, list):
445
+ size_dimensions.append(len(arr))
446
+ arr = arr[0]
447
+ return size_dimensions
448
+
449
+ def scale_array(batch,n):
450
+ scaled_batch = []
451
+
452
+ for array in batch:
453
+ if(n < len(array)): raise ValueError("Cannot scale Array down")
454
+
455
+ repeat = round(n/len(array))+1
456
+ scaled_length_array= []
457
+
458
+ for i in array:
459
+ for j in range(repeat) :
460
+ if(len(scaled_length_array) == n): break
461
+ scaled_length_array.append(i)
462
+
463
+ scaled_batch.append(scaled_length_array)
464
+
465
+ return torch.tensor(scaled_batch)
466
+
467
+
468
+ def load_paths(wavs_path):
469
+ waveforms = []
470
+ for path in wavs_path :
471
+ waveform, _ = torchaudio.load(path)
472
+ waveforms.append(waveform.squeeze(0))
473
+ # normalize array length to the bigger arrays by pading with 0's
474
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
475
+ return torch.tensor(padded_arrays)
476
+
477
+
478
+
479
+ def word_to_vec(input_string):
480
+ mapping= {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, 'ا': 27, 'ب': 28, 'ت': 29, 'ث': 30, 'ج': 31, 'ح': 32, 'خ': 33, 'د': 34, 'ذ': 35, 'ر': 36, 'ز': 37, 'س': 38, 'ش': 39, 'ص': 40, 'ض': 41, 'ط': 42, 'ظ': 43, 'ع': 44, 'غ': 45, 'ف': 46, 'ق': 47, 'ك': 48, 'ل': 49, 'م': 50, 'ن': 51, 'ه': 52, 'و': 53, 'ي': 54,' ':55}
481
+
482
+ numbers = [mapping[word] for word in input_string if word in mapping]
483
+ return numbers
484
+
485
+ device = 'cuda'
486
+ verbose = 0
487
+ #FLOW LEVEL FUNCTIONS
488
+ def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):
489
+
490
+
491
+ post1 = post1.to(device)
492
+ post2 = post2.to(device)
493
+ post3 = post3.to(device)
494
+ embeddings1 = embeddings1.to(device)
495
+ embeddings2 = embeddings2.to(device)
496
+ embeddings3 = embeddings3.to(device)
497
+
498
+ posteriograms_merged = torch.cat((post1,post2,post3),dim=2)
499
+ embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)
500
+
501
+ if(verbose !=0):
502
+ print('MERGED POST ',posteriograms_merged.shape)
503
+ print('MERGED emb ',embeddings_merged.shape)
504
+
505
+ return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)
506
+
507
+ def decode(model,wavs,wav_lens):
508
+
509
+ with torch.no_grad():
510
+ wav_lens = wav_lens.to(model.device)
511
+ encoder_out = model.encode_batch(wavs, wav_lens)
512
+ predictions = model.decoding_function(encoder_out, wav_lens)
513
+ return predictions
514
+
515
+ def middle_layer(batch, lens):
516
+
517
+ tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)
518
+
519
+ fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)
520
+ fr_posteriogram =french_asr_model.encode_batch(batch,lens)
521
+
522
+ en_embeddings = english_asr_model.encode_batch(batch, lens)
523
+ #scores, en_posteriogram = english_asr_model.mods.decoder(en_embeddings ,lens)
524
+ en_posteriogram = en_embeddings
525
+
526
+ if(verbose !=0):
527
+ print('[EMBEDDINGS] FR:',fr_embeddings.shape, "EN:",en_embeddings.shape, "TN:", tn_embeddings.shape)
528
+ print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, "EN:",en_posteriogram.shape,"TN:",tn_posteriogram.shape)
529
+
530
+
531
+ bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)
532
+ return bilangual_sample
533
+
534
+ class Mixer(sb.core.Brain):
535
+
536
+ def compute_forward(self, batch, stage):
537
+ """Forward computations from the waveform batches to the output probabilities."""
538
+ wavs, wav_lens = batch.sig
539
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
540
+
541
+ if stage == sb.Stage.TRAIN:
542
+ if hasattr(self.hparams, "augmentation"):
543
+ wavs = self.hparams.augmentation(wavs, wav_lens)
544
+
545
+ multi_langual_feats = middle_layer(wavs, wav_lens)
546
+ multi_langual_feats= multi_langual_feats.to(device)
547
+ feats, _ = self.modules.enc(multi_langual_feats)
548
+ logits = self.modules.ctc_lin(feats)
549
+ p_ctc = self.hparams.log_softmax(logits)
550
+
551
+ if stage!= sb.Stage.TRAIN:
552
+ p_tokens = sb.decoders.ctc_greedy_decode(
553
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
554
+ )
555
+ else :
556
+ p_tokens = None
557
+ return p_ctc, wav_lens, p_tokens
558
+
559
+ def compute_objectives(self, predictions, batch, stage):
560
+ """Computes the loss (CTC) given predictions and targets."""
561
+
562
+ p_ctc, wav_lens , predicted_tokens= predictions
563
+
564
+ ids = batch.id
565
+ tokens, tokens_lens = batch.tokens
566
+
567
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
568
+
569
+
570
+ if stage != sb.Stage.TRAIN:
571
+ predicted_words = [
572
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
573
+ for utt_seq in predicted_tokens
574
+ ]
575
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
576
+ self.wer_metric.append(ids, predicted_words, target_words)
577
+ self.cer_metric.append(ids, predicted_words, target_words)
578
+
579
+ return loss
580
+
581
+ def fit_batch(self, batch):
582
+ """Train the parameters given a single batch in input"""
583
+ should_step = self.step % self.grad_accumulation_factor == 0
584
+ # Managing automatic mixed precision
585
+ # TOFIX: CTC fine-tuning currently is unstable
586
+ # This is certainly due to CTC being done in fp16 instead of fp32
587
+ if self.auto_mix_prec:
588
+ with torch.cuda.amp.autocast():
589
+ with self.no_sync():
590
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
591
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
592
+ with self.no_sync(not should_step):
593
+ self.scaler.scale(
594
+ loss / self.grad_accumulation_factor
595
+ ).backward()
596
+ if should_step:
597
+
598
+
599
+ self.scaler.unscale_(self.model_optimizer)
600
+ if self.check_gradients(loss):
601
+ self.scaler.step(self.model_optimizer)
602
+ self.scaler.update()
603
+ self.zero_grad()
604
+ self.optimizer_step += 1
605
+ else:
606
+ # This is mandatory because HF models have a weird behavior with DDP
607
+ # on the forward pass
608
+ with self.no_sync():
609
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
610
+
611
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
612
+
613
+ with self.no_sync(not should_step):
614
+ (loss / self.grad_accumulation_factor).backward()
615
+ if should_step:
616
+ if self.check_gradients(loss):
617
+ self.model_optimizer.step()
618
+ self.zero_grad()
619
+ self.optimizer_step += 1
620
+
621
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
622
+ return loss.detach().cpu()
623
+
624
+ def evaluate_batch(self, batch, stage):
625
+ """Computations needed for validation/test batches"""
626
+ predictions = self.compute_forward(batch, stage=stage)
627
+ with torch.no_grad():
628
+ loss = self.compute_objectives(predictions, batch, stage=stage)
629
+ return loss.detach()
630
+
631
+ def on_stage_start(self, stage, epoch):
632
+ """Gets called at the beginning of each epoch"""
633
+ if stage != sb.Stage.TRAIN:
634
+ self.cer_metric = self.hparams.cer_computer()
635
+ self.wer_metric = self.hparams.error_rate_computer()
636
+
637
+ def on_stage_end(self, stage, stage_loss, epoch):
638
+ """Gets called at the end of an epoch."""
639
+ # Compute/store important stats
640
+ stage_stats = {"loss": stage_loss}
641
+ if stage == sb.Stage.TRAIN:
642
+ self.train_stats = stage_stats
643
+ else:
644
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
645
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
646
+
647
+ # Perform end-of-iteration things, like annealing, logging, etc.
648
+ if stage == sb.Stage.VALID:
649
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
650
+ stage_stats["loss"]
651
+ )
652
+ sb.nnet.schedulers.update_learning_rate(
653
+ self.model_optimizer, new_lr_model
654
+ )
655
+ self.hparams.train_logger.log_stats(
656
+ stats_meta={
657
+ "epoch": epoch,
658
+ "lr_model": old_lr_model,
659
+ },
660
+ train_stats=self.train_stats,
661
+ valid_stats=stage_stats,
662
+ )
663
+ self.checkpointer.save_and_keep_only(
664
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
665
+ )
666
+ elif stage == sb.Stage.TEST:
667
+ self.hparams.train_logger.log_stats(
668
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
669
+ test_stats=stage_stats,
670
+ )
671
+ with open(self.hparams.wer_file, "w") as w:
672
+ self.wer_metric.write_stats(w)
673
+
674
+ def init_optimizers(self):
675
+
676
+ self.model_optimizer = self.hparams.model_opt_class(
677
+ self.hparams.model.parameters()
678
+ )
679
+
680
+ if self.checkpointer is not None:
681
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
682
+
683
+ def zero_grad(self, set_to_none=False):
684
+
685
+ self.model_optimizer.zero_grad(set_to_none)
686
+
687
+
688
+ hparams_file, run_opts, overrides = sb.parse_arguments([sys.argv[1]])
689
+
690
+ # If distributed_launch=True then
691
+ # create ddp_group with the right communication protocol
692
+ sb.utils.distributed.ddp_init_group(run_opts)
693
+
694
+ with open(hparams_file) as fin:
695
+ hparams = load_hyperpyyaml(fin, overrides)
696
+
697
+ # Create experiment directory
698
+ sb.create_experiment_directory(
699
+ experiment_directory=hparams["output_folder"],
700
+ hyperparams_to_save=hparams_file,
701
+ overrides=overrides,
702
+ )
703
+ """
704
+ def read_labels_file(labels_file):
705
+ with open(labels_file, "r",encoding="utf-8") as lf:
706
+ lines = lf.read().splitlines()
707
+ division = "==="
708
+ numbers = {}
709
+ for line in lines :
710
+ if division in line :
711
+ break
712
+ string, number = line.split("=>")
713
+ number = int(number)
714
+ string = string[1:-2]
715
+ numbers[number] = string
716
+ return [numbers[x] for x in range(len(numbers))]
717
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
718
+ labels = [""] + labels[1:-1] + ["1"]
719
+
720
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
721
+ """
722
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
723
+ hparams
724
+ )
725
+
726
+
727
+
728
+
729
+ """
730
+ decoder = build_ctcdecoder(
731
+ labels,
732
+ kenlm_model_path="/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/lm_data/arpas/indomain.arpa", # either .arpa or .bin file
733
+ alpha=0.5, # tuned on a val set
734
+ beta=1, # tuned on a val set
735
+ )
736
+ """
737
+ mixer = Mixer(
738
+ modules=hparams["modules"],
739
+ hparams=hparams,
740
+ run_opts=run_opts,
741
+ checkpointer=hparams["checkpointer"],
742
+ )
743
+ mixer.tokenizer = label_encoder
744
+
745
+
746
+ mixer.fit(
747
+ mixer.hparams.epoch_counter,
748
+ train_data,
749
+ valid_data,
750
+ train_loader_kwargs=hparams["dataloader_options"],
751
+ valid_loader_kwargs=hparams["test_dataloader_options"],
752
+ )
753
+
754
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
755
+ mixer.hparams.wer_file = os.path.join(
756
+ hparams["output_folder"], "wer_{}.txt".format(k)
757
+ )
758
+ mixer.evaluate(
759
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
760
+ )
761
+
TunisianASR/results/14epoch_tunisian/<seed>/ctc_lin.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import sys
5
+ import torch
6
+ import logging
7
+ import speechbrain as sb
8
+ from speechbrain.utils.distributed import run_on_main
9
+ from hyperpyyaml import load_hyperpyyaml
10
+ from pathlib import Path
11
+ import torchaudio.transforms as T
12
+ from cv_train import ASRCV
13
+ import torchaudio
14
+ import numpy as np
15
+ import kenlm
16
+ from pyctcdecode import build_ctcdecoder
17
+ import re
18
+
19
+ # Commented out IPython magic to ensure Python compatibility.
20
+ # %cd /content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm
21
+ #hparams_file, run_opts, overrides = sb.parse_arguments(["/gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/hparams/train_semi.yaml"])
22
+ hparams_file, run_opts, overrides = sb.parse_arguments(["semi_supervised_test_tunisian.yaml"])
23
+
24
+ # If distributed_launch=True then
25
+ # create ddp_group with the right communication protocol
26
+ sb.utils.distributed.ddp_init_group(run_opts)
27
+
28
+ with open(hparams_file) as fin:
29
+ hparams = load_hyperpyyaml(fin, overrides)
30
+
31
+ # Create experiment directory
32
+ sb.create_experiment_directory(
33
+ experiment_directory=hparams["output_folder"],
34
+ hyperparams_to_save=hparams_file,
35
+ overrides=overrides,
36
+ )
37
+ # Dataset prep (parsing Librispeech)
38
+
39
+ def dataio_prepare(hparams):
40
+ """This function prepares the datasets to be used in the brain class.
41
+ It also defines the data processing pipeline through user-defined functions."""
42
+
43
+ # 1. Define datasets
44
+ data_folder = hparams["data_folder"]
45
+
46
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
47
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
48
+ )
49
+
50
+ if hparams["sorting"] == "ascending":
51
+ # we sort training data to speed up training and get better results.
52
+ train_data = train_data.filtered_sorted(
53
+ sort_key="duration",
54
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
55
+ )
56
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
57
+ hparams["dataloader_options"]["shuffle"] = False
58
+
59
+ elif hparams["sorting"] == "descending":
60
+ train_data = train_data.filtered_sorted(
61
+ sort_key="duration",
62
+ reverse=True,
63
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
64
+ )
65
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
66
+ hparams["dataloader_options"]["shuffle"] = False
67
+
68
+ elif hparams["sorting"] == "random":
69
+ pass
70
+
71
+ else:
72
+ raise NotImplementedError(
73
+ "sorting must be random, ascending or descending"
74
+ )
75
+
76
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
77
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
78
+ )
79
+ # We also sort the validation data so it is faster to validate
80
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
81
+ test_datasets = {}
82
+ for csv_file in hparams["test_csv"]:
83
+ name = Path(csv_file).stem
84
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
85
+ csv_path=csv_file, replacements={"data_root": data_folder}
86
+ )
87
+ test_datasets[name] = test_datasets[name].filtered_sorted(
88
+ sort_key="duration"
89
+ )
90
+
91
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
92
+
93
+
94
+ # 2. Define audio pipeline:
95
+ @sb.utils.data_pipeline.takes("wav")
96
+ @sb.utils.data_pipeline.provides("sig")
97
+ def audio_pipeline(wav):
98
+ info = torchaudio.info(wav)
99
+ sig = sb.dataio.dataio.read_audio(wav)
100
+ if len(sig.shape)>1 :
101
+ sig = torch.mean(sig, dim=1)
102
+ resampled = torchaudio.transforms.Resample(
103
+ info.sample_rate, hparams["sample_rate"],
104
+ )(sig)
105
+ return resampled
106
+
107
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
108
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
109
+
110
+ # 3. Define text pipeline:
111
+ @sb.utils.data_pipeline.takes("wrd")
112
+ @sb.utils.data_pipeline.provides(
113
+ "wrd", "char_list", "tokens_list", "tokens"
114
+ )
115
+ def text_pipeline(wrd):
116
+ yield wrd
117
+ char_list = list(wrd)
118
+ yield char_list
119
+ tokens_list = label_encoder.encode_sequence(char_list)
120
+ yield tokens_list
121
+ tokens = torch.LongTensor(tokens_list)
122
+ yield tokens
123
+
124
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
125
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
126
+ special_labels = {
127
+ "blank_label": hparams["blank_index"],
128
+ "unk_label": hparams["unk_index"]
129
+ }
130
+ label_encoder.load_or_create(
131
+ path=lab_enc_file,
132
+ from_didatasets=[train_data],
133
+ output_key="char_list",
134
+ special_labels=special_labels,
135
+ sequence_input=True,
136
+ )
137
+
138
+ # 4. Set output:
139
+ sb.dataio.dataset.set_output_keys(
140
+ datasets, ["id", "sig", "wrd", "char_list", "tokens"],
141
+ )
142
+ return train_data, valid_data,test_datasets, label_encoder
143
+
144
+ class ASR(sb.core.Brain):
145
+ def compute_forward(self, batch, stage):
146
+ """Forward computations from the waveform batches to the output probabilities."""
147
+
148
+ batch = batch.to(self.device)
149
+ wavs, wav_lens = batch.sig
150
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
151
+
152
+ if stage == sb.Stage.TRAIN:
153
+ if hasattr(self.hparams, "augmentation"):
154
+ wavs = self.hparams.augmentation(wavs, wav_lens)
155
+
156
+ # Forward pass
157
+ feats = self.modules.wav2vec2(wavs, wav_lens)
158
+ x = self.modules.enc(feats)
159
+ logits = self.modules.ctc_lin(x)
160
+ p_ctc = self.hparams.log_softmax(logits)
161
+
162
+ return p_ctc, wav_lens
163
+
164
+ def custom_encode(self,wavs,wav_lens) :
165
+ wavs = wavs.to(self.device)
166
+ if(wav_lens is not None): wav_lens.to(self.device)
167
+
168
+ feats = self.modules.wav2vec2(wavs, wav_lens)
169
+ x = self.modules.enc(feats)
170
+ logits = self.modules.ctc_lin(x)
171
+ p_ctc = self.hparams.log_softmax(logits)
172
+
173
+ return feats,p_ctc
174
+
175
+
176
+
177
+ def compute_objectives(self, predictions, batch, stage):
178
+ """Computes the loss (CTC) given predictions and targets."""
179
+
180
+ p_ctc, wav_lens = predictions
181
+
182
+ ids = batch.id
183
+ tokens, tokens_lens = batch.tokens
184
+
185
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
186
+
187
+ if stage != sb.Stage.TRAIN:
188
+ predicted_tokens = sb.decoders.ctc_greedy_decode(
189
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
190
+ )
191
+ # Decode token terms to words
192
+ if self.hparams.use_language_modelling:
193
+ predicted_words = []
194
+ for logs in p_ctc:
195
+ text = decoder.decode(logs.detach().cpu().numpy())
196
+ predicted_words.append(text.split(" "))
197
+ else:
198
+ predicted_words = [
199
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
200
+ for utt_seq in predicted_tokens
201
+ ]
202
+ # Convert indices to words
203
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
204
+
205
+ self.wer_metric.append(ids, predicted_words, target_words)
206
+ self.cer_metric.append(ids, predicted_words, target_words)
207
+
208
+ return loss
209
+
210
+ def fit_batch(self, batch):
211
+ """Train the parameters given a single batch in input"""
212
+ should_step = self.step % self.grad_accumulation_factor == 0
213
+ # Managing automatic mixed precision
214
+ # TOFIX: CTC fine-tuning currently is unstable
215
+ # This is certainly due to CTC being done in fp16 instead of fp32
216
+ if self.auto_mix_prec:
217
+ with torch.cuda.amp.autocast():
218
+ with self.no_sync():
219
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
220
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
221
+ with self.no_sync(not should_step):
222
+ self.scaler.scale(
223
+ loss / self.grad_accumulation_factor
224
+ ).backward()
225
+ if should_step:
226
+
227
+ if not self.hparams.wav2vec2.freeze:
228
+ self.scaler.unscale_(self.wav2vec_optimizer)
229
+ self.scaler.unscale_(self.model_optimizer)
230
+ if self.check_gradients(loss):
231
+ if not self.hparams.wav2vec2.freeze:
232
+ if self.optimizer_step >= self.hparams.warmup_steps:
233
+ self.scaler.step(self.wav2vec_optimizer)
234
+ self.scaler.step(self.model_optimizer)
235
+ self.scaler.update()
236
+ self.zero_grad()
237
+ self.optimizer_step += 1
238
+ else:
239
+ # This is mandatory because HF models have a weird behavior with DDP
240
+ # on the forward pass
241
+ with self.no_sync():
242
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
243
+
244
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
245
+
246
+ with self.no_sync(not should_step):
247
+ (loss / self.grad_accumulation_factor).backward()
248
+ if should_step:
249
+ if self.check_gradients(loss):
250
+ if not self.hparams.wav2vec2.freeze:
251
+ if self.optimizer_step >= self.hparams.warmup_steps:
252
+ self.wav2vec_optimizer.step()
253
+ self.model_optimizer.step()
254
+ self.zero_grad()
255
+ self.optimizer_step += 1
256
+
257
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
258
+ return loss.detach().cpu()
259
+
260
+ def evaluate_batch(self, batch, stage):
261
+ """Computations needed for validation/test batches"""
262
+ predictions = self.compute_forward(batch, stage=stage)
263
+ with torch.no_grad():
264
+ loss = self.compute_objectives(predictions, batch, stage=stage)
265
+ return loss.detach()
266
+
267
+ def on_stage_start(self, stage, epoch):
268
+ """Gets called at the beginning of each epoch"""
269
+ if stage != sb.Stage.TRAIN:
270
+ self.cer_metric = self.hparams.cer_computer()
271
+ self.wer_metric = self.hparams.error_rate_computer()
272
+
273
+ def on_stage_end(self, stage, stage_loss, epoch):
274
+ """Gets called at the end of an epoch."""
275
+ # Compute/store important stats
276
+ stage_stats = {"loss": stage_loss}
277
+ if stage == sb.Stage.TRAIN:
278
+ self.train_stats = stage_stats
279
+ else:
280
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
281
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
282
+
283
+ # Perform end-of-iteration things, like annealing, logging, etc.
284
+ if stage == sb.Stage.VALID:
285
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
286
+ stage_stats["loss"]
287
+ )
288
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
289
+ stage_stats["loss"]
290
+ )
291
+ sb.nnet.schedulers.update_learning_rate(
292
+ self.model_optimizer, new_lr_model
293
+ )
294
+ if not self.hparams.wav2vec2.freeze:
295
+ sb.nnet.schedulers.update_learning_rate(
296
+ self.wav2vec_optimizer, new_lr_wav2vec
297
+ )
298
+ self.hparams.train_logger.log_stats(
299
+ stats_meta={
300
+ "epoch": epoch,
301
+ "lr_model": old_lr_model,
302
+ "lr_wav2vec": old_lr_wav2vec,
303
+ },
304
+ train_stats=self.train_stats,
305
+ valid_stats=stage_stats,
306
+ )
307
+ self.checkpointer.save_and_keep_only(
308
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
309
+ )
310
+ elif stage == sb.Stage.TEST:
311
+ self.hparams.train_logger.log_stats(
312
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
313
+ test_stats=stage_stats,
314
+ )
315
+ with open(self.hparams.wer_file, "w") as w:
316
+ self.wer_metric.write_stats(w)
317
+
318
+ def init_optimizers(self):
319
+ "Initializes the wav2vec2 optimizer and model optimizer"
320
+
321
+ # If the wav2vec encoder is unfrozen, we create the optimizer
322
+ if not self.hparams.wav2vec2.freeze:
323
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
324
+ self.modules.wav2vec2.parameters()
325
+ )
326
+ if self.checkpointer is not None:
327
+ self.checkpointer.add_recoverable(
328
+ "wav2vec_opt", self.wav2vec_optimizer
329
+ )
330
+
331
+ self.model_optimizer = self.hparams.model_opt_class(
332
+ self.hparams.model.parameters()
333
+ )
334
+
335
+ if self.checkpointer is not None:
336
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
337
+
338
+ def zero_grad(self, set_to_none=False):
339
+ if not self.hparams.wav2vec2.freeze:
340
+ self.wav2vec_optimizer.zero_grad(set_to_none)
341
+ self.model_optimizer.zero_grad(set_to_none)
342
+
343
+
344
+ """
345
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
346
+
347
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
348
+ hparams
349
+ )
350
+
351
+
352
+ # We dynamicaly add the tokenizer to our brain class.
353
+ # NB: This tokenizer corresponds to the one used for the LM!!
354
+ """
355
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
356
+ french_asr_model = EncoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-fr", savedir="pretrained_models/asr-wav2vec2-commonvoice-fr").cuda()
357
+ #french_asr_model = "r"
358
+
359
+ cvhparams_file, cvrun_opts, cvoverrides = sb.parse_arguments(["en_cv.yaml"])
360
+ with open(cvhparams_file) as cvfin:
361
+ cvhparams = load_hyperpyyaml(cvfin, cvoverrides)
362
+ english_asr_model = ASRCV(
363
+ modules=cvhparams["modules"],
364
+ hparams=cvhparams,
365
+ run_opts=cvrun_opts,
366
+ checkpointer=cvhparams["checkpointer"],
367
+ )
368
+ english_asr_model.checkpointer.recover_if_possible()
369
+ asr_brain = ASR(
370
+ modules=hparams["modules"],
371
+ hparams=hparams,
372
+ run_opts=run_opts,
373
+ checkpointer=hparams["checkpointer"],
374
+ )
375
+ asr_brain.checkpointer.recover_if_possible()
376
+ asr_brain.modules.eval()
377
+ english_asr_model.modules.eval()
378
+ french_asr_model.mods.eval()
379
+ """
380
+ asr_brain.tokenizer = label_encoder
381
+
382
+ # Testing
383
+ real = True
384
+ if real :
385
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
386
+ asr_brain.hparams.wer_file = os.path.join(
387
+ hparams["output_folder"], "wer_{}.txt".format(k)
388
+ )
389
+ asr_brain.evaluate(
390
+ test_datasets[k], test_loader_kwargs=hparams["dataloader_options"]
391
+ )
392
+ """
393
+
394
+ """
395
+ from torch.nn.utils.rnn import pad_sequence
396
+ def load_paths(wavs_path):
397
+ waveforms = []
398
+ for path in wavs_path :
399
+ waveform, _ = torchaudio.load(path)
400
+ waveforms.append(waveform.squeeze(0))
401
+ # normalize array length to the bigger arrays by pading with 0's
402
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
403
+ return torch.tensor(padded_arrays)
404
+
405
+ waveform = load_paths(["/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav","/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav"])
406
+ embeddings, posteriogram = asr_brain.custom_encode(waveform,None)
407
+ print(embeddings.shape)
408
+ print(posteriogram.shape)
409
+ """
410
+
411
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
412
+ import torchaudio
413
+ import speechbrain as sb
414
+ import torch
415
+ from torch.nn.utils.rnn import pad_sequence
416
+ import torch
417
+ import speechbrain as sb
418
+ import numpy as np
419
+ import torch.optim as optim
420
+ import torch.nn as nn
421
+
422
+ # Commented out IPython magic to ensure Python compatibility.
423
+ # %ls
424
+
425
+ #UTILS FUNCTIOJNS
426
+ def get_size_dimensions(arr):
427
+ size_dimensions = []
428
+ while isinstance(arr, list):
429
+ size_dimensions.append(len(arr))
430
+ arr = arr[0]
431
+ return size_dimensions
432
+
433
+ def scale_array(batch,n):
434
+ scaled_batch = []
435
+
436
+ for array in batch:
437
+ if(n < len(array)): raise ValueError("Cannot scale Array down")
438
+
439
+ repeat = round(n/len(array))+1
440
+ scaled_length_array= []
441
+
442
+ for i in array:
443
+ for j in range(repeat) :
444
+ if(len(scaled_length_array) == n): break
445
+ scaled_length_array.append(i)
446
+
447
+ scaled_batch.append(scaled_length_array)
448
+
449
+ return torch.tensor(scaled_batch)
450
+
451
+
452
+ def load_paths(wavs_path):
453
+ waveforms = []
454
+ for path in wavs_path :
455
+ waveform, _ = torchaudio.load(path)
456
+ waveforms.append(waveform.squeeze(0))
457
+ # normalize array length to the bigger arrays by pading with 0's
458
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
459
+ return torch.tensor(padded_arrays)
460
+
461
+
462
+
463
+ def word_to_vec(input_string):
464
+ mapping= {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, 'ا': 27, 'ب': 28, 'ت': 29, 'ث': 30, 'ج': 31, 'ح': 32, 'خ': 33, 'د': 34, 'ذ': 35, 'ر': 36, 'ز': 37, 'س': 38, 'ش': 39, 'ص': 40, 'ض': 41, 'ط': 42, 'ظ': 43, 'ع': 44, 'غ': 45, 'ف': 46, 'ق': 47, 'ك': 48, 'ل': 49, 'م': 50, 'ن': 51, 'ه': 52, 'و': 53, 'ي': 54,' ':55}
465
+
466
+ numbers = [mapping[word] for word in input_string if word in mapping]
467
+ return numbers
468
+
469
+ device = 'cuda'
470
+ verbose = 0
471
+ #FLOW LEVEL FUNCTIONS
472
+ def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):
473
+
474
+
475
+ post1 = post1.to(device)
476
+ post2 = post2.to(device)
477
+ post3 = post3.to(device)
478
+ embeddings1 = embeddings1.to(device)
479
+ embeddings2 = embeddings2.to(device)
480
+ embeddings3 = embeddings3.to(device)
481
+
482
+ posteriograms_merged = torch.cat((post1,post2,post3),dim=2)
483
+ embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)
484
+
485
+ if(verbose !=0):
486
+ print('MERGED POST ',posteriograms_merged.shape)
487
+ print('MERGED emb ',embeddings_merged.shape)
488
+
489
+ return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)
490
+
491
+ def decode(model,wavs,wav_lens):
492
+
493
+ with torch.no_grad():
494
+ wav_lens = wav_lens.to(model.device)
495
+ encoder_out = model.encode_batch(wavs, wav_lens)
496
+ predictions = model.decoding_function(encoder_out, wav_lens)
497
+ return predictions
498
+
499
+ def middle_layer(batch, lens):
500
+
501
+ tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)
502
+
503
+ fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)
504
+ fr_posteriogram =french_asr_model.encode_batch(batch,lens)
505
+ en_embeddings = english_asr_model.modules.wav2vec2(batch, lens)
506
+ x = english_asr_model.modules.enc(en_embeddings)
507
+ en_posteriogram = english_asr_model.modules.ctc_lin(x)
508
+ #scores, en_posteriogram = english_asr_model.mods.decoder(en_embeddings ,lens)
509
+ if(verbose !=0):
510
+ print('[EMBEDDINGS] FR:',fr_embeddings.shape, "EN:",en_embeddings.shape, "TN:", tn_embeddings.shape)
511
+ print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, "EN:",en_posteriogram.shape,"TN:",tn_posteriogram.shape)
512
+
513
+
514
+ bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)
515
+ return bilangual_sample
516
+
517
+ class Mixer(sb.core.Brain):
518
+
519
+ def compute_forward(self, batch, stage):
520
+ """Forward computations from the waveform batches to the output probabilities."""
521
+ wavs, wav_lens = batch.sig
522
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
523
+
524
+ if stage == sb.Stage.TRAIN:
525
+ if hasattr(self.hparams, "augmentation"):
526
+ wavs = self.hparams.augmentation(wavs, wav_lens)
527
+
528
+ multi_langual_feats = middle_layer(wavs, wav_lens)
529
+ multi_langual_feats= multi_langual_feats.to(device)
530
+ feats, _ = self.modules.enc(multi_langual_feats)
531
+ logits = self.modules.ctc_lin(feats)
532
+ p_ctc = self.hparams.log_softmax(logits)
533
+
534
+ if stage!= sb.Stage.TRAIN:
535
+ p_tokens = sb.decoders.ctc_greedy_decode(
536
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
537
+ )
538
+ else :
539
+ p_tokens = None
540
+ return p_ctc, wav_lens, p_tokens
541
+
542
+ def compute_objectives(self, predictions, batch, stage):
543
+ """Computes the loss (CTC) given predictions and targets."""
544
+
545
+ p_ctc, wav_lens , predicted_tokens= predictions
546
+
547
+ ids = batch.id
548
+ tokens, tokens_lens = batch.tokens
549
+
550
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
551
+
552
+
553
+ if stage == sb.Stage.VALID:
554
+ predicted_words = [
555
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
556
+ for utt_seq in predicted_tokens
557
+ ]
558
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
559
+ self.wer_metric.append(ids, predicted_words, target_words)
560
+ self.cer_metric.append(ids, predicted_words, target_words)
561
+ if stage ==sb.Stage.TEST :
562
+ if self.hparams.language_modelling:
563
+ predicted_words = []
564
+ for logs in p_ctc:
565
+ text = decoder.decode(logs.detach().cpu().numpy())
566
+ predicted_words.append(text.split(" "))
567
+ else :
568
+ predicted_words = [
569
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
570
+ for utt_seq in predicted_tokens
571
+ ]
572
+
573
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
574
+ self.wer_metric.append(ids, predicted_words, target_words)
575
+ self.cer_metric.append(ids, predicted_words, target_words)
576
+
577
+ return loss
578
+
579
+ def fit_batch(self, batch):
580
+ """Train the parameters given a single batch in input"""
581
+ should_step = self.step % self.grad_accumulation_factor == 0
582
+ # Managing automatic mixed precision
583
+ # TOFIX: CTC fine-tuning currently is unstable
584
+ # This is certainly due to CTC being done in fp16 instead of fp32
585
+ if self.auto_mix_prec:
586
+ with torch.cuda.amp.autocast():
587
+ with self.no_sync():
588
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
589
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
590
+ with self.no_sync(not should_step):
591
+ self.scaler.scale(
592
+ loss / self.grad_accumulation_factor
593
+ ).backward()
594
+ if should_step:
595
+
596
+
597
+ self.scaler.unscale_(self.model_optimizer)
598
+ if self.check_gradients(loss):
599
+ self.scaler.step(self.model_optimizer)
600
+ self.scaler.update()
601
+ self.zero_grad()
602
+ self.optimizer_step += 1
603
+ else:
604
+ # This is mandatory because HF models have a weird behavior with DDP
605
+ # on the forward pass
606
+ with self.no_sync():
607
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
608
+
609
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
610
+
611
+ with self.no_sync(not should_step):
612
+ (loss / self.grad_accumulation_factor).backward()
613
+ if should_step:
614
+ if self.check_gradients(loss):
615
+ self.model_optimizer.step()
616
+ self.zero_grad()
617
+ self.optimizer_step += 1
618
+
619
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
620
+ return loss.detach().cpu()
621
+
622
+ def evaluate_batch(self, batch, stage):
623
+ """Computations needed for validation/test batches"""
624
+ predictions = self.compute_forward(batch, stage=stage)
625
+ with torch.no_grad():
626
+ loss = self.compute_objectives(predictions, batch, stage=stage)
627
+ return loss.detach()
628
+
629
+ def on_stage_start(self, stage, epoch):
630
+ """Gets called at the beginning of each epoch"""
631
+ if stage != sb.Stage.TRAIN:
632
+ self.cer_metric = self.hparams.cer_computer()
633
+ self.wer_metric = self.hparams.error_rate_computer()
634
+
635
+ def on_stage_end(self, stage, stage_loss, epoch):
636
+ """Gets called at the end of an epoch."""
637
+ # Compute/store important stats
638
+ stage_stats = {"loss": stage_loss}
639
+ if stage == sb.Stage.TRAIN:
640
+ self.train_stats = stage_stats
641
+ else:
642
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
643
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
644
+
645
+ # Perform end-of-iteration things, like annealing, logging, etc.
646
+ if stage == sb.Stage.VALID:
647
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
648
+ stage_stats["loss"]
649
+ )
650
+ sb.nnet.schedulers.update_learning_rate(
651
+ self.model_optimizer, new_lr_model
652
+ )
653
+ self.hparams.train_logger.log_stats(
654
+ stats_meta={
655
+ "epoch": epoch,
656
+ "lr_model": old_lr_model,
657
+ },
658
+ train_stats=self.train_stats,
659
+ valid_stats=stage_stats,
660
+ )
661
+ self.checkpointer.save_and_keep_only(
662
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
663
+ )
664
+ elif stage == sb.Stage.TEST:
665
+ self.hparams.train_logger.log_stats(
666
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
667
+ test_stats=stage_stats,
668
+ )
669
+ with open(self.hparams.wer_file, "w") as w:
670
+ self.wer_metric.write_stats(w)
671
+
672
+ def init_optimizers(self):
673
+
674
+ self.model_optimizer = self.hparams.model_opt_class(
675
+ self.hparams.model.parameters()
676
+ )
677
+
678
+ if self.checkpointer is not None:
679
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
680
+
681
+ def zero_grad(self, set_to_none=False):
682
+
683
+ self.model_optimizer.zero_grad(set_to_none)
684
+
685
+
686
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
687
+
688
+ # If distributed_launch=True then
689
+ # create ddp_group with the right communication protocol
690
+ sb.utils.distributed.ddp_init_group(run_opts)
691
+
692
+ with open(hparams_file) as fin:
693
+ hparams = load_hyperpyyaml(fin, overrides)
694
+
695
+ # Create experiment directory
696
+ sb.create_experiment_directory(
697
+ experiment_directory=hparams["output_folder"],
698
+ hyperparams_to_save=hparams_file,
699
+ overrides=overrides,
700
+ )
701
+ def read_labels_file(labels_file):
702
+ with open(labels_file, "r",encoding="utf-8") as lf:
703
+ lines = lf.read().splitlines()
704
+ division = "==="
705
+ numbers = {}
706
+ for line in lines :
707
+ if division in line :
708
+ break
709
+ string, number = line.split("=>")
710
+ number = int(number)
711
+ string = string[1:-2]
712
+ numbers[number] = string
713
+ return [numbers[x] for x in range(len(numbers))]
714
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
715
+ hparams
716
+ )
717
+
718
+
719
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
720
+ labels = [""] + labels[1:-1] + ["1"]
721
+ if hparams["language_modelling"]:
722
+ decoder = build_ctcdecoder(
723
+ labels,
724
+ kenlm_model_path=hparams["ngram_lm_path"], # either .arpa or .bin file
725
+ alpha=0.5, # tuned on a val set
726
+ beta=1, # tuned on a val set
727
+ )
728
+
729
+
730
+
731
+
732
+ mixer = Mixer(
733
+ modules=hparams["modules"],
734
+ hparams=hparams,
735
+ run_opts=run_opts,
736
+ checkpointer=hparams["checkpointer"],
737
+ )
738
+ mixer.tokenizer = label_encoder
739
+
740
+
741
+ mixer.fit(
742
+ mixer.hparams.epoch_counter,
743
+ train_data,
744
+ valid_data,
745
+ train_loader_kwargs=hparams["dataloader_options"],
746
+ valid_loader_kwargs=hparams["test_dataloader_options"],
747
+ )
748
+ print(test_datasets.keys())
749
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
750
+ mixer.hparams.wer_file = os.path.join(
751
+ hparams["output_folder"], "wer_{}.txt".format(k)
752
+ )
753
+ mixer.evaluate(
754
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
755
+ )
756
+
TunisianASR/results/14epoch_tunisian/<seed>/env.log ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SpeechBrain system description
2
+ ==============================
3
+ Python version:
4
+ 3.9.12 | packaged by conda-forge | (main, Mar 24 2022, 23:22:55)
5
+ [GCC 10.3.0]
6
+ ==============================
7
+ Installed Python packages:
8
+ aiohttp==3.8.5
9
+ aiosignal==1.3.1
10
+ async-timeout==4.0.3
11
+ attrs==23.1.0
12
+ audioread==3.0.0
13
+ certifi==2023.7.22
14
+ cffi==1.15.1
15
+ charset-normalizer==3.2.0
16
+ click==8.1.7
17
+ cmake==3.27.2
18
+ datasets==2.14.4
19
+ decorator==5.1.1
20
+ dill==0.3.7
21
+ exceptiongroup==1.1.3
22
+ filelock==3.12.3
23
+ frozenlist==1.4.0
24
+ fsspec==2023.6.0
25
+ huggingface-hub==0.16.4
26
+ HyperPyYAML==1.2.1
27
+ hypothesis==6.82.7
28
+ idna==3.4
29
+ Jinja2==3.1.2
30
+ jiwer==3.0.3
31
+ joblib==1.3.2
32
+ kenlm @ https://github.com/kpu/kenlm/archive/master.zip#sha256=4d002dcde70b52d519cafff4dc0008696c40cff1c9184a531b40c7b45905be6b
33
+ lazy_loader==0.3
34
+ librosa==0.10.1
35
+ lit==16.0.6
36
+ llvmlite==0.40.1
37
+ MarkupSafe==2.1.3
38
+ mpmath==1.3.0
39
+ msgpack==1.0.5
40
+ multidict==6.0.4
41
+ multiprocess==0.70.15
42
+ networkx==3.1
43
+ numba==0.57.1
44
+ numpy==1.24.4
45
+ nvidia-cublas-cu11==11.10.3.66
46
+ nvidia-cuda-cupti-cu11==11.7.101
47
+ nvidia-cuda-nvrtc-cu11==11.7.99
48
+ nvidia-cuda-runtime-cu11==11.7.99
49
+ nvidia-cudnn-cu11==8.5.0.96
50
+ nvidia-cufft-cu11==10.9.0.58
51
+ nvidia-curand-cu11==10.2.10.91
52
+ nvidia-cusolver-cu11==11.4.0.1
53
+ nvidia-cusparse-cu11==11.7.4.91
54
+ nvidia-nccl-cu11==2.14.3
55
+ nvidia-nvtx-cu11==11.7.91
56
+ packaging==23.1
57
+ pandas==2.0.3
58
+ platformdirs==3.10.0
59
+ pooch==1.7.0
60
+ pyarrow==13.0.0
61
+ pycparser==2.21
62
+ pyctcdecode==0.5.0
63
+ pygtrie==2.5.0
64
+ python-dateutil==2.8.2
65
+ pytz==2023.3
66
+ PyYAML==6.0.1
67
+ rapidfuzz==3.2.0
68
+ regex==2023.8.8
69
+ requests==2.31.0
70
+ ruamel.yaml==0.17.28
71
+ ruamel.yaml.clib==0.2.7
72
+ safetensors==0.3.3
73
+ scikit-learn==1.3.0
74
+ scipy==1.11.2
75
+ sentencepiece==0.1.99
76
+ six==1.16.0
77
+ sortedcontainers==2.4.0
78
+ soundfile==0.12.1
79
+ soxr==0.3.6
80
+ speechbrain==0.5.15
81
+ sympy==1.12
82
+ threadpoolctl==3.2.0
83
+ tokenizers==0.13.3
84
+ torch==2.0.1
85
+ torchaudio==2.0.2
86
+ tqdm==4.66.1
87
+ transformers==4.32.1
88
+ triton==2.0.0
89
+ typing_extensions==4.7.1
90
+ tzdata==2023.3
91
+ urllib3==2.0.4
92
+ xxhash==3.3.0
93
+ yarl==1.9.2
94
+ ==============================
95
+ Could not get git revision==============================
96
+ CUDA version:
97
+ 11.7
TunisianASR/results/14epoch_tunisian/<seed>/hyperparams.yaml ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2023-09-08 from:
2
+ # /gpfsssd/scratch/rech/nou/uzn19yk/switched_data/semi_supervised_test_tunisian.yaml
3
+ # yamllint disable
4
+ # ################################
5
+ # Model: wav2vec2 + DNN + CTC
6
+ # Augmentation: SpecAugment
7
+ # Authors: Titouan Parcollet 2021
8
+ # ################################
9
+
10
+ # Seed needs to be set at top of yaml, before objects with parameters are made
11
+ seed: 1234
12
+ __set_seed: !!python/object/apply:torch.manual_seed [1234]
13
+ output_folder:
14
+ /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/results/14epoch_tunisian/<seed>
15
+ wer_file:
16
+ /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/results/14epoch_tunisian/1234/wer.txt
17
+ save_folder:
18
+ /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/results/14epoch_tunisian/1234/save
19
+ train_log:
20
+ /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/results/14epoch_tunisian/1234/train_log.txt
21
+
22
+ # URL for the biggest LeBenchmark wav2vec french.
23
+ wav2vec2_folder:
24
+ /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/results/14epoch_tunisian/1234/save/wav2vec2_checkpoint
25
+
26
+ # Data files
27
+ data_folder: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
28
+ train_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/train.tsv # Standard CommonVoice .tsv files
29
+ dev_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/dev.tsv # Standard CommonVoice .tsv files
30
+ test_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/test.tsv # Standard CommonVoice .tsv files
31
+ accented_letters: true
32
+ language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
33
+ train_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/train.csv
34
+ valid_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/dev.csv
35
+ test_csv:
36
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_semi/unlabeled.csv
37
+
38
+ skip_prep: true # Skip data preparation
39
+
40
+ use_language_modelling: true
41
+ ngram_lm_path: arpas/indomain.arpa
42
+
43
+ # We remove utterance slonger than 10s in the train/dev/test sets as
44
+ # longer sentences certainly correspond to "open microphones".
45
+ avoid_if_longer_than: 10.0
46
+ avoid_if_shorter_than: 1.2
47
+
48
+
49
+ # Training parameters
50
+ number_of_epochs: 14
51
+ lr: 1.0
52
+ lr_wav2vec: 0.0001
53
+ sorting: ascending
54
+ auto_mix_prec: false
55
+ sample_rate: 16000
56
+ ckpt_interval_minutes: 30 # save checkpoint every N min
57
+
58
+ # With data_parallel batch_size is split into N jobs
59
+ # With DDP batch_size is multiplied by N jobs
60
+ # Must be 6 per GPU to fit 16GB of VRAM
61
+ batch_size: 10
62
+ test_batch_size: 4
63
+
64
+ dataloader_options:
65
+ batch_size: 10
66
+ num_workers: 6
67
+ test_dataloader_options:
68
+ batch_size: 4
69
+ num_workers: 6
70
+
71
+ # BPE parameters
72
+ token_type: char # ["unigram", "bpe", "char"]
73
+ character_coverage: 1.0
74
+
75
+ # Model parameters
76
+ # activation: !name:torch.nn.LeakyReLU
77
+ wav2vec_output_dim: 1024
78
+ dnn_neurons: 1024
79
+ freeze_wav2vec: false
80
+ freeze_feature_extractor: true
81
+ dropout: 0.15
82
+ warmup_steps: 500 # The wav2vec 2 model isn't updated for this amount of steps
83
+
84
+ # Outputs
85
+ output_neurons: 40 # BPE size, index(blank/eos/bos) = 0
86
+
87
+ # Decoding parameters
88
+ # Be sure that the bos and eos index match with the BPEs ones
89
+ blank_index: 0
90
+ unk_index: 1
91
+
92
+ #
93
+ # Functions and classes
94
+ #
95
+ epoch_counter: &id007 !new:speechbrain.utils.epoch_loop.EpochCounter
96
+
97
+ limit: 14
98
+
99
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
100
+ sample_rate: 16000
101
+ speeds: [95, 100, 105]
102
+
103
+ enc: &id002 !new:speechbrain.nnet.containers.Sequential
104
+ input_shape: [null, null, 1024]
105
+ linear1: !name:speechbrain.nnet.linear.Linear
106
+ n_neurons: 1024
107
+ bias: true
108
+ bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
109
+ activation: !new:torch.nn.LeakyReLU
110
+ drop: !new:torch.nn.Dropout
111
+ p: 0.15
112
+ linear2: !name:speechbrain.nnet.linear.Linear
113
+ n_neurons: 1024
114
+ bias: true
115
+ bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
116
+ activation2: !new:torch.nn.LeakyReLU
117
+ drop2: !new:torch.nn.Dropout
118
+ p: 0.15
119
+ linear3: !name:speechbrain.nnet.linear.Linear
120
+ n_neurons: 1024
121
+ bias: true
122
+ bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
123
+ activation3: !new:torch.nn.LeakyReLU
124
+
125
+ wav2vec2: &id001 !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
126
+ source: /gpfsstore/rech/nou/uzn19yk/wavlm/
127
+ output_norm: false
128
+ freeze: false
129
+ freeze_feature_extractor: true
130
+ save_path:
131
+ /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/results/14epoch_tunisian/1234/save/wav2vec2_checkpoint
132
+
133
+ #####
134
+ # Uncomment this block if you prefer to use a Fairseq pretrained model instead
135
+ # of a HuggingFace one. Here, we provide an URL that is obtained from the
136
+ # Fairseq github for the multilingual XLSR.
137
+ #
138
+ #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
139
+ #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
140
+ # pretrained_path: !ref <wav2vec2_url>
141
+ # output_norm: True
142
+ # freeze: False
143
+ # save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
144
+ #####
145
+
146
+
147
+ ctc_lin: &id003 !new:speechbrain.nnet.linear.Linear
148
+
149
+ input_size: 1024
150
+ n_neurons: 40
151
+
152
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
153
+ apply_log: true
154
+
155
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
156
+ blank_index: 0
157
+
158
+ modules:
159
+ wav2vec2: *id001
160
+ enc: *id002
161
+ ctc_lin: *id003
162
+ model: &id004 !new:torch.nn.ModuleList
163
+ - [*id002, *id003]
164
+ model_opt_class: !name:torch.optim.Adadelta
165
+ lr: 1.0
166
+ rho: 0.95
167
+ eps: 1.e-8
168
+
169
+ wav2vec_opt_class: !name:torch.optim.Adam
170
+ lr: 0.0001
171
+
172
+ lr_annealing_model: &id005 !new:speechbrain.nnet.schedulers.NewBobScheduler
173
+ initial_value: 1.0
174
+ improvement_threshold: 0.0025
175
+ annealing_factor: 0.8
176
+ patient: 0
177
+
178
+ lr_annealing_wav2vec: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
179
+ initial_value: 0.0001
180
+ improvement_threshold: 0.0025
181
+ annealing_factor: 0.9
182
+ patient: 0
183
+
184
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
185
+ checkpoints_dir:
186
+ /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/results/14epoch_tunisian/1234/save
187
+ recoverables:
188
+ wav2vec2: *id001
189
+ model: *id004
190
+ scheduler_model: *id005
191
+ scheduler_wav2vec: *id006
192
+ counter: *id007
193
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
194
+ save_file:
195
+ /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/results/14epoch_tunisian/1234/train_log.txt
196
+
197
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
198
+
199
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
200
+ split_tokens: true
TunisianASR/results/14epoch_tunisian/<seed>/log.txt ADDED
The diff for this file is too large to render. See raw diff
 
TunisianASR/{train_semi.yaml → semi_trained.yaml} RENAMED
@@ -7,7 +7,7 @@
7
  # Seed needs to be set at top of yaml, before objects with parameters are made
8
  seed: 1234
9
  __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
10
- output_folder: !ref semi_wavlm_large_tunisian_ctc/<seed>
11
  wer_file: !ref <output_folder>/wer.txt
12
  save_folder: !ref <output_folder>/save
13
  train_log: !ref <output_folder>/train_log.txt
@@ -16,19 +16,23 @@ train_log: !ref <output_folder>/train_log.txt
16
  wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
17
 
18
  # Data files
19
- data_folder: /path/to/data # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
20
  train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
21
  dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
22
  test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
23
  accented_letters: True
24
  language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
 
 
25
  test_csv:
26
- - /path/to/test_data
 
 
27
 
28
  skip_prep: True # Skip data preparation
29
 
30
  use_language_modelling: True
31
- ngram_lm_path: outdomain.arpa
32
 
33
  # We remove utterance slonger than 10s in the train/dev/test sets as
34
  # longer sentences certainly correspond to "open microphones".
@@ -37,7 +41,7 @@ avoid_if_shorter_than: 1.2
37
 
38
 
39
  # Training parameters
40
- number_of_epochs: 12
41
  lr: 1.0
42
  lr_wav2vec: 0.0001
43
  sorting: ascending
@@ -112,12 +116,25 @@ enc: !new:speechbrain.nnet.containers.Sequential
112
  activation3: !new:torch.nn.LeakyReLU
113
 
114
  wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
115
- source: /gpfsstore/rech/nou/uzn19yk/wavlm/
116
  output_norm: False
117
  freeze: !ref <freeze_wav2vec>
118
  freeze_feature_extractor: !ref <freeze_feature_extractor>
119
  save_path: !ref <wav2vec2_folder>
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  ctc_lin: !new:speechbrain.nnet.linear.Linear
123
  input_size: !ref <dnn_neurons>
 
7
  # Seed needs to be set at top of yaml, before objects with parameters are made
8
  seed: 1234
9
  __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
10
+ output_folder: TunisianASR/results/14epoch_tunisian/1234/
11
  wer_file: !ref <output_folder>/wer.txt
12
  save_folder: !ref <output_folder>/save
13
  train_log: !ref <output_folder>/train_log.txt
 
16
  wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
17
 
18
  # Data files
19
+ data_folder: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
20
  train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
21
  dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
22
  test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
23
  accented_letters: True
24
  language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
25
+ train_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/train.csv
26
+ valid_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/dev.csv
27
  test_csv:
28
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/full_annotation_test.csv
29
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/iwslt_test.csv
30
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/taric_test.csv
31
 
32
  skip_prep: True # Skip data preparation
33
 
34
  use_language_modelling: True
35
+ ngram_lm_path: arpas/outdomain.arpa
36
 
37
  # We remove utterance slonger than 10s in the train/dev/test sets as
38
  # longer sentences certainly correspond to "open microphones".
 
41
 
42
 
43
  # Training parameters
44
+ number_of_epochs: 14
45
  lr: 1.0
46
  lr_wav2vec: 0.0001
47
  sorting: ascending
 
116
  activation3: !new:torch.nn.LeakyReLU
117
 
118
  wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
119
+ source: wavlm-large/
120
  output_norm: False
121
  freeze: !ref <freeze_wav2vec>
122
  freeze_feature_extractor: !ref <freeze_feature_extractor>
123
  save_path: !ref <wav2vec2_folder>
124
 
125
+ #####
126
+ # Uncomment this block if you prefer to use a Fairseq pretrained model instead
127
+ # of a HuggingFace one. Here, we provide an URL that is obtained from the
128
+ # Fairseq github for the multilingual XLSR.
129
+ #
130
+ #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
131
+ #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
132
+ # pretrained_path: !ref <wav2vec2_url>
133
+ # output_norm: True
134
+ # freeze: False
135
+ # save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
136
+ #####
137
+
138
 
139
  ctc_lin: !new:speechbrain.nnet.linear.Linear
140
  input_size: !ref <dnn_neurons>
__pycache__/cv_train.cpython-38.pyc ADDED
Binary file (8.43 kB). View file
 
app.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import speechbrain as sb
6
+ from speechbrain.utils.distributed import run_on_main
7
+ from hyperpyyaml import load_hyperpyyaml
8
+ from pathlib import Path
9
+ import torchaudio.transforms as T
10
+ from cv_train import ASRCV
11
+ import torchaudio
12
+ import numpy as np
13
+ import kenlm
14
+ from pyctcdecode import build_ctcdecoder
15
+ import re
16
+ from torch.nn.utils.rnn import pad_sequence
17
+ import torch.optim as optim
18
+ import torch.nn as nn
19
+
20
+
21
+ # Commented out IPython magic to ensure Python compatibility.
22
+ hparams_file, run_opts, overrides = sb.parse_arguments(["TunisianASR/semi_trained.yaml"])
23
+
24
+ # If distributed_launch=True then
25
+ # create ddp_group with the right communication protocol
26
+ sb.utils.distributed.ddp_init_group(run_opts)
27
+
28
+ with open(hparams_file) as fin:
29
+ hparams = load_hyperpyyaml(fin, overrides)
30
+
31
+ # Create experiment directory
32
+ sb.create_experiment_directory(
33
+ experiment_directory=hparams["output_folder"],
34
+ hyperparams_to_save=hparams_file,
35
+ overrides=overrides,
36
+ )
37
+ # Dataset prep (parsing Librispeech)
38
+
39
+ def dataio_prepare(hparams):
40
+ """This function prepares the datasets to be used in the brain class.
41
+ It also defines the data processing pipeline through user-defined functions."""
42
+
43
+ # 1. Define datasets
44
+ data_folder = hparams["data_folder"]
45
+
46
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
47
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
48
+ )
49
+
50
+ if hparams["sorting"] == "ascending":
51
+ # we sort training data to speed up training and get better results.
52
+ train_data = train_data.filtered_sorted(
53
+ sort_key="duration",
54
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
55
+ )
56
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
57
+ hparams["dataloader_options"]["shuffle"] = False
58
+
59
+ elif hparams["sorting"] == "descending":
60
+ train_data = train_data.filtered_sorted(
61
+ sort_key="duration",
62
+ reverse=True,
63
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
64
+ )
65
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
66
+ hparams["dataloader_options"]["shuffle"] = False
67
+
68
+ elif hparams["sorting"] == "random":
69
+ pass
70
+
71
+ else:
72
+ raise NotImplementedError(
73
+ "sorting must be random, ascending or descending"
74
+ )
75
+
76
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
77
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
78
+ )
79
+ # We also sort the validation data so it is faster to validate
80
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
81
+ test_datasets = {}
82
+ for csv_file in hparams["test_csv"]:
83
+ name = Path(csv_file).stem
84
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
85
+ csv_path=csv_file, replacements={"data_root": data_folder}
86
+ )
87
+ test_datasets[name] = test_datasets[name].filtered_sorted(
88
+ sort_key="duration"
89
+ )
90
+
91
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
92
+
93
+
94
+ # 2. Define audio pipeline:
95
+ @sb.utils.data_pipeline.takes("wav")
96
+ @sb.utils.data_pipeline.provides("sig")
97
+ def audio_pipeline(wav):
98
+ info = torchaudio.info(wav)
99
+ sig = sb.dataio.dataio.read_audio(wav)
100
+ if len(sig.shape)>1 :
101
+ sig = torch.mean(sig, dim=1)
102
+ resampled = torchaudio.transforms.Resample(
103
+ info.sample_rate, hparams["sample_rate"],
104
+ )(sig)
105
+ return resampled
106
+
107
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
108
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
109
+
110
+ # 3. Define text pipeline:
111
+ @sb.utils.data_pipeline.takes("wrd")
112
+ @sb.utils.data_pipeline.provides(
113
+ "wrd", "char_list", "tokens_list", "tokens"
114
+ )
115
+ def text_pipeline(wrd):
116
+ yield wrd
117
+ char_list = list(wrd)
118
+ yield char_list
119
+ tokens_list = label_encoder.encode_sequence(char_list)
120
+ yield tokens_list
121
+ tokens = torch.LongTensor(tokens_list)
122
+ yield tokens
123
+
124
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
125
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
126
+ special_labels = {
127
+ "blank_label": hparams["blank_index"],
128
+ "unk_label": hparams["unk_index"]
129
+ }
130
+ label_encoder.load_or_create(
131
+ path=lab_enc_file,
132
+ from_didatasets=[train_data],
133
+ output_key="char_list",
134
+ special_labels=special_labels,
135
+ sequence_input=True,
136
+ )
137
+
138
+ # 4. Set output:
139
+ sb.dataio.dataset.set_output_keys(
140
+ datasets, ["id", "sig", "wrd", "char_list", "tokens"],
141
+ )
142
+ return train_data, valid_data,test_datasets, label_encoder
143
+
144
+ class ASR(sb.core.Brain):
145
+ def compute_forward(self, batch, stage):
146
+ """Forward computations from the waveform batches to the output probabilities."""
147
+
148
+ batch = batch.to(self.device)
149
+ wavs, wav_lens = batch.sig
150
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
151
+
152
+ if stage == sb.Stage.TRAIN:
153
+ if hasattr(self.hparams, "augmentation"):
154
+ wavs = self.hparams.augmentation(wavs, wav_lens)
155
+
156
+ # Forward pass
157
+ feats = self.modules.wav2vec2(wavs, wav_lens)
158
+ x = self.modules.enc(feats)
159
+ logits = self.modules.ctc_lin(x)
160
+ p_ctc = self.hparams.log_softmax(logits)
161
+
162
+ return p_ctc, wav_lens
163
+
164
+ def custom_encode(self,wavs,wav_lens) :
165
+ wavs = wavs.to("cpu")
166
+ if(wav_lens is not None): wav_lens.to(self.device)
167
+
168
+ feats = self.modules.wav2vec2(wavs, wav_lens)
169
+ x = self.modules.enc(feats)
170
+ logits = self.modules.ctc_lin(x)
171
+ p_ctc = self.hparams.log_softmax(logits)
172
+
173
+ return feats,p_ctc
174
+
175
+
176
+
177
+ def compute_objectives(self, predictions, batch, stage):
178
+ """Computes the loss (CTC) given predictions and targets."""
179
+
180
+ p_ctc, wav_lens = predictions
181
+
182
+ ids = batch.id
183
+ tokens, tokens_lens = batch.tokens
184
+
185
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
186
+
187
+ if stage != sb.Stage.TRAIN:
188
+ predicted_tokens = sb.decoders.ctc_greedy_decode(
189
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
190
+ )
191
+ # Decode token terms to words
192
+ if self.hparams.use_language_modelling:
193
+ predicted_words = []
194
+ for logs in p_ctc:
195
+ text = decoder.decode(logs.detach().cpu().numpy())
196
+ predicted_words.append(text.split(" "))
197
+ else:
198
+ predicted_words = [
199
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
200
+ for utt_seq in predicted_tokens
201
+ ]
202
+ # Convert indices to words
203
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
204
+
205
+ self.wer_metric.append(ids, predicted_words, target_words)
206
+ self.cer_metric.append(ids, predicted_words, target_words)
207
+
208
+ return loss
209
+
210
+ def fit_batch(self, batch):
211
+ """Train the parameters given a single batch in input"""
212
+ should_step = self.step % self.grad_accumulation_factor == 0
213
+ # Managing automatic mixed precision
214
+ # TOFIX: CTC fine-tuning currently is unstable
215
+ # This is certainly due to CTC being done in fp16 instead of fp32
216
+ if self.auto_mix_prec:
217
+ with torch.cuda.amp.autocast():
218
+ with self.no_sync():
219
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
220
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
221
+ with self.no_sync(not should_step):
222
+ self.scaler.scale(
223
+ loss / self.grad_accumulation_factor
224
+ ).backward()
225
+ if should_step:
226
+
227
+ if not self.hparams.wav2vec2.freeze:
228
+ self.scaler.unscale_(self.wav2vec_optimizer)
229
+ self.scaler.unscale_(self.model_optimizer)
230
+ if self.check_gradients(loss):
231
+ if not self.hparams.wav2vec2.freeze:
232
+ if self.optimizer_step >= self.hparams.warmup_steps:
233
+ self.scaler.step(self.wav2vec_optimizer)
234
+ self.scaler.step(self.model_optimizer)
235
+ self.scaler.update()
236
+ self.zero_grad()
237
+ self.optimizer_step += 1
238
+ else:
239
+ # This is mandatory because HF models have a weird behavior with DDP
240
+ # on the forward pass
241
+ with self.no_sync():
242
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
243
+
244
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
245
+
246
+ with self.no_sync(not should_step):
247
+ (loss / self.grad_accumulation_factor).backward()
248
+ if should_step:
249
+ if self.check_gradients(loss):
250
+ if not self.hparams.wav2vec2.freeze:
251
+ if self.optimizer_step >= self.hparams.warmup_steps:
252
+ self.wav2vec_optimizer.step()
253
+ self.model_optimizer.step()
254
+ self.zero_grad()
255
+ self.optimizer_step += 1
256
+
257
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
258
+ return loss.detach().cpu()
259
+
260
+ def evaluate_batch(self, batch, stage):
261
+ """Computations needed for validation/test batches"""
262
+ predictions = self.compute_forward(batch, stage=stage)
263
+ with torch.no_grad():
264
+ loss = self.compute_objectives(predictions, batch, stage=stage)
265
+ return loss.detach()
266
+
267
+ def on_stage_start(self, stage, epoch):
268
+ """Gets called at the beginning of each epoch"""
269
+ if stage != sb.Stage.TRAIN:
270
+ self.cer_metric = self.hparams.cer_computer()
271
+ self.wer_metric = self.hparams.error_rate_computer()
272
+
273
+ def on_stage_end(self, stage, stage_loss, epoch):
274
+ """Gets called at the end of an epoch."""
275
+ # Compute/store important stats
276
+ stage_stats = {"loss": stage_loss}
277
+ if stage == sb.Stage.TRAIN:
278
+ self.train_stats = stage_stats
279
+ else:
280
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
281
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
282
+
283
+ # Perform end-of-iteration things, like annealing, logging, etc.
284
+ if stage == sb.Stage.VALID:
285
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
286
+ stage_stats["loss"]
287
+ )
288
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
289
+ stage_stats["loss"]
290
+ )
291
+ sb.nnet.schedulers.update_learning_rate(
292
+ self.model_optimizer, new_lr_model
293
+ )
294
+ if not self.hparams.wav2vec2.freeze:
295
+ sb.nnet.schedulers.update_learning_rate(
296
+ self.wav2vec_optimizer, new_lr_wav2vec
297
+ )
298
+ self.hparams.train_logger.log_stats(
299
+ stats_meta={
300
+ "epoch": epoch,
301
+ "lr_model": old_lr_model,
302
+ "lr_wav2vec": old_lr_wav2vec,
303
+ },
304
+ train_stats=self.train_stats,
305
+ valid_stats=stage_stats,
306
+ )
307
+ self.checkpointer.save_and_keep_only(
308
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
309
+ )
310
+ elif stage == sb.Stage.TEST:
311
+ self.hparams.train_logger.log_stats(
312
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
313
+ test_stats=stage_stats,
314
+ )
315
+ with open(self.hparams.wer_file, "w") as w:
316
+ self.wer_metric.write_stats(w)
317
+
318
+ def init_optimizers(self):
319
+ "Initializes the wav2vec2 optimizer and model optimizer"
320
+
321
+ # If the wav2vec encoder is unfrozen, we create the optimizer
322
+ if not self.hparams.wav2vec2.freeze:
323
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
324
+ self.modules.wav2vec2.parameters()
325
+ )
326
+ if self.checkpointer is not None:
327
+ self.checkpointer.add_recoverable(
328
+ "wav2vec_opt", self.wav2vec_optimizer
329
+ )
330
+
331
+ self.model_optimizer = self.hparams.model_opt_class(
332
+ self.hparams.model.parameters()
333
+ )
334
+
335
+ if self.checkpointer is not None:
336
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
337
+
338
+ def zero_grad(self, set_to_none=False):
339
+ if not self.hparams.wav2vec2.freeze:
340
+ self.wav2vec_optimizer.zero_grad(set_to_none)
341
+ self.model_optimizer.zero_grad(set_to_none)
342
+
343
+
344
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
345
+ french_asr_model = EncoderASR.from_hparams(source="asr-wav2vec2-commonvoice-fr", savedir="pretrained_models/asr-wav2vec2-commonvoice-fr").cuda()
346
+ french_asr_model.to("cpu")
347
+ cvhparams_file, cvrun_opts, cvoverrides = sb.parse_arguments(["EnglishCV/train_en_with_wav2vec.yaml"])
348
+ with open(cvhparams_file) as cvfin:
349
+ cvhparams = load_hyperpyyaml(cvfin, cvoverrides)
350
+ english_asr_model = ASRCV(
351
+ modules=cvhparams["modules"],
352
+ hparams=cvhparams,
353
+ run_opts=cvrun_opts,
354
+ checkpointer=cvhparams["checkpointer"],
355
+ )
356
+ english_asr_model.modules.to("cpu")
357
+ english_asr_model.checkpointer.recover_if_possible()
358
+ print("moving to tunisian model")
359
+ asr_brain = ASR(
360
+ modules=hparams["modules"],
361
+ hparams=hparams,
362
+ run_opts=run_opts,
363
+ checkpointer=hparams["checkpointer"],
364
+ )
365
+ asr_brain.modules.to("cpu")
366
+ asr_brain.checkpointer.recover_if_possible()
367
+ asr_brain.modules.eval()
368
+ english_asr_model.modules.eval()
369
+ french_asr_model.mods.eval()
370
+ asr_brain.modules.to("cpu")
371
+
372
+ # Commented out IPython magic to ensure Python compatibility.
373
+ # %ls
374
+
375
+ #UTILS FUNCTIOJNS
376
+ def get_size_dimensions(arr):
377
+ size_dimensions = []
378
+ while isinstance(arr, list):
379
+ size_dimensions.append(len(arr))
380
+ arr = arr[0]
381
+ return size_dimensions
382
+
383
+ def scale_array(batch,n):
384
+ scaled_batch = []
385
+
386
+ for array in batch:
387
+ if(n < len(array)): raise ValueError("Cannot scale Array down")
388
+
389
+ repeat = round(n/len(array))+1
390
+ scaled_length_array= []
391
+
392
+ for i in array:
393
+ for j in range(repeat) :
394
+ if(len(scaled_length_array) == n): break
395
+ scaled_length_array.append(i)
396
+
397
+ scaled_batch.append(scaled_length_array)
398
+
399
+ return torch.tensor(scaled_batch)
400
+
401
+
402
+ def load_paths(wavs_path):
403
+ waveforms = []
404
+ for path in wavs_path :
405
+ waveform, _ = torchaudio.load(path)
406
+ waveforms.append(waveform.squeeze(0))
407
+ # normalize array length to the bigger arrays by pading with 0's
408
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
409
+ return torch.tensor(padded_arrays)
410
+
411
+
412
+
413
+ device = 'cuda'
414
+ verbose = 0
415
+ #FLOW LEVEL FUNCTIONS
416
+ def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):
417
+
418
+
419
+ post1 = post1.to(device)
420
+ post2 = post2.to(device)
421
+ post3 = post3.to(device)
422
+ embeddings1 = embeddings1.to(device)
423
+ embeddings2 = embeddings2.to(device)
424
+ embeddings3 = embeddings3.to(device)
425
+
426
+ posteriograms_merged = torch.cat((post1,post2,post3),dim=2)
427
+ embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)
428
+
429
+ if(verbose !=0):
430
+ print('MERGED POST ',posteriograms_merged.shape)
431
+ print('MERGED emb ',embeddings_merged.shape)
432
+
433
+ return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)
434
+
435
+ def decode(model,wavs,wav_lens):
436
+
437
+ with torch.no_grad():
438
+ wav_lens = wav_lens.to(model.device)
439
+ encoder_out = model.encode_batch(wavs, wav_lens)
440
+ predictions = model.decoding_function(encoder_out, wav_lens)
441
+ return predictions
442
+
443
+ def middle_layer(batch, lens):
444
+
445
+ tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)
446
+
447
+ fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)
448
+ fr_posteriogram =french_asr_model.encode_batch(batch,lens)
449
+ en_embeddings = english_asr_model.modules.wav2vec2(batch, lens)
450
+ x = english_asr_model.modules.enc(en_embeddings)
451
+ en_posteriogram = english_asr_model.modules.ctc_lin(x)
452
+ #scores, en_posteriogram = english_asr_model.mods.decoder(en_embeddings ,lens)
453
+ if(verbose !=0):
454
+ print('[EMBEDDINGS] FR:',fr_embeddings.shape, "EN:",en_embeddings.shape, "TN:", tn_embeddings.shape)
455
+ print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, "EN:",en_posteriogram.shape,"TN:",tn_posteriogram.shape)
456
+
457
+
458
+ bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)
459
+ return bilangual_sample
460
+
461
+ class Mixer(sb.core.Brain):
462
+
463
+ def compute_forward(self, batch, stage):
464
+ """Forward computations from the waveform batches to the output probabilities."""
465
+ wavs, wav_lens = batch.sig
466
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
467
+
468
+ if stage == sb.Stage.TRAIN:
469
+ if hasattr(self.hparams, "augmentation"):
470
+ wavs = self.hparams.augmentation(wavs, wav_lens)
471
+
472
+ multi_langual_feats = middle_layer(wavs, wav_lens)
473
+ multi_langual_feats= multi_langual_feats.to(device)
474
+ feats, _ = self.modules.enc(multi_langual_feats)
475
+ logits = self.modules.ctc_lin(feats)
476
+ p_ctc = self.hparams.log_softmax(logits)
477
+
478
+ if stage!= sb.Stage.TRAIN:
479
+ p_tokens = sb.decoders.ctc_greedy_decode(
480
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
481
+ )
482
+ else :
483
+ p_tokens = None
484
+ return p_ctc, wav_lens, p_tokens
485
+
486
+
487
+ def treat_wav(self,sig):
488
+ multi_langual_feats = middle_layer(sig.to("cpu"), torch.tensor([1]).to("cpu"))
489
+ multi_langual_feats= multi_langual_feats.to(device)
490
+ feats, _ = self.modules.enc(multi_langual_feats)
491
+ logits = self.modules.ctc_lin(feats)
492
+ p_ctc = self.hparams.log_softmax(logits)
493
+ predicted_words =[]
494
+ for logs in p_ctc:
495
+ text = decoder.decode(logs.detach().cpu().numpy())
496
+ predicted_words.append(text.split(" "))
497
+ return " ".join(predicted_words[0])
498
+
499
+
500
+ def compute_objectives(self, predictions, batch, stage):
501
+ """Computes the loss (CTC) given predictions and targets."""
502
+
503
+ p_ctc, wav_lens , predicted_tokens= predictions
504
+
505
+ ids = batch.id
506
+ tokens, tokens_lens = batch.tokens
507
+
508
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
509
+
510
+
511
+ if stage == sb.Stage.VALID:
512
+ predicted_words = [
513
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
514
+ for utt_seq in predicted_tokens
515
+ ]
516
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
517
+ self.wer_metric.append(ids, predicted_words, target_words)
518
+ self.cer_metric.append(ids, predicted_words, target_words)
519
+ if stage ==sb.Stage.TEST :
520
+ if self.hparams.language_modelling:
521
+ predicted_words = []
522
+ for logs in p_ctc:
523
+ text = decoder.decode(logs.detach().cpu().numpy())
524
+ predicted_words.append(text.split(" "))
525
+ else :
526
+ predicted_words = [
527
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
528
+ for utt_seq in predicted_tokens
529
+ ]
530
+
531
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
532
+ self.wer_metric.append(ids, predicted_words, target_words)
533
+ self.cer_metric.append(ids, predicted_words, target_words)
534
+
535
+ return loss
536
+
537
+ def fit_batch(self, batch):
538
+ """Train the parameters given a single batch in input"""
539
+ should_step = self.step % self.grad_accumulation_factor == 0
540
+ # Managing automatic mixed precision
541
+ # TOFIX: CTC fine-tuning currently is unstable
542
+ # This is certainly due to CTC being done in fp16 instead of fp32
543
+ if self.auto_mix_prec:
544
+ with torch.cuda.amp.autocast():
545
+ with self.no_sync():
546
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
547
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
548
+ with self.no_sync(not should_step):
549
+ self.scaler.scale(
550
+ loss / self.grad_accumulation_factor
551
+ ).backward()
552
+ if should_step:
553
+
554
+
555
+ self.scaler.unscale_(self.model_optimizer)
556
+ if self.check_gradients(loss):
557
+ self.scaler.step(self.model_optimizer)
558
+ self.scaler.update()
559
+ self.zero_grad()
560
+ self.optimizer_step += 1
561
+ else:
562
+ # This is mandatory because HF models have a weird behavior with DDP
563
+ # on the forward pass
564
+ with self.no_sync():
565
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
566
+
567
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
568
+
569
+ with self.no_sync(not should_step):
570
+ (loss / self.grad_accumulation_factor).backward()
571
+ if should_step:
572
+ if self.check_gradients(loss):
573
+ self.model_optimizer.step()
574
+ self.zero_grad()
575
+ self.optimizer_step += 1
576
+
577
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
578
+ return loss.detach().cpu()
579
+
580
+ def evaluate_batch(self, batch, stage):
581
+ """Computations needed for validation/test batches"""
582
+ predictions = self.compute_forward(batch, stage=stage)
583
+ with torch.no_grad():
584
+ loss = self.compute_objectives(predictions, batch, stage=stage)
585
+ return loss.detach()
586
+
587
+ def on_stage_start(self, stage, epoch):
588
+ """Gets called at the beginning of each epoch"""
589
+ if stage != sb.Stage.TRAIN:
590
+ self.cer_metric = self.hparams.cer_computer()
591
+ self.wer_metric = self.hparams.error_rate_computer()
592
+
593
+ def on_stage_end(self, stage, stage_loss, epoch):
594
+ """Gets called at the end of an epoch."""
595
+ # Compute/store important stats
596
+ stage_stats = {"loss": stage_loss}
597
+ if stage == sb.Stage.TRAIN:
598
+ self.train_stats = stage_stats
599
+ else:
600
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
601
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
602
+
603
+ # Perform end-of-iteration things, like annealing, logging, etc.
604
+ if stage == sb.Stage.VALID:
605
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
606
+ stage_stats["loss"]
607
+ )
608
+ sb.nnet.schedulers.update_learning_rate(
609
+ self.model_optimizer, new_lr_model
610
+ )
611
+ self.hparams.train_logger.log_stats(
612
+ stats_meta={
613
+ "epoch": epoch,
614
+ "lr_model": old_lr_model,
615
+ },
616
+ train_stats=self.train_stats,
617
+ valid_stats=stage_stats,
618
+ )
619
+ self.checkpointer.save_and_keep_only(
620
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
621
+ )
622
+ elif stage == sb.Stage.TEST:
623
+ self.hparams.train_logger.log_stats(
624
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
625
+ test_stats=stage_stats,
626
+ )
627
+ with open(self.hparams.wer_file, "w") as w:
628
+ self.wer_metric.write_stats(w)
629
+
630
+ def init_optimizers(self):
631
+
632
+ self.model_optimizer = self.hparams.model_opt_class(
633
+ self.hparams.model.parameters()
634
+ )
635
+
636
+ if self.checkpointer is not None:
637
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
638
+
639
+ def zero_grad(self, set_to_none=False):
640
+
641
+ self.model_optimizer.zero_grad(set_to_none)
642
+
643
+
644
+
645
+
646
+ hparams_file, run_opts, overrides = sb.parse_arguments(["cs.yaml"])
647
+
648
+ # If distributed_launch=True then
649
+ # create ddp_group with the right communication protocol
650
+ sb.utils.distributed.ddp_init_group(run_opts)
651
+
652
+ with open(hparams_file) as fin:
653
+ hparams = load_hyperpyyaml(fin, overrides)
654
+
655
+ # Create experiment directory
656
+ sb.create_experiment_directory(
657
+ experiment_directory=hparams["output_folder"],
658
+ hyperparams_to_save=hparams_file,
659
+ overrides=overrides,
660
+ )
661
+ def read_labels_file(labels_file):
662
+ with open(labels_file, "r",encoding="utf-8") as lf:
663
+ lines = lf.read().splitlines()
664
+ division = "==="
665
+ numbers = {}
666
+ for line in lines :
667
+ if division in line :
668
+ break
669
+ string, number = line.split("=>")
670
+ number = int(number)
671
+ string = string[1:-2]
672
+ numbers[number] = string
673
+ return [numbers[x] for x in range(len(numbers))]
674
+
675
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
676
+
677
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
678
+ special_labels = {
679
+ "blank_label": hparams["blank_index"],
680
+ "unk_label": hparams["unk_index"]
681
+ }
682
+ label_encoder.load_or_create(
683
+ path=lab_enc_file,
684
+ from_didatasets=[[]],
685
+ output_key="char_list",
686
+ special_labels=special_labels,
687
+ sequence_input=True,
688
+ )
689
+
690
+
691
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
692
+ labels = [""] + labels[1:-1] + ["1"]
693
+ if hparams["language_modelling"]:
694
+ decoder = build_ctcdecoder(
695
+ labels,
696
+ kenlm_model_path=hparams["ngram_lm_path"], # either .arpa or .bin file
697
+ alpha=0.5, # tuned on a val set
698
+ beta=1, # tuned on a val set
699
+ )
700
+
701
+
702
+
703
+
704
+ mixer = Mixer(
705
+ modules=hparams["modules"],
706
+ hparams=hparams,
707
+ run_opts=run_opts,
708
+ checkpointer=hparams["checkpointer"],
709
+ )
710
+ mixer.tokenizer = label_encoder
711
+ mixer.checkpointer.recover_if_possible()
712
+ mixer.modules.eval()
713
+
714
+
715
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
716
+
717
+
718
+ # We dynamicaly add the tokenizer to our brain class.
719
+ # NB: This tokenizer corresponds to the one used for the LM!!
720
+
721
+ decoder = build_ctcdecoder(
722
+ labels,
723
+ kenlm_model_path= "arpas/everything.arpa", # either .arpa or .bin file
724
+ alpha=0.5, # tuned on a val set
725
+ beta=1, # tuned on a val set
726
+ )
727
+
728
+ run_opts["device"]="cpu"
729
+
730
+
731
+ device = "cpu"
732
+ mixer.device= "cpu"
733
+ mixer.modules.to("cpu")
734
+
735
+ from enum import Enum, auto
736
+ class Stage(Enum):
737
+ TRAIN = auto()
738
+ VALID = auto()
739
+ TEST = auto()
740
+
741
+ asr_brain.on_evaluate_start()
742
+ asr_brain.modules.eval()
743
+
744
+
745
+ import gradio as gr
746
+
747
+ def treat_wav_file(file_mic,file_upload ,asr=mixer, device="cpu") :
748
+ if (file_mic is not None) and (file_upload is not None):
749
+ warn_output = "WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
750
+ wav = file_mic
751
+ elif (file_mic is None) and (file_upload is None):
752
+ return "ERROR: You have to either use the microphone or upload an audio file"
753
+ elif file_mic is not None:
754
+ wav = file_mic
755
+ else:
756
+ wav = file_upload
757
+ sig, sr = torchaudio.load(wav)
758
+ tensor_wav = sig.to(device)
759
+ resampled = torchaudio.functional.resample( tensor_wav, sr, 16000)
760
+ sentence = asr.treat_wav(resampled)
761
+ return sentence
762
+
763
+ gr.Interface(
764
+ fn=treat_wav_file,
765
+ inputs=[gr.Audio(source="microphone", type='filepath', label = "record", optional = True),
766
+ gr.Audio(source="upload", type='filepath', label="filein", optional=True)]
767
+ ,outputs="text").launch()
768
+
TunisianASR/outdomain.arpa → arpas/pluslanguages_everything.arpa RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:24654c1d236bb1bd367125131c847c4a734e69914eda71a6786964c20440d8fe
3
- size 324243244
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a2a593a09bb07c74ac6f1aadbb4c0a6540edb53b9234f3da1aafa9e6d99f983
3
+ size 354742106
asr-wav2vec2-commonvoice-fr/hyperparams.yaml CHANGED
@@ -5,7 +5,7 @@
5
  # ################################
6
 
7
  sample_rate: 16000
8
- wav2vec2_hub: /gpfsscratch/rech/nou/uzn19yk/wav2vec2-FR-7K-large
9
 
10
  # BPE parameters
11
  token_type: unigram # ["unigram", "bpe", "char"]
 
5
  # ################################
6
 
7
  sample_rate: 16000
8
+ wav2vec2_hub: wav2vec2-FR-7K-large/
9
 
10
  # BPE parameters
11
  token_type: unigram # ["unigram", "bpe", "char"]
cs.yaml CHANGED
@@ -43,7 +43,7 @@ sorting: ascending
43
  auto_mix_prec: False
44
  sample_rate: 16000
45
  language_modelling: True
46
- ngram_lm_path: /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/arpas/pluslanguages_everything.arpa
47
 
48
  # With data_parallel batch_size is split into N jobs
49
  # With DDP batch_size is multiplied by N jobs
 
43
  auto_mix_prec: False
44
  sample_rate: 16000
45
  language_modelling: True
46
+ ngram_lm_path: arpas/pluslanguages_everything.arpa
47
 
48
  # With data_parallel batch_size is split into N jobs
49
  # With DDP batch_size is multiplied by N jobs
cv_train.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import speechbrain as sb
6
+ import torchaudio
7
+ from hyperpyyaml import load_hyperpyyaml
8
+ from speechbrain.tokenizers.SentencePiece import SentencePiece
9
+ from speechbrain.utils.data_utils import undo_padding
10
+ from speechbrain.utils.distributed import run_on_main
11
+
12
+ """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
13
+ The system employs a wav2vec2 encoder and a CTC decoder.
14
+ Decoding is performed with greedy decoding (will be extended to beam search).
15
+
16
+ To run this recipe, do the following:
17
+ > python train_with_wav2vec2.py hparams/train_with_wav2vec2.yaml
18
+
19
+ With the default hyperparameters, the system employs a pretrained wav2vec2 encoder.
20
+ The wav2vec2 model is pretrained following the model given in the hprams file.
21
+ It may be dependent on the language.
22
+
23
+ The neural network is trained with CTC on sub-word units estimated with
24
+ Byte Pairwise Encoding (BPE).
25
+
26
+ The experiment file is flexible enough to support a large variety of
27
+ different systems. By properly changing the parameter files, you can try
28
+ different encoders, decoders, tokens (e.g, characters instead of BPE),
29
+ training languages (all CommonVoice languages), and many
30
+ other possible variations.
31
+
32
+ Authors
33
+ * Titouan Parcollet 2021
34
+ """
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # Define training procedure
40
+ class ASRCV(sb.core.Brain):
41
+ def compute_forward(self, batch, stage):
42
+ """Forward computations from the waveform batches to the output probabilities."""
43
+
44
+ batch = batch.to(self.device)
45
+ wavs, wav_lens = batch.sig
46
+ tokens_bos, _ = batch.tokens_bos
47
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
48
+
49
+ if stage == sb.Stage.TRAIN:
50
+ if hasattr(self.hparams, "augmentation"):
51
+ wavs = self.hparams.augmentation(wavs, wav_lens)
52
+
53
+ # Forward pass
54
+ feats = self.modules.wav2vec2(wavs, wav_lens)
55
+ x = self.modules.enc(feats)
56
+ logits = self.modules.ctc_lin(x)
57
+ p_ctc = self.hparams.log_softmax(logits)
58
+
59
+ return p_ctc, wav_lens
60
+
61
+ def compute_objectives(self, predictions, batch, stage):
62
+ """Computes the loss (CTC) given predictions and targets."""
63
+
64
+ p_ctc, wav_lens = predictions
65
+
66
+ ids = batch.id
67
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
68
+ tokens, tokens_lens = batch.tokens
69
+
70
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
71
+
72
+ if stage != sb.Stage.TRAIN:
73
+ # Decode token terms to words
74
+ sequence = sb.decoders.ctc_greedy_decode(
75
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
76
+ )
77
+
78
+ predicted_words = self.tokenizer(sequence, task="decode_from_list")
79
+
80
+ # Convert indices to words
81
+ target_words = undo_padding(tokens, tokens_lens)
82
+ target_words = self.tokenizer(target_words, task="decode_from_list")
83
+
84
+ self.wer_metric.append(ids, predicted_words, target_words)
85
+ self.cer_metric.append(ids, predicted_words, target_words)
86
+
87
+ return loss
88
+
89
+ def fit_batch(self, batch):
90
+ """Train the parameters given a single batch in input"""
91
+ should_step = self.step % self.grad_accumulation_factor == 0
92
+ # Managing automatic mixed precision
93
+ # TOFIX: CTC fine-tuning currently is unstable
94
+ # This is certainly due to CTC being done in fp16 instead of fp32
95
+ if self.auto_mix_prec:
96
+ with torch.cuda.amp.autocast():
97
+ with self.no_sync():
98
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
99
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
100
+ with self.no_sync(not should_step):
101
+ self.scaler.scale(
102
+ loss / self.grad_accumulation_factor
103
+ ).backward()
104
+ if should_step:
105
+
106
+ if not self.hparams.wav2vec2.freeze:
107
+ self.scaler.unscale_(self.wav2vec_optimizer)
108
+ self.scaler.unscale_(self.model_optimizer)
109
+ if self.check_gradients(loss):
110
+ if not self.hparams.wav2vec2.freeze:
111
+ if self.optimizer_step >= self.hparams.warmup_steps:
112
+ self.scaler.step(self.wav2vec_optimizer)
113
+ self.scaler.step(self.model_optimizer)
114
+ self.scaler.update()
115
+ self.zero_grad()
116
+ self.optimizer_step += 1
117
+ else:
118
+ # This is mandatory because HF models have a weird behavior with DDP
119
+ # on the forward pass
120
+ with self.no_sync():
121
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
122
+
123
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
124
+
125
+ with self.no_sync(not should_step):
126
+ (loss / self.grad_accumulation_factor).backward()
127
+ if should_step:
128
+ if self.check_gradients(loss):
129
+ if not self.hparams.wav2vec2.freeze:
130
+ if self.optimizer_step >= self.hparams.warmup_steps:
131
+ self.wav2vec_optimizer.step()
132
+ self.model_optimizer.step()
133
+ self.zero_grad()
134
+ self.optimizer_step += 1
135
+
136
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
137
+ return loss.detach().cpu()
138
+
139
+ def evaluate_batch(self, batch, stage):
140
+ """Computations needed for validation/test batches"""
141
+ predictions = self.compute_forward(batch, stage=stage)
142
+ with torch.no_grad():
143
+ loss = self.compute_objectives(predictions, batch, stage=stage)
144
+ return loss.detach()
145
+
146
+ def on_stage_start(self, stage, epoch):
147
+ """Gets called at the beginning of each epoch"""
148
+ if stage != sb.Stage.TRAIN:
149
+ self.cer_metric = self.hparams.cer_computer()
150
+ self.wer_metric = self.hparams.error_rate_computer()
151
+
152
+ def on_stage_end(self, stage, stage_loss, epoch):
153
+ """Gets called at the end of an epoch."""
154
+ # Compute/store important stats
155
+ stage_stats = {"loss": stage_loss}
156
+ if stage == sb.Stage.TRAIN:
157
+ self.train_stats = stage_stats
158
+ else:
159
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
160
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
161
+
162
+ # Perform end-of-iteration things, like annealing, logging, etc.
163
+ if stage == sb.Stage.VALID:
164
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
165
+ stage_stats["loss"]
166
+ )
167
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
168
+ stage_stats["loss"]
169
+ )
170
+ sb.nnet.schedulers.update_learning_rate(
171
+ self.model_optimizer, new_lr_model
172
+ )
173
+ if not self.hparams.wav2vec2.freeze:
174
+ sb.nnet.schedulers.update_learning_rate(
175
+ self.wav2vec_optimizer, new_lr_wav2vec
176
+ )
177
+ self.hparams.train_logger.log_stats(
178
+ stats_meta={
179
+ "epoch": epoch,
180
+ "lr_model": old_lr_model,
181
+ "lr_wav2vec": old_lr_wav2vec,
182
+ },
183
+ train_stats=self.train_stats,
184
+ valid_stats=stage_stats,
185
+ )
186
+ self.checkpointer.save_and_keep_only(
187
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
188
+ )
189
+ elif stage == sb.Stage.TEST:
190
+ self.hparams.train_logger.log_stats(
191
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
192
+ test_stats=stage_stats,
193
+ )
194
+ with open(self.hparams.wer_file, "w") as w:
195
+ self.wer_metric.write_stats(w)
196
+
197
+ def init_optimizers(self):
198
+ "Initializes the wav2vec2 optimizer and model optimizer"
199
+
200
+ # If the wav2vec encoder is unfrozen, we create the optimizer
201
+ if not self.hparams.wav2vec2.freeze:
202
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
203
+ self.modules.wav2vec2.parameters()
204
+ )
205
+ if self.checkpointer is not None:
206
+ self.checkpointer.add_recoverable(
207
+ "wav2vec_opt", self.wav2vec_optimizer
208
+ )
209
+
210
+ self.model_optimizer = self.hparams.model_opt_class(
211
+ self.hparams.model.parameters()
212
+ )
213
+
214
+ if self.checkpointer is not None:
215
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
216
+
217
+ def zero_grad(self, set_to_none=False):
218
+ if not self.hparams.wav2vec2.freeze:
219
+ self.wav2vec_optimizer.zero_grad(set_to_none)
220
+ self.model_optimizer.zero_grad(set_to_none)
221
+
222
+
223
+ # Define custom data procedure
224
+ def dataio_prepare(hparams, tokenizer):
225
+ """This function prepares the datasets to be used in the brain class.
226
+ It also defines the data processing pipeline through user-defined functions."""
227
+
228
+ # 1. Define datasets
229
+ data_folder = hparams["data_folder"]
230
+
231
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
232
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
233
+ )
234
+
235
+ if hparams["sorting"] == "ascending":
236
+ # we sort training data to speed up training and get better results.
237
+ train_data = train_data.filtered_sorted(
238
+ sort_key="duration",
239
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
240
+ )
241
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
242
+ hparams["dataloader_options"]["shuffle"] = False
243
+
244
+ elif hparams["sorting"] == "descending":
245
+ train_data = train_data.filtered_sorted(
246
+ sort_key="duration",
247
+ reverse=True,
248
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
249
+ )
250
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
251
+ hparams["dataloader_options"]["shuffle"] = False
252
+
253
+ elif hparams["sorting"] == "random":
254
+ pass
255
+
256
+ else:
257
+ raise NotImplementedError(
258
+ "sorting must be random, ascending or descending"
259
+ )
260
+
261
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
262
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
263
+ )
264
+ # We also sort the validation data so it is faster to validate
265
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
266
+
267
+ test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
268
+ csv_path=hparams["test_csv"], replacements={"data_root": data_folder},
269
+ )
270
+
271
+ # We also sort the validation data so it is faster to validate
272
+ test_data = test_data.filtered_sorted(sort_key="duration")
273
+
274
+ datasets = [train_data, valid_data, test_data]
275
+
276
+ # 2. Define audio pipeline:
277
+ @sb.utils.data_pipeline.takes("wav")
278
+ @sb.utils.data_pipeline.provides("sig")
279
+ def audio_pipeline(wav):
280
+ info = torchaudio.info(wav)
281
+ sig = sb.dataio.dataio.read_audio(wav)
282
+ resampled = torchaudio.transforms.Resample(
283
+ info.sample_rate, hparams["sample_rate"],
284
+ )(sig)
285
+ return resampled
286
+
287
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
288
+
289
+ # 3. Define text pipeline:
290
+ @sb.utils.data_pipeline.takes("wrd")
291
+ @sb.utils.data_pipeline.provides(
292
+ "tokens_list", "tokens_bos", "tokens_eos", "tokens"
293
+ )
294
+ def text_pipeline(wrd):
295
+ tokens_list = tokenizer.sp.encode_as_ids(wrd)
296
+ yield tokens_list
297
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
298
+ yield tokens_bos
299
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
300
+ yield tokens_eos
301
+ tokens = torch.LongTensor(tokens_list)
302
+ yield tokens
303
+
304
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
305
+
306
+ # 4. Set output:
307
+ sb.dataio.dataset.set_output_keys(
308
+ datasets, ["id", "sig", "tokens_bos", "tokens_eos", "tokens"],
309
+ )
310
+ return train_data, valid_data, test_data
311
+
312
+
313
+ if __name__ == "__main__":
314
+
315
+ # Load hyperparameters file with command-line overrides
316
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
317
+ with open(hparams_file) as fin:
318
+ hparams = load_hyperpyyaml(fin, overrides)
319
+
320
+ # If --distributed_launch then
321
+ # create ddp_group with the right communication protocol
322
+ sb.utils.distributed.ddp_init_group(run_opts)
323
+
324
+ # Dataset preparation (parsing CommonVoice)
325
+ from common_voice_prepare import prepare_common_voice # noqa
326
+
327
+ # Create experiment directory
328
+ sb.create_experiment_directory(
329
+ experiment_directory=hparams["output_folder"],
330
+ hyperparams_to_save=hparams_file,
331
+ overrides=overrides,
332
+ )
333
+
334
+ # Due to DDP, we do the preparation ONLY on the main python process
335
+ run_on_main(
336
+ prepare_common_voice,
337
+ kwargs={
338
+ "data_folder": hparams["data_folder"],
339
+ "save_folder": hparams["save_folder"],
340
+ "train_tsv_file": hparams["train_tsv_file"],
341
+ "dev_tsv_file": hparams["dev_tsv_file"],
342
+ "test_tsv_file": hparams["test_tsv_file"],
343
+ "accented_letters": hparams["accented_letters"],
344
+ "language": hparams["language"],
345
+ "skip_prep": hparams["skip_prep"],
346
+ },
347
+ )
348
+
349
+ # Defining tokenizer and loading it
350
+ tokenizer = SentencePiece(
351
+ model_dir=hparams["save_folder"],
352
+ vocab_size=hparams["output_neurons"],
353
+ annotation_train=hparams["train_csv"],
354
+ annotation_read="wrd",
355
+ model_type=hparams["token_type"],
356
+ character_coverage=hparams["character_coverage"],
357
+ )
358
+
359
+ # Create the datasets objects as well as tokenization and encoding :-D
360
+ train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer)
361
+
362
+ # Trainer initialization
363
+ asr_brain = ASRCV(
364
+ modules=hparams["modules"],
365
+ hparams=hparams,
366
+ run_opts=run_opts,
367
+ checkpointer=hparams["checkpointer"],
368
+ )
369
+
370
+ # Adding objects to trainer.
371
+ asr_brain.tokenizer = tokenizer
372
+
373
+ # Training
374
+ asr_brain.fit(
375
+ asr_brain.hparams.epoch_counter,
376
+ train_data,
377
+ valid_data,
378
+ train_loader_kwargs=hparams["dataloader_options"],
379
+ valid_loader_kwargs=hparams["test_dataloader_options"],
380
+ )
381
+
382
+ # Test
383
+ asr_brain.hparams.wer_file = hparams["output_folder"] + "/wer_test.txt"
384
+ asr_brain.evaluate(
385
+ test_data,
386
+ min_key="WER",
387
+ test_loader_kwargs=hparams["test_dataloader_options"],
388
+ )
pretrained_models/asr-wav2vec2-commonvoice-fr/asr.ckpt ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/salah/Code_Switched_Tunisian_Speech_Recognition/asr-wav2vec2-commonvoice-fr/asr.ckpt
pretrained_models/asr-wav2vec2-commonvoice-fr/custom.py ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/salah/Code_Switched_Tunisian_Speech_Recognition/asr-wav2vec2-commonvoice-fr/custom.py
pretrained_models/asr-wav2vec2-commonvoice-fr/hyperparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/salah/Code_Switched_Tunisian_Speech_Recognition/asr-wav2vec2-commonvoice-fr/hyperparams.yaml
pretrained_models/asr-wav2vec2-commonvoice-fr/tokenizer.ckpt ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/salah/Code_Switched_Tunisian_Speech_Recognition/asr-wav2vec2-commonvoice-fr/tokenizer.ckpt
pretrained_models/asr-wav2vec2-commonvoice-fr/wav2vec2.ckpt ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/salah/Code_Switched_Tunisian_Speech_Recognition/asr-wav2vec2-commonvoice-fr/wav2vec2.ckpt
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub>=0.7.0
2
+ gradio
3
+ https://github.com/kpu/kenlm/archive/master.zip
4
+ hyperpyyaml>=0.0.1
5
+ joblib>=0.14.1
6
+ numpy>=1.17.0
7
+ packaging
8
+ pre-commit>=2.3.0
9
+ scipy>=1.4.1, <1.9
10
+ sentencepiece>=0.1.91
11
+ SoundFile; sys_platform == 'win32'
12
+ torch>=1.13.0
13
+ torchaudio>=0.9.0
14
+ tqdm>=4.42.0
15
+ transformers
16
+ speechbrain
17
+ pyctcdecode
results/non_semi_final_stac/env.log ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SpeechBrain system description
2
+ ==============================
3
+ Python version:
4
+ 3.8.5 (default, Sep 4 2020, 07:30:14)
5
+ [GCC 7.3.0]
6
+ ==============================
7
+ Installed Python packages:
8
+ absl-py==1.2.0
9
+ aiohttp==3.8.1
10
+ aiosignal==1.2.0
11
+ alabaster==0.7.12
12
+ anaconda-client==1.7.2
13
+ anaconda-navigator==1.10.0
14
+ anaconda-project==0.8.3
15
+ antlr4-python3-runtime==4.9.3
16
+ appdirs==1.4.4
17
+ argh==0.26.2
18
+ argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1596828493937/work
19
+ asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work
20
+ astroid @ file:///tmp/build/80754af9/astroid_1592495912941/work
21
+ astropy==4.0.2
22
+ async-generator==1.10
23
+ async-timeout==4.0.2
24
+ atomicwrites==1.4.0
25
+ attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work
26
+ audioread==2.1.9
27
+ autopep8 @ file:///tmp/build/80754af9/autopep8_1596578164842/work
28
+ Babel @ file:///tmp/build/80754af9/babel_1605108370292/work
29
+ backcall==0.2.0
30
+ backports.functools-lru-cache==1.6.1
31
+ backports.shutil-get-terminal-size==1.0.0
32
+ backports.tempfile==1.0
33
+ backports.weakref==1.0.post1
34
+ beautifulsoup4 @ file:///tmp/build/80754af9/beautifulsoup4_1601924105527/work
35
+ bitarray @ file:///tmp/build/80754af9/bitarray_1605065113847/work
36
+ bkcharts==0.2
37
+ black==22.12.0
38
+ bleach @ file:///tmp/build/80754af9/bleach_1600439572647/work
39
+ bokeh @ file:///tmp/build/80754af9/bokeh_1603297833684/work
40
+ boto==2.49.0
41
+ boto3==1.28.43
42
+ botocore==1.31.43
43
+ Bottleneck==1.3.2
44
+ bpemb==0.3.4
45
+ brotlipy==0.7.0
46
+ cachetools==5.2.0
47
+ certifi==2020.6.20
48
+ cffi @ file:///tmp/build/80754af9/cffi_1600699146221/work
49
+ chardet==3.0.4
50
+ charset-normalizer==2.0.12
51
+ click==8.1.3
52
+ cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work
53
+ clyent==1.2.2
54
+ colorama @ file:///tmp/build/80754af9/colorama_1603211150991/work
55
+ coloredlogs==15.0.1
56
+ conda==4.9.2
57
+ conda-build==3.20.5
58
+ conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1603018141399/work
59
+ conda-verify==3.4.2
60
+ conllu==4.5.3
61
+ contextlib2==0.6.0.post1
62
+ cryptography @ file:///tmp/build/80754af9/cryptography_1601046815590/work
63
+ cycler==0.10.0
64
+ Cython @ file:///tmp/build/80754af9/cython_1594831566883/work
65
+ cytoolz==0.11.0
66
+ dask @ file:///tmp/build/80754af9/dask-core_1602083700509/work
67
+ datasets==1.18.3
68
+ decorator==4.4.2
69
+ defusedxml==0.6.0
70
+ Deprecated==1.2.14
71
+ diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work
72
+ dill==0.3.4
73
+ distributed @ file:///tmp/build/80754af9/distributed_1605066520644/work
74
+ docutils==0.16
75
+ easyocr==1.2.1
76
+ einops==0.3.0
77
+ entrypoints==0.3
78
+ et-xmlfile==1.0.1
79
+ farasapy==0.0.14
80
+ fastcache==1.1.0
81
+ ffmpeg-python==0.2.0
82
+ filelock==3.0.12
83
+ flair==0.12.2
84
+ flake8 @ file:///tmp/build/80754af9/flake8_1601911421857/work
85
+ Flask==1.1.2
86
+ flatbuffers==22.9.24
87
+ frozenlist==1.3.0
88
+ fsspec==2022.3.0
89
+ ftfy==6.1.1
90
+ future==0.18.2
91
+ gdown==4.4.0
92
+ gensim==4.1.2
93
+ gevent @ file:///tmp/build/80754af9/gevent_1601397537062/work
94
+ glob2==0.7
95
+ gmpy2==2.0.8
96
+ google-auth==2.12.0
97
+ google-auth-oauthlib==0.4.6
98
+ greenlet @ file:///tmp/build/80754af9/greenlet_1600874013538/work
99
+ grpcio==1.49.1
100
+ h5py==2.10.0
101
+ HeapDict==1.0.1
102
+ html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work
103
+ huggingface-hub==0.16.4
104
+ humanfriendly==10.0
105
+ hyperopt==0.2.7
106
+ idna @ file:///tmp/build/80754af9/idna_1593446292537/work
107
+ imageio @ file:///tmp/build/80754af9/imageio_1594161405741/work
108
+ imagesize==1.2.0
109
+ imhist==0.0.4
110
+ importlib-metadata==5.0.0
111
+ imWatermark==0.0.2
112
+ iniconfig @ file:///tmp/build/80754af9/iniconfig_1602780191262/work
113
+ install==1.3.5
114
+ intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work
115
+ invisible-watermark==0.1.5
116
+ ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl
117
+ ipython @ file:///tmp/build/80754af9/ipython_1604101197014/work
118
+ ipython-genutils==0.2.0
119
+ ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1601490159889/work
120
+ isort @ file:///tmp/build/80754af9/isort_1602603989581/work
121
+ itsdangerous==1.1.0
122
+ Janome==0.5.0
123
+ jdcal==1.4.1
124
+ jedi @ file:///tmp/build/80754af9/jedi_1592841866100/work
125
+ jeepney @ file:///tmp/build/80754af9/jeepney_1605069705079/work
126
+ Jinja2==2.11.2
127
+ jiwer==2.3.0
128
+ jmespath==1.0.1
129
+ joblib @ file:///tmp/build/80754af9/joblib_1601912903842/work
130
+ json5==0.9.5
131
+ jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
132
+ jupyter==1.0.0
133
+ jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1601311786391/work
134
+ jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1598884538475/work
135
+ jupyter-core==4.6.3
136
+ jupyterlab==2.2.6
137
+ jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
138
+ jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1594164409481/work
139
+ keyring @ file:///tmp/build/80754af9/keyring_1601490835422/work
140
+ kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1604014535162/work
141
+ langdetect==1.0.9
142
+ lazy-object-proxy==1.4.3
143
+ libarchive-c==2.9
144
+ librosa==0.9.1
145
+ llvmlite==0.34.0
146
+ locket==0.2.0
147
+ lxml @ file:///tmp/build/80754af9/lxml_1603216285000/work
148
+ Markdown==3.4.1
149
+ MarkupSafe==1.1.1
150
+ matplotlib @ file:///tmp/build/80754af9/matplotlib-base_1603378225747/work
151
+ mccabe==0.6.1
152
+ mido==1.2.10
153
+ mistune==0.8.4
154
+ mkl-fft==1.2.0
155
+ mkl-random==1.1.1
156
+ mkl-service==2.3.0
157
+ mock==4.0.2
158
+ more-itertools @ file:///tmp/build/80754af9/more-itertools_1605111547926/work
159
+ mpld3==0.3
160
+ mpmath==1.1.0
161
+ msgpack==1.0.0
162
+ multidict==6.0.2
163
+ multipledispatch==0.6.0
164
+ multiprocess==0.70.12.2
165
+ mypy-extensions==0.4.3
166
+ navigator-updater==0.2.1
167
+ nbclient @ file:///tmp/build/80754af9/nbclient_1602783176460/work
168
+ nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914830498/work
169
+ nbformat @ file:///tmp/build/80754af9/nbformat_1602783287752/work
170
+ nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1605115881283/work
171
+ networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work
172
+ nltk @ file:///tmp/build/80754af9/nltk_1592496090529/work
173
+ nose==1.3.7
174
+ notebook @ file:///tmp/build/80754af9/notebook_1601501575118/work
175
+ numba @ file:///tmp/build/80754af9/numba_1600100669015/work
176
+ numexpr==2.7.1
177
+ numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1603570489231/work
178
+ numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work
179
+ oauthlib==3.2.1
180
+ olefile==0.46
181
+ omegaconf==2.2.3
182
+ onnx==1.12.0
183
+ onnxruntime==1.12.1
184
+ opencv-python==4.4.0.46
185
+ openpyxl @ file:///tmp/build/80754af9/openpyxl_1598113097404/work
186
+ packaging==20.9
187
+ pandas @ file:///tmp/build/80754af9/pandas_1602088120436/work
188
+ pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work
189
+ parso==0.7.0
190
+ partd==1.1.0
191
+ path @ file:///tmp/build/80754af9/path_1598376507494/work
192
+ pathlib2==2.3.5
193
+ pathspec==0.10.3
194
+ pathtools==0.1.2
195
+ patsy==0.5.1
196
+ pep8==1.7.1
197
+ pexpect==4.8.0
198
+ pickleshare==0.7.5
199
+ Pillow @ file:///tmp/build/80754af9/pillow_1603822255246/work
200
+ pkginfo==1.6.1
201
+ platformdirs==2.6.0
202
+ pluggy==0.13.1
203
+ ply==3.11
204
+ pooch==1.6.0
205
+ pptree==3.1
206
+ pretty-midi==0.2.9
207
+ prometheus-client==0.8.0
208
+ prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work
209
+ protobuf==3.19.6
210
+ psutil @ file:///tmp/build/80754af9/psutil_1598370257551/work
211
+ ptyprocess==0.6.0
212
+ py @ file:///tmp/build/80754af9/py_1593446248552/work
213
+ py-espeak-ng==0.1.8
214
+ py4j==0.10.9.7
215
+ PyArabic==0.6.15
216
+ pyarrow==7.0.0
217
+ pyasn1==0.4.8
218
+ pyasn1-modules==0.2.8
219
+ pycodestyle==2.6.0
220
+ pycosat==0.6.3
221
+ pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
222
+ pycurl==7.43.0.6
223
+ pyDeprecate==0.3.1
224
+ pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1598885001695/work
225
+ pyflakes==2.2.0
226
+ Pygments @ file:///tmp/build/80754af9/pygments_1604103097372/work
227
+ pylint @ file:///tmp/build/80754af9/pylint_1598623985952/work
228
+ pyodbc===4.0.0-unsupported
229
+ pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1594392929924/work
230
+ pyparsing==2.4.7
231
+ pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work
232
+ PySocks==1.7.1
233
+ pytest==0.0.0
234
+ python-bidi==0.4.2
235
+ python-crfsuite==0.9.7
236
+ python-dateutil==2.8.1
237
+ python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work
238
+ python-language-server @ file:///tmp/build/80754af9/python-language-server_1600454544709/work
239
+ python-Levenshtein==0.12.2
240
+ pytorch-lightning==1.4.2
241
+ pytorch-revgrad==0.2.0
242
+ pytz==2020.1
243
+ PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work
244
+ pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work
245
+ PyYAML==5.3.1
246
+ pyzmq==19.0.2
247
+ QDarkStyle==2.8.1
248
+ QtAwesome @ file:///tmp/build/80754af9/qtawesome_1602272867890/work
249
+ qtconsole @ file:///tmp/build/80754af9/qtconsole_1600870028330/work
250
+ QtPy==1.9.0
251
+ regex @ file:///tmp/build/80754af9/regex_1602786672676/work
252
+ requests @ file:///tmp/build/80754af9/requests_1592841827918/work
253
+ requests-oauthlib==1.3.1
254
+ resampy==0.2.2
255
+ rope @ file:///tmp/build/80754af9/rope_1602264064449/work
256
+ rsa==4.9
257
+ Rtree==0.9.4
258
+ ruamel-yaml==0.15.87
259
+ s3transfer==0.6.2
260
+ sacremoses==0.0.49
261
+ safetensors==0.3.3
262
+ scikit-image==0.17.2
263
+ scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1598376899566/work
264
+ scipy @ file:///tmp/build/80754af9/scipy_1597686649129/work
265
+ seaborn @ file:///tmp/build/80754af9/seaborn_1600553570093/work
266
+ SecretStorage==3.1.2
267
+ segtok==1.5.11
268
+ Send2Trash==1.5.0
269
+ sentencepiece==0.1.97
270
+ simplegeneric==0.8.1
271
+ singledispatch @ file:///tmp/build/80754af9/singledispatch_1602523705405/work
272
+ sip==4.19.13
273
+ six @ file:///tmp/build/80754af9/six_1605205327372/work
274
+ smart-open==5.2.1
275
+ snowballstemmer==2.0.0
276
+ sortedcollections==1.2.1
277
+ sortedcontainers==2.2.2
278
+ SoundFile==0.10.3.post1
279
+ soupsieve==2.0.1
280
+ sphfile==1.0.3
281
+ Sphinx @ file:///tmp/build/80754af9/sphinx_1597428793432/work
282
+ sphinxcontrib-applehelp==1.0.2
283
+ sphinxcontrib-devhelp==1.0.2
284
+ sphinxcontrib-htmlhelp==1.0.3
285
+ sphinxcontrib-jsmath==1.0.1
286
+ sphinxcontrib-qthelp==1.0.3
287
+ sphinxcontrib-serializinghtml==1.1.4
288
+ sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work
289
+ spyder @ file:///tmp/build/80754af9/spyder_1599056981321/work
290
+ spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1599056754858/work
291
+ SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1603397987316/work
292
+ sqlitedict==2.1.0
293
+ statsmodels @ file:///tmp/build/80754af9/statsmodels_1602280205159/work
294
+ sympy @ file:///tmp/build/80754af9/sympy_1605119542615/work
295
+ tables==3.6.1
296
+ tabulate==0.9.0
297
+ tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work
298
+ tensorboard==2.10.1
299
+ tensorboard-data-server==0.6.1
300
+ tensorboard-plugin-wit==1.8.1
301
+ terminado==0.9.1
302
+ testpath==0.4.4
303
+ threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl
304
+ tifffile==2020.10.1
305
+ tkseem==0.0.3
306
+ tokenizers==0.13.3
307
+ toml @ file:///tmp/build/80754af9/toml_1592853716807/work
308
+ tomli==2.0.1
309
+ toolz @ file:///tmp/build/80754af9/toolz_1601054250827/work
310
+ torch==1.11.0
311
+ torchaudio==0.11.0
312
+ torchmetrics==0.6.0
313
+ torchvision==0.8.2
314
+ tornado==6.0.4
315
+ tqdm==4.64.0
316
+ traitlets @ file:///tmp/build/80754af9/traitlets_1602787416690/work
317
+ transformer-smaller-training-vocab==0.3.1
318
+ transformers==4.33.1
319
+ typing-extensions==4.4.0
320
+ ujson @ file:///tmp/build/80754af9/ujson_1602523317881/work
321
+ unicodecsv==0.14.1
322
+ urllib3 @ file:///tmp/build/80754af9/urllib3_1603305693037/work
323
+ watchdog @ file:///tmp/build/80754af9/watchdog_1593447344699/work
324
+ wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
325
+ webencodings==0.5.1
326
+ Werkzeug==1.0.1
327
+ widgetsnbextension==3.5.1
328
+ Wikipedia-API==0.6.0
329
+ wrapt==1.11.2
330
+ wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1594753850195/work
331
+ xlrd==1.2.0
332
+ XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1602692860603/work
333
+ xlwt==1.3.0
334
+ xmltodict==0.12.0
335
+ xxhash==3.0.0
336
+ yapf @ file:///tmp/build/80754af9/yapf_1593528177422/work
337
+ yarl==1.7.2
338
+ zict==2.0.0
339
+ zipp @ file:///tmp/build/80754af9/zipp_1604001098328/work
340
+ zope.event==4.5.0
341
+ zope.interface @ file:///tmp/build/80754af9/zope.interface_1602002420968/work
342
+ ==============================
343
+ Git revision:
344
+ 8a51838
345
+ ==============================
346
+ CUDA version:
347
+ 11.7
results/non_semi_final_stac/hyperparams.yaml CHANGED
@@ -1,5 +1,5 @@
1
- # Generated 2023-09-08 from:
2
- # /gpfsssd/scratch/rech/nou/uzn19yk/switched_data/stac.yaml
3
  # yamllint disable
4
  # Generated 2023-08-03 from:
5
  # /home/salah/new_tunisian_model/hparams/train_tunisian_withwavlm.yaml
@@ -46,8 +46,7 @@ sorting: ascending
46
  auto_mix_prec: false
47
  sample_rate: 16000
48
  language_modelling: true
49
- ngram_lm_path:
50
- /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/arpas/pluslanguages_everything.arpa
51
 
52
  # With data_parallel batch_size is split into N jobs
53
  # With DDP batch_size is multiplied by N jobs
 
1
+ # Generated 2023-09-20 from:
2
+ # /home/salah/Code_Switched_Tunisian_Speech_Recognition/cs.yaml
3
  # yamllint disable
4
  # Generated 2023-08-03 from:
5
  # /home/salah/new_tunisian_model/hparams/train_tunisian_withwavlm.yaml
 
46
  auto_mix_prec: false
47
  sample_rate: 16000
48
  language_modelling: true
49
+ ngram_lm_path: arpas/pluslanguages_everything.arpa
 
50
 
51
  # With data_parallel batch_size is split into N jobs
52
  # With DDP batch_size is multiplied by N jobs
results/non_semi_final_stac/log.txt ADDED
The diff for this file is too large to render. See raw diff
 
transcribe.ipynb ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 30,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "speechbrain.utils.distributed - distributed_launch flag is disabled, this experiment will be executed without DDP.\n",
13
+ "speechbrain.lobes.models.huggingface_wav2vec - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 feature extractor is frozen.\n",
14
+ "speechbrain.core - Beginning experiment!\n",
15
+ "speechbrain.core - Experiment folder: TunisianASR/results/14epoch_tunisian/1234/\n",
16
+ "speechbrain.pretrained.fetching - Fetch hyperparams.yaml: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/hyperparams.yaml.\n",
17
+ "speechbrain.pretrained.fetching - Fetch custom.py: Linking to local file in /home/salah/Code_Switched_Tunisian_Speech_Recognition/asr-wav2vec2-commonvoice-fr/custom.py.\n",
18
+ "speechbrain.lobes.models.huggingface_wav2vec - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 is frozen.\n",
19
+ "speechbrain.pretrained.fetching - Fetch wav2vec2.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/wav2vec2.ckpt.\n",
20
+ "speechbrain.pretrained.fetching - Fetch asr.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/asr.ckpt.\n",
21
+ "speechbrain.pretrained.fetching - Fetch tokenizer.ckpt: Using existing file/symlink in pretrained_models/asr-wav2vec2-commonvoice-fr/tokenizer.ckpt.\n",
22
+ "speechbrain.utils.parameter_transfer - Loading pretrained files for: wav2vec2, asr, tokenizer\n"
23
+ ]
24
+ },
25
+ {
26
+ "name": "stderr",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "Some weights of the model checkpoint at wav2vec2-large-lv60/ were not used when initializing Wav2Vec2Model: ['project_hid.bias', 'project_q.bias', 'project_hid.weight', 'quantizer.codevectors', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_q.weight']\n",
30
+ "- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
31
+ "- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
32
+ ]
33
+ },
34
+ {
35
+ "name": "stdout",
36
+ "output_type": "stream",
37
+ "text": [
38
+ "speechbrain.lobes.models.huggingface_wav2vec - speechbrain.lobes.models.huggingface_wav2vec - wav2vec 2.0 feature extractor is frozen.\n",
39
+ "speechbrain.core - Info: auto_mix_prec arg from hparam file is used\n",
40
+ "speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used\n",
41
+ "speechbrain.core - 314.4M trainable parameters in ASRCV\n",
42
+ "speechbrain.utils.checkpoints - Loading a checkpoint from EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00\n",
43
+ "moving to tunisian model\n",
44
+ "speechbrain.core - Info: auto_mix_prec arg from hparam file is used\n",
45
+ "speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used\n",
46
+ "speechbrain.core - 314.4M trainable parameters in ASR\n",
47
+ "speechbrain.utils.checkpoints - Loading a checkpoint from TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00\n"
48
+ ]
49
+ }
50
+ ],
51
+ "source": [
52
+ "import os\n",
53
+ "import sys\n",
54
+ "import torch\n",
55
+ "import logging\n",
56
+ "import speechbrain as sb\n",
57
+ "from speechbrain.utils.distributed import run_on_main\n",
58
+ "from hyperpyyaml import load_hyperpyyaml\n",
59
+ "from pathlib import Path\n",
60
+ "import torchaudio.transforms as T\n",
61
+ "from cv_train import ASRCV\n",
62
+ "import torchaudio\n",
63
+ "import numpy as np\n",
64
+ "import kenlm\n",
65
+ "from pyctcdecode import build_ctcdecoder\n",
66
+ "import re\n",
67
+ "from torch.nn.utils.rnn import pad_sequence\n",
68
+ "import torch.optim as optim\n",
69
+ "import torch.nn as nn\n",
70
+ "\n",
71
+ "\n",
72
+ "# Commented out IPython magic to ensure Python compatibility.\n",
73
+ "hparams_file, run_opts, overrides = sb.parse_arguments([\"TunisianASR/semi_trained.yaml\"])\n",
74
+ "\n",
75
+ "# If distributed_launch=True then\n",
76
+ "# create ddp_group with the right communication protocol\n",
77
+ "sb.utils.distributed.ddp_init_group(run_opts)\n",
78
+ "\n",
79
+ "with open(hparams_file) as fin:\n",
80
+ " hparams = load_hyperpyyaml(fin, overrides)\n",
81
+ "\n",
82
+ "# Create experiment directory\n",
83
+ "sb.create_experiment_directory(\n",
84
+ " experiment_directory=hparams[\"output_folder\"],\n",
85
+ " hyperparams_to_save=hparams_file,\n",
86
+ " overrides=overrides,\n",
87
+ ")\n",
88
+ "# Dataset prep (parsing Librispeech)\n",
89
+ "\n",
90
+ "def dataio_prepare(hparams):\n",
91
+ " \"\"\"This function prepares the datasets to be used in the brain class.\n",
92
+ " It also defines the data processing pipeline through user-defined functions.\"\"\"\n",
93
+ "\n",
94
+ " # 1. Define datasets\n",
95
+ " data_folder = hparams[\"data_folder\"]\n",
96
+ "\n",
97
+ " train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(\n",
98
+ " csv_path=hparams[\"train_csv\"], replacements={\"data_root\": data_folder},\n",
99
+ " )\n",
100
+ "\n",
101
+ " if hparams[\"sorting\"] == \"ascending\":\n",
102
+ " # we sort training data to speed up training and get better results.\n",
103
+ " train_data = train_data.filtered_sorted(\n",
104
+ " sort_key=\"duration\",\n",
105
+ " key_max_value={\"duration\": hparams[\"avoid_if_longer_than\"]},\n",
106
+ " )\n",
107
+ " # when sorting do not shuffle in dataloader ! otherwise is pointless\n",
108
+ " hparams[\"dataloader_options\"][\"shuffle\"] = False\n",
109
+ "\n",
110
+ " elif hparams[\"sorting\"] == \"descending\":\n",
111
+ " train_data = train_data.filtered_sorted(\n",
112
+ " sort_key=\"duration\",\n",
113
+ " reverse=True,\n",
114
+ " key_max_value={\"duration\": hparams[\"avoid_if_longer_than\"]},\n",
115
+ " )\n",
116
+ " # when sorting do not shuffle in dataloader ! otherwise is pointless\n",
117
+ " hparams[\"dataloader_options\"][\"shuffle\"] = False\n",
118
+ "\n",
119
+ " elif hparams[\"sorting\"] == \"random\":\n",
120
+ " pass\n",
121
+ "\n",
122
+ " else:\n",
123
+ " raise NotImplementedError(\n",
124
+ " \"sorting must be random, ascending or descending\"\n",
125
+ " )\n",
126
+ "\n",
127
+ " valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(\n",
128
+ " csv_path=hparams[\"valid_csv\"], replacements={\"data_root\": data_folder},\n",
129
+ " )\n",
130
+ " # We also sort the validation data so it is faster to validate\n",
131
+ " valid_data = valid_data.filtered_sorted(sort_key=\"duration\")\n",
132
+ " test_datasets = {}\n",
133
+ " for csv_file in hparams[\"test_csv\"]:\n",
134
+ " name = Path(csv_file).stem\n",
135
+ " test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(\n",
136
+ " csv_path=csv_file, replacements={\"data_root\": data_folder}\n",
137
+ " )\n",
138
+ " test_datasets[name] = test_datasets[name].filtered_sorted(\n",
139
+ " sort_key=\"duration\"\n",
140
+ " )\n",
141
+ "\n",
142
+ " datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]\n",
143
+ "\n",
144
+ "\n",
145
+ " # 2. Define audio pipeline:\n",
146
+ " @sb.utils.data_pipeline.takes(\"wav\")\n",
147
+ " @sb.utils.data_pipeline.provides(\"sig\")\n",
148
+ " def audio_pipeline(wav):\n",
149
+ " info = torchaudio.info(wav)\n",
150
+ " sig = sb.dataio.dataio.read_audio(wav)\n",
151
+ " if len(sig.shape)>1 :\n",
152
+ " sig = torch.mean(sig, dim=1)\n",
153
+ " resampled = torchaudio.transforms.Resample(\n",
154
+ " info.sample_rate, hparams[\"sample_rate\"],\n",
155
+ " )(sig)\n",
156
+ " return resampled\n",
157
+ "\n",
158
+ " sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)\n",
159
+ " label_encoder = sb.dataio.encoder.CTCTextEncoder()\n",
160
+ "\n",
161
+ " # 3. Define text pipeline:\n",
162
+ " @sb.utils.data_pipeline.takes(\"wrd\")\n",
163
+ " @sb.utils.data_pipeline.provides(\n",
164
+ " \"wrd\", \"char_list\", \"tokens_list\", \"tokens\"\n",
165
+ " )\n",
166
+ " def text_pipeline(wrd):\n",
167
+ " yield wrd\n",
168
+ " char_list = list(wrd)\n",
169
+ " yield char_list\n",
170
+ " tokens_list = label_encoder.encode_sequence(char_list)\n",
171
+ " yield tokens_list\n",
172
+ " tokens = torch.LongTensor(tokens_list)\n",
173
+ " yield tokens\n",
174
+ "\n",
175
+ " sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)\n",
176
+ " lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n",
177
+ " special_labels = {\n",
178
+ " \"blank_label\": hparams[\"blank_index\"],\n",
179
+ " \"unk_label\": hparams[\"unk_index\"]\n",
180
+ " }\n",
181
+ " label_encoder.load_or_create(\n",
182
+ " path=lab_enc_file,\n",
183
+ " from_didatasets=[train_data],\n",
184
+ " output_key=\"char_list\",\n",
185
+ " special_labels=special_labels,\n",
186
+ " sequence_input=True,\n",
187
+ " )\n",
188
+ "\n",
189
+ " # 4. Set output:\n",
190
+ " sb.dataio.dataset.set_output_keys(\n",
191
+ " datasets, [\"id\", \"sig\", \"wrd\", \"char_list\", \"tokens\"],\n",
192
+ " )\n",
193
+ " return train_data, valid_data,test_datasets, label_encoder\n",
194
+ "\n",
195
+ "class ASR(sb.core.Brain):\n",
196
+ " def compute_forward(self, batch, stage):\n",
197
+ " \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n",
198
+ "\n",
199
+ " batch = batch.to(self.device)\n",
200
+ " wavs, wav_lens = batch.sig\n",
201
+ " wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n",
202
+ "\n",
203
+ " if stage == sb.Stage.TRAIN:\n",
204
+ " if hasattr(self.hparams, \"augmentation\"):\n",
205
+ " wavs = self.hparams.augmentation(wavs, wav_lens)\n",
206
+ "\n",
207
+ " # Forward pass\n",
208
+ " feats = self.modules.wav2vec2(wavs, wav_lens)\n",
209
+ " x = self.modules.enc(feats)\n",
210
+ " logits = self.modules.ctc_lin(x)\n",
211
+ " p_ctc = self.hparams.log_softmax(logits)\n",
212
+ "\n",
213
+ " return p_ctc, wav_lens\n",
214
+ "\n",
215
+ " def custom_encode(self,wavs,wav_lens) :\n",
216
+ " wavs = wavs.to(\"cpu\")\n",
217
+ " if(wav_lens is not None): wav_lens.to(self.device)\n",
218
+ "\n",
219
+ " feats = self.modules.wav2vec2(wavs, wav_lens)\n",
220
+ " x = self.modules.enc(feats)\n",
221
+ " logits = self.modules.ctc_lin(x)\n",
222
+ " p_ctc = self.hparams.log_softmax(logits)\n",
223
+ "\n",
224
+ " return feats,p_ctc\n",
225
+ "\n",
226
+ "\n",
227
+ "\n",
228
+ " def compute_objectives(self, predictions, batch, stage):\n",
229
+ " \"\"\"Computes the loss (CTC) given predictions and targets.\"\"\"\n",
230
+ "\n",
231
+ " p_ctc, wav_lens = predictions\n",
232
+ "\n",
233
+ " ids = batch.id\n",
234
+ " tokens, tokens_lens = batch.tokens\n",
235
+ "\n",
236
+ " loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n",
237
+ "\n",
238
+ " if stage != sb.Stage.TRAIN:\n",
239
+ " predicted_tokens = sb.decoders.ctc_greedy_decode(\n",
240
+ " p_ctc, wav_lens, blank_id=self.hparams.blank_index\n",
241
+ " )\n",
242
+ " # Decode token terms to words\n",
243
+ " if self.hparams.use_language_modelling:\n",
244
+ " predicted_words = []\n",
245
+ " for logs in p_ctc:\n",
246
+ " text = decoder.decode(logs.detach().cpu().numpy())\n",
247
+ " predicted_words.append(text.split(\" \"))\n",
248
+ " else:\n",
249
+ " predicted_words = [\n",
250
+ " \"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \")\n",
251
+ " for utt_seq in predicted_tokens\n",
252
+ " ]\n",
253
+ " # Convert indices to words\n",
254
+ " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n",
255
+ "\n",
256
+ " self.wer_metric.append(ids, predicted_words, target_words)\n",
257
+ " self.cer_metric.append(ids, predicted_words, target_words)\n",
258
+ "\n",
259
+ " return loss\n",
260
+ "\n",
261
+ " def fit_batch(self, batch):\n",
262
+ " \"\"\"Train the parameters given a single batch in input\"\"\"\n",
263
+ " should_step = self.step % self.grad_accumulation_factor == 0\n",
264
+ " # Managing automatic mixed precision\n",
265
+ " # TOFIX: CTC fine-tuning currently is unstable\n",
266
+ " # This is certainly due to CTC being done in fp16 instead of fp32\n",
267
+ " if self.auto_mix_prec:\n",
268
+ " with torch.cuda.amp.autocast():\n",
269
+ " with self.no_sync():\n",
270
+ " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n",
271
+ " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n",
272
+ " with self.no_sync(not should_step):\n",
273
+ " self.scaler.scale(\n",
274
+ " loss / self.grad_accumulation_factor\n",
275
+ " ).backward()\n",
276
+ " if should_step:\n",
277
+ "\n",
278
+ " if not self.hparams.wav2vec2.freeze:\n",
279
+ " self.scaler.unscale_(self.wav2vec_optimizer)\n",
280
+ " self.scaler.unscale_(self.model_optimizer)\n",
281
+ " if self.check_gradients(loss):\n",
282
+ " if not self.hparams.wav2vec2.freeze:\n",
283
+ " if self.optimizer_step >= self.hparams.warmup_steps:\n",
284
+ " self.scaler.step(self.wav2vec_optimizer)\n",
285
+ " self.scaler.step(self.model_optimizer)\n",
286
+ " self.scaler.update()\n",
287
+ " self.zero_grad()\n",
288
+ " self.optimizer_step += 1\n",
289
+ " else:\n",
290
+ " # This is mandatory because HF models have a weird behavior with DDP\n",
291
+ " # on the forward pass\n",
292
+ " with self.no_sync():\n",
293
+ " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n",
294
+ "\n",
295
+ " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n",
296
+ "\n",
297
+ " with self.no_sync(not should_step):\n",
298
+ " (loss / self.grad_accumulation_factor).backward()\n",
299
+ " if should_step:\n",
300
+ " if self.check_gradients(loss):\n",
301
+ " if not self.hparams.wav2vec2.freeze:\n",
302
+ " if self.optimizer_step >= self.hparams.warmup_steps:\n",
303
+ " self.wav2vec_optimizer.step()\n",
304
+ " self.model_optimizer.step()\n",
305
+ " self.zero_grad()\n",
306
+ " self.optimizer_step += 1\n",
307
+ "\n",
308
+ " self.on_fit_batch_end(batch, outputs, loss, should_step)\n",
309
+ " return loss.detach().cpu()\n",
310
+ "\n",
311
+ " def evaluate_batch(self, batch, stage):\n",
312
+ " \"\"\"Computations needed for validation/test batches\"\"\"\n",
313
+ " predictions = self.compute_forward(batch, stage=stage)\n",
314
+ " with torch.no_grad():\n",
315
+ " loss = self.compute_objectives(predictions, batch, stage=stage)\n",
316
+ " return loss.detach()\n",
317
+ "\n",
318
+ " def on_stage_start(self, stage, epoch):\n",
319
+ " \"\"\"Gets called at the beginning of each epoch\"\"\"\n",
320
+ " if stage != sb.Stage.TRAIN:\n",
321
+ " self.cer_metric = self.hparams.cer_computer()\n",
322
+ " self.wer_metric = self.hparams.error_rate_computer()\n",
323
+ "\n",
324
+ " def on_stage_end(self, stage, stage_loss, epoch):\n",
325
+ " \"\"\"Gets called at the end of an epoch.\"\"\"\n",
326
+ " # Compute/store important stats\n",
327
+ " stage_stats = {\"loss\": stage_loss}\n",
328
+ " if stage == sb.Stage.TRAIN:\n",
329
+ " self.train_stats = stage_stats\n",
330
+ " else:\n",
331
+ " stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n",
332
+ " stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n",
333
+ "\n",
334
+ " # Perform end-of-iteration things, like annealing, logging, etc.\n",
335
+ " if stage == sb.Stage.VALID:\n",
336
+ " old_lr_model, new_lr_model = self.hparams.lr_annealing_model(\n",
337
+ " stage_stats[\"loss\"]\n",
338
+ " )\n",
339
+ " old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(\n",
340
+ " stage_stats[\"loss\"]\n",
341
+ " )\n",
342
+ " sb.nnet.schedulers.update_learning_rate(\n",
343
+ " self.model_optimizer, new_lr_model\n",
344
+ " )\n",
345
+ " if not self.hparams.wav2vec2.freeze:\n",
346
+ " sb.nnet.schedulers.update_learning_rate(\n",
347
+ " self.wav2vec_optimizer, new_lr_wav2vec\n",
348
+ " )\n",
349
+ " self.hparams.train_logger.log_stats(\n",
350
+ " stats_meta={\n",
351
+ " \"epoch\": epoch,\n",
352
+ " \"lr_model\": old_lr_model,\n",
353
+ " \"lr_wav2vec\": old_lr_wav2vec,\n",
354
+ " },\n",
355
+ " train_stats=self.train_stats,\n",
356
+ " valid_stats=stage_stats,\n",
357
+ " )\n",
358
+ " self.checkpointer.save_and_keep_only(\n",
359
+ " meta={\"WER\": stage_stats[\"WER\"]}, min_keys=[\"WER\"],\n",
360
+ " )\n",
361
+ " elif stage == sb.Stage.TEST:\n",
362
+ " self.hparams.train_logger.log_stats(\n",
363
+ " stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},\n",
364
+ " test_stats=stage_stats,\n",
365
+ " )\n",
366
+ " with open(self.hparams.wer_file, \"w\") as w:\n",
367
+ " self.wer_metric.write_stats(w)\n",
368
+ "\n",
369
+ " def init_optimizers(self):\n",
370
+ " \"Initializes the wav2vec2 optimizer and model optimizer\"\n",
371
+ "\n",
372
+ " # If the wav2vec encoder is unfrozen, we create the optimizer\n",
373
+ " if not self.hparams.wav2vec2.freeze:\n",
374
+ " self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(\n",
375
+ " self.modules.wav2vec2.parameters()\n",
376
+ " )\n",
377
+ " if self.checkpointer is not None:\n",
378
+ " self.checkpointer.add_recoverable(\n",
379
+ " \"wav2vec_opt\", self.wav2vec_optimizer\n",
380
+ " )\n",
381
+ "\n",
382
+ " self.model_optimizer = self.hparams.model_opt_class(\n",
383
+ " self.hparams.model.parameters()\n",
384
+ " )\n",
385
+ "\n",
386
+ " if self.checkpointer is not None:\n",
387
+ " self.checkpointer.add_recoverable(\"modelopt\", self.model_optimizer)\n",
388
+ "\n",
389
+ " def zero_grad(self, set_to_none=False):\n",
390
+ " if not self.hparams.wav2vec2.freeze:\n",
391
+ " self.wav2vec_optimizer.zero_grad(set_to_none)\n",
392
+ " self.model_optimizer.zero_grad(set_to_none)\n",
393
+ "\n",
394
+ "\n",
395
+ "from speechbrain.pretrained import EncoderASR,EncoderDecoderASR\n",
396
+ "french_asr_model = EncoderASR.from_hparams(source=\"asr-wav2vec2-commonvoice-fr\", savedir=\"pretrained_models/asr-wav2vec2-commonvoice-fr\").cuda()\n",
397
+ "french_asr_model.to(\"cpu\")\n",
398
+ "cvhparams_file, cvrun_opts, cvoverrides = sb.parse_arguments([\"EnglishCV/train_en_with_wav2vec.yaml\"])\n",
399
+ "with open(cvhparams_file) as cvfin:\n",
400
+ " cvhparams = load_hyperpyyaml(cvfin, cvoverrides)\n",
401
+ "english_asr_model = ASRCV(\n",
402
+ " modules=cvhparams[\"modules\"],\n",
403
+ " hparams=cvhparams,\n",
404
+ " run_opts=cvrun_opts,\n",
405
+ " checkpointer=cvhparams[\"checkpointer\"],\n",
406
+ " )\n",
407
+ "english_asr_model.modules.to(\"cpu\")\n",
408
+ "english_asr_model.checkpointer.recover_if_possible()\n",
409
+ "print(\"moving to tunisian model\")\n",
410
+ "asr_brain = ASR(\n",
411
+ " modules=hparams[\"modules\"],\n",
412
+ " hparams=hparams,\n",
413
+ " run_opts=run_opts,\n",
414
+ " checkpointer=hparams[\"checkpointer\"],\n",
415
+ ")\n",
416
+ "asr_brain.modules.to(\"cpu\")\n",
417
+ "asr_brain.checkpointer.recover_if_possible()\n",
418
+ "asr_brain.modules.eval()\n",
419
+ "english_asr_model.modules.eval()\n",
420
+ "french_asr_model.mods.eval()\n",
421
+ "asr_brain.modules.to(\"cpu\")\n",
422
+ "\n",
423
+ "# Commented out IPython magic to ensure Python compatibility.\n",
424
+ "# %ls\n",
425
+ "\n",
426
+ "#UTILS FUNCTIOJNS\n",
427
+ "def get_size_dimensions(arr):\n",
428
+ " size_dimensions = []\n",
429
+ " while isinstance(arr, list):\n",
430
+ " size_dimensions.append(len(arr))\n",
431
+ " arr = arr[0]\n",
432
+ " return size_dimensions\n",
433
+ "\n",
434
+ "def scale_array(batch,n):\n",
435
+ " scaled_batch = []\n",
436
+ "\n",
437
+ " for array in batch:\n",
438
+ " if(n < len(array)): raise ValueError(\"Cannot scale Array down\")\n",
439
+ "\n",
440
+ " repeat = round(n/len(array))+1\n",
441
+ " scaled_length_array= []\n",
442
+ "\n",
443
+ " for i in array:\n",
444
+ " for j in range(repeat) :\n",
445
+ " if(len(scaled_length_array) == n): break\n",
446
+ " scaled_length_array.append(i)\n",
447
+ "\n",
448
+ " scaled_batch.append(scaled_length_array)\n",
449
+ "\n",
450
+ " return torch.tensor(scaled_batch)\n",
451
+ "\n",
452
+ "\n",
453
+ "def load_paths(wavs_path):\n",
454
+ " waveforms = []\n",
455
+ " for path in wavs_path :\n",
456
+ " waveform, _ = torchaudio.load(path)\n",
457
+ " waveforms.append(waveform.squeeze(0))\n",
458
+ " # normalize array length to the bigger arrays by pading with 0's\n",
459
+ " padded_arrays = pad_sequence(waveforms, batch_first=True)\n",
460
+ " return torch.tensor(padded_arrays)\n",
461
+ "\n",
462
+ "\n",
463
+ "\n",
464
+ "device = 'cuda'\n",
465
+ "verbose = 0\n",
466
+ "#FLOW LEVEL FUNCTIONS\n",
467
+ "def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):\n",
468
+ "\n",
469
+ "\n",
470
+ " post1 = post1.to(device)\n",
471
+ " post2 = post2.to(device)\n",
472
+ " post3 = post3.to(device)\n",
473
+ " embeddings1 = embeddings1.to(device)\n",
474
+ " embeddings2 = embeddings2.to(device)\n",
475
+ " embeddings3 = embeddings3.to(device)\n",
476
+ "\n",
477
+ " posteriograms_merged = torch.cat((post1,post2,post3),dim=2)\n",
478
+ " embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)\n",
479
+ "\n",
480
+ " if(verbose !=0):\n",
481
+ " print('MERGED POST ',posteriograms_merged.shape)\n",
482
+ " print('MERGED emb ',embeddings_merged.shape)\n",
483
+ "\n",
484
+ " return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)\n",
485
+ "\n",
486
+ "def decode(model,wavs,wav_lens):\n",
487
+ "\n",
488
+ " with torch.no_grad():\n",
489
+ " wav_lens = wav_lens.to(model.device)\n",
490
+ " encoder_out = model.encode_batch(wavs, wav_lens)\n",
491
+ " predictions = model.decoding_function(encoder_out, wav_lens)\n",
492
+ " return predictions\n",
493
+ "\n",
494
+ "def middle_layer(batch, lens):\n",
495
+ "\n",
496
+ " tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)\n",
497
+ "\n",
498
+ " fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)\n",
499
+ " fr_posteriogram =french_asr_model.encode_batch(batch,lens)\n",
500
+ " en_embeddings = english_asr_model.modules.wav2vec2(batch, lens)\n",
501
+ " x = english_asr_model.modules.enc(en_embeddings)\n",
502
+ " en_posteriogram = english_asr_model.modules.ctc_lin(x)\n",
503
+ " #scores, en_posteriogram = english_asr_model.mods.decoder(en_embeddings ,lens)\n",
504
+ " if(verbose !=0):\n",
505
+ " print('[EMBEDDINGS] FR:',fr_embeddings.shape, \"EN:\",en_embeddings.shape, \"TN:\", tn_embeddings.shape)\n",
506
+ " print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, \"EN:\",en_posteriogram.shape,\"TN:\",tn_posteriogram.shape)\n",
507
+ "\n",
508
+ "\n",
509
+ " bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)\n",
510
+ " return bilangual_sample\n",
511
+ "\n",
512
+ "class Mixer(sb.core.Brain):\n",
513
+ "\n",
514
+ " def compute_forward(self, batch, stage):\n",
515
+ " \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n",
516
+ " wavs, wav_lens = batch.sig\n",
517
+ " wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n",
518
+ "\n",
519
+ " if stage == sb.Stage.TRAIN:\n",
520
+ " if hasattr(self.hparams, \"augmentation\"):\n",
521
+ " wavs = self.hparams.augmentation(wavs, wav_lens)\n",
522
+ "\n",
523
+ " multi_langual_feats = middle_layer(wavs, wav_lens)\n",
524
+ " multi_langual_feats= multi_langual_feats.to(device)\n",
525
+ " feats, _ = self.modules.enc(multi_langual_feats)\n",
526
+ " logits = self.modules.ctc_lin(feats)\n",
527
+ " p_ctc = self.hparams.log_softmax(logits)\n",
528
+ " \n",
529
+ " if stage!= sb.Stage.TRAIN:\n",
530
+ " p_tokens = sb.decoders.ctc_greedy_decode(\n",
531
+ " p_ctc, wav_lens, blank_id=self.hparams.blank_index\n",
532
+ " )\n",
533
+ " else : \n",
534
+ " p_tokens = None\n",
535
+ " return p_ctc, wav_lens, p_tokens\n",
536
+ " \n",
537
+ " \n",
538
+ " def treat_wav(self,sig):\n",
539
+ " multi_langual_feats = middle_layer(sig.to(\"cpu\"), torch.tensor([1]).to(\"cpu\"))\n",
540
+ " multi_langual_feats= multi_langual_feats.to(device)\n",
541
+ " feats, _ = self.modules.enc(multi_langual_feats)\n",
542
+ " logits = self.modules.ctc_lin(feats)\n",
543
+ " p_ctc = self.hparams.log_softmax(logits)\n",
544
+ " predicted_words =[]\n",
545
+ " for logs in p_ctc:\n",
546
+ " text = decoder.decode(logs.detach().cpu().numpy())\n",
547
+ " predicted_words.append(text.split(\" \"))\n",
548
+ " return \" \".join(predicted_words[0])\n",
549
+ " \n",
550
+ "\n",
551
+ " def compute_objectives(self, predictions, batch, stage):\n",
552
+ " \"\"\"Computes the loss (CTC) given predictions and targets.\"\"\"\n",
553
+ "\n",
554
+ " p_ctc, wav_lens , predicted_tokens= predictions\n",
555
+ "\n",
556
+ " ids = batch.id\n",
557
+ " tokens, tokens_lens = batch.tokens\n",
558
+ "\n",
559
+ " loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n",
560
+ "\n",
561
+ "\n",
562
+ " if stage == sb.Stage.VALID:\n",
563
+ " predicted_words = [\n",
564
+ " \"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \")\n",
565
+ " for utt_seq in predicted_tokens\n",
566
+ " ]\n",
567
+ " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n",
568
+ " self.wer_metric.append(ids, predicted_words, target_words)\n",
569
+ " self.cer_metric.append(ids, predicted_words, target_words)\n",
570
+ " if stage ==sb.Stage.TEST : \n",
571
+ " if self.hparams.language_modelling:\n",
572
+ " predicted_words = []\n",
573
+ " for logs in p_ctc:\n",
574
+ " text = decoder.decode(logs.detach().cpu().numpy())\n",
575
+ " predicted_words.append(text.split(\" \"))\n",
576
+ " else : \n",
577
+ " predicted_words = [\n",
578
+ " \"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \")\n",
579
+ " for utt_seq in predicted_tokens\n",
580
+ " ]\n",
581
+ "\n",
582
+ " target_words = [wrd.split(\" \") for wrd in batch.wrd]\n",
583
+ " self.wer_metric.append(ids, predicted_words, target_words)\n",
584
+ " self.cer_metric.append(ids, predicted_words, target_words)\n",
585
+ "\n",
586
+ " return loss\n",
587
+ "\n",
588
+ " def fit_batch(self, batch):\n",
589
+ " \"\"\"Train the parameters given a single batch in input\"\"\"\n",
590
+ " should_step = self.step % self.grad_accumulation_factor == 0\n",
591
+ " # Managing automatic mixed precision\n",
592
+ " # TOFIX: CTC fine-tuning currently is unstable\n",
593
+ " # This is certainly due to CTC being done in fp16 instead of fp32\n",
594
+ " if self.auto_mix_prec:\n",
595
+ " with torch.cuda.amp.autocast():\n",
596
+ " with self.no_sync():\n",
597
+ " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n",
598
+ " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n",
599
+ " with self.no_sync(not should_step):\n",
600
+ " self.scaler.scale(\n",
601
+ " loss / self.grad_accumulation_factor\n",
602
+ " ).backward()\n",
603
+ " if should_step:\n",
604
+ "\n",
605
+ "\n",
606
+ " self.scaler.unscale_(self.model_optimizer)\n",
607
+ " if self.check_gradients(loss):\n",
608
+ " self.scaler.step(self.model_optimizer)\n",
609
+ " self.scaler.update()\n",
610
+ " self.zero_grad()\n",
611
+ " self.optimizer_step += 1\n",
612
+ " else:\n",
613
+ " # This is mandatory because HF models have a weird behavior with DDP\n",
614
+ " # on the forward pass\n",
615
+ " with self.no_sync():\n",
616
+ " outputs = self.compute_forward(batch, sb.Stage.TRAIN)\n",
617
+ "\n",
618
+ " loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)\n",
619
+ "\n",
620
+ " with self.no_sync(not should_step):\n",
621
+ " (loss / self.grad_accumulation_factor).backward()\n",
622
+ " if should_step:\n",
623
+ " if self.check_gradients(loss):\n",
624
+ " self.model_optimizer.step()\n",
625
+ " self.zero_grad()\n",
626
+ " self.optimizer_step += 1\n",
627
+ "\n",
628
+ " self.on_fit_batch_end(batch, outputs, loss, should_step)\n",
629
+ " return loss.detach().cpu()\n",
630
+ "\n",
631
+ " def evaluate_batch(self, batch, stage):\n",
632
+ " \"\"\"Computations needed for validation/test batches\"\"\"\n",
633
+ " predictions = self.compute_forward(batch, stage=stage)\n",
634
+ " with torch.no_grad():\n",
635
+ " loss = self.compute_objectives(predictions, batch, stage=stage)\n",
636
+ " return loss.detach()\n",
637
+ "\n",
638
+ " def on_stage_start(self, stage, epoch):\n",
639
+ " \"\"\"Gets called at the beginning of each epoch\"\"\"\n",
640
+ " if stage != sb.Stage.TRAIN:\n",
641
+ " self.cer_metric = self.hparams.cer_computer()\n",
642
+ " self.wer_metric = self.hparams.error_rate_computer()\n",
643
+ "\n",
644
+ " def on_stage_end(self, stage, stage_loss, epoch):\n",
645
+ " \"\"\"Gets called at the end of an epoch.\"\"\"\n",
646
+ " # Compute/store important stats\n",
647
+ " stage_stats = {\"loss\": stage_loss}\n",
648
+ " if stage == sb.Stage.TRAIN:\n",
649
+ " self.train_stats = stage_stats\n",
650
+ " else:\n",
651
+ " stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n",
652
+ " stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n",
653
+ "\n",
654
+ " # Perform end-of-iteration things, like annealing, logging, etc.\n",
655
+ " if stage == sb.Stage.VALID:\n",
656
+ " old_lr_model, new_lr_model = self.hparams.lr_annealing_model(\n",
657
+ " stage_stats[\"loss\"]\n",
658
+ " )\n",
659
+ " sb.nnet.schedulers.update_learning_rate(\n",
660
+ " self.model_optimizer, new_lr_model\n",
661
+ " )\n",
662
+ " self.hparams.train_logger.log_stats(\n",
663
+ " stats_meta={\n",
664
+ " \"epoch\": epoch,\n",
665
+ " \"lr_model\": old_lr_model,\n",
666
+ " },\n",
667
+ " train_stats=self.train_stats,\n",
668
+ " valid_stats=stage_stats,\n",
669
+ " )\n",
670
+ " self.checkpointer.save_and_keep_only(\n",
671
+ " meta={\"WER\": stage_stats[\"WER\"]}, min_keys=[\"WER\"],\n",
672
+ " )\n",
673
+ " elif stage == sb.Stage.TEST:\n",
674
+ " self.hparams.train_logger.log_stats(\n",
675
+ " stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},\n",
676
+ " test_stats=stage_stats,\n",
677
+ " )\n",
678
+ " with open(self.hparams.wer_file, \"w\") as w:\n",
679
+ " self.wer_metric.write_stats(w)\n",
680
+ "\n",
681
+ " def init_optimizers(self):\n",
682
+ "\n",
683
+ " self.model_optimizer = self.hparams.model_opt_class(\n",
684
+ " self.hparams.model.parameters()\n",
685
+ " )\n",
686
+ "\n",
687
+ " if self.checkpointer is not None:\n",
688
+ " self.checkpointer.add_recoverable(\"modelopt\", self.model_optimizer)\n",
689
+ "\n",
690
+ " def zero_grad(self, set_to_none=False):\n",
691
+ "\n",
692
+ " self.model_optimizer.zero_grad(set_to_none)\n",
693
+ "\n",
694
+ "\n"
695
+ ]
696
+ },
697
+ {
698
+ "cell_type": "code",
699
+ "execution_count": null,
700
+ "metadata": {},
701
+ "outputs": [
702
+ {
703
+ "name": "stdout",
704
+ "output_type": "stream",
705
+ "text": [
706
+ "speechbrain.utils.distributed - distributed_launch flag is disabled, this experiment will be executed without DDP.\n",
707
+ "speechbrain.core - Beginning experiment!\n",
708
+ "speechbrain.core - Experiment folder: results/non_semi_final_stac\n",
709
+ "speechbrain.dataio.encoder - Load called, but CTCTextEncoder is not empty. Loaded data will overwrite everything. This is normal if there is e.g. an unk label defined at init.\n",
710
+ "pyctcdecode.decoder - Using arpa instead of binary LM file, decoder instantiation might be slow.\n",
711
+ "pyctcdecode.alphabet - Alphabet determined to be of regular style.\n",
712
+ "pyctcdecode.alphabet - Unigrams and labels don't seem to agree.\n",
713
+ "speechbrain.core - Info: auto_mix_prec arg from hparam file is used\n",
714
+ "speechbrain.core - 60.1M trainable parameters in Mixer\n",
715
+ "speechbrain.utils.checkpoints - Loading a checkpoint from results/non_semi_final_stac/save/CKPT+2023-09-08+01-40-18+00\n",
716
+ "pyctcdecode.decoder - Using arpa instead of binary LM file, decoder instantiation might be slow.\n",
717
+ "pyctcdecode.alphabet - Alphabet determined to be of regular style.\n",
718
+ "pyctcdecode.alphabet - Unigrams and labels don't seem to agree.\n",
719
+ "speechbrain.utils.checkpoints - Loading a checkpoint from TunisianASR/results/14epoch_tunisian/1234/save/CKPT+2023-08-03+01-38-38+00\n"
720
+ ]
721
+ },
722
+ {
723
+ "name": "stderr",
724
+ "output_type": "stream",
725
+ "text": [
726
+ "<ipython-input-32-948c3f4b1130>:120: GradioDeprecationWarning: `optional` parameter is deprecated, and it has no effect\n",
727
+ " inputs=[gr.Audio(source=\"microphone\", type='filepath', label = \"record\", optional = True),\n",
728
+ "<ipython-input-32-948c3f4b1130>:121: GradioDeprecationWarning: `optional` parameter is deprecated, and it has no effect\n",
729
+ " gr.Audio(source=\"upload\", type='filepath', label=\"filein\", optional=True)]\n"
730
+ ]
731
+ },
732
+ {
733
+ "name": "stdout",
734
+ "output_type": "stream",
735
+ "text": [
736
+ "Running on local URL: http://127.0.0.1:7860\n",
737
+ "\n",
738
+ "To create a public link, set `share=True` in `launch()`.\n"
739
+ ]
740
+ },
741
+ {
742
+ "data": {
743
+ "text/html": [
744
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
745
+ ],
746
+ "text/plain": [
747
+ "<IPython.core.display.HTML object>"
748
+ ]
749
+ },
750
+ "metadata": {},
751
+ "output_type": "display_data"
752
+ },
753
+ {
754
+ "name": "stderr",
755
+ "output_type": "stream",
756
+ "text": [
757
+ "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:188: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n",
758
+ " warnings.warn(warning.format(data.dtype))\n",
759
+ "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:188: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n",
760
+ " warnings.warn(warning.format(data.dtype))\n",
761
+ "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:188: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n",
762
+ " warnings.warn(warning.format(data.dtype))\n",
763
+ "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/gradio/processing_utils.py:188: UserWarning: Trying to convert audio automatically from int32 to 16-bit int format.\n",
764
+ " warnings.warn(warning.format(data.dtype))\n"
765
+ ]
766
+ }
767
+ ],
768
+ "source": [
769
+ "hparams_file, run_opts, overrides = sb.parse_arguments([\"cs.yaml\"])\n",
770
+ "\n",
771
+ "# If distributed_launch=True then\n",
772
+ "# create ddp_group with the right communication protocol\n",
773
+ "sb.utils.distributed.ddp_init_group(run_opts)\n",
774
+ "\n",
775
+ "with open(hparams_file) as fin:\n",
776
+ " hparams = load_hyperpyyaml(fin, overrides)\n",
777
+ "\n",
778
+ "# Create experiment directory\n",
779
+ "sb.create_experiment_directory(\n",
780
+ " experiment_directory=hparams[\"output_folder\"],\n",
781
+ " hyperparams_to_save=hparams_file,\n",
782
+ " overrides=overrides,\n",
783
+ ")\n",
784
+ "def read_labels_file(labels_file):\n",
785
+ " with open(labels_file, \"r\",encoding=\"utf-8\") as lf:\n",
786
+ " lines = lf.read().splitlines()\n",
787
+ " division = \"===\"\n",
788
+ " numbers = {}\n",
789
+ " for line in lines :\n",
790
+ " if division in line :\n",
791
+ " break\n",
792
+ " string, number = line.split(\"=>\")\n",
793
+ " number = int(number)\n",
794
+ " string = string[1:-2]\n",
795
+ " numbers[number] = string\n",
796
+ " return [numbers[x] for x in range(len(numbers))]\n",
797
+ "\n",
798
+ "label_encoder = sb.dataio.encoder.CTCTextEncoder()\n",
799
+ "\n",
800
+ "lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n",
801
+ "special_labels = {\n",
802
+ " \"blank_label\": hparams[\"blank_index\"],\n",
803
+ " \"unk_label\": hparams[\"unk_index\"]\n",
804
+ "}\n",
805
+ "label_encoder.load_or_create(\n",
806
+ " path=lab_enc_file,\n",
807
+ " from_didatasets=[[]],\n",
808
+ " output_key=\"char_list\",\n",
809
+ " special_labels=special_labels,\n",
810
+ " sequence_input=True,\n",
811
+ ")\n",
812
+ "\n",
813
+ "\n",
814
+ "labels = read_labels_file(os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\"))\n",
815
+ "labels = [\"\"] + labels[1:-1] + [\"1\"] \n",
816
+ "if hparams[\"language_modelling\"]:\n",
817
+ " decoder = build_ctcdecoder(\n",
818
+ " labels,\n",
819
+ " kenlm_model_path=hparams[\"ngram_lm_path\"], # either .arpa or .bin file\n",
820
+ " alpha=0.5, # tuned on a val set\n",
821
+ " beta=1, # tuned on a val set\n",
822
+ " )\n",
823
+ "\n",
824
+ "\n",
825
+ "\n",
826
+ "\n",
827
+ "mixer = Mixer(\n",
828
+ " modules=hparams[\"modules\"],\n",
829
+ " hparams=hparams,\n",
830
+ " run_opts=run_opts,\n",
831
+ " checkpointer=hparams[\"checkpointer\"],\n",
832
+ ")\n",
833
+ "mixer.tokenizer = label_encoder\n",
834
+ "mixer.checkpointer.recover_if_possible()\n",
835
+ "mixer.modules.eval()\n",
836
+ "\n",
837
+ "\n",
838
+ "label_encoder = sb.dataio.encoder.CTCTextEncoder()\n",
839
+ "\n",
840
+ "\n",
841
+ "# We dynamicaly add the tokenizer to our brain class.\n",
842
+ "# NB: This tokenizer corresponds to the one used for the LM!!\n",
843
+ "\n",
844
+ "decoder = build_ctcdecoder(\n",
845
+ " labels,\n",
846
+ " kenlm_model_path= \"arpas/everything.arpa\", # either .arpa or .bin file\n",
847
+ " alpha=0.5, # tuned on a val set\n",
848
+ " beta=1, # tuned on a val set\n",
849
+ ")\n",
850
+ "\n",
851
+ "run_opts[\"device\"]=\"cpu\"\n",
852
+ "\n",
853
+ "\n",
854
+ "device = \"cpu\"\n",
855
+ "mixer.device= \"cpu\"\n",
856
+ "mixer.modules.to(\"cpu\")\n",
857
+ "\n",
858
+ "from enum import Enum, auto\n",
859
+ "class Stage(Enum):\n",
860
+ " TRAIN = auto()\n",
861
+ " VALID = auto()\n",
862
+ " TEST = auto()\n",
863
+ "\n",
864
+ "asr_brain.on_evaluate_start()\n",
865
+ "asr_brain.modules.eval()\n",
866
+ "\n",
867
+ "\n",
868
+ "import gradio as gr\n",
869
+ "\n",
870
+ "def treat_wav_file(file_mic,file_upload ,asr=mixer, device=\"cpu\") :\n",
871
+ " if (file_mic is not None) and (file_upload is not None):\n",
872
+ " warn_output = \"WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\\n\"\n",
873
+ " wav = file_mic\n",
874
+ " elif (file_mic is None) and (file_upload is None):\n",
875
+ " return \"ERROR: You have to either use the microphone or upload an audio file\"\n",
876
+ " elif file_mic is not None:\n",
877
+ " wav = file_mic\n",
878
+ " else:\n",
879
+ " wav = file_upload\n",
880
+ " sig, sr = torchaudio.load(wav)\n",
881
+ " tensor_wav = sig.to(device)\n",
882
+ " resampled = torchaudio.functional.resample( tensor_wav, sr, 16000)\n",
883
+ " sentence = asr.treat_wav(resampled)\n",
884
+ " return sentence\n",
885
+ "\n",
886
+ "gr.Interface(\n",
887
+ " fn=treat_wav_file, \n",
888
+ " inputs=[gr.Audio(source=\"microphone\", type='filepath', label = \"record\", optional = True), \n",
889
+ " gr.Audio(source=\"upload\", type='filepath', label=\"filein\", optional=True)]\n",
890
+ " ,outputs=\"text\").launch(share= False, debug = True)\n"
891
+ ]
892
+ }
893
+ ],
894
+ "metadata": {
895
+ "kernelspec": {
896
+ "display_name": "Python 3",
897
+ "language": "python",
898
+ "name": "python3"
899
+ },
900
+ "language_info": {
901
+ "codemirror_mode": {
902
+ "name": "ipython",
903
+ "version": 3
904
+ },
905
+ "file_extension": ".py",
906
+ "mimetype": "text/x-python",
907
+ "name": "python",
908
+ "nbconvert_exporter": "python",
909
+ "pygments_lexer": "ipython3",
910
+ "version": "3.8.5"
911
+ }
912
+ },
913
+ "nbformat": 4,
914
+ "nbformat_minor": 5
915
+ }
wav2vec2-FR-7K-large ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 7fa1111246fca6eb198d1caab50fd2e4469bf659
wav2vec2-large-lv60 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 0cde644b64dac88d8416bec1c92a4099b850ba0b
wavlm-large ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit c1423ed94bb01d80a3f5ce5bc39f6026a0f4828c