ivanlau commited on
Commit
a80f928
·
1 Parent(s): 103417d

Training in progress, step 10

Browse files
.ipynb_checkpoints/finetune-checkpoint.py DELETED
@@ -1,267 +0,0 @@
1
- import json
2
- import random
3
- import re
4
- from dataclasses import dataclass, field
5
- from typing import Any, Dict, List, Optional, Union
6
-
7
- import numpy as np
8
- import pandas as pd
9
- import torch
10
- import torchaudio
11
- import transformers
12
- import datasets
13
- from datasets import ClassLabel, load_dataset, load_metric
14
- from transformers import (Trainer, TrainingArguments, Wav2Vec2CTCTokenizer,
15
- Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC,
16
- Wav2Vec2Processor)
17
-
18
- import argparse
19
- parser = argparse.ArgumentParser()
20
- parser.add_argument('--model', type=str, default="facebook/wav2vec2-xls-r-300m")
21
- parser.add_argument('--unfreeze', action='store_true')
22
- parser.add_argument('--lr', type=float, default=3e-4)
23
- parser.add_argument('--warmup', type=float, default=500)
24
- args = parser.parse_args()
25
-
26
-
27
- print(f"args: {args}")
28
-
29
- common_voice_train = datasets.load_dataset("mozilla-foundation/common_voice_8_0", "zh-HK", split="train+validation", use_auth_token=True)
30
- common_voice_test = datasets.load_dataset("mozilla-foundation/common_voice_8_0", "zh-HK", split="test[:10%]", use_auth_token=True)
31
-
32
- # common_voice_train = datasets.load_dataset("common_voice", "zh-HK", split="train+validation", use_auth_token=True)
33
- # common_voice_test = datasets.load_dataset("common_voice", "zh-HK", split="test[:10%]", use_auth_token=True)
34
-
35
- unused_cols = ["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"]
36
- common_voice_train = common_voice_train.remove_columns(unused_cols)
37
- common_voice_test = common_voice_test.remove_columns(unused_cols)
38
-
39
- chars_to_ignore_regex = '[\丶\,\?\.\!\-\;\:"\“\%\‘\”\�\.\⋯\!\-\:\–\。\》\,\)\,\?\;\~\~\…\︰\,\(\」\‧\《\﹔\、\—\/\,\「\﹖\·\']'
40
-
41
- import string
42
- def remove_special_characters(batch):
43
- sen = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
44
- # convert 'D' and 'd' to '啲' if there a 'D' in sentence
45
- # hacky stuff, wont work on 'D', 'd' co-occure with normal english words
46
- # wont work on multiple 'D'
47
- if "d" in sen:
48
- if len([c for c in sen if c in string.ascii_lowercase]) == 1:
49
- sen = sen.replace("d", "啲")
50
- batch["sentence"] = sen
51
- return batch
52
-
53
- common_voice_train = common_voice_train.map(remove_special_characters)
54
- common_voice_test = common_voice_test.map(remove_special_characters)
55
-
56
- def extract_all_chars(batch):
57
- all_text = " ".join(batch["sentence"])
58
- vocab = list(set(all_text))
59
- return {"vocab": [vocab], "all_text": [all_text]}
60
-
61
- vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names,)
62
- vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names,)
63
- vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
64
- vocab_list = [char for char in vocab_list if not char.isascii()] # remove english char from vocab_list, so tokenizer will replace english with [UNK]
65
- vocab_list.append(" ") # previous will remove " " from vocab_list
66
-
67
- vocab_dict = {v: k for k, v in enumerate(vocab_list)}
68
- vocab_dict["|"] = vocab_dict[" "]
69
- del vocab_dict[" "]
70
-
71
- vocab_dict["[UNK]"] = len(vocab_dict)
72
- vocab_dict["[PAD]"] = len(vocab_dict)
73
-
74
- with open("vocab.json", "w") as vocab_file:
75
- json.dump(vocab_dict, vocab_file)
76
-
77
- tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
78
-
79
- feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True,)
80
-
81
- processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
82
- processor.save_pretrained("./finetuned-wav2vec2-xls-r-300m-cantonese")
83
-
84
- # resamplers = {
85
- # 48000: torchaudio.transforms.Resample(48000, 16000),
86
- # 44100: torchaudio.transforms.Resample(44100, 16000),
87
- # }
88
-
89
- # def load_and_resample(batch):
90
- # speech_array, sampling_rate = torchaudio.load(batch["path"])
91
- # batch["array"] = resamplers[sampling_rate](speech_array).squeeze().numpy()
92
- # batch["sampling_rate"] = 16_000
93
- # batch["target_text"] = batch["sentence"]
94
- # return batch
95
-
96
- # common_voice_train = common_voice_train.map(load_and_resample, remove_columns=common_voice_train.column_names,)
97
- # common_voice_test = common_voice_test.map(load_and_resample, remove_columns=common_voice_test.column_names,)
98
-
99
-
100
- common_voice_train = common_voice_train.cast_column('audio', datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate))
101
- common_voice_test = common_voice_test.cast_column('audio', datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate))
102
-
103
-
104
- def prepare_dataset(batch):
105
- batch["input_values"] = processor(batch["array"], sampling_rate=batch["sampling_rate"][0]).input_values
106
- with processor.as_target_processor():
107
- batch["labels"] = processor(batch["target_text"]).input_ids
108
- return batch
109
-
110
- print(common_voice_train[0]['audio'])
111
-
112
- common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, batched=True,)
113
- common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batched=True,)
114
-
115
-
116
- @dataclass
117
- class DataCollatorCTCWithPadding:
118
- """
119
- Data collator that will dynamically pad the inputs received.
120
- Args:
121
- processor (:class:`~transformers.Wav2Vec2Processor`)
122
- The processor used for proccessing the data.
123
- padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
124
- Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
125
- among:
126
- * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
127
- sequence if provided).
128
- * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
129
- maximum acceptable input length for the model if that argument is not provided.
130
- * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
131
- different lengths).
132
- max_length (:obj:`int`, `optional`):
133
- Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
134
- max_length_labels (:obj:`int`, `optional`):
135
- Maximum length of the ``labels`` returned list and optionally padding length (see above).
136
- pad_to_multiple_of (:obj:`int`, `optional`):
137
- If set will pad the sequence to a multiple of the provided value.
138
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
139
- 7.5 (Volta).
140
- """
141
-
142
- processor: Wav2Vec2Processor
143
- padding: Union[bool, str] = True
144
- max_length: Optional[int] = None
145
- max_length_labels: Optional[int] = None
146
- pad_to_multiple_of: Optional[int] = None
147
- pad_to_multiple_of_labels: Optional[int] = None
148
-
149
- def __call__(
150
- self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
151
- ) -> Dict[str, torch.Tensor]:
152
- # split inputs and labels since they have to be of different lenghts and need
153
- # different padding methods
154
- input_features = [
155
- {"input_values": feature["input_values"]} for feature in features
156
- ]
157
- label_features = [{"input_ids": feature["labels"]} for feature in features]
158
-
159
- batch = self.processor.pad(
160
- input_features,
161
- padding=self.padding,
162
- max_length=self.max_length,
163
- pad_to_multiple_of=self.pad_to_multiple_of,
164
- return_tensors="pt",
165
- )
166
- with self.processor.as_target_processor():
167
- labels_batch = self.processor.pad(
168
- label_features,
169
- padding=self.padding,
170
- max_length=self.max_length_labels,
171
- pad_to_multiple_of=self.pad_to_multiple_of_labels,
172
- return_tensors="pt",
173
- )
174
-
175
- # replace padding with -100 to ignore loss correctly
176
- labels = labels_batch["input_ids"].masked_fill(
177
- labels_batch.attention_mask.ne(1), -100
178
- )
179
-
180
- batch["labels"] = labels
181
-
182
- return batch
183
-
184
-
185
- data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
186
- # cer_metric = load_metric("./cer")
187
-
188
- # def compute_metrics(pred):
189
- # pred_logits = pred.predictions
190
- # pred_ids = np.argmax(pred_logits, axis=-1)
191
-
192
- # pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
193
-
194
- # pred_str = processor.batch_decode(pred_ids)
195
- # # we do not want to group tokens when computing the metrics
196
- # label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
197
-
198
- # cer = cer_metric.compute(predictions=pred_str, references=label_str)
199
-
200
- # return {"cer": cer}
201
-
202
- def compute_metrics(pred):
203
- pred_logits = pred.predictions
204
- pred_ids = np.argmax(pred_logits, axis=-1)
205
-
206
- pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
207
-
208
- pred_str = tokenizer.batch_decode(pred_ids)
209
- # we do not want to group tokens when computing the metrics
210
- label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
211
-
212
- metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
213
-
214
- return metrics
215
-
216
- model = Wav2Vec2ForCTC.from_pretrained(
217
- args.model,
218
- attention_dropout=0.1,
219
- hidden_dropout=0.1,
220
- feat_proj_dropout=0.0,
221
- mask_time_prob=0.05,
222
- layerdrop=0.1,
223
- gradient_checkpointing=True,
224
- ctc_loss_reduction="mean",
225
- pad_token_id=processor.tokenizer.pad_token_id,
226
- vocab_size=len(processor.tokenizer),
227
- )
228
-
229
- if not args.unfreeze:
230
- model.freeze_feature_extractor()
231
-
232
- training_args = TrainingArguments(
233
- output_dir="./finetuned-wav2vec2-xls-r-300m-cantonese/wav2vec2-xls-r-300m-cantonese",
234
- group_by_length=True,
235
- per_device_train_batch_size=8,
236
- gradient_accumulation_steps=2,
237
- #evaluation_strategy="no",
238
- evaluation_strategy="steps",
239
- #evaluation_strategy="epoch",
240
- eval_steps=400,
241
- #eval_accumulation_steps=60,
242
- num_train_epochs=1,
243
- fp16=True,
244
- fp16_backend="amp",
245
- logging_strategy="steps",
246
- logging_steps=400,
247
- #logging_strategy="epoch",
248
- learning_rate=args.lr,
249
- warmup_steps=100,
250
- save_steps=2376, # every 3 epoch with batch_size 8
251
- #save_strategy="epoch",
252
- save_total_limit=3,
253
- ###################
254
- # fp16_full_eval=True,
255
- dataloader_num_workers=20,
256
- )
257
-
258
- trainer = Trainer(
259
- model=model,
260
- data_collator=data_collator,
261
- args=training_args,
262
- compute_metrics=compute_metrics,
263
- train_dataset=common_voice_train,
264
- eval_dataset=common_voice_test,
265
- tokenizer=processor.feature_extractor,
266
- )
267
- trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.ipynb_checkpoints/run-checkpoint.sh CHANGED
@@ -5,7 +5,7 @@ python run_speech_recognition_ctc.py \
5
  --output_dir="./" \
