|
from transformers.models.whisper.english_normalizer import BasicTextNormalizer |
|
from audiomentations import Compose, TimeStretch, PitchShift |
|
from datasets import Audio |
|
from datasets import load_dataset, DatasetDict |
|
import jiwer |
|
import warnings |
|
import pandas as pd |
|
from io import StringIO |
|
from datasets import Dataset, IterableDatasetDict, load_dataset, interleave_datasets, Audio |
|
import evaluate |
|
|
|
import torch |
|
import string |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Union |
|
|
|
from transformers import WhisperForConditionalGeneration |
|
from transformers import WhisperProcessor |
|
from transformers import Seq2SeqTrainingArguments |
|
from transformers import Seq2SeqTrainer |
|
from transformers import WhisperTokenizer |
|
from transformers import WhisperFeatureExtractor |
|
|
|
|
|
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift |
|
import numpy as np |
|
|
|
from transformers import TrainerCallback |
|
from transformers.integrations import WandbCallback |
|
from transformers.trainer_pt_utils import IterableDatasetShard |
|
from torch.utils.data import IterableDataset |
|
from datasets import load_dataset, Audio |
|
from pathlib import Path |
|
import numpy as np |
|
|
|
|
|
import tempfile |
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.is_available() |
|
|
|
"""## Load Dataset |
|
Loading MS-MY Dataset from FLEURS. |
|
Combine train and validation set. |
|
""" |
|
|
|
|
|
|
|
|
|
fleurs = DatasetDict() |
|
fleurs["train"] = load_dataset( |
|
"google/fleurs", "ps_af", split="train+validation", use_auth_token=True) |
|
fleurs["test"] = load_dataset( |
|
"google/fleurs", "ps_af", split="test", use_auth_token=True) |
|
|
|
fleurs = fleurs.remove_columns( |
|
["id", "num_samples", "path", "raw_transcription", "gender", "lang_id", "language", "lang_group_id"]) |
|
|
|
print(fleurs) |
|
|
|
|
|
feature_extractor = WhisperFeatureExtractor.from_pretrained( |
|
"openai/whisper-small") |
|
|
|
|
|
tokenizer = WhisperTokenizer.from_pretrained( |
|
"openai/whisper-small", language="Pashto", task="transcribe") |
|
|
|
"""### Combine To Create A WhisperProcessor""" |
|
|
|
|
|
processor = WhisperProcessor.from_pretrained( |
|
"openai/whisper-small", language="Pashto", task="transcribe") |
|
|
|
"""### Prepare Data""" |
|
|
|
fleurs = fleurs.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
|
|
|
augment_waveform = Compose([ |
|
TimeStretch(min_rate=0.8, max_rate=1.25, p=0.3, |
|
leave_length_unchanged=False), |
|
PitchShift(min_semitones=-4, max_semitones=4, p=0.3), |
|
]) |
|
|
|
|
|
def augment_dataset(batch): |
|
|
|
|
|
audio = batch["audio"]["array"] |
|
|
|
augmented_audio = augment_waveform(samples=audio, sample_rate=16000) |
|
|
|
batch["audio"]["array"] = augmented_audio |
|
|
|
return batch |
|
|
|
|
|
print('Augment train set:') |
|
fleurs['train'] = fleurs['train'].map(augment_dataset, num_proc=10) |
|
|
|
"""We can apply the data preparation function to all of our training examples using dataset's `.map` method. The argument `num_proc` specifies how many CPU cores to use. Setting `num_proc` > 1 will enable multiprocessing. If the `.map` method hangs with multiprocessing, set `num_proc=1` and process the dataset sequentially.""" |
|
|
|
|
|
do_lower_case = True |
|
do_remove_punctuation = True |
|
|
|
normalizer = BasicTextNormalizer() |
|
|
|
|
|
def prepare_dataset(batch): |
|
|
|
|
|
audio = batch["audio"] |
|
|
|
|
|
batch["input_features"] = processor.feature_extractor( |
|
audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] |
|
|
|
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"] |
|
|
|
|
|
transcription = batch["transcription"] |
|
if do_lower_case: |
|
transcription = transcription.lower() |
|
if do_remove_punctuation: |
|
transcription = normalizer(transcription).strip() |
|
|
|
|
|
batch["labels"] = processor.tokenizer(transcription).input_ids |
|
return batch |
|
|
|
|
|
print('Extract features and normalize data:') |
|
fleurs = fleurs.map( |
|
prepare_dataset, remove_columns=fleurs.column_names['train'], num_proc=10).with_format('torch') |
|
|
|
"""Finally, we filter any training data with audio samples longer than 30s. These samples would otherwise be truncated by the Whisper feature-extractor which could affect the stability of training. We define a function that returns `True` for samples that are less than 30s, and `False` for those that are longer:""" |
|
|
|
max_input_length = 30.0 |
|
|
|
|
|
def is_audio_in_length_range(length): |
|
return length < max_input_length |
|
|
|
|
|
"""We apply our filter function to all samples of our training dataset through 🤗 Datasets' `.filter` method:""" |
|
|
|
fleurs['train'] = fleurs['train'].filter( |
|
is_audio_in_length_range, |
|
input_columns=["input_length"], |
|
) |
|
|
|
fleurs["train"] = fleurs["train"].shuffle(seed=42, writer_batch_size=100) |
|
|
|
|
|
@dataclass |
|
class DataCollatorSpeechSeq2SeqWithPadding: |
|
processor: Any |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
input_features = [{"input_features": feature["input_features"]} |
|
for feature in features] |
|
batch = self.processor.feature_extractor.pad( |
|
input_features, return_tensors="pt") |
|
|
|
|
|
label_features = [{"input_ids": feature["labels"]} |
|
for feature in features] |
|
|
|
labels_batch = self.processor.tokenizer.pad( |
|
label_features, return_tensors="pt") |
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill( |
|
labels_batch.attention_mask.ne(1), -100) |
|
|
|
|
|
|
|
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): |
|
labels = labels[:, 1:] |
|
|
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
|
|
"""Let's initialise the data collator we've just defined:""" |
|
|
|
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) |
|
|
|
"""### Evaluation Metrics |
|
|
|
We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing |
|
ASR systems. For more information, refer to the WER [docs](https://huggingface.co/metrics/wer). We'll load the WER metric from 🤗 Evaluate: |
|
""" |
|
|
|
|
|
wer_metric = evaluate.load("wer") |
|
cer_metric = evaluate.load("cer") |
|
|
|
|
|
do_normalize_eval = True |
|
|
|
|
|
def compute_metrics(pred): |
|
pred_ids = pred.predictions |
|
label_ids = pred.label_ids |
|
|
|
|
|
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id |
|
|
|
|
|
pred_str = processor.tokenizer.batch_decode( |
|
pred_ids, skip_special_tokens=True) |
|
label_str = processor.tokenizer.batch_decode( |
|
label_ids, skip_special_tokens=True) |
|
|
|
if do_normalize_eval: |
|
pred_str = [normalizer(pred) for pred in pred_str] |
|
label_str = [normalizer(label) for label in label_str] |
|
|
|
wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str) |
|
cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
return {"wer": wer, "cer": cer} |
|
|
|
|
|
"""### Load a Pre-Trained Checkpoint """ |
|
|
|
|
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") |
|
|
|
"""Override generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)). Set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible:""" |
|
|
|
model.config.forced_decoder_ids = None |
|
model.config.suppress_tokens = [] |
|
model.config.use_cache = False |
|
|
|
"""### Define the Training Configuration |
|
|
|
In the final step, we define all the parameters related to training. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments). |
|
""" |
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir="./", |
|
per_device_train_batch_size=16, |
|
|
|
gradient_accumulation_steps=4, |
|
learning_rate=1e-5, |
|
warmup_steps=30, |
|
max_steps=500, |
|
gradient_checkpointing=True, |
|
fp16=True, |
|
evaluation_strategy="steps", |
|
per_device_eval_batch_size=8, |
|
predict_with_generate=True, |
|
generation_max_length=225, |
|
save_steps=100, |
|
eval_steps=100, |
|
logging_steps=10, |
|
report_to=["tensorboard"], |
|
load_best_model_at_end=True, |
|
metric_for_best_model="wer", |
|
greater_is_better=False, |
|
push_to_hub=True, |
|
|
|
overwrite_output_dir="False", |
|
resume_from_checkpoint="True" |
|
) |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
args=training_args, |
|
model=model, |
|
train_dataset=fleurs['train'], |
|
eval_dataset=fleurs['test'], |
|
data_collator=data_collator, |
|
compute_metrics=compute_metrics, |
|
tokenizer=processor.feature_extractor |
|
|
|
) |
|
trainer.train(resume_from_checkpoint = True) |
|
|
|
"""We'll save the processor object once before starting training. Since the processor is not trainable, it won't change over the course of training:""" |
|
|
|
processor.save_pretrained(training_args.output_dir) |
|
|
|
trainer.train() |
|
|
|
"""We can label our checkpoint with the `whisper-event` tag on push by setting the appropriate key-word arguments (kwargs):""" |
|
|
|
kwargs = { |
|
"dataset_tags": "google/fleurs", |
|
"dataset": "google/fleurs", |
|
"language": "ps", |
|
"model_name": "Whisper Small Pashto - Augmented", |
|
"finetuned_from": "openai/whisper-small", |
|
"tasks": "automatic-speech-recognition", |
|
"tags": "whisper-event", |
|
} |
|
|
|
"""The training results can now be uploaded to the Hub. To do so, execute the `push_to_hub` command and save the preprocessor object we created:""" |
|
|
|
trainer.push_to_hub(**kwargs) |
|
|