lighteternal
commited on
Commit
•
8aa27c7
1
Parent(s):
c9229a1
Added inference script
Browse files- .ipynb_checkpoints/ASR_Inference-checkpoint.ipynb +550 -0
- ASR_Inference.ipynb +279 -0
- README.md +101 -3
.ipynb_checkpoints/ASR_Inference-checkpoint.ipynb
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {
|
7 |
+
"ExecuteTime": {
|
8 |
+
"end_time": "2021-03-14T09:33:41.892030Z",
|
9 |
+
"start_time": "2021-03-14T09:33:40.729163Z"
|
10 |
+
}
|
11 |
+
},
|
12 |
+
"outputs": [
|
13 |
+
{
|
14 |
+
"name": "stderr",
|
15 |
+
"output_type": "stream",
|
16 |
+
"text": [
|
17 |
+
"/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/torchaudio/backend/utils.py:53: UserWarning: \"sox\" backend is being deprecated. The default backend will be changed to \"sox_io\" backend in 0.8.0 and \"sox\" backend will be removed in 0.9.0. Please migrate to \"sox_io\" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.\n",
|
18 |
+
" warnings.warn(\n"
|
19 |
+
]
|
20 |
+
}
|
21 |
+
],
|
22 |
+
"source": [
|
23 |
+
"from transformers import Wav2Vec2ForCTC\n",
|
24 |
+
"from transformers import Wav2Vec2Processor\n",
|
25 |
+
"from datasets import load_dataset, load_metric\n",
|
26 |
+
"import re\n",
|
27 |
+
"import torchaudio\n",
|
28 |
+
"import librosa\n",
|
29 |
+
"import numpy as np\n",
|
30 |
+
"from datasets import load_dataset, load_metric\n",
|
31 |
+
"import torch"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 2,
|
37 |
+
"metadata": {
|
38 |
+
"ExecuteTime": {
|
39 |
+
"end_time": "2021-03-14T09:33:41.909851Z",
|
40 |
+
"start_time": "2021-03-14T09:33:41.906327Z"
|
41 |
+
}
|
42 |
+
},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n",
|
46 |
+
"\n",
|
47 |
+
"def remove_special_characters(batch):\n",
|
48 |
+
" batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"sentence\"]).lower() + \" \"\n",
|
49 |
+
" return batch\n",
|
50 |
+
"\n",
|
51 |
+
"def speech_file_to_array_fn(batch):\n",
|
52 |
+
" speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n",
|
53 |
+
" batch[\"speech\"] = speech_array[0].numpy()\n",
|
54 |
+
" batch[\"sampling_rate\"] = sampling_rate\n",
|
55 |
+
" batch[\"target_text\"] = batch[\"text\"]\n",
|
56 |
+
" return batch\n",
|
57 |
+
"\n",
|
58 |
+
"def resample(batch):\n",
|
59 |
+
" batch[\"speech\"] = librosa.resample(np.asarray(batch[\"speech\"]), 48_000, 16_000)\n",
|
60 |
+
" batch[\"sampling_rate\"] = 16_000\n",
|
61 |
+
" return batch\n",
|
62 |
+
"\n",
|
63 |
+
"def prepare_dataset(batch):\n",
|
64 |
+
" # check that all files have the correct sampling rate\n",
|
65 |
+
" assert (\n",
|
66 |
+
" len(set(batch[\"sampling_rate\"])) == 1\n",
|
67 |
+
" ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n",
|
68 |
+
"\n",
|
69 |
+
" batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n",
|
70 |
+
" \n",
|
71 |
+
" with processor.as_target_processor():\n",
|
72 |
+
" batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n",
|
73 |
+
" return batch"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": 3,
|
79 |
+
"metadata": {
|
80 |
+
"ExecuteTime": {
|
81 |
+
"end_time": "2021-03-14T09:33:49.053762Z",
|
82 |
+
"start_time": "2021-03-14T09:33:41.922683Z"
|
83 |
+
}
|
84 |
+
},
|
85 |
+
"outputs": [
|
86 |
+
{
|
87 |
+
"name": "stderr",
|
88 |
+
"output_type": "stream",
|
89 |
+
"text": [
|
90 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.\n"
|
91 |
+
]
|
92 |
+
}
|
93 |
+
],
|
94 |
+
"source": [
|
95 |
+
"model = Wav2Vec2ForCTC.from_pretrained(\"wav2vec2-large-xlsr-greek/checkpoint-9200/\").to(\"cuda\")\n",
|
96 |
+
"processor = Wav2Vec2Processor.from_pretrained(\"wav2vec2-large-xlsr-greek/\")"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": 4,
|
102 |
+
"metadata": {
|
103 |
+
"ExecuteTime": {
|
104 |
+
"end_time": "2021-03-14T09:33:52.413558Z",
|
105 |
+
"start_time": "2021-03-14T09:33:49.078466Z"
|
106 |
+
}
|
107 |
+
},
|
108 |
+
"outputs": [
|
109 |
+
{
|
110 |
+
"name": "stderr",
|
111 |
+
"output_type": "stream",
|
112 |
+
"text": [
|
113 |
+
"Using custom data configuration el-afd0a157f05ee080\n",
|
114 |
+
"Reusing dataset common_voice (/home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564)\n"
|
115 |
+
]
|
116 |
+
}
|
117 |
+
],
|
118 |
+
"source": [
|
119 |
+
"common_voice_test = load_dataset(\"common_voice\", \"el\", data_dir=\"cv-corpus-6.1-2020-12-11\", split=\"test\")"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "code",
|
124 |
+
"execution_count": 5,
|
125 |
+
"metadata": {
|
126 |
+
"ExecuteTime": {
|
127 |
+
"end_time": "2021-03-14T09:33:52.444418Z",
|
128 |
+
"start_time": "2021-03-14T09:33:52.441338Z"
|
129 |
+
}
|
130 |
+
},
|
131 |
+
"outputs": [],
|
132 |
+
"source": [
|
133 |
+
"common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "code",
|
138 |
+
"execution_count": 6,
|
139 |
+
"metadata": {
|
140 |
+
"ExecuteTime": {
|
141 |
+
"end_time": "2021-03-14T09:33:52.473087Z",
|
142 |
+
"start_time": "2021-03-14T09:33:52.468014Z"
|
143 |
+
}
|
144 |
+
},
|
145 |
+
"outputs": [
|
146 |
+
{
|
147 |
+
"name": "stderr",
|
148 |
+
"output_type": "stream",
|
149 |
+
"text": [
|
150 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-0ce2ebca66096fff.arrow\n"
|
151 |
+
]
|
152 |
+
}
|
153 |
+
],
|
154 |
+
"source": [
|
155 |
+
"common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=[\"sentence\"])"
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"execution_count": 7,
|
161 |
+
"metadata": {
|
162 |
+
"ExecuteTime": {
|
163 |
+
"end_time": "2021-03-14T09:33:52.510377Z",
|
164 |
+
"start_time": "2021-03-14T09:33:52.501677Z"
|
165 |
+
}
|
166 |
+
},
|
167 |
+
"outputs": [
|
168 |
+
{
|
169 |
+
"name": "stderr",
|
170 |
+
"output_type": "stream",
|
171 |
+
"text": [
|
172 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-38a09981767eff59.arrow\n"
|
173 |
+
]
|
174 |
+
}
|
175 |
+
],
|
176 |
+
"source": [
|
177 |
+
"common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": 8,
|
183 |
+
"metadata": {
|
184 |
+
"ExecuteTime": {
|
185 |
+
"end_time": "2021-03-14T09:33:53.321810Z",
|
186 |
+
"start_time": "2021-03-14T09:33:52.533233Z"
|
187 |
+
}
|
188 |
+
},
|
189 |
+
"outputs": [
|
190 |
+
{
|
191 |
+
"name": "stdout",
|
192 |
+
"output_type": "stream",
|
193 |
+
"text": [
|
194 |
+
" "
|
195 |
+
]
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"name": "stderr",
|
199 |
+
"output_type": "stream",
|
200 |
+
"text": [
|
201 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-ba8c6dd59eb8ccf2.arrow\n",
|
202 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-2e240883a5f827fd.arrow\n"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"name": "stdout",
|
207 |
+
"output_type": "stream",
|
208 |
+
"text": [
|
209 |
+
" "
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"name": "stderr",
|
214 |
+
"output_type": "stream",
|
215 |
+
"text": [
|
216 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-485c00dc9048ed50.arrow\n",
|
217 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-44bf1791baae8e2e.arrow\n"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"name": "stdout",
|
222 |
+
"output_type": "stream",
|
223 |
+
"text": [
|
224 |
+
" "
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"name": "stderr",
|
229 |
+
"output_type": "stream",
|
230 |
+
"text": [
|
231 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-ecc0dfac5615a58e.arrow\n"
|
232 |
+
]
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"name": "stdout",
|
236 |
+
"output_type": "stream",
|
237 |
+
"text": [
|
238 |
+
" "
|
239 |
+
]
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"name": "stderr",
|
243 |
+
"output_type": "stream",
|
244 |
+
"text": [
|
245 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-923d905502a8661d.arrow\n"
|
246 |
+
]
|
247 |
+
},
|
248 |
+
{
|
249 |
+
"name": "stdout",
|
250 |
+
"output_type": "stream",
|
251 |
+
"text": [
|
252 |
+
" "
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"name": "stderr",
|
257 |
+
"output_type": "stream",
|
258 |
+
"text": [
|
259 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-062aeafc3b8816c1.arrow\n"
|
260 |
+
]
|
261 |
+
},
|
262 |
+
{
|
263 |
+
"name": "stdout",
|
264 |
+
"output_type": "stream",
|
265 |
+
"text": [
|
266 |
+
" "
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"name": "stderr",
|
271 |
+
"output_type": "stream",
|
272 |
+
"text": [
|
273 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-bb54bb00dae79669.arrow\n"
|
274 |
+
]
|
275 |
+
}
|
276 |
+
],
|
277 |
+
"source": [
|
278 |
+
"common_voice_test = common_voice_test.map(resample, num_proc=8)"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": 9,
|
284 |
+
"metadata": {
|
285 |
+
"ExecuteTime": {
|
286 |
+
"end_time": "2021-03-14T09:33:53.611415Z",
|
287 |
+
"start_time": "2021-03-14T09:33:53.342487Z"
|
288 |
+
}
|
289 |
+
},
|
290 |
+
"outputs": [
|
291 |
+
{
|
292 |
+
"name": "stderr",
|
293 |
+
"output_type": "stream",
|
294 |
+
"text": [
|
295 |
+
"/home/earendil/anaconda3/envs/cuda110/lib/python3.8/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
|
296 |
+
" return array(a, dtype, copy=False, order=order)\n"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"name": "stdout",
|
301 |
+
"output_type": "stream",
|
302 |
+
"text": [
|
303 |
+
" "
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"name": "stderr",
|
308 |
+
"output_type": "stream",
|
309 |
+
"text": [
|
310 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-6dfad29ca815f865.arrow\n"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
{
|
314 |
+
"name": "stdout",
|
315 |
+
"output_type": "stream",
|
316 |
+
"text": [
|
317 |
+
" "
|
318 |
+
]
|
319 |
+
},
|
320 |
+
{
|
321 |
+
"name": "stderr",
|
322 |
+
"output_type": "stream",
|
323 |
+
"text": [
|
324 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-61e9ae0296df46f8.arrow\n"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
{
|
328 |
+
"name": "stdout",
|
329 |
+
"output_type": "stream",
|
330 |
+
"text": [
|
331 |
+
" "
|
332 |
+
]
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"name": "stderr",
|
336 |
+
"output_type": "stream",
|
337 |
+
"text": [
|
338 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-7f5aae16804e0788.arrow\n"
|
339 |
+
]
|
340 |
+
},
|
341 |
+
{
|
342 |
+
"name": "stdout",
|
343 |
+
"output_type": "stream",
|
344 |
+
"text": [
|
345 |
+
" "
|
346 |
+
]
|
347 |
+
},
|
348 |
+
{
|
349 |
+
"name": "stderr",
|
350 |
+
"output_type": "stream",
|
351 |
+
"text": [
|
352 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-b9636a5d30ffb973.arrow\n"
|
353 |
+
]
|
354 |
+
},
|
355 |
+
{
|
356 |
+
"name": "stdout",
|
357 |
+
"output_type": "stream",
|
358 |
+
"text": [
|
359 |
+
" "
|
360 |
+
]
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"name": "stderr",
|
364 |
+
"output_type": "stream",
|
365 |
+
"text": [
|
366 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-7e60f2d73a65610a.arrow\n"
|
367 |
+
]
|
368 |
+
},
|
369 |
+
{
|
370 |
+
"name": "stdout",
|
371 |
+
"output_type": "stream",
|
372 |
+
"text": [
|
373 |
+
" "
|
374 |
+
]
|
375 |
+
},
|
376 |
+
{
|
377 |
+
"name": "stderr",
|
378 |
+
"output_type": "stream",
|
379 |
+
"text": [
|
380 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-3c99781789816a60.arrow\n"
|
381 |
+
]
|
382 |
+
},
|
383 |
+
{
|
384 |
+
"name": "stdout",
|
385 |
+
"output_type": "stream",
|
386 |
+
"text": [
|
387 |
+
" "
|
388 |
+
]
|
389 |
+
},
|
390 |
+
{
|
391 |
+
"name": "stderr",
|
392 |
+
"output_type": "stream",
|
393 |
+
"text": [
|
394 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-bae077f32f9eb290.arrow\n"
|
395 |
+
]
|
396 |
+
},
|
397 |
+
{
|
398 |
+
"name": "stdout",
|
399 |
+
"output_type": "stream",
|
400 |
+
"text": [
|
401 |
+
" "
|
402 |
+
]
|
403 |
+
},
|
404 |
+
{
|
405 |
+
"name": "stderr",
|
406 |
+
"output_type": "stream",
|
407 |
+
"text": [
|
408 |
+
"Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-4fb6951626f7548e.arrow\n"
|
409 |
+
]
|
410 |
+
}
|
411 |
+
],
|
412 |
+
"source": [
|
413 |
+
"common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)"
|
414 |
+
]
|
415 |
+
},
|
416 |
+
{
|
417 |
+
"cell_type": "code",
|
418 |
+
"execution_count": 10,
|
419 |
+
"metadata": {
|
420 |
+
"ExecuteTime": {
|
421 |
+
"end_time": "2021-03-14T09:33:56.243678Z",
|
422 |
+
"start_time": "2021-03-14T09:33:53.632436Z"
|
423 |
+
}
|
424 |
+
},
|
425 |
+
"outputs": [
|
426 |
+
{
|
427 |
+
"name": "stderr",
|
428 |
+
"output_type": "stream",
|
429 |
+
"text": [
|
430 |
+
"Using custom data configuration el-ac779bf2c9f7c09b\n",
|
431 |
+
"Reusing dataset common_voice (/home/earendil/.cache/huggingface/datasets/common_voice/el-ac779bf2c9f7c09b/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564)\n"
|
432 |
+
]
|
433 |
+
}
|
434 |
+
],
|
435 |
+
"source": [
|
436 |
+
"common_voice_test_transcription = load_dataset(\"common_voice\", \"el\", data_dir=\"./cv-corpus-6.1-2020-12-11\", split=\"test\")"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"cell_type": "code",
|
441 |
+
"execution_count": 19,
|
442 |
+
"metadata": {
|
443 |
+
"ExecuteTime": {
|
444 |
+
"end_time": "2021-03-14T09:36:50.076837Z",
|
445 |
+
"start_time": "2021-03-14T09:36:24.943947Z"
|
446 |
+
}
|
447 |
+
},
|
448 |
+
"outputs": [],
|
449 |
+
"source": [
|
450 |
+
"# Change this value to try inference on different CommonVoice extracts\n",
|
451 |
+
"example = 123\n",
|
452 |
+
"\n",
|
453 |
+
"input_dict = processor(common_voice_test[\"input_values\"][example], return_tensors=\"pt\", sampling_rate=16_000, padding=True)\n",
|
454 |
+
"\n",
|
455 |
+
"logits = model(input_dict.input_values.to(\"cuda\")).logits\n",
|
456 |
+
"\n",
|
457 |
+
"pred_ids = torch.argmax(logits, dim=-1)"
|
458 |
+
]
|
459 |
+
},
|
460 |
+
{
|
461 |
+
"cell_type": "code",
|
462 |
+
"execution_count": 20,
|
463 |
+
"metadata": {
|
464 |
+
"ExecuteTime": {
|
465 |
+
"end_time": "2021-03-14T09:36:50.137886Z",
|
466 |
+
"start_time": "2021-03-14T09:36:50.134218Z"
|
467 |
+
}
|
468 |
+
},
|
469 |
+
"outputs": [
|
470 |
+
{
|
471 |
+
"name": "stdout",
|
472 |
+
"output_type": "stream",
|
473 |
+
"text": [
|
474 |
+
"Prediction:\n",
|
475 |
+
"καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρί\n",
|
476 |
+
"\n",
|
477 |
+
"Reference:\n",
|
478 |
+
"καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρή\n"
|
479 |
+
]
|
480 |
+
}
|
481 |
+
],
|
482 |
+
"source": [
|
483 |
+
"print(\"Prediction:\")\n",
|
484 |
+
"print(processor.decode(pred_ids[0]))\n",
|
485 |
+
"# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρί\n",
|
486 |
+
"\n",
|
487 |
+
"print(\"\\nReference:\")\n",
|
488 |
+
"print(common_voice_test_transcription[\"sentence\"][example].lower())\n",
|
489 |
+
"# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρή"
|
490 |
+
]
|
491 |
+
},
|
492 |
+
{
|
493 |
+
"cell_type": "code",
|
494 |
+
"execution_count": null,
|
495 |
+
"metadata": {},
|
496 |
+
"outputs": [],
|
497 |
+
"source": []
|
498 |
+
}
|
499 |
+
],
|
500 |
+
"metadata": {
|
501 |
+
"kernelspec": {
|
502 |
+
"display_name": "cuda110",
|
503 |
+
"language": "python",
|
504 |
+
"name": "cuda110"
|
505 |
+
},
|
506 |
+
"language_info": {
|
507 |
+
"codemirror_mode": {
|
508 |
+
"name": "ipython",
|
509 |
+
"version": 3
|
510 |
+
},
|
511 |
+
"file_extension": ".py",
|
512 |
+
"mimetype": "text/x-python",
|
513 |
+
"name": "python",
|
514 |
+
"nbconvert_exporter": "python",
|
515 |
+
"pygments_lexer": "ipython3",
|
516 |
+
"version": "3.8.5"
|
517 |
+
},
|
518 |
+
"varInspector": {
|
519 |
+
"cols": {
|
520 |
+
"lenName": 16,
|
521 |
+
"lenType": 16,
|
522 |
+
"lenVar": 40
|
523 |
+
},
|
524 |
+
"kernels_config": {
|
525 |
+
"python": {
|
526 |
+
"delete_cmd_postfix": "",
|
527 |
+
"delete_cmd_prefix": "del ",
|
528 |
+
"library": "var_list.py",
|
529 |
+
"varRefreshCmd": "print(var_dic_list())"
|
530 |
+
},
|
531 |
+
"r": {
|
532 |
+
"delete_cmd_postfix": ") ",
|
533 |
+
"delete_cmd_prefix": "rm(",
|
534 |
+
"library": "var_list.r",
|
535 |
+
"varRefreshCmd": "cat(var_dic_list()) "
|
536 |
+
}
|
537 |
+
},
|
538 |
+
"types_to_exclude": [
|
539 |
+
"module",
|
540 |
+
"function",
|
541 |
+
"builtin_function_or_method",
|
542 |
+
"instance",
|
543 |
+
"_Feature"
|
544 |
+
],
|
545 |
+
"window_display": false
|
546 |
+
}
|
547 |
+
},
|
548 |
+
"nbformat": 4,
|
549 |
+
"nbformat_minor": 4
|
550 |
+
}
|
ASR_Inference.ipynb
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {
|
7 |
+
"ExecuteTime": {
|
8 |
+
"end_time": "2021-03-14T09:33:41.892030Z",
|
9 |
+
"start_time": "2021-03-14T09:33:40.729163Z"
|
10 |
+
}
|
11 |
+
},
|
12 |
+
"outputs": [],
|
13 |
+
"source": [
|
14 |
+
"from transformers import Wav2Vec2ForCTC\n",
|
15 |
+
"from transformers import Wav2Vec2Processor\n",
|
16 |
+
"from datasets import load_dataset, load_metric\n",
|
17 |
+
"import re\n",
|
18 |
+
"import torchaudio\n",
|
19 |
+
"import librosa\n",
|
20 |
+
"import numpy as np\n",
|
21 |
+
"from datasets import load_dataset, load_metric\n",
|
22 |
+
"import torch"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": null,
|
28 |
+
"metadata": {
|
29 |
+
"ExecuteTime": {
|
30 |
+
"end_time": "2021-03-14T09:33:41.909851Z",
|
31 |
+
"start_time": "2021-03-14T09:33:41.906327Z"
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"outputs": [],
|
35 |
+
"source": [
|
36 |
+
"chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n",
|
37 |
+
"\n",
|
38 |
+
"def remove_special_characters(batch):\n",
|
39 |
+
" batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"sentence\"]).lower() + \" \"\n",
|
40 |
+
" return batch\n",
|
41 |
+
"\n",
|
42 |
+
"def speech_file_to_array_fn(batch):\n",
|
43 |
+
" speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n",
|
44 |
+
" batch[\"speech\"] = speech_array[0].numpy()\n",
|
45 |
+
" batch[\"sampling_rate\"] = sampling_rate\n",
|
46 |
+
" batch[\"target_text\"] = batch[\"text\"]\n",
|
47 |
+
" return batch\n",
|
48 |
+
"\n",
|
49 |
+
"def resample(batch):\n",
|
50 |
+
" batch[\"speech\"] = librosa.resample(np.asarray(batch[\"speech\"]), 48_000, 16_000)\n",
|
51 |
+
" batch[\"sampling_rate\"] = 16_000\n",
|
52 |
+
" return batch\n",
|
53 |
+
"\n",
|
54 |
+
"def prepare_dataset(batch):\n",
|
55 |
+
" # check that all files have the correct sampling rate\n",
|
56 |
+
" assert (\n",
|
57 |
+
" len(set(batch[\"sampling_rate\"])) == 1\n",
|
58 |
+
" ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n",
|
59 |
+
"\n",
|
60 |
+
" batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n",
|
61 |
+
" \n",
|
62 |
+
" with processor.as_target_processor():\n",
|
63 |
+
" batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n",
|
64 |
+
" return batch"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": null,
|
70 |
+
"metadata": {
|
71 |
+
"ExecuteTime": {
|
72 |
+
"end_time": "2021-03-14T09:33:49.053762Z",
|
73 |
+
"start_time": "2021-03-14T09:33:41.922683Z"
|
74 |
+
}
|
75 |
+
},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"model = Wav2Vec2ForCTC.from_pretrained(\"wav2vec2-large-xlsr-greek/checkpoint-9200/\").to(\"cuda\")\n",
|
79 |
+
"processor = Wav2Vec2Processor.from_pretrained(\"wav2vec2-large-xlsr-greek/\")"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": null,
|
85 |
+
"metadata": {
|
86 |
+
"ExecuteTime": {
|
87 |
+
"end_time": "2021-03-14T09:33:52.413558Z",
|
88 |
+
"start_time": "2021-03-14T09:33:49.078466Z"
|
89 |
+
}
|
90 |
+
},
|
91 |
+
"outputs": [],
|
92 |
+
"source": [
|
93 |
+
"common_voice_test = load_dataset(\"common_voice\", \"el\", data_dir=\"cv-corpus-6.1-2020-12-11\", split=\"test\")"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"cell_type": "code",
|
98 |
+
"execution_count": null,
|
99 |
+
"metadata": {
|
100 |
+
"ExecuteTime": {
|
101 |
+
"end_time": "2021-03-14T09:33:52.444418Z",
|
102 |
+
"start_time": "2021-03-14T09:33:52.441338Z"
|
103 |
+
}
|
104 |
+
},
|
105 |
+
"outputs": [],
|
106 |
+
"source": [
|
107 |
+
"common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])"
|
108 |
+
]
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"cell_type": "code",
|
112 |
+
"execution_count": null,
|
113 |
+
"metadata": {
|
114 |
+
"ExecuteTime": {
|
115 |
+
"end_time": "2021-03-14T09:33:52.473087Z",
|
116 |
+
"start_time": "2021-03-14T09:33:52.468014Z"
|
117 |
+
}
|
118 |
+
},
|
119 |
+
"outputs": [],
|
120 |
+
"source": [
|
121 |
+
"common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=[\"sentence\"])"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"execution_count": null,
|
127 |
+
"metadata": {
|
128 |
+
"ExecuteTime": {
|
129 |
+
"end_time": "2021-03-14T09:33:52.510377Z",
|
130 |
+
"start_time": "2021-03-14T09:33:52.501677Z"
|
131 |
+
}
|
132 |
+
},
|
133 |
+
"outputs": [],
|
134 |
+
"source": [
|
135 |
+
"common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"execution_count": null,
|
141 |
+
"metadata": {
|
142 |
+
"ExecuteTime": {
|
143 |
+
"end_time": "2021-03-14T09:33:53.321810Z",
|
144 |
+
"start_time": "2021-03-14T09:33:52.533233Z"
|
145 |
+
}
|
146 |
+
},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"common_voice_test = common_voice_test.map(resample, num_proc=8)"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"execution_count": null,
|
155 |
+
"metadata": {
|
156 |
+
"ExecuteTime": {
|
157 |
+
"end_time": "2021-03-14T09:33:53.611415Z",
|
158 |
+
"start_time": "2021-03-14T09:33:53.342487Z"
|
159 |
+
}
|
160 |
+
},
|
161 |
+
"outputs": [],
|
162 |
+
"source": [
|
163 |
+
"common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": null,
|
169 |
+
"metadata": {
|
170 |
+
"ExecuteTime": {
|
171 |
+
"end_time": "2021-03-14T09:33:56.243678Z",
|
172 |
+
"start_time": "2021-03-14T09:33:53.632436Z"
|
173 |
+
}
|
174 |
+
},
|
175 |
+
"outputs": [],
|
176 |
+
"source": [
|
177 |
+
"common_voice_test_transcription = load_dataset(\"common_voice\", \"el\", data_dir=\"./cv-corpus-6.1-2020-12-11\", split=\"test\")"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": null,
|
183 |
+
"metadata": {
|
184 |
+
"ExecuteTime": {
|
185 |
+
"end_time": "2021-03-14T09:36:50.076837Z",
|
186 |
+
"start_time": "2021-03-14T09:36:24.943947Z"
|
187 |
+
}
|
188 |
+
},
|
189 |
+
"outputs": [],
|
190 |
+
"source": [
|
191 |
+
"# Change this value to try inference on different CommonVoice extracts\n",
|
192 |
+
"example = 123\n",
|
193 |
+
"\n",
|
194 |
+
"input_dict = processor(common_voice_test[\"input_values\"][example], return_tensors=\"pt\", sampling_rate=16_000, padding=True)\n",
|
195 |
+
"\n",
|
196 |
+
"logits = model(input_dict.input_values.to(\"cuda\")).logits\n",
|
197 |
+
"\n",
|
198 |
+
"pred_ids = torch.argmax(logits, dim=-1)"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": null,
|
204 |
+
"metadata": {
|
205 |
+
"ExecuteTime": {
|
206 |
+
"end_time": "2021-03-14T09:36:50.137886Z",
|
207 |
+
"start_time": "2021-03-14T09:36:50.134218Z"
|
208 |
+
}
|
209 |
+
},
|
210 |
+
"outputs": [],
|
211 |
+
"source": [
|
212 |
+
"print(\"Prediction:\")\n",
|
213 |
+
"print(processor.decode(pred_ids[0]))\n",
|
214 |
+
"# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρί\n",
|
215 |
+
"\n",
|
216 |
+
"print(\"\\nReference:\")\n",
|
217 |
+
"print(common_voice_test_transcription[\"sentence\"][example].lower())\n",
|
218 |
+
"# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρή"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"cell_type": "code",
|
223 |
+
"execution_count": null,
|
224 |
+
"metadata": {},
|
225 |
+
"outputs": [],
|
226 |
+
"source": []
|
227 |
+
}
|
228 |
+
],
|
229 |
+
"metadata": {
|
230 |
+
"kernelspec": {
|
231 |
+
"display_name": "cuda110",
|
232 |
+
"language": "python",
|
233 |
+
"name": "cuda110"
|
234 |
+
},
|
235 |
+
"language_info": {
|
236 |
+
"codemirror_mode": {
|
237 |
+
"name": "ipython",
|
238 |
+
"version": 3
|
239 |
+
},
|
240 |
+
"file_extension": ".py",
|
241 |
+
"mimetype": "text/x-python",
|
242 |
+
"name": "python",
|
243 |
+
"nbconvert_exporter": "python",
|
244 |
+
"pygments_lexer": "ipython3",
|
245 |
+
"version": "3.8.5"
|
246 |
+
},
|
247 |
+
"varInspector": {
|
248 |
+
"cols": {
|
249 |
+
"lenName": 16,
|
250 |
+
"lenType": 16,
|
251 |
+
"lenVar": 40
|
252 |
+
},
|
253 |
+
"kernels_config": {
|
254 |
+
"python": {
|
255 |
+
"delete_cmd_postfix": "",
|
256 |
+
"delete_cmd_prefix": "del ",
|
257 |
+
"library": "var_list.py",
|
258 |
+
"varRefreshCmd": "print(var_dic_list())"
|
259 |
+
},
|
260 |
+
"r": {
|
261 |
+
"delete_cmd_postfix": ") ",
|
262 |
+
"delete_cmd_prefix": "rm(",
|
263 |
+
"library": "var_list.r",
|
264 |
+
"varRefreshCmd": "cat(var_dic_list()) "
|
265 |
+
}
|
266 |
+
},
|
267 |
+
"types_to_exclude": [
|
268 |
+
"module",
|
269 |
+
"function",
|
270 |
+
"builtin_function_or_method",
|
271 |
+
"instance",
|
272 |
+
"_Feature"
|
273 |
+
],
|
274 |
+
"window_display": false
|
275 |
+
}
|
276 |
+
},
|
277 |
+
"nbformat": 4,
|
278 |
+
"nbformat_minor": 4
|
279 |
+
}
|
README.md
CHANGED
@@ -24,9 +24,107 @@ Wav2Vec2 is a pretrained model for Automatic Speech Recognition (ASR) and was re
|
|
24 |
|
25 |
Similar to Wav2Vec2, XLSR-Wav2Vec2 learns powerful speech representations from hundreds of thousands of hours of speech in more than 50 languages of unlabeled speech. Similar, to BERT's masked language modeling, the model learns contextualized speech representations by randomly masking feature vectors before passing them to a transformer network.
|
26 |
|
27 |
-
### How to use
|
28 |
|
29 |
-
Instructions to
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
## Metrics
|
@@ -38,6 +136,6 @@ Instructions to replicate the process are included in the Jupyter notebook.
|
|
38 |
| WER | 0.45049 |
|
39 |
|
40 |
|
41 |
-
###
|
42 |
Based on the tutorial of Patrick von Platen: https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
|
43 |
Original colab notebook here: https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_Tune_XLSR_Wav2Vec2_on_Turkish_ASR_with_%F0%9F%A4%97_Transformers.ipynb#scrollTo=V7YOT2mnUiea
|
|
|
24 |
|
25 |
Similar to Wav2Vec2, XLSR-Wav2Vec2 learns powerful speech representations from hundreds of thousands of hours of speech in more than 50 languages of unlabeled speech. Similar, to BERT's masked language modeling, the model learns contextualized speech representations by randomly masking feature vectors before passing them to a transformer network.
|
26 |
|
27 |
+
### How to use for inference:
|
28 |
|
29 |
+
Instructions to test on CommonVoice extracts are provided in the ASR_Inference.ipynb. Snippet also available below:
|
30 |
+
|
31 |
+
```
|
32 |
+
#!/usr/bin/env python
|
33 |
+
# coding: utf-8
|
34 |
+
|
35 |
+
# Loading dependencies and defining preprocessing functions
|
36 |
+
|
37 |
+
from transformers import Wav2Vec2ForCTC
|
38 |
+
from transformers import Wav2Vec2Processor
|
39 |
+
from datasets import load_dataset, load_metric
|
40 |
+
import re
|
41 |
+
import torchaudio
|
42 |
+
import librosa
|
43 |
+
import numpy as np
|
44 |
+
from datasets import load_dataset, load_metric
|
45 |
+
import torch
|
46 |
+
|
47 |
+
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�]'
|
48 |
+
|
49 |
+
def remove_special_characters(batch):
|
50 |
+
batch["text"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
|
51 |
+
return batch
|
52 |
+
|
53 |
+
def speech_file_to_array_fn(batch):
|
54 |
+
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
55 |
+
batch["speech"] = speech_array[0].numpy()
|
56 |
+
batch["sampling_rate"] = sampling_rate
|
57 |
+
batch["target_text"] = batch["text"]
|
58 |
+
return batch
|
59 |
+
|
60 |
+
def resample(batch):
|
61 |
+
batch["speech"] = librosa.resample(np.asarray(batch["speech"]), 48_000, 16_000)
|
62 |
+
batch["sampling_rate"] = 16_000
|
63 |
+
return batch
|
64 |
+
|
65 |
+
def prepare_dataset(batch):
|
66 |
+
# check that all files have the correct sampling rate
|
67 |
+
assert (
|
68 |
+
len(set(batch["sampling_rate"])) == 1
|
69 |
+
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
|
70 |
+
|
71 |
+
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
|
72 |
+
|
73 |
+
with processor.as_target_processor():
|
74 |
+
batch["labels"] = processor(batch["target_text"]).input_ids
|
75 |
+
return batch
|
76 |
+
|
77 |
+
|
78 |
+
# Loading model and dataset processor
|
79 |
+
|
80 |
+
model = Wav2Vec2ForCTC.from_pretrained("wav2vec2-large-xlsr-greek/checkpoint-9200/").to("cuda")
|
81 |
+
processor = Wav2Vec2Processor.from_pretrained("wav2vec2-large-xlsr-greek/")
|
82 |
+
|
83 |
+
|
84 |
+
# Preparing speech dataset to be suitable for inference
|
85 |
+
|
86 |
+
common_voice_test = load_dataset("common_voice", "el", data_dir="cv-corpus-6.1-2020-12-11", split="test")
|
87 |
+
|
88 |
+
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
|
89 |
+
|
90 |
+
common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=["sentence"])
|
91 |
+
|
92 |
+
common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)
|
93 |
+
|
94 |
+
common_voice_test = common_voice_test.map(resample, num_proc=8)
|
95 |
+
|
96 |
+
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)
|
97 |
+
|
98 |
+
|
99 |
+
# Loading test dataset
|
100 |
+
|
101 |
+
common_voice_test_transcription = load_dataset("common_voice", "el", data_dir="./cv-corpus-6.1-2020-12-11", split="test")
|
102 |
+
|
103 |
+
|
104 |
+
#Performing inference on a random sample. Change the "example" value to try inference on different CommonVoice extracts
|
105 |
+
|
106 |
+
example = 123
|
107 |
+
|
108 |
+
input_dict = processor(common_voice_test["input_values"][example], return_tensors="pt", sampling_rate=16_000, padding=True)
|
109 |
+
|
110 |
+
logits = model(input_dict.input_values.to("cuda")).logits
|
111 |
+
|
112 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
113 |
+
|
114 |
+
print("Prediction:")
|
115 |
+
print(processor.decode(pred_ids[0]))
|
116 |
+
# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρί
|
117 |
+
|
118 |
+
print("\nReference:")
|
119 |
+
print(common_voice_test_transcription["sentence"][example].lower())
|
120 |
+
# καμιά φορά τα έπαιρνε και έπαιζε όταν η δουλειά ήταν πιο χαλαρή
|
121 |
+
|
122 |
+
|
123 |
+
```
|
124 |
+
|
125 |
+
### How to use for training:
|
126 |
+
|
127 |
+
Instructions and code to replicate the process are provided in the Fine_Tune_XLSR_Wav2Vec2_on_Greek_ASR_with_🤗_Transformers.ipynb notebook.
|
128 |
|
129 |
|
130 |
## Metrics
|
|
|
136 |
| WER | 0.45049 |
|
137 |
|
138 |
|
139 |
+
### Acknowledgment
|
140 |
Based on the tutorial of Patrick von Platen: https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
|
141 |
Original colab notebook here: https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_Tune_XLSR_Wav2Vec2_on_Turkish_ASR_with_%F0%9F%A4%97_Transformers.ipynb#scrollTo=V7YOT2mnUiea
|