{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "ad22981d-cabf-44bc-868c-0f70f0e20ca1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/hf_env/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from datasets import Audio, interleave_datasets, IterableDataset, IterableDatasetDict, load_dataset\n", "from transformers import WhisperProcessor\n", "from transformers.models.whisper.english_normalizer import BasicTextNormalizer\n", "from typing import List, Optional" ] }, { "cell_type": "code", "execution_count": 2, "id": "92047288-08ba-4e14-b5e6-46a1bacd6b3f", "metadata": {}, "outputs": [], "source": [ "def load_multiple_streaming_datasets(\n", " dataset_names: List,\n", " dataset_config_names: List,\n", " splits: Optional[List] = None,\n", " text_column_names: Optional[List] = None,\n", " sampling_rate: Optional[int] = 16000,\n", " stopping_strategy: Optional[str] = \"all_exhausted\",\n", " **kwargs\n", ") -> IterableDataset:\n", "\n", " if len(dataset_names) != len(dataset_config_names):\n", " raise ValueError(\n", " f\"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and\"\n", " f\" {len(dataset_config_names)} configs.\"\n", " )\n", "\n", " if splits is not None and len(splits) != len(dataset_names):\n", " raise ValueError(\n", " f\"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits.\"\n", " )\n", "\n", " if text_column_names is not None and len(text_column_names) != len(dataset_names):\n", " raise ValueError(\n", " f\"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and\"\n", " f\" {len(text_column_names)} text column names.\"\n", " )\n", "\n", " splits = splits if splits is not None else [\"train\" for i in range(len(dataset_names))]\n", " text_column_names = (\n", " text_column_names if text_column_names is not None else [\"text\" for i in range(len(dataset_names))]\n", " )\n", "\n", " all_datasets = []\n", " # iterate over the datasets we want to interleave\n", " for i, dataset_name in enumerate(dataset_names):\n", " dataset = load_dataset(dataset_name, dataset_config_names[i], split=splits[i], streaming=True, **kwargs)\n", " # resample to specified sampling rate\n", " dataset = dataset.cast_column(\"audio\", Audio(sampling_rate))\n", " # normalise columns to [\"audio\", \"sentence\"]\n", " if text_column_names[i] != \"sentence\":\n", " dataset = dataset.rename_column(text_column_names[i], \"sentence\")\n", " dataset = dataset.remove_columns(set(dataset.features.keys()) - set([\"audio\", \"sentence\"]))\n", " all_datasets.append(dataset)\n", "\n", " interleaved_dataset = interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)\n", " return interleaved_dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "8e974feb-3a0f-48fb-aaa8-f88bcdd60844", "metadata": {}, "outputs": [], "source": [ "dataset_names = [\"mozilla-foundation/common_voice_11_0\", \"google/fleurs\", \"openslr\", \"collectivat/tv3_parla\", \"projecte-aina/parlament_parla\", \"projecte-aina/parlament_parla\"]\n", "dataset_config_names = [\"ca\", \"ca_es\", \"SLR69\", \"ca\", \"clean\", \"other\"]\n", "text_column_names = [\"sentence\", \"transcription\", \"sentence\", \"text\", \"sentence\", \"sentence\"]" ] }, { "cell_type": "code", "execution_count": 4, "id": "ba881a5f-46a4-4475-b01c-5b3c18814118", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading builder script: 100%|██████████| 8.30k/8.30k [00:00<00:00, 9.93MB/s]\n", "Downloading readme: 100%|██████████| 12.2k/12.2k [00:00<00:00, 14.5MB/s]\n", "Downloading extra modules: 100%|██████████| 3.44k/3.44k [00:00<00:00, 6.04MB/s]\n", "Downloading extra modules: 100%|██████████| 60.9k/60.9k [00:00<00:00, 32.4MB/s]\n", "Downloading builder script: 100%|██████████| 12.8k/12.8k [00:00<00:00, 14.6MB/s]\n", "Downloading readme: 100%|██████████| 11.2k/11.2k [00:00<00:00, 16.5MB/s]\n", "Downloading builder script: 100%|██████████| 26.9k/26.9k [00:00<00:00, 32.3MB/s]\n", "Downloading metadata: 100%|██████████| 210k/210k [00:00<00:00, 36.4MB/s]\n", "Downloading readme: 100%|██████████| 42.9k/42.9k [00:00<00:00, 26.2MB/s]\n", "Downloading builder script: 100%|██████████| 3.98k/3.98k [00:00<00:00, 7.21MB/s]\n", "Downloading readme: 100%|██████████| 5.15k/5.15k [00:00<00:00, 6.86MB/s]\n", "Using custom data configuration ca\n", "Downloading builder script: 100%|██████████| 5.13k/5.13k [00:00<00:00, 6.95MB/s]\n", "Downloading readme: 100%|██████████| 8.64k/8.64k [00:00<00:00, 10.7MB/s]\n" ] } ], "source": [ "trainset = load_multiple_streaming_datasets(dataset_names, dataset_config_names=dataset_config_names, text_column_names=text_column_names, use_auth_token=True)" ] }, { "cell_type": "code", "execution_count": 5, "id": "1c91f4d8-2c4d-4324-9e71-960cc465c2e9", "metadata": {}, "outputs": [], "source": [ "testset = IterableDataset\n", "testset = load_dataset(\"mozilla-foundation/common_voice_11_0\", \"ca\", split=\"test\", streaming=True, use_auth_token=True)\n", "testset = testset.cast_column(\"audio\", Audio(sampling_rate=16000))" ] }, { "cell_type": "code", "execution_count": 6, "id": "b09ec5fd-a430-4f89-b56d-6187b612e80c", "metadata": {}, "outputs": [], "source": [ "COLUMNS_TO_KEEP = [\"sentence\", \"audio\"]\n", "all_columns = testset.features\n", "columns_to_remove = set(all_columns) - set(COLUMNS_TO_KEEP)\n", "\n", "testset = testset.remove_columns(columns_to_remove)" ] }, { "cell_type": "code", "execution_count": 7, "id": "564d0e67-dc41-4f32-ae23-836686dfa6a1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),\n", " 'sentence': Value(dtype='string', id=None)}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainset.features" ] }, { "cell_type": "code", "execution_count": 8, "id": "a64386dc-4a17-4c55-a87c-b5b2e6033a4f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),\n", " 'sentence': Value(dtype='string', id=None)}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "testset.features" ] }, { "cell_type": "code", "execution_count": 9, "id": "02745170-1967-4a4b-9bbc-1723ed31cf33", "metadata": {}, "outputs": [], "source": [ "do_lower_case = True\n", "do_remove_punctuation = True\n", "\n", "normalizer = BasicTextNormalizer()" ] }, { "cell_type": "code", "execution_count": 11, "id": "1018c0ad-4a00-43d5-b67f-9eed1e3f759d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading: 100%|██████████| 185k/185k [00:00<00:00, 38.7MB/s]\n", "Downloading: 100%|██████████| 828/828 [00:00<00:00, 1.21MB/s]\n", "Downloading: 100%|██████████| 1.04M/1.04M [00:00<00:00, 61.5MB/s]\n", "Downloading: 100%|██████████| 494k/494k [00:00<00:00, 51.6MB/s]\n", "Downloading: 100%|██████████| 52.7k/52.7k [00:00<00:00, 26.2MB/s]\n", "Downloading: 100%|██████████| 2.11k/2.11k [00:00<00:00, 4.13MB/s]\n", "Downloading: 100%|██████████| 2.06k/2.06k [00:00<00:00, 3.02MB/s]\n" ] } ], "source": [ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", language=\"Catalan\", task=\"transcribe\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "5a4d6c05-dcd2-480c-b737-838877bcfd45", "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " # load and (possibly) resample audio data to 16kHz\n", " audio = batch[\"audio\"]\n", "\n", " # compute log-Mel input features from input audio array \n", " batch[\"input_features\"] = processor.feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n", " # compute input length of audio sample in seconds\n", " batch[\"input_length\"] = len(audio[\"array\"]) / audio[\"sampling_rate\"]\n", " \n", " # optional pre-processing steps\n", " transcription = batch[\"sentence\"]\n", " if do_lower_case:\n", " transcription = transcription.lower()\n", " if do_remove_punctuation:\n", " transcription = normalizer(transcription).strip()\n", " \n", " # encode target text to label ids\n", " batch[\"labels\"] = processor.tokenizer(transcription).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 13, "id": "dfeac3fb-0b3a-42e2-a8fd-87975327e68e", "metadata": {}, "outputs": [], "source": [ "vectorized_trainset = trainset.map(prepare_dataset).with_format(\"torch\")\n", "vectorized_testset = testset.map(prepare_dataset).with_format(\"torch\")" ] }, { "cell_type": "code", "execution_count": 14, "id": "2a4223b6-7a50-49c2-9787-b69809279d76", "metadata": {}, "outputs": [], "source": [ "vectorized_trainset = vectorized_trainset.shuffle( buffer_size=500,seed=0,)\n", "vectorized_testset = vectorized_testset.shuffle( buffer_size=500,seed=0,)" ] }, { "cell_type": "code", "execution_count": 15, "id": "cadb0815-739e-408b-895e-5d0d8ecd00d2", "metadata": {}, "outputs": [], "source": [ "MAX_DURATION_IN_SECONDS = 30.0\n", "\n", "def is_audio_length_in_range(input_length):\n", " return input_length < MAX_DURATION_IN_SECONDS" ] }, { "cell_type": "code", "execution_count": 16, "id": "3185083b-8b2b-4832-8ac3-3efcb291019f", "metadata": {}, "outputs": [], "source": [ "vectorized_trainset = vectorized_trainset.filter(is_audio_length_in_range, input_columns=[\"input_length\"])\n", "vectorized_testset = vectorized_testset.filter(is_audio_length_in_range, input_columns=[\"input_length\"])" ] }, { "cell_type": "code", "execution_count": 17, "id": "5b7a8dfc-617a-44e0-88f0-f39df95faf75", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from dataclasses import dataclass\n", "from typing import Any, Dict, List, Union\n", "\n", "@dataclass\n", "class DataCollatorSpeechSeq2SeqWithPadding:\n", " processor: Any\n", "\n", " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n", " # split inputs and labels since they have to be of different lengths and need different padding methods\n", " # first treat the audio inputs by simply returning torch tensors\n", " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n", " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n", "\n", " # get the tokenized label sequences\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", " # pad the labels to max length\n", " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n", "\n", " # replace padding with -100 to ignore loss correctly\n", " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", "\n", " # if bos token is appended in previous tokenization step,\n", " # cut bos token here as it's append later anyways\n", " if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n", " labels = labels[:, 1:]\n", "\n", " batch[\"labels\"] = labels\n", "\n", " return batch" ] }, { "cell_type": "code", "execution_count": 18, "id": "a37c4465-5216-4ffc-ad22-46714003f39f", "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)" ] }, { "cell_type": "code", "execution_count": 19, "id": "230728cd-ef98-4456-9a4b-05cd11ebce0e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading builder script: 100%|██████████| 4.49k/4.49k [00:00<00:00, 5.85MB/s]\n" ] } ], "source": [ "import evaluate\n", "\n", "metric = evaluate.load(\"wer\")" ] }, { "cell_type": "code", "execution_count": 20, "id": "ad8a0ecf-2684-4e11-a031-26e8904b310f", "metadata": {}, "outputs": [], "source": [ "# evaluate with the 'normalised' WER\n", "do_normalize_eval = True\n", "\n", "def compute_metrics(pred):\n", " pred_ids = pred.predictions\n", " label_ids = pred.label_ids\n", "\n", " # replace -100 with the pad_token_id\n", " label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", " # we do not want to group tokens when computing the metrics\n", " pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", " label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n", "\n", " if do_normalize_eval:\n", " pred_str = [normalizer(pred) for pred in pred_str]\n", " label_str = [normalizer(label) for label in label_str]\n", "\n", " wer = 100 * metric.compute(predictions=pred_str, references=label_str)\n", "\n", " return {\"wer\": wer}" ] }, { "cell_type": "code", "execution_count": 21, "id": "cabcfd51-4fcc-4a5f-a180-da9fa35573bd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading: 100%|██████████| 1.96k/1.96k [00:00<00:00, 2.74MB/s]\n", "Downloading: 100%|██████████| 151M/151M [00:01<00:00, 90.8MB/s] \n" ] } ], "source": [ "from transformers import WhisperForConditionalGeneration\n", "\n", "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")" ] }, { "cell_type": "code", "execution_count": 22, "id": "4d802a49-4739-4e30-8273-2b7352805acb", "metadata": {}, "outputs": [], "source": [ "model.config.forced_decoder_ids = None\n", "model.config.suppress_tokens = []\n", "model.config.use_cache = False" ] }, { "cell_type": "code", "execution_count": 27, "id": "ad620be1-0fcc-4749-9447-fa680140e6af", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "PyTorch: setting up devices\n" ] } ], "source": [ "from transformers import Seq2SeqTrainingArguments\n", "\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir=\"./whisper-tiny-ca/\",\n", " per_device_train_batch_size=32,\n", " gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size\n", " learning_rate=1e-5,\n", " warmup_steps=100,\n", " max_steps=1000,\n", " gradient_checkpointing=True,\n", " fp16=True,\n", " evaluation_strategy=\"steps\",\n", " per_device_eval_batch_size=8,\n", " predict_with_generate=True,\n", " generation_max_length=225,\n", " save_steps=1000,\n", " eval_steps=1000,\n", " logging_steps=25,\n", " report_to=[\"tensorboard\"],\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"wer\",\n", " greater_is_better=False,\n", " push_to_hub=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 28, "id": "56f59a72-af0d-43b1-bd20-632c163c09d0", "metadata": {}, "outputs": [], "source": [ "from transformers import TrainerCallback\n", "from transformers.trainer_pt_utils import IterableDatasetShard\n", "from torch.utils.data import IterableDataset\n", "\n", "# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch\n", "class ShuffleCallback(TrainerCallback):\n", " def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):\n", " if isinstance(train_dataloader.dataset, IterableDatasetShard):\n", " pass # set_epoch() is handled by the Trainer\n", " elif isinstance(train_dataloader.dataset, IterableDataset):\n", " train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)" ] }, { "cell_type": "code", "execution_count": 29, "id": "9076d20d-7a6d-45d8-b386-cffab6f6920a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "max_steps is given, it will override any value given in num_train_epochs\n", "Using cuda_amp half precision backend\n" ] } ], "source": [ "from transformers import Seq2SeqTrainer\n", "\n", "trainer = Seq2SeqTrainer(\n", " args=training_args,\n", " model=model,\n", " train_dataset=vectorized_trainset,\n", " eval_dataset=vectorized_testset,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", " tokenizer=processor,\n", " callbacks=[ShuffleCallback()],\n", ")" ] }, { "cell_type": "code", "execution_count": 30, "id": "af6bb835-8a62-4c13-a20d-3e85df12ab7d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Configuration saved in ./whisper-tiny-ca/config.json\n", "Model weights saved in ./whisper-tiny-ca/pytorch_model.bin\n", "Feature extractor saved in ./whisper-tiny-ca/preprocessor_config.json\n", "tokenizer config file saved in ./whisper-tiny-ca/tokenizer_config.json\n", "Special tokens file saved in ./whisper-tiny-ca/special_tokens_map.json\n", "added tokens file saved in ./whisper-tiny-ca/added_tokens.json\n" ] } ], "source": [ "model.save_pretrained(training_args.output_dir)\n", "processor.save_pretrained(training_args.output_dir)" ] }, { "cell_type": "code", "execution_count": 31, "id": "8f841688-696e-408c-b3a4-d0f7163efc33", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/hf_env/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "***** Running training *****\n", " Num examples = 64000\n", " Num Epochs = 9223372036854775807\n", " Instantaneous batch size per device = 32\n", " Total train batch size (w. parallel, distributed & accumulation) = 64\n", " Gradient Accumulation steps = 2\n", " Total optimization steps = 1000\n", " Number of trainable parameters = 37760640\n", "Reading metadata...: 905243it [00:16, 56138.00it/s]\n", "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: sentence, input_length, audio. If sentence, input_length, audio are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "Validation Loss | \n", "Wer | \n", "
---|---|---|---|
1000 | \n", "0.531500 | \n", "0.676937 | \n", "41.100362 | \n", "
"
],
"text/plain": [
"