Commit
·
79a4bc0
1
Parent(s):
2b994e2
allow for multiple datasets from hf in run
Browse files- run_speech_recognition_ctc.py +75 -17
run_speech_recognition_ctc.py
CHANGED
@@ -30,7 +30,7 @@ import datasets
|
|
30 |
import numpy as np
|
31 |
import torch
|
32 |
import wandb
|
33 |
-
from datasets import DatasetDict, load_dataset, load_metric
|
34 |
|
35 |
import transformers
|
36 |
from transformers import (
|
@@ -140,21 +140,33 @@ class DataTrainingArguments:
|
|
140 |
"""
|
141 |
|
142 |
dataset_name: str = field(
|
143 |
-
metadata={
|
|
|
|
|
|
|
|
|
144 |
)
|
145 |
dataset_config_name: str = field(
|
146 |
-
default=None, metadata={
|
|
|
|
|
|
|
|
|
147 |
)
|
148 |
train_split_name: str = field(
|
149 |
default="train+validation",
|
150 |
metadata={
|
151 |
-
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
|
|
|
|
152 |
},
|
153 |
)
|
154 |
eval_split_name: str = field(
|
155 |
default="test",
|
156 |
metadata={
|
157 |
-
"help": "The name of the training data set split to use (via the datasets library). Defaults to '
|
|
|
|
|
158 |
},
|
159 |
)
|
160 |
audio_column_name: str = field(
|
@@ -407,12 +419,36 @@ def main():
|
|
407 |
raw_datasets = DatasetDict()
|
408 |
|
409 |
if training_args.do_train:
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
418 |
raise ValueError(
|
@@ -432,12 +468,34 @@ def main():
|
|
432 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
433 |
|
434 |
if training_args.do_eval:
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
|
442 |
if data_args.max_eval_samples is not None:
|
443 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
|
|
30 |
import numpy as np
|
31 |
import torch
|
32 |
import wandb
|
33 |
+
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_metric
|
34 |
|
35 |
import transformers
|
36 |
from transformers import (
|
|
|
140 |
"""
|
141 |
|
142 |
dataset_name: str = field(
|
143 |
+
metadata={
|
144 |
+
"help": "The name of the dataset to use (via the datasets library)."
|
145 |
+
" To use multiple datasets, specify them separated by a comma."
|
146 |
+
" e.g.: 'mozilla-foundation/common_voice_7_0,marinone94/nst_sv'"
|
147 |
+
}
|
148 |
)
|
149 |
dataset_config_name: str = field(
|
150 |
+
default=None, metadata={
|
151 |
+
"help": "The configuration name of the dataset to use (via the datasets library)."
|
152 |
+
" To use multiple datasets, specify them separated by a comma."
|
153 |
+
" e.g.: 'sv-SE,sv'"
|
154 |
+
}
|
155 |
)
|
156 |
train_split_name: str = field(
|
157 |
default="train+validation",
|
158 |
metadata={
|
159 |
+
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train+validation'"
|
160 |
+
" To use multiple datasets, specify them separated by a comma."
|
161 |
+
" e.g.: 'train+validation,all'"
|
162 |
},
|
163 |
)
|
164 |
eval_split_name: str = field(
|
165 |
default="test",
|
166 |
metadata={
|
167 |
+
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'test'"
|
168 |
+
" To use multiple datasets, specify them separated by a comma."
|
169 |
+
" e.g.: 'test,None'"
|
170 |
},
|
171 |
)
|
172 |
audio_column_name: str = field(
|
|
|
419 |
raw_datasets = DatasetDict()
|
420 |
|
421 |
if training_args.do_train:
|
422 |
+
|
423 |
+
# Multiple datasets might need to be loaded from HF
|
424 |
+
# It assumes they all follow the common voice format
|
425 |
+
for (dataset_name, dataset_config_name, train_split_name) in zip(
|
426 |
+
data_args.dataset_name.split(","),
|
427 |
+
data_args.dataset_config_name.split(","),
|
428 |
+
data_args.train_split_name.split(","),
|
429 |
+
):
|
430 |
+
|
431 |
+
|
432 |
+
if train_split_name != "None":
|
433 |
+
if "train" not in raw_datasets:
|
434 |
+
raw_datasets["train"] = load_dataset(
|
435 |
+
dataset_name,
|
436 |
+
dataset_config_name,
|
437 |
+
split=train_split_name,
|
438 |
+
use_auth_token=data_args.use_auth_token,
|
439 |
+
)
|
440 |
+
else:
|
441 |
+
raw_datasets["train"] = concatenate_datasets(
|
442 |
+
[
|
443 |
+
raw_datasets["train"],
|
444 |
+
load_dataset(
|
445 |
+
dataset_name,
|
446 |
+
dataset_config_name,
|
447 |
+
split=train_split_name,
|
448 |
+
use_auth_token=data_args.use_auth_token,
|
449 |
+
)
|
450 |
+
]
|
451 |
+
)
|
452 |
|
453 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
454 |
raise ValueError(
|
|
|
468 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
469 |
|
470 |
if training_args.do_eval:
|
471 |
+
# Multiple datasets might need to be loaded from HF
|
472 |
+
# It assumes they all follow the common voice format
|
473 |
+
for (dataset_name, dataset_config_name, eval_split_name) in zip(
|
474 |
+
data_args.dataset_name.split(","),
|
475 |
+
data_args.dataset_config_name.split(","),
|
476 |
+
data_args.eval_split_name.split(","),
|
477 |
+
):
|
478 |
+
|
479 |
+
if train_split_name != "None":
|
480 |
+
if "eval" not in raw_datasets:
|
481 |
+
raw_datasets["eval"] = load_dataset(
|
482 |
+
dataset_name,
|
483 |
+
dataset_config_name,
|
484 |
+
split=eval_split_name,
|
485 |
+
use_auth_token=data_args.use_auth_token,
|
486 |
+
)
|
487 |
+
else:
|
488 |
+
raw_datasets["eval"] = concatenate_datasets(
|
489 |
+
[
|
490 |
+
raw_datasets["eval"],
|
491 |
+
load_dataset(
|
492 |
+
dataset_name,
|
493 |
+
dataset_config_name,
|
494 |
+
split=train_split_name,
|
495 |
+
use_auth_token=data_args.use_auth_token,
|
496 |
+
)
|
497 |
+
]
|
498 |
+
)
|
499 |
|
500 |
if data_args.max_eval_samples is not None:
|
501 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|