SalahZa commited on
Commit
8b664ce
1 Parent(s): c445670

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.arpa filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,21 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tunisian Arabic ASR Model with wav2vec2
2
+
3
+ This repository provides all the necessary tools to perform automatic speech recognition from an end-to-end system pretrained on Tunisian arabic dialect
4
+
5
+ ## Performance
6
+ the performance of the mode is :
7
+ | Release Version | |WER (%) | CER (%) |
8
+ |-----------------|----|---------|---------|
9
+ | v1.0 | Without LM |11.82 | 6.33 |
10
+ ## Dataset
11
+ This ASR model was trained on :
12
+ * TARIC : The corpus, named TARIC (Tunisian Arabic Railway Interaction Corpus) has a collection of audio recordings and transcriptions from dialogues in the Tunisian Railway Transport Network. - [Taric Corpus](https://aclanthology.org/L14-1385/) -
13
+ * STAC :A corpus of spoken Tunisian Arabic - [STAC Corpus](https://www.researchgate.net/publication/307583782_Spoken_Tunisian_Arabic_Corpus_STAC_Transcription_and_Annotation)
14
+ * IWSLT : A Tunisian conversational speech - [IWSLT Corpus](https://iwslt.org/2022/dialect)-
15
+ * Tunspeech : Our custom dataset
16
+
17
+ ## Install
18
+ ```python
19
+ pip install speechbrain transformers
20
+ ```
21
+
outdomain.arpa ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24654c1d236bb1bd367125131c847c4a734e69914eda71a6786964c20440d8fe
3
+ size 324243244
semi_wavlm_large_tunisian_ctc/1234/hyperparams.yaml ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2023-09-08 from:
2
+ # /gpfsdsstore/projects/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/hparams/train_semi.yaml
3
+ # yamllint disable
4
+ # ################################
5
+ # Model: wav2vec2 + DNN + CTC
6
+ # Augmentation: SpecAugment
7
+ # Authors: Titouan Parcollet 2021
8
+ # ################################
9
+
10
+ # Seed needs to be set at top of yaml, before objects with parameters are made
11
+ seed: 1234
12
+ __set_seed: !!python/object/apply:torch.manual_seed [1234]
13
+ output_folder: results/semi_wavlm_large_tunisian_ctc/1234
14
+ wer_file: results/semi_wavlm_large_tunisian_ctc/1234/wer.txt
15
+ save_folder: results/semi_wavlm_large_tunisian_ctc/1234/save
16
+ train_log: results/semi_wavlm_large_tunisian_ctc/1234/train_log.txt
17
+
18
+ # URL for the biggest LeBenchmark wav2vec french.
19
+ wav2vec2_folder: results/semi_wavlm_large_tunisian_ctc/1234/save/wav2vec2_checkpoint
20
+
21
+ # Data files
22
+ data_folder: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
23
+ train_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/train.tsv # Standard CommonVoice .tsv files
24
+ dev_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/dev.tsv # Standard CommonVoice .tsv files
25
+ test_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/test.tsv # Standard CommonVoice .tsv files
26
+ accented_letters: true
27
+ language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
28
+ train_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/train_enhanced.csv
29
+ valid_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/dev.csv
30
+ test_csv:
31
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/full_annotation_test.csv
32
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/iwslt_test.csv
33
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/taric_test.csv
34
+
35
+ skip_prep: true # Skip data preparation
36
+
37
+ use_language_modelling: true
38
+ ngram_lm_path: arpas/outdomain.arpa
39
+
40
+ # We remove utterance slonger than 10s in the train/dev/test sets as
41
+ # longer sentences certainly correspond to "open microphones".
42
+ avoid_if_longer_than: 10.0
43
+ avoid_if_shorter_than: 1.2
44
+
45
+
46
+ # Training parameters
47
+ number_of_epochs: 12
48
+ lr: 1.0
49
+ lr_wav2vec: 0.0001
50
+ sorting: ascending
51
+ auto_mix_prec: false
52
+ sample_rate: 16000
53
+ ckpt_interval_minutes: 30 # save checkpoint every N min
54
+
55
+ # With data_parallel batch_size is split into N jobs
56
+ # With DDP batch_size is multiplied by N jobs
57
+ # Must be 6 per GPU to fit 16GB of VRAM
58
+ batch_size: 10
59
+ test_batch_size: 4
60
+
61
+ dataloader_options:
62
+ batch_size: 10
63
+ num_workers: 6
64
+ test_dataloader_options:
65
+ batch_size: 4
66
+ num_workers: 6
67
+
68
+ # BPE parameters
69
+ token_type: char # ["unigram", "bpe", "char"]
70
+ character_coverage: 1.0
71
+
72
+ # Model parameters
73
+ # activation: !name:torch.nn.LeakyReLU
74
+ wav2vec_output_dim: 1024
75
+ dnn_neurons: 1024
76
+ freeze_wav2vec: false
77
+ freeze_feature_extractor: true
78
+ dropout: 0.15
79
+ warmup_steps: 500 # The wav2vec 2 model isn't updated for this amount of steps
80
+
81
+ # Outputs
82
+ output_neurons: 40 # BPE size, index(blank/eos/bos) = 0
83
+
84
+ # Decoding parameters
85
+ # Be sure that the bos and eos index match with the BPEs ones
86
+ blank_index: 0
87
+ unk_index: 1
88
+
89
+ #
90
+ # Functions and classes
91
+ #
92
+ epoch_counter: &id007 !new:speechbrain.utils.epoch_loop.EpochCounter
93
+
94
+ limit: 12
95
+
96
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
97
+ sample_rate: 16000
98
+ speeds: [95, 100, 105]
99
+
100
+ enc: &id002 !new:speechbrain.nnet.containers.Sequential
101
+ input_shape: [null, null, 1024]
102
+ linear1: !name:speechbrain.nnet.linear.Linear
103
+ n_neurons: 1024
104
+ bias: true
105
+ bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
106
+ activation: !new:torch.nn.LeakyReLU
107
+ drop: !new:torch.nn.Dropout
108
+ p: 0.15
109
+ linear2: !name:speechbrain.nnet.linear.Linear
110
+ n_neurons: 1024
111
+ bias: true
112
+ bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
113
+ activation2: !new:torch.nn.LeakyReLU
114
+ drop2: !new:torch.nn.Dropout
115
+ p: 0.15
116
+ linear3: !name:speechbrain.nnet.linear.Linear
117
+ n_neurons: 1024
118
+ bias: true
119
+ bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
120
+ activation3: !new:torch.nn.LeakyReLU
121
+
122
+ wav2vec2: &id001 !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
123
+ source: /gpfsstore/rech/nou/uzn19yk/wavlm/
124
+ output_norm: false
125
+ freeze: false
126
+ freeze_feature_extractor: true
127
+ save_path: results/semi_wavlm_large_tunisian_ctc/1234/save/wav2vec2_checkpoint
128
+
129
+ #####
130
+ # Uncomment this block if you prefer to use a Fairseq pretrained model instead
131
+ # of a HuggingFace one. Here, we provide an URL that is obtained from the
132
+ # Fairseq github for the multilingual XLSR.
133
+ #
134
+ #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
135
+ #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
136
+ # pretrained_path: !ref <wav2vec2_url>
137
+ # output_norm: True
138
+ # freeze: False
139
+ # save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
140
+ #####
141
+
142
+
143
+ ctc_lin: &id003 !new:speechbrain.nnet.linear.Linear
144
+
145
+ input_size: 1024
146
+ n_neurons: 40
147
+
148
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
149
+ apply_log: true
150
+
151
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
152
+ blank_index: 0
153
+
154
+ modules:
155
+ wav2vec2: *id001
156
+ enc: *id002
157
+ ctc_lin: *id003
158
+ model: &id004 !new:torch.nn.ModuleList
159
+ - [*id002, *id003]
160
+ model_opt_class: !name:torch.optim.Adadelta
161
+ lr: 1.0
162
+ rho: 0.95
163
+ eps: 1.e-8
164
+
165
+ wav2vec_opt_class: !name:torch.optim.Adam
166
+ lr: 0.0001
167
+
168
+ lr_annealing_model: &id005 !new:speechbrain.nnet.schedulers.NewBobScheduler
169
+ initial_value: 1.0
170
+ improvement_threshold: 0.0025
171
+ annealing_factor: 0.8
172
+ patient: 0
173
+
174
+ lr_annealing_wav2vec: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
175
+ initial_value: 0.0001
176
+ improvement_threshold: 0.0025
177
+ annealing_factor: 0.9
178
+ patient: 0
179
+
180
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
181
+ checkpoints_dir: results/semi_wavlm_large_tunisian_ctc/1234/save
182
+ recoverables:
183
+ wav2vec2: *id001
184
+ model: *id004
185
+ scheduler_model: *id005
186
+ scheduler_wav2vec: *id006
187
+ counter: *id007
188
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
189
+ save_file: results/semi_wavlm_large_tunisian_ctc/1234/train_log.txt
190
+
191
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
192
+
193
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
194
+ split_tokens: true
semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/CKPT.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # yamllint disable
2
+ WER: 27.83210816487267
3
+ end-of-epoch: true
4
+ unixtime: 1693868963.5220973
semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3947a24e8dff5a14299b9cf2fe66ffb4d738cb88717de7f0cf7e8547a76e9776
3
+ size 51
semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b51d431df5d7f141cbececcf79edf3dd861c3b4069f0b11661a3eefacbba918
3
+ size 2
semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/dataloader-TRAIN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b363886c229e536bd3c84e0c3e89312d70e00422578e076a62df1b45c9390793
3
+ size 5
semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc1dbeca1e1f1340b08d8ebea6e492f474708dddbbe8cabbcdde5ee9660704f2
3
+ size 12814446
semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/modelopt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3af1791eb9a5bfbfc087d2c10b94634df24cad3ac503ce9ba280a3ecc4737781
3
+ size 25575663
semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/scheduler_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c275ab9245b440d1586f72058d9edaac1a2fb3e7a52712aa9a9ad022b99a1c0d
3
+ size 639
semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/scheduler_wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a88187f7882dc3e10c108f1b7abfbd819285b34bded4e88e91c4ff699c1bb5d2
3
+ size 643
semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/wav2vec2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:788267bd25ef37623715fa21a975090e5e316fff05971375cd3f62e5160f0743
3
+ size 1262005979
semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/wav2vec_opt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efa967fdd8067be7d88c18cd197980c9c91f344a3dff2b2518b8381c49f28b1e
3
+ size 2490361859
semi_wavlm_large_tunisian_ctc/1234/save/label_encoder.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'ب' => 38
2
+ 'ا' => 1
3
+ 'ه' => 2
4
+ 'ي' => 3
5
+ 'و' => 4
6
+ 'ن' => 5
7
+ 'أ' => 6
8
+ ' ' => 7
9
+ 'م' => 8
10
+ 'ش' => 9
11
+ 'ل' => 10
12
+ 'س' => 11
13
+ 'ت' => 12
14
+ 'د' => 13
15
+ 'ر' => 14
16
+ 'ى' => 15
17
+ 'ح' => 16
18
+ 'ط' => 17
19
+ 'ع' => 18
20
+ 'ك' => 19
21
+ 'ف' => 20
22
+ 'ق' => 21
23
+ 'آ' => 22
24
+ 'ة' => 23
25
+ 'ج' => 24
26
+ 'ض' => 25
27
+ 'ز' => 26
28
+ 'ص' => 27
29
+ 'إ' => 28
30
+ 'ث' => 29
31
+ 'خ' => 30
32
+ 'ڨ' => 31
33
+ 'ذ' => 32
34
+ 'ظ' => 33
35
+ 'ء' => 34
36
+ 'غ' => 35
37
+ 'ئ' => 36
38
+ 'ؤ' => 37
39
+ '<blank>' => 0
40
+ 1 => 39
41
+ ================
42
+ 'starting_index' => 0
43
+ 'unk_label' => 1
44
+ 'blank_label' => '<blank>'
train_semi.yaml ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ################################
2
+ # Model: wav2vec2 + DNN + CTC
3
+ # Augmentation: SpecAugment
4
+ # Authors: Titouan Parcollet 2021
5
+ # ################################
6
+
7
+ # Seed needs to be set at top of yaml, before objects with parameters are made
8
+ seed: 1234
9
+ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
10
+ output_folder: !ref semi_wavlm_large_tunisian_ctc/<seed>
11
+ wer_file: !ref <output_folder>/wer.txt
12
+ save_folder: !ref <output_folder>/save
13
+ train_log: !ref <output_folder>/train_log.txt
14
+
15
+ # URL for the biggest LeBenchmark wav2vec french.
16
+ wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
17
+
18
+ # Data files
19
+ data_folder: /path/to/data # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
20
+ train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
21
+ dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
22
+ test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
23
+ accented_letters: True
24
+ language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
25
+ test_csv:
26
+ - /path/to/test_data
27
+
28
+ skip_prep: True # Skip data preparation
29
+
30
+ use_language_modelling: True
31
+ ngram_lm_path: outdomain.arpa
32
+
33
+ # We remove utterance slonger than 10s in the train/dev/test sets as
34
+ # longer sentences certainly correspond to "open microphones".
35
+ avoid_if_longer_than: 10.0
36
+ avoid_if_shorter_than: 1.2
37
+
38
+
39
+ # Training parameters
40
+ number_of_epochs: 12
41
+ lr: 1.0
42
+ lr_wav2vec: 0.0001
43
+ sorting: ascending
44
+ auto_mix_prec: False
45
+ sample_rate: 16000
46
+ ckpt_interval_minutes: 30 # save checkpoint every N min
47
+
48
+ # With data_parallel batch_size is split into N jobs
49
+ # With DDP batch_size is multiplied by N jobs
50
+ # Must be 6 per GPU to fit 16GB of VRAM
51
+ batch_size: 10
52
+ test_batch_size: 4
53
+
54
+ dataloader_options:
55
+ batch_size: !ref <batch_size>
56
+ num_workers: 6
57
+ test_dataloader_options:
58
+ batch_size: !ref <test_batch_size>
59
+ num_workers: 6
60
+
61
+ # BPE parameters
62
+ token_type: char # ["unigram", "bpe", "char"]
63
+ character_coverage: 1.0
64
+
65
+ # Model parameters
66
+ # activation: !name:torch.nn.LeakyReLU
67
+ wav2vec_output_dim: 1024
68
+ dnn_neurons: 1024
69
+ freeze_wav2vec: False
70
+ freeze_feature_extractor: True
71
+ dropout: 0.15
72
+ warmup_steps: 500 # The wav2vec 2 model isn't updated for this amount of steps
73
+
74
+ # Outputs
75
+ output_neurons: 40 # BPE size, index(blank/eos/bos) = 0
76
+
77
+ # Decoding parameters
78
+ # Be sure that the bos and eos index match with the BPEs ones
79
+ blank_index: 0
80
+ unk_index: 1
81
+
82
+ #
83
+ # Functions and classes
84
+ #
85
+ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
86
+ limit: !ref <number_of_epochs>
87
+
88
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
89
+ sample_rate: !ref <sample_rate>
90
+ speeds: [95, 100, 105]
91
+
92
+ enc: !new:speechbrain.nnet.containers.Sequential
93
+ input_shape: [null, null, !ref <wav2vec_output_dim>]
94
+ linear1: !name:speechbrain.nnet.linear.Linear
95
+ n_neurons: !ref <dnn_neurons>
96
+ bias: True
97
+ bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
98
+ activation: !new:torch.nn.LeakyReLU
99
+ drop: !new:torch.nn.Dropout
100
+ p: !ref <dropout>
101
+ linear2: !name:speechbrain.nnet.linear.Linear
102
+ n_neurons: !ref <dnn_neurons>
103
+ bias: True
104
+ bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
105
+ activation2: !new:torch.nn.LeakyReLU
106
+ drop2: !new:torch.nn.Dropout
107
+ p: !ref <dropout>
108
+ linear3: !name:speechbrain.nnet.linear.Linear
109
+ n_neurons: !ref <dnn_neurons>
110
+ bias: True
111
+ bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
112
+ activation3: !new:torch.nn.LeakyReLU
113
+
114
+ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
115
+ source: /gpfsstore/rech/nou/uzn19yk/wavlm/
116
+ output_norm: False
117
+ freeze: !ref <freeze_wav2vec>
118
+ freeze_feature_extractor: !ref <freeze_feature_extractor>
119
+ save_path: !ref <wav2vec2_folder>
120
+
121
+
122
+ ctc_lin: !new:speechbrain.nnet.linear.Linear
123
+ input_size: !ref <dnn_neurons>
124
+ n_neurons: !ref <output_neurons>
125
+
126
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
127
+ apply_log: True
128
+
129
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
130
+ blank_index: !ref <blank_index>
131
+
132
+ modules:
133
+ wav2vec2: !ref <wav2vec2>
134
+ enc: !ref <enc>
135
+ ctc_lin: !ref <ctc_lin>
136
+
137
+ model: !new:torch.nn.ModuleList
138
+ - [!ref <enc>, !ref <ctc_lin>]
139
+
140
+ model_opt_class: !name:torch.optim.Adadelta
141
+ lr: !ref <lr>
142
+ rho: 0.95
143
+ eps: 1.e-8
144
+
145
+ wav2vec_opt_class: !name:torch.optim.Adam
146
+ lr: !ref <lr_wav2vec>
147
+
148
+ lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
149
+ initial_value: !ref <lr>
150
+ improvement_threshold: 0.0025
151
+ annealing_factor: 0.8
152
+ patient: 0
153
+
154
+ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
155
+ initial_value: !ref <lr_wav2vec>
156
+ improvement_threshold: 0.0025
157
+ annealing_factor: 0.9
158
+ patient: 0
159
+
160
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
161
+ checkpoints_dir: !ref <save_folder>
162
+ recoverables:
163
+ wav2vec2: !ref <wav2vec2>
164
+ model: !ref <model>
165
+ scheduler_model: !ref <lr_annealing_model>
166
+ scheduler_wav2vec: !ref <lr_annealing_wav2vec>
167
+ counter: !ref <epoch_counter>
168
+
169
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
170
+ save_file: !ref <train_log>
171
+
172
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
173
+
174
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
175
+ split_tokens: True
train_with_wavlm.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import speechbrain as sb
6
+ from pathlib import Path
7
+ import os
8
+ import torchaudio
9
+ from hyperpyyaml import load_hyperpyyaml
10
+ from speechbrain.tokenizers.SentencePiece import SentencePiece
11
+ from speechbrain.utils.data_utils import undo_padding
12
+ from speechbrain.utils.distributed import run_on_main
13
+
14
+ """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
15
+ The system employs a wav2vec2 encoder and a CTC decoder.
16
+ Decoding is performed with greedy decoding (will be extended to beam search).
17
+
18
+ To run this recipe, do the following:
19
+ > python train_with_wav2vec2.py hparams/train_with_wav2vec2.yaml
20
+
21
+ With the default hyperparameters, the system employs a pretrained wav2vec2 encoder.
22
+ The wav2vec2 model is pretrained following the model given in the hprams file.
23
+ It may be dependent on the language.
24
+
25
+ The neural network is trained with CTC on sub-word units estimated with
26
+ Byte Pairwise Encoding (BPE).
27
+
28
+ The experiment file is flexible enough to support a large variety of
29
+ different systems. By properly changing the parameter files, you can try
30
+ different encoders, decoders, tokens (e.g, characters instead of BPE),
31
+ training languages (all CommonVoice languages), and many
32
+ other possible variations.
33
+
34
+ Authors
35
+ * Titouan Parcollet 2021
36
+ """
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ # Define training procedure
42
+ class ASR(sb.core.Brain):
43
+ def compute_forward(self, batch, stage):
44
+ """Forward computations from the waveform batches to the output probabilities."""
45
+
46
+ batch = batch.to(self.device)
47
+ wavs, wav_lens = batch.sig
48
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
49
+ if stage == sb.Stage.TRAIN:
50
+ if hasattr(self.hparams, "augmentation"):
51
+ wavs = self.hparams.augmentation(wavs, wav_lens)
52
+
53
+ # Forward pass
54
+ feats = self.modules.wav2vec2(wavs, wav_lens)
55
+ x = self.modules.enc(feats)
56
+ logits = self.modules.ctc_lin(x)
57
+ p_ctc = self.hparams.log_softmax(logits)
58
+
59
+ return p_ctc, wav_lens
60
+
61
+ def compute_objectives(self, predictions, batch, stage):
62
+ """Computes the loss (CTC) given predictions and targets."""
63
+
64
+ p_ctc, wav_lens = predictions
65
+
66
+ ids = batch.id
67
+ tokens, tokens_lens = batch.tokens
68
+
69
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
70
+
71
+ if stage != sb.Stage.TRAIN:
72
+ predicted_tokens = sb.decoders.ctc_greedy_decode(
73
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
74
+ )
75
+ # Decode token terms to words
76
+ if self.hparams.use_language_modelling:
77
+ predicted_words = []
78
+ for logs in p_ctc:
79
+ text = decoder.decode(logs.detach().cpu().numpy())
80
+ predicted_words.append(text.split(" "))
81
+ else:
82
+ predicted_words = [
83
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
84
+ for utt_seq in predicted_tokens
85
+ ]
86
+ # Convert indices to words
87
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
88
+
89
+ self.wer_metric.append(ids, predicted_words, target_words)
90
+ self.cer_metric.append(ids, predicted_words, target_words)
91
+
92
+ return loss
93
+
94
+ def fit_batch(self, batch):
95
+ """Train the parameters given a single batch in input"""
96
+ should_step = self.step % self.grad_accumulation_factor == 0
97
+ # Managing automatic mixed precision
98
+ # TOFIX: CTC fine-tuning currently is unstable
99
+ # This is certainly due to CTC being done in fp16 instead of fp32
100
+ if self.auto_mix_prec:
101
+ with torch.cuda.amp.autocast():
102
+ with self.no_sync():
103
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
104
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
105
+ with self.no_sync(not should_step):
106
+ self.scaler.scale(
107
+ loss / self.grad_accumulation_factor
108
+ ).backward()
109
+ if should_step:
110
+
111
+ if not self.hparams.wav2vec2.freeze:
112
+ self.scaler.unscale_(self.wav2vec_optimizer)
113
+ self.scaler.unscale_(self.model_optimizer)
114
+ if self.check_gradients(loss):
115
+ if not self.hparams.wav2vec2.freeze:
116
+ if self.optimizer_step >= self.hparams.warmup_steps:
117
+ self.scaler.step(self.wav2vec_optimizer)
118
+ self.scaler.step(self.model_optimizer)
119
+ self.scaler.update()
120
+ self.zero_grad()
121
+ self.optimizer_step += 1
122
+ else:
123
+ # This is mandatory because HF models have a weird behavior with DDP
124
+ # on the forward pass
125
+ with self.no_sync():
126
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
127
+
128
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
129
+
130
+ with self.no_sync(not should_step):
131
+ (loss / self.grad_accumulation_factor).backward()
132
+ if should_step:
133
+ if self.check_gradients(loss):
134
+ if not self.hparams.wav2vec2.freeze:
135
+ if self.optimizer_step >= self.hparams.warmup_steps:
136
+ self.wav2vec_optimizer.step()
137
+ self.model_optimizer.step()
138
+ self.zero_grad()
139
+ self.optimizer_step += 1
140
+
141
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
142
+ return loss.detach().cpu()
143
+
144
+ def evaluate_batch(self, batch, stage):
145
+ """Computations needed for validation/test batches"""
146
+ predictions = self.compute_forward(batch, stage=stage)
147
+ with torch.no_grad():
148
+ loss = self.compute_objectives(predictions, batch, stage=stage)
149
+ return loss.detach()
150
+
151
+ def on_stage_start(self, stage, epoch):
152
+ """Gets called at the beginning of each epoch"""
153
+ if stage != sb.Stage.TRAIN:
154
+ self.cer_metric = self.hparams.cer_computer()
155
+ self.wer_metric = self.hparams.error_rate_computer()
156
+
157
+ def on_stage_end(self, stage, stage_loss, epoch):
158
+ """Gets called at the end of an epoch."""
159
+ # Compute/store important stats
160
+ stage_stats = {"loss": stage_loss}
161
+ if stage == sb.Stage.TRAIN:
162
+ self.train_stats = stage_stats
163
+ else:
164
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
165
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
166
+
167
+ # Perform end-of-iteration things, like annealing, logging, etc.
168
+ if stage == sb.Stage.VALID:
169
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
170
+ stage_stats["loss"]
171
+ )
172
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
173
+ stage_stats["loss"]
174
+ )
175
+ sb.nnet.schedulers.update_learning_rate(
176
+ self.model_optimizer, new_lr_model
177
+ )
178
+ if not self.hparams.wav2vec2.freeze:
179
+ sb.nnet.schedulers.update_learning_rate(
180
+ self.wav2vec_optimizer, new_lr_wav2vec
181
+ )
182
+ self.hparams.train_logger.log_stats(
183
+ stats_meta={
184
+ "epoch": epoch,
185
+ "lr_model": old_lr_model,
186
+ "lr_wav2vec": old_lr_wav2vec,
187
+ },
188
+ train_stats=self.train_stats,
189
+ valid_stats=stage_stats,
190
+ )
191
+ self.checkpointer.save_and_keep_only(
192
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
193
+ )
194
+ elif stage == sb.Stage.TEST:
195
+ self.hparams.train_logger.log_stats(
196
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
197
+ test_stats=stage_stats,
198
+ )
199
+ with open(self.hparams.wer_file, "w") as w:
200
+ self.wer_metric.write_stats(w)
201
+
202
+ def init_optimizers(self):
203
+ "Initializes the wav2vec2 optimizer and model optimizer"
204
+
205
+ # If the wav2vec encoder is unfrozen, we create the optimizer
206
+ if not self.hparams.wav2vec2.freeze:
207
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
208
+ self.modules.wav2vec2.parameters()
209
+ )
210
+ if self.checkpointer is not None:
211
+ self.checkpointer.add_recoverable(
212
+ "wav2vec_opt", self.wav2vec_optimizer
213
+ )
214
+
215
+ self.model_optimizer = self.hparams.model_opt_class(
216
+ self.hparams.model.parameters()
217
+ )
218
+
219
+ if self.checkpointer is not None:
220
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
221
+
222
+ def zero_grad(self, set_to_none=False):
223
+ if not self.hparams.wav2vec2.freeze:
224
+ self.wav2vec_optimizer.zero_grad(set_to_none)
225
+ self.model_optimizer.zero_grad(set_to_none)
226
+
227
+
228
+ # Define custom data procedure
229
+ def dataio_prepare(hparams):
230
+ """This function prepares the datasets to be used in the brain class.
231
+ It also defines the data processing pipeline through user-defined functions."""
232
+
233
+ # 1. Define datasets
234
+ data_folder = hparams["data_folder"]
235
+
236
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
237
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
238
+ )
239
+
240
+ if hparams["sorting"] == "ascending":
241
+ # we sort training data to speed up training and get better results.
242
+ train_data = train_data.filtered_sorted(
243
+ sort_key="duration",
244
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
245
+ )
246
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
247
+ hparams["dataloader_options"]["shuffle"] = False
248
+
249
+ elif hparams["sorting"] == "descending":
250
+ train_data = train_data.filtered_sorted(
251
+ sort_key="duration",
252
+ reverse=True,
253
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
254
+ )
255
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
256
+ hparams["dataloader_options"]["shuffle"] = False
257
+
258
+ elif hparams["sorting"] == "random":
259
+ pass
260
+
261
+ else:
262
+ raise NotImplementedError(
263
+ "sorting must be random, ascending or descending"
264
+ )
265
+
266
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
267
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
268
+ )
269
+ # We also sort the validation data so it is faster to validate
270
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
271
+ test_datasets = {}
272
+ for csv_file in hparams["test_csv"]:
273
+ name = Path(csv_file).stem
274
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
275
+ csv_path=csv_file, replacements={"data_root": data_folder}
276
+ )
277
+ test_datasets[name] = test_datasets[name].filtered_sorted(
278
+ sort_key="duration"
279
+ )
280
+
281
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
282
+
283
+
284
+ # 2. Define audio pipeline:
285
+ @sb.utils.data_pipeline.takes("wav")
286
+ @sb.utils.data_pipeline.provides("sig")
287
+ def audio_pipeline(wav):
288
+ info = torchaudio.info(wav)
289
+ sig = sb.dataio.dataio.read_audio(wav)
290
+ resampled = torchaudio.transforms.Resample(
291
+ info.sample_rate, hparams["sample_rate"],
292
+ )(sig)
293
+ return resampled
294
+
295
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
296
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
297
+
298
+ # 3. Define text pipeline:
299
+ @sb.utils.data_pipeline.takes("wrd")
300
+ @sb.utils.data_pipeline.provides(
301
+ "wrd", "char_list", "tokens_list", "tokens"
302
+ )
303
+ def text_pipeline(wrd):
304
+ yield wrd
305
+ char_list = list(wrd)
306
+ yield char_list
307
+ tokens_list = label_encoder.encode_sequence(char_list)
308
+ yield tokens_list
309
+ tokens = torch.LongTensor(tokens_list)
310
+ yield tokens
311
+
312
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
313
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
314
+ special_labels = {
315
+ "blank_label": hparams["blank_index"],
316
+ "unk_label": hparams["unk_index"]
317
+ }
318
+ label_encoder.load_or_create(
319
+ path=lab_enc_file,
320
+ from_didatasets=[train_data],
321
+ output_key="char_list",
322
+ special_labels=special_labels,
323
+ sequence_input=True,
324
+ )
325
+
326
+ # 4. Set output:
327
+ sb.dataio.dataset.set_output_keys(
328
+ datasets, ["id", "sig", "wrd", "char_list", "tokens"],
329
+ )
330
+ return train_data, valid_data,test_datasets, label_encoder
331
+
332
+
333
+ if __name__ == "__main__":
334
+
335
+ # Load hyperparameters file with command-line overrides
336
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
337
+ with open(hparams_file) as fin:
338
+ hparams = load_hyperpyyaml(fin, overrides)
339
+
340
+ # If --distributed_launch then
341
+ # create ddp_group with the right communication protocol
342
+ sb.utils.distributed.ddp_init_group(run_opts)
343
+
344
+
345
+ # Create experiment directory
346
+ sb.create_experiment_directory(
347
+ experiment_directory=hparams["output_folder"],
348
+ hyperparams_to_save=hparams_file,
349
+ overrides=overrides,
350
+ )
351
+
352
+ # Due to DDP, we do the preparation ONLY on the main python process
353
+ # Defining tokenizer and loading it
354
+ # Create the datasets objects as well as tokenization and encoding :-D
355
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(hparams)
356
+ if hparams["use_language_modelling"]:
357
+ print("using langauge_modeeling")
358
+ from pyctcdecode import build_ctcdecoder
359
+ ind2lab = label_encoder.ind2lab
360
+ print(ind2lab)
361
+ labels = [ind2lab[x] for x in range(len(ind2lab))]
362
+ labels = [""] + labels[1:-1] + ["1"]
363
+ # Replace the <blank> token with a blank character, needed for PyCTCdecode
364
+ print(labels)
365
+ decoder = build_ctcdecoder(
366
+ labels,
367
+ kenlm_model_path=hparams["ngram_lm_path"], # .arpa or .bin
368
+ alpha=0.5, # Default by KenLM
369
+ beta=1.0, # Default by KenLM
370
+ )
371
+ # Trainer initialization
372
+ asr_brain = ASR(
373
+ modules=hparams["modules"],
374
+ hparams=hparams,
375
+ run_opts=run_opts,
376
+ checkpointer=hparams["checkpointer"],
377
+ )
378
+
379
+ # Adding objects to trainer.
380
+ asr_brain.tokenizer = label_encoder
381
+
382
+ # Training
383
+ asr_brain.fit(
384
+ asr_brain.hparams.epoch_counter,
385
+ train_data,
386
+ valid_data,
387
+ train_loader_kwargs=hparams["dataloader_options"],
388
+ valid_loader_kwargs=hparams["test_dataloader_options"],
389
+ )
390
+
391
+ # Test
392
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
393
+ asr_brain.hparams.wer_file = os.path.join(
394
+ hparams["output_folder"], "wer_{}.txt".format(k)
395
+ )
396
+ asr_brain.evaluate(
397
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
398
+ )
399
+