marinone94 commited on
Commit
c6104eb
β€’
1 Parent(s): 79a4bc0

fix oom vocab building. adjust run params

Browse files
Files changed (2) hide show
  1. run.sh +11 -10
  2. run_speech_recognition_ctc.py +21 -4
run.sh CHANGED
@@ -1,21 +1,22 @@
1
  python run_speech_recognition_ctc.py \
2
- --dataset_name="mozilla-foundation/common_voice_7_0" \
3
  --model_name_or_path="KBLab/wav2vec2-large-voxrex" \
4
- --dataset_config_name="sv-SE" \
 
 
5
  --output_dir="./" \
6
  --overwrite_output_dir \
7
- --num_train_epochs="50" \
8
- --per_device_train_batch_size="8" \
9
- --per_device_eval_batch_size="8" \
10
  --gradient_accumulation_steps="4" \
11
  --learning_rate="7.5e-5" \
12
- --warmup_steps="2000" \
13
  --length_column_name="input_length" \
14
- --evaluation_strategy="steps" \
 
15
  --text_column_name="sentence" \
16
- --chars_to_ignore , ? . ! \- \; \: \" β€œ % β€˜ ” οΏ½ β€” ’ … – \
17
- --save_steps="500" \
18
- --eval_steps="500" \
19
  --logging_steps="100" \
20
  --layerdrop="0.0" \
21
  --activation_dropout="0.1" \
 
1
  python run_speech_recognition_ctc.py \
2
+ --dataset_name="mozilla-foundation/common_voice_7_0,marinone94/nst_sv" \
3
  --model_name_or_path="KBLab/wav2vec2-large-voxrex" \
4
+ --dataset_config_name="sv-SE,distant_channel" \
5
+ --train_split_name="train+validation,train"
6
+ --eval_split_name="test,None"
7
  --output_dir="./" \
8
  --overwrite_output_dir \
9
+ --num_train_epochs="5" \
10
+ --per_device_train_batch_size="16" \
11
+ --per_device_eval_batch_size="16" \
12
  --gradient_accumulation_steps="4" \
13
  --learning_rate="7.5e-5" \
14
+ --warmup_steps="1000" \
15
  --length_column_name="input_length" \
16
+ --evaluation_strategy="epoch" \
17
+ --save_strategy="epoch" \
18
  --text_column_name="sentence" \
19
+ --chars_to_ignore , ? . ! \- \; \: \" β€œ % β€˜ ” οΏ½ β€” ’ … – / \\ \
 
 
20
  --logging_steps="100" \
21
  --layerdrop="0.0" \
22
  --activation_dropout="0.1" \
run_speech_recognition_ctc.py CHANGED
@@ -329,8 +329,8 @@ def create_vocabulary_from_data(
329
  vocabs = datasets.map(
330
  extract_all_chars,
331
  batched=True,
332
- batch_size=-1,
333
- keep_in_memory=True,
334
  remove_columns=datasets["train"].column_names,
335
  )
336
 
@@ -449,6 +449,8 @@ def main():
449
  )
450
  ]
451
  )
 
 
452
 
453
  if data_args.audio_column_name not in raw_datasets["train"].column_names:
454
  raise ValueError(
@@ -491,11 +493,13 @@ def main():
491
  load_dataset(
492
  dataset_name,
493
  dataset_config_name,
494
- split=train_split_name,
495
  use_auth_token=data_args.use_auth_token,
496
  )
497
  ]
498
  )
 
 
499
 
500
  if data_args.max_eval_samples is not None:
501
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
@@ -509,6 +513,12 @@ def main():
509
  )
510
  text_column_name = data_args.text_column_name
511
 
 
 
 
 
 
 
512
  def remove_special_characters(batch):
513
  if chars_to_ignore_regex is not None:
514
  batch["target_text"] = re.sub(chars_to_ignore_regex, "", batch[text_column_name]).lower() + " "
@@ -516,6 +526,7 @@ def main():
516
  batch["target_text"] = batch[text_column_name].lower() + " "
517
  return batch
518
 
 
519
  with training_args.main_process_first(desc="dataset map special characters removal"):
520
  raw_datasets = raw_datasets.map(
521
  remove_special_characters,
@@ -523,6 +534,13 @@ def main():
523
  desc="remove special characters from datasets",
524
  )
525
 
 
 
 
 
 
 
 
526
  # save special tokens for tokenizer
527
  word_delimiter_token = data_args.word_delimiter_token
528
  unk_token = data_args.unk_token
@@ -638,7 +656,6 @@ def main():
638
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
639
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
640
  audio_column_name = data_args.audio_column_name
641
- num_workers = data_args.preprocessing_num_workers
642
 
643
  # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
644
  phoneme_language = data_args.phoneme_language
 
329
  vocabs = datasets.map(
330
  extract_all_chars,
331
  batched=True,
332
+ batch_size=10000,
333
+ keep_in_memory=False,
334
  remove_columns=datasets["train"].column_names,
335
  )
336
 
 
449
  )
450
  ]
451
  )
452
+ else:
453
+ logging.warning(f"{dataset_name} {dataset_config_name} as split is {train_split_name}")
454
 
455
  if data_args.audio_column_name not in raw_datasets["train"].column_names:
456
  raise ValueError(
 
493
  load_dataset(
494
  dataset_name,
495
  dataset_config_name,
496
+ split=eval_split_name,
497
  use_auth_token=data_args.use_auth_token,
498
  )
499
  ]
500
  )
501
+ else:
502
+ logging.warning(f"{dataset_name} {dataset_config_name} as split is {eval_split_name}")
503
 
504
  if data_args.max_eval_samples is not None:
505
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
 
513
  )
514
  text_column_name = data_args.text_column_name
515
 
516
+ def is_text_valid(text):
517
+ for token in text.split():
518
+ if len(token) > 1:
519
+ return True
520
+ return False
521
+
522
  def remove_special_characters(batch):
523
  if chars_to_ignore_regex is not None:
524
  batch["target_text"] = re.sub(chars_to_ignore_regex, "", batch[text_column_name]).lower() + " "
 
526
  batch["target_text"] = batch[text_column_name].lower() + " "
527
  return batch
528
 
529
+ num_workers = data_args.preprocessing_num_workers
530
  with training_args.main_process_first(desc="dataset map special characters removal"):
531
  raw_datasets = raw_datasets.map(
532
  remove_special_characters,
 
534
  desc="remove special characters from datasets",
535
  )
536
 
537
+ raw_datasets = raw_datasets.filter(
538
+ is_text_valid,
539
+ num_proc=num_workers,
540
+ input_columns=["input_length"],
541
+ desc="remove single words, single chars and 'W O R D S'",
542
+ )
543
+
544
  # save special tokens for tokenizer
545
  word_delimiter_token = data_args.word_delimiter_token
546
  unk_token = data_args.unk_token
 
656
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
657
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
658
  audio_column_name = data_args.audio_column_name
 
659
 
660
  # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
661
  phoneme_language = data_args.phoneme_language