base_mgb2 / preprocess_dataset.py
Ahmed007's picture
Training in progress, step 1000
c8bff8b verified
from datasets import load_dataset, DatasetDict
from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor
from datasets import Audio
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from huggingface_hub import login
import argparse
my_parser = argparse.ArgumentParser()
my_parser.add_argument(
"--model_name",
"-model_name",
type=str,
action="store",
default="openai/whisper-tiny",
)
my_parser.add_argument("--hf_token", "-hf_token", type=str, action="store")
my_parser.add_argument(
"--dataset_name", "-dataset_name", type=str, action="store", default="google/fleurs"
)
my_parser.add_argument("--split", "-split", type=str, action="store", default="test")
my_parser.add_argument("--subset", "-subset", type=str, action="store")
args = my_parser.parse_args()
dataset_name = args.dataset_name
model_name = args.model_name
subset = args.subset
hf_token = args.hf_token
login(hf_token)
text_column = "sentence"
if dataset_name == "google/fleurs":
text_column = "transcription"
do_lower_case = False
do_remove_punctuation = False
normalizer = BasicTextNormalizer()
processor = WhisperProcessor.from_pretrained(
model_name, language="Arabic", task="transcribe"
)
dataset = load_dataset(dataset_name, subset, use_auth_token=True)
print(dataset)
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
tokenizer = WhisperTokenizer.from_pretrained(
model_name, language="Arabic", task="transcribe"
)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
def prepare_dataset(batch):
# load and (possibly) resample audio data to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = processor.feature_extractor(
audio["array"], sampling_rate=audio["sampling_rate"]
).input_features[0]
# compute input length of audio sample in seconds
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
# optional pre-processing steps
transcription = batch[text_column]
if do_lower_case:
transcription = transcription.lower()
if do_remove_punctuation:
transcription = normalizer(transcription).strip()
# encode target text to label ids
batch["labels"] = processor.tokenizer(transcription).input_ids
return batch
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"])
login(hf_token)
print(
f"pushing to arbml/{dataset_name.split('/')[-1]}_preprocessed_{model_name.split('/')[-1]}"
)
dataset.push_to_hub(
f"arbml/{dataset_name.split('/')[-1]}_preprocessed_{model_name.split('/')[-1]}",
private=True,
)