{ "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": 3 }, "orig_nbformat": 2 }, "nbformat": 4, "nbformat_minor": 2, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!pip install datasets==1.4.1\n", "!pip install transformers==4.4.0\n", "!pip install torchaudio\n", "!pip install librosa\n", "!pip install jiwer\n", "!pip install mecab-python3\n", "!pip install unidic-lite\n", "!pip isntall audiomentations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor\n", "from datasets import load_dataset, load_metric, ClassLabel, Dataset\n", "from audiomentations import Compose, AddGaussianNoise, Gain, PitchShift, TimeStretch, Shift\n", "from torch.optim.lr_scheduler import LambdaLR\n", "from transformers import Wav2Vec2ForCTC, TrainingArguments, Trainer\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import soundfile as sf\n", "import re\n", "import json\n", "import torchaudio\n", "import librosa\n", "import datasets\n", "import MeCab\n", "import pykakasi\n", "import random\n", "\n", "import torch\n", "from dataclasses import dataclass, field\n", "from typing import Any, Dict, List, Optional, Union" ] }, { "source": [ "# Load dataset and prepare processor" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load public dataset from University of Tokyo\n", "!wget http://ss-takashi.sakura.ne.jp/corpus/jsut_ver1.1.zip\n", "!unzip jsut_ver1.1.zip\n", "\n", "path = 'jsut_ver1.1/basic5000/'\n", "df = pd.read_csv(path + 'transcript_utf8.txt', header = None, delimiter = \":\", names=[\"path\", \"sentence\"], index_col=False)\n", "df[\"path\"] = df[\"path\"].map(lambda x: path + 'wav/' + x + \".wav\")\n", "df.head()\n", "\n", "jsut_voice_train = Dataset.from_pandas(df)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Import training dataset\n", "common_voice_train = load_dataset('common_voice', 'ja',split='train+validation')\n", "common_voice_test = load_dataset('common_voice', 'ja', split='test')\n", "\n", "# Remove unwanted columns\n", "common_voice_train = common_voice_train.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])\n", "common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])\n", "\n", "# Concat common voice and public dataset\n", "common_voice_train = datasets.concatenate_datasets([jsut_voice_train, common_voice_train])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Parser Japanese sentence. Ex: \"pythonが大好きです\" -> \"python が 大好き です EOS\"\n", "wakati = MeCab.Tagger(\"-Owakati\")\n", "\n", "# Unwanted token\n", "chars_to_ignore_regex = '[\\,\\、\\。\\.\\「\\」\\…\\?\\・]'\n", "\n", "def remove_special_characters(batch):\n", " batch[\"sentence\"] = wakati.parse(batch[\"sentence\"]).strip()\n", " batch[\"sentence\"] = re.sub(chars_to_ignore_regex,'', batch[\"sentence\"]).strip()\n", " return batch\n", "\n", "common_voice_train = common_voice_train.map(remove_special_characters)\n", "common_voice_test = common_voice_test.map(remove_special_characters)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# make vocab file\n", "def extract_all_chars(batch):\n", " all_text = \" \".join(batch[\"sentence\"])\n", " vocab = list(set(all_text))\n", " return {\"vocab\": [vocab], \"all_text\": [all_text]}\n", "\n", "# make vocab list and text\n", "vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)\n", "vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)\n", "\n", "# concate vocab from train and test set\n", "vocab_list = list(set(vocab_train[\"vocab\"][0]) | set(vocab_test[\"vocab\"][0]))\n", "vocab_dict = {v: k for k, v in enumerate(vocab_list)}\n", "print(len(vocab_dict))\n", "vocab_dict[\"|\"] = vocab_dict[\" \"]\n", "del vocab_dict[\" \"]\n", "\n", "# create unk and pad token\n", "vocab_dict[\"[UNK]\"] = len(vocab_dict)\n", "vocab_dict[\"[PAD]\"] = len(vocab_dict)\n", "\n", "# save to json file\n", "with open('vocab.json', 'w') as vocab_file:\n", " json.dump(vocab_dict, vocab_file, indent=2, ensure_ascii=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_dir = \"./output_models\"\n", "# wrap tokenizer and feature extractor to processor\n", "tokenizer = Wav2Vec2CTCTokenizer(\"./vocab_demo.json\", unk_token=\"[UNK]\", pad_token=\"[PAD]\", word_delimiter_token=\"|\")\n", "feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)\n", "processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)\n", "processor.save_pretrained(save_dir)" ] }, { "source": [ "# Prepare train and test dataset " ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# convert audio from 48kHz to 16kHz (standard sample rate of wave2vec model)\n", "def speech_file_to_array_fn(batch):\n", " speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n", " batch[\"speech\"] = librosa.resample(np.asarray(speech_array[0].numpy()), 48_000, 16_000)\n", " batch[\"sampling_rate\"] = 16_000\n", " batch[\"target_text\"] = batch[\"sentence\"]\n", " return batch\n", "\n", "common_voice_train = common_voice_train.map(speech_file_to_array_fn, remove_columns=common_voice_train.column_names,num_proc=4)\n", "common_voice_test = common_voice_test.map(speech_file_to_array_fn,remove_columns=common_voice_test.column_names, num_proc=4) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# do augment to enrich common voice dataset \n", "augment = Compose([\n", " AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=0.8),\n", " PitchShift(min_semitones=-1, max_semitones=1, p=0.8),\n", " Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.8),\n", " TimeStretch(min_rate=0.8, max_rate=1.25, p=0.8)\n", "\n", "])\n", "\n", "def augmented_speech(batch, augment):\n", " samples = np.array(batch[\"speech\"])\n", " batch[\"speech\"] = augment(samples=samples, sample_rate=16000)\n", " batch[\"sampling_rate\"] = 16_000\n", " batch[\"target_text\"] = batch[\"target_text\"]\n", " return batch\n", "\n", "# augument 50% of trainset\n", "common_voice_train_augmented = common_voice_train.train_test_split(test_size = 0.5)['train']\n", "common_voice_train_augmented = common_voice_train_augmented.map(lambda batch: augmented_speech(batch, augment), num_proc=4)\n", "\n", "# concate with trainset\n", "common_voice_train = datasets.concatenate_datasets([common_voice_train_augmented, common_voice_train])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " # check that all files have the correct sampling rate\n", " assert (\n", " len(set(batch[\"sampling_rate\"])) == 1\n", " ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n", "\n", " batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n", " \n", " with processor.as_target_processor():\n", " batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n", " return batch\n", " \n", "# prepare dataset\n", "common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, batch_size=8, num_proc=4, batched=True)\n", "common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=4, batched=True)" ] }, { "source": [ "# Training" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# create data collator\n", "@dataclass\n", "class DataCollatorCTCWithPadding:\n", "\n", " processor: Wav2Vec2Processor\n", " padding: Union[bool, str] = True\n", " max_length: Optional[int] = None\n", " max_length_labels: Optional[int] = None\n", " pad_to_multiple_of: Optional[int] = None\n", " pad_to_multiple_of_labels: Optional[int] = None\n", "\n", " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n", " input_features = [{\"input_values\": feature[\"input_values\"]} for feature in features]\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", "\n", " batch = self.processor.pad(\n", " input_features,\n", " padding=self.padding,\n", " max_length=self.max_length,\n", " pad_to_multiple_of=self.pad_to_multiple_of,\n", " return_tensors=\"pt\",\n", " )\n", " with self.processor.as_target_processor():\n", " labels_batch = self.processor.pad(\n", " label_features,\n", " padding=self.padding,\n", " max_length=self.max_length_labels,\n", " pad_to_multiple_of=self.pad_to_multiple_of_labels,\n", " return_tensors=\"pt\",\n", " )\n", "\n", " # replace padding with -100 to ignore loss correctly\n", " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", "\n", " batch[\"labels\"] = labels\n", "\n", " return batch\n", "\n", "data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# make metric function\n", "wer_metric = load_metric(\"wer\")\n", "\n", "def compute_metrics(pred):\n", " pred_logits = pred.predictions\n", " pred_ids = np.argmax(pred_logits, axis=-1)\n", "\n", " pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", " pred_str = processor.batch_decode(pred_ids)\n", " # we do not want to group tokens when computing the metrics\n", " label_str = processor.batch_decode(pred.label_ids, group_tokens=False)\n", "\n", " wer = wer_metric.compute(predictions=pred_str, references=label_str)\n", "\n", " return {\"wer\": wer}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# create custom learning scheduler\n", "\n", "# polynomial decay\n", "def get_polynomial_decay_schedule_with_warmup(\n", " optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.2, last_epoch=-1\n", "):\n", "\n", " lr_init = optimizer.defaults[\"lr\"]\n", " assert lr_init > lr_end, f\"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})\"\n", "\n", " def lr_lambda(current_step: int):\n", " if current_step < num_warmup_steps:\n", " return float(current_step) / float(max(1, num_warmup_steps))\n", " elif current_step > num_training_steps:\n", " return lr_end / lr_init # as LambdaLR multiplies by lr_init\n", " else:\n", " lr_range = lr_init - lr_end\n", " decay_steps = num_training_steps - num_warmup_steps\n", " pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps\n", " decay = lr_range * pct_remaining ** power + lr_end\n", " return decay / lr_init # as LambdaLR multiplies by lr_init\n", "\n", " return LambdaLR(optimizer, lr_lambda, last_epoch)\n", " \n", "# wrap custom learning scheduler with trainer\n", "class PolyTrainer(Trainer):\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", " \n", " def create_scheduler(self, num_training_steps: int):\n", " self.lr_scheduler = get_polynomial_decay_schedule_with_warmup(self.optimizer, \n", " num_warmup_steps=self.args.warmup_steps,\n", " num_training_steps=num_training_steps)\n", " def create_optimizer_and_scheduler(self, num_training_steps: int):\n", " self.create_optimizer()\n", " self.create_scheduler(num_training_steps)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load pretrain model\n", "model = Wav2Vec2ForCTC.from_pretrained(\n", " \"facebook/wav2vec2-large-xlsr-53\", \n", " attention_dropout=0.1,\n", " hidden_dropout=0.1,\n", " feat_proj_dropout=0.1,\n", " mask_time_prob=0.1, \n", " layerdrop=0.1,\n", " gradient_checkpointing=True, \n", " ctc_loss_reduction=\"mean\", \n", " pad_token_id=processor.tokenizer.pad_token_id,\n", " vocab_size=len(processor.tokenizer)\n", ")\n", "# free feature extractor\n", "model.freeze_feature_extractor()\n", "\n", "# define train argument\n", "training_args = TrainingArguments(\n", " output_dir=save_dir,\n", " group_by_length=True,\n", " per_device_train_batch_size=32,\n", " gradient_accumulation_steps=2,\n", " evaluation_strategy=\"steps\",\n", " num_train_epochs=200,\n", " fp16=True,\n", " save_steps=2400, \n", " eval_steps=800,\n", " logging_steps=800, \n", " learning_rate=1e-4, \n", " warmup_steps=1500, \n", " save_total_limit=2,\n", " load_best_model_at_end = True, \n", " metric_for_best_model='wer', \n", " greater_is_better=False\n", ")\n", "\n", "# wrap everything to Trainer\n", "trainer = PolyTrainer(\n", " model=model,\n", " data_collator=data_collator,\n", " args=training_args,\n", " compute_metrics=compute_metrics,\n", " train_dataset=common_voice_train,\n", " eval_dataset=common_voice_test,\n", " tokenizer=processor.feature_extractor,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# training\n", "train_result = trainer.train()" ] }, { "source": [ "# Testing result" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchaudio\n", "from datasets import load_dataset, load_metric\n", "from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor\n", "import MeCab\n", "import pykakasi\n", "import re\n", "\n", "#config\n", "wakati = MeCab.Tagger(\"-Owakati\")\n", "chars_to_ignore_regex = '[\\,\\、\\。\\.\\「\\」\\…\\?\\・]'\n", "\n", "#load model\n", "processor = Wav2Vec2Processor.from_pretrained(save_dir)\n", "test_model = Wav2Vec2ForCTC.from_pretrained(save_dir)\n", "test_model.to(\"cuda\")\n", "resampler = torchaudio.transforms.Resample(48_000, 16_000)\n", "\n", "#load testdata\n", "test_dataset = load_dataset(\"common_voice\", \"ja\", split=\"test\")\n", "wer = load_metric(\"wer\")\n", "\n", "# Preprocessing the datasets.\n", "def speech_file_to_array_fn(batch):\n", " batch[\"sentence\"] = wakati.parse(batch[\"sentence\"]).strip()\n", " batch[\"sentence\"] = re.sub(chars_to_ignore_regex,'', batch[\"sentence\"]).strip()\n", " speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n", " batch[\"speech\"] = resampler(speech_array).squeeze().numpy()\n", " return batch\n", "\n", "test_dataset = test_dataset.map(speech_file_to_array_fn)\n", "\n", "# Preprocessing the datasets.\n", "# We need to read the aduio files as arrays\n", "def evaluate(batch):\n", " inputs = processor(batch[\"speech\"], sampling_rate=16_000, return_tensors=\"pt\", padding=True)\n", "\n", " with torch.no_grad():\n", " logits = test_model(inputs.input_values.to(\"cuda\"), attention_mask=inputs.attention_mask.to(\"cuda\")).logits\n", " pred_ids = torch.argmax(logits, dim=-1)\n", " batch[\"pred_strings\"] = processor.batch_decode(pred_ids)\n", " return batch\n", "\n", "result = test_dataset.map(evaluate, batched=True, batch_size=8)\n", "\n", "print(\"WER: {:2f}\".format(100 * wer.compute(predictions=result[\"pred_strings\"], references=result[\"sentence\"])))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# print some reusults\n", "pick = random.randint(0, len(common_voice_test_transcription)-1)\n", "input_dict = processor(common_voice_test[\"input_values\"][pick], return_tensors=\"pt\", padding=True)\n", "logits = test_model(input_dict.input_values.to(\"cuda\")).logits\n", "pred_ids = torch.argmax(logits, dim=-1)[0]\n", "\n", "print(\"Prediction:\")\n", "print(processor.decode(pred_ids).strip())\n", "\n", "print(\"\\nLabel:\")\n", "print(processor.decode(common_voice_test['labels'][pick]))\n" ] } ] }