whisper_small_ps_augmented / whisper_small_ps_augmented.py
ihanif's picture
Training in progress, step 300
9f0a12a
raw
history blame
10.8 kB
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from audiomentations import Compose, TimeStretch, PitchShift
from datasets import Audio
from datasets import load_dataset, DatasetDict
import jiwer
import warnings
import pandas as pd
from io import StringIO
from datasets import Dataset, IterableDatasetDict, load_dataset, interleave_datasets, Audio
import evaluate
import torch
import string
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from transformers import WhisperForConditionalGeneration
from transformers import WhisperProcessor
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import WhisperTokenizer
from transformers import WhisperFeatureExtractor
#import wandb
#from IPython.display import clear_output
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift
import numpy as np
#from huggingface_hub import notebook_login
from transformers import TrainerCallback
from transformers.integrations import WandbCallback
from transformers.trainer_pt_utils import IterableDatasetShard
from torch.utils.data import IterableDataset
from datasets import load_dataset, Audio
from pathlib import Path
import numpy as np
#import holoviews as hv
#import panel as pn
import tempfile
#from bokeh.resources import INLINE
#hv.extension("bokeh", logo=False)
#warnings.filterwarnings('ignore')
#clear_output()
torch.cuda.is_available()
"""## Load Dataset
Loading MS-MY Dataset from FLEURS.
Combine train and validation set.
"""
# notebook_login()
fleurs = DatasetDict()
fleurs["train"] = load_dataset(
"google/fleurs", "ps_af", split="train+validation", use_auth_token=True)
fleurs["test"] = load_dataset(
"google/fleurs", "ps_af", split="test", use_auth_token=True)
fleurs = fleurs.remove_columns(
["id", "num_samples", "path", "raw_transcription", "gender", "lang_id", "language", "lang_group_id"])
print(fleurs)
feature_extractor = WhisperFeatureExtractor.from_pretrained(
"openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained(
"openai/whisper-small", language="Pashto", task="transcribe")
"""### Combine To Create A WhisperProcessor"""
processor = WhisperProcessor.from_pretrained(
"openai/whisper-small", language="Pashto", task="transcribe")
"""### Prepare Data"""
fleurs = fleurs.cast_column("audio", Audio(sampling_rate=16000))
augment_waveform = Compose([
TimeStretch(min_rate=0.8, max_rate=1.25, p=0.3,
leave_length_unchanged=False),
PitchShift(min_semitones=-4, max_semitones=4, p=0.3),
])
def augment_dataset(batch):
#return batch
audio = batch["audio"]["array"]
# apply augmentation
augmented_audio = augment_waveform(samples=audio, sample_rate=16000)
batch["audio"]["array"] = augmented_audio
return batch
print('Augment train set:')
fleurs['train'] = fleurs['train'].map(augment_dataset, num_proc=10)
"""We can apply the data preparation function to all of our training examples using dataset's `.map` method. The argument `num_proc` specifies how many CPU cores to use. Setting `num_proc` > 1 will enable multiprocessing. If the `.map` method hangs with multiprocessing, set `num_proc=1` and process the dataset sequentially."""
do_lower_case = True
do_remove_punctuation = True
normalizer = BasicTextNormalizer()
def prepare_dataset(batch):
#return batch
# load and (possibly) resample audio data to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = processor.feature_extractor(
audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# compute input length of audio sample in seconds
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
# optional pre-processing steps
transcription = batch["transcription"]
if do_lower_case:
transcription = transcription.lower()
if do_remove_punctuation:
transcription = normalizer(transcription).strip()
# encode target text to label ids
batch["labels"] = processor.tokenizer(transcription).input_ids
return batch
print('Extract features and normalize data:')
fleurs = fleurs.map(
prepare_dataset, remove_columns=fleurs.column_names['train'], num_proc=10).with_format('torch')
"""Finally, we filter any training data with audio samples longer than 30s. These samples would otherwise be truncated by the Whisper feature-extractor which could affect the stability of training. We define a function that returns `True` for samples that are less than 30s, and `False` for those that are longer:"""
max_input_length = 30.0
def is_audio_in_length_range(length):
return length < max_input_length
"""We apply our filter function to all samples of our training dataset through 🤗 Datasets' `.filter` method:"""
fleurs['train'] = fleurs['train'].filter(
is_audio_in_length_range,
input_columns=["input_length"],
)
fleurs["train"] = fleurs["train"].shuffle(seed=42, writer_batch_size=100)
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]}
for feature in features]
batch = self.processor.feature_extractor.pad(
input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]}
for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(
label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
"""Let's initialise the data collator we've just defined:"""
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
"""### Evaluation Metrics
We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing
ASR systems. For more information, refer to the WER [docs](https://huggingface.co/metrics/wer). We'll load the WER metric from 🤗 Evaluate:
"""
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")
#  evaluate with the 'normalised' WER
do_normalize_eval = True
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(
pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(
label_ids, skip_special_tokens=True)
if do_normalize_eval:
pred_str = [normalizer(pred) for pred in pred_str]
label_str = [normalizer(label) for label in label_str]
wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer, "cer": cer}
"""### Load a Pre-Trained Checkpoint """
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
"""Override generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)). Set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible:"""
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False
"""### Define the Training Configuration
In the final step, we define all the parameters related to training. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments).
"""
training_args = Seq2SeqTrainingArguments(
output_dir="./",
per_device_train_batch_size=16,
# increase by 2x for every 2x decrease in batch size
gradient_accumulation_steps=4,
learning_rate=1e-5,
warmup_steps=30,
max_steps=500,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=100,
eval_steps=100,
logging_steps=10,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
#optim='adamw_bnb_8bit', # 'adamw_bnb_8bit',
overwrite_output_dir="False",
resume_from_checkpoint="True"
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=fleurs['train'],
eval_dataset=fleurs['test'],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor
)
trainer.train(resume_from_checkpoint = True)
"""We'll save the processor object once before starting training. Since the processor is not trainable, it won't change over the course of training:"""
processor.save_pretrained(training_args.output_dir)
trainer.train()
"""We can label our checkpoint with the `whisper-event` tag on push by setting the appropriate key-word arguments (kwargs):"""
kwargs = {
"dataset_tags": "google/fleurs",
"dataset": "google/fleurs", # a 'pretty' name for the training dataset
"language": "ps",
"model_name": "Whisper Small Pashto - Augmented", # a 'pretty' name for your model
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
"tags": "whisper-event",
}
"""The training results can now be uploaded to the Hub. To do so, execute the `push_to_hub` command and save the preprocessor object we created:"""
trainer.push_to_hub(**kwargs)