marinone94
commited on
Commit
β’
c6104eb
1
Parent(s):
79a4bc0
fix oom vocab building. adjust run params
Browse files- run.sh +11 -10
- run_speech_recognition_ctc.py +21 -4
run.sh
CHANGED
@@ -1,21 +1,22 @@
|
|
1 |
python run_speech_recognition_ctc.py \
|
2 |
-
--dataset_name="mozilla-foundation/common_voice_7_0" \
|
3 |
--model_name_or_path="KBLab/wav2vec2-large-voxrex" \
|
4 |
-
--dataset_config_name="sv-SE" \
|
|
|
|
|
5 |
--output_dir="./" \
|
6 |
--overwrite_output_dir \
|
7 |
-
--num_train_epochs="
|
8 |
-
--per_device_train_batch_size="
|
9 |
-
--per_device_eval_batch_size="
|
10 |
--gradient_accumulation_steps="4" \
|
11 |
--learning_rate="7.5e-5" \
|
12 |
-
--warmup_steps="
|
13 |
--length_column_name="input_length" \
|
14 |
-
--evaluation_strategy="
|
|
|
15 |
--text_column_name="sentence" \
|
16 |
-
--chars_to_ignore , ? . ! \- \; \: \" β % β β οΏ½ β β β¦ β \
|
17 |
-
--save_steps="500" \
|
18 |
-
--eval_steps="500" \
|
19 |
--logging_steps="100" \
|
20 |
--layerdrop="0.0" \
|
21 |
--activation_dropout="0.1" \
|
|
|
1 |
python run_speech_recognition_ctc.py \
|
2 |
+
--dataset_name="mozilla-foundation/common_voice_7_0,marinone94/nst_sv" \
|
3 |
--model_name_or_path="KBLab/wav2vec2-large-voxrex" \
|
4 |
+
--dataset_config_name="sv-SE,distant_channel" \
|
5 |
+
--train_split_name="train+validation,train"
|
6 |
+
--eval_split_name="test,None"
|
7 |
--output_dir="./" \
|
8 |
--overwrite_output_dir \
|
9 |
+
--num_train_epochs="5" \
|
10 |
+
--per_device_train_batch_size="16" \
|
11 |
+
--per_device_eval_batch_size="16" \
|
12 |
--gradient_accumulation_steps="4" \
|
13 |
--learning_rate="7.5e-5" \
|
14 |
+
--warmup_steps="1000" \
|
15 |
--length_column_name="input_length" \
|
16 |
+
--evaluation_strategy="epoch" \
|
17 |
+
--save_strategy="epoch" \
|
18 |
--text_column_name="sentence" \
|
19 |
+
--chars_to_ignore , ? . ! \- \; \: \" β % β β οΏ½ β β β¦ β / \\ \
|
|
|
|
|
20 |
--logging_steps="100" \
|
21 |
--layerdrop="0.0" \
|
22 |
--activation_dropout="0.1" \
|
run_speech_recognition_ctc.py
CHANGED
@@ -329,8 +329,8 @@ def create_vocabulary_from_data(
|
|
329 |
vocabs = datasets.map(
|
330 |
extract_all_chars,
|
331 |
batched=True,
|
332 |
-
batch_size
|
333 |
-
keep_in_memory=
|
334 |
remove_columns=datasets["train"].column_names,
|
335 |
)
|
336 |
|
@@ -449,6 +449,8 @@ def main():
|
|
449 |
)
|
450 |
]
|
451 |
)
|
|
|
|
|
452 |
|
453 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
454 |
raise ValueError(
|
@@ -491,11 +493,13 @@ def main():
|
|
491 |
load_dataset(
|
492 |
dataset_name,
|
493 |
dataset_config_name,
|
494 |
-
split=
|
495 |
use_auth_token=data_args.use_auth_token,
|
496 |
)
|
497 |
]
|
498 |
)
|
|
|
|
|
499 |
|
500 |
if data_args.max_eval_samples is not None:
|
501 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
@@ -509,6 +513,12 @@ def main():
|
|
509 |
)
|
510 |
text_column_name = data_args.text_column_name
|
511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
512 |
def remove_special_characters(batch):
|
513 |
if chars_to_ignore_regex is not None:
|
514 |
batch["target_text"] = re.sub(chars_to_ignore_regex, "", batch[text_column_name]).lower() + " "
|
@@ -516,6 +526,7 @@ def main():
|
|
516 |
batch["target_text"] = batch[text_column_name].lower() + " "
|
517 |
return batch
|
518 |
|
|
|
519 |
with training_args.main_process_first(desc="dataset map special characters removal"):
|
520 |
raw_datasets = raw_datasets.map(
|
521 |
remove_special_characters,
|
@@ -523,6 +534,13 @@ def main():
|
|
523 |
desc="remove special characters from datasets",
|
524 |
)
|
525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
# save special tokens for tokenizer
|
527 |
word_delimiter_token = data_args.word_delimiter_token
|
528 |
unk_token = data_args.unk_token
|
@@ -638,7 +656,6 @@ def main():
|
|
638 |
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
639 |
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
640 |
audio_column_name = data_args.audio_column_name
|
641 |
-
num_workers = data_args.preprocessing_num_workers
|
642 |
|
643 |
# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
|
644 |
phoneme_language = data_args.phoneme_language
|
|
|
329 |
vocabs = datasets.map(
|
330 |
extract_all_chars,
|
331 |
batched=True,
|
332 |
+
batch_size=10000,
|
333 |
+
keep_in_memory=False,
|
334 |
remove_columns=datasets["train"].column_names,
|
335 |
)
|
336 |
|
|
|
449 |
)
|
450 |
]
|
451 |
)
|
452 |
+
else:
|
453 |
+
logging.warning(f"{dataset_name} {dataset_config_name} as split is {train_split_name}")
|
454 |
|
455 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
456 |
raise ValueError(
|
|
|
493 |
load_dataset(
|
494 |
dataset_name,
|
495 |
dataset_config_name,
|
496 |
+
split=eval_split_name,
|
497 |
use_auth_token=data_args.use_auth_token,
|
498 |
)
|
499 |
]
|
500 |
)
|
501 |
+
else:
|
502 |
+
logging.warning(f"{dataset_name} {dataset_config_name} as split is {eval_split_name}")
|
503 |
|
504 |
if data_args.max_eval_samples is not None:
|
505 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
|
|
513 |
)
|
514 |
text_column_name = data_args.text_column_name
|
515 |
|
516 |
+
def is_text_valid(text):
|
517 |
+
for token in text.split():
|
518 |
+
if len(token) > 1:
|
519 |
+
return True
|
520 |
+
return False
|
521 |
+
|
522 |
def remove_special_characters(batch):
|
523 |
if chars_to_ignore_regex is not None:
|
524 |
batch["target_text"] = re.sub(chars_to_ignore_regex, "", batch[text_column_name]).lower() + " "
|
|
|
526 |
batch["target_text"] = batch[text_column_name].lower() + " "
|
527 |
return batch
|
528 |
|
529 |
+
num_workers = data_args.preprocessing_num_workers
|
530 |
with training_args.main_process_first(desc="dataset map special characters removal"):
|
531 |
raw_datasets = raw_datasets.map(
|
532 |
remove_special_characters,
|
|
|
534 |
desc="remove special characters from datasets",
|
535 |
)
|
536 |
|
537 |
+
raw_datasets = raw_datasets.filter(
|
538 |
+
is_text_valid,
|
539 |
+
num_proc=num_workers,
|
540 |
+
input_columns=["input_length"],
|
541 |
+
desc="remove single words, single chars and 'W O R D S'",
|
542 |
+
)
|
543 |
+
|
544 |
# save special tokens for tokenizer
|
545 |
word_delimiter_token = data_args.word_delimiter_token
|
546 |
unk_token = data_args.unk_token
|
|
|
656 |
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
657 |
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
658 |
audio_column_name = data_args.audio_column_name
|
|
|
659 |
|
660 |
# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
|
661 |
phoneme_language = data_args.phoneme_language
|