6
  --cache_dir="../container_0" \
7
  --overwrite_output_dir \
8
- --num_train_epochs="1" \
9
  --per_device_train_batch_size="8" \
10
  --per_device_eval_batch_size="1" \
11
  --gradient_accumulation_steps="2" \
@@ -26,4 +26,4 @@ python run_speech_recognition_ctc.py \
26
  --push_to_hub \
27
  --do_train \
28
  --do_eval \
29
- --max_duration_in_seconds="3"
 
5
  --output_dir="./" \
6
  --cache_dir="../container_0" \
7
  --overwrite_output_dir \
8
+ --num_train_epochs="20" \
9
  --per_device_train_batch_size="8" \
10
  --per_device_eval_batch_size="1" \
11
  --gradient_accumulation_steps="2" \
 
26
  --push_to_hub \
27
  --do_train \
28
  --do_eval \
29
+ --max_duration_in_seconds="6"
.ipynb_checkpoints/run_speech_recognition_ctc-checkpoint.py CHANGED
@@ -572,9 +572,10 @@ def main():
572
 
573
  # make sure that dataset decodes audio with correct sampling rate
574
  dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
 
575
  # print("data sample rate:", dataset_sampling_rate) # 48_000
576
  # print("feature sample rate:", feature_extractor.sampling_rate) # 16_000
