{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:33:41.892030Z", "start_time": "2021-03-14T09:33:40.729163Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n", " warnings.warn(\n" ] } ], "source": [ "from transformers import Wav2Vec2ForCTC\n", "from transformers import Wav2Vec2Processor\n", "from datasets import load_dataset, load_metric\n", "import re\n", "import torchaudio\n", "import librosa\n", "import numpy as np\n", "from datasets import load_dataset, load_metric\n", "import torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:33:41.909851Z", "start_time": "2021-03-14T09:33:41.906327Z" } }, "outputs": [], "source": [ "chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n", "\n", "def remove_special_characters(batch):\n", " batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"sentence\"]).lower() + \" \"\n", " return batch\n", "\n", "def speech_file_to_array_fn(batch):\n", " speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n", " batch[\"speech\"] = speech_array[0].numpy()\n", " batch[\"sampling_rate\"] = sampling_rate\n", " batch[\"target_text\"] = batch[\"text\"]\n", " return batch\n", "\n", "def resample(batch):\n", " batch[\"speech\"] = librosa.resample(np.asarray(batch[\"speech\"]), 48_000, 16_000)\n", " batch[\"sampling_rate\"] = 16_000\n", " return batch\n", "\n", "def prepare_dataset(batch):\n", " # check that all files have the correct sampling rate\n", " assert (\n", " len(set(batch[\"sampling_rate\"])) == 1\n", " ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n", "\n", " batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n", " \n", " with processor.as_target_processor():\n", " batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:33:49.053762Z", "start_time": "2021-03-14T09:33:41.922683Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.\n" ] } ], "source": [ "model = Wav2Vec2ForCTC.from_pretrained(\"wav2vec2-large-xlsr-greek/checkpoint-9200/\").to(\"cuda\")\n", "processor = Wav2Vec2Processor.from_pretrained(\"wav2vec2-large-xlsr-greek/\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:33:52.413558Z", "start_time": "2021-03-14T09:33:49.078466Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration el-afd0a157f05ee080\n", "Reusing dataset common_voice (/home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564)\n" ] } ], "source": [ "common_voice_test = load_dataset(\"common_voice\", \"el\", data_dir=\"cv-corpus-6.1-2020-12-11\", split=\"test\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:33:52.444418Z", "start_time": "2021-03-14T09:33:52.441338Z" } }, "outputs": [], "source": [ "common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:33:52.473087Z", "start_time": "2021-03-14T09:33:52.468014Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-0ce2ebca66096fff.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=[\"sentence\"])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:33:52.510377Z", "start_time": "2021-03-14T09:33:52.501677Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-38a09981767eff59.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:33:53.321810Z", "start_time": "2021-03-14T09:33:52.533233Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-ba8c6dd59eb8ccf2.arrow\n", "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-2e240883a5f827fd.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-485c00dc9048ed50.arrow\n", "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-44bf1791baae8e2e.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-ecc0dfac5615a58e.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-923d905502a8661d.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-062aeafc3b8816c1.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-bb54bb00dae79669.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(resample, num_proc=8)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:33:53.611415Z", "start_time": "2021-03-14T09:33:53.342487Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", " return array(a, dtype, copy=False, order=order)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-6dfad29ca815f865.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-61e9ae0296df46f8.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-7f5aae16804e0788.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-b9636a5d30ffb973.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-7e60f2d73a65610a.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-3c99781789816a60.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-bae077f32f9eb290.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-4fb6951626f7548e.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:33:56.243678Z", "start_time": "2021-03-14T09:33:53.632436Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration el-ac779bf2c9f7c09b\n", "Reusing dataset common_voice (/home/earendil/.cache/huggingface/datasets/common_voice/el-ac779bf2c9f7c09b/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564)\n" ] } ], "source": [ "common_voice_test_transcription = load_dataset(\"common_voice\", \"el\", data_dir=\"./cv-corpus-6.1-2020-12-11\", split=\"test\")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:36:50.076837Z", "start_time": "2021-03-14T09:36:24.943947Z" } }, "outputs": [], "source": [ "# Change this value to try inference on different CommonVoice extracts\n", "example = 123\n", "\n", "input_dict = processor(common_voice_test[\"input_values\"][example], return_tensors=\"pt\", sampling_rate=16_000, padding=True)\n", "\n", "logits = model(input_dict.input_values.to(\"cuda\")).logits\n", "\n", "pred_ids = torch.argmax(logits, dim=-1)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T09:36:50.137886Z", "start_time": "2021-03-14T09:36:50.134218Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction:\n", "καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρί\n", "\n", "Reference:\n", "καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρή\n" ] } ], "source": [ "print(\"Prediction:\")\n", "print(processor.decode(pred_ids[0]))\n", "# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρί\n", "\n", "print(\"\\nReference:\")\n", "print(common_voice_test_transcription[\"sentence\"][example].lower())\n", "# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρή" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "cuda110", "language": "python", "name": "cuda110" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }