|
from datasets import Audio, interleave_datasets, IterableDataset, load_dataset |
|
from typing import List, Optional |
|
|
|
dataset_names = ["mozilla-foundation/common_voice_11_0", "google/fleurs"] |
|
dataset_config_names = ["da", "da_dk"] |
|
text_column_names = ["sentence", "normalized_text", "text", "transcription"] |
|
|
|
from datasets import interleave_datasets, load_dataset |
|
|
|
def load_streaming_dataset(dataset_name, dataset_config_name, split, **kwargs): |
|
if "+" in split: |
|
|
|
dataset_splits = [load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=True, **kwargs) for split_name in split.split("+")] |
|
|
|
interleaved_dataset = interleave_datasets(dataset_splits) |
|
return interleaved_dataset |
|
else: |
|
|
|
dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=True, **kwargs) |
|
return dataset |
|
|
|
from datasets import IterableDatasetDict |
|
|
|
raw_datasets = IterableDatasetDict() |
|
|
|
raw_datasets["train"] = load_streaming_dataset("mozilla-foundation/common_voice_11_0", "da", split="train+validation", use_auth_token=True) |
|
raw_datasets["test"] = load_streaming_dataset("mozilla-foundation/common_voice_11_0", "da", split="test", use_auth_token=True) |
|
|
|
from transformers import WhisperProcessor |
|
|
|
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Danish", task="transcribe") |
|
|
|
from datasets import Audio |
|
|
|
raw_datasets = raw_datasets.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
|
from transformers.models.whisper.english_normalizer import BasicTextNormalizer |
|
|
|
do_lower_case = False |
|
do_remove_punctuation = False |
|
|
|
normalizer = BasicTextNormalizer() |
|
|
|
def prepare_dataset(batch): |
|
|
|
audio = batch["audio"] |
|
|
|
|
|
batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] |
|
|
|
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"] |
|
|
|
|
|
transcription = batch["sentence"] |
|
if do_lower_case: |
|
transcription = transcription.lower() |
|
if do_remove_punctuation: |
|
transcription = normalizer(transcription).strip() |
|
|
|
|
|
batch["labels"] = processor.tokenizer(transcription).input_ids |
|
return batch |
|
|
|
vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=list(next(iter(raw_datasets.values())).features)).with_format("torch") |
|
|
|
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle( |
|
buffer_size=500, |
|
seed=0, |
|
) |
|
|
|
max_input_length = 30.0 |
|
|
|
def is_audio_in_length_range(length): |
|
return length < max_input_length |
|
|
|
vectorized_datasets["train"] = vectorized_datasets["train"].filter( |
|
is_audio_in_length_range, |
|
input_columns=["input_length"], |
|
) |
|
|
|
import torch |
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Union |
|
|
|
@dataclass |
|
class DataCollatorSpeechSeq2SeqWithPadding: |
|
processor: Any |
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
input_features = [{"input_features": feature["input_features"]} for feature in features] |
|
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") |
|
|
|
|
|
label_features = [{"input_ids": feature["labels"]} for feature in features] |
|
|
|
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") |
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) |
|
|
|
|
|
|
|
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): |
|
labels = labels[:, 1:] |
|
|
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) |
|
import evaluate |
|
|
|
metric = evaluate.load("wer") |
|
|
|
do_normalize_eval = True |
|
|
|
def compute_metrics(pred): |
|
pred_ids = pred.predictions |
|
label_ids = pred.label_ids |
|
|
|
|
|
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id |
|
|
|
|
|
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
|
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
|
|
|
if do_normalize_eval: |
|
pred_str = [normalizer(pred) for pred in pred_str] |
|
label_str = [normalizer(label) for label in label_str] |
|
|
|
wer = 100 * metric.compute(predictions=pred_str, references=label_str) |
|
|
|
return {"wer": wer} |
|
|
|
from transformers import WhisperForConditionalGeneration |
|
|
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") |
|
|
|
model.config.forced_decoder_ids = None |
|
model.config.suppress_tokens = [] |
|
model.config.use_cache = False |
|
|
|
from transformers import Seq2SeqTrainingArguments |
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir="./", |
|
per_device_train_batch_size=64, |
|
gradient_accumulation_steps=1, |
|
learning_rate=1e-07, |
|
warmup_steps=500, |
|
max_steps=5000, |
|
gradient_checkpointing=True, |
|
fp16=True, |
|
evaluation_strategy="steps", |
|
per_device_eval_batch_size=32, |
|
predict_with_generate=True, |
|
generation_max_length=225, |
|
save_steps=1000, |
|
eval_steps=1000, |
|
logging_steps=25, |
|
report_to=["tensorboard"], |
|
load_best_model_at_end=True, |
|
metric_for_best_model="wer", |
|
greater_is_better=False, |
|
push_to_hub=False, |
|
|
|
) |
|
|
|
from transformers import TrainerCallback |
|
from transformers.trainer_pt_utils import IterableDatasetShard |
|
from torch.utils.data import IterableDataset |
|
|
|
|
|
class ShuffleCallback(TrainerCallback): |
|
def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs): |
|
if isinstance(train_dataloader.dataset, IterableDatasetShard): |
|
pass |
|
elif isinstance(train_dataloader.dataset, IterableDataset): |
|
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1) |
|
|
|
from transformers import Seq2SeqTrainer |
|
|
|
trainer = Seq2SeqTrainer( |
|
args=training_args, |
|
model=model, |
|
train_dataset=vectorized_datasets["train"], |
|
eval_dataset=vectorized_datasets["test"], |
|
data_collator=data_collator, |
|
compute_metrics=compute_metrics, |
|
tokenizer=processor, |
|
callbacks=[ShuffleCallback()], |
|
) |
|
|
|
model.save_pretrained(training_args.output_dir) |
|
processor.save_pretrained(training_args.output_dir) |
|
|
|
trainer.train() |
|
|
|
kwargs = { |
|
"dataset_tags": "mozilla-foundation/common_voice_11_0", |
|
"dataset": "Common Voice 11.0, FLEURS", |
|
"language": "da", |
|
"model_name": "Whisper Small da - Common Voice+FLEURS", |
|
"finetuned_from": "openai/whisper-small", |
|
"tasks": "automatic-speech-recognition", |
|
"tags": "whisper-event", |
|
} |
|
|
|
trainer.push_to_hub(**kwargs) |
|
|