577
-
578
  # # remove long common voice
579
  # def remove_long_common_voicedata(dataset, max_seconds=6):
580
  # #convert pyarrow table to pandas
 
572
 
573
  # make sure that dataset decodes audio with correct sampling rate
574
  dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
575
+
576
  # print("data sample rate:", dataset_sampling_rate) # 48_000
577
  # print("feature sample rate:", feature_extractor.sampling_rate) # 16_000
578
+
579
  # # remove long common voice
580
  # def remove_long_common_voicedata(dataset, max_seconds=6):
581
  # #convert pyarrow table to pandas
finetune.py DELETED
@@ -1,267 +0,0 @@
1
- import json
2
- import random
3
- import re
4
- from dataclasses import dataclass, field
5
- from typing import Any, Dict, List, Optional, Union
6
-
7
- import numpy as np
8
- import pandas as pd
9
- import torch
10
- import torchaudio
11
- import transformers
12
- import datasets
13
- from datasets import ClassLabel, load_dataset, load_metric
14
- from transformers import (Trainer, TrainingArguments, Wav2Vec2CTCTokenizer,
15
- Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC,
16
- Wav2Vec2Processor)
17
-
18
- import argparse
19
- parser = argparse.ArgumentParser()
20
- parser.add_argument('--model', type=str, default="facebook/wav2vec2-xls-r-300m")
21
- parser.add_argument('--unfreeze', action='store_true')
22
- parser.add_argument('--lr', type=float, default=3e-4)
23
- parser.add_argument('--warmup', type=float, default=500)
24
- args = parser.parse_args()
25
-
26
-
27
- print(f"args: {args}")
28
-
29
- common_voice_train = datasets.load_dataset("mozilla-foundation/common_voice_8_0", "zh-HK", split="train+validation", use_auth_token=True)
30
- common_voice_test = datasets.load_dataset("mozilla-foundation/common_voice_8_0", "zh-HK", split="test[:10%]", use_auth_token=True)
31
-
32
- # common_voice_train = datasets.load_dataset("common_voice", "zh-HK", split="train+validation", use_auth_token=True)
33
- # common_voice_test = datasets.load_dataset("common_voice", "zh-HK", split="test[:10%]", use_auth_token=True)
34
-
35
- unused_cols = ["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"]
36
- common_voice_train = common_voice_train.remove_columns(unused_cols)
37
- common_voice_test = common_voice_test.remove_columns(unused_cols)
38
-
39
- chars_to_ignore_regex = '[\丶\,\?\.\!\-\;\:"\“\%\‘\”\�\.\⋯\!\-\:\–\。\》\,\)\,\?\;\~\~\…\︰\,\(\」\‧\《\﹔\、\—\/\,\「\﹖\·\']'
40
-
41
- import string
42
- def remove_special_characters(batch):
43
- sen = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
44
- # convert 'D' and 'd' to '啲' if there a 'D' in sentence
45
- # hacky stuff, wont work on 'D', 'd' co-occure with normal english words
46
- # wont work on multiple 'D'
47
- if "d" in sen:
48
- if len([c for c in sen if c in string.ascii_lowercase]) == 1:
49
- sen = sen.replace("d", "啲")
50
- batch["sentence"] = sen
51
- return batch
52
-
53
- common_voice_train = common_voice_train.map(remove_special_characters)
54
- common_voice_test = common_voice_test.map(remove_special_characters)
55
-
56
- def extract_all_chars(batch):
57
- all_text = " ".join(batch["sentence"])
58
- vocab = list(set(all_text))
59
- return {"vocab": [vocab], "all_text": [all_text]}
60
-
61
- vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names,)
62
- vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names,)
63
- vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
64
- vocab_list = [char for char in vocab_list if not char.isascii()] # remove english char from vocab_list, so tokenizer will replace english with [UNK]
65
- vocab_list.append(" ") # previous will remove " " from vocab_list
66
-
67
- vocab_dict = {v: k for k, v in enumerate(vocab_list)}
68
- vocab_dict["|"] = vocab_dict[" "]
69
- del vocab_dict[" "]
70
-
71
- vocab_dict["[UNK]"] = len(vocab_dict)
72
- vocab_dict["[PAD]"] = len(vocab_dict)
73
-
74
- with open("vocab.json", "w") as vocab_file:
75
- json.dump(vocab_dict, vocab_file)
76
-
77
- tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
78
-
79
- feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True,)
80
-
81
- processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
82
- processor.save_pretrained("./finetuned-wav2vec2-xls-r-300m-cantonese")
83
-
84
- # resamplers = {
85
- # 48000: torchaudio.transforms.Resample(48000, 16000),
86
- # 44100: torchaudio.transforms.Resample(44100, 16000),
87
- # }
88
-
89
- # def load_and_resample(batch):
90
- # speech_array, sampling_rate = torchaudio.load(batch["path"])
91
- # batch["array"] = resamplers[sampling_rate](speech_array).squeeze().numpy()
92
- # batch["sampling_rate"] = 16_000
93
- # batch["target_text"] = batch["sentence"]
94
- # return batch
95
-
96
- # common_voice_train = common_voice_train.map(load_and_resample, remove_columns=common_voice_train.column_names,)
97
- # common_voice_test = common_voice_test.map(load_and_resample, remove_columns=common_voice_test.column_names,)
98
-
99
-
100
- common_voice_train = common_voice_train.cast_column('audio', datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate))
101
- common_voice_test = common_voice_test.cast_column('audio', datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate))
102
-
103
-
104
- def prepare_dataset(batch):
105
- batch["input_values"] = processor(batch["array"], sampling_rate=batch["sampling_rate"][0]).input_values
106
- with processor.as_target_processor():
107
- batch["labels"] = processor(batch["target_text"]).input_ids
108
- return batch
109
-
110
- print(common_voice_train[0]['audio'])
111
-
112
- common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, batched=True,)
113
- common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batched=True,)
114
-
115
-
116
- @dataclass
117
- class DataCollatorCTCWithPadding:
118
- """
119
- Data collator that will dynamically pad the inputs received.
120
- Args:
121
- processor (:class:`~transformers.Wav2Vec2Processor`)
122
- The processor used for proccessing the data.
123
- padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
124
- Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
125
- among:
126
- * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
127
- sequence if provided).
128
- * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
129
- maximum acceptable input length for the model if that argument is not provided.
130
- * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
131
- different lengths).
132
- max_length (:obj:`int`, `optional`):
133
- Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
134
- max_length_labels (:obj:`int`, `optional`):
135
- Maximum length of the ``labels`` returned list and optionally padding length (see above).
136
- pad_to_multiple_of (:obj:`int`, `optional`):
137
- If set will pad the sequence to a multiple of the provided value.
138
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
139
- 7.5 (Volta).
140
- """
141
-
142
- processor: Wav2Vec2Processor
143
- padding: Union[bool, str] = True
144
- max_length: Optional[int] = None
145
- max_length_labels: Optional[int] = None
146
- pad_to_multiple_of: Optional[int] = None
147
- pad_to_multiple_of_labels: Optional[int] = None
148
-
149
- def __call__(
150
- self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
151
- ) -> Dict[str, torch.Tensor]:
152
- # split inputs and labels since they have to be of different lenghts and need
153
- # different padding methods
154
- input_features = [
155
- {"input_values": feature["input_values"]} for feature in features
156
- ]
157
- label_features = [{"input_ids": feature["labels"]} for feature in features]
158
-
159
- batch = self.processor.pad(
160
- input_features,
161
- padding=self.padding,
162
- max_length=self.max_length,
163
- pad_to_multiple_of=self.pad_to_multiple_of,
164
- return_tensors="pt",
165
- )
166
- with self.processor.as_target_processor():
167
- labels_batch = self.processor.pad(
168
- label_features,
169
- padding=self.padding,
170
- max_length=self.max_length_labels,
171
- pad_to_multiple_of=self.pad_to_multiple_of_labels,
172
- return_tensors="pt",
173
- )
174
-
175
- # replace padding with -100 to ignore loss correctly
176
- labels = labels_batch["input_ids"].masked_fill(
177
- labels_batch.attention_mask.ne(1), -100
178
- )
179
-
180
- batch["labels"] = labels
181
-
182
- return batch
183
-
184
-
185
- data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
186
- # cer_metric = load_metric("./cer")
187
-
188
- # def compute_metrics(pred):
189
- # pred_logits = pred.predictions
190
- # pred_ids = np.argmax(pred_logits, axis=-1)
191
-
192
- # pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
193
-
194
- # pred_str = processor.batch_decode(pred_ids)
195
- # # we do not want to group tokens when computing the metrics
196
- # label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
197
-
198
- # cer = cer_metric.compute(predictions=pred_str, references=label_str)
199
-
200
- # return {"cer": cer}
201
-
202
- def compute_metrics(pred):
203
- pred_logits = pred.predictions
204
- pred_ids = np.argmax(pred_logits, axis=-1)
205
-
206
- pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
207
-
208
- pred_str = tokenizer.batch_decode(pred_ids)
209
- # we do not want to group tokens when computing the metrics
210
- label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
211
-
212
- metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
213
-
214
- return metrics
215
-
216
- model = Wav2Vec2ForCTC.from_pretrained(
217
- args.model,
218
- attention_dropout=0.1,
219
- hidden_dropout=0.1,
220
- feat_proj_dropout=0.0,
221
- mask_time_prob=0.05,
222
- layerdrop=0.1,
223
- gradient_checkpointing=True,
224
- ctc_loss_reduction="mean",
225
- pad_token_id=processor.tokenizer.pad_token_id,
226
- vocab_size=len(processor.tokenizer),
227
- )
228
-
229
- if not args.unfreeze:
230
- model.freeze_feature_extractor()
231
-
232
- training_args = TrainingArguments(
233
- output_dir="./finetuned-wav2vec2-xls-r-300m-cantonese/wav2vec2-xls-r-300m-cantonese",
234
- group_by_length=True,
235
- per_device_train_batch_size=8,
236
- gradient_accumulation_steps=2,
237
- #evaluation_strategy="no",
238
- evaluation_strategy="steps",
239
- #evaluation_strategy="epoch",
240
- eval_steps=400,
241
- #eval_accumulation_steps=60,
242
- num_train_epochs=1,
243
- fp16=True,
244
- fp16_backend="amp",
245
- logging_strategy="steps",
246
- logging_steps=400,
247
- #logging_strategy="epoch",
248
- learning_rate=args.lr,
249
- warmup_steps=100,
250
- save_steps=2376, # every 3 epoch with batch_size 8
251
- #save_strategy="epoch",
252
- save_total_limit=3,
253
- ###################
254
- # fp16_full_eval=True,
255
- dataloader_num_workers=20,
256
- )
257
-
258
- trainer = Trainer(
259
- model=model,
260
- data_collator=data_collator,
261
- args=training_args,
262
- compute_metrics=compute_metrics,
263
- train_dataset=common_voice_train,
264
- eval_dataset=common_voice_test,
265
- tokenizer=processor.feature_extractor,
266
- )
267
- trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:adfaab1decb47281256c560d259dcea403f5ee4639cd88d6072d844ed1e991d9
3
  size 1278024433
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dcc7787f7bcc54150b3f999655173fdb6b418594a5bff33ca62d1262157b933
3
  size 1278024433
run.sh CHANGED
@@ -5,7 +5,7 @@ python run_speech_recognition_ctc.py \
5
  --output_dir="./" \
6
  --cache_dir="../container_0" \
7
  --overwrite_output_dir \
8
- --num_train_epochs="1" \
9
  --per_device_train_batch_size="8" \
10
  --per_device_eval_batch_size="1" \
11
  --gradient_accumulation_steps="2" \
@@ -26,4 +26,4 @@ python run_speech_recognition_ctc.py \
26
  --push_to_hub \
27
  --do_train \
28
  --do_eval \
29
- --max_duration_in_seconds="3"
 
5
  --output_dir="./" \
6
  --cache_dir="../container_0" \
7
  --overwrite_output_dir \
8
+ --num_train_epochs="20" \
9
  --per_device_train_batch_size="8" \
10
  --per_device_eval_batch_size="1" \
11
  --gradient_accumulation_steps="2" \
 
26
  --push_to_hub \
27
  --do_train \
28
  --do_eval \
29
+ --max_duration_in_seconds="6"
run_speech_recognition_ctc.py CHANGED
@@ -572,9 +572,10 @@ def main():
572
 
573
  # make sure that dataset decodes audio with correct sampling rate
574
  dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
 
575
  # print("data sample rate:", dataset_sampling_rate) # 48_000
576
  # print("feature sample rate:", feature_extractor.sampling_rate) # 16_000
577
-
578
  # # remove long common voice
579
  # def remove_long_common_voicedata(dataset, max_seconds=6):
580
  # #convert pyarrow table to pandas
 
572
 
573
  # make sure that dataset decodes audio with correct sampling rate
574
  dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
575
+
576
  # print("data sample rate:", dataset_sampling_rate) # 48_000
577
  # print("feature sample rate:", feature_extractor.sampling_rate) # 16_000
578
+
579
  # # remove long common voice
580
  # def remove_long_common_voicedata(dataset, max_seconds=6):
581
  # #convert pyarrow table to pandas
special_tokens_map.json CHANGED
@@ -1 +1 @@
1
- {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]", "additional_special_tokens": [{"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}]}
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]", "additional_special_tokens": [{"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}]}
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:530b8b92ddc86c8636f3692a5bd1f5a725cd44fe2261da30abe50a8651413ae9
3
  size 2991
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6af2b3c8fe0d22f849d5057a3507a71d4f6c5aae8f245106e7ccc49c930315da
3
  size 2991